Commit 0f2a3086 by Zhi Committed by Tianqi Chen

[Relay][Compilation] replace relay.build_module with C++ BuildModule (#3174)

parent 7d845f0d
......@@ -25,7 +25,7 @@ from . import expr_functor
from . import module
from . import adt
from . import ir_pass
from .build_module import build, build_config, create_executor, optimize
from .build_module import build, build_config, create_executor
from . import prelude
from . import parser
from . import debug
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable
"""The interface for building Relay functions exposed from C++."""
from tvm._ffi.function import _init_api
_init_api("relay.build_module", __name__)
......@@ -36,12 +36,9 @@ contrib.graph_runtime or any other TVM runtime compatible systems.
from __future__ import absolute_import
from tvm.ndarray import empty
from tvm._ffi.function import _init_api
from tvm.relay import build_module
from tvm import target as _target
_init_api("tvm.relay.build_module")
from tvm import expr as _expr
class GraphRuntimeCodegen(object):
"""The compiler from Relay to the TVM runtime system."""
......@@ -57,17 +54,14 @@ class GraphRuntimeCodegen(object):
self._setup(mod, target)
def _setup(self, mod, target):
tgts = []
tgts = {}
if isinstance(target, dict):
for kv in target.items():
tgts.append(kv[0])
if isinstance(kv[1], (str, _target.Target)):
tgts.append(str(kv[1]))
else:
for dev, tgt in target.items():
if not isinstance(tgt, (str, _target.Target)):
raise Exception("Unknown target type")
tgts[dev] = _target.create(tgt)
elif isinstance(target, (str, _target.Target)):
tgts.append("0")
tgts.append(str(target))
tgts[_expr.IntImm("int32", 0)] = _target.create(target)
self._init(mod, tgts)
def codegen(self, func):
......
......@@ -269,6 +269,77 @@ def realize(graph):
return _quantize.realize(graph)
def optimize(func, params=None):
""" Perform "SimplifyInference", "FoldScaleAxis", "FoldConstant", and
"CanonicalizeOps" optimization before quantization.
# TODO(zhiics) These passes are executed one by one so far. We need to
# move them to the pass manager.
Parameters
---------
func: tvm.relay.Function
The original Relay function to be optimized.
params : dict of str to tvm.NDArray
Input parameters to the graph that do not change
during inference time. Used for constant folding.
Returns
-------
ret: tvm.relay.Function
The graph after quantization
"""
opt_passes = ["SimplifyInference",
"FoldScaleAxis",
"FoldConstant",
"CanonicalizeOps"]
cfg = _build.build_config(add_pass=opt_passes)
if params:
name_dict = {}
for arg in func.params:
name = arg.name_hint
if name in name_dict:
name_dict[name] = None
else:
name_dict[name] = arg
bind_dict = {}
for k, v in params.items():
if k not in name_dict:
continue
arg = name_dict[k]
if arg is None:
raise ValueError("Multiple args in the function have name %s" % k)
bind_dict[arg] = _expr.const(v)
func = _expr.bind(func, bind_dict)
if "SimplifyInference" in cfg.add_pass:
func = _ir_pass.infer_type(func)
func = _ir_pass.simplify_inference(func)
if "FoldConstant" in cfg.add_pass:
func = _ir_pass.fold_constant(func)
if "FoldScaleAxis" in cfg.add_pass:
func = _ir_pass.infer_type(func)
func = _ir_pass.backward_fold_scale_axis(func)
func = _ir_pass.infer_type(func)
func = _ir_pass.forward_fold_scale_axis(func)
func = _ir_pass.fold_constant(func)
if "CanonicalizeOps" in cfg.add_pass:
func = _ir_pass.infer_type(func)
func = _ir_pass.canonicalize_ops(func)
if "FoldConstant" in cfg.add_pass:
func = _ir_pass.fold_constant(func)
return func
def quantize(graph, params=None, dataset=None):
""" The quantization procedure. Before running the three main
procedure of quantization, "annotate", "calibrate" and "realize"
......@@ -292,12 +363,8 @@ def quantize(graph, params=None, dataset=None):
ret: Function
The graph after quantization
"""
opt_passes = ["SimplifyInference",
"FoldScaleAxis",
"FoldConstant",
"CanonicalizeOps"]
with _build.build_config(add_pass=opt_passes):
graph = _build.optimize(graph, params=params)
# TODO(zhiics) Move this to the pass manager.
graph = optimize(graph, params)
graph = annotate(graph)
graph = calibrate(graph, dataset)
......
......@@ -311,7 +311,7 @@ bool LLVMEnabled() {
/*! \return The default host target for a given device target */
Target DefaultTargetHost(Target target) {
if (target->device_type == kDLCPU) {
if (target.defined() && target->device_type == kDLCPU) {
return target;
} else {
if (LLVMEnabled()) {
......
......@@ -51,7 +51,7 @@ using GraphAttrs = std::unordered_map<std::string, dmlc::any>;
using GraphNodePtr = std::shared_ptr<GraphNode>;
using GraphInputNodePtr = std::shared_ptr<GraphInputNode>;
using GraphOpNodePtr = std::shared_ptr<GraphOpNode>;
using TargetsMap = std::unordered_map<std::string, Target>;
using TargetsMap = std::unordered_map<int, Target>;
/*! \brief Lowered outputs */
struct LoweredOutput {
......@@ -193,12 +193,10 @@ class GraphOpNode : public GraphNode {
class GraphRuntimeCodegen
: public ::tvm::relay::ExprFunctor<std::vector<GraphNodeRef>(const Expr&)> {
public:
GraphRuntimeCodegen(runtime::Module* mod,
const std::unordered_map<std::string, std::string>& targets) : mod_(mod) {
GraphRuntimeCodegen(runtime::Module* mod, const TargetsMap& targets)
: mod_(mod) {
compile_engine_ = CompileEngine::Global();
for (auto &kv : targets) {
targets_[kv.first] = Target::create(kv.second);
}
targets_ = targets;
}
LoweredOutput Codegen(relay::Function func) {
......@@ -406,7 +404,7 @@ class GraphRuntimeCodegen
auto pf0 = GetPackedFunc("relay.backend._make_CCacheKey");
auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower");
auto &device_type = storage_device_map_[expr][1];
auto call_dev_type = device_type[0]->value; //-> int to string
auto call_dev_type = device_type[0]->value;
Target target;
if (targets_.size() == 1) {
// homogeneous execution.
......@@ -415,22 +413,17 @@ class GraphRuntimeCodegen
}
} else {
// heterogeneous execution.
const auto call_dev_key = std::to_string(call_dev_type);
std::string call_dev_name;
if (call_dev_type == 0) {
call_dev_name = "llvm";
} else {
call_dev_name = runtime::DeviceName(call_dev_type);
}
if (targets_.count(call_dev_name) == 0 && targets_.count(call_dev_key) == 0) {
if (targets_.count(call_dev_type) == 0) {
LOG(FATAL) << "No target is provided for device "
<< call_dev_name;
}
if (targets_.count(call_dev_key)) {
target = targets_[call_dev_key];
} else {
target = targets_[call_dev_name];
}
target = targets_[call_dev_type];
}
CCacheKey key = (*pf0)(func, target);
CachedFunc lowerd_func = (*pf1)(compile_engine_, key);
......@@ -604,30 +597,21 @@ class GraphRuntimeCodegenModule : public runtime::ModuleNode {
virtual PackedFunc GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) {
if (name == "init") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.num_args, 2) << "The expected of arguments are: "
<< "runtime::Module mod and Map<str, StringImm> targets";
void* mod = args[0];
auto& sptr = args[1].node_sptr();
auto* node = static_cast<const ArrayNode*>(sptr.get());
auto& tmp_targets = node->data;
std::unordered_map<std::string, std::string> targets;
for (size_t i = 0; i < tmp_targets.size(); i += 2) {
std::string key;
auto sk = Expr(tmp_targets[i]).as<ir::StringImm>();
auto ik = Expr(tmp_targets[i]).as<ir::IntImm>();
if (sk) {
key = sk->value;
}
if (ik) {
key = std::to_string(ik->value);
}
auto v = Expr(tmp_targets[i + 1]).as<ir::StringImm>();
targets[key] = v->value;
}
codegen_ = std::make_shared<GraphRuntimeCodegen>(
reinterpret_cast<runtime::Module*>(mod), targets);
});
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.num_args, 2)
<< "The expected of arguments are: "
<< "runtime::Module mod and Map<int, Target> targets";
void* mod = args[0];
Map<Integer, tvm::Target> tmp = args[1];
TargetsMap targets;
for (const auto& it : tmp) {
auto dev_type = it.first.as<ir::IntImm>();
CHECK(dev_type);
targets[dev_type->value] = it.second;
}
codegen_ = std::make_shared<GraphRuntimeCodegen>(
reinterpret_cast<runtime::Module*>(mod), targets);
});
} else if (name == "codegen") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
Function func = args[0];
......
......@@ -18,6 +18,7 @@
*/
#include <gtest/gtest.h>
#include <tvm/build_module.h>
#include <tvm/tvm.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/type.h>
......@@ -73,10 +74,10 @@ TEST(Relay, BuildModule) {
auto build_f = build_mod.GetFunction("build", false);
auto json_f = build_mod.GetFunction("get_graph_json", false);
auto mod_f = build_mod.GetFunction("get_module", false);
Array<HalideIR::Expr> target_pair;
target_pair.push_back(ir::StringImm::make("cpu"));
target_pair.push_back(ir::StringImm::make("llvm"));
build_f(func, target_pair, "llvm");
Map<tvm::Integer, tvm::Target> targets;
Target llvm_tgt = Target::create("llvm");
targets.Set(0, llvm_tgt);
build_f(func, targets, llvm_tgt);
std::string json = json_f();
tvm::runtime::Module mod = mod_f();
// run
......
......@@ -74,13 +74,12 @@ def test_alter_layout_conv2d():
for tgt in targets:
with tvm.target.create(tgt) as target:
with relay.build_config(opt_level=-1, add_pass='AlterOpLayout'):
with autotvm.tophub.context(target):
O = relay.optimize(N, target, params=None)
O = relay.ir_pass.infer_type(O)
with autotvm.tophub.context(target):
O = relay.ir_pass.alter_op_layout(N)
O = relay.ir_pass.infer_type(O)
# graph should differ
assert not relay.ir_pass.alpha_equal(N, O)
# graph should differ
assert not relay.ir_pass.alpha_equal(N, O)
if __name__ == "__main__":
np.random.seed(42)
......
......@@ -18,55 +18,10 @@ import numpy as np
import tvm
from tvm import relay
from tvm.contrib.nvcc import have_fp16
from tvm._ffi.function import _init_api
_init_api("tvm.relay.build_module")
class BuildModule(object):
def __init__(self):
self.mod = relay.build_module._BuildModule()
self._get_graph_json = self.mod["get_graph_json"]
self._get_module = self.mod["get_module"]
self._build = self.mod["build"]
self._set_opt_level = self.mod["set_opt_level"]
self._set_params_func = self.mod["set_params"]
self._get_params_func = self.mod["get_params"]
def build(self, func, target, target_host, params):
tgts = []
for kv in target.items():
tgts.append(kv[0])
tgts.append(kv[1])
self._set_params(params)
self._build(func, tgts, target_host)
def get_json(self):
return self._get_graph_json()
def get_module(self):
return self._get_module()
def set_opt_level(self, level):
self._set_opt_level(level)
def _set_params(self, params):
inputs = {}
for name, param in params.items():
inputs[name] = relay.Constant(param)
self._set_params_func(inputs)
def get_params(self):
params = self._get_params_func()
ret = {}
for key, value in params.items():
ret[key] = value.data
return ret
def test_build():
m_bld = BuildModule()
tgt_name = "llvm"
def test_basic_build():
tgt = "llvm"
ctx = tvm.cpu()
# func
......@@ -86,21 +41,96 @@ def test_build():
}
# build
targets = {
tgt: tgt
tvm.expr.IntImm("int32", ctx.device_type): tgt
}
m_bld.set_opt_level(3)
m_bld.build(func, targets, "llvm", params=params)
g_json = m_bld.get_json()
mmod = m_bld.get_module()
params = m_bld.get_params()
g_json, mmod, params = relay.build(func, targets, "llvm", params=params)
# test
rt = tvm.contrib.graph_runtime.create(g_json, mmod, ctx)
rt.set_input("a", A)
rt.load_params(relay.save_param_dict(params))
rt.run()
out = rt.get_output(0)
np.testing.assert_allclose(out.asnumpy(),
np.maximum(np.dot(A.asnumpy(), B.asnumpy().T), 0) + C.asnumpy(), atol=1e-5, rtol=1e-5)
np.testing.assert_allclose(out.asnumpy(), np.maximum(np.dot(A.asnumpy(),
B.asnumpy().T),
0) + C.asnumpy(),
atol=1e-5, rtol=1e-5)
def test_fp16_build():
dtype = "float16"
if not tvm.module.enabled("cuda") or not tvm.gpu(0).exist:
print("skip because cuda is not enabled.")
return
ctx = tvm.gpu(0)
if dtype == "float16" and not have_fp16(ctx.compute_version):
print("skip because gpu does not support fp16")
return
x = relay.var("x", dtype=dtype, shape=(4, 4))
y = relay.var("y", dtype=dtype, shape=(4, 4))
z = x + y
func = relay.Function([x, y], z)
X = tvm.nd.array(np.random.uniform(-1, 1, (4, 4)).astype(dtype), ctx=ctx)
Y = tvm.nd.array(np.random.uniform(-1, 1, (4, 4)).astype(dtype), ctx=ctx)
params = {
"x": X,
"y": Y,
}
# build
g_json, mmod, params = relay.build(func, "cuda", params=params)
# test
rt = tvm.contrib.graph_runtime.create(g_json, mmod, ctx)
rt.load_params(relay.save_param_dict(params))
rt.run()
out = rt.get_output(0)
np.testing.assert_allclose(out.asnumpy(), X.asnumpy() + Y.asnumpy(),
atol=1e-5, rtol=1e-5)
def test_fp16_conversion():
def check_conversion(tgt, ctx):
if not tvm.module.enabled(tgt):
print("skip because {} is not enabled.".format(tgt))
return
elif tgt == "cuda" and ctx.exist and not have_fp16(ctx.compute_version):
print("skip because gpu does not support fp16")
return
n = 10
for (src, dst) in [('float32', 'float16'), ('float16', 'float32')]:
x = relay.var("x", relay.TensorType((n,), src))
y = x.astype(dst)
func = relay.Function([x], y)
# init input
X = tvm.nd.array(n * np.random.randn(n).astype(src) - n / 2)
# build
with relay.build_config(opt_level=1):
g_json, mmod, params = relay.build(func, tgt)
# test
rt = tvm.contrib.graph_runtime.create(g_json, mmod, ctx)
rt.set_input("x", X)
rt.run()
out = rt.get_output(0)
np.testing.assert_allclose(out.asnumpy(), X.asnumpy().astype(dst),
atol=1e-5, rtol=1e-5)
for target, ctx in [('llvm', tvm.cpu()), ('cuda', tvm.gpu())]:
check_conversion(target, ctx)
if __name__ == "__main__":
test_basic_build()
test_fp16_build()
test_fp16_conversion()
......@@ -411,7 +411,7 @@ def run_fusible_network(dev, tgt):
expected_index)
def test_fallback_all_operators(device, tgt):
target = {device: tgt}
target = {device: tgt, "cpu": "llvm"}
annotated_func = get_func()
expected_func = get_func()
check_annotated_graph(annotated_func, expected_func)
......
......@@ -47,54 +47,54 @@ def test_simulated_quantize():
assert out.args[3].checked_type == relay.ty.TensorType(tuple(), "float32")
# def test_quantize_pass():
# def quantize_weight(arr):
# maximum = np.amax(np.abs(arr.asnumpy()))
# scale = 2**math.ceil(math.log(maximum, 2))
# out = np.around(arr.asnumpy() / scale * 128).astype('int8')
# out = np.clip(out, -127, 127)
# return relay.const(out, 'int8')
#
# n, c, h, w = 1, 3, 224, 224
# def make_graph(data):
# weight = relay.var("conv_weight")
# out = relay.nn.conv2d(data, weight, kernel_size=(3, 3), padding=(1, 1), channels=c)
# out = relay.Function(relay.ir_pass.free_vars(out), out)
# return out
#
# def make_qgraph(data, weight):
# out = data * relay.const(32.0)
# out = relay.round(out)
# out = relay.clip(out, a_min=-127, a_max=127)
# out = out.astype('int8')
#
# out = relay.nn.conv2d(out, weight, kernel_size=(3, 3),
# padding=(1, 1), channels=c, out_dtype='int32')
# out = out.astype('float32')
# out = relay.multiply(out, relay.const(0.00024414062))
# out = relay.Function(relay.ir_pass.free_vars(out), out)
# return out
#
# data = relay.var("data", relay.TensorType((n, c, h, w), "float32"))
# graph = make_graph(data)
# dataset, params = make_dataset(graph, 10)
#
# with qtz.qconfig(skip_k_conv=0, global_scale=4.0,
# round_for_shift=False, store_lowbit_output=False):
# qgraph0 = qtz.quantize(graph, params)
# qgraph0 = relay.ir_pass.infer_type(qgraph0)
#
# conv_weight = quantize_weight(params['conv_weight'])
# qgraph1 = make_qgraph(data, conv_weight)
# qgraph1 = relay.ir_pass.infer_type(qgraph1)
#
# graph = relay.create_executor('graph')
# res0 = graph.evaluate(qgraph0)(dataset[0]['data'])
# res1 = graph.evaluate(qgraph1)(dataset[0]['data'])
# tvm.testing.assert_allclose(res0.asnumpy(), res1.asnumpy(), rtol=1e-3)
def test_quantize_pass():
def quantize_weight(arr):
maximum = np.amax(np.abs(arr.asnumpy()))
scale = 2**math.ceil(math.log(maximum, 2))
out = np.around(arr.asnumpy() / scale * 128).astype('int8')
out = np.clip(out, -127, 127)
return relay.const(out, 'int8')
n, c, h, w = 1, 3, 224, 224
def make_graph(data):
weight = relay.var("conv_weight")
out = relay.nn.conv2d(data, weight, kernel_size=(3, 3), padding=(1, 1), channels=c)
out = relay.Function(relay.ir_pass.free_vars(out), out)
return out
def make_qgraph(data, weight):
out = data * relay.const(32.0)
out = relay.round(out)
out = relay.clip(out, a_min=-127, a_max=127)
out = out.astype('int8')
out = relay.nn.conv2d(out, weight, kernel_size=(3, 3),
padding=(1, 1), channels=c, out_dtype='int32')
out = out.astype('float32')
out = relay.multiply(out, relay.const(0.00024414062))
out = relay.Function(relay.ir_pass.free_vars(out), out)
return out
data = relay.var("data", relay.TensorType((n, c, h, w), "float32"))
graph = make_graph(data)
dataset, params = make_dataset(graph, 10)
with qtz.qconfig(skip_k_conv=0, global_scale=4.0,
round_for_shift=False, store_lowbit_output=False):
qgraph0 = qtz.quantize(graph, params)
qgraph0 = relay.ir_pass.infer_type(qgraph0)
conv_weight = quantize_weight(params['conv_weight'])
qgraph1 = make_qgraph(data, conv_weight)
qgraph1 = relay.ir_pass.infer_type(qgraph1)
graph = relay.create_executor('graph')
res0 = graph.evaluate(qgraph0)(dataset[0]['data'])
res1 = graph.evaluate(qgraph1)(dataset[0]['data'])
tvm.testing.assert_allclose(res0.asnumpy(), res1.asnumpy(), rtol=1e-3)
if __name__ == "__main__":
np.random.seed(42)
test_simulated_quantize()
# test_quantize_pass()
test_quantize_pass()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment