From 3c1020dffb94e8ee4e076fa69374ad9df1d339ae Mon Sep 17 00:00:00 2001
From: Tianqi Chen <tqchen@users.noreply.github.com>
Date: Fri, 20 Jan 2017 12:41:52 -0800
Subject: [PATCH] [CODEGEN] Add CodeGenC (#22)

---
 HalideIR                          |   2 +-
 python/tvm/__init__.py            |   1 +
 python/tvm/_ctypes/_api.py        |   1 +
 python/tvm/codegen.py             |   1 +
 src/base/common.h                 |   2 +-
 src/c_api/c_api_codegen.cc        |  25 ++
 src/codegen/codegen_c.cc          | 483 ++++++++++++++++++++++++++++++
 src/codegen/codegen_c.h           | 140 +++++++++
 tests/python/test_codegen_cuda.py |   8 +
 9 files changed, 661 insertions(+), 2 deletions(-)
 create mode 100644 python/tvm/codegen.py
 create mode 100644 src/c_api/c_api_codegen.cc
 create mode 100644 src/codegen/codegen_c.cc
 create mode 100644 src/codegen/codegen_c.h

diff --git a/HalideIR b/HalideIR
index b6637f611..adfa66240 160000
--- a/HalideIR
+++ b/HalideIR
@@ -1 +1 @@
-Subproject commit b6637f611f91dd075dc251438f72ad38901d17fb
+Subproject commit adfa662402650e2f9b02ea600ffb70d6e7bb5adf
diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py
index b3a376de3..91b5abb6c 100644
--- a/python/tvm/__init__.py
+++ b/python/tvm/__init__.py
@@ -8,6 +8,7 @@ from . import expr
 from . import stmt
 from . import make
 from . import ir_pass
+from . import codegen
 from . import collections
 from . import schedule
 
diff --git a/python/tvm/_ctypes/_api.py b/python/tvm/_ctypes/_api.py
index 3ad5aa330..4b9d9d493 100644
--- a/python/tvm/_ctypes/_api.py
+++ b/python/tvm/_ctypes/_api.py
@@ -281,6 +281,7 @@ def _init_function_module(root_namespace):
     namespace_match = {
         "_make_": sys.modules["%s.make" % root_namespace],
         "_pass_": sys.modules["%s.ir_pass" % root_namespace],
+        "_codegen_": sys.modules["%s.codegen" % root_namespace],
         "_schedule_": sys.modules["%s.schedule" % root_namespace]
     }
 
diff --git a/python/tvm/codegen.py b/python/tvm/codegen.py
new file mode 100644
index 000000000..02dda155c
--- /dev/null
+++ b/python/tvm/codegen.py
@@ -0,0 +1 @@
+"""Code generation related functions"""
diff --git a/src/base/common.h b/src/base/common.h
index 0485bdfc4..ea2f4bdad 100644
--- a/src/base/common.h
+++ b/src/base/common.h
@@ -30,7 +30,7 @@ inline Type String2Type(std::string s) {
   } else if (s.substr(0, 5) == "float") {
     code = Type::Float; s = s.substr(5);
   } else if (s == "handle") {
-    return Type(Type::Handle, 0, 0);
+    return Type(Type::Handle, 32, 1);
   } else {
     LOG(FATAL) << "unknown type " << s;
   }
diff --git a/src/c_api/c_api_codegen.cc b/src/c_api/c_api_codegen.cc
new file mode 100644
index 000000000..365033ea4
--- /dev/null
+++ b/src/c_api/c_api_codegen.cc
@@ -0,0 +1,25 @@
+/*!
+ *  Copyright (c) 2016 by Contributors
+ *  Implementation of API functions related to IR build
+ * \file c_api_ir.cc
+ */
+#include <tvm/expr.h>
+#include <tvm/ir.h>
+
+#include "./c_api_registry.h"
+#include "../codegen/codegen_c.h"
+
+namespace tvm {
+namespace codegen {
+
+using ArgStack = const std::vector<APIVariantValue>;
+using RetValue = APIVariantValue;
+
+TVM_REGISTER_API(_codegen_CompileToC)
+.set_body([](const ArgStack& args, RetValue *ret) {
+    *ret = CodeGenC().Compile(
+        args.at(0), args.at(1), args.at(2), args.at(3));
+  });
+
+}  // namespace codegen
+}  // namespace tvm
diff --git a/src/codegen/codegen_c.cc b/src/codegen/codegen_c.cc
new file mode 100644
index 000000000..a42569e9a
--- /dev/null
+++ b/src/codegen/codegen_c.cc
@@ -0,0 +1,483 @@
+/*!
+ *  Copyright (c) 2017 by Contributors
+ * \file codegen_c.cc
+ */
+#include "./codegen_c.h"
+
+namespace tvm {
+namespace codegen {
+
+using namespace ir;
+
+std::string CodeGenC::Compile(
+    Stmt stmt, std::string fun_name,
+    Array<Var> args, bool output_ssa) {
+  print_ssa_form_ = output_ssa;
+  // skip the first underscore, so SSA variable starts from _1
+  if (print_ssa_form_) GetUniqueName("_");
+
+  this->indent += 2;
+  this->stream << "void " << fun_name << "(";
+  for (size_t i = 0; i < args.size(); ++i) {
+    Var v = args[i];
+    std::string vid = AllocVarID(v.get());
+    if (i != 0) stream << ", ";
+    PrintType(v.type(), stream);
+    stream << ' ' << vid;
+  }
+  stream << ") {\n";
+  this->PrintStmt(stmt);
+  this->indent -= 2;
+  this->PrintIndent();
+  this->stream << "}\n";
+  return stream.str();
+}
+
+void CodeGenC::PrintStmt(const Stmt& n) {
+  static const FPrintStmt& f = vtable_print_stmt();
+  f(n, this);
+}
+
+std::string CodeGenC::SSAGetID(std::string src, Type t) {
+  if (name_alloc_map_.count(src)) return src;
+  auto it = ssa_assign_map_.find(src);
+  if (it != ssa_assign_map_.end()) {
+    return it->second;
+  } else {
+    this->PrintIndent();
+    std::string id = GetUniqueName("_");
+    ssa_assign_map_[src] = id;
+    if (src.length() > 3 &&
+        src[0] == '(' && src[src.length() - 1] == ')') {
+      src = src.substr(1, src.length() - 2);
+    }
+    PrintType(t, stream);
+    stream << ' ' << id << " = " << src << ";\n";
+    return id;
+  }
+}
+
+void CodeGenC::PrintExpr(const Expr& n, std::ostream& os) {  // NOLINT(*)
+  static const FPrintExpr& f = vtable_print_expr();
+  if (print_ssa_form_) {
+    std::ostringstream temp;
+    f(n, temp, this);
+    os << SSAGetID(temp.str(), n.type());
+  } else {
+    f(n, os, this);
+  }
+}
+
+std::string CodeGenC::GetUniqueName(std::string prefix) {
+  auto it = name_alloc_map_.find(prefix);
+  if (it != name_alloc_map_.end()) {
+    while (true) {
+      std::ostringstream os;
+      os << prefix << (++it->second);
+      std::string name = os.str();
+      if (name_alloc_map_.count(name) == 0) {
+        prefix = name;
+        break;
+      }
+    }
+  }
+  name_alloc_map_[prefix] = 0;
+  return prefix;
+}
+
+std::string CodeGenC::AllocVarID(const Variable* v) {
+  CHECK(!var_idmap_.count(v))
+      << "Need input to be in SSA form dup " << v->name_hint;
+  std::string key = v->name_hint;
+  for (size_t i = 0; i < key.size(); ++i) {
+    if (key[i] == '.') key[i] = '_';
+  }
+  std::string vid = GetUniqueName(key);
+  var_idmap_[v] = vid;
+  return vid;
+}
+
+std::string CodeGenC::GetVarID(const Variable* v) const {
+  auto it = var_idmap_.find(v);
+  CHECK(it != var_idmap_.end())
+      << "Find undefined Variable " << v->name_hint;
+  return it->second;
+}
+
+bool CodeGenC::BufferTypeMatch(const Variable* buf_var, Type t) const {
+  auto it = alloc_buf_type_.find(buf_var);
+  if (it == alloc_buf_type_.end()) return false;
+  return it->second == t;
+}
+
+void CodeGenC::PrintIndent() {
+  for (int i = 0; i < this->indent; ++i) {
+    this->stream << ' ';
+  }
+}
+
+void CodeGenC::MarkConst(std::string vid) {
+  if (print_ssa_form_) {
+    auto it = ssa_assign_map_.find(vid);
+    if (it == ssa_assign_map_.end()) {
+      ssa_assign_map_[vid] = vid;
+    } else {
+      CHECK_EQ(it->second, vid);
+    }
+  }
+}
+
+void CodeGenC::PrintType(Type t, std::ostream& os) const {  // NOLINT(*)
+  CHECK_EQ(t.lanes(), 1)
+      << "do not yet support vector types";
+  if (t.is_handle()) {
+    os << "void*"; return;
+  }
+  if (t.is_float()) {
+    if (t.bits() == 32) {
+      os << "float"; return;
+    }
+    if (t.bits() == 64) {
+      os << "double"; return;
+    }
+  } else if (t.is_uint()) {
+    switch (t.bits()) {
+      case 8: case 16: case 32: case 64: {
+        os << "uint" << t.bits() << "_t"; return;
+      }
+      case 1: os << "int"; return;
+    }
+  } else if (t.is_int()) {
+    switch (t.bits()) {
+      case 8: case 16: case 32: case 64: {
+        os << "int" << t.bits() << "_t";  return;
+      }
+    }
+  }
+  LOG(FATAL) << "Cannot convert type " << t << " to C type";
+}
+
+CodeGenC::FPrintStmt& CodeGenC::vtable_print_stmt() {  // NOLINT(*)
+  static FPrintStmt inst; return inst;
+}
+
+CodeGenC::FPrintExpr& CodeGenC::vtable_print_expr() {  // NOLINT(*)
+  static FPrintExpr inst; return inst;
+}
+
+inline void PrintConst(const IntImm* op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
+  if (op->type == Int(32)) {
+    std::ostringstream temp;
+    temp << op->value;
+    p->MarkConst(temp.str());
+    os << temp.str();
+  } else {
+    os << "(";
+    p->PrintType(op->type, os);
+    os << ")" << op->value;
+  }
+}
+
+inline void PrintConst(const UIntImm* op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
+  if (op->type == UInt(32)) {
+    std::ostringstream temp;
+    temp << op->value << "U";
+    p->MarkConst(temp.str());
+    os << temp.str();
+  } else {
+    os << "(";
+    p->PrintType(op->type, os);
+    os << ")" << op->value;
+  }
+}
+
+inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
+  switch (op->type.bits()) {
+    case 64: case 32: {
+      std::ostringstream temp;
+      temp << op->value;
+      if (op->type.bits() == 32) temp << 'f';
+      p->MarkConst(temp.str());
+      os << temp.str();
+      break;
+    }
+    case 16: {
+      os << '(';
+      p->PrintType(op->type, os);
+      os << ')' << op->value << 'f';
+      break;
+    }
+    default: LOG(FATAL) << "Bad bit-width for float: " << op->type << "\n";
+  }
+}
+
+TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_expr)
+.set_dispatch<IntImm>([](const IntImm *op, std::ostream& os, CodeGenC *p) {  // NOLINT(*)
+    PrintConst(op, os, p);
+  })
+.set_dispatch<UIntImm>([](const UIntImm *op, std::ostream& os, CodeGenC *p) {  // NOLINT(*)
+    PrintConst(op, os, p);
+  })
+.set_dispatch<FloatImm>([](const FloatImm *op, std::ostream& os, CodeGenC *p) { // NOLINT(*)
+    PrintConst(op, os, p);
+  });
+
+template<typename T>
+inline void PrintBinaryExpr(const T* op,
+                            const char *opstr,
+                            std::ostream& os,  // NOLINT(*)
+                            CodeGenC* p) {
+  os << '(';
+  p->PrintExpr(op->a, os);
+  os << opstr;
+  p->PrintExpr(op->b, os);
+  os << ')';
+}
+
+TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_expr)
+.set_dispatch<Cast>([](const Cast *op, std::ostream& os, CodeGenC *p) {  // NOLINT(*)
+    p->PrintType(op->type, os);
+    os << '(';
+    p->PrintExpr(op->value, os);
+    os << ')';
+  })
+.set_dispatch<Variable>([](const Variable *op, std::ostream& os, CodeGenC* p) {  // NOLINT(*)
+    os << p->GetVarID(op);
+  })
+.set_dispatch<Add>([](const Add *op, std::ostream& os, CodeGenC* p) {  // NOLINT(*)
+    PrintBinaryExpr(op, " + ", os, p);
+  })
+.set_dispatch<Sub>([](const Sub *op, std::ostream& os, CodeGenC* p) {  // NOLINT(*)
+    PrintBinaryExpr(op, " - ", os, p);
+  })
+.set_dispatch<Mul>([](const Mul *op, std::ostream& os, CodeGenC* p) {  // NOLINT(*)
+    PrintBinaryExpr(op, " * ", os, p);
+  })
+.set_dispatch<Div>([](const Div *op, std::ostream& os, CodeGenC* p) {  // NOLINT(*)
+    PrintBinaryExpr(op, " / ", os, p);
+  })
+.set_dispatch<Mod>([](const Mod *op, std::ostream& os, CodeGenC* p) {  // NOLINT(*)
+    PrintBinaryExpr(op, " % ", os, p);
+})
+.set_dispatch<Min>([](const Min *op, std::ostream& os, CodeGenC* p) {  // NOLINT(*)
+    os << "min(";
+    p->PrintExpr(op->a, os);
+    os << ", ";
+    p->PrintExpr(op->b, os);
+    os << ")";
+})
+.set_dispatch<Max>([](const Max *op, std::ostream& os, CodeGenC* p) {  // NOLINT(*)
+    os << "max(";
+    p->PrintExpr(op->a, os);
+    os << ", ";
+    p->PrintExpr(op->b, os);
+    os << ")";
+})
+.set_dispatch<EQ>([](const EQ *op, std::ostream& os, CodeGenC* p) {  // NOLINT(*)
+    PrintBinaryExpr(op, " == ", os, p);
+})
+.set_dispatch<NE>([](const NE *op, std::ostream& os, CodeGenC* p) {  // NOLINT(*)
+    PrintBinaryExpr(op, " != ", os, p);
+})
+.set_dispatch<LT>([](const LT *op, std::ostream& os, CodeGenC* p) {  // NOLINT(*)
+    PrintBinaryExpr(op, " < ", os, p);
+})
+.set_dispatch<LE>([](const LE *op, std::ostream& os, CodeGenC* p) {  // NOLINT(*)
+    PrintBinaryExpr(op, " <= ", os, p);
+})
+.set_dispatch<GT>([](const GT *op, std::ostream& os, CodeGenC* p) {  // NOLINT(*)
+    PrintBinaryExpr(op, " > ", os, p);
+})
+.set_dispatch<GE>([](const GE *op, std::ostream& os, CodeGenC* p) {  // NOLINT(*)
+    PrintBinaryExpr(op, " >= ", os, p);
+})
+.set_dispatch<And>([](const And *op, std::ostream& os, CodeGenC* p) {  // NOLINT(*)
+    PrintBinaryExpr(op, " && ", os, p);
+})
+.set_dispatch<Or>([](const Or *op, std::ostream& os, CodeGenC* p) {  // NOLINT(*)
+    PrintBinaryExpr(op, " || ", os, p);
+})
+.set_dispatch<Not>([](const Not *op, std::ostream& os, CodeGenC* p) {  // NOLINT(*)
+    os << '!';
+    p->PrintExpr(op->a, os);
+  })
+.set_dispatch<Call>([](const Call *op, std::ostream& os, CodeGenC* p) {  // NOLINT(*)
+    os << op->name << "(";
+    for (size_t i = 0; i < op->args.size(); i++) {
+      p->PrintExpr(op->args[i], os);
+      if (i < op->args.size() - 1) {
+        os << ", ";
+      }
+    }
+    os << ")";
+  });
+
+TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_stmt)
+.set_dispatch<AssertStmt>([](const AssertStmt *op, CodeGenC* p) {
+    std::string cond = p->PrintExpr(op->condition);
+    p->PrintIndent();
+    p->stream << "assert(" << cond << ");\n";
+  })
+.set_dispatch<ProducerConsumer>([](const ProducerConsumer *op, CodeGenC* p) {
+    p->PrintStmt(op->body);
+  })
+.set_dispatch<For>([](const For *op, CodeGenC* p) {
+    std::string extent = p->PrintExpr(op->extent);
+    p->PrintIndent();
+    std::string vid = p->AllocVarID(op->loop_var.get());
+    CHECK(is_zero(op->min));
+    p->stream << "for (";
+    p->PrintType(op->loop_var.type(), p->stream);
+    p->stream << ' ' << vid << " = 0; "
+              << vid << " < " << extent
+              << "; ++" << vid << ") {\n";
+    p->indent += 2;
+    p->PrintStmt(op->body);
+    p->indent -= 2;
+    p->PrintIndent();
+    p->stream << "}\n";
+  })
+.set_dispatch<Block>([](const Block *op, CodeGenC* p) {
+    p->PrintStmt(op->first);
+    if (op->rest.defined()) p->PrintStmt(op->rest);
+  })
+.set_dispatch<Evaluate>([](const Evaluate *op, CodeGenC* p) {
+    if (is_const(op->value)) return;
+    std::string vid = p->PrintExpr(op->value);
+    p->PrintIndent();
+    p->stream << "(void)" << vid << ";\n";
+  })
+.set_dispatch<IfThenElse>([](const IfThenElse *op, CodeGenC* p) {
+    std::string cond = p->PrintExpr(op->condition);
+    p->PrintIndent();
+    p->stream << "if (" << cond << ") {\n";
+    p->indent += 2;
+    p->PrintStmt(op->then_case);
+    p->indent -= 2;
+    if (op->else_case.defined()) {
+      p->PrintIndent();
+      p->stream << "} else {\n";
+      p->indent += 2;
+      p->PrintStmt(op->else_case);
+      p->indent -= 2;
+    }
+    p->PrintIndent();
+    p->stream << "}\n";
+});
+
+
+#define DISPATCH_EXPR(OP)                            \
+  set_dispatch<OP>([](const OP *op, std::ostream&os, CodeGenC* p) { \
+      p->PrintExpr(op, os); })
+
+TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_expr)
+.DISPATCH_EXPR(Load)
+.DISPATCH_EXPR(Let)
+.DISPATCH_EXPR(Ramp)
+.DISPATCH_EXPR(Broadcast)
+.DISPATCH_EXPR(Select);
+
+void CodeGenC::PrintExpr(const Load* op, std::ostream& os) {  // NOLINT(*)
+  std::string vid = GetVarID(op->buffer_var.get());
+  if (!BufferTypeMatch(op->buffer_var.get(), op->type)) {
+    os << "((const ";
+    PrintType(op->type, os);
+    os << "*)" << vid << ')';
+  } else {
+    os << vid;
+  }
+  os << '[';
+  PrintExpr(op->index, os);
+  os << ']';
+}
+
+void CodeGenC::PrintExpr(const Let* op, std::ostream& os) {  // NOLINT(*)
+  CHECK(print_ssa_form_)
+      << "LetExpr is only supported by print SSA form";
+  std::string value = PrintExpr(op->value);
+  CHECK(!var_idmap_.count(op->var.get()));
+  var_idmap_[op->var.get()] = value;
+}
+
+void CodeGenC::PrintExpr(const Ramp* op, std::ostream& os) {  // NOLINT(*)
+  LOG(FATAL) << "not supported ";
+}
+
+void CodeGenC::PrintExpr(const Broadcast* op, std::ostream& os) {   // NOLINT(*)
+  LOG(FATAL) << "not supported ";
+}
+
+void CodeGenC::PrintExpr(const Select* op, std::ostream& os) {  // NOLINT(*)
+  LOG(FATAL) << "not supported ";
+}
+
+// Disoatch back to member functions
+TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_stmt)
+.set_dispatch<LetStmt>([](const LetStmt *op, CodeGenC* p) { p->PrintStmt(op); })
+.set_dispatch<Store>([](const Store *op, CodeGenC* p) { p->PrintStmt(op); })
+.set_dispatch<Allocate>([](const Allocate *op, CodeGenC* p) { p->PrintStmt(op); })
+.set_dispatch<AttrStmt>([](const AttrStmt *op, CodeGenC* p) { p->PrintStmt(op); });
+
+
+void CodeGenC::PrintStmt(const LetStmt* op) {
+  std::string value = PrintExpr(op->value);
+  if (print_ssa_form_) {
+    CHECK(!var_idmap_.count(op->var.get()));
+    var_idmap_[op->var.get()] = value;
+  } else {
+    PrintIndent();
+    PrintType(op->var.type(), this->stream);
+    this->stream << ' '
+       << AllocVarID(op->var.get())
+       << " = " << value << ";\n";
+  }
+  PrintStmt(op->body);
+}
+
+void CodeGenC::PrintStmt(const Store* op) {
+  std::string index = this->PrintExpr(op->index);
+  std::string value = this->PrintExpr(op->value);
+  this->PrintIndent();
+  std::string vid = GetVarID(op->buffer_var.get());
+  if (!BufferTypeMatch(op->buffer_var.get(), op->value.type())) {
+    this->stream << "((";
+    PrintType(op->value.type(), this->stream);
+    this->stream << "*)" << vid << ')';
+  } else {
+    this->stream << vid;
+  }
+  this->stream << '[' << index
+               << "] = " << value
+               << ";\n";
+}
+
+void CodeGenC::PrintStmt(const Allocate* op) {
+  this->PrintIndent();
+  int32_t constant_size = op->constant_allocation_size();
+  std::string vid = AllocVarID(op->buffer_var.get());
+  CHECK(!op->new_expr.defined());
+  CHECK(!is_zero(op->condition));
+  CHECK_GT(constant_size, 0)
+      << "Can only handle constant size stack allocation for now";
+  PrintType(op->type, stream);
+  stream << ' '<< vid << '['
+         << constant_size << "]\n;";
+  this->PrintStmt(op->body);
+}
+
+void CodeGenC::PrintStmt(const AttrStmt* op) {
+  if (op->type_key == "scope") {
+    IterVar iv(op->node.node_);
+    if (iv->thread_tag.length() != 0) {
+      this->PrintIndent();
+      PrintType(iv->var.type(), stream);
+      stream << ' '
+             << AllocVarID(iv->var.get())
+             << " = " << iv->thread_tag << ";\n";
+    }
+  }
+  this->PrintStmt(op->body);
+}
+
+}  // namespace codegen
+}  // namespace tvm
diff --git a/src/codegen/codegen_c.h b/src/codegen/codegen_c.h
new file mode 100644
index 000000000..a8ce1828e
--- /dev/null
+++ b/src/codegen/codegen_c.h
@@ -0,0 +1,140 @@
+/*!
+ *  Copyright (c) 2016 by Contributors
+ * \file codegen_c.h
+ * \brief Common utilities to generated C style code.
+ */
+#ifndef TVM_CODEGEN_CODEGEN_C_H_
+#define TVM_CODEGEN_CODEGEN_C_H_
+
+#include <tvm/ir.h>
+#include <tvm/ir_visitor.h>
+#include <string>
+#include <unordered_map>
+
+namespace tvm {
+namespace codegen {
+
+/*!
+ * \brief A base class to generate C code.
+ *
+ *  CodeGenC have two modes: generate SSA formed C code or normal form.
+ */
+class CodeGenC {
+ public:
+  /*!
+   * \brief Generate the C code of statement
+   * \param body The body of the function.
+   * \param fun_name The name of the function.
+   * \param args The arguments to the function.
+   * \param output_ssa Whether output ssa form.
+   * \note Only call compile once,
+   *  create a new codegen object each time.
+   */
+  std::string Compile(Stmt body,
+                      std::string fun_name,
+                      Array<Var> args,
+                      bool output_ssa);
+  /*!
+   * \brief Print the Stmt n to CodeGenC->stream
+   * \param n The statement to be printed.
+   */
+  void PrintStmt(const Stmt& n);
+  /*!
+   * \brief Print the expression n(or its ssa id if in ssa mode) into os
+   * \param n The expression to be printed.
+   * \param os The output stream
+   */
+  void PrintExpr(const Expr& n, std::ostream& os);  // NOLINT(*)
+  /*!
+   * \brief Same as PrintExpr, but simply returns result string
+   * \param n The expression to be printed.
+   */
+  inline std::string PrintExpr(const Expr& n) {
+    std::ostringstream os;
+    PrintExpr(n, os);
+    return os.str();
+  }
+  /*! \brief print the current indented value */
+  void PrintIndent();
+  /*!
+   * \brief Register constant value appeared in expresion tree
+   *  This avoid generated a ssa id for each appearance of the value
+   * \param value The constant value.
+   */
+  void MarkConst(std::string value);
+  /*!
+   * \brief Allocate a variable name for a newly defined var.
+   * \param v The variable.
+   * \return the variable name.
+   */
+  std::string AllocVarID(const Variable* v);
+  /*!
+   * \brief Get a variable name.
+   * \param v The variable.
+   * \return the variable name.
+   */
+  std::string GetVarID(const Variable* v) const;
+  /*!
+   * Print Type represetnation of type t.
+   * \param t The type representation.
+   * \return os The stream to print the ctype into
+   */
+  virtual void PrintType(Type t, std::ostream& os) const; // NOLINT(*)
+  // The following parts are overloadable print operations.
+  virtual void PrintStmt(const ir::LetStmt* op);
+  virtual void PrintStmt(const ir::Store* op);
+  virtual void PrintStmt(const ir::Allocate* op);
+  virtual void PrintStmt(const ir::AttrStmt* op);
+  virtual void PrintExpr(const ir::Load* op, std::ostream& os);  // NOLINT(*)
+  virtual void PrintExpr(const ir::Let* op, std::ostream& os);  // NOLINT(*)
+  virtual void PrintExpr(const ir::Ramp* op, std::ostream& os);  // NOLINT(*)
+  virtual void PrintExpr(const ir::Broadcast* op, std::ostream& os);  // NOLINT(*)
+  virtual void PrintExpr(const ir::Select* op, std::ostream& os);  // NOLINT(*)
+  /*! \brief function print into the ostream */
+  using FPrintExpr = IRFunctor<void(const NodeRef&, std::ostream& os, CodeGenC *)>; // NOLINT(*)
+  /*! \brief function to to print normal code */
+  using FPrintStmt = IRFunctor<void(const NodeRef&, CodeGenC *)>;
+  // vtable to print code
+  static FPrintStmt& vtable_print_stmt();
+  // vtable to print code
+  static FPrintExpr& vtable_print_expr();
+  /*! \brief The current indentation value */
+  int indent{0};
+  /*! \brief the stream to be printed */
+  std::ostringstream stream;
+
+ private:
+  /*!
+   * \brief Get the SSA ID corresponds to src
+   *  If necessary, generate new assignment
+   * \param src The source expression
+   * \param t The type of the expression.
+   */
+  std::string SSAGetID(std::string src, Type t);
+  /*!
+   * \brief If buffer is allocated as type t.
+   * \param buf_var The buffer variable.
+   * \param t The type to be checked.
+   */
+  bool BufferTypeMatch(const Variable* buf_var, Type t) const;
+  /*!
+   * \brief get a unique name with the corresponding prefix
+   * \param prefix The prefix of the name
+   * \return The returned name.
+   */
+  std::string GetUniqueName(std::string prefix);
+  /*! \brief whether to print in SSA form */
+  bool print_ssa_form_{true};
+  /*! \brief name of each variable */
+  std::unordered_map<const Variable*, std::string> var_idmap_;
+  /*! \brief the data type of allocated buffers */
+  std::unordered_map<const Variable*, Type> alloc_buf_type_;
+  /*! \brief name allocation map */
+  std::unordered_map<std::string, int> name_alloc_map_;
+  /*! \brief assignment map of ssa */
+  std::unordered_map<std::string, std::string> ssa_assign_map_;
+};
+
+}  // namespace codegen
+}  // namespace tvm
+#endif  // TVM_CODEGEN_CODEGEN_C_H_
diff --git a/tests/python/test_codegen_cuda.py b/tests/python/test_codegen_cuda.py
index b93e80e52..0f0a8df30 100644
--- a/tests/python/test_codegen_cuda.py
+++ b/tests/python/test_codegen_cuda.py
@@ -19,10 +19,18 @@ def mock_test_add():
     bounds = tvm.schedule.InferBound(s)
     stmt = tvm.ir_pass.ScheduleOps(s, bounds)
 
+
     Ab = tvm.Buffer(A.shape, A.dtype, name='A')
     Bb = tvm.Buffer(B.shape, B.dtype, name='B')
     Cb = tvm.Buffer(C.shape, C.dtype, name='C')
+    stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb})
+    print(stmt)
+    output_ssa = False
+    code = tvm.codegen.CompileToC(stmt, "myadd",
+                                  [Ab.ptr, Bb.ptr, Cb.ptr, n],
+                                  output_ssa)
 
+    print(code)
     def codegen():
         # generate host/device code
         host_code, device_code = tvm.codegen.GenCUDA(
-- 
GitLab