From fb570e5a0ecbcd6b083a0908c3055960ff852726 Mon Sep 17 00:00:00 2001
From: Lianmin Zheng <mercy_zheng@sjtu.edu.cn>
Date: Tue, 18 Sep 2018 09:29:49 -0700
Subject: [PATCH] [CODEGEN] Fix let expression (#1727)

---
 src/codegen/codegen_c.cc                 |  3 +--
 topi/tests/python/test_topi_transform.py | 15 +++++++++++++++
 2 files changed, 16 insertions(+), 2 deletions(-)

diff --git a/src/codegen/codegen_c.cc b/src/codegen/codegen_c.cc
index 09a6c7e6a..c3b0d278c 100644
--- a/src/codegen/codegen_c.cc
+++ b/src/codegen/codegen_c.cc
@@ -652,11 +652,10 @@ void CodeGenC::VisitStmt_(const Store* op) {
 }
 
 void CodeGenC::VisitExpr_(const Let* op, std::ostream& os) {  // NOLINT(*)
-  CHECK(print_ssa_form_)
-      << "LetExpr is only supported by print SSA form";
   std::string value = PrintExpr(op->value);
   CHECK(!var_idmap_.count(op->var.get()));
   var_idmap_[op->var.get()] = value;
+  os << PrintExpr(op->body);
 }
 
 void CodeGenC::VisitExpr_(const Ramp* op, std::ostream& os) {  // NOLINT(*)
diff --git a/topi/tests/python/test_topi_transform.py b/topi/tests/python/test_topi_transform.py
index 123df331e..ce2505e0d 100644
--- a/topi/tests/python/test_topi_transform.py
+++ b/topi/tests/python/test_topi_transform.py
@@ -307,6 +307,21 @@ def test_squeeze():
     verify_squeeze((1, 1, 1, 4), (1, 2))
     verify_squeeze((1, 1, 1, 1), None)
 
+    # a special case to trigger inline let expression
+    A = tvm.placeholder((2,), 'float32', 'A')
+    E = topi.squeeze(A)
+    C = tvm.compute((1,), lambda i: E[(2 * A[0] - 1).astype('int32')])
+    for device in ['cuda', 'opencl']:
+        ctx = tvm.context(device, 0)
+        if ctx.exist:
+            with tvm.target.create(device):
+                s = topi.generic.schedule_injective(C)
+                func = tvm.build(s, [A, C])
+            a = tvm.nd.array(np.array((1, 2)).astype('float32'), ctx=ctx)
+            c = tvm.nd.empty((1,), dtype='float32', ctx=ctx)
+            func(a, c)
+            assert c.asnumpy()[0] == 2
+
 
 def test_concatenate():
     verify_concatenate([(2,), (2,), (2,)], 0)
-- 
GitLab