From fb570e5a0ecbcd6b083a0908c3055960ff852726 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng <mercy_zheng@sjtu.edu.cn> Date: Tue, 18 Sep 2018 09:29:49 -0700 Subject: [PATCH] [CODEGEN] Fix let expression (#1727) --- src/codegen/codegen_c.cc | 3 +-- topi/tests/python/test_topi_transform.py | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/src/codegen/codegen_c.cc b/src/codegen/codegen_c.cc index 09a6c7e6a..c3b0d278c 100644 --- a/src/codegen/codegen_c.cc +++ b/src/codegen/codegen_c.cc @@ -652,11 +652,10 @@ void CodeGenC::VisitStmt_(const Store* op) { } void CodeGenC::VisitExpr_(const Let* op, std::ostream& os) { // NOLINT(*) - CHECK(print_ssa_form_) - << "LetExpr is only supported by print SSA form"; std::string value = PrintExpr(op->value); CHECK(!var_idmap_.count(op->var.get())); var_idmap_[op->var.get()] = value; + os << PrintExpr(op->body); } void CodeGenC::VisitExpr_(const Ramp* op, std::ostream& os) { // NOLINT(*) diff --git a/topi/tests/python/test_topi_transform.py b/topi/tests/python/test_topi_transform.py index 123df331e..ce2505e0d 100644 --- a/topi/tests/python/test_topi_transform.py +++ b/topi/tests/python/test_topi_transform.py @@ -307,6 +307,21 @@ def test_squeeze(): verify_squeeze((1, 1, 1, 4), (1, 2)) verify_squeeze((1, 1, 1, 1), None) + # a special case to trigger inline let expression + A = tvm.placeholder((2,), 'float32', 'A') + E = topi.squeeze(A) + C = tvm.compute((1,), lambda i: E[(2 * A[0] - 1).astype('int32')]) + for device in ['cuda', 'opencl']: + ctx = tvm.context(device, 0) + if ctx.exist: + with tvm.target.create(device): + s = topi.generic.schedule_injective(C) + func = tvm.build(s, [A, C]) + a = tvm.nd.array(np.array((1, 2)).astype('float32'), ctx=ctx) + c = tvm.nd.empty((1,), dtype='float32', ctx=ctx) + func(a, c) + assert c.asnumpy()[0] == 2 + def test_concatenate(): verify_concatenate([(2,), (2,), (2,)], 0) -- GitLab