Unverified Commit c91ded32 by Tianqi Chen Committed by GitHub

[RELAY][BACKEND] CompileEngine refactor. (#2059)

parent 4e77eeb2
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/build_module.h
* \brief The passes and data structures needed to build a
* tvm::Module from a Relay program.
*/
#ifndef TVM_RELAY_BUILD_MODULE_H_
#define TVM_RELAY_BUILD_MODULE_H_
#include <tvm/lowered_func.h>
#include <tvm/relay/module.h>
#include <tvm/relay/expr.h>
#include <string>
namespace tvm {
namespace relay {
/*! \brief A lowered Relay operation.
*
* A lowered operation is a pair containing the "primitive" function used
* to produce the lowered function as well as the lowered function itself.
*/
class LoweredOp;
/*! \brief Call container. */
class LoweredOpNode : public Node {
public:
/*!
* \brief The primitive function to be lowered.
*
* A primitive function consists only of calls to relay::Op which
* can be fused.
*/
Function func;
/*!
* \brief The lowered function.
*/
LoweredFunc lowered_func;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("func", &func);
v->Visit("lowered_func", &lowered_func);
}
TVM_DLL static LoweredOp make(
Function func,
LoweredFunc lowered_func);
static constexpr const char* _type_key = "relay.LoweredOp";
TVM_DECLARE_NODE_TYPE_INFO(LoweredOpNode, Node);
};
RELAY_DEFINE_NODE_REF(LoweredOp, LoweredOpNode, NodeRef);
/*!
* \brief Lower the operations contained in a Relay expression.
*
* The lowering pass will only lower functions marked as primitive,
* the FuseOps pass will provide this behavior, if run before LowerOps.
*
* \note This will do a reachability analysis and lower all definitions
* reachable from the provided expression.
*
* \param mod The module.
* \param expr The expression with operations to be lowered.
* \param target The target to lower the functions to.
*
* \return The set of lowered operations.
*/
Array<LoweredOp> LowerOps(const Module& mod, const Expr& expr,
const std::string& target = "llvm");
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_BUILD_MODULE_H_
......@@ -16,6 +16,7 @@
#ifndef TVM_RELAY_INTERPRETER_H_
#define TVM_RELAY_INTERPRETER_H_
#include <tvm/build_module.h>
#include <tvm/relay/module.h>
#include <tvm/relay/expr.h>
......@@ -27,7 +28,9 @@ namespace relay {
*/
class Value;
/*! \brief Evaluate an expression using the interpreter producing a value.
/*!
*\brief Create a Interpreter function that can
* evaluate an expression and produce a value.
*
* The resulting value can be passed to Python, making it easy to use
* for testing and debugging.
......@@ -38,8 +41,14 @@ class Value;
*
* Our intent is that this will never be the most efficient implementation of
* Relay's semantics, but a readable and clear one.
*
* \param mod The function module.
* \param context The primary context that the interepreter runs on.
* \param target Compiler target flag to compile the functions on the context.
* \return A function that takes in an expression and returns a value.
*/
Value Evaluate(Module mod, Expr e);
runtime::TypedPackedFunc<Value(Expr)>
CreateInterpreter(Module mod, DLContext context, Target target);
/*! \brief The base container type of Relay values. */
class ValueNode : public RelayNode {
......@@ -125,9 +134,6 @@ struct TensorValueNode : ValueNode {
/*! \brief Build a value from an NDArray. */
TVM_DLL static TensorValue make(runtime::NDArray data);
/*! \brief Construct an empty tensor value from t. */
TVM_DLL static TensorValue FromType(const Type& t);
static constexpr const char* _type_key = "relay.TensorValue";
TVM_DECLARE_NODE_TYPE_INFO(TensorValueNode, ValueNode);
};
......
/*!
* Copyright (c) 2017 by Contributors
* \file nnvm/compiler/op_attr_types.h
* \brief The Expr and related elements in DataFlow construction.
*/
#ifndef TVM_RELAY_OP_ATTR_TYPES_H_
#define TVM_RELAY_OP_ATTR_TYPES_H_
#include <tvm/tensor.h>
#include <tvm/schedule.h>
#include <tvm/build_module.h>
#include <tvm/relay/type.h>
#include <tvm/relay/expr.h>
namespace tvm {
namespace relay {
/*! \brief operator pattern used in graph fusion */
enum OpPatternKind {
// Elementwise operation
kElemWise = 0,
// Broadcasting operator, can always map output axis to the input in order.
// for example :code:`out[i, ax1, j, ax2] = input[i, j]`.
// Note that the axis need to be in order so transpose is not a bcast operator.
kBroadcast = 1,
// Injective operator, can always injectively map output axis to a single input axis.
// All injective operator can still be safely fused to injective and reduction.
kInjective = 2,
// Communicative reduction operator.
kCommReduce = 3,
// Complex operation, can still fuse elemwise operations into its output.
// but cannot chain another complex op
kOutEWiseFusable = 4,
// Opaque operation, cannot fuse anything.
kOpaque = 8
};
/*! \brief the operator pattern */
using TOpPattern = int;
/*!
* \brief Computation description interface.
*
* \note This function have a special convention
* for functions with tuple input/output.
*
* So far we restrict tuple support to the following case:
* - Function which takes a single tuple as input.
* - Function which outputs a single tuple.
*
* In both cases, the tuple is flattened as array.
*
* \param attrs The attribute of the primitive
* \param inputs The input tensors.
* \param out_type The output type information
& these are always placeholders.
* \return The output compute description of the operator.
*/
using FTVMCompute = runtime::TypedPackedFunc<
Array<Tensor>(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target)>;
/*!
* \brief Build the computation schedule for
* op whose root is at current op.
*
* \param attrs The attribute of the node.
* \param outs The output tensors.
* \param target The build target.
* \return schedule The computation schedule.
*/
using FTVMSchedule = runtime::TypedPackedFunc<
Schedule(const Array<Tensor>& outs,
const Target& target)>;
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_OP_ATTR_TYPES_H_
......@@ -178,6 +178,40 @@ class DeviceAPI {
/*! \brief The device type bigger than this is RPC device */
constexpr int kRPCSessMask = 128;
/*!
* \brief The name of Device API factory.
* \param type The device type.
* \return the device name.
*/
inline const char* DeviceName(int type) {
switch (type) {
case kDLCPU: return "cpu";
case kDLGPU: return "gpu";
case kDLOpenCL: return "opencl";
case kDLSDAccel: return "sdaccel";
case kDLAOCL: return "aocl";
case kDLVulkan: return "vulkan";
case kDLMetal: return "metal";
case kDLVPI: return "vpi";
case kDLROCM: return "rocm";
case kOpenGL: return "opengl";
case kDLExtDev: return "ext_dev";
default: LOG(FATAL) << "unknown type =" << type; return "Unknown";
}
}
#ifndef _LIBCPP_SGX_NO_IOSTREAMS
inline std::ostream& operator<<(std::ostream& os, DLContext ctx) { // NOLINT(*)
int device_type = static_cast<int>(ctx.device_type);
if (device_type > kRPCSessMask) {
os << "remote[" << (device_type / kRPCSessMask) << "]-";
device_type = device_type % kRPCSessMask;
}
os << runtime::DeviceName(device_type) << "(" << ctx.device_id << ")";
return os;
}
#endif
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_DEVICE_API_H_
......@@ -888,6 +888,7 @@ inline std::ostream& operator<<(std::ostream& os, TVMType t) { // NOLINT(*)
}
return os;
}
#endif
inline std::string TVMType2String(TVMType t) {
......
......@@ -132,7 +132,7 @@ class GraphModule(object):
params : dict of str to NDArray
Additonal arguments
"""
if key:
if key is not None:
self._get_input(key).copyfrom(value)
if params:
......
......@@ -7,8 +7,7 @@ from . import ty
from . import expr
from . import module
from . import ir_pass
from .build_module import build
from .interpreter import create_executor
from .build_module import build, create_executor
# Root operators
from .op import Op
......@@ -18,7 +17,7 @@ from .op.transform import *
from . import nn
from . import vision
from . import image
from . import backend
from .scope_builder import ScopeBuilder
......@@ -56,13 +55,6 @@ TupleGetItem = expr.TupleGetItem
var = expr.var
const = expr.const
@register_func("relay._tensor_value_repr")
def _tensor_value_repr(tv):
return str(tv.data.asnumpy())
@register_func("relay._constant_repr")
def _tensor_constant_repr(tv):
return str(tv.data.asnumpy())
# pylint: disable=unused-argument
@register_func("relay.debug")
......
"""The interface to the Evaluator exposed from C++."""
from tvm._ffi.function import _init_api
_init_api("relay._interpreter", __name__)
"""Backend codege modules for relay."""
from . import compile_engine
"""The interface of expr function exposed from C++."""
from __future__ import absolute_import
import logging
from ... import build_module as _build
from ... import container as _container
from ..._ffi.function import _init_api, register_func
@register_func("relay.backend.lower")
def lower(sch, inputs, func_name, source_func):
"""Backend function for lowering.
Parameters
----------
sch : tvm.Schedule
The schedule.
inputs : List[tvm.Tensor]
The inputs to the function.
func_name : str
The name of the function.
source-func : tvm.relay.Function
The source function to be lowered.
Returns
-------
lowered_funcs : List[tvm.LoweredFunc]
The result of lowering.
"""
import traceback
# pylint: disable=broad-except
try:
f = _build.lower(sch, inputs, name=func_name)
logging.debug("lower function %s", func_name)
logging.debug("%s", _build.lower(sch, inputs, simple_mode=True))
except Exception:
msg = traceback.format_exc()
msg += "Error during compile function\n"
msg += "-----------------------------\n"
msg += source_func.astext()
raise RuntimeError(msg)
return f if isinstance(
f, (_container.Array, tuple, list)) else [f]
@register_func("relay.backend.build")
def build(funcs, target, target_host=None):
"""Backend build function.
Parameters
----------
funcs : List[tvm.LoweredFunc]
The list of lowered functions.
target : tvm.Target
The target to run the code on.
target_host : tvm.Target
The host target.
Returns
-------
module : tvm.Module
The runtime module.
"""
if target_host == "":
target_host = None
return _build.build(funcs, target=target, target_host=target_host)
@register_func("relay._tensor_value_repr")
def _tensor_value_repr(tvalue):
return str(tvalue.data.asnumpy())
@register_func("relay._constant_repr")
def _tensor_constant_repr(tvalue):
return str(tvalue.data.asnumpy())
_init_api("relay.backend", __name__)
"""Backend code generation engine."""
from __future__ import absolute_import
from ..base import register_relay_node, NodeBase
from ... import target as _target
from .. import expr as _expr
from . import _backend
@register_relay_node
class CachedFunc(NodeBase):
"""Low-level tensor function to back a relay primitive function.
"""
pass
@register_relay_node
class CCacheKey(NodeBase):
"""Key in the CompileEngine.
Parameters
----------
source_func : tvm.relay.Function
The source function.
target : tvm.Target
The target we want to run the function on.
"""
def __init__(self, source_func, target):
self.__init_handle_by_constructor__(
_backend._make_CCacheKey, source_func, target)
@register_relay_node
class CCacheValue(NodeBase):
"""Value in the CompileEngine, including usage statistics.
"""
pass
def _get_cache_key(source_func, target):
if isinstance(source_func, _expr.Function):
if isinstance(target, str):
target = _target.create(target)
if not target:
raise ValueError("Need target when source_func is a Function")
return CCacheKey(source_func, target)
if not isinstance(source_func, CCacheKey):
raise TypeError("Expect source_func to be CCacheKey")
return source_func
@register_relay_node
class CompileEngine(NodeBase):
"""CompileEngine to get lowered code.
"""
def __init__(self):
raise RuntimeError("Cannot construct a CompileEngine")
def lower(self, source_func, target=None):
"""Lower a source_func to a CachedFunc.
Parameters
----------
source_func : Union[tvm.relay.Function, CCacheKey]
The source relay function.
target : tvm.Target
The target platform.
Returns
-------
cached_func: CachedFunc
The result of lowering.
"""
key = _get_cache_key(source_func, target)
return _backend._CompileEngineLower(self, key)
def jit(self, source_func, target=None):
"""JIT a source_func to a tvm.Function.
Parameters
----------
source_func : Union[tvm.relay.Function, CCacheKey]
The source relay function.
target : tvm.Target
The target platform.
Returns
-------
cached_func: CachedFunc
The result of lowering.
"""
key = _get_cache_key(source_func, target)
return _backend._CompileEngineJIT(self, key)
def clear(self):
"""clear the existing cached functions"""
_backend._CompileEngineClear(self)
def items(self):
"""List items in the cache.
Returns
-------
item_list : List[Tuple[CCacheKey, CCacheValue]]
The list of items.
"""
res = _backend._CompileEngineListItems(self)
assert len(res) % 2 == 0
return [(res[2*i], res[2*i+1]) for i in range(len(res) // 2)]
def dump(self):
"""Return a string representation of engine dump.
Returns
-------
dump : str
The dumped string representation
"""
items = self.items()
res = "====================================\n"
res += "CompilerEngine dump, %d items cached\n" % len(items)
for k, v in items:
res += "------------------------------------\n"
res += "target={}\n".format(k.target)
res += "use_count={}\n".format(v.use_count)
res += "func_name={}\n".format(v.cached_func.func_name)
res += k.source_func.astext() + "\n"
res += "===================================\n"
return res
def get():
"""Get the global compile engine.
Returns
-------
engine : tvm.relay.backend.CompileEngine
The compile engine.
"""
return _backend._CompileEngineGlobal()
#pylint: disable=no-else-return
"""An interface to the Realy interpreter."""
from __future__ import absolute_import
import numpy as np
from .. import register_func, nd
from .base import NodeBase, register_relay_node
from . import build_module
from . import _make
from . import _interpreter
from . import ir_pass
from .module import Module
from .expr import Call, Constant, GlobalVar, Function, const
from .scope_builder import ScopeBuilder
from .._ffi.base import integer_types
from ..contrib import graph_runtime as tvm_runtime
from .. import cpu
from . import _backend
from .. import _make, ir_pass
from ... import register_func, nd
from ..base import NodeBase, register_relay_node
from ..expr import Call, Constant, GlobalVar, Function, const
from ..scope_builder import ScopeBuilder
class Value(NodeBase):
"""Base class of all values.
"""
@staticmethod
@register_func("relay.from_scalar")
def from_scalar(i, dtype=None):
def from_scalar(value, dtype=None):
"""Convert a Python scalar to a Relay scalar."""
if dtype is None:
if isinstance(i, integer_types):
dtype = 'int32'
elif isinstance(i, float):
dtype = 'float32'
elif isinstance(i, bool):
dtype = 'uint8'
else:
raise Exception("unable to infer dtype {0}".format(type(i)))
return TensorValue(nd.array(np.array(i, dtype=dtype)))
return TensorValue(const(value, dtype).data)
@register_relay_node
......@@ -65,10 +50,6 @@ class TensorValue(Value):
self.__init_handle_by_constructor__(
_make.TensorValue, data)
def as_ndarray(self):
"""Convert a Relay TensorValue into a tvm.ndarray."""
return self.data
def asnumpy(self):
"""Convert a Relay TensorValue into a numpy.ndarray."""
return self.data.asnumpy()
......@@ -79,7 +60,7 @@ class TensorValue(Value):
def _arg_to_ast(arg):
if isinstance(arg, TensorValue):
return Constant(arg.data)
return Constant(arg.data.copyto(_nd.cpu(0)))
elif isinstance(arg, np.ndarray):
return Constant(nd.array(arg))
elif isinstance(arg, Constant):
......@@ -87,29 +68,9 @@ def _arg_to_ast(arg):
else:
return const(arg)
class Executor(object):
"""An abstract interface for executing Relay programs."""
def __init__(self, mod=None):
"""
Parameters
----------
mod: relay.Module
The module.
"""
if mod is None:
self.mod = Module({})
else:
self.mod = mod
def optimize(self, expr):
# TODO: We need to move this optimization code into the optimizer/pass manager
ck_expr = ir_pass.infer_type(expr, mod=self.mod)
fused_expr = ir_pass.fuse_ops(self.mod, ck_expr)
ck_fused = ir_pass.infer_type(fused_expr, mod=self.mod)
return ck_fused
def _make_executor(self, _):
"""
Construct a Python function that implements the evaluation
......@@ -122,50 +83,85 @@ class Executor(object):
Returns
-------
executor: function
executor: function,
A Python function which implements the behavior of `expr`.
"""
raise Exception("abstract method: please implement me.")
raise NotImplementedError()
def evaluate(self, expr, params=None):
def evaluate(self, expr, binds=None):
"""
Evaluate a Relay expression on the interpreter.
Evaluate a Relay expression on the executor.
Parameters
----------
expr: tvm.relay.Expr
The expression to evaluate.
binds: Map[tvm.relay.Var, tvm.relay.Expr]
Additional binding of free variable.
Returns
-------
val : Union[function, Value]
The evaluation result.
"""
if params:
if binds:
scope_builder = ScopeBuilder()
for key in params:
value = params[key]
scope_builder.let(key, value)
for key, value in binds.items():
scope_builder.let(key, _arg_to_ast(value))
scope_builder.ret(expr)
expr = scope_builder.get()
if isinstance(expr, Function):
assert not ir_pass.free_vars(expr)
executor = self._make_executor(expr)
# If we are evaluating a function or top-level defintion
# the user must call the function themselves.
#
# If we are evaluating an open term with parameters we will
# just return them the result.
if isinstance(expr, (Function, GlobalVar)):
return executor
else:
return executor()
return self._make_executor(expr)
# normal expression evaluated by running a function.
func = Function([], expr)
return self._make_executor(func)()
class Interpreter(Executor):
"""
A wrapper around the Relay interpreter, implements the excecutor interface.
Simple interpreter interface.
Parameters
----------
mod : tvm.relay.Module
The module to support the execution.
ctx : tvm.TVMContext
The runtime context to run the code on.
target : tvm.Target
The target option to build the function.
"""
def __init__(self, mod=None):
Executor.__init__(self, mod)
def __init__(self, mod, ctx, target):
self.mod = mod
self.ctx = ctx
self.target = target
self._intrp = _backend.CreateInterpreter(mod, ctx, target)
def optimize(self, expr):
"""Optimize an expr.
Parameters
----------
expr : Expr
The expression to be optimized.
Returns
-------
opt_expr : Expr
The optimized expression.
"""
# TODO: We need to move this optimization code into the optimizer/pass manager
ck_expr = ir_pass.infer_type(expr, mod=self.mod)
fused_expr = ir_pass.fuse_ops(ck_expr)
ck_fused = ir_pass.infer_type(fused_expr, mod=self.mod)
return ck_fused
def _make_executor(self, expr):
def _interp_wrapper(*args):
......@@ -178,46 +174,9 @@ class Interpreter(Executor):
func = self.optimize(func)
self.mod._add(expr, func, True)
opt_expr = Call(expr, relay_args)
return _interpreter.evaluate(self.mod, opt_expr)
elif isinstance(expr, Function):
return self._intrp(opt_expr)
else:
call = Call(expr, relay_args)
opt_expr = self.optimize(call)
return _interpreter.evaluate(self.mod, opt_expr)
else:
assert not args
opt_expr = self.optimize(expr)
return _interpreter.evaluate(self.mod, opt_expr)
return self._intrp(opt_expr)
return _interp_wrapper
class GraphRuntime(Executor):
"""A wrapper around the TVM graph runtime, implements the Executor interface."""
def __init__(self, mod=None):
Executor.__init__(self, mod)
def _make_executor(self, expr):
def _graph_wrapper(*args):
func = self.optimize(expr)
graph_json, mod, params = build_module.build(func, mod=self.mod)
assert params is None
gmodule = tvm_runtime.create(graph_json, mod, cpu(0))
# Create map of inputs.
inputs = {}
for i, arg in enumerate(args):
inputs[func.params[i].name_hint] = arg
# Set the inputs here.
gmodule.set_input(**inputs)
# Run the module, and fetch the output.
gmodule.run()
return gmodule.get_output(0)
return _graph_wrapper
def create_executor(mode='debug', mod=None):
if mode == 'debug':
return Interpreter(mod)
elif mode == 'graph':
return GraphRuntime(mod)
else:
raise Exception("unknown mode {0}".format(mode))
......@@ -2,45 +2,257 @@
Construct the necessary state for the TVM graph runtime
from a Relay expression.
"""
from ..build_module import build as tvm_build_module
from . graph_runtime_codegen import GraphRuntimeCodegen
from ..build_module import build as _tvm_build_module
from .. import nd as _nd, target as _target, autotvm
from ..contrib import graph_runtime as _graph_rt
from . import ir_pass
from .module import Module
from .backend import interpreter as _interpreter
from .backend import graph_runtime_codegen as _graph_gen
def build(func, params=None, target=None, mod=None):
# List of optimization pass and level when switch on
OPT_PASS_LEVEL = {
"SimplifyInference": 0,
"OpFusion": 1,
"FoldScaleAxis": 3,
}
class BuildConfig(object):
"""Configuration scope to set a build config option.
Parameters
----------
kwargs
Keyword arguments of configurations to set.
"""
current = None
defaults = {
"opt_level": 2,
"add_pass": None,
}
def __init__(self, **kwargs):
self._old_scope = None
for k, _ in kwargs.items():
if k not in BuildConfig.defaults:
raise ValueError(
"invalid argument %s, candidates are %s" % (k, BuildConfig.defaults.keys()))
self._attr = kwargs
def __getattr__(self, name):
if name not in self._attr:
return BuildConfig.defaults[name]
return self._attr[name]
def __enter__(self):
# pylint: disable=protected-access
self._old_scope = BuildConfig.current
attr = BuildConfig.current._attr.copy()
attr.update(self._attr)
self._attr = attr
BuildConfig.current = self
return self
def __exit__(self, ptype, value, trace):
assert self._old_scope
BuildConfig.current = self._old_scope
def pass_enabled(self, pass_name):
"""Get whether pass is enabled.
Parameters
----------
pass_name : str
The optimization pass name
Returns
-------
enabled : bool
Whether pass is enabled.
"""
if self.add_pass and pass_name in self.add_pass:
return True
return self.opt_level >= OPT_PASS_LEVEL[pass_name]
BuildConfig.current = BuildConfig()
def build_config(**kwargs):
"""Configure the build behavior by setting config variables.
Parameters
----------
opt_level: int, default=2
Optimization level. See OPT_PASS_LEVEL for level of each pass.
add_pass: set of str
Optimization pass to be added regardless of optimization level.
Returns
-------
config: BuildConfig
The build configuration
"""
return BuildConfig(**kwargs)
def optimize(func):
"""Perform target invariant optimizations.
Parameters
----------
func : tvm.relay.Function
The input to optimization.
Returns
-------
opt_func : tvm.relay.Function
The optimized version of the function.
"""
Compile a single function to the components needed by the
TVM RTS.
cfg = BuildConfig.current
if cfg.pass_enabled("FoldScaleAxis"):
func = ir_pass.infer_type(func)
func = ir_pass.simplify_inference(func)
if cfg.pass_enabled("FoldScaleAxis"):
func = ir_pass.infer_type(func)
func = ir_pass.backward_fold_scale_axis(func)
func = ir_pass.infer_type(func)
func = ir_pass.forward_fold_scale_axis(func)
return func
def build(func,
target=None,
target_host=None,
params=None):
"""Build a function to run on TVM graph runtime.
Parameters
----------
func: relay.Expr
func: relay.Function
The function to build.
target: optional str
The target platform.
target : str or :any:`tvm.target.Target`, optional
The build target
target_host : str or :any:`tvm.target.Target` optional
Host compilation target, if target is device.
When TVM compiles device specific program such as CUDA,
we also need host(CPU) side code to interact with the driver
setup the dimensions and parameters correctly.
target_host is used to specify the host side codegen target.
By default, llvm is used if it is enabled,
otherwise a stackvm intepreter is used.
params : dict of str to NDArray
Input parameters to the graph that do not change
during inference time. Used for pre-compute
folding optimization.
Returns
-------
(graph_json, mod, params): tuple of (str, tvm.Module, dict)
The outputs of building a Relay function for the TVM runtime.
graph_json : str
The json string that can be accepted by graph runtime.
mod : tvm.Module
The module containing necessary libraries.
params : dict
The parameters of the final graph.
"""
target = target if target else _target.current_target()
if target is None:
target = 'llvm'
if mod is None:
mod = Module({})
comp = GraphRuntimeCodegen(mod)
# NB(@jroesch) This creates lowered functions, and generates names for them
#
# We need these names to emit the correct graph as these are names of the
# functions contained in the module.
lowered_ops = ir_pass.lower_ops(mod, func)
mod = tvm_build_module([lf.lowered_func for lf in lowered_ops], target)
# Therefore the call to compile must come after.
comp.codegen(func)
graph_json = comp.to_json()
raise ValueError("Target is not set in env or passed as argument.")
target = _target.create(target)
# If current dispatch context is fallback context (the default root context),
# then load pre-tuned parameters from TopHub
if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext):
tophub_context = autotvm.tophub.context(target)
else:
tophub_context = autotvm.util.EmptyContext()
with tophub_context:
func = optimize(func)
# Fuse ops before running code gen
func = ir_pass.infer_type(func)
func = ir_pass.fuse_ops(func)
# Graph code generation
func = ir_pass.infer_type(func)
graph_gen = _graph_gen.GraphRuntimeCodegen(mod=None, target=target)
graph_json, lowered_funcs = graph_gen.codegen(func)
mod = _tvm_build_module(lowered_funcs, target=target, target_host=target_host)
return graph_json, mod, params
class GraphExecutor(_interpreter.Executor):
"""Wrapper around Executor interface.
This executor is used for debug and testing purpoes.
Parameters
----------
mod : tvm.relay.Module
The module to support the execution.
ctx : tvm.TVMContext
The runtime context to run the code on.
target : tvm.Target
The target option to build the function.
"""
def __init__(self, mod, ctx, target):
self.mod = mod
self.ctx = ctx
self.target = target
def _make_executor(self, func):
def _graph_wrapper(*args):
graph_json, mod, params = build(func, target=self.target)
assert params is None
gmodule = _graph_rt.create(graph_json, mod, self.ctx)
# Create map of inputs.
for i, arg in enumerate(args):
gmodule.set_input(i, arg)
# Run the module, and fetch the output.
gmodule.run()
return gmodule.get_output(0)
return _graph_wrapper
def create_executor(kind="debug",
mod=None,
ctx=None,
target="llvm"):
"""Factory function to create an executor.
Parameters
----------
kind : str
The type of executor
mod : relay.Mod
The mod
ctx : tvm.TVMContext
The context to execute the code.
target : tvm.Target
The corresponding context
"""
if ctx is not None:
assert ctx.device_type == _nd.context(str(target), 0).device_type
else:
ctx = _nd.context(str(target), 0)
if isinstance(target, str):
target = _target.create(target)
if kind == "debug":
return _interpreter.Interpreter(mod, ctx, target)
elif kind == "graph":
return GraphExecutor(mod, ctx, target)
else:
raise RuntimeError("unknown mode {0}".format(mode))
......@@ -319,12 +319,11 @@ class TupleGetItem(Expr):
self.__init_handle_by_constructor__(
_make.TupleGetItem, tuple_value, index)
class ExprFunctor(object):
"""
An abstract visitor defined over Expr.
A Python version of the class defined in `expr_functor.h`.
Defines the default dispatch over expressions, and
implements memoization.
"""
......@@ -352,6 +351,8 @@ class ExprFunctor(object):
res = self.visit_if(expr)
elif isinstance(expr, Tuple):
res = self.visit_tuple(expr)
elif isinstance(expr, TupleGetItem):
res = self.visit_tuple_getitem(expr)
elif isinstance(expr, Constant):
res = self.visit_constant(expr)
else:
......@@ -361,31 +362,34 @@ class ExprFunctor(object):
return res
def visit_function(self, _):
raise Exception("Abstract method please implement me.")
raise NotImplementedError()
def visit_let(self, _):
raise Exception("Abstract method please implement me.")
raise NotImplementedError()
def visit_call(self, _):
raise Exception("Abstract method please implement me.")
raise NotImplementedError()
def visit_var(self, _):
raise Exception("Abstract method please implement me.")
raise NotImplementedError()
def visit_type(self, typ):
return typ
def visit_if(self, _):
raise Exception("Abstract method please implement me.")
raise NotImplementedError()
def visit_tuple(self, _):
raise Exception("Abstract method please implement me.")
raise NotImplementedError()
def visit_tuple_getitem(self, _):
raise NotImplementedError()
def visit_constant(self, _):
raise Exception("Abstract method please implement me.")
raise NotImplementedError()
def visit_global_var(self, _):
raise Exception("Abstract method please implement me.")
raise NotImplementedError()
class ExprMutator(ExprFunctor):
......@@ -395,7 +399,6 @@ class ExprMutator(ExprFunctor):
The default behavior recursively traverses the AST
and reconstructs the AST.
"""
def visit_function(self, fn):
new_body = self.visit(fn.body)
return Function(
......@@ -429,9 +432,19 @@ class ExprMutator(ExprFunctor):
def visit_tuple(self, tup):
return Tuple([self.visit(field) for field in tup.fields])
def visit_tuple_getitem(self, op):
tuple_value = self.visit(op.tuple_value)
if not tuple_value.same_as(op.tuple_value):
return TupleGetItem(tuple_value, op.index)
return op
def visit_global_var(self, gvar):
return gvar
def visit_constant(self, rconst):
return rconst
class TupleWrapper(object):
"""TupleWrapper.
......
......@@ -160,6 +160,7 @@ def free_type_vars(expr):
"""
return _ir_pass.free_type_vars(expr)
def simplify_inference(expr):
""" Simplify the data-flow graph for inference phase.
......@@ -176,6 +177,7 @@ def simplify_inference(expr):
"""
return _ir_pass.simplify_inference(expr)
def dead_code_elimination(expr):
""" Remove expressions which does not effect the program result (dead code).
......@@ -256,8 +258,18 @@ def structural_hash(value):
"relay.Expr or relay.Type").format(type(value))
raise TypeError(msg)
def fuse_ops(expr, mod):
return _ir_pass.FuseOps(mod, expr)
def lower_ops(mod, expr, target='llvm'):
return _ir_pass.LowerOps(mod, expr, target)
def fuse_ops(expr):
"""Fuse operators in expr together.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
Returns
-------
transformed_expr : tvm.relay.Expr
Transformed expression, containing fused result.
"""
return _ir_pass.FuseOps(expr)
#pylint: disable=invalid-name, unused-argument
"""Backend compiler related feature registration"""
from __future__ import absolute_import
import tvm
import topi
import topi.cuda
from . import register_schedule, register_compute
from .op import register_compute, register_schedule, register_pattern, OpPattern
def schedule_injective(outputs, target):
"""Generic schedule for binary broadcast."""
with tvm.target.create(target):
with target:
return topi.generic.schedule_injective(outputs)
schedule_broadcast = schedule_injective
schedule_elemwise = schedule_injective
# log
@register_compute("log")
def log_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.log(inputs[0])]
register_compute("log", log_compute)
register_schedule("log", schedule_broadcast)
# exp
@register_compute("exp")
def exp_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.exp(inputs[0])]
register_compute("exp", exp_compute)
register_schedule("exp", schedule_broadcast)
# sqrt
@register_compute("sqrt")
def sqrt_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.sqrt(inputs[0])]
register_compute("sqrt", sqrt_compute)
register_schedule("sqrt", schedule_broadcast)
# sigmoid
@register_compute("sigmoid")
def sigmoid_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.sigmoid(inputs[0])]
register_compute("sigmoid", sigmoid_compute)
register_schedule("sigmoid", schedule_broadcast)
# floor
@register_compute("floor")
def floor_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.floor(inputs[0])]
register_compute("floor", floor_compute)
register_schedule("floor", schedule_broadcast)
# ceil
@register_compute("ceil")
def ceil_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.ceil(inputs[0])]
register_compute("ceil", ceil_compute)
register_schedule("ceil", schedule_broadcast)
# trunc
@register_compute("trunc")
def trunc_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.trunc(inputs[0])]
register_compute("trunc", trunc_compute)
register_schedule("trunc", schedule_broadcast)
# round
@register_compute("round")
def round_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.round(inputs[0])]
register_compute("round", round_compute)
register_schedule("round", schedule_broadcast)
# abs
@register_compute("abs")
def abs_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.abs(inputs[0])]
register_compute("abs", abs_compute)
register_schedule("abs", schedule_broadcast)
# tanh
@register_compute("tanh")
def tanh_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.tanh(inputs[0])]
register_compute("tanh", tanh_compute)
register_schedule("tanh", schedule_broadcast)
# negative
@register_compute("negative")
def negative_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.negative(inputs[0])]
register_compute("negative", negative_compute)
register_schedule("negative", schedule_broadcast)
# add
@register_compute("add")
def add_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.add(inputs[0], inputs[1])]
register_compute("add", add_compute)
register_schedule("add", schedule_injective)
# subtract
@register_compute("subtract")
def subtract_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.subtract(inputs[0], inputs[1])]
register_compute("subtract", subtract_compute)
register_schedule("subtract", schedule_broadcast)
# multiply
@register_compute("multiply")
def multiply_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.multiply(inputs[0], inputs[1])]
register_compute("multiply", multiply_compute)
register_schedule("multiply", schedule_broadcast)
# divide
@register_compute("divide")
def divide_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.divide(inputs[0], inputs[1])]
register_compute("divide", divide_compute)
register_schedule("divide", schedule_broadcast)
# pow
def pow_compute(attrs, inputs, output_type, target):
# power
@register_compute("power")
def power_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.power(inputs[0], inputs[1])]
register_compute("pow", pow_compute)
register_schedule("pow", schedule_injective)
register_schedule("power", schedule_injective)
# mod
@register_compute("mod")
def mod_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.mod(inputs[0], inputs[1])]
register_compute("mod", mod_compute)
register_schedule("mod", schedule_broadcast)
# equal
@register_compute("equal")
def equal_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.equal(inputs[0], inputs[1])]
register_compute("equal", equal_compute)
register_schedule("equal", schedule_broadcast)
# not_equal
@register_compute("not_equal")
def not_equal_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.not_equal(inputs[0], inputs[1])]
register_compute("not_equal", not_equal_compute)
register_schedule("not_equal", schedule_broadcast)
# less
@register_compute("less")
def less_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.less(inputs[0], inputs[1])]
register_compute("less", less_compute)
register_schedule("less", schedule_broadcast)
# less equal
@register_compute("less_equal")
def less_equal_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.less_equal(inputs[0], inputs[1])]
register_compute("less_equal", less_equal_compute)
register_schedule("less_equal", schedule_broadcast)
# greater
@register_compute("greater")
def greater_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.greater(inputs[0], inputs[1])]
register_compute("greater", greater_compute)
register_schedule("greater", schedule_broadcast)
# greater equal
@register_compute("greater_equal")
def greater_equal_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.greater_equal(inputs[0], inputs[1])]
register_compute("greater_equal", greater_equal_compute)
register_schedule("greater_equal", schedule_broadcast)
# maximum
@register_compute("maximum")
def maximum_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.maximum(inputs[0], inputs[1])]
register_compute("maximum_compute", maximum_compute)
register_schedule("maximum_compute", schedule_injective)
# minimum
@register_compute("minimum")
def minimum_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.minimum(inputs[0], inputs[1])]
register_compute("minimum", minimum_compute)
register_schedule("minimum", schedule_injective)
# right shift
@register_compute("right_shift")
def right_shift_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.right_shift(inputs[0], inputs[1])]
register_compute("right_shift", right_shift_compute)
register_schedule("right_shift", schedule_injective)
# lift shift
# left shift
@register_compute("left_shift")
def left_shift_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.left_shift(inputs[0], inputs[1])]
register_compute("left_shift", left_shift_compute)
register_schedule("left_shift", schedule_injective)
# zeros
@register_compute("zeros")
def zeros_compute(attrs, inputs, output_type, target):
assert not inputs
return [topi.full(output_type.shape, output_type.dtype, 0.0)]
register_compute("zeros", zeros_compute)
register_schedule("zeros", schedule_injective)
register_schedule("zeros", schedule_broadcast)
register_pattern("zeros", OpPattern.ELEMWISE)
# zeros_like
@register_compute("zeros_like")
def zeros_like_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.full_like(inputs[0], 0.0)]
register_compute("zeros_like", zeros_like_compute)
register_schedule("zeros_like", schedule_injective)
register_schedule("zeros_like", schedule_broadcast)
# ones
@register_compute("ones")
def ones_compute(attrs, inputs, output_type, target):
assert not inputs
return [topi.full(output_type.shape, output_type.dtype, 1.0)]
register_compute("ones", ones_compute)
register_schedule("ones", schedule_injective)
register_schedule("ones", schedule_broadcast)
register_pattern("ones", OpPattern.ELEMWISE)
# ones_like
@register_compute("ones_like")
def ones_like(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.full_like(inputs[0], 1.0)]
register_compute("ones_like", ones_like)
register_schedule("ones_like", schedule_injective)
register_schedule("ones_like", schedule_broadcast)
# clip
@register_compute("clip")
def clip_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.clip(inputs[0], attrs.a_min, attrs.a_max)]
register_schedule("clip", schedule_elemwise)
register_pattern("clip", OpPattern.ELEMWISE)
register_compute("clip", clip_compute)
register_schedule("clip", schedule_injective)
# concatenate
@register_compute("concatenate")
def concatenate_compute(attrs, inputs, output_type, target):
return [topi.concatenate(inputs, axis=attrs.axis)]
register_schedule("concatenate", schedule_injective)
register_pattern("concatenate", OpPattern.INJECTIVE)
......@@ -72,13 +72,80 @@ def register(op_name, attr_key, value=None, level=10):
"""internal register function"""
_Register(op_name, attr_key, v, level)
return v
return _register(value) if value else _register
return _register(value) if value is not None else _register
def register_schedule(op_name, schedule):
register(op_name, "FTVMSchedule", schedule)
def register_compute(op_name, compute):
register(op_name, "FTVMCompute", compute)
class OpPattern(object):
"""Operator generic patterns
See Also
--------
top.tag : Contains explanation of the tag type.
"""
# Elementwise operator
ELEMWISE = 0
# Broadcast operator
BROADCAST = 1
# Injective mapping
INJECTIVE = 2
# Comunication
COMM_REDUCE = 3
# Complex op, can still fuse ewise into it
OUT_ELEMWISE_FUSABLE = 4
# Not fusable opaque op
OPAQUE = 8
def register_schedule(op_name, schedule=None, level=10):
"""Register schedule function for an op
Parameters
----------
op_name : str
The name of the op.
schedule : function
The schedule function.
level : int
The priority level
"""
return register(op_name, "FTVMSchedule", schedule, level)
def register_compute(op_name, compute=None, level=10):
"""Register compute function for an op.
Parameters
----------
op_name : str
The name of the op.
compute : function
The compute function.
level : int
The priority level
"""
return register(op_name, "FTVMCompute", compute, level)
def register_pattern(op_name, pattern, level=10):
"""Register operator pattern for an op.
Parameters
----------
op_name : str
The name of the op.
pattern : int
The pattern being used.
level : int
The priority level
"""
return register(op_name, "TOpPattern", pattern, level)
_init_api("relay.op", __name__)
......
......@@ -266,7 +266,7 @@ def divide(lhs, rhs):
return _make.divide(lhs, rhs)
def pow(lhs, rhs):
def power(lhs, rhs):
"""Power with numpy-style broadcasting.
Parameters
......@@ -281,7 +281,7 @@ def pow(lhs, rhs):
result : relay.Expr
The computed result.
"""
return _make.pow(lhs, rhs)
return _make.power(lhs, rhs)
def mod(lhs, rhs):
......
......@@ -6,3 +6,4 @@ from . import resnet
from . import dqn
from . import dcgan
from . import mobilenet
from .config import ctx_list
"""Configuration about tests"""
from __future__ import absolute_import as _abs
import os
import tvm
def ctx_list():
"""Get context list for testcases"""
device_list = os.environ.get("RELAY_TEST_TARGETS", "")
device_list = (device_list.split(",") if device_list
else ["llvm", "cuda"])
device_list = set(device_list)
res = [(device, tvm.context(device, 0)) for device in device_list]
return [x for x in res if x[1].exist]
......@@ -154,13 +154,15 @@ std::unordered_set<std::string> TargetNode::libs() const {
return result;
}
std::string TargetNode::str() const {
const std::string& TargetNode::str() const {
if (str_repr_.length() != 0) return str_repr_;
std::ostringstream result;
result << target_name;
for (const auto &x : options()) {
result << " " << x;
}
return result.str();
str_repr_ = result.str();
return str_repr_;
}
......
/*!
* Copyright (c) 2018 by Contributors
* \file relay/backend/compile_engine.h
* \brief Internal compialtion engine handle function cache.
* and interface to low level code generation.
*/
#ifndef TVM_RELAY_BACKEND_COMPILE_ENGINE_H_
#define TVM_RELAY_BACKEND_COMPILE_ENGINE_H_
#include <tvm/lowered_func.h>
#include <tvm/relay/expr.h>
#include <string>
#include <functional>
namespace tvm {
namespace relay {
/*! \brief Node container to represent a cached function. */
struct CachedFuncNode : public Node {
/* \brief compiled target */
tvm::Target target;
/*! \brief Function name */
std::string func_name;
/* \brief The inputs to the function */
tvm::Array<Tensor> inputs;
/* \brief The outputs to the function */
tvm::Array<Tensor> outputs;
/*! \brief The lowered functions to support the function. */
tvm::Array<tvm::LoweredFunc> funcs;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("target", &target);
v->Visit("func_name", &func_name);
v->Visit("inputs", &inputs);
v->Visit("outputs", &outputs);
v->Visit("funcs", &funcs);
}
static constexpr const char* _type_key = "relay.CachedFunc";
TVM_DECLARE_NODE_TYPE_INFO(CachedFuncNode, Node);
};
TVM_DEFINE_NODE_REF(CachedFunc, CachedFuncNode);
class CCacheKey;
/*! \brief Compile cache key */
class CCacheKeyNode : public Node {
public:
/*! \brief The source function to be lowered. */
Function source_func;
/*! \brief The hardware target.*/
Target target;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("source_func", &source_func);
v->Visit("target", &target);
}
/*! \return The hash value of CCacheKey. */
inline size_t Hash() const;
/*!
* \brief check content equality
* \param other The other value.
* \return The result of equality check.
*/
inline bool Equal(const CCacheKeyNode* other) const;
/*!
* \brief create a cache key.
* \param source_func The source function.
* \param target The target device.
* \return the created key.
*/
TVM_DLL static CCacheKey make(Function source_func,
Target target);
static constexpr const char* _type_key = "relay.CCacheKey";
TVM_DECLARE_NODE_TYPE_INFO(CCacheKeyNode, tvm::Node);
private:
/*!
* \brief internal cached hash value.
*/
mutable size_t hash_{0};
};
/*! \brief cache entry used in compile engine */
class CCacheKey : public NodeRef {
public:
CCacheKey() {}
explicit CCacheKey(NodePtr<Node> n) : NodeRef(n) {}
const CCacheKeyNode* operator->() const {
return static_cast<CCacheKeyNode*>(node_.get());
}
// comparator
inline bool operator==(const CCacheKey& other) const {
CHECK(defined() && other.defined());
return (*this)->Equal(other.operator->());
}
using ContainerType = CCacheKeyNode;
};
/*! \brief Node container for compile cache. */
class CCacheValueNode : public Node {
public:
/*! \brief The corresponding function */
CachedFunc cached_func;
/*! \brief Result of Packed function generated by JIT */
PackedFunc packed_func;
/*! \brief usage statistics */
int use_count{0};
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("cached_func", &cached_func);
v->Visit("use_count", &use_count);
}
static constexpr const char* _type_key = "relay.CCacheValue";
TVM_DECLARE_NODE_TYPE_INFO(CCacheValueNode, tvm::Node);
};
/*! \brief cache entry used in compile engine */
class CCacheValue : public NodeRef {
public:
CCacheValue() {}
explicit CCacheValue(NodePtr<Node> n) : NodeRef(n) {}
CCacheValueNode* operator->() {
return static_cast<CCacheValueNode*>(node_.get());
}
const CCacheValueNode* operator->() const {
return static_cast<const CCacheValueNode*>(node_.get());
}
using ContainerType = CCacheValueNode;
};
/*!
* \brief Backend compilation engine for
* low level code generation.
*/
class CompileEngineNode : public Node {
public:
/*!
* \brief Get lowered result.
* \param key The key to the cached function.
* \return The result.
*/
virtual CachedFunc Lower(const CCacheKey& key) = 0;
/*!
* \brief Just in time compile to get a PackedFunc.
* \param key The key to the cached function.
* \return The result.
*/
virtual PackedFunc JIT(const CCacheKey& key) = 0;
/*! \brief clear the cache. */
virtual void Clear() = 0;
// VisitAttrs
void VisitAttrs(AttrVisitor*) final {}
static constexpr const char* _type_key = "relay.CompileEngine";
TVM_DECLARE_NODE_TYPE_INFO(CompileEngineNode, Node);
};
/*! \brier cache entry used in compile engine */
class CompileEngine : public NodeRef {
public:
CompileEngine() {}
explicit CompileEngine(NodePtr<Node> n) : NodeRef(n) {}
CompileEngineNode* operator->() {
return static_cast<CompileEngineNode*>(node_.get());
}
using ContainerType = CompileEngineNode;
/*! \brief The global compile engine. */
TVM_DLL static const CompileEngine& Global();
};
// implementations
inline size_t CCacheKeyNode::Hash() const {
if (hash_ != 0) return hash_;
// do structral hash, avoid 0.
hash_ = StructuralHash()(this->source_func);
hash_ = dmlc::HashCombine(
hash_, std::hash<std::string>()(target->str()));
if (hash_ == 0) hash_ = 1;
return hash_;
}
inline bool CCacheKeyNode::Equal(
const CCacheKeyNode* other) const {
if (Hash() != other->Hash()) return false;
return this->target->str() == other->target->str() &&
AlphaEqual(this->source_func, other->source_func);
}
} // namespace relay
} // namespace tvm
namespace std {
// overload hash
template<>
struct hash<::tvm::relay::CCacheKey> {
size_t operator()(const ::tvm::relay::CCacheKey& key) const {
CHECK(key.defined());
return key->Hash();
}
};
} // namespace std
#endif // TVM_RELAY_BACKEND_COMPILE_ENGINE_H_
......@@ -34,10 +34,12 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
TensorType ConstantNode::tensor_type() const {
auto dtype = TVMType2Type(data->dtype);
Array<tvm::Expr> shape;
for (int i = 0; i < data->ndim; i++) {
shape.push_back(tvm::ir::IntImm::make(HalideIR::Int(64), data->shape[i]));
CHECK_LE(data->shape[i], std::numeric_limits<int32_t>::max());
CHECK_GE(data->shape[i], std::numeric_limits<int32_t>::min());
shape.push_back(
tvm::ir::IntImm::make(Int(32), data->shape[i]));
}
return TensorTypeNode::make(shape, dtype);
......
......@@ -67,13 +67,15 @@ const GenericOpMap& Op::GetGenericAttr(const std::string& key) {
return *it->second.get();
}
void OpRegistry::UpdateAttr(const std::string& key, TVMRetValue value,
void OpRegistry::UpdateAttr(const std::string& key,
TVMRetValue value,
int plevel) {
OpManager* mgr = OpManager::Global();
std::lock_guard<std::mutex> lock(mgr->mutex);
std::unique_ptr<GenericOpMap>& op_map = mgr->attr[key];
if (op_map == nullptr) {
op_map.reset(new GenericOpMap());
op_map->attr_name_ = key;
}
uint32_t index = op_->index_;
if (op_map->data_.size() <= index) {
......@@ -112,31 +114,31 @@ TVM_REGISTER_API("relay.op._OpGetAttr")
});
TVM_REGISTER_API("relay.op._Register")
.set_body([](TVMArgs args, TVMRetValue* rv) {
std::string op_name = args[0];
std::string attr_key = args[1];
runtime::TVMArgValue value = args[2];
int plevel = args[3];
auto& reg =
OpRegistry::Registry()->__REGISTER_OR_GET__(op_name).set_name();
// enable resgiteration and override of certain properties
if (attr_key == "num_inputs" && plevel > 128) {
reg.set_num_inputs(value);
} else if (attr_key == "attrs_type_key" && plevel > 128) {
reg.set_attrs_type_key(value);
.set_body([](TVMArgs args, TVMRetValue* rv) {
std::string op_name = args[0];
std::string attr_key = args[1];
runtime::TVMArgValue value = args[2];
int plevel = args[3];
auto& reg =
OpRegistry::Registry()->__REGISTER_OR_GET__(op_name).set_name();
// enable resgiteration and override of certain properties
if (attr_key == "num_inputs" && plevel > 128) {
reg.set_num_inputs(value);
} else if (attr_key == "attrs_type_key" && plevel > 128) {
reg.set_attrs_type_key(value);
} else {
// normal attr table override.
if (args[2].type_code() == kFuncHandle) {
// do an eager copy of the PackedFunc
PackedFunc f = args[2];
// If we get a function from frontend, avoid deleting it.
OpManager::Global()->frontend_funcs.push_back(new PackedFunc(f));
reg.set_attr(attr_key, f, plevel);
} else {
// normal attr table override.
if (args[2].type_code() == kFuncHandle) {
// do an eager copy of the PackedFunc
PackedFunc f = args[2];
// If we get a function from frontend, avoid deleting it.
OpManager::Global()->frontend_funcs.push_back(new PackedFunc(f));
reg.set_attr(attr_key, f, plevel);
} else {
reg.set_attr(attr_key, args[2], plevel);
}
reg.set_attr(attr_key, args[2], plevel);
}
});
}
});
NodePtr<Node> CreateOp(const std::string& name) {
auto op = Op::Get(name);
......
......@@ -271,7 +271,7 @@ class TextPrinter :
TextValue VisitExpr_(const FunctionNode* op) final {
TextValue id = AllocTempVar();
std::ostringstream os;
os << id << " = function";
os << id << " = fn";
this->PrintFuncInternal(os.str(), GetRef<Function>(op));
this->PrintEndInst("\n");
return id;
......@@ -516,11 +516,14 @@ class TextPrinter :
stream_ << ",\n";
}
}
stream_ << ") ";
stream_ << ')';
if (fn->ret_type.defined()) {
stream_ << " -> ";
stream_ << '\n';
this->PrintIndent(decl_indent);
stream_ << "-> ";
this->PrintType(fn->ret_type, stream_);
}
stream_ << ' ';
this->PrintScope(fn->body);
}
/*!
......
......@@ -9,6 +9,7 @@
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <vector>
namespace tvm {
......@@ -44,7 +45,8 @@ std::vector<T> AsVector(const Array<T> &array) {
}); \
RELAY_REGISTER_OP(OpName) \
.set_num_inputs(1) \
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("data", "Tensor", "The input tensor.") \
.set_attr<TOpPattern>("TOpPattern", kElemWise)
/*! Quick helper macro
* - Expose a positional make function to construct the node.
......@@ -68,7 +70,8 @@ std::vector<T> AsVector(const Array<T> &array) {
.set_num_inputs(2) \
.add_argument("lhs", "Tensor", "The left hand side tensor.") \
.add_argument("rhs", "Tensor", "The right hand side tensor.") \
.add_type_rel("Broadcast", BroadcastRel)
.add_type_rel("Broadcast", BroadcastRel) \
.set_attr<TOpPattern>("TOpPattern", kBroadcast)
} // namespace relay
} // namespace tvm
......
......@@ -46,7 +46,7 @@ RELAY_REGISTER_BINARY_OP("relay.op._make.", "multiply")
.describe("Elementwise multiply with broadcasting")
.set_support_level(1);
RELAY_REGISTER_BINARY_OP("relay.op._make.", "pow")
RELAY_REGISTER_BINARY_OP("relay.op._make.", "power")
.describe("Elementwise power with broadcasting")
.set_support_level(4);
......@@ -65,7 +65,8 @@ RELAY_REGISTER_BINARY_OP("relay.op._make.", "mod")
.set_num_inputs(2) \
.add_argument("lhs", "Tensor", "The left hand side tensor.") \
.add_argument("rhs", "Tensor", "The right hand side tensor.") \
.add_type_rel("BroadcastComp", BroadcastCompRel)
.add_type_rel("BroadcastComp", BroadcastCompRel) \
.set_attr<TOpPattern>("TOpPattern", kBroadcast)
RELAY_REGISTER_CMP_OP("equal")
.describe("Elementwise equal compare with broadcasting")
......
......@@ -3,32 +3,32 @@
*
* \file src/tvm/relay/pass/fuse_ops.cc
*
* \brief Fuse Relay eligble sequences of Relay operators into a single one.
*
* \brief This is a backend-aware optimization pass.
* Fuse necessary ops into a single one.
*/
#include <tvm/ir_operator.h>
#include <tvm/relay/pass.h>
#include <tvm/runtime/module.h>
#include <tvm/lowered_func.h>
#include <tvm/operation.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/logging.h>
#include "../ir/type_functor.h"
namespace tvm {
namespace relay {
using namespace runtime;
struct AbstractFusableOps : ExprMutator {
Module mod;
Array<GlobalVar> fusable_funcs;
int counter = 0;
size_t expr_hash;
AbstractFusableOps(Module mod, size_t expr_hash) : mod(mod), expr_hash(expr_hash) {}
// Simple fuser that only makes each operator function as primitive.
class SimpleFuser : public ExprMutator {
public:
// Skip primitive function.
Expr VisitExpr_(const FunctionNode* fn_node) {
NodeRef res = FunctionGetAttr(GetRef<Function>(fn_node), "Primitive");
const ir::IntImm* pval = res.as<ir::IntImm>();
if (pval && pval->value != 0) {
return GetRef<Expr>(fn_node);
} else {
return ExprMutator::VisitExpr_(fn_node);
}
}
Expr VisitExpr_(const CallNode* call) {
if (auto op_node = call->op.as<OpNode>()) {
if (call->op.as<OpNode>()) {
// Placeholder fusion algorithm which abstracts
// single definitions into functions only.
Array<Var> params;
......@@ -37,50 +37,37 @@ struct AbstractFusableOps : ExprMutator {
int param_number = 0;
for (auto arg : call->args) {
auto name = std::string("p") + std::to_string(param_number++);
std::ostringstream os;
os << "p" << param_number++;
auto type = arg->checked_type();
auto var = VarNode::make(name, type);
auto var = VarNode::make(os.str(), type);
params.push_back(var);
inner_args.push_back(var);
args.push_back(VisitExpr(arg));
args.push_back(this->Mutate(arg));
}
auto body = CallNode::make(call->op, inner_args, call->attrs);
auto func = FunctionNode::make(params, body, call->checked_type(), {});
auto func = FunctionNode::make(
params, body, call->checked_type(), {});
func = FunctionSetAttr(func, "Primitive", tvm::Integer(1));
std::string func_name = "fused_";
func_name += op_node->name;
func_name += "_";
func_name += std::to_string(counter++);
func_name += "_";
func_name += std::to_string(expr_hash);
auto gv = GlobalVarNode::make(func_name);
mod->Add(gv, func);
fusable_funcs.push_back(gv);
return CallNode::make(gv, args, Attrs());
return CallNode::make(func, args, Attrs());
} else {
return ExprMutator::VisitExpr_(call);
}
}
};
Expr FuseOps(const Module& mod, const Expr& e) {
Expr FuseOps(const Expr& expr) {
// First we convert all chains of fusable ops into
// abstracted functions which we mark as primtive
// then we convert these primtive functions into
// new operators.
auto abstract = AbstractFusableOps(mod, StructuralHash()(e));
auto abstracted_e = abstract.VisitExpr(e);
RELAY_LOG(INFO) << "FuseOps: before=" << e
<< "Fuse: after=" << abstracted_e;
return abstracted_e;
return SimpleFuser().Mutate(expr);
}
TVM_REGISTER_API("relay._ir_pass.FuseOps")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = FuseOps(args[1], args[0]);
*ret = FuseOps(args[0]);
});
} // namespace relay
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
*
* \file src/tvm/relay/pass/lower_ops.cc
*
* \brief Lower a Relay program to set of TVM operators.
*
*/
#include <tvm/lowered_func.h>
#include <tvm/operation.h>
#include <tvm/build_module.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/logging.h>
#include <tvm/relay/pass.h>
#include <tvm/runtime/module.h>
#include <tvm/relay/build_module.h>
#include "../ir/type_functor.h"
namespace tvm {
namespace relay {
using namespace runtime;
LoweredOp LoweredOpNode::make(Function func, LoweredFunc lowered_func) {
auto node = make_node<LoweredOpNode>();
node->func = func;
node->lowered_func = lowered_func;
return LoweredOp(node);
}
struct AbstractLocalFunctions : ExprMutator {
Module mod;
size_t expr_hash;
int counter = 0;
std::unordered_set<GlobalVar, NodeHash, NodeEqual> visited_funcs;
explicit AbstractLocalFunctions(Module mod)
: mod(mod), expr_hash(0), counter(0), visited_funcs() {}
Expr Abstract(const Expr& e) {
expr_hash = StructuralHash()(e);
return VisitExpr(e);
}
Expr VisitExpr_(const GlobalVarNode* gvar_node) final {
auto gvar = GetRef<GlobalVar>(gvar_node);
auto it = visited_funcs.find(gvar);
if (it == visited_funcs.end()) {
auto func = mod->Lookup(gvar);
visited_funcs.insert(gvar);
auto new_func = FunctionNode::make(
func->params,
VisitExpr(func->body),
func->ret_type,
func->type_params,
func->attrs);
mod->Update(gvar, new_func);
}
return gvar;
}
Expr VisitExpr_(const FunctionNode* func_node) final {
Function func = GetRef<Function>(func_node);
auto free_vars = FreeVars(func);
Array<Var> params;
for (auto free_var : free_vars) {
auto var = VarNode::make("free_var", free_var->checked_type());
params.push_back(var);
}
std::string abs_func = "abstracted_func_";
abs_func += std::to_string(counter++);
abs_func += std::to_string(expr_hash);
auto gv = GlobalVarNode::make(abs_func);
auto lifted_func = FunctionNode::make(params, func, Type(), {}, {});
mod->Add(gv, lifted_func);
Array<Expr> args;
for (auto free_var : free_vars) {
args.push_back(free_var);
}
return CallNode::make(gv, args, {});
}
};
struct LiveFunctions : ExprVisitor {
Module mod;
explicit LiveFunctions(Module mod) : mod(mod), global_funcs() {}
std::unordered_set<GlobalVar, NodeHash, NodeEqual> visited_funcs;
std::unordered_set<GlobalVar, NodeHash, NodeEqual> global_funcs;
void Live(const Expr& e) {
CHECK(!e.as<FunctionNode>())
<< "functions should of been transformed away by previous pass";
VisitExpr(e);
}
void VisitExpr_(const FunctionNode* func_node) {
LOG(FATAL) << "functions should of been transformed away by previous pass";
}
void VisitExpr_(const GlobalVarNode* var_node) final {
GlobalVar var = GetRef<GlobalVar>(var_node);
auto it = visited_funcs.find(var);
if (it == visited_funcs.end()) {
auto func = mod->Lookup(var);
visited_funcs.insert(var);
// The last pass has trasnformed functions of the form:
//
// let x = fn (p_1, ..., p_n) { ... };
// ...
//
// into, a top-level declaration:
//
// def abs_f(fv_1, ..., fv_n) {
// return (fn (p_1...,p_N) { ... };)
// }
//
// and:
//
// let x = abs_f(fv_1, ... fv_n);
//
// The only other case we can handle is
//
// fn foo(...) { body }
//
// We just search through the body in this case.
if (auto inner_func = func->body.as<FunctionNode>()) {
return VisitExpr(inner_func->body);
} else {
return VisitExpr(func->body);
}
}
}
void VisitExpr_(const CallNode* call) final {
RELAY_LOG(INFO) << "LiveOps: CallNode=" << GetRef<Call>(call);
if (auto gv_node = call->op.as<GlobalVarNode>()) {
GlobalVar gvar = GetRef<GlobalVar>(gv_node);
Function func = mod->Lookup(gvar);
auto attr = FunctionGetAttr(func, "Primitive");
if (attr.defined() && Downcast<Integer>(attr)->value == 1) {
global_funcs.insert(gvar);
} else {
VisitExpr(gvar);
}
// Finally we need to ensure to visit all the args no matter what.
for (auto arg : call->args) {
VisitExpr(arg);
}
} else {
return ExprVisitor::VisitExpr_(call);
}
}
};
using FCompute = TypedPackedFunc<Array<Tensor>(
const Attrs&, const Array<Tensor>&, Type, tvm::Target)>;
using FSchedule = TypedPackedFunc<Schedule(const Array<Tensor>&, tvm::Target)>;
/*! \brief Return the set of operators in their TVM format. */
Array<LoweredOp> LowerOps(const Module& mod, const Expr& e,
const std::string& target) {
RELAY_LOG(INFO) << "LowerOps: e=" << e;
auto flower_ptr = Registry::Get("relay.op.compiler._lower");
CHECK(flower_ptr);
PackedFunc flower = *flower_ptr;
auto abstracted_e = AbstractLocalFunctions(mod).Abstract(e);
auto live_funcs = LiveFunctions(mod);
live_funcs.VisitExpr(abstracted_e);
auto schedule_reg = Op::GetAttr<FSchedule>("FTVMSchedule");
auto compute_reg = Op::GetAttr<FCompute>("FTVMCompute");
Array<LoweredOp> lowered_funcs;
for (auto func_name : live_funcs.global_funcs) {
auto func = mod->Lookup(func_name);
auto call = Downcast<Call>(func->body);
auto op_node = call->op.as<OpNode>();
CHECK(op_node) << "violated invariant that primtive calls contain a single op call";
auto op = GetRef<Op>(op_node);
RELAY_LOG(INFO) << "LowerOps: Lowering " << op->name;
CHECK(IsPrimitiveOp(op)) << "failed to lower "
<< op->name << "can only lower primitve operations";
Array<Tensor> inputs;
std::string input_name = "in";
int i = 0;
for (auto type_arg : call->type_args) {
auto tt = Downcast<TensorType>(type_arg);
inputs.push_back(PlaceholderOpNode::make(input_name + std::to_string(i),
tt->shape, tt->dtype)
.output(0));
i++;
}
auto output_tt = call->checked_type();
auto target_node = Target::create(target);
Array<Tensor> outputs =
compute_reg[op](call->attrs, inputs, output_tt, target_node);
auto schedule = schedule_reg[op](outputs, target_node);
size_t hash = StructuralHash()(func);
LoweredFunc lf =
flower(op->name + std::to_string(hash), schedule, inputs, outputs);
func = FunctionSetAttr(func, "LoweredFunc", lf);
mod->Add(func_name, func, true);
lowered_funcs.push_back(LoweredOpNode::make(func, lf));
}
return lowered_funcs;
}
TVM_REGISTER_API("relay._ir_pass.LowerOps")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = LowerOps(args[0], args[1], args[2]);
});
} // namespace relay
} // namespace tvm
......@@ -22,27 +22,6 @@
namespace tvm {
namespace runtime {
/*!
* \brief The name of Device API factory.
* \param type The device type.
*/
inline std::string DeviceName(int type) {
switch (type) {
case kDLCPU: return "cpu";
case kDLGPU: return "gpu";
case kDLOpenCL: return "opencl";
case kDLSDAccel: return "sdaccel";
case kDLAOCL: return "aocl";
case kDLVulkan: return "vulkan";
case kDLMetal: return "metal";
case kDLVPI: return "vpi";
case kDLROCM: return "rocm";
case kOpenGL: return "opengl";
case kDLExtDev: return "ext_dev";
default: LOG(FATAL) << "unknown type =" << type; return "Unknown";
}
}
class DeviceAPIManager {
public:
static const int kMaxDeviceAPI = 32;
......
......@@ -187,8 +187,8 @@ void GraphRuntime::SetupStorage() {
CHECK_GE(storage_id, 0) << "Do not support runtime shape op";
DLDataType t = vtype[i];
size_t bits = t.bits * t.lanes;
CHECK_EQ(bits % 8U, 0U);
size_t bytes = (bits / 8U) * size;
CHECK(bits % 8U == 0U || bits ==1U);
size_t bytes = ((bits + 7U) / 8U) * size;
uint32_t sid = static_cast<uint32_t>(storage_id);
if (sid >= pool_entry.size()) {
......
import tvm
import tvm.testing
import numpy as np
from tvm import relay
def test_compile_engine():
engine = relay.backend.compile_engine.get()
def get_func(shape):
x = relay.var("x", shape=shape)
y = relay.add(x, x)
z = relay.add(y, x)
f = relay.ir_pass.infer_type(relay.Function([x], z))
return f
z1 = engine.lower(get_func((10,)), "llvm")
z2 = engine.lower(get_func((10,)), "llvm")
z3 = engine.lower(get_func(()), "llvm")
assert z1.same_as(z2)
assert not z3.same_as(z1)
if tvm.context("cuda").exist:
z4 = engine.lower(get_func(()), "cuda")
assert not z3.same_as(z4)
# Test JIT target
for target in ["llvm"]:
ctx = tvm.context(target)
if ctx.exist:
f = engine.jit(get_func((10,)), target)
x = tvm.nd.array(np.ones(10).astype("float32"), ctx=ctx)
y = tvm.nd.empty((10,), ctx=ctx)
f(x, y)
tvm.testing.assert_allclose(
y.asnumpy(), x.asnumpy() * 3)
engine.dump()
if __name__ == "__main__":
test_compile_engine()
import numpy as np
from tvm import relay
from tvm.relay import create_executor
from tvm.relay.ir_pass import infer_type
from tvm.relay.interpreter import Interpreter
from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay.op import add
from tvm.relay.module import Module
......@@ -25,8 +23,8 @@ def check_rts(expr, args, expected_result, mod=None):
expected_result:
The expected result of running the expression.
"""
intrp = create_executor('graph', mod=mod)
graph = create_executor('graph', mod=mod)
intrp = relay.create_executor('debug', mod=mod)
graph = relay.create_executor('graph', mod=mod)
eval_result = intrp.evaluate(expr)(*args)
rts_result = graph.evaluate(expr)(*args)
np.testing.assert_allclose(eval_result.asnumpy(), rts_result.asnumpy())
......
import numpy as np
import tvm
import tvm.testing
from tvm import relay
from tvm.relay.interpreter import Value, TupleValue
from tvm.relay import op
from tvm.relay.backend.interpreter import Value, TupleValue
from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay import testing, create_executor
def check_eval(expr, args, expected_result, mod=None, rtol=1e-07):
intrp = create_executor(mod=mod)
result = intrp.evaluate(expr)(*args)
np.testing.assert_allclose(result.asnumpy(), expected_result, rtol=rtol)
# TODO(tqchen) add more types once the schedule register is fixed.
for target in ["llvm"]:
ctx = tvm.context(target, 0)
if not ctx.exist:
return
intrp = create_executor(mod=mod, ctx=ctx, target=target)
result = intrp.evaluate(expr)(*args)
# use tvm.testing which also set atol
tvm.testing.assert_allclose(
result.asnumpy(), expected_result, rtol=rtol)
def test_from_scalar():
......@@ -34,7 +41,7 @@ def test_id():
def test_add_const():
two = op.add(relay.const(1), relay.const(1))
two = relay.add(relay.const(1), relay.const(1))
func = relay.Function([], two)
check_eval(func, [], 2)
......@@ -42,7 +49,7 @@ def test_add_const():
def test_mul_param():
x = relay.var('x', shape=(10, 10))
y = relay.var('y', shape=(1, 10))
func = relay.Function([x, y], op.multiply(x, y))
func = relay.Function([x, y], relay.multiply(x, y))
x_data = np.random.rand(10, 10).astype('float32')
y_data = np.random.rand(1, 10).astype('float32')
check_eval(func, [x_data, y_data], x_data * y_data)
......@@ -53,7 +60,7 @@ def test_mul_param():
# def test_dense():
# x = relay.var('x', shape=(10, 10))
# w = relay.var('w', shape=(10, 10))
# y = op.nn.dense(x, w)
# y = relay.nn.dense(x, w)
# func = relay.Function([x, w], y)
# x_data = np.random.rand(10, 10).astype('float32')
# w_data = np.random.rand(10, 10).astype('float32')
......@@ -63,7 +70,7 @@ def test_mul_param():
# x = relay.var('x', shape=(10, 10))
# w = relay.var('w', shape=(10, 10))
# b = relay.var('b', shape=(10,))
# y = op.add(op.nn.dense(x, w), b)
# y = relay.add(relay.nn.dense(x, w), b)
# func = relay.Function([x, w, b], y)
# x_data = np.random.rand(10, 10).astype('float32')
# w_data = np.random.rand(10, 10).astype('float32')
......@@ -73,46 +80,49 @@ def test_mul_param():
def test_equal():
i = relay.var('i', shape=[], dtype='int32')
j = relay.var('i', shape=[], dtype='int32')
z = op.equal(i, j)
z = relay.equal(i, j)
func = relay.Function([i, j], z, ret_type=relay.TensorType([], 'bool'))
i_data = relay.const(0)
j_data = relay.const(0)
check_eval(func, [i_data, j_data], True)
def test_subtract():
i = relay.var('i', shape=[], dtype='int32')
sub = op.subtract(i, relay.const(1, dtype='int32'))
sub = relay.subtract(i, relay.const(1, dtype='int32'))
func = relay.Function([i], sub, ret_type=relay.TensorType([], 'int32'))
i_data = np.array(1, dtype='int32')
check_eval(func, [i_data], 0)
def test_simple_loop():
mod = relay.module.Module({})
sum_up = relay.GlobalVar('sum_up')
i = relay.var('i', shape=[], dtype='int32')
sb = ScopeBuilder()
with sb.if_scope(op.equal(i, relay.const(0, dtype='int32'))):
with sb.if_scope(relay.equal(i, relay.const(0, dtype='int32'))):
sb.ret(i)
with sb.else_scope():
one_less = op.subtract(i, relay.const(1, dtype='int32'))
one_less = relay.subtract(i, relay.const(1, dtype='int32'))
rec_call = relay.Call(sum_up, [one_less])
sb.ret(op.add(rec_call, i))
sb.ret(relay.add(rec_call, i))
func = relay.Function([i], sb.get(), ret_type=relay.TensorType([], 'int32'))
mod[sum_up] = func
i_data = np.array(10, dtype='int32')
check_eval(sum_up, [i_data], sum(range(1, 11)), mod=mod)
def test_loop():
mod = relay.module.Module({})
sum_up = relay.GlobalVar('sum_up')
i = relay.var('i', shape=[], dtype='int32')
accum = relay.var('accum', shape=[], dtype='int32')
sb = ScopeBuilder()
with sb.if_scope(op.equal(i, relay.const(0))):
with sb.if_scope(relay.equal(i, relay.const(0))):
sb.ret(accum)
with sb.else_scope():
one_less = op.subtract(i, relay.const(1))
new_accum = op.add(accum, i)
one_less = relay.subtract(i, relay.const(1))
new_accum = relay.add(accum, i)
sb.ret(relay.Call(sum_up, [one_less, new_accum]))
func = relay.Function([i, accum], sb.get())
mod[sum_up] = func
......@@ -120,19 +130,21 @@ def test_loop():
accum_data = np.array(0, dtype='int32')
check_eval(sum_up, [i_data, accum_data], sum(range(1, 11)), mod=mod)
def test_mlp():
pass
# net = testing.mlp.get_workload(1)
# import pdb; pdb.set_trace()
def test_binds():
x = relay.var("x")
y = relay.add(x, x)
intrp = create_executor("debug")
xx = np.ones((10, 20))
res = intrp.evaluate(y, binds={x: xx}).asnumpy()
tvm.testing.assert_allclose(xx + xx, res)
if __name__ == "__main__":
test_id()
test_add_const()
# test_dense()
# test_linear()
test_equal()
test_subtract()
test_simple_loop()
test_loop()
test_mlp()
test_binds()
......@@ -2,7 +2,7 @@ import math
import tvm
import numpy as np
from tvm import relay
from tvm.relay.interpreter import create_executor
from tvm.relay.testing import ctx_list
def sigmoid(x):
one = np.ones_like(x)
......@@ -27,10 +27,15 @@ def test_unary_op():
if ref is not None:
data = np.random.rand(*shape).astype(dtype)
intrp = create_executor()
op_res = intrp.evaluate(y, { x: relay.const(data) })
ref_res = ref(data)
np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
func = relay.Function([x], y)
for target, ctx in ctx_list():
# use graph by execuor default for testing, as we need
# create function explicitly to avoid constant-folding.
intrp = relay.create_executor("graph", ctx=ctx, target=target)
op_res = intrp.evaluate(func)(data)
np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
for opfunc, ref in [(tvm.relay.log, np.log),
(tvm.relay.exp, np.exp),
......@@ -67,14 +72,17 @@ def test_binary_op():
z = opfunc(x, y)
x_data = np.random.rand(5, 10, 5).astype(t1.dtype)
y_data = np.random.rand(5, 10, 5).astype(t2.dtype)
intrp = create_executor()
op_res = intrp.evaluate(z, { x: relay.const(x_data), y: relay.const(y_data) })
ref_res = ref(x_data, y_data)
np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
func = relay.Function([x, y], z)
for target, ctx in ctx_list():
# use graph by execuor default for testing, as we need
# create function explicitly to avoid constant-folding.
intrp = relay.create_executor("graph", ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data, y_data)
np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
for opfunc, ref in [(relay.add, np.add),
(relay.subtract, np.subtract),
(relay.mod, np.mod),
(relay.multiply, np.multiply),
(relay.divide, np.divide)]:
check_binary_op(opfunc, ref)
......@@ -116,7 +124,7 @@ def test_log_softmax():
assert yy.checked_type == relay.TensorType((n, d))
def test_concatenate_infer_type():
def test_concatenate():
n, t, d = tvm.var("n"), tvm.var("t"), 100
x = relay.var("x", shape=(n, t, d))
y = relay.var("y", shape=(n, t, d))
......@@ -134,15 +142,23 @@ def test_concatenate_infer_type():
zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.TensorType((n, t + t, 100))
# x = relay.var("x", shape=(10, 5))
# y = relay.var("y", shape=(10, 5))
# z = relay.concatenate((x, y), axis=1)
# intrp = create_executor()
# x_data = np.random.rand(10, 5).astype('float32')
# y_data = np.random.rand(10, 5).astype('float32')
# op_res = intrp.evaluate(z, { x: relay.const(x_data), y: relay.const(y_data) })
# ref_res = np.concatenate(x_data, y_data, axis=1)
# np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
x = relay.var("x", shape=(10, 5))
y = relay.var("y", shape=(10, 5))
z = relay.concatenate((x, y), axis=1)
# Check result.
func = relay.Function([x, y], z)
x_data = np.random.rand(10, 5).astype('float32')
y_data = np.random.rand(10, 5).astype('float32')
ref_res = np.concatenate((x_data, y_data), axis=1)
for target, ctx in ctx_list():
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
intrp2 = relay.create_executor("debug", ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(x_data, y_data)
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=0.01)
op_res2 = intrp2.evaluate(func)(x_data, y_data)
tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=0.01)
def test_dropout():
n, t, d = tvm.var("n"), tvm.var("t"), tvm.var("d")
......@@ -206,7 +222,7 @@ if __name__ == "__main__":
test_unary_op()
test_binary_op()
test_expand_dims_infer_type()
test_concatenate_infer_type()
test_concatenate()
test_softmax()
test_log_softmax()
test_dropout()
......
import tvm
import numpy as np
from tvm import relay
from tvm.relay import create_executor
from tvm.relay.testing import ctx_list
def test_binary_op():
......@@ -24,12 +24,15 @@ def test_binary_op():
z = opfunc(x, y)
x_data = np.random.rand(5, 10, 5).astype(t1.dtype)
y_data = np.random.rand(5, 10, 5).astype(t2.dtype)
intrp = create_executor()
op_res = intrp.evaluate(z, { x: relay.const(x_data), y: relay.const(y_data) })
ref_res = ref(x_data, y_data)
np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
func = relay.Function([x, y], z)
for opfunc, ref in [(relay.pow, np.power)]:
for target, ctx in ctx_list():
intrp = relay.create_executor("graph", ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data, y_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res)
for opfunc, ref in [(relay.power, np.power)]:
check_binary_op(opfunc, ref)
......@@ -57,15 +60,19 @@ def test_cmp_type():
z = op(x, y)
x_data = np.random.rand(*x_shape).astype(t1.dtype)
y_data = np.random.rand(*y_shape).astype(t2.dtype)
intrp = create_executor()
op_res = intrp.evaluate(z, { x: relay.const(x_data), y: relay.const(y_data) })
ref_res = ref(x_data, y_data)
np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
func = relay.Function([x, y], z)
for target, ctx in ctx_list():
intrp = relay.create_executor("graph", ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data, y_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res)
def test_binary_int_broadcast():
for op, ref in [(relay.right_shift, np.right_shift),
(relay.left_shift, np.left_shift),
(relay.mod, np.mod),
(relay.maximum, np.maximum),
(relay.minimum, np.minimum)]:
x = relay.var("x", relay.TensorType((10, 4), "int32"))
......@@ -81,10 +88,14 @@ def test_binary_int_broadcast():
t2 = relay.TensorType(y_shape, 'int32')
x_data = np.random.rand(*x_shape).astype(t1.dtype)
y_data = np.random.rand(*y_shape).astype(t2.dtype)
intrp = create_executor()
op_res = intrp.evaluate(z, { x: relay.const(x_data), y: relay.const(y_data) })
func = relay.Function([x, y], z)
ref_res = ref(x_data, y_data)
np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
for target, ctx in ctx_list():
intrp = relay.create_executor("graph", ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data, y_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res)
def test_where():
cond = relay.var("cond", relay.TensorType((3, 4), "float32"))
......
import tvm
from tvm import relay
def test_fuse_simple():
"""Simple testcase."""
x = relay.var("x", shape=(10, 20))
y = relay.add(x, x)
z = relay.exp(y)
z = relay.ir_pass.infer_type(z)
zz = relay.ir_pass.fuse_ops(z)
zz = relay.ir_pass.fuse_ops(zz)
zz = relay.ir_pass.infer_type(zz)
zz.astext()
if __name__ == "__main__":
test_fuse_simple()
......@@ -3,10 +3,9 @@
from __future__ import absolute_import as _abs
import tvm
import topi
from . import tag
from . import cpp
@tvm.tag_scope(tag=tag.BROADCAST)
def expand_dims(a, axis, num_newaxis=1):
"""Expand the shape of an array.
......@@ -25,7 +24,6 @@ def expand_dims(a, axis, num_newaxis=1):
return cpp.expand_dims(a, axis, num_newaxis)
@tvm.tag_scope(tag=tag.BROADCAST)
def expand_like(a, shape_like, axis):
"""Expand an input array with the shape of second array.
This operation can always be composed of unsqueezing and
......@@ -79,7 +77,6 @@ def expand_like(a, shape_like, axis):
return tvm.compute(shape_like.shape, _compute)
@tvm.tag_scope(tag=tag.INJECTIVE)
def transpose(a, axes=None):
"""Permute the dimensions of an array.
......@@ -141,7 +138,6 @@ def strided_slice(a, begin, end, strides=None):
return cpp.strided_slice(a, begin, end, strides)
@tvm.tag_scope(tag=tag.INJECTIVE)
def reshape(a, newshape):
"""Reshape the array
......@@ -159,7 +155,6 @@ def reshape(a, newshape):
return cpp.reshape(a, newshape)
@tvm.tag_scope(tag=tag.INJECTIVE)
def squeeze(a, axis=None):
"""Remove single-dimensional entries from the shape of an array.
......@@ -178,7 +173,6 @@ def squeeze(a, axis=None):
return cpp.squeeze(a, axis)
@tvm.tag_scope(tag=tag.INJECTIVE)
def concatenate(a_tuple, axis=0):
"""Join a sequence of arrays along an existing axis.
......@@ -197,7 +191,6 @@ def concatenate(a_tuple, axis=0):
return cpp.concatenate(a_tuple, axis)
@tvm.tag_scope(tag=tag.INJECTIVE)
def split(ary, indices_or_sections, axis=0):
"""Split an array into multiple sub-arrays.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment