Commit a3a9dbeb by Tianqi Chen

Make compiler more robust (#378)

parent 860adec8
...@@ -279,7 +279,9 @@ def _run_graph(graph, params): ...@@ -279,7 +279,9 @@ def _run_graph(graph, params):
graph, libmod, _ = build(graph, target, shape, dtype) graph, libmod, _ = build(graph, target, shape, dtype)
m = graph_runtime.create(graph, libmod, ctx) m = graph_runtime.create(graph, libmod, ctx)
set_input, run, get_output = m["set_input"], m["run"], m["get_output"] set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
kset = set(graph.symbol.list_input_names())
for k, v in params.items(): for k, v in params.items():
if k in kset:
set_input(k, tvm.nd.array(v)) set_input(k, tvm.nd.array(v))
run() run()
out_data = [] out_data = []
......
...@@ -135,7 +135,8 @@ class CompileEngine { ...@@ -135,7 +135,8 @@ class CompileEngine {
} }
// get schedule and its args // get schedule and its args
std::pair<Schedule, Array<tvm::Tensor> > GetScheduleArgs(Graph graph, std::tuple<Schedule, Array<tvm::Tensor>, Graph>
GetScheduleArgs(Graph graph,
const Array<tvm::Tensor> &inputs, const Array<tvm::Tensor> &inputs,
const std::string &target, const std::string &target,
int master_idx, int master_idx,
...@@ -221,12 +222,14 @@ class CompileEngine { ...@@ -221,12 +222,14 @@ class CompileEngine {
idx[master_idx].source->attrs, outs, target); idx[master_idx].source->attrs, outs, target);
// store extra return values // store extra return values
if (readable_name != nullptr) if (readable_name != nullptr) {
*readable_name = readable_name_os.str(); *readable_name = readable_name_os.str();
if (outputs != nullptr) }
if (outputs != nullptr) {
*outputs = outs; *outputs = outs;
}
return std::make_pair(sch, all_args); return std::make_tuple(sch, all_args, graph);
} }
// run the actual lowering process // run the actual lowering process
...@@ -239,7 +242,8 @@ class CompileEngine { ...@@ -239,7 +242,8 @@ class CompileEngine {
Array<tvm::Tensor> outputs; Array<tvm::Tensor> outputs;
Schedule sch; Schedule sch;
std::tie(sch, all_args) = GetScheduleArgs(graph, inputs, target, master_idx, std::tie(sch, all_args, graph) = GetScheduleArgs(
graph, inputs, target, master_idx,
&readable_name, &outputs); &readable_name, &outputs);
std::shared_ptr<GraphFuncNode> gf = std::make_shared<GraphFuncNode>(); std::shared_ptr<GraphFuncNode> gf = std::make_shared<GraphFuncNode>();
...@@ -335,7 +339,8 @@ TVM_REGISTER_GLOBAL("nnvm.compiler.CacheItem2ScheduleArgs") ...@@ -335,7 +339,8 @@ TVM_REGISTER_GLOBAL("nnvm.compiler.CacheItem2ScheduleArgs")
Schedule sch; Schedule sch;
Array<tvm::Tensor> all_args; Array<tvm::Tensor> all_args;
std::tie(sch, all_args) = CompileEngine::Global()->GetScheduleArgs( std::tie(sch, all_args, graph) =
CompileEngine::Global()->GetScheduleArgs(
graph, inputs, target, master_idx, nullptr, nullptr); graph, inputs, target, master_idx, nullptr, nullptr);
Array<tvm::NodeRef> ret; Array<tvm::NodeRef> ret;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment