diff --git a/dmlc-core b/dmlc-core index a384fb9ed09d0c430c468db91abb3694deb88e54..04f91953ace74aced3bb317990515304c5425849 160000 --- a/dmlc-core +++ b/dmlc-core @@ -1 +1 @@ -Subproject commit a384fb9ed09d0c430c468db91abb3694deb88e54 +Subproject commit 04f91953ace74aced3bb317990515304c5425849 diff --git a/src/runtime/graph/graph_runtime.cc b/src/runtime/graph/graph_runtime.cc index d1160f5237735f87f350dc78499c81e5d007a432..f305367d4a2d82a27a7026c2a0086ce810c23450 100644 --- a/src/runtime/graph/graph_runtime.cc +++ b/src/runtime/graph/graph_runtime.cc @@ -156,6 +156,37 @@ class GraphRuntime : public ModuleNode { // control deps std::vector<uint32_t> control_deps; // JSON Loader + void LoadAttrs(dmlc::JSONReader *reader, TVMOpParam* param) { + int bitmask = 0; + std::string key, value; + reader->BeginObject(); + while (reader->NextObjectItem(&key)) { + if (key == "func_name") { + reader->Read(&value); + param->func_name = value; + bitmask |= 1; + } else if (key == "num_inputs") { + reader->Read(&value); + std::istringstream is(value); + is >> param->num_inputs; + bitmask |= 2; + } else if (key == "num_outputs") { + reader->Read(&value); + std::istringstream is(value); + is >> param->num_outputs; + bitmask |= 4; + } else if (key == "flatten_data") { + reader->Read(&value); + std::istringstream is(value); + is >> param->flatten_data; + bitmask |= 8; + } else { + reader->Read(&value); + } + } + CHECK_EQ(bitmask, 1|2|4|8) << "invalid format"; + } + // JSON Loader void Load(dmlc::JSONReader *reader) { reader->BeginObject(); std::unordered_map<std::string, std::string> dict; @@ -172,8 +203,7 @@ class GraphRuntime : public ModuleNode { reader->Read(&inputs); bitmask |= 4; } else if (key == "attr" || key == "attrs") { - reader->Read(&dict); - param.Init(dict); + this->LoadAttrs(reader, ¶m); } else if (key == "control_deps") { reader->Read(&control_deps); } else { @@ -263,6 +293,8 @@ class GraphRuntime : public ModuleNode { } else if (key == "attrs") { reader->Read(&attrs_); bitmask |= 16; + } else { + LOG(FATAL) << "key " << key << " is not supported"; } } CHECK_EQ(bitmask, 1|2|4|8|16) << "invalid format"; @@ -320,7 +352,6 @@ class GraphRuntime : public ModuleNode { std::vector<std::function<void()> > op_execs_; }; -DMLC_REGISTER_PARAMETER(TVMOpParam); bool GraphRuntime::LoadDLTensor(dmlc::Stream* strm, DLTensor* tensor) { uint64_t header, reserved; diff --git a/src/runtime/graph/graph_runtime.h b/src/runtime/graph/graph_runtime.h index 2df909f9be48537e0c0cc9e6626abcefc0d04006..8e2590dc6359bbf22104b338b7829bab4972453b 100644 --- a/src/runtime/graph/graph_runtime.h +++ b/src/runtime/graph/graph_runtime.h @@ -8,7 +8,6 @@ #ifndef TVM_RUNTIME_GRAPH_GRAPH_RUNTIME_H_ #define TVM_RUNTIME_GRAPH_GRAPH_RUNTIME_H_ -#include <dmlc/parameter.h> #include <string> namespace tvm { @@ -20,18 +19,11 @@ constexpr uint64_t kTVMNDArrayMagic = 0xDD5E40F096B4A13F; constexpr uint64_t kTVMNDArrayListMagic = 0xF7E58D4F05049CB7; /*! \brief operator attributes about tvm op */ -struct TVMOpParam : public dmlc::Parameter<TVMOpParam> { +struct TVMOpParam { std::string func_name; uint32_t num_inputs; uint32_t num_outputs; uint32_t flatten_data; - - DMLC_DECLARE_PARAMETER(TVMOpParam) { - DMLC_DECLARE_FIELD(func_name); - DMLC_DECLARE_FIELD(num_inputs).set_default(1); - DMLC_DECLARE_FIELD(num_outputs).set_default(1); - DMLC_DECLARE_FIELD(flatten_data).set_default(0); - } }; } // namespace runtime