From eee0ebef17af4de8ff3af348771f99d4c6b01bd8 Mon Sep 17 00:00:00 2001
From: tqchen <tianqi.tchen@gmail.com>
Date: Fri, 6 Jan 2017 15:39:31 -0800
Subject: [PATCH] Stronger type checker during conversion

---
 include/tvm/expr.h          |  4 ++++
 include/tvm/ir.h            |  1 -
 python/tvm/__init__.py      |  1 +
 src/c_api/c_api_registry.h  | 43 +++++++++++++++++++++++++++++++++----
 tests/python/test_inline.py | 11 ++++++++++
 5 files changed, 55 insertions(+), 5 deletions(-)

diff --git a/include/tvm/expr.h b/include/tvm/expr.h
index bb23ab457..adf0d245d 100644
--- a/include/tvm/expr.h
+++ b/include/tvm/expr.h
@@ -27,6 +27,7 @@ using Halide::IR::FunctionRef;
 using Halide::IR::FunctionBaseNode;
 using Halide::Internal::Stmt;
 using Halide::Internal::IRPrinter;
+using Halide::Internal::Variable;
 
 /*! \brief a named variable in TVM */
 class Var : public Halide::VarExpr {
@@ -35,6 +36,9 @@ class Var : public Halide::VarExpr {
                Type t = Int(32)) : VarExpr(name_hint, t) {}
 
   explicit Var(std::shared_ptr<Node> n) : VarExpr(n) {}
+
+  /*! \brief type indicate the container type */
+  using ContainerType = Variable;
 };
 
 
diff --git a/include/tvm/ir.h b/include/tvm/ir.h
index 5fa906296..37ce3352e 100644
--- a/include/tvm/ir.h
+++ b/include/tvm/ir.h
@@ -83,7 +83,6 @@ using Halide::Internal::UIntImm;
 using Halide::Internal::FloatImm;
 using Halide::Internal::StringImm;
 using Halide::Internal::Cast;
-using Halide::Internal::Variable;
 using Halide::Internal::Add;
 using Halide::Internal::Sub;
 using Halide::Internal::Mul;
diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py
index f1c2ea41a..07fca6eca 100644
--- a/python/tvm/__init__.py
+++ b/python/tvm/__init__.py
@@ -10,4 +10,5 @@ from . import ir_pass
 from . import collections
 from . import schedule
 
+from ._base import TVMError
 from .function import *
diff --git a/src/c_api/c_api_registry.h b/src/c_api/c_api_registry.h
index 9edce8ea6..885c45b14 100644
--- a/src/c_api/c_api_registry.h
+++ b/src/c_api/c_api_registry.h
@@ -54,6 +54,43 @@ inline const char* TypeId2Str(ArgVariantID type_id) {
   }
 }
 
+template<typename T>
+struct NodeTypeChecker {
+  static inline void Check(Node* sptr) {
+    using ContainerType = typename T::ContainerType;
+    // use dynamic RTTI for safety
+    CHECK(dynamic_cast<ContainerType*>(sptr))
+        << "wrong type specified, expected " << typeid(ContainerType).name();
+  }
+};
+
+template<typename T>
+struct NodeTypeChecker<Array<T> > {
+  static inline void Check(Node* sptr) {
+    // use dynamic RTTI for safety
+    CHECK(sptr != nullptr && sptr->is_type<ArrayNode>())
+        << "wrong type specified, expected Array";
+    ArrayNode* n = static_cast<ArrayNode*>(sptr);
+    for (const auto& p : n->data) {
+      NodeTypeChecker<T>::Check(p.get());
+    }
+  }
+};
+
+template<typename K, typename V>
+struct NodeTypeChecker<Map<K, V> > {
+  static inline void Check(Node* sptr) {
+    // use dynamic RTTI for safety
+    CHECK(sptr != nullptr && sptr->is_type<MapNode>())
+        << "wrong type specified, expected Map";
+    MapNode* n = static_cast<MapNode*>(sptr);
+    for (const auto& kv : n->data) {
+      NodeTypeChecker<K>::Check(kv.first.get());
+      NodeTypeChecker<V>::Check(kv.second.get());
+    }
+  }
+};
+
 /*! \brief Variant container for API calls */
 class APIVariantValue {
  public:
@@ -109,13 +146,11 @@ class APIVariantValue {
     return operator=(Type2String(value));
   }
   template<typename T,
-         typename = typename std::enable_if<std::is_base_of<NodeRef, T>::value>::type>
+           typename = typename std::enable_if<std::is_base_of<NodeRef, T>::value>::type>
   inline operator T() const {
     if (type_id == kNull) return T();
     CHECK_EQ(type_id, kNodeHandle);
-    // use dynamic RTTI for safety
-    CHECK(dynamic_cast<typename T::ContainerType*>(sptr.get()))
-        << "wrong type specified, expected " << typeid(typename T::ContainerType).name();
+    NodeTypeChecker<T>::Check(sptr.get());
     return T(sptr);
   }
   inline operator Expr() const {
diff --git a/tests/python/test_inline.py b/tests/python/test_inline.py
index 73305688f..c3f6b6aa7 100644
--- a/tests/python/test_inline.py
+++ b/tests/python/test_inline.py
@@ -10,5 +10,16 @@ def test_inline():
     print(stmt)
     assert(tvm.ir_pass.VerifySSA(stmt))
 
+    try:
+        # pass in int array(wrong argument type)
+        # must raise an error
+        stmt = tvm.ir_pass.Inline(
+            T, [1,2,3], T.op.body, stmt)
+        assert False
+    except tvm.TVMError:
+        pass
+
+
+
 if __name__ == "__main__":
     test_inline()
-- 
GitLab