From eee0ebef17af4de8ff3af348771f99d4c6b01bd8 Mon Sep 17 00:00:00 2001 From: tqchen <tianqi.tchen@gmail.com> Date: Fri, 6 Jan 2017 15:39:31 -0800 Subject: [PATCH] Stronger type checker during conversion --- include/tvm/expr.h | 4 ++++ include/tvm/ir.h | 1 - python/tvm/__init__.py | 1 + src/c_api/c_api_registry.h | 43 +++++++++++++++++++++++++++++++++---- tests/python/test_inline.py | 11 ++++++++++ 5 files changed, 55 insertions(+), 5 deletions(-) diff --git a/include/tvm/expr.h b/include/tvm/expr.h index bb23ab457..adf0d245d 100644 --- a/include/tvm/expr.h +++ b/include/tvm/expr.h @@ -27,6 +27,7 @@ using Halide::IR::FunctionRef; using Halide::IR::FunctionBaseNode; using Halide::Internal::Stmt; using Halide::Internal::IRPrinter; +using Halide::Internal::Variable; /*! \brief a named variable in TVM */ class Var : public Halide::VarExpr { @@ -35,6 +36,9 @@ class Var : public Halide::VarExpr { Type t = Int(32)) : VarExpr(name_hint, t) {} explicit Var(std::shared_ptr<Node> n) : VarExpr(n) {} + + /*! \brief type indicate the container type */ + using ContainerType = Variable; }; diff --git a/include/tvm/ir.h b/include/tvm/ir.h index 5fa906296..37ce3352e 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -83,7 +83,6 @@ using Halide::Internal::UIntImm; using Halide::Internal::FloatImm; using Halide::Internal::StringImm; using Halide::Internal::Cast; -using Halide::Internal::Variable; using Halide::Internal::Add; using Halide::Internal::Sub; using Halide::Internal::Mul; diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index f1c2ea41a..07fca6eca 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -10,4 +10,5 @@ from . import ir_pass from . import collections from . import schedule +from ._base import TVMError from .function import * diff --git a/src/c_api/c_api_registry.h b/src/c_api/c_api_registry.h index 9edce8ea6..885c45b14 100644 --- a/src/c_api/c_api_registry.h +++ b/src/c_api/c_api_registry.h @@ -54,6 +54,43 @@ inline const char* TypeId2Str(ArgVariantID type_id) { } } +template<typename T> +struct NodeTypeChecker { + static inline void Check(Node* sptr) { + using ContainerType = typename T::ContainerType; + // use dynamic RTTI for safety + CHECK(dynamic_cast<ContainerType*>(sptr)) + << "wrong type specified, expected " << typeid(ContainerType).name(); + } +}; + +template<typename T> +struct NodeTypeChecker<Array<T> > { + static inline void Check(Node* sptr) { + // use dynamic RTTI for safety + CHECK(sptr != nullptr && sptr->is_type<ArrayNode>()) + << "wrong type specified, expected Array"; + ArrayNode* n = static_cast<ArrayNode*>(sptr); + for (const auto& p : n->data) { + NodeTypeChecker<T>::Check(p.get()); + } + } +}; + +template<typename K, typename V> +struct NodeTypeChecker<Map<K, V> > { + static inline void Check(Node* sptr) { + // use dynamic RTTI for safety + CHECK(sptr != nullptr && sptr->is_type<MapNode>()) + << "wrong type specified, expected Map"; + MapNode* n = static_cast<MapNode*>(sptr); + for (const auto& kv : n->data) { + NodeTypeChecker<K>::Check(kv.first.get()); + NodeTypeChecker<V>::Check(kv.second.get()); + } + } +}; + /*! \brief Variant container for API calls */ class APIVariantValue { public: @@ -109,13 +146,11 @@ class APIVariantValue { return operator=(Type2String(value)); } template<typename T, - typename = typename std::enable_if<std::is_base_of<NodeRef, T>::value>::type> + typename = typename std::enable_if<std::is_base_of<NodeRef, T>::value>::type> inline operator T() const { if (type_id == kNull) return T(); CHECK_EQ(type_id, kNodeHandle); - // use dynamic RTTI for safety - CHECK(dynamic_cast<typename T::ContainerType*>(sptr.get())) - << "wrong type specified, expected " << typeid(typename T::ContainerType).name(); + NodeTypeChecker<T>::Check(sptr.get()); return T(sptr); } inline operator Expr() const { diff --git a/tests/python/test_inline.py b/tests/python/test_inline.py index 73305688f..c3f6b6aa7 100644 --- a/tests/python/test_inline.py +++ b/tests/python/test_inline.py @@ -10,5 +10,16 @@ def test_inline(): print(stmt) assert(tvm.ir_pass.VerifySSA(stmt)) + try: + # pass in int array(wrong argument type) + # must raise an error + stmt = tvm.ir_pass.Inline( + T, [1,2,3], T.op.body, stmt) + assert False + except tvm.TVMError: + pass + + + if __name__ == "__main__": test_inline() -- GitLab