Commit a3a9dbeb by Tianqi Chen

Make compiler more robust (#378)

parent 860adec8
......@@ -279,8 +279,10 @@ def _run_graph(graph, params):
graph, libmod, _ = build(graph, target, shape, dtype)
m = graph_runtime.create(graph, libmod, ctx)
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
kset = set(graph.symbol.list_input_names())
for k, v in params.items():
set_input(k, tvm.nd.array(v))
if k in kset:
set_input(k, tvm.nd.array(v))
run()
out_data = []
for i, kv in enumerate(zip(oshape, odtype)):
......
......@@ -135,12 +135,13 @@ class CompileEngine {
}
// get schedule and its args
std::pair<Schedule, Array<tvm::Tensor> > GetScheduleArgs(Graph graph,
const Array<tvm::Tensor> &inputs,
const std::string &target,
int master_idx,
std::string *readable_name,
Array<tvm::Tensor> *outputs) {
std::tuple<Schedule, Array<tvm::Tensor>, Graph>
GetScheduleArgs(Graph graph,
const Array<tvm::Tensor> &inputs,
const std::string &target,
int master_idx,
std::string *readable_name,
Array<tvm::Tensor> *outputs) {
// shape, type
static auto& fcompute =
nnvm::Op::GetAttr<FTVMCompute>("FTVMCompute");
......@@ -221,12 +222,14 @@ class CompileEngine {
idx[master_idx].source->attrs, outs, target);
// store extra return values
if (readable_name != nullptr)
if (readable_name != nullptr) {
*readable_name = readable_name_os.str();
if (outputs != nullptr)
}
if (outputs != nullptr) {
*outputs = outs;
}
return std::make_pair(sch, all_args);
return std::make_tuple(sch, all_args, graph);
}
// run the actual lowering process
......@@ -239,8 +242,9 @@ class CompileEngine {
Array<tvm::Tensor> outputs;
Schedule sch;
std::tie(sch, all_args) = GetScheduleArgs(graph, inputs, target, master_idx,
&readable_name, &outputs);
std::tie(sch, all_args, graph) = GetScheduleArgs(
graph, inputs, target, master_idx,
&readable_name, &outputs);
std::shared_ptr<GraphFuncNode> gf = std::make_shared<GraphFuncNode>();
gf->target = target;
......@@ -335,7 +339,8 @@ TVM_REGISTER_GLOBAL("nnvm.compiler.CacheItem2ScheduleArgs")
Schedule sch;
Array<tvm::Tensor> all_args;
std::tie(sch, all_args) = CompileEngine::Global()->GetScheduleArgs(
std::tie(sch, all_args, graph) =
CompileEngine::Global()->GetScheduleArgs(
graph, inputs, target, master_idx, nullptr, nullptr);
Array<tvm::NodeRef> ret;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment