diff --git a/src/lang/attrs.cc b/src/lang/attrs.cc index 0d8d1f3c9ece1163e44a3be766204c437b997375..360c5b2e983342b4262eb36141638ad371249cda 100644 --- a/src/lang/attrs.cc +++ b/src/lang/attrs.cc @@ -42,4 +42,6 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_REGISTER_NODE_TYPE(DictAttrsNode); +TVM_REGISTER_NODE_TYPE(AttrFieldInfoNode); + } // namespace tvm diff --git a/src/lang/reflection.cc b/src/lang/reflection.cc index 93e2defd5aef2f32811c039319d83ada3f86b197..a33594107a69dee22bbb32564117c63109cced8d 100644 --- a/src/lang/reflection.cc +++ b/src/lang/reflection.cc @@ -236,6 +236,11 @@ class JSONAttrGetter : public AttrVisitor { node_index_->at(kv.second.get())); } } else { + // do not need to recover content of global singleton object + // they are registered via the environment + auto* f = dmlc::Registry<NodeFactoryReg>::Find(node->type_key()); + if (f != nullptr && f->fglobal_key != nullptr) return; + // recursively index normal object. node->VisitAttrs(this); } } diff --git a/tests/python/unittest/test_lang_reflection.py b/tests/python/unittest/test_lang_reflection.py index 2ba67b8d9c864f4b05a547fa5db8b90c798b403d..9678fff8ef9b7d1f23d6c4ef36025dbb37eb68ba 100644 --- a/tests/python/unittest/test_lang_reflection.py +++ b/tests/python/unittest/test_lang_reflection.py @@ -58,7 +58,8 @@ def test_make_attrs(): dattr = tvm.make.node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0)) assert dattr.x.value == 1 - + datrr = tvm.load_json(tvm.save_json(dattr)) + assert dattr.name.value == "xyz" def test_make_sum():