Skip to content
Snippets Groups Projects
Commit 383494a5 authored by Tianqi Chen's avatar Tianqi Chen Committed by GitHub
Browse files

[API] Move all RTTI related code to one place (#20)

* [API] Move all RTTI related code to one place

* add back rtti comment
parent 4d4e19ce
No related branches found
No related tags found
No related merge requests found
Subproject commit 6375e6b76f6b70d58f66b357d946c971843f3169
Subproject commit af2a2fcee59378f33817d7745a8110b9cc836438
Subproject commit 749e570c19423fe679a5f496e2394ba3bed75a16
Subproject commit 3a51614d39b69fdb5de1efcf1016426626d267a6
......@@ -15,7 +15,7 @@ using RetValue = APIVariantValue;
TVM_REGISTER_API(_pass_Simplify)
.set_body([](const ArgStack& args, RetValue *ret) {
if (dynamic_cast<Stmt::ContainerType*>(args.at(0).sptr.get())) {
if (NodeTypeChecker<Stmt>::Check(args.at(0).sptr.get())) {
*ret = Simplify(args.at(0).operator Stmt());
} else {
*ret = Simplify(args.at(0).operator Expr());
......@@ -24,13 +24,10 @@ TVM_REGISTER_API(_pass_Simplify)
TVM_REGISTER_API(_pass_Equal)
.set_body([](const ArgStack& args, RetValue *ret) {
if (dynamic_cast<Stmt::ContainerType*>(args.at(0).sptr.get())) {
CHECK(args.at(1).type_id == kNodeHandle);
if (NodeTypeChecker<Stmt>::Check(args.at(0).sptr.get())) {
*ret = Equal(args.at(0).operator Stmt(), args.at(1).operator Stmt());
} else {
Expr a = args.at(0).operator Expr();
Expr b = args.at(1).operator Expr();
*ret = Equal(a, b);
*ret = Equal(args.at(0).operator Expr(), args.at(1).operator Expr());
}
});
......
......@@ -33,41 +33,65 @@ inline const char* TypeId2Str(ArgVariantID type_id) {
template<typename T>
struct NodeTypeChecker {
static inline void Check(Node* sptr) {
static inline bool Check(Node* sptr) {
// This is the only place in the project where RTTI is used
// It can be turned off, but will make non strict checking.
// TODO(tqchen) possibly find alternative to turn of RTTI
using ContainerType = typename T::ContainerType;
// use dynamic RTTI for safety
CHECK(dynamic_cast<ContainerType*>(sptr))
<< "wrong type specified, expected " << typeid(ContainerType).name();
return (dynamic_cast<ContainerType*>(sptr) != nullptr);
}
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
using ContainerType = typename T::ContainerType;
os << ContainerType::_type_key;
}
};
template<typename T>
struct NodeTypeChecker<Array<T> > {
static inline void Check(Node* sptr) {
// use dynamic RTTI for safety
CHECK(sptr != nullptr && sptr->is_type<ArrayNode>())
<< "wrong type specified, expected Array";
static inline bool Check(Node* sptr) {
if (sptr == nullptr) return false;
if (!sptr->is_type<ArrayNode>()) return false;
ArrayNode* n = static_cast<ArrayNode*>(sptr);
for (const auto& p : n->data) {
NodeTypeChecker<T>::Check(p.get());
if (!NodeTypeChecker<T>::Check(p.get())) return false;
}
return true;
}
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
os << "array<";
NodeTypeChecker<T>::PrintName(os);
os << ">";
}
};
template<typename K, typename V>
struct NodeTypeChecker<Map<K, V> > {
static inline void Check(Node* sptr) {
// use dynamic RTTI for safety
CHECK(sptr != nullptr && sptr->is_type<MapNode>())
<< "wrong type specified, expected Map";
static inline bool Check(Node* sptr) {
if (sptr == nullptr) return false;
if (!sptr->is_type<MapNode>()) return false;
MapNode* n = static_cast<MapNode*>(sptr);
for (const auto& kv : n->data) {
NodeTypeChecker<K>::Check(kv.first.get());
NodeTypeChecker<V>::Check(kv.second.get());
if (!NodeTypeChecker<K>::Check(kv.first.get())) return false;
if (!NodeTypeChecker<V>::Check(kv.second.get())) return false;
}
return true;
}
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
os << "map<";
NodeTypeChecker<K>::PrintName(os);
os << ',';
NodeTypeChecker<V>::PrintName(os);
os << '>';
}
};
template<typename T>
inline std::string NodeTypeName() {
std::ostringstream os;
NodeTypeChecker<T>::PrintName(os);
return os.str();
}
/*! \brief Variant container for API calls */
class APIVariantValue {
public:
......@@ -127,7 +151,8 @@ class APIVariantValue {
inline operator T() const {
if (type_id == kNull) return T();
CHECK_EQ(type_id, kNodeHandle);
NodeTypeChecker<T>::Check(sptr.get());
CHECK(NodeTypeChecker<T>::Check(sptr.get()))
<< "Did not get expected type " << NodeTypeName<T>();
return T(sptr);
}
inline operator Expr() const {
......@@ -140,7 +165,7 @@ class APIVariantValue {
if (sptr->is_type<IterVarNode>()) {
return IterVar(sptr)->var;
} else {
CHECK(dynamic_cast<typename Expr::ContainerType*>(sptr.get()))
CHECK(NodeTypeChecker<Expr>::Check(sptr.get()))
<< "did not pass in Expr in a place need Expr";
return Expr(sptr);
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment