Commit d51119e6 by Zhi Committed by Haichen Shen

vm external codegen (#4544)

parent bc5367a0
...@@ -476,24 +476,32 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> { ...@@ -476,24 +476,32 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
argument_registers.push_back(reg->second); argument_registers.push_back(reg->second);
} }
// Next generate the invoke instruction.
Target target; Target target;
if (!func->UseDefaultCompiler()) {
target = tvm::target::ext_dev();
} else {
// Next generate the invoke instruction.
if (targets_.size() == 1) { if (targets_.size() == 1) {
// homogeneous execution. // homogeneous execution.
for (auto kv : targets_) { const auto& it = targets_.begin();
target = kv.second; target = (*it).second;
}
} else { } else {
// heterogeneous execution. // heterogeneous execution.
LOG(FATAL) << "Currently VM compiler doesn't support heterogeneous compilation"; LOG(FATAL) << "Currently VM compiler doesn't support heterogeneous compilation";
} }
}
auto key = CCacheKeyNode::make(func, target); auto key = CCacheKeyNode::make(func, target);
auto cfunc = engine_->Lower(key); auto cfunc = engine_->Lower(key);
auto op_index = -1;
if (!func->UseDefaultCompiler()) {
op_index = context_->cached_funcs.size();
context_->cached_funcs.push_back(cfunc);
} else {
// TODO(jroesch): support lowered funcs for multiple targets // TODO(jroesch): support lowered funcs for multiple targets
CHECK_EQ(cfunc->funcs.size(), 1); CHECK_EQ(cfunc->funcs.size(), 1);
auto op_index = -1;
if (context_->seen_funcs.find(cfunc->funcs[0]) == context_->seen_funcs.end()) { if (context_->seen_funcs.find(cfunc->funcs[0]) == context_->seen_funcs.end()) {
op_index = context_->cached_funcs.size(); op_index = context_->cached_funcs.size();
context_->cached_funcs.push_back(cfunc); context_->cached_funcs.push_back(cfunc);
...@@ -501,6 +509,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> { ...@@ -501,6 +509,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
} else { } else {
op_index = context_->seen_funcs[cfunc->funcs[0]]; op_index = context_->seen_funcs[cfunc->funcs[0]];
} }
}
Emit(Instruction::InvokePacked(op_index, Emit(Instruction::InvokePacked(op_index,
argument_registers.size(), argument_registers.size(),
...@@ -950,33 +959,47 @@ void VMCompiler::LibraryCodegen() { ...@@ -950,33 +959,47 @@ void VMCompiler::LibraryCodegen() {
if (cached_funcs.size() == 0) { if (cached_funcs.size() == 0) {
return; return;
} }
std::unordered_map<std::string, Array<LoweredFunc>> tgt_funcs; std::unordered_map<std::string, Array<LoweredFunc>> funcs;
for (auto &cfunc : cached_funcs) { for (auto& cfunc : cached_funcs) {
std::string target_str = cfunc->target->str(); std::string target_str = cfunc->target->str();
if (tgt_funcs.count(target_str) == 0) { if (target_str == "ext_dev") {
tgt_funcs.emplace(target_str, Array<LoweredFunc>{cfunc->funcs[0]}); continue;
} else if (funcs.count(target_str) == 0) {
funcs.emplace(target_str, Array<LoweredFunc>{cfunc->funcs[0]});
} else { } else {
tgt_funcs[target_str].push_back(cfunc->funcs[0]); funcs[target_str].push_back(cfunc->funcs[0]);
}
} }
Map<Target, Array<LoweredFunc>> funcs;
for (auto &it : tgt_funcs) {
funcs.Set(Target::Create(it.first), it.second);
} }
if (const auto *f = runtime::Registry::Get("relay.backend.build")) { auto compile_engine = CompileEngine::Global();
// The target is just a dummy arg because funcs already contains corresponding target auto ext_mods = compile_engine->LowerExternalFunctions();
// therefore target won't be used in the build function runtime::Module mod;
runtime::Module mod = (*f)(funcs, Target(), target_host_); if (funcs.size() > 0) {
mod = tvm::build(funcs, target_host_, tvm::BuildConfig::Current());
CHECK(mod.operator->()); CHECK(mod.operator->());
exec_->lib = mod;
} else { } else {
LOG(FATAL) << "relay.backend.build is not registered"; CHECK_EQ(ext_mods.size(), 1U)
<< "Expect to have a TVM DSOModule when multiple runtime modules exist";
}
if (!ext_mods.empty()) {
if (funcs.size() == 0) {
mod = ext_mods[0];
} else {
// Import all external runtime modules.
for (auto it : ext_mods) {
mod.Import(it);
} }
}
}
exec_->lib = mod;
size_t primitive_index = 0; size_t primitive_index = 0;
for (auto cfunc : cached_funcs) { for (auto cfunc : cached_funcs) {
if (cfunc->target->str() == "ext_dev") {
exec_->primitive_map.insert({cfunc->func_name, primitive_index++});
} else {
exec_->primitive_map.insert({cfunc->funcs[0]->name, primitive_index++}); exec_->primitive_map.insert({cfunc->funcs[0]->name, primitive_index++});
} }
}
} }
runtime::Module CreateVMCompiler() { runtime::Module CreateVMCompiler() {
......
...@@ -800,7 +800,9 @@ void VirtualMachine::LoadExecutable(const Executable* exec) { ...@@ -800,7 +800,9 @@ void VirtualMachine::LoadExecutable(const Executable* exec) {
if (packed_funcs_.size() <= packed_index) { if (packed_funcs_.size() <= packed_index) {
packed_funcs_.resize(packed_index + 1); packed_funcs_.resize(packed_index + 1);
} }
packed_funcs_[packed_index] = lib.GetFunction(packed_name); tvm::runtime::PackedFunc pf = lib.GetFunction(packed_name, true);
CHECK(pf != nullptr) << "Cannot find function in module: " << packed_name;
packed_funcs_[packed_index] = pf;
} }
} }
......
...@@ -26,13 +26,13 @@ import tvm.relay.transform ...@@ -26,13 +26,13 @@ import tvm.relay.transform
from tvm import relay from tvm import relay
from tvm.contrib import util from tvm.contrib import util
def check_result(mod, map_inputs, out_shape, result, tol=1e-5): def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
ctx=tvm.cpu()):
if sys.platform == "win32": if sys.platform == "win32":
print("Skip test on Windows for now") print("Skip test on Windows for now")
return return
with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]): def update_lib(lib):
json, lib, _ = relay.build(mod, "llvm")
test_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__))) test_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__)))
source_dir = os.path.join(test_dir, "..", "..", "..") source_dir = os.path.join(test_dir, "..", "..", "..")
contrib_path = os.path.join(source_dir, "src", "runtime", "contrib") contrib_path = os.path.join(source_dir, "src", "runtime", "contrib")
...@@ -45,18 +45,36 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5): ...@@ -45,18 +45,36 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5):
lib.export_library(lib_path, fcompile=False, **kwargs) lib.export_library(lib_path, fcompile=False, **kwargs)
lib = tvm.module.load(lib_path) lib = tvm.module.load(lib_path)
ctx = tvm.cpu() return lib
def check_vm_result():
with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
exe = relay.vm.compile(mod, target=target)
code, lib = exe.save()
lib = update_lib(lib)
exe = relay.vm.Executable.load_exec(code, lib)
vm = relay.vm.VirtualMachine(exe)
vm.init(ctx)
out = vm.run(**map_inputs)
tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)
def check_graph_runtime_result():
with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
json, lib, _ = relay.build(mod, target=target)
lib = update_lib(lib)
rt_mod = tvm.contrib.graph_runtime.create(json, lib, ctx) rt_mod = tvm.contrib.graph_runtime.create(json, lib, ctx)
for name, data in map_inputs.items(): for name, data in map_inputs.items():
rt_mod.set_input(name, data) rt_mod.set_input(name, data)
rt_mod.run() rt_mod.run()
out = tvm.nd.empty(out_shape, ctx=ctx) out = tvm.nd.empty(out_shape, ctx=ctx)
out = rt_mod.get_output(0, out) out = rt_mod.get_output(0, out)
tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol) tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)
check_vm_result()
check_graph_runtime_result()
def set_external_func_attr(func, compiler, ext_symbol): def set_external_func_attr(func, compiler, ext_symbol):
func = func.set_attribute("Primitive", tvm.expr.IntImm("int32", 1)) func = func.set_attribute("Primitive", tvm.expr.IntImm("int32", 1))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment