Commit 0f2a3086 by Zhi Committed by Tianqi Chen

[Relay][Compilation] replace relay.build_module with C++ BuildModule (#3174)

parent 7d845f0d
......@@ -25,7 +25,7 @@ from . import expr_functor
from . import module
from . import adt
from . import ir_pass
from .build_module import build, build_config, create_executor, optimize
from .build_module import build, build_config, create_executor
from . import prelude
from . import parser
from . import debug
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable
"""The interface for building Relay functions exposed from C++."""
from tvm._ffi.function import _init_api
_init_api("relay.build_module", __name__)
......@@ -36,12 +36,9 @@ contrib.graph_runtime or any other TVM runtime compatible systems.
from __future__ import absolute_import
from tvm.ndarray import empty
from tvm._ffi.function import _init_api
from tvm.relay import build_module
from tvm import target as _target
_init_api("tvm.relay.build_module")
from tvm import expr as _expr
class GraphRuntimeCodegen(object):
"""The compiler from Relay to the TVM runtime system."""
......@@ -57,17 +54,14 @@ class GraphRuntimeCodegen(object):
self._setup(mod, target)
def _setup(self, mod, target):
tgts = []
tgts = {}
if isinstance(target, dict):
for kv in target.items():
tgts.append(kv[0])
if isinstance(kv[1], (str, _target.Target)):
tgts.append(str(kv[1]))
else:
for dev, tgt in target.items():
if not isinstance(tgt, (str, _target.Target)):
raise Exception("Unknown target type")
tgts[dev] = _target.create(tgt)
elif isinstance(target, (str, _target.Target)):
tgts.append("0")
tgts.append(str(target))
tgts[_expr.IntImm("int32", 0)] = _target.create(target)
self._init(mod, tgts)
def codegen(self, func):
......
......@@ -18,32 +18,19 @@
Construct the necessary state for the TVM graph runtime
from a Relay expression.
"""
import warnings
import numpy as np
from tvm._ffi.runtime_ctypes import TVMContext
from ..build_module import build as _tvm_build_module
from tvm import expr as tvm_expr
from .. import nd as _nd, target as _target, autotvm
from ..contrib import graph_runtime as _graph_rt
from . import _build_module
from . import ir_pass
from . import expr as _expr
from . import ty as _ty
from . import expr as _expr
from .backend import interpreter as _interpreter
from .backend import graph_runtime_codegen as _graph_gen
from .backend.vm import VMExecutor
# List of optimization pass and level when switch on
OPT_PASS_LEVEL = {
"SimplifyInference": 0,
"OpFusion": 1,
"FoldConstant": 2,
"CombineParallelConv2D": 3,
"FoldScaleAxis": 3,
"AlterOpLayout": 3,
"CanonicalizeOps": 3,
"EliminateCommonSubexpr": 3,
}
class BuildConfig(object):
"""Configuration scope to set a build config option.
......@@ -56,6 +43,7 @@ class BuildConfig(object):
defaults = {
"opt_level": 2,
"add_pass": None,
"disable_pass": None,
"fallback_device": None,
}
......@@ -85,23 +73,6 @@ class BuildConfig(object):
assert self._old_scope
BuildConfig.current = self._old_scope
def pass_enabled(self, pass_name):
"""Get whether pass is enabled.
Parameters
----------
pass_name : str
The optimization pass name
Returns
-------
enabled : bool
Whether pass is enabled.
"""
if self.add_pass and pass_name in self.add_pass:
return True
return self.opt_level >= OPT_PASS_LEVEL[pass_name]
BuildConfig.current = BuildConfig()
......@@ -117,6 +88,9 @@ def build_config(**kwargs):
add_pass: set of str
Optimization pass to be added regardless of optimization level.
disable_pass: set of str
Optimization pass to be disabled during optimization.
fallback_device : str or tvm.TVMContext
The fallback device. It is also used as the default device for
operators without specified device during heterogeneous execution.
......@@ -129,108 +103,203 @@ def build_config(**kwargs):
return BuildConfig(**kwargs)
def _bind_params_by_name(func, params):
"""Bind parameters of function by its name."""
name_dict = {}
for arg in func.params:
name = arg.name_hint
if name in name_dict:
name_dict[name] = None
else:
name_dict[name] = arg
bind_dict = {}
for k, v in params.items():
if k not in name_dict:
continue
arg = name_dict[k]
if arg is None:
raise ValueError("Multiple args in the function have name %s" % k)
bind_dict[arg] = _expr.const(v)
return _expr.bind(func, bind_dict)
def optimize(func, target=None, params=None):
"""Perform target invariant optimizations.
Parameters
----------
func : tvm.relay.Function
The input to optimization.
def _update_target(target):
target = target if target else _target.current_target()
if target is None:
raise ValueError("Target is not set in env or passed as argument.")
target : Optional[:any:`tvm.target.Target`, Dict[int, tvm.target.Target]]
The optimization target. For heterogeneous compilation, it is a
dictionary mapping device type to compilation target. For homogeneous
compilation, it is a build target.
tgts = {}
if isinstance(target, (str, _target.Target)):
dev_type = tvm_expr.IntImm("int32", _nd.context(str(target)).device_type)
tgts[dev_type] = _target.create(target)
elif isinstance(target, dict):
for dev, tgt in target.items():
dev_type = tvm_expr.IntImm("int32", _nd.context(dev).device_type)
tgts[dev_type] = _target.create(tgt)
else:
raise TypeError("target is expected to be str or " +
"tvm.target.Target, but received " +
"{}".format(type(target)))
return tgts
params : Optional[Dict[str, tvm.nd.NDArray]]
Input parameters to the graph that do not change
during inference time. used for constant folding.
Returns
-------
opt_func : tvm.relay.Function
The optimized version of the function.
class BuildModule(object):
"""Build a Relay function to run on TVM graph runtime. This class is used
to expose the `RelayBuildModule` APIs implemented in C++.
"""
cfg = BuildConfig.current
# bind expressions
if params:
func = _bind_params_by_name(func, params)
if cfg.pass_enabled("SimplifyInference"):
func = ir_pass.infer_type(func)
func = ir_pass.simplify_inference(func)
if cfg.pass_enabled("EliminateCommonSubexpr"):
def fskip(expr):
if isinstance(expr, _expr.Call) and expr.op.name == 'cast' and \
expr.attrs.dtype == 'int32':
return True
return False
func = ir_pass.infer_type(func)
func = ir_pass.eliminate_common_subexpr(func, fskip)
if cfg.pass_enabled("CombineParallelConv2D"):
func = ir_pass.infer_type(func)
func = ir_pass.combine_parallel_conv2d(func)
# The constant folding pass is necessary because FoldScaleAxis pass needs
# to check the constantness and positiveness of scales.
if cfg.pass_enabled("FoldConstant"):
func = ir_pass.fold_constant(func)
if cfg.pass_enabled("FoldScaleAxis"):
func = ir_pass.infer_type(func)
func = ir_pass.backward_fold_scale_axis(func)
func = ir_pass.infer_type(func)
func = ir_pass.forward_fold_scale_axis(func)
func = ir_pass.fold_constant(func)
if cfg.pass_enabled("CanonicalizeOps"):
func = ir_pass.infer_type(func)
func = ir_pass.canonicalize_ops(func)
# FIXME(zhiics) Skip AlterOpLayout pass for heterogeneous compilation for
# now. We probably need to pass target to this pass as well. Fix it in
# a followup PR.
if cfg.pass_enabled("AlterOpLayout"):
if isinstance(target, _target.Target):
func = ir_pass.infer_type(func)
with target:
func = ir_pass.alter_op_layout(func)
elif isinstance(target, dict):
warnings.warn("AlterOpLayout pass is not enabled for heterogeneous"
" execution yet.")
if cfg.pass_enabled("FoldConstant"):
func = ir_pass.fold_constant(func)
return func
def __init__(self):
self.mod = _build_module._BuildModule()
self._get_graph_json = self.mod["get_graph_json"]
self._get_module = self.mod["get_module"]
self._build = self.mod["build"]
self._add_pass = self.mod["add_pass"]
self._disable_pass = self.mod["disable_pass"]
self._set_opt_level = self.mod["set_opt_level"]
self._set_fallback_device = self.mod["set_fallback_device"]
self._set_params_func = self.mod["set_params"]
self._get_params_func = self.mod["get_params"]
def build(self, func, target=None, target_host=None, params=None):
"""
Parameters
----------
func: relay.Function
The function to build.
target : str, :any:`tvm.target.Target`, or dict of str(i.e.
device/context name) to str/tvm.target.Target, optional
For heterogeneous compilation, it is a dictionary indicating context
to target mapping. For homogeneous compilation, it is a build target.
target_host : str or :any:`tvm.target.Target`, optional
Host compilation target, if target is device.
When TVM compiles device specific program such as CUDA,
we also need host(CPU) side code to interact with the driver
to setup the dimensions and parameters correctly.
target_host is used to specify the host side codegen target.
By default, llvm is used if it is enabled,
otherwise a stackvm intepreter is used.
params : dict of str to NDArray
Input parameters to the graph that do not change
during inference time. Used for constant folding.
Returns
-------
graph_json : str
The json string that can be accepted by graph runtime.
mod : tvm.Module
The module containing necessary libraries.
params : dict
The parameters of the final graph.
"""
target = _update_target(target)
# Setup the build configurations passed in through `with build_config`.
self._setup_build_config(params)
# Build the function
self._build(func, target, target_host)
# Get artifacts
graph_json = self.get_json()
mod = self.get_module()
params = self.get_params()
return graph_json, mod, params
def _setup_build_config(self, params):
cfg = BuildConfig.current
# Set opt_level.
self.set_opt_level(cfg.opt_level)
# Set fallback device if it is available.
if cfg.fallback_device:
self.set_fallback_device(cfg.fallback_device)
# Add required passes.
if cfg.add_pass:
passes = set()
if isinstance(cfg.add_pass, (list, tuple, set)):
passes = set(cfg.add_pass)
else:
raise TypeError("add_pass must be list, tuple, or set, but " +
"got {}".format(type(cfg.add_pass)))
for pass_name in passes:
self.add_pass(pass_name)
# Add disabled passes.
if cfg.disable_pass:
passes = set()
if isinstance(cfg.disable_pass, (list, tuple, set)):
passes = set(cfg.disable_pass)
else:
raise TypeError("disable_pass must be list, tuple, or set, " +
"but got {}".format(type(cfg.disable_pass)))
for pass_name in passes:
self.disable_pass(pass_name)
if params:
self._set_params(params)
def _set_params(self, params):
inputs = {}
for name, param in params.items():
if isinstance(param, np.ndarray):
param = _nd.array(param)
inputs[name] = _expr.const(param)
self._set_params_func(inputs)
def add_pass(self, pass_name):
"""Add a pass to the pass list.
Parameters
----------
pass_name : str
The name of the pass that will be added to the list of passes used
for optimizations.
"""
self._add_pass(pass_name)
def disable_pass(self, pass_name):
"""Add a pass to the disabled pass list.
Parameters
----------
pass_name : str
The name of a pass. This pass will be added to the list of passes
that are disabled during optimization.
"""
self._disable_pass(pass_name)
def get_json(self):
"""Return the json file of the built program."""
return self._get_graph_json()
def get_module(self):
"""Return the built module."""
return self._get_module()
def get_params(self):
"""Return the updated weights."""
params = self._get_params_func()
ret = {}
for key, value in params.items():
ret[key] = value.data
return ret
def set_opt_level(self, level):
"""Set the optimization level.
Parameters
----------
level : int
The optimization level for build.
"""
self._set_opt_level(level)
def set_fallback_device(self, fallback_device):
"""Set the fallback device for heterogeneous execution.
Parameters
----------
fallback_device : str or tvm.TVMContext
The fallback device used for heterogeneous execution.
"""
if isinstance(fallback_device, str):
fallback_device = _nd.context(fallback_device)
if not isinstance(fallback_device, TVMContext):
raise TypeError("fallback_device is expected to be str " +
"TVMContext, or dict of device name to target, " +
"but received: {}".format(type(fallback_device)))
self._set_fallback_device(fallback_device.device_type)
def build(func, target=None, target_host=None, params=None):
"""Build a function to run on TVM graph runtime.
"""Helper function that builds a Relay function to run on TVM graph
runtime.
Parameters
----------
......@@ -266,146 +335,28 @@ def build(func, target=None, target_host=None, params=None):
params : dict
The parameters of the final graph.
"""
target = target if target else _target.current_target()
if target is None:
raise ValueError("Target is not set in env or passed as argument.")
target = _update_target(target)
if isinstance(target, dict):
target, fallback_device = _update_heterogeneous_inputs(target)
elif isinstance(target, (str, _target.Target)):
target = _target.create(target)
else:
raise ValueError("target must be the type of str, tvm.target.Target," +
"or dict of device name to target")
if isinstance(target_host, (str, _target.Target)):
target_host = _target.create(target_host)
elif target_host:
raise ValueError("target host must be the type of str, " +
"tvm.target.Target, or None")
# If current dispatch context is fallback context (the default root context),
# then load pre-tuned parameters from TopHub
if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext):
if isinstance(target, dict):
tophub_context = autotvm.tophub.context(list(target.values()))
else:
tophub_context = autotvm.tophub.context(target)
tophub_context = autotvm.tophub.context(list(target.values()))
else:
tophub_context = autotvm.util.EmptyContext()
cfg = BuildConfig.current
with tophub_context:
func = optimize(func, target, params)
# Annotate the ops for heterogeneous execution.
if isinstance(target, dict):
func, target = _run_device_annotation_passes(func, target,
fallback_device)
# Fuse ops before running code gen
func = ir_pass.infer_type(func)
func = ir_pass.fuse_ops(func, cfg.opt_level)
# Graph code generation
func = ir_pass.infer_type(func)
graph_gen = _graph_gen.GraphRuntimeCodegen(mod=None, target=target)
graph_json, lowered_funcs, params = graph_gen.codegen(func)
mod = _tvm_build_module(
lowered_funcs, target=target, target_host=target_host)
bld_mod = BuildModule()
graph_json, mod, params = bld_mod.build(func, target, target_host,
params)
return graph_json, mod, params
def _update_heterogeneous_inputs(target):
"""Update the target and fallback device required for heterogeneous
compilation. CPU is used as the fallback device if it wasn't provided.
Meanwhile, a CPU device type and "llvm" pair will be added to the target
dictionary in this case.
Parameters
----------
target : dict of str(i.e. device/context name) to str/tvm.target.Target.
A dict contains context to target pairs.
Returns
-------
device_target : dict of int to tvm.target.Target.
The updated device type to target dict.
fallback_device : int
The updated fallback device type.
"""
if not isinstance(target, dict):
raise ValueError("target must be dict of device name to target for " +
"heterogeneous execution, but received %s."
% type(target))
fallback_device = BuildConfig.current.fallback_device
if fallback_device is None:
# cpu is used as the default fallback device when heterogeneous
# execution is needed, but no fallback device is provided.
fallback_device = _nd.cpu(0).device_type
target[fallback_device] = str(_target.create("llvm"))
elif isinstance(fallback_device, str):
fallback_device = _nd.context(fallback_device).device_type
elif isinstance(fallback_device, TVMContext):
fallback_device = fallback_device.device_type
else:
raise ValueError("fallback_device expects the type of str or " +
"TVMContext, but received %s." % type(fallback_device))
device_target = {}
for dev, tgt in target.items():
device_target[_nd.context(dev).device_type] = _target.create(tgt)
if fallback_device not in device_target:
raise ValueError("%s is used as the default device, but the target" +
"is not provided."
% _nd.context(fallback_device).device_name)
return device_target, fallback_device
def _run_device_annotation_passes(func, target, fallback_device):
"""Execute the device annotation passes to update the input program and
target information.
Parameters
----------
func: tvm.relay.Function
The function where annotation passes will be execute at.
target : Dict[int, tvm.target.Target]
A dict contains device type to target pairs.
fallback_device : int
The fallback device type.
Returns
-------
target : Dict[int, tvm.target.Target]
The updated device type to target dict.
func : tvm.relay.Function
The updated func.
"""
func = ir_pass.infer_type(func)
func = ir_pass.rewrite_annotated_ops(func, fallback_device)
device_map = ir_pass.collect_device_info(func)
# The expression to device type map will be empty if all or none of
# the expressions in the `func` are annotated because this map is
# obtained by propagating the device information in the device copy
# operator. None of the above cases needs device copy operator.
if not device_map:
annotation_map = ir_pass.collect_device_annotation_ops(func)
# No annotation.
if not annotation_map:
target = {0: target[fallback_device]}
else:
dev_type = next(iter(annotation_map.values()))
# All annotated with the same device type.
if all(val == dev_type for val in annotation_map.values()):
target = {0: target[dev_type]}
else:
raise RuntimeError("Expressions in the function are "
"annotated with various device types,"
"but not device copy operators "
"found. Please check the "
"RewriteAnnotation pass.")
return func, target
class GraphExecutor(_interpreter.Executor):
"""Wrapper around Executor interface.
......
......@@ -269,6 +269,77 @@ def realize(graph):
return _quantize.realize(graph)
def optimize(func, params=None):
""" Perform "SimplifyInference", "FoldScaleAxis", "FoldConstant", and
"CanonicalizeOps" optimization before quantization.
# TODO(zhiics) These passes are executed one by one so far. We need to
# move them to the pass manager.
Parameters
---------
func: tvm.relay.Function
The original Relay function to be optimized.
params : dict of str to tvm.NDArray
Input parameters to the graph that do not change
during inference time. Used for constant folding.
Returns
-------
ret: tvm.relay.Function
The graph after quantization
"""
opt_passes = ["SimplifyInference",
"FoldScaleAxis",
"FoldConstant",
"CanonicalizeOps"]
cfg = _build.build_config(add_pass=opt_passes)
if params:
name_dict = {}
for arg in func.params:
name = arg.name_hint
if name in name_dict:
name_dict[name] = None
else:
name_dict[name] = arg
bind_dict = {}
for k, v in params.items():
if k not in name_dict:
continue
arg = name_dict[k]
if arg is None:
raise ValueError("Multiple args in the function have name %s" % k)
bind_dict[arg] = _expr.const(v)
func = _expr.bind(func, bind_dict)
if "SimplifyInference" in cfg.add_pass:
func = _ir_pass.infer_type(func)
func = _ir_pass.simplify_inference(func)
if "FoldConstant" in cfg.add_pass:
func = _ir_pass.fold_constant(func)
if "FoldScaleAxis" in cfg.add_pass:
func = _ir_pass.infer_type(func)
func = _ir_pass.backward_fold_scale_axis(func)
func = _ir_pass.infer_type(func)
func = _ir_pass.forward_fold_scale_axis(func)
func = _ir_pass.fold_constant(func)
if "CanonicalizeOps" in cfg.add_pass:
func = _ir_pass.infer_type(func)
func = _ir_pass.canonicalize_ops(func)
if "FoldConstant" in cfg.add_pass:
func = _ir_pass.fold_constant(func)
return func
def quantize(graph, params=None, dataset=None):
""" The quantization procedure. Before running the three main
procedure of quantization, "annotate", "calibrate" and "realize"
......@@ -292,12 +363,8 @@ def quantize(graph, params=None, dataset=None):
ret: Function
The graph after quantization
"""
opt_passes = ["SimplifyInference",
"FoldScaleAxis",
"FoldConstant",
"CanonicalizeOps"]
with _build.build_config(add_pass=opt_passes):
graph = _build.optimize(graph, params=params)
# TODO(zhiics) Move this to the pass manager.
graph = optimize(graph, params)
graph = annotate(graph)
graph = calibrate(graph, dataset)
......
......@@ -311,7 +311,7 @@ bool LLVMEnabled() {
/*! \return The default host target for a given device target */
Target DefaultTargetHost(Target target) {
if (target->device_type == kDLCPU) {
if (target.defined() && target->device_type == kDLCPU) {
return target;
} else {
if (LLVMEnabled()) {
......
......@@ -38,54 +38,31 @@ namespace tvm {
namespace relay {
namespace backend {
using TargetsMap = Map<tvm::Integer, tvm::Target>;
/*!
* \brief Context name / index
* See: python/tvm/_ffi/runtime_ctypes.py
* \brief Context index to Target
*/
struct ContextMap {
static const std::unordered_map<int, std::string> mask2str;
static const std::unordered_map<std::string, int> str2mask;
static std::string Mask2Str(int mask) {
struct ContextTargetMap {
static const std::unordered_map<int, tvm::Target> mask2str;
static tvm::Target Mask2Str(int mask) {
CHECK_GT(mask2str.count(mask), 0) << "Unknown mask.";
return mask2str.at(mask);
}
static int Str2Mask(const std::string& str) {
CHECK_GT(str2mask.count(str), 0) << "Unknown context.";
return str2mask.at(str);
}
};
const std::unordered_map<int, std::string> ContextMap::mask2str = {
{1, "cpu"},
{2, "gpu"},
{4, "opencl"},
{5, "aocl"},
{6, "sdaccel"},
{7, "vulkan"},
{8, "metal"},
{9, "vpi"},
{10, "rocm"},
{11, "opengl"},
{12, "ext_dev"}
};
const std::unordered_map<std::string, int> ContextMap::str2mask = {
{"llvm", 1},
{"cpu", 1},
{"c", 1},
{"gpu", 2},
{"cuda", 2},
{"nvptx", 2},
{"cl", 4},
{"opencl", 4},
{"aocl", 5},
{"aocl_sw_emu", 5},
{"vulkan", 7},
{"metal", 8},
{"vpi", 9},
{"rocm", 10},
{"opengl", 11},
{"ext_dev", 12}
const std::unordered_map<int, tvm::Target> ContextTargetMap::mask2str = {
{1, tvm::Target::create("llvm")},
{2, tvm::Target::create("cuda")},
{4, tvm::Target::create("opencl")},
{5, tvm::Target::create("aocl")},
{6, tvm::Target::create("sdaccel")},
{7, tvm::Target::create("vulkan")},
{8, tvm::Target::create("metal")},
{9, tvm::Target::create("vpi")},
{10, tvm::Target::create("rocm")},
{11, tvm::Target::create("opengl")},
{12, tvm::Target::create("ext_dev")}
};
/*!
......@@ -137,7 +114,7 @@ struct BuildOutput {
*/
struct RelayBuildConfig {
int opt_level{2};
std::string fallback_device{"llvm"};
int fallback_device{static_cast<int>(kDLCPU)};
std::unordered_set<std::string> enabled_pass;
std::unordered_set<std::string> disabled_pass;
OptPassLevel OPT_PASS_LEVEL;
......@@ -164,14 +141,8 @@ struct GraphCodegen {
}
~GraphCodegen() {}
void Init(runtime::Module* m,
Map<HalideIR::Expr, HalideIR::Expr> targets) {
Array<HalideIR::Expr> tgts;
for (auto kv : targets) {
tgts.push_back(kv.first);
tgts.push_back(kv.second);
}
CallFunc("init", m, tgts);
void Init(runtime::Module* m, TargetsMap targets) {
CallFunc("init", m, targets);
}
void Codegen(const Function& func) {
......@@ -248,14 +219,7 @@ class RelayBuildModule : public runtime::ModuleNode {
} else if (name == "build") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.num_args, 3);
Array<HalideIR::Expr> tmp = args[1];
std::unordered_map<std::string, std::string> targets;
for (size_t i = 0; i < tmp.size(); i += 2) {
auto k = tmp[i].as<ir::StringImm>()->value;
auto v = tmp[i + 1].as<ir::StringImm>()->value;
targets[k] = v;
}
this->Build(args[0], targets, args[2]);
this->Build(args[0], args[1], args[2]);
});
} else if (name == "list_params") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
......@@ -273,7 +237,8 @@ class RelayBuildModule : public runtime::ModuleNode {
});
} else if (name == "set_fallback_device") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
std::string dev = args[0];
CHECK_EQ(args.num_args, 1);
int dev = args[0];
this->SetFallBackDev(dev);
});
} else if (name == "add_pass") {
......@@ -328,7 +293,7 @@ class RelayBuildModule : public runtime::ModuleNode {
*
* \param device name
*/
void SetFallBackDev(const std::string& dev) {
void SetFallBackDev(int dev) {
cfg_.fallback_device = dev;
}
/*!
......@@ -402,8 +367,8 @@ class RelayBuildModule : public runtime::ModuleNode {
* \param target_host Host target device
*/
void Build(Function func,
const std::unordered_map<std::string, std::string>& targets,
const std::string& target_host) {
const TargetsMap& targets,
const tvm::Target& target_host) {
targets_ = targets;
target_host_ = target_host;
BuildRelay(func, cfg_, params_);
......@@ -416,8 +381,9 @@ class RelayBuildModule : public runtime::ModuleNode {
* \param params params dict
* \return relay::Function
*/
relay::Function BindParamsByName(relay::Function func,
const std::unordered_map<std::string, runtime::NDArray>& params) {
relay::Function BindParamsByName(
relay::Function func,
const std::unordered_map<std::string, runtime::NDArray>& params) {
std::unordered_map<std::string, relay::Var> name_dict;
std::unordered_set<relay::Var, NodeHash, NodeEqual> repeat_var;
for (auto arg : func->params) {
......@@ -454,7 +420,7 @@ class RelayBuildModule : public runtime::ModuleNode {
* \return relay::Function
*/
relay::Function Optimize(relay::Function func,
const std::unordered_map<std::string, std::string>& targets,
const TargetsMap& targets,
const RelayBuildConfig& cfg,
const std::unordered_map<std::string, runtime::NDArray>& params) {
if (params.size()) {
......@@ -507,8 +473,7 @@ class RelayBuildModule : public runtime::ModuleNode {
auto enter_pf = GetPackedFunc("_EnterTargetScope");
auto exit_pf = GetPackedFunc("_ExitTargetScope");
for (const auto& kv : targets) {
auto target = Target::create(kv.second);
(*enter_pf)(target);
(*enter_pf)(kv.second);
func = CallPackedFunc("relay._ir_pass.AlterOpLayout", func);
(*exit_pf)();
}
......@@ -530,25 +495,19 @@ class RelayBuildModule : public runtime::ModuleNode {
*
* \param targets dictionary
* \param cfg
* \return Map<HalideIR::Expr, HalideIR::Expr>
* \return Map<tvm::Integer, tvm::Target>
*/
Map<HalideIR::Expr, HalideIR::Expr> UpdateHeterogeneousInputs(
const std::unordered_map<std::string, std::string>& targets,
const RelayBuildConfig& cfg) {
Map<HalideIR::Expr, HalideIR::Expr> device_target;
std::unordered_map<int64_t, std::string> tmp_map;
auto fallback_idx = ContextMap::Str2Mask(cfg.fallback_device);
TargetsMap UpdateHeterogeneousInputs(const TargetsMap& targets,
const RelayBuildConfig& cfg) {
TargetsMap device_target = targets;
std::unordered_map<int64_t, tvm::Target> tmp_map;
for (const auto& kv : targets) {
tmp_map[ContextMap::Str2Mask(kv.first)] = kv.second;
}
if (tmp_map.count(fallback_idx) == 0) {
tmp_map[fallback_idx] = cfg.fallback_device;
tmp_map[kv.first->value] = kv.second;
}
for (const auto& kv : tmp_map) {
if (tmp_map.count(cfg.fallback_device) == 0) {
device_target.Set(
ir::IntImm::make(HalideIR::Int(64), kv.first),
ir::StringImm::make(kv.second));
cfg.fallback_device,
ContextTargetMap::Mask2Str(cfg.fallback_device));
}
return device_target;
}
......@@ -561,25 +520,19 @@ class RelayBuildModule : public runtime::ModuleNode {
* \param targets_map_ptr
* \return Function
*/
Function RunDeviceAnnotationPass(
Function func,
const RelayBuildConfig& cfg,
Map<HalideIR::Expr, HalideIR::Expr>* targets_map_ptr) {
auto fallback_idx = ContextMap::Str2Mask(cfg.fallback_device);
Function RunDeviceAnnotationPass(Function func, const RelayBuildConfig& cfg,
TargetsMap* targets_map_ptr) {
func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
func = CallPackedFunc("relay._ir_pass.RewriteDeviceAnnotation", func, fallback_idx);
auto device_map = CallPackedFunc<Map<Expr, Integer> >("relay._ir_pass.CollectDeviceInfo",
func,
nullptr);
func = CallPackedFunc("relay._ir_pass.RewriteDeviceAnnotation", func,
cfg.fallback_device);
auto device_map = CallPackedFunc<Map<Expr, Integer> >(
"relay._ir_pass.CollectDeviceInfo", func, nullptr);
if (device_map.size() == 0) {
auto annotation_map =
CallPackedFunc<Map<Expr, Integer> >("relay._ir_pass.CollectDeviceAnnotationOps",
func,
nullptr);
auto annotation_map = CallPackedFunc<Map<Expr, Integer> >(
"relay._ir_pass.CollectDeviceAnnotationOps", func, nullptr);
if (annotation_map.size() == 0) {
targets_map_ptr->Set(
ir::IntImm::make(HalideIR::Int(64), 0),
ir::StringImm::make(cfg.fallback_device));
0, ContextTargetMap::Mask2Str(cfg.fallback_device));
} else {
int64_t dev_type = -1;
for (auto kv : annotation_map) {
......@@ -594,9 +547,7 @@ class RelayBuildModule : public runtime::ModuleNode {
<< "found. Please check the "
<< "RewriteAnnotation pass.";
}
targets_map_ptr->Set(
ir::IntImm::make(HalideIR::Int(64), 0),
ir::StringImm::make(ContextMap::Mask2Str(dev_type)));
targets_map_ptr->Set(0, ContextTargetMap::Mask2Str(dev_type));
}
}
return func;
......@@ -614,15 +565,11 @@ class RelayBuildModule : public runtime::ModuleNode {
const std::unordered_map<std::string, tvm::runtime::NDArray> &params) {
// convert
tvm_cfg_ = build_config();
Map<HalideIR::Expr, HalideIR::Expr> device_target;
TargetsMap device_target;
if (targets_.size() > 1) {
device_target = UpdateHeterogeneousInputs(targets_, cfg);
} else {
for (auto &kv : targets_) {
device_target.Set(
ir::IntImm::make(HalideIR::Int(64), ContextMap::Str2Mask(kv.first)),
ir::StringImm::make(kv.second));
}
device_target = targets_;
}
func = Optimize(func, targets_, cfg, params);
if (device_target.size() > 1) {
......@@ -640,16 +587,15 @@ class RelayBuildModule : public runtime::ModuleNode {
ret_.graph_json = graph_codegen_->GetJSON();
ret_.params = graph_codegen_->GetParams();
auto target_host = Target::create(target_host_);
ret_.mod = tvm::build(graph_codegen_->GetLoweredFunc(), target_host, tvm_cfg_);
ret_.mod = tvm::build(graph_codegen_->GetLoweredFunc(), target_host_, tvm_cfg_);
}
protected:
std::unique_ptr<GraphCodegen> graph_codegen_;
/*! \brief target device */
std::unordered_map<std::string, std::string> targets_;
TargetsMap targets_;
/*! \brief target host device */
std::string target_host_;
tvm::Target target_host_;
/*! \brief frontend optimization configure */
RelayBuildConfig cfg_;
/*! \brief parameters */
......
......@@ -51,7 +51,7 @@ using GraphAttrs = std::unordered_map<std::string, dmlc::any>;
using GraphNodePtr = std::shared_ptr<GraphNode>;
using GraphInputNodePtr = std::shared_ptr<GraphInputNode>;
using GraphOpNodePtr = std::shared_ptr<GraphOpNode>;
using TargetsMap = std::unordered_map<std::string, Target>;
using TargetsMap = std::unordered_map<int, Target>;
/*! \brief Lowered outputs */
struct LoweredOutput {
......@@ -193,12 +193,10 @@ class GraphOpNode : public GraphNode {
class GraphRuntimeCodegen
: public ::tvm::relay::ExprFunctor<std::vector<GraphNodeRef>(const Expr&)> {
public:
GraphRuntimeCodegen(runtime::Module* mod,
const std::unordered_map<std::string, std::string>& targets) : mod_(mod) {
GraphRuntimeCodegen(runtime::Module* mod, const TargetsMap& targets)
: mod_(mod) {
compile_engine_ = CompileEngine::Global();
for (auto &kv : targets) {
targets_[kv.first] = Target::create(kv.second);
}
targets_ = targets;
}
LoweredOutput Codegen(relay::Function func) {
......@@ -406,7 +404,7 @@ class GraphRuntimeCodegen
auto pf0 = GetPackedFunc("relay.backend._make_CCacheKey");
auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower");
auto &device_type = storage_device_map_[expr][1];
auto call_dev_type = device_type[0]->value; //-> int to string
auto call_dev_type = device_type[0]->value;
Target target;
if (targets_.size() == 1) {
// homogeneous execution.
......@@ -415,22 +413,17 @@ class GraphRuntimeCodegen
}
} else {
// heterogeneous execution.
const auto call_dev_key = std::to_string(call_dev_type);
std::string call_dev_name;
if (call_dev_type == 0) {
call_dev_name = "llvm";
} else {
call_dev_name = runtime::DeviceName(call_dev_type);
}
if (targets_.count(call_dev_name) == 0 && targets_.count(call_dev_key) == 0) {
if (targets_.count(call_dev_type) == 0) {
LOG(FATAL) << "No target is provided for device "
<< call_dev_name;
}
if (targets_.count(call_dev_key)) {
target = targets_[call_dev_key];
} else {
target = targets_[call_dev_name];
}
target = targets_[call_dev_type];
}
CCacheKey key = (*pf0)(func, target);
CachedFunc lowerd_func = (*pf1)(compile_engine_, key);
......@@ -604,30 +597,21 @@ class GraphRuntimeCodegenModule : public runtime::ModuleNode {
virtual PackedFunc GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) {
if (name == "init") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.num_args, 2) << "The expected of arguments are: "
<< "runtime::Module mod and Map<str, StringImm> targets";
void* mod = args[0];
auto& sptr = args[1].node_sptr();
auto* node = static_cast<const ArrayNode*>(sptr.get());
auto& tmp_targets = node->data;
std::unordered_map<std::string, std::string> targets;
for (size_t i = 0; i < tmp_targets.size(); i += 2) {
std::string key;
auto sk = Expr(tmp_targets[i]).as<ir::StringImm>();
auto ik = Expr(tmp_targets[i]).as<ir::IntImm>();
if (sk) {
key = sk->value;
}
if (ik) {
key = std::to_string(ik->value);
}
auto v = Expr(tmp_targets[i + 1]).as<ir::StringImm>();
targets[key] = v->value;
}
codegen_ = std::make_shared<GraphRuntimeCodegen>(
reinterpret_cast<runtime::Module*>(mod), targets);
});
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.num_args, 2)
<< "The expected of arguments are: "
<< "runtime::Module mod and Map<int, Target> targets";
void* mod = args[0];
Map<Integer, tvm::Target> tmp = args[1];
TargetsMap targets;
for (const auto& it : tmp) {
auto dev_type = it.first.as<ir::IntImm>();
CHECK(dev_type);
targets[dev_type->value] = it.second;
}
codegen_ = std::make_shared<GraphRuntimeCodegen>(
reinterpret_cast<runtime::Module*>(mod), targets);
});
} else if (name == "codegen") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
Function func = args[0];
......
......@@ -18,6 +18,7 @@
*/
#include <gtest/gtest.h>
#include <tvm/build_module.h>
#include <tvm/tvm.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/type.h>
......@@ -73,10 +74,10 @@ TEST(Relay, BuildModule) {
auto build_f = build_mod.GetFunction("build", false);
auto json_f = build_mod.GetFunction("get_graph_json", false);
auto mod_f = build_mod.GetFunction("get_module", false);
Array<HalideIR::Expr> target_pair;
target_pair.push_back(ir::StringImm::make("cpu"));
target_pair.push_back(ir::StringImm::make("llvm"));
build_f(func, target_pair, "llvm");
Map<tvm::Integer, tvm::Target> targets;
Target llvm_tgt = Target::create("llvm");
targets.Set(0, llvm_tgt);
build_f(func, targets, llvm_tgt);
std::string json = json_f();
tvm::runtime::Module mod = mod_f();
// run
......
......@@ -74,13 +74,12 @@ def test_alter_layout_conv2d():
for tgt in targets:
with tvm.target.create(tgt) as target:
with relay.build_config(opt_level=-1, add_pass='AlterOpLayout'):
with autotvm.tophub.context(target):
O = relay.optimize(N, target, params=None)
O = relay.ir_pass.infer_type(O)
with autotvm.tophub.context(target):
O = relay.ir_pass.alter_op_layout(N)
O = relay.ir_pass.infer_type(O)
# graph should differ
assert not relay.ir_pass.alpha_equal(N, O)
# graph should differ
assert not relay.ir_pass.alpha_equal(N, O)
if __name__ == "__main__":
np.random.seed(42)
......
......@@ -18,55 +18,10 @@ import numpy as np
import tvm
from tvm import relay
from tvm.contrib.nvcc import have_fp16
from tvm._ffi.function import _init_api
_init_api("tvm.relay.build_module")
class BuildModule(object):
def __init__(self):
self.mod = relay.build_module._BuildModule()
self._get_graph_json = self.mod["get_graph_json"]
self._get_module = self.mod["get_module"]
self._build = self.mod["build"]
self._set_opt_level = self.mod["set_opt_level"]
self._set_params_func = self.mod["set_params"]
self._get_params_func = self.mod["get_params"]
def build(self, func, target, target_host, params):
tgts = []
for kv in target.items():
tgts.append(kv[0])
tgts.append(kv[1])
self._set_params(params)
self._build(func, tgts, target_host)
def get_json(self):
return self._get_graph_json()
def get_module(self):
return self._get_module()
def set_opt_level(self, level):
self._set_opt_level(level)
def _set_params(self, params):
inputs = {}
for name, param in params.items():
inputs[name] = relay.Constant(param)
self._set_params_func(inputs)
def get_params(self):
params = self._get_params_func()
ret = {}
for key, value in params.items():
ret[key] = value.data
return ret
def test_build():
m_bld = BuildModule()
tgt_name = "llvm"
def test_basic_build():
tgt = "llvm"
ctx = tvm.cpu()
# func
......@@ -86,21 +41,96 @@ def test_build():
}
# build
targets = {
tgt: tgt
tvm.expr.IntImm("int32", ctx.device_type): tgt
}
m_bld.set_opt_level(3)
m_bld.build(func, targets, "llvm", params=params)
g_json = m_bld.get_json()
mmod = m_bld.get_module()
params = m_bld.get_params()
g_json, mmod, params = relay.build(func, targets, "llvm", params=params)
# test
rt = tvm.contrib.graph_runtime.create(g_json, mmod, ctx)
rt.set_input("a", A)
rt.load_params(relay.save_param_dict(params))
rt.run()
out = rt.get_output(0)
np.testing.assert_allclose(out.asnumpy(),
np.maximum(np.dot(A.asnumpy(), B.asnumpy().T), 0) + C.asnumpy(), atol=1e-5, rtol=1e-5)
np.testing.assert_allclose(out.asnumpy(), np.maximum(np.dot(A.asnumpy(),
B.asnumpy().T),
0) + C.asnumpy(),
atol=1e-5, rtol=1e-5)
def test_fp16_build():
dtype = "float16"
if not tvm.module.enabled("cuda") or not tvm.gpu(0).exist:
print("skip because cuda is not enabled.")
return
ctx = tvm.gpu(0)
if dtype == "float16" and not have_fp16(ctx.compute_version):
print("skip because gpu does not support fp16")
return
x = relay.var("x", dtype=dtype, shape=(4, 4))
y = relay.var("y", dtype=dtype, shape=(4, 4))
z = x + y
func = relay.Function([x, y], z)
X = tvm.nd.array(np.random.uniform(-1, 1, (4, 4)).astype(dtype), ctx=ctx)
Y = tvm.nd.array(np.random.uniform(-1, 1, (4, 4)).astype(dtype), ctx=ctx)
params = {
"x": X,
"y": Y,
}
# build
g_json, mmod, params = relay.build(func, "cuda", params=params)
# test
rt = tvm.contrib.graph_runtime.create(g_json, mmod, ctx)
rt.load_params(relay.save_param_dict(params))
rt.run()
out = rt.get_output(0)
np.testing.assert_allclose(out.asnumpy(), X.asnumpy() + Y.asnumpy(),
atol=1e-5, rtol=1e-5)
def test_fp16_conversion():
def check_conversion(tgt, ctx):
if not tvm.module.enabled(tgt):
print("skip because {} is not enabled.".format(tgt))
return
elif tgt == "cuda" and ctx.exist and not have_fp16(ctx.compute_version):
print("skip because gpu does not support fp16")
return
n = 10
for (src, dst) in [('float32', 'float16'), ('float16', 'float32')]:
x = relay.var("x", relay.TensorType((n,), src))
y = x.astype(dst)
func = relay.Function([x], y)
# init input
X = tvm.nd.array(n * np.random.randn(n).astype(src) - n / 2)
# build
with relay.build_config(opt_level=1):
g_json, mmod, params = relay.build(func, tgt)
# test
rt = tvm.contrib.graph_runtime.create(g_json, mmod, ctx)
rt.set_input("x", X)
rt.run()
out = rt.get_output(0)
np.testing.assert_allclose(out.asnumpy(), X.asnumpy().astype(dst),
atol=1e-5, rtol=1e-5)
for target, ctx in [('llvm', tvm.cpu()), ('cuda', tvm.gpu())]:
check_conversion(target, ctx)
if __name__ == "__main__":
test_basic_build()
test_fp16_build()
test_fp16_conversion()
......@@ -411,7 +411,7 @@ def run_fusible_network(dev, tgt):
expected_index)
def test_fallback_all_operators(device, tgt):
target = {device: tgt}
target = {device: tgt, "cpu": "llvm"}
annotated_func = get_func()
expected_func = get_func()
check_annotated_graph(annotated_func, expected_func)
......
......@@ -47,54 +47,54 @@ def test_simulated_quantize():
assert out.args[3].checked_type == relay.ty.TensorType(tuple(), "float32")
# def test_quantize_pass():
# def quantize_weight(arr):
# maximum = np.amax(np.abs(arr.asnumpy()))
# scale = 2**math.ceil(math.log(maximum, 2))
# out = np.around(arr.asnumpy() / scale * 128).astype('int8')
# out = np.clip(out, -127, 127)
# return relay.const(out, 'int8')
#
# n, c, h, w = 1, 3, 224, 224
# def make_graph(data):
# weight = relay.var("conv_weight")
# out = relay.nn.conv2d(data, weight, kernel_size=(3, 3), padding=(1, 1), channels=c)
# out = relay.Function(relay.ir_pass.free_vars(out), out)
# return out
#
# def make_qgraph(data, weight):
# out = data * relay.const(32.0)
# out = relay.round(out)
# out = relay.clip(out, a_min=-127, a_max=127)
# out = out.astype('int8')
#
# out = relay.nn.conv2d(out, weight, kernel_size=(3, 3),
# padding=(1, 1), channels=c, out_dtype='int32')
# out = out.astype('float32')
# out = relay.multiply(out, relay.const(0.00024414062))
# out = relay.Function(relay.ir_pass.free_vars(out), out)
# return out
#
# data = relay.var("data", relay.TensorType((n, c, h, w), "float32"))
# graph = make_graph(data)
# dataset, params = make_dataset(graph, 10)
#
# with qtz.qconfig(skip_k_conv=0, global_scale=4.0,
# round_for_shift=False, store_lowbit_output=False):
# qgraph0 = qtz.quantize(graph, params)
# qgraph0 = relay.ir_pass.infer_type(qgraph0)
#
# conv_weight = quantize_weight(params['conv_weight'])
# qgraph1 = make_qgraph(data, conv_weight)
# qgraph1 = relay.ir_pass.infer_type(qgraph1)
#
# graph = relay.create_executor('graph')
# res0 = graph.evaluate(qgraph0)(dataset[0]['data'])
# res1 = graph.evaluate(qgraph1)(dataset[0]['data'])
# tvm.testing.assert_allclose(res0.asnumpy(), res1.asnumpy(), rtol=1e-3)
def test_quantize_pass():
def quantize_weight(arr):
maximum = np.amax(np.abs(arr.asnumpy()))
scale = 2**math.ceil(math.log(maximum, 2))
out = np.around(arr.asnumpy() / scale * 128).astype('int8')
out = np.clip(out, -127, 127)
return relay.const(out, 'int8')
n, c, h, w = 1, 3, 224, 224
def make_graph(data):
weight = relay.var("conv_weight")
out = relay.nn.conv2d(data, weight, kernel_size=(3, 3), padding=(1, 1), channels=c)
out = relay.Function(relay.ir_pass.free_vars(out), out)
return out
def make_qgraph(data, weight):
out = data * relay.const(32.0)
out = relay.round(out)
out = relay.clip(out, a_min=-127, a_max=127)
out = out.astype('int8')
out = relay.nn.conv2d(out, weight, kernel_size=(3, 3),
padding=(1, 1), channels=c, out_dtype='int32')
out = out.astype('float32')
out = relay.multiply(out, relay.const(0.00024414062))
out = relay.Function(relay.ir_pass.free_vars(out), out)
return out
data = relay.var("data", relay.TensorType((n, c, h, w), "float32"))
graph = make_graph(data)
dataset, params = make_dataset(graph, 10)
with qtz.qconfig(skip_k_conv=0, global_scale=4.0,
round_for_shift=False, store_lowbit_output=False):
qgraph0 = qtz.quantize(graph, params)
qgraph0 = relay.ir_pass.infer_type(qgraph0)
conv_weight = quantize_weight(params['conv_weight'])
qgraph1 = make_qgraph(data, conv_weight)
qgraph1 = relay.ir_pass.infer_type(qgraph1)
graph = relay.create_executor('graph')
res0 = graph.evaluate(qgraph0)(dataset[0]['data'])
res1 = graph.evaluate(qgraph1)(dataset[0]['data'])
tvm.testing.assert_allclose(res0.asnumpy(), res1.asnumpy(), rtol=1e-3)
if __name__ == "__main__":
np.random.seed(42)
test_simulated_quantize()
# test_quantize_pass()
test_quantize_pass()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment