Skip to content
Snippets Groups Projects
Commit eee0ebef authored by tqchen's avatar tqchen
Browse files

Stronger type checker during conversion

parent 57a74936
No related branches found
No related tags found
No related merge requests found
......@@ -27,6 +27,7 @@ using Halide::IR::FunctionRef;
using Halide::IR::FunctionBaseNode;
using Halide::Internal::Stmt;
using Halide::Internal::IRPrinter;
using Halide::Internal::Variable;
/*! \brief a named variable in TVM */
class Var : public Halide::VarExpr {
......@@ -35,6 +36,9 @@ class Var : public Halide::VarExpr {
Type t = Int(32)) : VarExpr(name_hint, t) {}
explicit Var(std::shared_ptr<Node> n) : VarExpr(n) {}
/*! \brief type indicate the container type */
using ContainerType = Variable;
};
......
......@@ -83,7 +83,6 @@ using Halide::Internal::UIntImm;
using Halide::Internal::FloatImm;
using Halide::Internal::StringImm;
using Halide::Internal::Cast;
using Halide::Internal::Variable;
using Halide::Internal::Add;
using Halide::Internal::Sub;
using Halide::Internal::Mul;
......
......@@ -10,4 +10,5 @@ from . import ir_pass
from . import collections
from . import schedule
from ._base import TVMError
from .function import *
......@@ -54,6 +54,43 @@ inline const char* TypeId2Str(ArgVariantID type_id) {
}
}
template<typename T>
struct NodeTypeChecker {
static inline void Check(Node* sptr) {
using ContainerType = typename T::ContainerType;
// use dynamic RTTI for safety
CHECK(dynamic_cast<ContainerType*>(sptr))
<< "wrong type specified, expected " << typeid(ContainerType).name();
}
};
template<typename T>
struct NodeTypeChecker<Array<T> > {
static inline void Check(Node* sptr) {
// use dynamic RTTI for safety
CHECK(sptr != nullptr && sptr->is_type<ArrayNode>())
<< "wrong type specified, expected Array";
ArrayNode* n = static_cast<ArrayNode*>(sptr);
for (const auto& p : n->data) {
NodeTypeChecker<T>::Check(p.get());
}
}
};
template<typename K, typename V>
struct NodeTypeChecker<Map<K, V> > {
static inline void Check(Node* sptr) {
// use dynamic RTTI for safety
CHECK(sptr != nullptr && sptr->is_type<MapNode>())
<< "wrong type specified, expected Map";
MapNode* n = static_cast<MapNode*>(sptr);
for (const auto& kv : n->data) {
NodeTypeChecker<K>::Check(kv.first.get());
NodeTypeChecker<V>::Check(kv.second.get());
}
}
};
/*! \brief Variant container for API calls */
class APIVariantValue {
public:
......@@ -109,13 +146,11 @@ class APIVariantValue {
return operator=(Type2String(value));
}
template<typename T,
typename = typename std::enable_if<std::is_base_of<NodeRef, T>::value>::type>
typename = typename std::enable_if<std::is_base_of<NodeRef, T>::value>::type>
inline operator T() const {
if (type_id == kNull) return T();
CHECK_EQ(type_id, kNodeHandle);
// use dynamic RTTI for safety
CHECK(dynamic_cast<typename T::ContainerType*>(sptr.get()))
<< "wrong type specified, expected " << typeid(typename T::ContainerType).name();
NodeTypeChecker<T>::Check(sptr.get());
return T(sptr);
}
inline operator Expr() const {
......
......@@ -10,5 +10,16 @@ def test_inline():
print(stmt)
assert(tvm.ir_pass.VerifySSA(stmt))
try:
# pass in int array(wrong argument type)
# must raise an error
stmt = tvm.ir_pass.Inline(
T, [1,2,3], T.op.body, stmt)
assert False
except tvm.TVMError:
pass
if __name__ == "__main__":
test_inline()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment