Commit b131d836 by Bing Xu Committed by Jared Roesch

Relay C++ Build Module (#3082)

* [Relay] C++ Build module

* asdf
parent 472c3146
...@@ -344,6 +344,19 @@ TVM_DLL Array<LoweredFunc> lower(Schedule sch, ...@@ -344,6 +344,19 @@ TVM_DLL Array<LoweredFunc> lower(Schedule sch,
const std::string& name, const std::string& name,
const std::unordered_map<Tensor, Buffer>& binds, const std::unordered_map<Tensor, Buffer>& binds,
const BuildConfig& config); const BuildConfig& config);
/*!
* \brief Split host/device function and running necessary pass before build
* \param funcs The functions to be built.
* \param target The target device to build for.
* \param target_host The target for building host code. To use the default, pass Target()
* \param config The build configuration.
* \return The Array<Array<LoweredFunc>> with 2 elements. First is host function Array,
second is device function array
*/
TVM_DLL Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs,
const Target& target,
const Target& target_host,
const BuildConfig& config);
/*! /*!
* \brief Build a device and host module for a specific target from an array of lowered functions. * \brief Build a device and host module for a specific target from an array of lowered functions.
......
...@@ -423,7 +423,7 @@ Array<LoweredFunc> lower(Schedule sch, ...@@ -423,7 +423,7 @@ Array<LoweredFunc> lower(Schedule sch,
return Array<LoweredFunc>({ ir::MakeAPI(stmt, name, out_arg_list, 0, config->restricted_func) }); return Array<LoweredFunc>({ ir::MakeAPI(stmt, name, out_arg_list, 0, config->restricted_func) });
} }
runtime::Module build(const Array<LoweredFunc>& funcs, Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs,
const Target& target, const Target& target,
const Target& target_host, const Target& target_host,
const BuildConfig& config) { const BuildConfig& config) {
...@@ -493,6 +493,17 @@ runtime::Module build(const Array<LoweredFunc>& funcs, ...@@ -493,6 +493,17 @@ runtime::Module build(const Array<LoweredFunc>& funcs,
func = ir::CombineContextCall(func); func = ir::CombineContextCall(func);
fhost.Set(i, func); fhost.Set(i, func);
} }
return {fhost, fdevice};
}
runtime::Module build(const Array<LoweredFunc>& funcs,
const Target& target,
const Target& target_host,
const BuildConfig& config) {
auto target_host_val = target_host.defined() ? target_host : DefaultTargetHost(target);
auto host_dev_funcs = split_dev_host_funcs(funcs, target, target_host, config);
auto& fhost = host_dev_funcs[0];
auto& fdevice = host_dev_funcs[1];
auto mhost = codegen::Build(fhost, target_host_val->str()); auto mhost = codegen::Build(fhost, target_host_val->str());
......
...@@ -371,7 +371,9 @@ class CompileEngineImpl : public CompileEngineNode { ...@@ -371,7 +371,9 @@ class CompileEngineImpl : public CompileEngineNode {
cache_node->funcs = (*f)( cache_node->funcs = (*f)(
spair.first, all_args, cache_node->func_name, key->source_func); spair.first, all_args, cache_node->func_name, key->source_func);
} else { } else {
LOG(FATAL) << "relay.backend.lower is not registred"; tvm::BuildConfig bcfg = tvm::build_config();
std::unordered_map<Tensor, Buffer> binds;
cache_node->funcs = tvm::lower(spair.first, all_args, cache_node->func_name, binds, bcfg);
} }
value->cached_func = CachedFunc(cache_node); value->cached_func = CachedFunc(cache_node);
return value; return value;
......
...@@ -416,7 +416,12 @@ class GraphRuntimeCodegen ...@@ -416,7 +416,12 @@ class GraphRuntimeCodegen
} else { } else {
// heterogeneous execution. // heterogeneous execution.
const auto call_dev_key = std::to_string(call_dev_type); const auto call_dev_key = std::to_string(call_dev_type);
const auto call_dev_name = runtime::DeviceName(call_dev_type); std::string call_dev_name;
if (call_dev_type == 0) {
call_dev_name = "llvm";
} else {
call_dev_name = runtime::DeviceName(call_dev_type);
}
if (targets_.count(call_dev_name) == 0 && targets_.count(call_dev_key) == 0) { if (targets_.count(call_dev_name) == 0 && targets_.count(call_dev_key) == 0) {
LOG(FATAL) << "No target is provided for device " LOG(FATAL) << "No target is provided for device "
<< call_dev_name; << call_dev_name;
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <gtest/gtest.h>
#include <tvm/tvm.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/type.h>
#include <tvm/relay/pass.h>
#include <topi/generic/injective.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
TVM_REGISTER_GLOBAL("test.sch")
.set_body([](tvm::TVMArgs args, tvm::TVMRetValue *rv) {
*rv = topi::generic::schedule_injective(args[0], args[1]);
});
TEST(Relay, BuildModule) {
using namespace tvm;
auto tensor_type = relay::TensorTypeNode::make({2, 3}, ::tvm::Float(32));
auto a = relay::VarNode::make("a", tensor_type);
auto b = relay::VarNode::make("b", tensor_type);
auto add_op = relay::Op::Get("add");
auto x = relay::CallNode::make(add_op, {a, b}, tvm::Attrs(), {});
auto c = relay::VarNode::make("c", tensor_type);
auto y = relay::CallNode::make(add_op, {x, c}, tvm::Attrs(), {});
auto func = relay::FunctionNode::make(relay::FreeVars(y), y, relay::Type(), {});
auto A = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
auto B = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
auto C = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
auto pA = (float*)A.ToDLPack()->dl_tensor.data;
auto pB = (float*)B.ToDLPack()->dl_tensor.data;
auto pC = (float*)C.ToDLPack()->dl_tensor.data;
for (int i = 0; i < 6; ++i) {
pA[i] = i;
pB[i] = i + 1;
pC[i] = i + 2;
}
// get schedule
auto reg = tvm::runtime::Registry::Get("relay.op._Register");
auto s_i = tvm::runtime::Registry::Get("test.sch");
if (!reg) {
LOG(FATAL) << "no _Register";
}
if (!s_i) {
LOG(FATAL) << "no _Register";
}
(*reg)("add", "FTVMSchedule", *s_i, 10);
// build
auto pfb = tvm::runtime::Registry::Get("relay.build_module._BuildModule");
tvm::runtime::Module build_mod = (*pfb)();
auto build_f = build_mod.GetFunction("build", false);
auto json_f = build_mod.GetFunction("get_graph_json", false);
auto mod_f = build_mod.GetFunction("get_module", false);
Array<HalideIR::Expr> target_pair;
target_pair.push_back(ir::StringImm::make("cpu"));
target_pair.push_back(ir::StringImm::make("llvm"));
build_f(func, target_pair, "llvm");
std::string json = json_f();
tvm::runtime::Module mod = mod_f();
// run
auto ctx = A->ctx;
auto pfr = tvm::runtime::Registry::Get("tvm.graph_runtime.create");
tvm::runtime::Module run_mod = (*pfr)(json, mod, (int)ctx.device_type, (int)ctx.device_id);
auto set_input_f = run_mod.GetFunction("set_input", false);
auto run_f = run_mod.GetFunction("run", false);
auto get_output_f = run_mod.GetFunction("get_output", false);
set_input_f("a", A);
set_input_f("b", B);
set_input_f("c", C);
run_f();
tvm::runtime::NDArray Y = get_output_f(0);
auto pY = (float*)Y.ToDLPack()->dl_tensor.data;
for (int i = 0; i < 6; ++i) {
CHECK_LT(fabs(pY[i] - (i + (i + 1) + (i + 2))), 1e-4);
}
}
int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
}
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import numpy as np
import tvm
from tvm import relay
from tvm._ffi.function import _init_api
_init_api("tvm.relay.build_module")
class BuildModule(object):
def __init__(self):
self.mod = relay.build_module._BuildModule()
self._get_graph_json = self.mod["get_graph_json"]
self._get_module = self.mod["get_module"]
self._build = self.mod["build"]
self._set_opt_level = self.mod["set_opt_level"]
self._set_params_func = self.mod["set_params"]
self._get_params_func = self.mod["get_params"]
def build(self, func, target, target_host, params):
tgts = []
for kv in target.items():
tgts.append(kv[0])
tgts.append(kv[1])
self._set_params(params)
self._build(func, tgts, target_host)
def get_json(self):
return self._get_graph_json()
def get_module(self):
return self._get_module()
def set_opt_level(self, level):
self._set_opt_level(level)
def _set_params(self, params):
inputs = {}
for name, param in params.items():
inputs[name] = relay.Constant(param)
self._set_params_func(inputs)
def get_params(self):
params = self._get_params_func()
ret = {}
for key, value in params.items():
ret[key] = value.data
return ret
def test_build():
m_bld = BuildModule()
tgt_name = "llvm"
tgt = "llvm"
ctx = tvm.cpu()
# func
a = relay.var("a", dtype="float32", shape=(16, 8))
b = relay.var("b", dtype="float32", shape=(8, 8))
c = relay.var("c", dtype="float32", shape=(16, 8))
x = relay.nn.dense(a, b)
y = relay.nn.relu(x)
z = y + c
func = relay.Function([a, b, c], z)
A = tvm.nd.array(np.random.uniform(-1, 1, (16, 8)).astype("float32"), ctx=ctx)
B = tvm.nd.array(np.random.uniform(-1, 1, (8, 8)).astype("float32"), ctx=ctx)
C = tvm.nd.array(np.random.uniform(-1, 1, (16, 8)).astype("float32"), ctx=ctx)
params = {
"b" : B,
"c" : C
}
# build
targets = {
tgt: tgt
}
m_bld.set_opt_level(3)
m_bld.build(func, targets, "llvm -mcpu=sse3", params=params)
g_json = m_bld.get_json()
mmod = m_bld.get_module()
params = m_bld.get_params()
# test
rt = tvm.contrib.graph_runtime.create(g_json, mmod, ctx)
rt.set_input("a", A)
rt.load_params(relay.save_param_dict(params))
rt.run()
out = rt.get_output(0)
np.testing.assert_allclose(out.asnumpy(),
np.maximum(np.dot(A.asnumpy(), B.asnumpy().T), 0) + C.asnumpy(), atol=1e-5, rtol=1e-5)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment