Commit d51119e6 by Zhi Committed by Haichen Shen

vm external codegen (#4544)

parent bc5367a0
...@@ -476,30 +476,39 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> { ...@@ -476,30 +476,39 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
argument_registers.push_back(reg->second); argument_registers.push_back(reg->second);
} }
// Next generate the invoke instruction.
Target target; Target target;
if (targets_.size() == 1) {
// homogeneous execution. if (!func->UseDefaultCompiler()) {
for (auto kv : targets_) { target = tvm::target::ext_dev();
target = kv.second;
}
} else { } else {
// heterogeneous execution. // Next generate the invoke instruction.
LOG(FATAL) << "Currently VM compiler doesn't support heterogeneous compilation"; if (targets_.size() == 1) {
// homogeneous execution.
const auto& it = targets_.begin();
target = (*it).second;
} else {
// heterogeneous execution.
LOG(FATAL) << "Currently VM compiler doesn't support heterogeneous compilation";
}
} }
auto key = CCacheKeyNode::make(func, target); auto key = CCacheKeyNode::make(func, target);
auto cfunc = engine_->Lower(key); auto cfunc = engine_->Lower(key);
// TODO(jroesch): support lowered funcs for multiple targets
CHECK_EQ(cfunc->funcs.size(), 1);
auto op_index = -1; auto op_index = -1;
if (context_->seen_funcs.find(cfunc->funcs[0]) == context_->seen_funcs.end()) { if (!func->UseDefaultCompiler()) {
op_index = context_->cached_funcs.size(); op_index = context_->cached_funcs.size();
context_->cached_funcs.push_back(cfunc); context_->cached_funcs.push_back(cfunc);
context_->seen_funcs[cfunc->funcs[0]] = op_index;
} else { } else {
op_index = context_->seen_funcs[cfunc->funcs[0]]; // TODO(jroesch): support lowered funcs for multiple targets
CHECK_EQ(cfunc->funcs.size(), 1);
if (context_->seen_funcs.find(cfunc->funcs[0]) == context_->seen_funcs.end()) {
op_index = context_->cached_funcs.size();
context_->cached_funcs.push_back(cfunc);
context_->seen_funcs[cfunc->funcs[0]] = op_index;
} else {
op_index = context_->seen_funcs[cfunc->funcs[0]];
}
} }
Emit(Instruction::InvokePacked(op_index, Emit(Instruction::InvokePacked(op_index,
...@@ -950,32 +959,46 @@ void VMCompiler::LibraryCodegen() { ...@@ -950,32 +959,46 @@ void VMCompiler::LibraryCodegen() {
if (cached_funcs.size() == 0) { if (cached_funcs.size() == 0) {
return; return;
} }
std::unordered_map<std::string, Array<LoweredFunc>> tgt_funcs; std::unordered_map<std::string, Array<LoweredFunc>> funcs;
for (auto &cfunc : cached_funcs) { for (auto& cfunc : cached_funcs) {
std::string target_str = cfunc->target->str(); std::string target_str = cfunc->target->str();
if (tgt_funcs.count(target_str) == 0) { if (target_str == "ext_dev") {
tgt_funcs.emplace(target_str, Array<LoweredFunc>{cfunc->funcs[0]}); continue;
} else if (funcs.count(target_str) == 0) {
funcs.emplace(target_str, Array<LoweredFunc>{cfunc->funcs[0]});
} else { } else {
tgt_funcs[target_str].push_back(cfunc->funcs[0]); funcs[target_str].push_back(cfunc->funcs[0]);
} }
} }
Map<Target, Array<LoweredFunc>> funcs;
for (auto &it : tgt_funcs) {
funcs.Set(Target::Create(it.first), it.second);
}
if (const auto *f = runtime::Registry::Get("relay.backend.build")) { auto compile_engine = CompileEngine::Global();
// The target is just a dummy arg because funcs already contains corresponding target auto ext_mods = compile_engine->LowerExternalFunctions();
// therefore target won't be used in the build function runtime::Module mod;
runtime::Module mod = (*f)(funcs, Target(), target_host_); if (funcs.size() > 0) {
mod = tvm::build(funcs, target_host_, tvm::BuildConfig::Current());
CHECK(mod.operator->()); CHECK(mod.operator->());
exec_->lib = mod;
} else { } else {
LOG(FATAL) << "relay.backend.build is not registered"; CHECK_EQ(ext_mods.size(), 1U)
<< "Expect to have a TVM DSOModule when multiple runtime modules exist";
}
if (!ext_mods.empty()) {
if (funcs.size() == 0) {
mod = ext_mods[0];
} else {
// Import all external runtime modules.
for (auto it : ext_mods) {
mod.Import(it);
}
}
} }
exec_->lib = mod;
size_t primitive_index = 0; size_t primitive_index = 0;
for (auto cfunc : cached_funcs) { for (auto cfunc : cached_funcs) {
exec_->primitive_map.insert({cfunc->funcs[0]->name, primitive_index++}); if (cfunc->target->str() == "ext_dev") {
exec_->primitive_map.insert({cfunc->func_name, primitive_index++});
} else {
exec_->primitive_map.insert({cfunc->funcs[0]->name, primitive_index++});
}
} }
} }
......
...@@ -800,7 +800,9 @@ void VirtualMachine::LoadExecutable(const Executable* exec) { ...@@ -800,7 +800,9 @@ void VirtualMachine::LoadExecutable(const Executable* exec) {
if (packed_funcs_.size() <= packed_index) { if (packed_funcs_.size() <= packed_index) {
packed_funcs_.resize(packed_index + 1); packed_funcs_.resize(packed_index + 1);
} }
packed_funcs_[packed_index] = lib.GetFunction(packed_name); tvm::runtime::PackedFunc pf = lib.GetFunction(packed_name, true);
CHECK(pf != nullptr) << "Cannot find function in module: " << packed_name;
packed_funcs_[packed_index] = pf;
} }
} }
......
...@@ -26,36 +26,54 @@ import tvm.relay.transform ...@@ -26,36 +26,54 @@ import tvm.relay.transform
from tvm import relay from tvm import relay
from tvm.contrib import util from tvm.contrib import util
def check_result(mod, map_inputs, out_shape, result, tol=1e-5): def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
ctx=tvm.cpu()):
if sys.platform == "win32": if sys.platform == "win32":
print("Skip test on Windows for now") print("Skip test on Windows for now")
return return
with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]): def update_lib(lib):
json, lib, _ = relay.build(mod, "llvm") test_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__)))
test_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__))) source_dir = os.path.join(test_dir, "..", "..", "..")
source_dir = os.path.join(test_dir, "..", "..", "..") contrib_path = os.path.join(source_dir, "src", "runtime", "contrib")
contrib_path = os.path.join(source_dir, "src", "runtime", "contrib")
kwargs = {}
kwargs = {} kwargs["options"] = ["-O2", "-std=c++11", "-I" + contrib_path]
kwargs["options"] = ["-O2", "-std=c++11", "-I" + contrib_path] tmp_path = util.tempdir()
tmp_path = util.tempdir() lib_name = 'lib.so'
lib_name = 'lib.so' lib_path = tmp_path.relpath(lib_name)
lib_path = tmp_path.relpath(lib_name) lib.export_library(lib_path, fcompile=False, **kwargs)
lib.export_library(lib_path, fcompile=False, **kwargs) lib = tvm.module.load(lib_path)
lib = tvm.module.load(lib_path)
return lib
ctx = tvm.cpu()
rt_mod = tvm.contrib.graph_runtime.create(json, lib, ctx) def check_vm_result():
with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
for name, data in map_inputs.items(): exe = relay.vm.compile(mod, target=target)
rt_mod.set_input(name, data) code, lib = exe.save()
lib = update_lib(lib)
rt_mod.run() exe = relay.vm.Executable.load_exec(code, lib)
out = tvm.nd.empty(out_shape, ctx=ctx) vm = relay.vm.VirtualMachine(exe)
out = rt_mod.get_output(0, out) vm.init(ctx)
out = vm.run(**map_inputs)
tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol) tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)
def check_graph_runtime_result():
with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
json, lib, _ = relay.build(mod, target=target)
lib = update_lib(lib)
rt_mod = tvm.contrib.graph_runtime.create(json, lib, ctx)
for name, data in map_inputs.items():
rt_mod.set_input(name, data)
rt_mod.run()
out = tvm.nd.empty(out_shape, ctx=ctx)
out = rt_mod.get_output(0, out)
tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)
check_vm_result()
check_graph_runtime_result()
def set_external_func_attr(func, compiler, ext_symbol): def set_external_func_attr(func, compiler, ext_symbol):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment