Commit 6196cd50 by ziheng Committed by Tianqi Chen

[EXECUTOR] Move tvm_op and Handler<DLTensor> to graph_executor.cc (#259)

parent eaea99c5
...@@ -399,6 +399,55 @@ FOpExec GraphExecutor::CreateTVMOp(const nnvm::NodeAttrs& attrs, ...@@ -399,6 +399,55 @@ FOpExec GraphExecutor::CreateTVMOp(const nnvm::NodeAttrs& attrs,
return fexec; return fexec;
} }
struct TVMOpParam : public dmlc::Parameter<TVMOpParam> {
std::string func_name;
uint32_t num_inputs;
uint32_t num_outputs;
bool flatten_data;
DMLC_DECLARE_PARAMETER(TVMOpParam) {
DMLC_DECLARE_FIELD(func_name);
DMLC_DECLARE_FIELD(num_inputs)
.set_default(1);
DMLC_DECLARE_FIELD(num_outputs)
.set_default(1);
DMLC_DECLARE_FIELD(flatten_data)
.set_default(false);
}
};
DMLC_REGISTER_PARAMETER(TVMOpParam);
/*! \brief Parse keyword arguments as PType arguments and save to parsed */
template<typename PType>
inline void ParamParser(nnvm::NodeAttrs* attrs) {
PType param;
try {
param.Init(attrs->dict);
} catch (const dmlc::ParamError& e) {
std::ostringstream os;
os << e.what();
os << ", in operator " << attrs->op->name << "("
<< "name=\"" << attrs->name << "\"";
for (const auto& k : attrs->dict) {
os << ", " << k.first << "=\"" << k.second << "\"";
}
os << ")";
throw dmlc::ParamError(os.str());
}
attrs->parsed = std::move(param);
}
// ewise tvm op
NNVM_REGISTER_OP(tvm_op)
.set_attr_parser(ParamParser<TVMOpParam>)
.set_num_inputs([](const NodeAttrs& attrs) {
const TVMOpParam& param = nnvm::get<TVMOpParam>(attrs.parsed);
return param.num_inputs;
})
.set_num_outputs([](const NodeAttrs& attrs) {
const TVMOpParam& param = nnvm::get<TVMOpParam>(attrs.parsed);
return param.num_outputs;
});
// Create executor // Create executor
tvm::runtime::Module CreateExecutor(nnvm::Graph g, TVMContext ctx) { tvm::runtime::Module CreateExecutor(nnvm::Graph g, TVMContext ctx) {
std::shared_ptr<GraphExecutor> exec = std::shared_ptr<GraphExecutor> exec =
...@@ -460,3 +509,27 @@ TVM_REGISTER_GLOBAL("tvm_graph._load_executor") ...@@ -460,3 +509,27 @@ TVM_REGISTER_GLOBAL("tvm_graph._load_executor")
}); });
} // namespace contrib } // namespace contrib
} // namespace tvm } // namespace tvm
namespace dmlc {
namespace json {
template<>
struct Handler<DLDataType> {
static void Write(JSONWriter *writer, const DLDataType& data) {
std::vector<int> tmp({data.code, data.bits, data.lanes});
writer->Write(tmp);
}
static void Read(JSONReader *reader, DLDataType* data) {
std::vector<int> tmp;
reader->Read(&tmp);
data->code = tmp[0];
data->bits = tmp[1];
data->lanes = tmp[2];
}
};
DMLC_JSON_ENABLE_ANY(std::vector<DLDataType>, list_dltype);
} // namespace dmlc
} // namespace json
...@@ -340,54 +340,6 @@ nnvm::Graph GraphFuse(nnvm::Graph g) { ...@@ -340,54 +340,6 @@ nnvm::Graph GraphFuse(nnvm::Graph g) {
NNVM_REGISTER_PASS(GraphFuse) NNVM_REGISTER_PASS(GraphFuse)
.set_body(GraphFuse); .set_body(GraphFuse);
struct TVMOpParam : public dmlc::Parameter<TVMOpParam> {
std::string func_name;
uint32_t num_inputs;
uint32_t num_outputs;
bool flatten_data;
DMLC_DECLARE_PARAMETER(TVMOpParam) {
DMLC_DECLARE_FIELD(func_name);
DMLC_DECLARE_FIELD(num_inputs)
.set_default(1);
DMLC_DECLARE_FIELD(num_outputs)
.set_default(1);
DMLC_DECLARE_FIELD(flatten_data)
.set_default(false);
}
};
DMLC_REGISTER_PARAMETER(TVMOpParam);
/*! \brief Parse keyword arguments as PType arguments and save to parsed */
template<typename PType>
inline void ParamParser(nnvm::NodeAttrs* attrs) {
PType param;
try {
param.Init(attrs->dict);
} catch (const dmlc::ParamError& e) {
std::ostringstream os;
os << e.what();
os << ", in operator " << attrs->op->name << "("
<< "name=\"" << attrs->name << "\"";
for (const auto& k : attrs->dict) {
os << ", " << k.first << "=\"" << k.second << "\"";
}
os << ")";
throw dmlc::ParamError(os.str());
}
attrs->parsed = std::move(param);
}
// ewise tvm op
NNVM_REGISTER_OP(tvm_op)
.set_attr_parser(ParamParser<TVMOpParam>)
.set_num_inputs([](const NodeAttrs& attrs) {
const TVMOpParam& param = nnvm::get<TVMOpParam>(attrs.parsed);
return param.num_inputs;
})
.set_num_outputs([](const NodeAttrs& attrs) {
const TVMOpParam& param = nnvm::get<TVMOpParam>(attrs.parsed);
return param.num_outputs;
});
inline bool IsIdentityLayout(const LayoutInfo& layout) { inline bool IsIdentityLayout(const LayoutInfo& layout) {
if (layout.src == "" && layout.dst == "") return true; if (layout.src == "" && layout.dst == "") return true;
...@@ -515,27 +467,3 @@ NNVM_REGISTER_OP(layout_transform) ...@@ -515,27 +467,3 @@ NNVM_REGISTER_OP(layout_transform)
.set_num_outputs(1); .set_num_outputs(1);
} // namespace contrib } // namespace contrib
} // namespace tvm } // namespace tvm
namespace dmlc {
namespace json {
template<>
struct Handler<DLDataType> {
static void Write(JSONWriter *writer, const DLDataType& data) {
std::vector<int> tmp({data.code, data.bits, data.lanes});
writer->Write(tmp);
}
static void Read(JSONReader *reader, DLDataType* data) {
std::vector<int> tmp;
reader->Read(&tmp);
data->code = tmp[0];
data->bits = tmp[1];
data->lanes = tmp[2];
}
};
DMLC_JSON_ENABLE_ANY(std::vector<DLDataType>, list_dltype);
} // namespace dmlc
} // namespace json
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment