Commit 8e2ea2c4 by ziheng Committed by Tianqi Chen

[APP] Improve GraphExecutor (#216)

* Remove 'final' in GraphExecutor for extension

* Dynamic num of inputs/outputs for tvm_op
parent 4bb3c35a
......@@ -36,14 +36,14 @@ using FOpExec = std::function<void()>;
}
/*! \brief Graph Executor with TVM runtime */
class GraphExecutor final : public runtime::ModuleNode {
class GraphExecutor : public runtime::ModuleNode {
public:
const char* type_key() const final {
const char* type_key() const {
return "GraphExecutor";
}
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final;
const std::shared_ptr<ModuleNode>& sptr_to_self);
// Destructor
~GraphExecutor();
// Setup with a given graph
......@@ -271,10 +271,5 @@ TVM_REGISTER_GLOBAL("tvm_graph._create_executor")
nnvm::Graph g = static_cast<nnvm::Graph*>(graph_handle)[0];
*rv = CreateExecutor(g, ctx);
});
// ewise tvm op
NNVM_REGISTER_OP(tvm_op)
.set_num_inputs(-1);
} // namespace contrib
} // namespace tvm
......@@ -286,8 +286,11 @@ nnvm::Graph GraphFuse(nnvm::Graph g) {
nnvm::NodePtr np = nnvm::Node::Create();
np->attrs.op = tvm_op;
np->attrs.name = inode.source->attrs.name;
np->attrs.dict["num_inputs"] = std::to_string(fe.inputs.size());
np->attrs.dict["num_outputs"] = std::to_string(fe.outputs.size());
np->attrs.dict["func_name"] = fuse_vec[nid].func_name;
np->attrs.dict["flatten_data"] = std::to_string(pattern_vec[nid] == kElemWise);
np->op()->attr_parser(&(np->attrs));
for (const auto& e : fe.inputs) {
auto it = old_new.find(e.node_id);
CHECK(it != old_new.end())
......@@ -336,5 +339,53 @@ nnvm::Graph GraphFuse(nnvm::Graph g) {
NNVM_REGISTER_PASS(GraphFuse)
.set_body(GraphFuse);
struct TVMOpParam : public dmlc::Parameter<TVMOpParam> {
std::string func_name;
uint32_t num_inputs;
uint32_t num_outputs;
bool flatten_data;
DMLC_DECLARE_PARAMETER(TVMOpParam) {
DMLC_DECLARE_FIELD(func_name);
DMLC_DECLARE_FIELD(num_inputs)
.set_default(1);
DMLC_DECLARE_FIELD(num_outputs)
.set_default(1);
DMLC_DECLARE_FIELD(flatten_data)
.set_default(false);
}
};
DMLC_REGISTER_PARAMETER(TVMOpParam);
/*! \brief Parse keyword arguments as PType arguments and save to parsed */
template<typename PType>
inline void ParamParser(nnvm::NodeAttrs* attrs) {
PType param;
try {
param.Init(attrs->dict);
} catch (const dmlc::ParamError& e) {
std::ostringstream os;
os << e.what();
os << ", in operator " << attrs->op->name << "("
<< "name=\"" << attrs->name << "\"";
for (const auto& k : attrs->dict) {
os << ", " << k.first << "=\"" << k.second << "\"";
}
os << ")";
throw dmlc::ParamError(os.str());
}
attrs->parsed = std::move(param);
}
// ewise tvm op
NNVM_REGISTER_OP(tvm_op)
.set_attr_parser(ParamParser<TVMOpParam>)
.set_num_inputs([](const NodeAttrs& attrs) {
const TVMOpParam& param = nnvm::get<TVMOpParam>(attrs.parsed);
return param.num_inputs;
})
.set_num_outputs([](const NodeAttrs& attrs) {
const TVMOpParam& param = nnvm::get<TVMOpParam>(attrs.parsed);
return param.num_outputs;
});
} // namespace contrib
} // namespace tvm
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment