From 383494a51a3c9169e50650570ffd178e750e60e5 Mon Sep 17 00:00:00 2001
From: Tianqi Chen <tqchen@users.noreply.github.com>
Date: Thu, 19 Jan 2017 07:42:33 -0800
Subject: [PATCH] [API] Move all RTTI related code to one place (#20)

* [API] Move all RTTI related code to one place

* add back rtti comment
---
 HalideIR                   |  2 +-
 dmlc-core                  |  2 +-
 src/c_api/c_api_pass.cc    |  9 ++----
 src/c_api/c_api_registry.h | 59 +++++++++++++++++++++++++++-----------
 4 files changed, 47 insertions(+), 25 deletions(-)

diff --git a/HalideIR b/HalideIR
index 6375e6b76..af2a2fcee 160000
--- a/HalideIR
+++ b/HalideIR
@@ -1 +1 @@
-Subproject commit 6375e6b76f6b70d58f66b357d946c971843f3169
+Subproject commit af2a2fcee59378f33817d7745a8110b9cc836438
diff --git a/dmlc-core b/dmlc-core
index 749e570c1..3a51614d3 160000
--- a/dmlc-core
+++ b/dmlc-core
@@ -1 +1 @@
-Subproject commit 749e570c19423fe679a5f496e2394ba3bed75a16
+Subproject commit 3a51614d39b69fdb5de1efcf1016426626d267a6
diff --git a/src/c_api/c_api_pass.cc b/src/c_api/c_api_pass.cc
index e45e25a26..d7d41c4f6 100644
--- a/src/c_api/c_api_pass.cc
+++ b/src/c_api/c_api_pass.cc
@@ -15,7 +15,7 @@ using RetValue = APIVariantValue;
 
 TVM_REGISTER_API(_pass_Simplify)
 .set_body([](const ArgStack& args, RetValue *ret) {
-    if (dynamic_cast<Stmt::ContainerType*>(args.at(0).sptr.get())) {
+    if (NodeTypeChecker<Stmt>::Check(args.at(0).sptr.get())) {
       *ret = Simplify(args.at(0).operator Stmt());
     } else {
       *ret = Simplify(args.at(0).operator Expr());
@@ -24,13 +24,10 @@ TVM_REGISTER_API(_pass_Simplify)
 
 TVM_REGISTER_API(_pass_Equal)
 .set_body([](const ArgStack& args, RetValue *ret) {
-    if (dynamic_cast<Stmt::ContainerType*>(args.at(0).sptr.get())) {
-      CHECK(args.at(1).type_id == kNodeHandle);
+    if (NodeTypeChecker<Stmt>::Check(args.at(0).sptr.get())) {
       *ret = Equal(args.at(0).operator Stmt(), args.at(1).operator Stmt());
     } else {
-      Expr a = args.at(0).operator Expr();
-      Expr b = args.at(1).operator Expr();
-      *ret = Equal(a, b);
+      *ret = Equal(args.at(0).operator Expr(), args.at(1).operator Expr());
     }
   });
 
diff --git a/src/c_api/c_api_registry.h b/src/c_api/c_api_registry.h
index 835368f50..7223ebaee 100644
--- a/src/c_api/c_api_registry.h
+++ b/src/c_api/c_api_registry.h
@@ -33,41 +33,65 @@ inline const char* TypeId2Str(ArgVariantID type_id) {
 
 template<typename T>
 struct NodeTypeChecker {
-  static inline void Check(Node* sptr) {
+  static inline bool Check(Node* sptr) {
+    // This is the only place in the project where RTTI is used
+    // It can be turned off, but will make non strict checking.
+    // TODO(tqchen) possibly find alternative to turn of RTTI
     using ContainerType = typename T::ContainerType;
-    // use dynamic RTTI for safety
-    CHECK(dynamic_cast<ContainerType*>(sptr))
-        << "wrong type specified, expected " << typeid(ContainerType).name();
+    return (dynamic_cast<ContainerType*>(sptr) != nullptr);
+  }
+  static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
+    using ContainerType = typename T::ContainerType;
+    os << ContainerType::_type_key;
   }
 };
 
 template<typename T>
 struct NodeTypeChecker<Array<T> > {
-  static inline void Check(Node* sptr) {
-    // use dynamic RTTI for safety
-    CHECK(sptr != nullptr && sptr->is_type<ArrayNode>())
-        << "wrong type specified, expected Array";
+  static inline bool Check(Node* sptr) {
+    if (sptr == nullptr) return false;
+    if (!sptr->is_type<ArrayNode>()) return false;
     ArrayNode* n = static_cast<ArrayNode*>(sptr);
     for (const auto& p : n->data) {
-      NodeTypeChecker<T>::Check(p.get());
+      if (!NodeTypeChecker<T>::Check(p.get())) return false;
     }
+    return true;
+  }
+  static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
+    os << "array<";
+    NodeTypeChecker<T>::PrintName(os);
+    os << ">";
   }
 };
 
 template<typename K, typename V>
 struct NodeTypeChecker<Map<K, V> > {
-  static inline void Check(Node* sptr) {
-    // use dynamic RTTI for safety
-    CHECK(sptr != nullptr && sptr->is_type<MapNode>())
-        << "wrong type specified, expected Map";
+  static inline bool Check(Node* sptr) {
+    if (sptr == nullptr) return false;
+    if (!sptr->is_type<MapNode>()) return false;
     MapNode* n = static_cast<MapNode*>(sptr);
     for (const auto& kv : n->data) {
-      NodeTypeChecker<K>::Check(kv.first.get());
-      NodeTypeChecker<V>::Check(kv.second.get());
+      if (!NodeTypeChecker<K>::Check(kv.first.get())) return false;
+      if (!NodeTypeChecker<V>::Check(kv.second.get())) return false;
     }
+    return true;
+  }
+  static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
+    os << "map<";
+    NodeTypeChecker<K>::PrintName(os);
+    os << ',';
+    NodeTypeChecker<V>::PrintName(os);
+    os << '>';
   }
 };
 
+template<typename T>
+inline std::string NodeTypeName() {
+  std::ostringstream os;
+  NodeTypeChecker<T>::PrintName(os);
+  return os.str();
+}
+
 /*! \brief Variant container for API calls */
 class APIVariantValue {
  public:
@@ -127,7 +151,8 @@ class APIVariantValue {
   inline operator T() const {
     if (type_id == kNull) return T();
     CHECK_EQ(type_id, kNodeHandle);
-    NodeTypeChecker<T>::Check(sptr.get());
+    CHECK(NodeTypeChecker<T>::Check(sptr.get()))
+        << "Did not get expected type " << NodeTypeName<T>();
     return T(sptr);
   }
   inline operator Expr() const {
@@ -140,7 +165,7 @@ class APIVariantValue {
     if (sptr->is_type<IterVarNode>()) {
       return IterVar(sptr)->var;
     } else {
-      CHECK(dynamic_cast<typename Expr::ContainerType*>(sptr.get()))
+      CHECK(NodeTypeChecker<Expr>::Check(sptr.get()))
           << "did not pass in Expr in a place need Expr";
       return Expr(sptr);
     }
-- 
GitLab