Commit 0f2a3086 by Zhi Committed by Tianqi Chen

[Relay][Compilation] replace relay.build_module with C++ BuildModule (#3174)

parent 7d845f0d
...@@ -25,7 +25,7 @@ from . import expr_functor ...@@ -25,7 +25,7 @@ from . import expr_functor
from . import module from . import module
from . import adt from . import adt
from . import ir_pass from . import ir_pass
from .build_module import build, build_config, create_executor, optimize from .build_module import build, build_config, create_executor
from . import prelude from . import prelude
from . import parser from . import parser
from . import debug from . import debug
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable
"""The interface for building Relay functions exposed from C++."""
from tvm._ffi.function import _init_api
_init_api("relay.build_module", __name__)
...@@ -36,12 +36,9 @@ contrib.graph_runtime or any other TVM runtime compatible systems. ...@@ -36,12 +36,9 @@ contrib.graph_runtime or any other TVM runtime compatible systems.
from __future__ import absolute_import from __future__ import absolute_import
from tvm.ndarray import empty from tvm.ndarray import empty
from tvm._ffi.function import _init_api
from tvm.relay import build_module from tvm.relay import build_module
from tvm import target as _target from tvm import target as _target
from tvm import expr as _expr
_init_api("tvm.relay.build_module")
class GraphRuntimeCodegen(object): class GraphRuntimeCodegen(object):
"""The compiler from Relay to the TVM runtime system.""" """The compiler from Relay to the TVM runtime system."""
...@@ -57,17 +54,14 @@ class GraphRuntimeCodegen(object): ...@@ -57,17 +54,14 @@ class GraphRuntimeCodegen(object):
self._setup(mod, target) self._setup(mod, target)
def _setup(self, mod, target): def _setup(self, mod, target):
tgts = [] tgts = {}
if isinstance(target, dict): if isinstance(target, dict):
for kv in target.items(): for dev, tgt in target.items():
tgts.append(kv[0]) if not isinstance(tgt, (str, _target.Target)):
if isinstance(kv[1], (str, _target.Target)):
tgts.append(str(kv[1]))
else:
raise Exception("Unknown target type") raise Exception("Unknown target type")
tgts[dev] = _target.create(tgt)
elif isinstance(target, (str, _target.Target)): elif isinstance(target, (str, _target.Target)):
tgts.append("0") tgts[_expr.IntImm("int32", 0)] = _target.create(target)
tgts.append(str(target))
self._init(mod, tgts) self._init(mod, tgts)
def codegen(self, func): def codegen(self, func):
......
...@@ -18,32 +18,19 @@ ...@@ -18,32 +18,19 @@
Construct the necessary state for the TVM graph runtime Construct the necessary state for the TVM graph runtime
from a Relay expression. from a Relay expression.
""" """
import warnings import numpy as np
from tvm._ffi.runtime_ctypes import TVMContext from tvm._ffi.runtime_ctypes import TVMContext
from ..build_module import build as _tvm_build_module from tvm import expr as tvm_expr
from .. import nd as _nd, target as _target, autotvm from .. import nd as _nd, target as _target, autotvm
from ..contrib import graph_runtime as _graph_rt from ..contrib import graph_runtime as _graph_rt
from . import _build_module
from . import ir_pass from . import ir_pass
from . import expr as _expr
from . import ty as _ty from . import ty as _ty
from . import expr as _expr
from .backend import interpreter as _interpreter from .backend import interpreter as _interpreter
from .backend import graph_runtime_codegen as _graph_gen
from .backend.vm import VMExecutor from .backend.vm import VMExecutor
# List of optimization pass and level when switch on
OPT_PASS_LEVEL = {
"SimplifyInference": 0,
"OpFusion": 1,
"FoldConstant": 2,
"CombineParallelConv2D": 3,
"FoldScaleAxis": 3,
"AlterOpLayout": 3,
"CanonicalizeOps": 3,
"EliminateCommonSubexpr": 3,
}
class BuildConfig(object): class BuildConfig(object):
"""Configuration scope to set a build config option. """Configuration scope to set a build config option.
...@@ -56,6 +43,7 @@ class BuildConfig(object): ...@@ -56,6 +43,7 @@ class BuildConfig(object):
defaults = { defaults = {
"opt_level": 2, "opt_level": 2,
"add_pass": None, "add_pass": None,
"disable_pass": None,
"fallback_device": None, "fallback_device": None,
} }
...@@ -85,23 +73,6 @@ class BuildConfig(object): ...@@ -85,23 +73,6 @@ class BuildConfig(object):
assert self._old_scope assert self._old_scope
BuildConfig.current = self._old_scope BuildConfig.current = self._old_scope
def pass_enabled(self, pass_name):
"""Get whether pass is enabled.
Parameters
----------
pass_name : str
The optimization pass name
Returns
-------
enabled : bool
Whether pass is enabled.
"""
if self.add_pass and pass_name in self.add_pass:
return True
return self.opt_level >= OPT_PASS_LEVEL[pass_name]
BuildConfig.current = BuildConfig() BuildConfig.current = BuildConfig()
...@@ -117,6 +88,9 @@ def build_config(**kwargs): ...@@ -117,6 +88,9 @@ def build_config(**kwargs):
add_pass: set of str add_pass: set of str
Optimization pass to be added regardless of optimization level. Optimization pass to be added regardless of optimization level.
disable_pass: set of str
Optimization pass to be disabled during optimization.
fallback_device : str or tvm.TVMContext fallback_device : str or tvm.TVMContext
The fallback device. It is also used as the default device for The fallback device. It is also used as the default device for
operators without specified device during heterogeneous execution. operators without specified device during heterogeneous execution.
...@@ -129,108 +103,203 @@ def build_config(**kwargs): ...@@ -129,108 +103,203 @@ def build_config(**kwargs):
return BuildConfig(**kwargs) return BuildConfig(**kwargs)
def _bind_params_by_name(func, params): def _update_target(target):
"""Bind parameters of function by its name.""" target = target if target else _target.current_target()
name_dict = {} if target is None:
for arg in func.params: raise ValueError("Target is not set in env or passed as argument.")
name = arg.name_hint
if name in name_dict:
name_dict[name] = None
else:
name_dict[name] = arg
bind_dict = {}
for k, v in params.items():
if k not in name_dict:
continue
arg = name_dict[k]
if arg is None:
raise ValueError("Multiple args in the function have name %s" % k)
bind_dict[arg] = _expr.const(v)
return _expr.bind(func, bind_dict)
def optimize(func, target=None, params=None):
"""Perform target invariant optimizations.
Parameters
----------
func : tvm.relay.Function
The input to optimization.
target : Optional[:any:`tvm.target.Target`, Dict[int, tvm.target.Target]] tgts = {}
The optimization target. For heterogeneous compilation, it is a if isinstance(target, (str, _target.Target)):
dictionary mapping device type to compilation target. For homogeneous dev_type = tvm_expr.IntImm("int32", _nd.context(str(target)).device_type)
compilation, it is a build target. tgts[dev_type] = _target.create(target)
elif isinstance(target, dict):
for dev, tgt in target.items():
dev_type = tvm_expr.IntImm("int32", _nd.context(dev).device_type)
tgts[dev_type] = _target.create(tgt)
else:
raise TypeError("target is expected to be str or " +
"tvm.target.Target, but received " +
"{}".format(type(target)))
return tgts
params : Optional[Dict[str, tvm.nd.NDArray]]
Input parameters to the graph that do not change
during inference time. used for constant folding.
Returns class BuildModule(object):
------- """Build a Relay function to run on TVM graph runtime. This class is used
opt_func : tvm.relay.Function to expose the `RelayBuildModule` APIs implemented in C++.
The optimized version of the function.
""" """
cfg = BuildConfig.current def __init__(self):
self.mod = _build_module._BuildModule()
# bind expressions self._get_graph_json = self.mod["get_graph_json"]
if params: self._get_module = self.mod["get_module"]
func = _bind_params_by_name(func, params) self._build = self.mod["build"]
self._add_pass = self.mod["add_pass"]
if cfg.pass_enabled("SimplifyInference"): self._disable_pass = self.mod["disable_pass"]
func = ir_pass.infer_type(func) self._set_opt_level = self.mod["set_opt_level"]
func = ir_pass.simplify_inference(func) self._set_fallback_device = self.mod["set_fallback_device"]
self._set_params_func = self.mod["set_params"]
if cfg.pass_enabled("EliminateCommonSubexpr"): self._get_params_func = self.mod["get_params"]
def fskip(expr):
if isinstance(expr, _expr.Call) and expr.op.name == 'cast' and \ def build(self, func, target=None, target_host=None, params=None):
expr.attrs.dtype == 'int32': """
return True Parameters
return False ----------
func: relay.Function
func = ir_pass.infer_type(func) The function to build.
func = ir_pass.eliminate_common_subexpr(func, fskip)
target : str, :any:`tvm.target.Target`, or dict of str(i.e.
if cfg.pass_enabled("CombineParallelConv2D"): device/context name) to str/tvm.target.Target, optional
func = ir_pass.infer_type(func) For heterogeneous compilation, it is a dictionary indicating context
func = ir_pass.combine_parallel_conv2d(func) to target mapping. For homogeneous compilation, it is a build target.
# The constant folding pass is necessary because FoldScaleAxis pass needs target_host : str or :any:`tvm.target.Target`, optional
# to check the constantness and positiveness of scales. Host compilation target, if target is device.
if cfg.pass_enabled("FoldConstant"): When TVM compiles device specific program such as CUDA,
func = ir_pass.fold_constant(func) we also need host(CPU) side code to interact with the driver
to setup the dimensions and parameters correctly.
if cfg.pass_enabled("FoldScaleAxis"): target_host is used to specify the host side codegen target.
func = ir_pass.infer_type(func) By default, llvm is used if it is enabled,
func = ir_pass.backward_fold_scale_axis(func) otherwise a stackvm intepreter is used.
func = ir_pass.infer_type(func)
func = ir_pass.forward_fold_scale_axis(func) params : dict of str to NDArray
func = ir_pass.fold_constant(func) Input parameters to the graph that do not change
during inference time. Used for constant folding.
if cfg.pass_enabled("CanonicalizeOps"):
func = ir_pass.infer_type(func) Returns
func = ir_pass.canonicalize_ops(func) -------
graph_json : str
# FIXME(zhiics) Skip AlterOpLayout pass for heterogeneous compilation for The json string that can be accepted by graph runtime.
# now. We probably need to pass target to this pass as well. Fix it in
# a followup PR. mod : tvm.Module
if cfg.pass_enabled("AlterOpLayout"): The module containing necessary libraries.
if isinstance(target, _target.Target):
func = ir_pass.infer_type(func) params : dict
with target: The parameters of the final graph.
func = ir_pass.alter_op_layout(func) """
elif isinstance(target, dict): target = _update_target(target)
warnings.warn("AlterOpLayout pass is not enabled for heterogeneous"
" execution yet.") # Setup the build configurations passed in through `with build_config`.
self._setup_build_config(params)
if cfg.pass_enabled("FoldConstant"): # Build the function
func = ir_pass.fold_constant(func) self._build(func, target, target_host)
# Get artifacts
return func graph_json = self.get_json()
mod = self.get_module()
params = self.get_params()
return graph_json, mod, params
def _setup_build_config(self, params):
cfg = BuildConfig.current
# Set opt_level.
self.set_opt_level(cfg.opt_level)
# Set fallback device if it is available.
if cfg.fallback_device:
self.set_fallback_device(cfg.fallback_device)
# Add required passes.
if cfg.add_pass:
passes = set()
if isinstance(cfg.add_pass, (list, tuple, set)):
passes = set(cfg.add_pass)
else:
raise TypeError("add_pass must be list, tuple, or set, but " +
"got {}".format(type(cfg.add_pass)))
for pass_name in passes:
self.add_pass(pass_name)
# Add disabled passes.
if cfg.disable_pass:
passes = set()
if isinstance(cfg.disable_pass, (list, tuple, set)):
passes = set(cfg.disable_pass)
else:
raise TypeError("disable_pass must be list, tuple, or set, " +
"but got {}".format(type(cfg.disable_pass)))
for pass_name in passes:
self.disable_pass(pass_name)
if params:
self._set_params(params)
def _set_params(self, params):
inputs = {}
for name, param in params.items():
if isinstance(param, np.ndarray):
param = _nd.array(param)
inputs[name] = _expr.const(param)
self._set_params_func(inputs)
def add_pass(self, pass_name):
"""Add a pass to the pass list.
Parameters
----------
pass_name : str
The name of the pass that will be added to the list of passes used
for optimizations.
"""
self._add_pass(pass_name)
def disable_pass(self, pass_name):
"""Add a pass to the disabled pass list.
Parameters
----------
pass_name : str
The name of a pass. This pass will be added to the list of passes
that are disabled during optimization.
"""
self._disable_pass(pass_name)
def get_json(self):
"""Return the json file of the built program."""
return self._get_graph_json()
def get_module(self):
"""Return the built module."""
return self._get_module()
def get_params(self):
"""Return the updated weights."""
params = self._get_params_func()
ret = {}
for key, value in params.items():
ret[key] = value.data
return ret
def set_opt_level(self, level):
"""Set the optimization level.
Parameters
----------
level : int
The optimization level for build.
"""
self._set_opt_level(level)
def set_fallback_device(self, fallback_device):
"""Set the fallback device for heterogeneous execution.
Parameters
----------
fallback_device : str or tvm.TVMContext
The fallback device used for heterogeneous execution.
"""
if isinstance(fallback_device, str):
fallback_device = _nd.context(fallback_device)
if not isinstance(fallback_device, TVMContext):
raise TypeError("fallback_device is expected to be str " +
"TVMContext, or dict of device name to target, " +
"but received: {}".format(type(fallback_device)))
self._set_fallback_device(fallback_device.device_type)
def build(func, target=None, target_host=None, params=None): def build(func, target=None, target_host=None, params=None):
"""Build a function to run on TVM graph runtime. """Helper function that builds a Relay function to run on TVM graph
runtime.
Parameters Parameters
---------- ----------
...@@ -266,146 +335,28 @@ def build(func, target=None, target_host=None, params=None): ...@@ -266,146 +335,28 @@ def build(func, target=None, target_host=None, params=None):
params : dict params : dict
The parameters of the final graph. The parameters of the final graph.
""" """
target = target if target else _target.current_target() target = _update_target(target)
if target is None:
raise ValueError("Target is not set in env or passed as argument.")
if isinstance(target, dict): if isinstance(target_host, (str, _target.Target)):
target, fallback_device = _update_heterogeneous_inputs(target) target_host = _target.create(target_host)
elif isinstance(target, (str, _target.Target)): elif target_host:
target = _target.create(target) raise ValueError("target host must be the type of str, " +
else: "tvm.target.Target, or None")
raise ValueError("target must be the type of str, tvm.target.Target," +
"or dict of device name to target")
# If current dispatch context is fallback context (the default root context), # If current dispatch context is fallback context (the default root context),
# then load pre-tuned parameters from TopHub # then load pre-tuned parameters from TopHub
if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext): if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext):
if isinstance(target, dict): tophub_context = autotvm.tophub.context(list(target.values()))
tophub_context = autotvm.tophub.context(list(target.values()))
else:
tophub_context = autotvm.tophub.context(target)
else: else:
tophub_context = autotvm.util.EmptyContext() tophub_context = autotvm.util.EmptyContext()
cfg = BuildConfig.current
with tophub_context: with tophub_context:
func = optimize(func, target, params) bld_mod = BuildModule()
# Annotate the ops for heterogeneous execution. graph_json, mod, params = bld_mod.build(func, target, target_host,
if isinstance(target, dict): params)
func, target = _run_device_annotation_passes(func, target,
fallback_device)
# Fuse ops before running code gen
func = ir_pass.infer_type(func)
func = ir_pass.fuse_ops(func, cfg.opt_level)
# Graph code generation
func = ir_pass.infer_type(func)
graph_gen = _graph_gen.GraphRuntimeCodegen(mod=None, target=target)
graph_json, lowered_funcs, params = graph_gen.codegen(func)
mod = _tvm_build_module(
lowered_funcs, target=target, target_host=target_host)
return graph_json, mod, params return graph_json, mod, params
def _update_heterogeneous_inputs(target):
"""Update the target and fallback device required for heterogeneous
compilation. CPU is used as the fallback device if it wasn't provided.
Meanwhile, a CPU device type and "llvm" pair will be added to the target
dictionary in this case.
Parameters
----------
target : dict of str(i.e. device/context name) to str/tvm.target.Target.
A dict contains context to target pairs.
Returns
-------
device_target : dict of int to tvm.target.Target.
The updated device type to target dict.
fallback_device : int
The updated fallback device type.
"""
if not isinstance(target, dict):
raise ValueError("target must be dict of device name to target for " +
"heterogeneous execution, but received %s."
% type(target))
fallback_device = BuildConfig.current.fallback_device
if fallback_device is None:
# cpu is used as the default fallback device when heterogeneous
# execution is needed, but no fallback device is provided.
fallback_device = _nd.cpu(0).device_type
target[fallback_device] = str(_target.create("llvm"))
elif isinstance(fallback_device, str):
fallback_device = _nd.context(fallback_device).device_type
elif isinstance(fallback_device, TVMContext):
fallback_device = fallback_device.device_type
else:
raise ValueError("fallback_device expects the type of str or " +
"TVMContext, but received %s." % type(fallback_device))
device_target = {}
for dev, tgt in target.items():
device_target[_nd.context(dev).device_type] = _target.create(tgt)
if fallback_device not in device_target:
raise ValueError("%s is used as the default device, but the target" +
"is not provided."
% _nd.context(fallback_device).device_name)
return device_target, fallback_device
def _run_device_annotation_passes(func, target, fallback_device):
"""Execute the device annotation passes to update the input program and
target information.
Parameters
----------
func: tvm.relay.Function
The function where annotation passes will be execute at.
target : Dict[int, tvm.target.Target]
A dict contains device type to target pairs.
fallback_device : int
The fallback device type.
Returns
-------
target : Dict[int, tvm.target.Target]
The updated device type to target dict.
func : tvm.relay.Function
The updated func.
"""
func = ir_pass.infer_type(func)
func = ir_pass.rewrite_annotated_ops(func, fallback_device)
device_map = ir_pass.collect_device_info(func)
# The expression to device type map will be empty if all or none of
# the expressions in the `func` are annotated because this map is
# obtained by propagating the device information in the device copy
# operator. None of the above cases needs device copy operator.
if not device_map:
annotation_map = ir_pass.collect_device_annotation_ops(func)
# No annotation.
if not annotation_map:
target = {0: target[fallback_device]}
else:
dev_type = next(iter(annotation_map.values()))
# All annotated with the same device type.
if all(val == dev_type for val in annotation_map.values()):
target = {0: target[dev_type]}
else:
raise RuntimeError("Expressions in the function are "
"annotated with various device types,"
"but not device copy operators "
"found. Please check the "
"RewriteAnnotation pass.")
return func, target
class GraphExecutor(_interpreter.Executor): class GraphExecutor(_interpreter.Executor):
"""Wrapper around Executor interface. """Wrapper around Executor interface.
......
...@@ -269,6 +269,77 @@ def realize(graph): ...@@ -269,6 +269,77 @@ def realize(graph):
return _quantize.realize(graph) return _quantize.realize(graph)
def optimize(func, params=None):
""" Perform "SimplifyInference", "FoldScaleAxis", "FoldConstant", and
"CanonicalizeOps" optimization before quantization.
# TODO(zhiics) These passes are executed one by one so far. We need to
# move them to the pass manager.
Parameters
---------
func: tvm.relay.Function
The original Relay function to be optimized.
params : dict of str to tvm.NDArray
Input parameters to the graph that do not change
during inference time. Used for constant folding.
Returns
-------
ret: tvm.relay.Function
The graph after quantization
"""
opt_passes = ["SimplifyInference",
"FoldScaleAxis",
"FoldConstant",
"CanonicalizeOps"]
cfg = _build.build_config(add_pass=opt_passes)
if params:
name_dict = {}
for arg in func.params:
name = arg.name_hint
if name in name_dict:
name_dict[name] = None
else:
name_dict[name] = arg
bind_dict = {}
for k, v in params.items():
if k not in name_dict:
continue
arg = name_dict[k]
if arg is None:
raise ValueError("Multiple args in the function have name %s" % k)
bind_dict[arg] = _expr.const(v)
func = _expr.bind(func, bind_dict)
if "SimplifyInference" in cfg.add_pass:
func = _ir_pass.infer_type(func)
func = _ir_pass.simplify_inference(func)
if "FoldConstant" in cfg.add_pass:
func = _ir_pass.fold_constant(func)
if "FoldScaleAxis" in cfg.add_pass:
func = _ir_pass.infer_type(func)
func = _ir_pass.backward_fold_scale_axis(func)
func = _ir_pass.infer_type(func)
func = _ir_pass.forward_fold_scale_axis(func)
func = _ir_pass.fold_constant(func)
if "CanonicalizeOps" in cfg.add_pass:
func = _ir_pass.infer_type(func)
func = _ir_pass.canonicalize_ops(func)
if "FoldConstant" in cfg.add_pass:
func = _ir_pass.fold_constant(func)
return func
def quantize(graph, params=None, dataset=None): def quantize(graph, params=None, dataset=None):
""" The quantization procedure. Before running the three main """ The quantization procedure. Before running the three main
procedure of quantization, "annotate", "calibrate" and "realize" procedure of quantization, "annotate", "calibrate" and "realize"
...@@ -292,12 +363,8 @@ def quantize(graph, params=None, dataset=None): ...@@ -292,12 +363,8 @@ def quantize(graph, params=None, dataset=None):
ret: Function ret: Function
The graph after quantization The graph after quantization
""" """
opt_passes = ["SimplifyInference", # TODO(zhiics) Move this to the pass manager.
"FoldScaleAxis", graph = optimize(graph, params)
"FoldConstant",
"CanonicalizeOps"]
with _build.build_config(add_pass=opt_passes):
graph = _build.optimize(graph, params=params)
graph = annotate(graph) graph = annotate(graph)
graph = calibrate(graph, dataset) graph = calibrate(graph, dataset)
......
...@@ -311,7 +311,7 @@ bool LLVMEnabled() { ...@@ -311,7 +311,7 @@ bool LLVMEnabled() {
/*! \return The default host target for a given device target */ /*! \return The default host target for a given device target */
Target DefaultTargetHost(Target target) { Target DefaultTargetHost(Target target) {
if (target->device_type == kDLCPU) { if (target.defined() && target->device_type == kDLCPU) {
return target; return target;
} else { } else {
if (LLVMEnabled()) { if (LLVMEnabled()) {
......
...@@ -38,54 +38,31 @@ namespace tvm { ...@@ -38,54 +38,31 @@ namespace tvm {
namespace relay { namespace relay {
namespace backend { namespace backend {
using TargetsMap = Map<tvm::Integer, tvm::Target>;
/*! /*!
* \brief Context name / index * \brief Context index to Target
* See: python/tvm/_ffi/runtime_ctypes.py
*/ */
struct ContextMap { struct ContextTargetMap {
static const std::unordered_map<int, std::string> mask2str; static const std::unordered_map<int, tvm::Target> mask2str;
static const std::unordered_map<std::string, int> str2mask; static tvm::Target Mask2Str(int mask) {
static std::string Mask2Str(int mask) {
CHECK_GT(mask2str.count(mask), 0) << "Unknown mask."; CHECK_GT(mask2str.count(mask), 0) << "Unknown mask.";
return mask2str.at(mask); return mask2str.at(mask);
} }
static int Str2Mask(const std::string& str) {
CHECK_GT(str2mask.count(str), 0) << "Unknown context.";
return str2mask.at(str);
}
};
const std::unordered_map<int, std::string> ContextMap::mask2str = {
{1, "cpu"},
{2, "gpu"},
{4, "opencl"},
{5, "aocl"},
{6, "sdaccel"},
{7, "vulkan"},
{8, "metal"},
{9, "vpi"},
{10, "rocm"},
{11, "opengl"},
{12, "ext_dev"}
}; };
const std::unordered_map<std::string, int> ContextMap::str2mask = { const std::unordered_map<int, tvm::Target> ContextTargetMap::mask2str = {
{"llvm", 1}, {1, tvm::Target::create("llvm")},
{"cpu", 1}, {2, tvm::Target::create("cuda")},
{"c", 1}, {4, tvm::Target::create("opencl")},
{"gpu", 2}, {5, tvm::Target::create("aocl")},
{"cuda", 2}, {6, tvm::Target::create("sdaccel")},
{"nvptx", 2}, {7, tvm::Target::create("vulkan")},
{"cl", 4}, {8, tvm::Target::create("metal")},
{"opencl", 4}, {9, tvm::Target::create("vpi")},
{"aocl", 5}, {10, tvm::Target::create("rocm")},
{"aocl_sw_emu", 5}, {11, tvm::Target::create("opengl")},
{"vulkan", 7}, {12, tvm::Target::create("ext_dev")}
{"metal", 8},
{"vpi", 9},
{"rocm", 10},
{"opengl", 11},
{"ext_dev", 12}
}; };
/*! /*!
...@@ -137,7 +114,7 @@ struct BuildOutput { ...@@ -137,7 +114,7 @@ struct BuildOutput {
*/ */
struct RelayBuildConfig { struct RelayBuildConfig {
int opt_level{2}; int opt_level{2};
std::string fallback_device{"llvm"}; int fallback_device{static_cast<int>(kDLCPU)};
std::unordered_set<std::string> enabled_pass; std::unordered_set<std::string> enabled_pass;
std::unordered_set<std::string> disabled_pass; std::unordered_set<std::string> disabled_pass;
OptPassLevel OPT_PASS_LEVEL; OptPassLevel OPT_PASS_LEVEL;
...@@ -164,14 +141,8 @@ struct GraphCodegen { ...@@ -164,14 +141,8 @@ struct GraphCodegen {
} }
~GraphCodegen() {} ~GraphCodegen() {}
void Init(runtime::Module* m, void Init(runtime::Module* m, TargetsMap targets) {
Map<HalideIR::Expr, HalideIR::Expr> targets) { CallFunc("init", m, targets);
Array<HalideIR::Expr> tgts;
for (auto kv : targets) {
tgts.push_back(kv.first);
tgts.push_back(kv.second);
}
CallFunc("init", m, tgts);
} }
void Codegen(const Function& func) { void Codegen(const Function& func) {
...@@ -248,14 +219,7 @@ class RelayBuildModule : public runtime::ModuleNode { ...@@ -248,14 +219,7 @@ class RelayBuildModule : public runtime::ModuleNode {
} else if (name == "build") { } else if (name == "build") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.num_args, 3); CHECK_EQ(args.num_args, 3);
Array<HalideIR::Expr> tmp = args[1]; this->Build(args[0], args[1], args[2]);
std::unordered_map<std::string, std::string> targets;
for (size_t i = 0; i < tmp.size(); i += 2) {
auto k = tmp[i].as<ir::StringImm>()->value;
auto v = tmp[i + 1].as<ir::StringImm>()->value;
targets[k] = v;
}
this->Build(args[0], targets, args[2]);
}); });
} else if (name == "list_params") { } else if (name == "list_params") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
...@@ -273,7 +237,8 @@ class RelayBuildModule : public runtime::ModuleNode { ...@@ -273,7 +237,8 @@ class RelayBuildModule : public runtime::ModuleNode {
}); });
} else if (name == "set_fallback_device") { } else if (name == "set_fallback_device") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
std::string dev = args[0]; CHECK_EQ(args.num_args, 1);
int dev = args[0];
this->SetFallBackDev(dev); this->SetFallBackDev(dev);
}); });
} else if (name == "add_pass") { } else if (name == "add_pass") {
...@@ -328,7 +293,7 @@ class RelayBuildModule : public runtime::ModuleNode { ...@@ -328,7 +293,7 @@ class RelayBuildModule : public runtime::ModuleNode {
* *
* \param device name * \param device name
*/ */
void SetFallBackDev(const std::string& dev) { void SetFallBackDev(int dev) {
cfg_.fallback_device = dev; cfg_.fallback_device = dev;
} }
/*! /*!
...@@ -402,8 +367,8 @@ class RelayBuildModule : public runtime::ModuleNode { ...@@ -402,8 +367,8 @@ class RelayBuildModule : public runtime::ModuleNode {
* \param target_host Host target device * \param target_host Host target device
*/ */
void Build(Function func, void Build(Function func,
const std::unordered_map<std::string, std::string>& targets, const TargetsMap& targets,
const std::string& target_host) { const tvm::Target& target_host) {
targets_ = targets; targets_ = targets;
target_host_ = target_host; target_host_ = target_host;
BuildRelay(func, cfg_, params_); BuildRelay(func, cfg_, params_);
...@@ -416,8 +381,9 @@ class RelayBuildModule : public runtime::ModuleNode { ...@@ -416,8 +381,9 @@ class RelayBuildModule : public runtime::ModuleNode {
* \param params params dict * \param params params dict
* \return relay::Function * \return relay::Function
*/ */
relay::Function BindParamsByName(relay::Function func, relay::Function BindParamsByName(
const std::unordered_map<std::string, runtime::NDArray>& params) { relay::Function func,
const std::unordered_map<std::string, runtime::NDArray>& params) {
std::unordered_map<std::string, relay::Var> name_dict; std::unordered_map<std::string, relay::Var> name_dict;
std::unordered_set<relay::Var, NodeHash, NodeEqual> repeat_var; std::unordered_set<relay::Var, NodeHash, NodeEqual> repeat_var;
for (auto arg : func->params) { for (auto arg : func->params) {
...@@ -454,7 +420,7 @@ class RelayBuildModule : public runtime::ModuleNode { ...@@ -454,7 +420,7 @@ class RelayBuildModule : public runtime::ModuleNode {
* \return relay::Function * \return relay::Function
*/ */
relay::Function Optimize(relay::Function func, relay::Function Optimize(relay::Function func,
const std::unordered_map<std::string, std::string>& targets, const TargetsMap& targets,
const RelayBuildConfig& cfg, const RelayBuildConfig& cfg,
const std::unordered_map<std::string, runtime::NDArray>& params) { const std::unordered_map<std::string, runtime::NDArray>& params) {
if (params.size()) { if (params.size()) {
...@@ -507,8 +473,7 @@ class RelayBuildModule : public runtime::ModuleNode { ...@@ -507,8 +473,7 @@ class RelayBuildModule : public runtime::ModuleNode {
auto enter_pf = GetPackedFunc("_EnterTargetScope"); auto enter_pf = GetPackedFunc("_EnterTargetScope");
auto exit_pf = GetPackedFunc("_ExitTargetScope"); auto exit_pf = GetPackedFunc("_ExitTargetScope");
for (const auto& kv : targets) { for (const auto& kv : targets) {
auto target = Target::create(kv.second); (*enter_pf)(kv.second);
(*enter_pf)(target);
func = CallPackedFunc("relay._ir_pass.AlterOpLayout", func); func = CallPackedFunc("relay._ir_pass.AlterOpLayout", func);
(*exit_pf)(); (*exit_pf)();
} }
...@@ -530,25 +495,19 @@ class RelayBuildModule : public runtime::ModuleNode { ...@@ -530,25 +495,19 @@ class RelayBuildModule : public runtime::ModuleNode {
* *
* \param targets dictionary * \param targets dictionary
* \param cfg * \param cfg
* \return Map<HalideIR::Expr, HalideIR::Expr> * \return Map<tvm::Integer, tvm::Target>
*/ */
Map<HalideIR::Expr, HalideIR::Expr> UpdateHeterogeneousInputs( TargetsMap UpdateHeterogeneousInputs(const TargetsMap& targets,
const std::unordered_map<std::string, std::string>& targets, const RelayBuildConfig& cfg) {
const RelayBuildConfig& cfg) { TargetsMap device_target = targets;
Map<HalideIR::Expr, HalideIR::Expr> device_target; std::unordered_map<int64_t, tvm::Target> tmp_map;
std::unordered_map<int64_t, std::string> tmp_map;
auto fallback_idx = ContextMap::Str2Mask(cfg.fallback_device);
for (const auto& kv : targets) { for (const auto& kv : targets) {
tmp_map[ContextMap::Str2Mask(kv.first)] = kv.second; tmp_map[kv.first->value] = kv.second;
}
if (tmp_map.count(fallback_idx) == 0) {
tmp_map[fallback_idx] = cfg.fallback_device;
} }
for (const auto& kv : tmp_map) { if (tmp_map.count(cfg.fallback_device) == 0) {
device_target.Set( device_target.Set(
ir::IntImm::make(HalideIR::Int(64), kv.first), cfg.fallback_device,
ir::StringImm::make(kv.second)); ContextTargetMap::Mask2Str(cfg.fallback_device));
} }
return device_target; return device_target;
} }
...@@ -561,25 +520,19 @@ class RelayBuildModule : public runtime::ModuleNode { ...@@ -561,25 +520,19 @@ class RelayBuildModule : public runtime::ModuleNode {
* \param targets_map_ptr * \param targets_map_ptr
* \return Function * \return Function
*/ */
Function RunDeviceAnnotationPass( Function RunDeviceAnnotationPass(Function func, const RelayBuildConfig& cfg,
Function func, TargetsMap* targets_map_ptr) {
const RelayBuildConfig& cfg,
Map<HalideIR::Expr, HalideIR::Expr>* targets_map_ptr) {
auto fallback_idx = ContextMap::Str2Mask(cfg.fallback_device);
func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
func = CallPackedFunc("relay._ir_pass.RewriteDeviceAnnotation", func, fallback_idx); func = CallPackedFunc("relay._ir_pass.RewriteDeviceAnnotation", func,
auto device_map = CallPackedFunc<Map<Expr, Integer> >("relay._ir_pass.CollectDeviceInfo", cfg.fallback_device);
func, auto device_map = CallPackedFunc<Map<Expr, Integer> >(
nullptr); "relay._ir_pass.CollectDeviceInfo", func, nullptr);
if (device_map.size() == 0) { if (device_map.size() == 0) {
auto annotation_map = auto annotation_map = CallPackedFunc<Map<Expr, Integer> >(
CallPackedFunc<Map<Expr, Integer> >("relay._ir_pass.CollectDeviceAnnotationOps", "relay._ir_pass.CollectDeviceAnnotationOps", func, nullptr);
func,
nullptr);
if (annotation_map.size() == 0) { if (annotation_map.size() == 0) {
targets_map_ptr->Set( targets_map_ptr->Set(
ir::IntImm::make(HalideIR::Int(64), 0), 0, ContextTargetMap::Mask2Str(cfg.fallback_device));
ir::StringImm::make(cfg.fallback_device));
} else { } else {
int64_t dev_type = -1; int64_t dev_type = -1;
for (auto kv : annotation_map) { for (auto kv : annotation_map) {
...@@ -594,9 +547,7 @@ class RelayBuildModule : public runtime::ModuleNode { ...@@ -594,9 +547,7 @@ class RelayBuildModule : public runtime::ModuleNode {
<< "found. Please check the " << "found. Please check the "
<< "RewriteAnnotation pass."; << "RewriteAnnotation pass.";
} }
targets_map_ptr->Set( targets_map_ptr->Set(0, ContextTargetMap::Mask2Str(dev_type));
ir::IntImm::make(HalideIR::Int(64), 0),
ir::StringImm::make(ContextMap::Mask2Str(dev_type)));
} }
} }
return func; return func;
...@@ -614,15 +565,11 @@ class RelayBuildModule : public runtime::ModuleNode { ...@@ -614,15 +565,11 @@ class RelayBuildModule : public runtime::ModuleNode {
const std::unordered_map<std::string, tvm::runtime::NDArray> &params) { const std::unordered_map<std::string, tvm::runtime::NDArray> &params) {
// convert // convert
tvm_cfg_ = build_config(); tvm_cfg_ = build_config();
Map<HalideIR::Expr, HalideIR::Expr> device_target; TargetsMap device_target;
if (targets_.size() > 1) { if (targets_.size() > 1) {
device_target = UpdateHeterogeneousInputs(targets_, cfg); device_target = UpdateHeterogeneousInputs(targets_, cfg);
} else { } else {
for (auto &kv : targets_) { device_target = targets_;
device_target.Set(
ir::IntImm::make(HalideIR::Int(64), ContextMap::Str2Mask(kv.first)),
ir::StringImm::make(kv.second));
}
} }
func = Optimize(func, targets_, cfg, params); func = Optimize(func, targets_, cfg, params);
if (device_target.size() > 1) { if (device_target.size() > 1) {
...@@ -640,16 +587,15 @@ class RelayBuildModule : public runtime::ModuleNode { ...@@ -640,16 +587,15 @@ class RelayBuildModule : public runtime::ModuleNode {
ret_.graph_json = graph_codegen_->GetJSON(); ret_.graph_json = graph_codegen_->GetJSON();
ret_.params = graph_codegen_->GetParams(); ret_.params = graph_codegen_->GetParams();
auto target_host = Target::create(target_host_); ret_.mod = tvm::build(graph_codegen_->GetLoweredFunc(), target_host_, tvm_cfg_);
ret_.mod = tvm::build(graph_codegen_->GetLoweredFunc(), target_host, tvm_cfg_);
} }
protected: protected:
std::unique_ptr<GraphCodegen> graph_codegen_; std::unique_ptr<GraphCodegen> graph_codegen_;
/*! \brief target device */ /*! \brief target device */
std::unordered_map<std::string, std::string> targets_; TargetsMap targets_;
/*! \brief target host device */ /*! \brief target host device */
std::string target_host_; tvm::Target target_host_;
/*! \brief frontend optimization configure */ /*! \brief frontend optimization configure */
RelayBuildConfig cfg_; RelayBuildConfig cfg_;
/*! \brief parameters */ /*! \brief parameters */
......
...@@ -51,7 +51,7 @@ using GraphAttrs = std::unordered_map<std::string, dmlc::any>; ...@@ -51,7 +51,7 @@ using GraphAttrs = std::unordered_map<std::string, dmlc::any>;
using GraphNodePtr = std::shared_ptr<GraphNode>; using GraphNodePtr = std::shared_ptr<GraphNode>;
using GraphInputNodePtr = std::shared_ptr<GraphInputNode>; using GraphInputNodePtr = std::shared_ptr<GraphInputNode>;
using GraphOpNodePtr = std::shared_ptr<GraphOpNode>; using GraphOpNodePtr = std::shared_ptr<GraphOpNode>;
using TargetsMap = std::unordered_map<std::string, Target>; using TargetsMap = std::unordered_map<int, Target>;
/*! \brief Lowered outputs */ /*! \brief Lowered outputs */
struct LoweredOutput { struct LoweredOutput {
...@@ -193,12 +193,10 @@ class GraphOpNode : public GraphNode { ...@@ -193,12 +193,10 @@ class GraphOpNode : public GraphNode {
class GraphRuntimeCodegen class GraphRuntimeCodegen
: public ::tvm::relay::ExprFunctor<std::vector<GraphNodeRef>(const Expr&)> { : public ::tvm::relay::ExprFunctor<std::vector<GraphNodeRef>(const Expr&)> {
public: public:
GraphRuntimeCodegen(runtime::Module* mod, GraphRuntimeCodegen(runtime::Module* mod, const TargetsMap& targets)
const std::unordered_map<std::string, std::string>& targets) : mod_(mod) { : mod_(mod) {
compile_engine_ = CompileEngine::Global(); compile_engine_ = CompileEngine::Global();
for (auto &kv : targets) { targets_ = targets;
targets_[kv.first] = Target::create(kv.second);
}
} }
LoweredOutput Codegen(relay::Function func) { LoweredOutput Codegen(relay::Function func) {
...@@ -406,7 +404,7 @@ class GraphRuntimeCodegen ...@@ -406,7 +404,7 @@ class GraphRuntimeCodegen
auto pf0 = GetPackedFunc("relay.backend._make_CCacheKey"); auto pf0 = GetPackedFunc("relay.backend._make_CCacheKey");
auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower"); auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower");
auto &device_type = storage_device_map_[expr][1]; auto &device_type = storage_device_map_[expr][1];
auto call_dev_type = device_type[0]->value; //-> int to string auto call_dev_type = device_type[0]->value;
Target target; Target target;
if (targets_.size() == 1) { if (targets_.size() == 1) {
// homogeneous execution. // homogeneous execution.
...@@ -415,22 +413,17 @@ class GraphRuntimeCodegen ...@@ -415,22 +413,17 @@ class GraphRuntimeCodegen
} }
} else { } else {
// heterogeneous execution. // heterogeneous execution.
const auto call_dev_key = std::to_string(call_dev_type);
std::string call_dev_name; std::string call_dev_name;
if (call_dev_type == 0) { if (call_dev_type == 0) {
call_dev_name = "llvm"; call_dev_name = "llvm";
} else { } else {
call_dev_name = runtime::DeviceName(call_dev_type); call_dev_name = runtime::DeviceName(call_dev_type);
} }
if (targets_.count(call_dev_name) == 0 && targets_.count(call_dev_key) == 0) { if (targets_.count(call_dev_type) == 0) {
LOG(FATAL) << "No target is provided for device " LOG(FATAL) << "No target is provided for device "
<< call_dev_name; << call_dev_name;
} }
if (targets_.count(call_dev_key)) { target = targets_[call_dev_type];
target = targets_[call_dev_key];
} else {
target = targets_[call_dev_name];
}
} }
CCacheKey key = (*pf0)(func, target); CCacheKey key = (*pf0)(func, target);
CachedFunc lowerd_func = (*pf1)(compile_engine_, key); CachedFunc lowerd_func = (*pf1)(compile_engine_, key);
...@@ -604,30 +597,21 @@ class GraphRuntimeCodegenModule : public runtime::ModuleNode { ...@@ -604,30 +597,21 @@ class GraphRuntimeCodegenModule : public runtime::ModuleNode {
virtual PackedFunc GetFunction(const std::string& name, virtual PackedFunc GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) { const std::shared_ptr<ModuleNode>& sptr_to_self) {
if (name == "init") { if (name == "init") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.num_args, 2) << "The expected of arguments are: " CHECK_EQ(args.num_args, 2)
<< "runtime::Module mod and Map<str, StringImm> targets"; << "The expected of arguments are: "
void* mod = args[0]; << "runtime::Module mod and Map<int, Target> targets";
auto& sptr = args[1].node_sptr(); void* mod = args[0];
auto* node = static_cast<const ArrayNode*>(sptr.get()); Map<Integer, tvm::Target> tmp = args[1];
auto& tmp_targets = node->data; TargetsMap targets;
std::unordered_map<std::string, std::string> targets; for (const auto& it : tmp) {
for (size_t i = 0; i < tmp_targets.size(); i += 2) { auto dev_type = it.first.as<ir::IntImm>();
std::string key; CHECK(dev_type);
auto sk = Expr(tmp_targets[i]).as<ir::StringImm>(); targets[dev_type->value] = it.second;
auto ik = Expr(tmp_targets[i]).as<ir::IntImm>(); }
if (sk) { codegen_ = std::make_shared<GraphRuntimeCodegen>(
key = sk->value; reinterpret_cast<runtime::Module*>(mod), targets);
} });
if (ik) {
key = std::to_string(ik->value);
}
auto v = Expr(tmp_targets[i + 1]).as<ir::StringImm>();
targets[key] = v->value;
}
codegen_ = std::make_shared<GraphRuntimeCodegen>(
reinterpret_cast<runtime::Module*>(mod), targets);
});
} else if (name == "codegen") { } else if (name == "codegen") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
Function func = args[0]; Function func = args[0];
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <tvm/build_module.h>
#include <tvm/tvm.h> #include <tvm/tvm.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/type.h> #include <tvm/relay/type.h>
...@@ -73,10 +74,10 @@ TEST(Relay, BuildModule) { ...@@ -73,10 +74,10 @@ TEST(Relay, BuildModule) {
auto build_f = build_mod.GetFunction("build", false); auto build_f = build_mod.GetFunction("build", false);
auto json_f = build_mod.GetFunction("get_graph_json", false); auto json_f = build_mod.GetFunction("get_graph_json", false);
auto mod_f = build_mod.GetFunction("get_module", false); auto mod_f = build_mod.GetFunction("get_module", false);
Array<HalideIR::Expr> target_pair; Map<tvm::Integer, tvm::Target> targets;
target_pair.push_back(ir::StringImm::make("cpu")); Target llvm_tgt = Target::create("llvm");
target_pair.push_back(ir::StringImm::make("llvm")); targets.Set(0, llvm_tgt);
build_f(func, target_pair, "llvm"); build_f(func, targets, llvm_tgt);
std::string json = json_f(); std::string json = json_f();
tvm::runtime::Module mod = mod_f(); tvm::runtime::Module mod = mod_f();
// run // run
......
...@@ -74,13 +74,12 @@ def test_alter_layout_conv2d(): ...@@ -74,13 +74,12 @@ def test_alter_layout_conv2d():
for tgt in targets: for tgt in targets:
with tvm.target.create(tgt) as target: with tvm.target.create(tgt) as target:
with relay.build_config(opt_level=-1, add_pass='AlterOpLayout'): with autotvm.tophub.context(target):
with autotvm.tophub.context(target): O = relay.ir_pass.alter_op_layout(N)
O = relay.optimize(N, target, params=None) O = relay.ir_pass.infer_type(O)
O = relay.ir_pass.infer_type(O)
# graph should differ # graph should differ
assert not relay.ir_pass.alpha_equal(N, O) assert not relay.ir_pass.alpha_equal(N, O)
if __name__ == "__main__": if __name__ == "__main__":
np.random.seed(42) np.random.seed(42)
......
...@@ -18,55 +18,10 @@ import numpy as np ...@@ -18,55 +18,10 @@ import numpy as np
import tvm import tvm
from tvm import relay from tvm import relay
from tvm.contrib.nvcc import have_fp16
from tvm._ffi.function import _init_api
_init_api("tvm.relay.build_module") def test_basic_build():
class BuildModule(object):
def __init__(self):
self.mod = relay.build_module._BuildModule()
self._get_graph_json = self.mod["get_graph_json"]
self._get_module = self.mod["get_module"]
self._build = self.mod["build"]
self._set_opt_level = self.mod["set_opt_level"]
self._set_params_func = self.mod["set_params"]
self._get_params_func = self.mod["get_params"]
def build(self, func, target, target_host, params):
tgts = []
for kv in target.items():
tgts.append(kv[0])
tgts.append(kv[1])
self._set_params(params)
self._build(func, tgts, target_host)
def get_json(self):
return self._get_graph_json()
def get_module(self):
return self._get_module()
def set_opt_level(self, level):
self._set_opt_level(level)
def _set_params(self, params):
inputs = {}
for name, param in params.items():
inputs[name] = relay.Constant(param)
self._set_params_func(inputs)
def get_params(self):
params = self._get_params_func()
ret = {}
for key, value in params.items():
ret[key] = value.data
return ret
def test_build():
m_bld = BuildModule()
tgt_name = "llvm"
tgt = "llvm" tgt = "llvm"
ctx = tvm.cpu() ctx = tvm.cpu()
# func # func
...@@ -86,21 +41,96 @@ def test_build(): ...@@ -86,21 +41,96 @@ def test_build():
} }
# build # build
targets = { targets = {
tgt: tgt tvm.expr.IntImm("int32", ctx.device_type): tgt
} }
m_bld.set_opt_level(3) g_json, mmod, params = relay.build(func, targets, "llvm", params=params)
m_bld.build(func, targets, "llvm", params=params)
g_json = m_bld.get_json()
mmod = m_bld.get_module()
params = m_bld.get_params()
# test # test
rt = tvm.contrib.graph_runtime.create(g_json, mmod, ctx) rt = tvm.contrib.graph_runtime.create(g_json, mmod, ctx)
rt.set_input("a", A) rt.set_input("a", A)
rt.load_params(relay.save_param_dict(params)) rt.load_params(relay.save_param_dict(params))
rt.run() rt.run()
out = rt.get_output(0) out = rt.get_output(0)
np.testing.assert_allclose(out.asnumpy(), np.testing.assert_allclose(out.asnumpy(), np.maximum(np.dot(A.asnumpy(),
np.maximum(np.dot(A.asnumpy(), B.asnumpy().T), 0) + C.asnumpy(), atol=1e-5, rtol=1e-5) B.asnumpy().T),
0) + C.asnumpy(),
atol=1e-5, rtol=1e-5)
def test_fp16_build():
dtype = "float16"
if not tvm.module.enabled("cuda") or not tvm.gpu(0).exist:
print("skip because cuda is not enabled.")
return
ctx = tvm.gpu(0)
if dtype == "float16" and not have_fp16(ctx.compute_version):
print("skip because gpu does not support fp16")
return
x = relay.var("x", dtype=dtype, shape=(4, 4))
y = relay.var("y", dtype=dtype, shape=(4, 4))
z = x + y
func = relay.Function([x, y], z)
X = tvm.nd.array(np.random.uniform(-1, 1, (4, 4)).astype(dtype), ctx=ctx)
Y = tvm.nd.array(np.random.uniform(-1, 1, (4, 4)).astype(dtype), ctx=ctx)
params = {
"x": X,
"y": Y,
}
# build
g_json, mmod, params = relay.build(func, "cuda", params=params)
# test
rt = tvm.contrib.graph_runtime.create(g_json, mmod, ctx)
rt.load_params(relay.save_param_dict(params))
rt.run()
out = rt.get_output(0)
np.testing.assert_allclose(out.asnumpy(), X.asnumpy() + Y.asnumpy(),
atol=1e-5, rtol=1e-5)
def test_fp16_conversion():
def check_conversion(tgt, ctx):
if not tvm.module.enabled(tgt):
print("skip because {} is not enabled.".format(tgt))
return
elif tgt == "cuda" and ctx.exist and not have_fp16(ctx.compute_version):
print("skip because gpu does not support fp16")
return
n = 10
for (src, dst) in [('float32', 'float16'), ('float16', 'float32')]:
x = relay.var("x", relay.TensorType((n,), src))
y = x.astype(dst)
func = relay.Function([x], y)
# init input
X = tvm.nd.array(n * np.random.randn(n).astype(src) - n / 2)
# build
with relay.build_config(opt_level=1):
g_json, mmod, params = relay.build(func, tgt)
# test
rt = tvm.contrib.graph_runtime.create(g_json, mmod, ctx)
rt.set_input("x", X)
rt.run()
out = rt.get_output(0)
np.testing.assert_allclose(out.asnumpy(), X.asnumpy().astype(dst),
atol=1e-5, rtol=1e-5)
for target, ctx in [('llvm', tvm.cpu()), ('cuda', tvm.gpu())]:
check_conversion(target, ctx)
if __name__ == "__main__":
test_basic_build()
test_fp16_build()
test_fp16_conversion()
...@@ -411,7 +411,7 @@ def run_fusible_network(dev, tgt): ...@@ -411,7 +411,7 @@ def run_fusible_network(dev, tgt):
expected_index) expected_index)
def test_fallback_all_operators(device, tgt): def test_fallback_all_operators(device, tgt):
target = {device: tgt} target = {device: tgt, "cpu": "llvm"}
annotated_func = get_func() annotated_func = get_func()
expected_func = get_func() expected_func = get_func()
check_annotated_graph(annotated_func, expected_func) check_annotated_graph(annotated_func, expected_func)
......
...@@ -47,54 +47,54 @@ def test_simulated_quantize(): ...@@ -47,54 +47,54 @@ def test_simulated_quantize():
assert out.args[3].checked_type == relay.ty.TensorType(tuple(), "float32") assert out.args[3].checked_type == relay.ty.TensorType(tuple(), "float32")
# def test_quantize_pass(): def test_quantize_pass():
# def quantize_weight(arr): def quantize_weight(arr):
# maximum = np.amax(np.abs(arr.asnumpy())) maximum = np.amax(np.abs(arr.asnumpy()))
# scale = 2**math.ceil(math.log(maximum, 2)) scale = 2**math.ceil(math.log(maximum, 2))
# out = np.around(arr.asnumpy() / scale * 128).astype('int8') out = np.around(arr.asnumpy() / scale * 128).astype('int8')
# out = np.clip(out, -127, 127) out = np.clip(out, -127, 127)
# return relay.const(out, 'int8') return relay.const(out, 'int8')
#
# n, c, h, w = 1, 3, 224, 224 n, c, h, w = 1, 3, 224, 224
# def make_graph(data): def make_graph(data):
# weight = relay.var("conv_weight") weight = relay.var("conv_weight")
# out = relay.nn.conv2d(data, weight, kernel_size=(3, 3), padding=(1, 1), channels=c) out = relay.nn.conv2d(data, weight, kernel_size=(3, 3), padding=(1, 1), channels=c)
# out = relay.Function(relay.ir_pass.free_vars(out), out) out = relay.Function(relay.ir_pass.free_vars(out), out)
# return out return out
#
# def make_qgraph(data, weight): def make_qgraph(data, weight):
# out = data * relay.const(32.0) out = data * relay.const(32.0)
# out = relay.round(out) out = relay.round(out)
# out = relay.clip(out, a_min=-127, a_max=127) out = relay.clip(out, a_min=-127, a_max=127)
# out = out.astype('int8') out = out.astype('int8')
#
# out = relay.nn.conv2d(out, weight, kernel_size=(3, 3), out = relay.nn.conv2d(out, weight, kernel_size=(3, 3),
# padding=(1, 1), channels=c, out_dtype='int32') padding=(1, 1), channels=c, out_dtype='int32')
# out = out.astype('float32') out = out.astype('float32')
# out = relay.multiply(out, relay.const(0.00024414062)) out = relay.multiply(out, relay.const(0.00024414062))
# out = relay.Function(relay.ir_pass.free_vars(out), out) out = relay.Function(relay.ir_pass.free_vars(out), out)
# return out return out
#
# data = relay.var("data", relay.TensorType((n, c, h, w), "float32")) data = relay.var("data", relay.TensorType((n, c, h, w), "float32"))
# graph = make_graph(data) graph = make_graph(data)
# dataset, params = make_dataset(graph, 10) dataset, params = make_dataset(graph, 10)
#
# with qtz.qconfig(skip_k_conv=0, global_scale=4.0, with qtz.qconfig(skip_k_conv=0, global_scale=4.0,
# round_for_shift=False, store_lowbit_output=False): round_for_shift=False, store_lowbit_output=False):
# qgraph0 = qtz.quantize(graph, params) qgraph0 = qtz.quantize(graph, params)
# qgraph0 = relay.ir_pass.infer_type(qgraph0) qgraph0 = relay.ir_pass.infer_type(qgraph0)
#
# conv_weight = quantize_weight(params['conv_weight']) conv_weight = quantize_weight(params['conv_weight'])
# qgraph1 = make_qgraph(data, conv_weight) qgraph1 = make_qgraph(data, conv_weight)
# qgraph1 = relay.ir_pass.infer_type(qgraph1) qgraph1 = relay.ir_pass.infer_type(qgraph1)
#
# graph = relay.create_executor('graph') graph = relay.create_executor('graph')
# res0 = graph.evaluate(qgraph0)(dataset[0]['data']) res0 = graph.evaluate(qgraph0)(dataset[0]['data'])
# res1 = graph.evaluate(qgraph1)(dataset[0]['data']) res1 = graph.evaluate(qgraph1)(dataset[0]['data'])
# tvm.testing.assert_allclose(res0.asnumpy(), res1.asnumpy(), rtol=1e-3) tvm.testing.assert_allclose(res0.asnumpy(), res1.asnumpy(), rtol=1e-3)
if __name__ == "__main__": if __name__ == "__main__":
np.random.seed(42) np.random.seed(42)
test_simulated_quantize() test_simulated_quantize()
# test_quantize_pass() test_quantize_pass()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment