diff --git a/include/tvm/expr.h b/include/tvm/expr.h index bb23ab4574abf4475a6fa6b8c4d1951f8ad6a41d..adf0d245d20a84a9bbcfc8508aaaa7e865f8413e 100644 --- a/include/tvm/expr.h +++ b/include/tvm/expr.h @@ -27,6 +27,7 @@ using Halide::IR::FunctionRef; using Halide::IR::FunctionBaseNode; using Halide::Internal::Stmt; using Halide::Internal::IRPrinter; +using Halide::Internal::Variable; /*! \brief a named variable in TVM */ class Var : public Halide::VarExpr { @@ -35,6 +36,9 @@ class Var : public Halide::VarExpr { Type t = Int(32)) : VarExpr(name_hint, t) {} explicit Var(std::shared_ptr<Node> n) : VarExpr(n) {} + + /*! \brief type indicate the container type */ + using ContainerType = Variable; }; diff --git a/include/tvm/ir.h b/include/tvm/ir.h index 5fa906296304efc332a43c5de90b686dc5ad74d3..37ce3352e95dff9cb3328d40328e5d828decfa1e 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -83,7 +83,6 @@ using Halide::Internal::UIntImm; using Halide::Internal::FloatImm; using Halide::Internal::StringImm; using Halide::Internal::Cast; -using Halide::Internal::Variable; using Halide::Internal::Add; using Halide::Internal::Sub; using Halide::Internal::Mul; diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index f1c2ea41afead15ad5c2b0ec8c6bda63063f15e4..07fca6ecab39c00c3b297c0a605357601981d68e 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -10,4 +10,5 @@ from . import ir_pass from . import collections from . import schedule +from ._base import TVMError from .function import * diff --git a/src/c_api/c_api_registry.h b/src/c_api/c_api_registry.h index 9edce8ea64f91de7e19bd50e54bccfa6810a1892..885c45b14432a5f52a0a22d4226376ee741599cb 100644 --- a/src/c_api/c_api_registry.h +++ b/src/c_api/c_api_registry.h @@ -54,6 +54,43 @@ inline const char* TypeId2Str(ArgVariantID type_id) { } } +template<typename T> +struct NodeTypeChecker { + static inline void Check(Node* sptr) { + using ContainerType = typename T::ContainerType; + // use dynamic RTTI for safety + CHECK(dynamic_cast<ContainerType*>(sptr)) + << "wrong type specified, expected " << typeid(ContainerType).name(); + } +}; + +template<typename T> +struct NodeTypeChecker<Array<T> > { + static inline void Check(Node* sptr) { + // use dynamic RTTI for safety + CHECK(sptr != nullptr && sptr->is_type<ArrayNode>()) + << "wrong type specified, expected Array"; + ArrayNode* n = static_cast<ArrayNode*>(sptr); + for (const auto& p : n->data) { + NodeTypeChecker<T>::Check(p.get()); + } + } +}; + +template<typename K, typename V> +struct NodeTypeChecker<Map<K, V> > { + static inline void Check(Node* sptr) { + // use dynamic RTTI for safety + CHECK(sptr != nullptr && sptr->is_type<MapNode>()) + << "wrong type specified, expected Map"; + MapNode* n = static_cast<MapNode*>(sptr); + for (const auto& kv : n->data) { + NodeTypeChecker<K>::Check(kv.first.get()); + NodeTypeChecker<V>::Check(kv.second.get()); + } + } +}; + /*! \brief Variant container for API calls */ class APIVariantValue { public: @@ -109,13 +146,11 @@ class APIVariantValue { return operator=(Type2String(value)); } template<typename T, - typename = typename std::enable_if<std::is_base_of<NodeRef, T>::value>::type> + typename = typename std::enable_if<std::is_base_of<NodeRef, T>::value>::type> inline operator T() const { if (type_id == kNull) return T(); CHECK_EQ(type_id, kNodeHandle); - // use dynamic RTTI for safety - CHECK(dynamic_cast<typename T::ContainerType*>(sptr.get())) - << "wrong type specified, expected " << typeid(typename T::ContainerType).name(); + NodeTypeChecker<T>::Check(sptr.get()); return T(sptr); } inline operator Expr() const { diff --git a/tests/python/test_inline.py b/tests/python/test_inline.py index 73305688f51d5617d9d76e88e5fa059d2c3c1775..c3f6b6aa7b15cb62c36d11c0808fc1379bde693e 100644 --- a/tests/python/test_inline.py +++ b/tests/python/test_inline.py @@ -10,5 +10,16 @@ def test_inline(): print(stmt) assert(tvm.ir_pass.VerifySSA(stmt)) + try: + # pass in int array(wrong argument type) + # must raise an error + stmt = tvm.ir_pass.Inline( + T, [1,2,3], T.op.body, stmt) + assert False + except tvm.TVMError: + pass + + + if __name__ == "__main__": test_inline()