Commit 3279957f by hlu1 Committed by Zhi

[Relay] Invoke tvm::build from relay compile_engine and interpreter (#4723)

parent b171cf1d
...@@ -599,12 +599,13 @@ class CompileEngineImpl : public CompileEngineNode { ...@@ -599,12 +599,13 @@ class CompileEngineImpl : public CompileEngineNode {
CCacheValue value = LowerInternal(key); CCacheValue value = LowerInternal(key);
if (value->packed_func != nullptr) return value->packed_func; if (value->packed_func != nullptr) return value->packed_func;
// build the function. // build the function.
tvm::runtime::Module m;
if (const auto* f = runtime::Registry::Get("relay.backend.build")) { if (const auto* f = runtime::Registry::Get("relay.backend.build")) {
tvm::runtime::Module m = (*f)(value->cached_func->funcs, key->target); m = (*f)(value->cached_func->funcs, key->target);
value->packed_func = m.GetFunction(value->cached_func->func_name);
} else { } else {
LOG(FATAL) << "relay.backend.build is not registered"; m = build(value->cached_func->funcs, key->target, Target(nullptr), BuildConfig::Current());
} }
value->packed_func = m.GetFunction(value->cached_func->func_name);
return value->packed_func; return value->packed_func;
} }
......
...@@ -418,13 +418,14 @@ class Interpreter : ...@@ -418,13 +418,14 @@ class Interpreter :
<< "Shape function output sizes mismatch"; << "Shape function output sizes mismatch";
PackedFunc shape_func; PackedFunc shape_func;
Module m;
TVMRetValue rv; TVMRetValue rv;
if (const auto* f = runtime::Registry::Get("relay.backend.build")) { if (const auto* f = runtime::Registry::Get("relay.backend.build")) {
tvm::runtime::Module m = (*f)(cfunc->funcs, cfunc->target); m = (*f)(cfunc->funcs, cfunc->target);
shape_func = m.GetFunction(cfunc->func_name);
} else { } else {
LOG(FATAL) << "relay.backend.build is not registered"; m = build(cfunc->funcs, cfunc->target, Target(nullptr), BuildConfig::Current());
} }
shape_func = m.GetFunction(cfunc->func_name);
shape_func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv); shape_func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv);
// Get output shapes // Get output shapes
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment