Commit 1b2e5ced by Tianqi Chen Committed by GitHub

[RUNTIME] Remove parameter def from runtime (#486)

parent b18143e5
Subproject commit a384fb9ed09d0c430c468db91abb3694deb88e54
Subproject commit 04f91953ace74aced3bb317990515304c5425849
......@@ -156,6 +156,37 @@ class GraphRuntime : public ModuleNode {
// control deps
std::vector<uint32_t> control_deps;
// JSON Loader
void LoadAttrs(dmlc::JSONReader *reader, TVMOpParam* param) {
int bitmask = 0;
std::string key, value;
reader->BeginObject();
while (reader->NextObjectItem(&key)) {
if (key == "func_name") {
reader->Read(&value);
param->func_name = value;
bitmask |= 1;
} else if (key == "num_inputs") {
reader->Read(&value);
std::istringstream is(value);
is >> param->num_inputs;
bitmask |= 2;
} else if (key == "num_outputs") {
reader->Read(&value);
std::istringstream is(value);
is >> param->num_outputs;
bitmask |= 4;
} else if (key == "flatten_data") {
reader->Read(&value);
std::istringstream is(value);
is >> param->flatten_data;
bitmask |= 8;
} else {
reader->Read(&value);
}
}
CHECK_EQ(bitmask, 1|2|4|8) << "invalid format";
}
// JSON Loader
void Load(dmlc::JSONReader *reader) {
reader->BeginObject();
std::unordered_map<std::string, std::string> dict;
......@@ -172,8 +203,7 @@ class GraphRuntime : public ModuleNode {
reader->Read(&inputs);
bitmask |= 4;
} else if (key == "attr" || key == "attrs") {
reader->Read(&dict);
param.Init(dict);
this->LoadAttrs(reader, &param);
} else if (key == "control_deps") {
reader->Read(&control_deps);
} else {
......@@ -263,6 +293,8 @@ class GraphRuntime : public ModuleNode {
} else if (key == "attrs") {
reader->Read(&attrs_);
bitmask |= 16;
} else {
LOG(FATAL) << "key " << key << " is not supported";
}
}
CHECK_EQ(bitmask, 1|2|4|8|16) << "invalid format";
......@@ -320,7 +352,6 @@ class GraphRuntime : public ModuleNode {
std::vector<std::function<void()> > op_execs_;
};
DMLC_REGISTER_PARAMETER(TVMOpParam);
bool GraphRuntime::LoadDLTensor(dmlc::Stream* strm, DLTensor* tensor) {
uint64_t header, reserved;
......
......@@ -8,7 +8,6 @@
#ifndef TVM_RUNTIME_GRAPH_GRAPH_RUNTIME_H_
#define TVM_RUNTIME_GRAPH_GRAPH_RUNTIME_H_
#include <dmlc/parameter.h>
#include <string>
namespace tvm {
......@@ -20,18 +19,11 @@ constexpr uint64_t kTVMNDArrayMagic = 0xDD5E40F096B4A13F;
constexpr uint64_t kTVMNDArrayListMagic = 0xF7E58D4F05049CB7;
/*! \brief operator attributes about tvm op */
struct TVMOpParam : public dmlc::Parameter<TVMOpParam> {
struct TVMOpParam {
std::string func_name;
uint32_t num_inputs;
uint32_t num_outputs;
uint32_t flatten_data;
DMLC_DECLARE_PARAMETER(TVMOpParam) {
DMLC_DECLARE_FIELD(func_name);
DMLC_DECLARE_FIELD(num_inputs).set_default(1);
DMLC_DECLARE_FIELD(num_outputs).set_default(1);
DMLC_DECLARE_FIELD(flatten_data).set_default(0);
}
};
} // namespace runtime
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment