Commit 0f2a3086 by Zhi Committed by Tianqi Chen

[Relay][Compilation] replace relay.build_module with C++ BuildModule (#3174)

parent 7d845f0d
...@@ -25,7 +25,7 @@ from . import expr_functor ...@@ -25,7 +25,7 @@ from . import expr_functor
from . import module from . import module
from . import adt from . import adt
from . import ir_pass from . import ir_pass
from .build_module import build, build_config, create_executor, optimize from .build_module import build, build_config, create_executor
from . import prelude from . import prelude
from . import parser from . import parser
from . import debug from . import debug
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable
"""The interface for building Relay functions exposed from C++."""
from tvm._ffi.function import _init_api
_init_api("relay.build_module", __name__)
...@@ -36,12 +36,9 @@ contrib.graph_runtime or any other TVM runtime compatible systems. ...@@ -36,12 +36,9 @@ contrib.graph_runtime or any other TVM runtime compatible systems.
from __future__ import absolute_import from __future__ import absolute_import
from tvm.ndarray import empty from tvm.ndarray import empty
from tvm._ffi.function import _init_api
from tvm.relay import build_module from tvm.relay import build_module
from tvm import target as _target from tvm import target as _target
from tvm import expr as _expr
_init_api("tvm.relay.build_module")
class GraphRuntimeCodegen(object): class GraphRuntimeCodegen(object):
"""The compiler from Relay to the TVM runtime system.""" """The compiler from Relay to the TVM runtime system."""
...@@ -57,17 +54,14 @@ class GraphRuntimeCodegen(object): ...@@ -57,17 +54,14 @@ class GraphRuntimeCodegen(object):
self._setup(mod, target) self._setup(mod, target)
def _setup(self, mod, target): def _setup(self, mod, target):
tgts = [] tgts = {}
if isinstance(target, dict): if isinstance(target, dict):
for kv in target.items(): for dev, tgt in target.items():
tgts.append(kv[0]) if not isinstance(tgt, (str, _target.Target)):
if isinstance(kv[1], (str, _target.Target)):
tgts.append(str(kv[1]))
else:
raise Exception("Unknown target type") raise Exception("Unknown target type")
tgts[dev] = _target.create(tgt)
elif isinstance(target, (str, _target.Target)): elif isinstance(target, (str, _target.Target)):
tgts.append("0") tgts[_expr.IntImm("int32", 0)] = _target.create(target)
tgts.append(str(target))
self._init(mod, tgts) self._init(mod, tgts)
def codegen(self, func): def codegen(self, func):
......
...@@ -269,6 +269,77 @@ def realize(graph): ...@@ -269,6 +269,77 @@ def realize(graph):
return _quantize.realize(graph) return _quantize.realize(graph)
def optimize(func, params=None):
""" Perform "SimplifyInference", "FoldScaleAxis", "FoldConstant", and
"CanonicalizeOps" optimization before quantization.
# TODO(zhiics) These passes are executed one by one so far. We need to
# move them to the pass manager.
Parameters
---------
func: tvm.relay.Function
The original Relay function to be optimized.
params : dict of str to tvm.NDArray
Input parameters to the graph that do not change
during inference time. Used for constant folding.
Returns
-------
ret: tvm.relay.Function
The graph after quantization
"""
opt_passes = ["SimplifyInference",
"FoldScaleAxis",
"FoldConstant",
"CanonicalizeOps"]
cfg = _build.build_config(add_pass=opt_passes)
if params:
name_dict = {}
for arg in func.params:
name = arg.name_hint
if name in name_dict:
name_dict[name] = None
else:
name_dict[name] = arg
bind_dict = {}
for k, v in params.items():
if k not in name_dict:
continue
arg = name_dict[k]
if arg is None:
raise ValueError("Multiple args in the function have name %s" % k)
bind_dict[arg] = _expr.const(v)
func = _expr.bind(func, bind_dict)
if "SimplifyInference" in cfg.add_pass:
func = _ir_pass.infer_type(func)
func = _ir_pass.simplify_inference(func)
if "FoldConstant" in cfg.add_pass:
func = _ir_pass.fold_constant(func)
if "FoldScaleAxis" in cfg.add_pass:
func = _ir_pass.infer_type(func)
func = _ir_pass.backward_fold_scale_axis(func)
func = _ir_pass.infer_type(func)
func = _ir_pass.forward_fold_scale_axis(func)
func = _ir_pass.fold_constant(func)
if "CanonicalizeOps" in cfg.add_pass:
func = _ir_pass.infer_type(func)
func = _ir_pass.canonicalize_ops(func)
if "FoldConstant" in cfg.add_pass:
func = _ir_pass.fold_constant(func)
return func
def quantize(graph, params=None, dataset=None): def quantize(graph, params=None, dataset=None):
""" The quantization procedure. Before running the three main """ The quantization procedure. Before running the three main
procedure of quantization, "annotate", "calibrate" and "realize" procedure of quantization, "annotate", "calibrate" and "realize"
...@@ -292,12 +363,8 @@ def quantize(graph, params=None, dataset=None): ...@@ -292,12 +363,8 @@ def quantize(graph, params=None, dataset=None):
ret: Function ret: Function
The graph after quantization The graph after quantization
""" """
opt_passes = ["SimplifyInference", # TODO(zhiics) Move this to the pass manager.
"FoldScaleAxis", graph = optimize(graph, params)
"FoldConstant",
"CanonicalizeOps"]
with _build.build_config(add_pass=opt_passes):
graph = _build.optimize(graph, params=params)
graph = annotate(graph) graph = annotate(graph)
graph = calibrate(graph, dataset) graph = calibrate(graph, dataset)
......
...@@ -311,7 +311,7 @@ bool LLVMEnabled() { ...@@ -311,7 +311,7 @@ bool LLVMEnabled() {
/*! \return The default host target for a given device target */ /*! \return The default host target for a given device target */
Target DefaultTargetHost(Target target) { Target DefaultTargetHost(Target target) {
if (target->device_type == kDLCPU) { if (target.defined() && target->device_type == kDLCPU) {
return target; return target;
} else { } else {
if (LLVMEnabled()) { if (LLVMEnabled()) {
......
...@@ -51,7 +51,7 @@ using GraphAttrs = std::unordered_map<std::string, dmlc::any>; ...@@ -51,7 +51,7 @@ using GraphAttrs = std::unordered_map<std::string, dmlc::any>;
using GraphNodePtr = std::shared_ptr<GraphNode>; using GraphNodePtr = std::shared_ptr<GraphNode>;
using GraphInputNodePtr = std::shared_ptr<GraphInputNode>; using GraphInputNodePtr = std::shared_ptr<GraphInputNode>;
using GraphOpNodePtr = std::shared_ptr<GraphOpNode>; using GraphOpNodePtr = std::shared_ptr<GraphOpNode>;
using TargetsMap = std::unordered_map<std::string, Target>; using TargetsMap = std::unordered_map<int, Target>;
/*! \brief Lowered outputs */ /*! \brief Lowered outputs */
struct LoweredOutput { struct LoweredOutput {
...@@ -193,12 +193,10 @@ class GraphOpNode : public GraphNode { ...@@ -193,12 +193,10 @@ class GraphOpNode : public GraphNode {
class GraphRuntimeCodegen class GraphRuntimeCodegen
: public ::tvm::relay::ExprFunctor<std::vector<GraphNodeRef>(const Expr&)> { : public ::tvm::relay::ExprFunctor<std::vector<GraphNodeRef>(const Expr&)> {
public: public:
GraphRuntimeCodegen(runtime::Module* mod, GraphRuntimeCodegen(runtime::Module* mod, const TargetsMap& targets)
const std::unordered_map<std::string, std::string>& targets) : mod_(mod) { : mod_(mod) {
compile_engine_ = CompileEngine::Global(); compile_engine_ = CompileEngine::Global();
for (auto &kv : targets) { targets_ = targets;
targets_[kv.first] = Target::create(kv.second);
}
} }
LoweredOutput Codegen(relay::Function func) { LoweredOutput Codegen(relay::Function func) {
...@@ -406,7 +404,7 @@ class GraphRuntimeCodegen ...@@ -406,7 +404,7 @@ class GraphRuntimeCodegen
auto pf0 = GetPackedFunc("relay.backend._make_CCacheKey"); auto pf0 = GetPackedFunc("relay.backend._make_CCacheKey");
auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower"); auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower");
auto &device_type = storage_device_map_[expr][1]; auto &device_type = storage_device_map_[expr][1];
auto call_dev_type = device_type[0]->value; //-> int to string auto call_dev_type = device_type[0]->value;
Target target; Target target;
if (targets_.size() == 1) { if (targets_.size() == 1) {
// homogeneous execution. // homogeneous execution.
...@@ -415,22 +413,17 @@ class GraphRuntimeCodegen ...@@ -415,22 +413,17 @@ class GraphRuntimeCodegen
} }
} else { } else {
// heterogeneous execution. // heterogeneous execution.
const auto call_dev_key = std::to_string(call_dev_type);
std::string call_dev_name; std::string call_dev_name;
if (call_dev_type == 0) { if (call_dev_type == 0) {
call_dev_name = "llvm"; call_dev_name = "llvm";
} else { } else {
call_dev_name = runtime::DeviceName(call_dev_type); call_dev_name = runtime::DeviceName(call_dev_type);
} }
if (targets_.count(call_dev_name) == 0 && targets_.count(call_dev_key) == 0) { if (targets_.count(call_dev_type) == 0) {
LOG(FATAL) << "No target is provided for device " LOG(FATAL) << "No target is provided for device "
<< call_dev_name; << call_dev_name;
} }
if (targets_.count(call_dev_key)) { target = targets_[call_dev_type];
target = targets_[call_dev_key];
} else {
target = targets_[call_dev_name];
}
} }
CCacheKey key = (*pf0)(func, target); CCacheKey key = (*pf0)(func, target);
CachedFunc lowerd_func = (*pf1)(compile_engine_, key); CachedFunc lowerd_func = (*pf1)(compile_engine_, key);
...@@ -604,30 +597,21 @@ class GraphRuntimeCodegenModule : public runtime::ModuleNode { ...@@ -604,30 +597,21 @@ class GraphRuntimeCodegenModule : public runtime::ModuleNode {
virtual PackedFunc GetFunction(const std::string& name, virtual PackedFunc GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) { const std::shared_ptr<ModuleNode>& sptr_to_self) {
if (name == "init") { if (name == "init") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.num_args, 2) << "The expected of arguments are: " CHECK_EQ(args.num_args, 2)
<< "runtime::Module mod and Map<str, StringImm> targets"; << "The expected of arguments are: "
void* mod = args[0]; << "runtime::Module mod and Map<int, Target> targets";
auto& sptr = args[1].node_sptr(); void* mod = args[0];
auto* node = static_cast<const ArrayNode*>(sptr.get()); Map<Integer, tvm::Target> tmp = args[1];
auto& tmp_targets = node->data; TargetsMap targets;
std::unordered_map<std::string, std::string> targets; for (const auto& it : tmp) {
for (size_t i = 0; i < tmp_targets.size(); i += 2) { auto dev_type = it.first.as<ir::IntImm>();
std::string key; CHECK(dev_type);
auto sk = Expr(tmp_targets[i]).as<ir::StringImm>(); targets[dev_type->value] = it.second;
auto ik = Expr(tmp_targets[i]).as<ir::IntImm>(); }
if (sk) { codegen_ = std::make_shared<GraphRuntimeCodegen>(
key = sk->value; reinterpret_cast<runtime::Module*>(mod), targets);
} });
if (ik) {
key = std::to_string(ik->value);
}
auto v = Expr(tmp_targets[i + 1]).as<ir::StringImm>();
targets[key] = v->value;
}
codegen_ = std::make_shared<GraphRuntimeCodegen>(
reinterpret_cast<runtime::Module*>(mod), targets);
});
} else if (name == "codegen") { } else if (name == "codegen") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
Function func = args[0]; Function func = args[0];
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <tvm/build_module.h>
#include <tvm/tvm.h> #include <tvm/tvm.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/type.h> #include <tvm/relay/type.h>
...@@ -73,10 +74,10 @@ TEST(Relay, BuildModule) { ...@@ -73,10 +74,10 @@ TEST(Relay, BuildModule) {
auto build_f = build_mod.GetFunction("build", false); auto build_f = build_mod.GetFunction("build", false);
auto json_f = build_mod.GetFunction("get_graph_json", false); auto json_f = build_mod.GetFunction("get_graph_json", false);
auto mod_f = build_mod.GetFunction("get_module", false); auto mod_f = build_mod.GetFunction("get_module", false);
Array<HalideIR::Expr> target_pair; Map<tvm::Integer, tvm::Target> targets;
target_pair.push_back(ir::StringImm::make("cpu")); Target llvm_tgt = Target::create("llvm");
target_pair.push_back(ir::StringImm::make("llvm")); targets.Set(0, llvm_tgt);
build_f(func, target_pair, "llvm"); build_f(func, targets, llvm_tgt);
std::string json = json_f(); std::string json = json_f();
tvm::runtime::Module mod = mod_f(); tvm::runtime::Module mod = mod_f();
// run // run
......
...@@ -74,13 +74,12 @@ def test_alter_layout_conv2d(): ...@@ -74,13 +74,12 @@ def test_alter_layout_conv2d():
for tgt in targets: for tgt in targets:
with tvm.target.create(tgt) as target: with tvm.target.create(tgt) as target:
with relay.build_config(opt_level=-1, add_pass='AlterOpLayout'): with autotvm.tophub.context(target):
with autotvm.tophub.context(target): O = relay.ir_pass.alter_op_layout(N)
O = relay.optimize(N, target, params=None) O = relay.ir_pass.infer_type(O)
O = relay.ir_pass.infer_type(O)
# graph should differ # graph should differ
assert not relay.ir_pass.alpha_equal(N, O) assert not relay.ir_pass.alpha_equal(N, O)
if __name__ == "__main__": if __name__ == "__main__":
np.random.seed(42) np.random.seed(42)
......
...@@ -18,55 +18,10 @@ import numpy as np ...@@ -18,55 +18,10 @@ import numpy as np
import tvm import tvm
from tvm import relay from tvm import relay
from tvm.contrib.nvcc import have_fp16
from tvm._ffi.function import _init_api
_init_api("tvm.relay.build_module") def test_basic_build():
class BuildModule(object):
def __init__(self):
self.mod = relay.build_module._BuildModule()
self._get_graph_json = self.mod["get_graph_json"]
self._get_module = self.mod["get_module"]
self._build = self.mod["build"]
self._set_opt_level = self.mod["set_opt_level"]
self._set_params_func = self.mod["set_params"]
self._get_params_func = self.mod["get_params"]
def build(self, func, target, target_host, params):
tgts = []
for kv in target.items():
tgts.append(kv[0])
tgts.append(kv[1])
self._set_params(params)
self._build(func, tgts, target_host)
def get_json(self):
return self._get_graph_json()
def get_module(self):
return self._get_module()
def set_opt_level(self, level):
self._set_opt_level(level)
def _set_params(self, params):
inputs = {}
for name, param in params.items():
inputs[name] = relay.Constant(param)
self._set_params_func(inputs)
def get_params(self):
params = self._get_params_func()
ret = {}
for key, value in params.items():
ret[key] = value.data
return ret
def test_build():
m_bld = BuildModule()
tgt_name = "llvm"
tgt = "llvm" tgt = "llvm"
ctx = tvm.cpu() ctx = tvm.cpu()
# func # func
...@@ -86,21 +41,96 @@ def test_build(): ...@@ -86,21 +41,96 @@ def test_build():
} }
# build # build
targets = { targets = {
tgt: tgt tvm.expr.IntImm("int32", ctx.device_type): tgt
} }
m_bld.set_opt_level(3) g_json, mmod, params = relay.build(func, targets, "llvm", params=params)
m_bld.build(func, targets, "llvm", params=params)
g_json = m_bld.get_json()
mmod = m_bld.get_module()
params = m_bld.get_params()
# test # test
rt = tvm.contrib.graph_runtime.create(g_json, mmod, ctx) rt = tvm.contrib.graph_runtime.create(g_json, mmod, ctx)
rt.set_input("a", A) rt.set_input("a", A)
rt.load_params(relay.save_param_dict(params)) rt.load_params(relay.save_param_dict(params))
rt.run() rt.run()
out = rt.get_output(0) out = rt.get_output(0)
np.testing.assert_allclose(out.asnumpy(), np.testing.assert_allclose(out.asnumpy(), np.maximum(np.dot(A.asnumpy(),
np.maximum(np.dot(A.asnumpy(), B.asnumpy().T), 0) + C.asnumpy(), atol=1e-5, rtol=1e-5) B.asnumpy().T),
0) + C.asnumpy(),
atol=1e-5, rtol=1e-5)
def test_fp16_build():
dtype = "float16"
if not tvm.module.enabled("cuda") or not tvm.gpu(0).exist:
print("skip because cuda is not enabled.")
return
ctx = tvm.gpu(0)
if dtype == "float16" and not have_fp16(ctx.compute_version):
print("skip because gpu does not support fp16")
return
x = relay.var("x", dtype=dtype, shape=(4, 4))
y = relay.var("y", dtype=dtype, shape=(4, 4))
z = x + y
func = relay.Function([x, y], z)
X = tvm.nd.array(np.random.uniform(-1, 1, (4, 4)).astype(dtype), ctx=ctx)
Y = tvm.nd.array(np.random.uniform(-1, 1, (4, 4)).astype(dtype), ctx=ctx)
params = {
"x": X,
"y": Y,
}
# build
g_json, mmod, params = relay.build(func, "cuda", params=params)
# test
rt = tvm.contrib.graph_runtime.create(g_json, mmod, ctx)
rt.load_params(relay.save_param_dict(params))
rt.run()
out = rt.get_output(0)
np.testing.assert_allclose(out.asnumpy(), X.asnumpy() + Y.asnumpy(),
atol=1e-5, rtol=1e-5)
def test_fp16_conversion():
def check_conversion(tgt, ctx):
if not tvm.module.enabled(tgt):
print("skip because {} is not enabled.".format(tgt))
return
elif tgt == "cuda" and ctx.exist and not have_fp16(ctx.compute_version):
print("skip because gpu does not support fp16")
return
n = 10
for (src, dst) in [('float32', 'float16'), ('float16', 'float32')]:
x = relay.var("x", relay.TensorType((n,), src))
y = x.astype(dst)
func = relay.Function([x], y)
# init input
X = tvm.nd.array(n * np.random.randn(n).astype(src) - n / 2)
# build
with relay.build_config(opt_level=1):
g_json, mmod, params = relay.build(func, tgt)
# test
rt = tvm.contrib.graph_runtime.create(g_json, mmod, ctx)
rt.set_input("x", X)
rt.run()
out = rt.get_output(0)
np.testing.assert_allclose(out.asnumpy(), X.asnumpy().astype(dst),
atol=1e-5, rtol=1e-5)
for target, ctx in [('llvm', tvm.cpu()), ('cuda', tvm.gpu())]:
check_conversion(target, ctx)
if __name__ == "__main__":
test_basic_build()
test_fp16_build()
test_fp16_conversion()
...@@ -411,7 +411,7 @@ def run_fusible_network(dev, tgt): ...@@ -411,7 +411,7 @@ def run_fusible_network(dev, tgt):
expected_index) expected_index)
def test_fallback_all_operators(device, tgt): def test_fallback_all_operators(device, tgt):
target = {device: tgt} target = {device: tgt, "cpu": "llvm"}
annotated_func = get_func() annotated_func = get_func()
expected_func = get_func() expected_func = get_func()
check_annotated_graph(annotated_func, expected_func) check_annotated_graph(annotated_func, expected_func)
......
...@@ -47,54 +47,54 @@ def test_simulated_quantize(): ...@@ -47,54 +47,54 @@ def test_simulated_quantize():
assert out.args[3].checked_type == relay.ty.TensorType(tuple(), "float32") assert out.args[3].checked_type == relay.ty.TensorType(tuple(), "float32")
# def test_quantize_pass(): def test_quantize_pass():
# def quantize_weight(arr): def quantize_weight(arr):
# maximum = np.amax(np.abs(arr.asnumpy())) maximum = np.amax(np.abs(arr.asnumpy()))
# scale = 2**math.ceil(math.log(maximum, 2)) scale = 2**math.ceil(math.log(maximum, 2))
# out = np.around(arr.asnumpy() / scale * 128).astype('int8') out = np.around(arr.asnumpy() / scale * 128).astype('int8')
# out = np.clip(out, -127, 127) out = np.clip(out, -127, 127)
# return relay.const(out, 'int8') return relay.const(out, 'int8')
#
# n, c, h, w = 1, 3, 224, 224 n, c, h, w = 1, 3, 224, 224
# def make_graph(data): def make_graph(data):
# weight = relay.var("conv_weight") weight = relay.var("conv_weight")
# out = relay.nn.conv2d(data, weight, kernel_size=(3, 3), padding=(1, 1), channels=c) out = relay.nn.conv2d(data, weight, kernel_size=(3, 3), padding=(1, 1), channels=c)
# out = relay.Function(relay.ir_pass.free_vars(out), out) out = relay.Function(relay.ir_pass.free_vars(out), out)
# return out return out
#
# def make_qgraph(data, weight): def make_qgraph(data, weight):
# out = data * relay.const(32.0) out = data * relay.const(32.0)
# out = relay.round(out) out = relay.round(out)
# out = relay.clip(out, a_min=-127, a_max=127) out = relay.clip(out, a_min=-127, a_max=127)
# out = out.astype('int8') out = out.astype('int8')
#
# out = relay.nn.conv2d(out, weight, kernel_size=(3, 3), out = relay.nn.conv2d(out, weight, kernel_size=(3, 3),
# padding=(1, 1), channels=c, out_dtype='int32') padding=(1, 1), channels=c, out_dtype='int32')
# out = out.astype('float32') out = out.astype('float32')
# out = relay.multiply(out, relay.const(0.00024414062)) out = relay.multiply(out, relay.const(0.00024414062))
# out = relay.Function(relay.ir_pass.free_vars(out), out) out = relay.Function(relay.ir_pass.free_vars(out), out)
# return out return out
#
# data = relay.var("data", relay.TensorType((n, c, h, w), "float32")) data = relay.var("data", relay.TensorType((n, c, h, w), "float32"))
# graph = make_graph(data) graph = make_graph(data)
# dataset, params = make_dataset(graph, 10) dataset, params = make_dataset(graph, 10)
#
# with qtz.qconfig(skip_k_conv=0, global_scale=4.0, with qtz.qconfig(skip_k_conv=0, global_scale=4.0,
# round_for_shift=False, store_lowbit_output=False): round_for_shift=False, store_lowbit_output=False):
# qgraph0 = qtz.quantize(graph, params) qgraph0 = qtz.quantize(graph, params)
# qgraph0 = relay.ir_pass.infer_type(qgraph0) qgraph0 = relay.ir_pass.infer_type(qgraph0)
#
# conv_weight = quantize_weight(params['conv_weight']) conv_weight = quantize_weight(params['conv_weight'])
# qgraph1 = make_qgraph(data, conv_weight) qgraph1 = make_qgraph(data, conv_weight)
# qgraph1 = relay.ir_pass.infer_type(qgraph1) qgraph1 = relay.ir_pass.infer_type(qgraph1)
#
# graph = relay.create_executor('graph') graph = relay.create_executor('graph')
# res0 = graph.evaluate(qgraph0)(dataset[0]['data']) res0 = graph.evaluate(qgraph0)(dataset[0]['data'])
# res1 = graph.evaluate(qgraph1)(dataset[0]['data']) res1 = graph.evaluate(qgraph1)(dataset[0]['data'])
# tvm.testing.assert_allclose(res0.asnumpy(), res1.asnumpy(), rtol=1e-3) tvm.testing.assert_allclose(res0.asnumpy(), res1.asnumpy(), rtol=1e-3)
if __name__ == "__main__": if __name__ == "__main__":
np.random.seed(42) np.random.seed(42)
test_simulated_quantize() test_simulated_quantize()
# test_quantize_pass() test_quantize_pass()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment