Commit 3279957f by hlu1 Committed by Zhi

[Relay] Invoke tvm::build from relay compile_engine and interpreter (#4723)

parent b171cf1d
......@@ -599,12 +599,13 @@ class CompileEngineImpl : public CompileEngineNode {
CCacheValue value = LowerInternal(key);
if (value->packed_func != nullptr) return value->packed_func;
// build the function.
tvm::runtime::Module m;
if (const auto* f = runtime::Registry::Get("relay.backend.build")) {
tvm::runtime::Module m = (*f)(value->cached_func->funcs, key->target);
value->packed_func = m.GetFunction(value->cached_func->func_name);
m = (*f)(value->cached_func->funcs, key->target);
} else {
LOG(FATAL) << "relay.backend.build is not registered";
m = build(value->cached_func->funcs, key->target, Target(nullptr), BuildConfig::Current());
}
value->packed_func = m.GetFunction(value->cached_func->func_name);
return value->packed_func;
}
......
......@@ -418,13 +418,14 @@ class Interpreter :
<< "Shape function output sizes mismatch";
PackedFunc shape_func;
Module m;
TVMRetValue rv;
if (const auto* f = runtime::Registry::Get("relay.backend.build")) {
tvm::runtime::Module m = (*f)(cfunc->funcs, cfunc->target);
shape_func = m.GetFunction(cfunc->func_name);
m = (*f)(cfunc->funcs, cfunc->target);
} else {
LOG(FATAL) << "relay.backend.build is not registered";
m = build(cfunc->funcs, cfunc->target, Target(nullptr), BuildConfig::Current());
}
shape_func = m.GetFunction(cfunc->func_name);
shape_func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv);
// Get output shapes
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment