Unverified Commit c91ded32 by Tianqi Chen Committed by GitHub

[RELAY][BACKEND] CompileEngine refactor. (#2059)

parent 4e77eeb2
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/build_module.h
* \brief The passes and data structures needed to build a
* tvm::Module from a Relay program.
*/
#ifndef TVM_RELAY_BUILD_MODULE_H_
#define TVM_RELAY_BUILD_MODULE_H_
#include <tvm/lowered_func.h>
#include <tvm/relay/module.h>
#include <tvm/relay/expr.h>
#include <string>
namespace tvm {
namespace relay {
/*! \brief A lowered Relay operation.
*
* A lowered operation is a pair containing the "primitive" function used
* to produce the lowered function as well as the lowered function itself.
*/
class LoweredOp;
/*! \brief Call container. */
class LoweredOpNode : public Node {
public:
/*!
* \brief The primitive function to be lowered.
*
* A primitive function consists only of calls to relay::Op which
* can be fused.
*/
Function func;
/*!
* \brief The lowered function.
*/
LoweredFunc lowered_func;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("func", &func);
v->Visit("lowered_func", &lowered_func);
}
TVM_DLL static LoweredOp make(
Function func,
LoweredFunc lowered_func);
static constexpr const char* _type_key = "relay.LoweredOp";
TVM_DECLARE_NODE_TYPE_INFO(LoweredOpNode, Node);
};
RELAY_DEFINE_NODE_REF(LoweredOp, LoweredOpNode, NodeRef);
/*!
* \brief Lower the operations contained in a Relay expression.
*
* The lowering pass will only lower functions marked as primitive,
* the FuseOps pass will provide this behavior, if run before LowerOps.
*
* \note This will do a reachability analysis and lower all definitions
* reachable from the provided expression.
*
* \param mod The module.
* \param expr The expression with operations to be lowered.
* \param target The target to lower the functions to.
*
* \return The set of lowered operations.
*/
Array<LoweredOp> LowerOps(const Module& mod, const Expr& expr,
const std::string& target = "llvm");
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_BUILD_MODULE_H_
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#ifndef TVM_RELAY_INTERPRETER_H_ #ifndef TVM_RELAY_INTERPRETER_H_
#define TVM_RELAY_INTERPRETER_H_ #define TVM_RELAY_INTERPRETER_H_
#include <tvm/build_module.h>
#include <tvm/relay/module.h> #include <tvm/relay/module.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
...@@ -27,7 +28,9 @@ namespace relay { ...@@ -27,7 +28,9 @@ namespace relay {
*/ */
class Value; class Value;
/*! \brief Evaluate an expression using the interpreter producing a value. /*!
*\brief Create a Interpreter function that can
* evaluate an expression and produce a value.
* *
* The resulting value can be passed to Python, making it easy to use * The resulting value can be passed to Python, making it easy to use
* for testing and debugging. * for testing and debugging.
...@@ -38,8 +41,14 @@ class Value; ...@@ -38,8 +41,14 @@ class Value;
* *
* Our intent is that this will never be the most efficient implementation of * Our intent is that this will never be the most efficient implementation of
* Relay's semantics, but a readable and clear one. * Relay's semantics, but a readable and clear one.
*
* \param mod The function module.
* \param context The primary context that the interepreter runs on.
* \param target Compiler target flag to compile the functions on the context.
* \return A function that takes in an expression and returns a value.
*/ */
Value Evaluate(Module mod, Expr e); runtime::TypedPackedFunc<Value(Expr)>
CreateInterpreter(Module mod, DLContext context, Target target);
/*! \brief The base container type of Relay values. */ /*! \brief The base container type of Relay values. */
class ValueNode : public RelayNode { class ValueNode : public RelayNode {
...@@ -125,9 +134,6 @@ struct TensorValueNode : ValueNode { ...@@ -125,9 +134,6 @@ struct TensorValueNode : ValueNode {
/*! \brief Build a value from an NDArray. */ /*! \brief Build a value from an NDArray. */
TVM_DLL static TensorValue make(runtime::NDArray data); TVM_DLL static TensorValue make(runtime::NDArray data);
/*! \brief Construct an empty tensor value from t. */
TVM_DLL static TensorValue FromType(const Type& t);
static constexpr const char* _type_key = "relay.TensorValue"; static constexpr const char* _type_key = "relay.TensorValue";
TVM_DECLARE_NODE_TYPE_INFO(TensorValueNode, ValueNode); TVM_DECLARE_NODE_TYPE_INFO(TensorValueNode, ValueNode);
}; };
......
/*!
* Copyright (c) 2017 by Contributors
* \file nnvm/compiler/op_attr_types.h
* \brief The Expr and related elements in DataFlow construction.
*/
#ifndef TVM_RELAY_OP_ATTR_TYPES_H_
#define TVM_RELAY_OP_ATTR_TYPES_H_
#include <tvm/tensor.h>
#include <tvm/schedule.h>
#include <tvm/build_module.h>
#include <tvm/relay/type.h>
#include <tvm/relay/expr.h>
namespace tvm {
namespace relay {
/*! \brief operator pattern used in graph fusion */
enum OpPatternKind {
// Elementwise operation
kElemWise = 0,
// Broadcasting operator, can always map output axis to the input in order.
// for example :code:`out[i, ax1, j, ax2] = input[i, j]`.
// Note that the axis need to be in order so transpose is not a bcast operator.
kBroadcast = 1,
// Injective operator, can always injectively map output axis to a single input axis.
// All injective operator can still be safely fused to injective and reduction.
kInjective = 2,
// Communicative reduction operator.
kCommReduce = 3,
// Complex operation, can still fuse elemwise operations into its output.
// but cannot chain another complex op
kOutEWiseFusable = 4,
// Opaque operation, cannot fuse anything.
kOpaque = 8
};
/*! \brief the operator pattern */
using TOpPattern = int;
/*!
* \brief Computation description interface.
*
* \note This function have a special convention
* for functions with tuple input/output.
*
* So far we restrict tuple support to the following case:
* - Function which takes a single tuple as input.
* - Function which outputs a single tuple.
*
* In both cases, the tuple is flattened as array.
*
* \param attrs The attribute of the primitive
* \param inputs The input tensors.
* \param out_type The output type information
& these are always placeholders.
* \return The output compute description of the operator.
*/
using FTVMCompute = runtime::TypedPackedFunc<
Array<Tensor>(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target)>;
/*!
* \brief Build the computation schedule for
* op whose root is at current op.
*
* \param attrs The attribute of the node.
* \param outs The output tensors.
* \param target The build target.
* \return schedule The computation schedule.
*/
using FTVMSchedule = runtime::TypedPackedFunc<
Schedule(const Array<Tensor>& outs,
const Target& target)>;
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_OP_ATTR_TYPES_H_
...@@ -178,6 +178,40 @@ class DeviceAPI { ...@@ -178,6 +178,40 @@ class DeviceAPI {
/*! \brief The device type bigger than this is RPC device */ /*! \brief The device type bigger than this is RPC device */
constexpr int kRPCSessMask = 128; constexpr int kRPCSessMask = 128;
/*!
* \brief The name of Device API factory.
* \param type The device type.
* \return the device name.
*/
inline const char* DeviceName(int type) {
switch (type) {
case kDLCPU: return "cpu";
case kDLGPU: return "gpu";
case kDLOpenCL: return "opencl";
case kDLSDAccel: return "sdaccel";
case kDLAOCL: return "aocl";
case kDLVulkan: return "vulkan";
case kDLMetal: return "metal";
case kDLVPI: return "vpi";
case kDLROCM: return "rocm";
case kOpenGL: return "opengl";
case kDLExtDev: return "ext_dev";
default: LOG(FATAL) << "unknown type =" << type; return "Unknown";
}
}
#ifndef _LIBCPP_SGX_NO_IOSTREAMS
inline std::ostream& operator<<(std::ostream& os, DLContext ctx) { // NOLINT(*)
int device_type = static_cast<int>(ctx.device_type);
if (device_type > kRPCSessMask) {
os << "remote[" << (device_type / kRPCSessMask) << "]-";
device_type = device_type % kRPCSessMask;
}
os << runtime::DeviceName(device_type) << "(" << ctx.device_id << ")";
return os;
}
#endif
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
#endif // TVM_RUNTIME_DEVICE_API_H_ #endif // TVM_RUNTIME_DEVICE_API_H_
...@@ -888,6 +888,7 @@ inline std::ostream& operator<<(std::ostream& os, TVMType t) { // NOLINT(*) ...@@ -888,6 +888,7 @@ inline std::ostream& operator<<(std::ostream& os, TVMType t) { // NOLINT(*)
} }
return os; return os;
} }
#endif #endif
inline std::string TVMType2String(TVMType t) { inline std::string TVMType2String(TVMType t) {
......
...@@ -132,7 +132,7 @@ class GraphModule(object): ...@@ -132,7 +132,7 @@ class GraphModule(object):
params : dict of str to NDArray params : dict of str to NDArray
Additonal arguments Additonal arguments
""" """
if key: if key is not None:
self._get_input(key).copyfrom(value) self._get_input(key).copyfrom(value)
if params: if params:
......
...@@ -7,8 +7,7 @@ from . import ty ...@@ -7,8 +7,7 @@ from . import ty
from . import expr from . import expr
from . import module from . import module
from . import ir_pass from . import ir_pass
from .build_module import build from .build_module import build, create_executor
from .interpreter import create_executor
# Root operators # Root operators
from .op import Op from .op import Op
...@@ -18,7 +17,7 @@ from .op.transform import * ...@@ -18,7 +17,7 @@ from .op.transform import *
from . import nn from . import nn
from . import vision from . import vision
from . import image from . import image
from . import backend
from .scope_builder import ScopeBuilder from .scope_builder import ScopeBuilder
...@@ -56,13 +55,6 @@ TupleGetItem = expr.TupleGetItem ...@@ -56,13 +55,6 @@ TupleGetItem = expr.TupleGetItem
var = expr.var var = expr.var
const = expr.const const = expr.const
@register_func("relay._tensor_value_repr")
def _tensor_value_repr(tv):
return str(tv.data.asnumpy())
@register_func("relay._constant_repr")
def _tensor_constant_repr(tv):
return str(tv.data.asnumpy())
# pylint: disable=unused-argument # pylint: disable=unused-argument
@register_func("relay.debug") @register_func("relay.debug")
......
"""The interface to the Evaluator exposed from C++."""
from tvm._ffi.function import _init_api
_init_api("relay._interpreter", __name__)
"""Backend codege modules for relay."""
from . import compile_engine
"""The interface of expr function exposed from C++."""
from __future__ import absolute_import
import logging
from ... import build_module as _build
from ... import container as _container
from ..._ffi.function import _init_api, register_func
@register_func("relay.backend.lower")
def lower(sch, inputs, func_name, source_func):
"""Backend function for lowering.
Parameters
----------
sch : tvm.Schedule
The schedule.
inputs : List[tvm.Tensor]
The inputs to the function.
func_name : str
The name of the function.
source-func : tvm.relay.Function
The source function to be lowered.
Returns
-------
lowered_funcs : List[tvm.LoweredFunc]
The result of lowering.
"""
import traceback
# pylint: disable=broad-except
try:
f = _build.lower(sch, inputs, name=func_name)
logging.debug("lower function %s", func_name)
logging.debug("%s", _build.lower(sch, inputs, simple_mode=True))
except Exception:
msg = traceback.format_exc()
msg += "Error during compile function\n"
msg += "-----------------------------\n"
msg += source_func.astext()
raise RuntimeError(msg)
return f if isinstance(
f, (_container.Array, tuple, list)) else [f]
@register_func("relay.backend.build")
def build(funcs, target, target_host=None):
"""Backend build function.
Parameters
----------
funcs : List[tvm.LoweredFunc]
The list of lowered functions.
target : tvm.Target
The target to run the code on.
target_host : tvm.Target
The host target.
Returns
-------
module : tvm.Module
The runtime module.
"""
if target_host == "":
target_host = None
return _build.build(funcs, target=target, target_host=target_host)
@register_func("relay._tensor_value_repr")
def _tensor_value_repr(tvalue):
return str(tvalue.data.asnumpy())
@register_func("relay._constant_repr")
def _tensor_constant_repr(tvalue):
return str(tvalue.data.asnumpy())
_init_api("relay.backend", __name__)
"""Backend code generation engine."""
from __future__ import absolute_import
from ..base import register_relay_node, NodeBase
from ... import target as _target
from .. import expr as _expr
from . import _backend
@register_relay_node
class CachedFunc(NodeBase):
"""Low-level tensor function to back a relay primitive function.
"""
pass
@register_relay_node
class CCacheKey(NodeBase):
"""Key in the CompileEngine.
Parameters
----------
source_func : tvm.relay.Function
The source function.
target : tvm.Target
The target we want to run the function on.
"""
def __init__(self, source_func, target):
self.__init_handle_by_constructor__(
_backend._make_CCacheKey, source_func, target)
@register_relay_node
class CCacheValue(NodeBase):
"""Value in the CompileEngine, including usage statistics.
"""
pass
def _get_cache_key(source_func, target):
if isinstance(source_func, _expr.Function):
if isinstance(target, str):
target = _target.create(target)
if not target:
raise ValueError("Need target when source_func is a Function")
return CCacheKey(source_func, target)
if not isinstance(source_func, CCacheKey):
raise TypeError("Expect source_func to be CCacheKey")
return source_func
@register_relay_node
class CompileEngine(NodeBase):
"""CompileEngine to get lowered code.
"""
def __init__(self):
raise RuntimeError("Cannot construct a CompileEngine")
def lower(self, source_func, target=None):
"""Lower a source_func to a CachedFunc.
Parameters
----------
source_func : Union[tvm.relay.Function, CCacheKey]
The source relay function.
target : tvm.Target
The target platform.
Returns
-------
cached_func: CachedFunc
The result of lowering.
"""
key = _get_cache_key(source_func, target)
return _backend._CompileEngineLower(self, key)
def jit(self, source_func, target=None):
"""JIT a source_func to a tvm.Function.
Parameters
----------
source_func : Union[tvm.relay.Function, CCacheKey]
The source relay function.
target : tvm.Target
The target platform.
Returns
-------
cached_func: CachedFunc
The result of lowering.
"""
key = _get_cache_key(source_func, target)
return _backend._CompileEngineJIT(self, key)
def clear(self):
"""clear the existing cached functions"""
_backend._CompileEngineClear(self)
def items(self):
"""List items in the cache.
Returns
-------
item_list : List[Tuple[CCacheKey, CCacheValue]]
The list of items.
"""
res = _backend._CompileEngineListItems(self)
assert len(res) % 2 == 0
return [(res[2*i], res[2*i+1]) for i in range(len(res) // 2)]
def dump(self):
"""Return a string representation of engine dump.
Returns
-------
dump : str
The dumped string representation
"""
items = self.items()
res = "====================================\n"
res += "CompilerEngine dump, %d items cached\n" % len(items)
for k, v in items:
res += "------------------------------------\n"
res += "target={}\n".format(k.target)
res += "use_count={}\n".format(v.use_count)
res += "func_name={}\n".format(v.cached_func.func_name)
res += k.source_func.astext() + "\n"
res += "===================================\n"
return res
def get():
"""Get the global compile engine.
Returns
-------
engine : tvm.relay.backend.CompileEngine
The compile engine.
"""
return _backend._CompileEngineGlobal()
#pylint: disable=no-else-return #pylint: disable=no-else-return
"""An interface to the Realy interpreter.""" """An interface to the Realy interpreter."""
from __future__ import absolute_import from __future__ import absolute_import
import numpy as np import numpy as np
from .. import register_func, nd
from .base import NodeBase, register_relay_node from . import _backend
from . import build_module from .. import _make, ir_pass
from . import _make from ... import register_func, nd
from . import _interpreter from ..base import NodeBase, register_relay_node
from . import ir_pass from ..expr import Call, Constant, GlobalVar, Function, const
from .module import Module from ..scope_builder import ScopeBuilder
from .expr import Call, Constant, GlobalVar, Function, const
from .scope_builder import ScopeBuilder
from .._ffi.base import integer_types
from ..contrib import graph_runtime as tvm_runtime
from .. import cpu
class Value(NodeBase): class Value(NodeBase):
"""Base class of all values. """Base class of all values.
""" """
@staticmethod @staticmethod
@register_func("relay.from_scalar") @register_func("relay.from_scalar")
def from_scalar(i, dtype=None): def from_scalar(value, dtype=None):
"""Convert a Python scalar to a Relay scalar.""" """Convert a Python scalar to a Relay scalar."""
if dtype is None: return TensorValue(const(value, dtype).data)
if isinstance(i, integer_types):
dtype = 'int32'
elif isinstance(i, float):
dtype = 'float32'
elif isinstance(i, bool):
dtype = 'uint8'
else:
raise Exception("unable to infer dtype {0}".format(type(i)))
return TensorValue(nd.array(np.array(i, dtype=dtype)))
@register_relay_node @register_relay_node
...@@ -65,10 +50,6 @@ class TensorValue(Value): ...@@ -65,10 +50,6 @@ class TensorValue(Value):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_make.TensorValue, data) _make.TensorValue, data)
def as_ndarray(self):
"""Convert a Relay TensorValue into a tvm.ndarray."""
return self.data
def asnumpy(self): def asnumpy(self):
"""Convert a Relay TensorValue into a numpy.ndarray.""" """Convert a Relay TensorValue into a numpy.ndarray."""
return self.data.asnumpy() return self.data.asnumpy()
...@@ -79,7 +60,7 @@ class TensorValue(Value): ...@@ -79,7 +60,7 @@ class TensorValue(Value):
def _arg_to_ast(arg): def _arg_to_ast(arg):
if isinstance(arg, TensorValue): if isinstance(arg, TensorValue):
return Constant(arg.data) return Constant(arg.data.copyto(_nd.cpu(0)))
elif isinstance(arg, np.ndarray): elif isinstance(arg, np.ndarray):
return Constant(nd.array(arg)) return Constant(nd.array(arg))
elif isinstance(arg, Constant): elif isinstance(arg, Constant):
...@@ -87,29 +68,9 @@ def _arg_to_ast(arg): ...@@ -87,29 +68,9 @@ def _arg_to_ast(arg):
else: else:
return const(arg) return const(arg)
class Executor(object): class Executor(object):
"""An abstract interface for executing Relay programs.""" """An abstract interface for executing Relay programs."""
def __init__(self, mod=None):
"""
Parameters
----------
mod: relay.Module
The module.
"""
if mod is None:
self.mod = Module({})
else:
self.mod = mod
def optimize(self, expr):
# TODO: We need to move this optimization code into the optimizer/pass manager
ck_expr = ir_pass.infer_type(expr, mod=self.mod)
fused_expr = ir_pass.fuse_ops(self.mod, ck_expr)
ck_fused = ir_pass.infer_type(fused_expr, mod=self.mod)
return ck_fused
def _make_executor(self, _): def _make_executor(self, _):
""" """
Construct a Python function that implements the evaluation Construct a Python function that implements the evaluation
...@@ -122,50 +83,85 @@ class Executor(object): ...@@ -122,50 +83,85 @@ class Executor(object):
Returns Returns
------- -------
executor: function executor: function,
A Python function which implements the behavior of `expr`. A Python function which implements the behavior of `expr`.
""" """
raise Exception("abstract method: please implement me.") raise NotImplementedError()
def evaluate(self, expr, params=None): def evaluate(self, expr, binds=None):
""" """
Evaluate a Relay expression on the interpreter. Evaluate a Relay expression on the executor.
Parameters Parameters
---------- ----------
expr: tvm.relay.Expr expr: tvm.relay.Expr
The expression to evaluate. The expression to evaluate.
binds: Map[tvm.relay.Var, tvm.relay.Expr]
Additional binding of free variable.
Returns
-------
val : Union[function, Value]
The evaluation result.
""" """
if params: if binds:
scope_builder = ScopeBuilder() scope_builder = ScopeBuilder()
for key in params: for key, value in binds.items():
value = params[key] scope_builder.let(key, _arg_to_ast(value))
scope_builder.let(key, value)
scope_builder.ret(expr) scope_builder.ret(expr)
expr = scope_builder.get() expr = scope_builder.get()
if isinstance(expr, Function): if isinstance(expr, Function):
assert not ir_pass.free_vars(expr) assert not ir_pass.free_vars(expr)
executor = self._make_executor(expr)
# If we are evaluating a function or top-level defintion
# the user must call the function themselves.
#
# If we are evaluating an open term with parameters we will
# just return them the result.
if isinstance(expr, (Function, GlobalVar)): if isinstance(expr, (Function, GlobalVar)):
return executor return self._make_executor(expr)
else:
return executor() # normal expression evaluated by running a function.
func = Function([], expr)
return self._make_executor(func)()
class Interpreter(Executor): class Interpreter(Executor):
""" """
A wrapper around the Relay interpreter, implements the excecutor interface. Simple interpreter interface.
Parameters
----------
mod : tvm.relay.Module
The module to support the execution.
ctx : tvm.TVMContext
The runtime context to run the code on.
target : tvm.Target
The target option to build the function.
""" """
def __init__(self, mod=None): def __init__(self, mod, ctx, target):
Executor.__init__(self, mod) self.mod = mod
self.ctx = ctx
self.target = target
self._intrp = _backend.CreateInterpreter(mod, ctx, target)
def optimize(self, expr):
"""Optimize an expr.
Parameters
----------
expr : Expr
The expression to be optimized.
Returns
-------
opt_expr : Expr
The optimized expression.
"""
# TODO: We need to move this optimization code into the optimizer/pass manager
ck_expr = ir_pass.infer_type(expr, mod=self.mod)
fused_expr = ir_pass.fuse_ops(ck_expr)
ck_fused = ir_pass.infer_type(fused_expr, mod=self.mod)
return ck_fused
def _make_executor(self, expr): def _make_executor(self, expr):
def _interp_wrapper(*args): def _interp_wrapper(*args):
...@@ -178,46 +174,9 @@ class Interpreter(Executor): ...@@ -178,46 +174,9 @@ class Interpreter(Executor):
func = self.optimize(func) func = self.optimize(func)
self.mod._add(expr, func, True) self.mod._add(expr, func, True)
opt_expr = Call(expr, relay_args) opt_expr = Call(expr, relay_args)
return _interpreter.evaluate(self.mod, opt_expr) return self._intrp(opt_expr)
elif isinstance(expr, Function): else:
call = Call(expr, relay_args) call = Call(expr, relay_args)
opt_expr = self.optimize(call) opt_expr = self.optimize(call)
return _interpreter.evaluate(self.mod, opt_expr) return self._intrp(opt_expr)
else:
assert not args
opt_expr = self.optimize(expr)
return _interpreter.evaluate(self.mod, opt_expr)
return _interp_wrapper return _interp_wrapper
class GraphRuntime(Executor):
"""A wrapper around the TVM graph runtime, implements the Executor interface."""
def __init__(self, mod=None):
Executor.__init__(self, mod)
def _make_executor(self, expr):
def _graph_wrapper(*args):
func = self.optimize(expr)
graph_json, mod, params = build_module.build(func, mod=self.mod)
assert params is None
gmodule = tvm_runtime.create(graph_json, mod, cpu(0))
# Create map of inputs.
inputs = {}
for i, arg in enumerate(args):
inputs[func.params[i].name_hint] = arg
# Set the inputs here.
gmodule.set_input(**inputs)
# Run the module, and fetch the output.
gmodule.run()
return gmodule.get_output(0)
return _graph_wrapper
def create_executor(mode='debug', mod=None):
if mode == 'debug':
return Interpreter(mod)
elif mode == 'graph':
return GraphRuntime(mod)
else:
raise Exception("unknown mode {0}".format(mode))
...@@ -2,45 +2,257 @@ ...@@ -2,45 +2,257 @@
Construct the necessary state for the TVM graph runtime Construct the necessary state for the TVM graph runtime
from a Relay expression. from a Relay expression.
""" """
from ..build_module import build as tvm_build_module from ..build_module import build as _tvm_build_module
from . graph_runtime_codegen import GraphRuntimeCodegen from .. import nd as _nd, target as _target, autotvm
from ..contrib import graph_runtime as _graph_rt
from . import ir_pass from . import ir_pass
from .module import Module from .backend import interpreter as _interpreter
from .backend import graph_runtime_codegen as _graph_gen
def build(func, params=None, target=None, mod=None): # List of optimization pass and level when switch on
OPT_PASS_LEVEL = {
"SimplifyInference": 0,
"OpFusion": 1,
"FoldScaleAxis": 3,
}
class BuildConfig(object):
"""Configuration scope to set a build config option.
Parameters
----------
kwargs
Keyword arguments of configurations to set.
"""
current = None
defaults = {
"opt_level": 2,
"add_pass": None,
}
def __init__(self, **kwargs):
self._old_scope = None
for k, _ in kwargs.items():
if k not in BuildConfig.defaults:
raise ValueError(
"invalid argument %s, candidates are %s" % (k, BuildConfig.defaults.keys()))
self._attr = kwargs
def __getattr__(self, name):
if name not in self._attr:
return BuildConfig.defaults[name]
return self._attr[name]
def __enter__(self):
# pylint: disable=protected-access
self._old_scope = BuildConfig.current
attr = BuildConfig.current._attr.copy()
attr.update(self._attr)
self._attr = attr
BuildConfig.current = self
return self
def __exit__(self, ptype, value, trace):
assert self._old_scope
BuildConfig.current = self._old_scope
def pass_enabled(self, pass_name):
"""Get whether pass is enabled.
Parameters
----------
pass_name : str
The optimization pass name
Returns
-------
enabled : bool
Whether pass is enabled.
"""
if self.add_pass and pass_name in self.add_pass:
return True
return self.opt_level >= OPT_PASS_LEVEL[pass_name]
BuildConfig.current = BuildConfig()
def build_config(**kwargs):
"""Configure the build behavior by setting config variables.
Parameters
----------
opt_level: int, default=2
Optimization level. See OPT_PASS_LEVEL for level of each pass.
add_pass: set of str
Optimization pass to be added regardless of optimization level.
Returns
-------
config: BuildConfig
The build configuration
"""
return BuildConfig(**kwargs)
def optimize(func):
"""Perform target invariant optimizations.
Parameters
----------
func : tvm.relay.Function
The input to optimization.
Returns
-------
opt_func : tvm.relay.Function
The optimized version of the function.
""" """
Compile a single function to the components needed by the cfg = BuildConfig.current
TVM RTS.
if cfg.pass_enabled("FoldScaleAxis"):
func = ir_pass.infer_type(func)
func = ir_pass.simplify_inference(func)
if cfg.pass_enabled("FoldScaleAxis"):
func = ir_pass.infer_type(func)
func = ir_pass.backward_fold_scale_axis(func)
func = ir_pass.infer_type(func)
func = ir_pass.forward_fold_scale_axis(func)
return func
def build(func,
target=None,
target_host=None,
params=None):
"""Build a function to run on TVM graph runtime.
Parameters Parameters
---------- ----------
func: relay.Expr func: relay.Function
The function to build. The function to build.
target: optional str target : str or :any:`tvm.target.Target`, optional
The target platform. The build target
target_host : str or :any:`tvm.target.Target` optional
Host compilation target, if target is device.
When TVM compiles device specific program such as CUDA,
we also need host(CPU) side code to interact with the driver
setup the dimensions and parameters correctly.
target_host is used to specify the host side codegen target.
By default, llvm is used if it is enabled,
otherwise a stackvm intepreter is used.
params : dict of str to NDArray
Input parameters to the graph that do not change
during inference time. Used for pre-compute
folding optimization.
Returns Returns
------- -------
(graph_json, mod, params): tuple of (str, tvm.Module, dict) graph_json : str
The outputs of building a Relay function for the TVM runtime. The json string that can be accepted by graph runtime.
mod : tvm.Module
The module containing necessary libraries.
params : dict
The parameters of the final graph.
""" """
target = target if target else _target.current_target()
if target is None: if target is None:
target = 'llvm' raise ValueError("Target is not set in env or passed as argument.")
target = _target.create(target)
if mod is None:
mod = Module({}) # If current dispatch context is fallback context (the default root context),
# then load pre-tuned parameters from TopHub
comp = GraphRuntimeCodegen(mod) if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext):
# NB(@jroesch) This creates lowered functions, and generates names for them tophub_context = autotvm.tophub.context(target)
# else:
# We need these names to emit the correct graph as these are names of the tophub_context = autotvm.util.EmptyContext()
# functions contained in the module.
lowered_ops = ir_pass.lower_ops(mod, func) with tophub_context:
mod = tvm_build_module([lf.lowered_func for lf in lowered_ops], target) func = optimize(func)
# Fuse ops before running code gen
# Therefore the call to compile must come after. func = ir_pass.infer_type(func)
comp.codegen(func) func = ir_pass.fuse_ops(func)
graph_json = comp.to_json() # Graph code generation
func = ir_pass.infer_type(func)
graph_gen = _graph_gen.GraphRuntimeCodegen(mod=None, target=target)
graph_json, lowered_funcs = graph_gen.codegen(func)
mod = _tvm_build_module(lowered_funcs, target=target, target_host=target_host)
return graph_json, mod, params return graph_json, mod, params
class GraphExecutor(_interpreter.Executor):
"""Wrapper around Executor interface.
This executor is used for debug and testing purpoes.
Parameters
----------
mod : tvm.relay.Module
The module to support the execution.
ctx : tvm.TVMContext
The runtime context to run the code on.
target : tvm.Target
The target option to build the function.
"""
def __init__(self, mod, ctx, target):
self.mod = mod
self.ctx = ctx
self.target = target
def _make_executor(self, func):
def _graph_wrapper(*args):
graph_json, mod, params = build(func, target=self.target)
assert params is None
gmodule = _graph_rt.create(graph_json, mod, self.ctx)
# Create map of inputs.
for i, arg in enumerate(args):
gmodule.set_input(i, arg)
# Run the module, and fetch the output.
gmodule.run()
return gmodule.get_output(0)
return _graph_wrapper
def create_executor(kind="debug",
mod=None,
ctx=None,
target="llvm"):
"""Factory function to create an executor.
Parameters
----------
kind : str
The type of executor
mod : relay.Mod
The mod
ctx : tvm.TVMContext
The context to execute the code.
target : tvm.Target
The corresponding context
"""
if ctx is not None:
assert ctx.device_type == _nd.context(str(target), 0).device_type
else:
ctx = _nd.context(str(target), 0)
if isinstance(target, str):
target = _target.create(target)
if kind == "debug":
return _interpreter.Interpreter(mod, ctx, target)
elif kind == "graph":
return GraphExecutor(mod, ctx, target)
else:
raise RuntimeError("unknown mode {0}".format(mode))
...@@ -319,12 +319,11 @@ class TupleGetItem(Expr): ...@@ -319,12 +319,11 @@ class TupleGetItem(Expr):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_make.TupleGetItem, tuple_value, index) _make.TupleGetItem, tuple_value, index)
class ExprFunctor(object): class ExprFunctor(object):
""" """
An abstract visitor defined over Expr. An abstract visitor defined over Expr.
A Python version of the class defined in `expr_functor.h`.
Defines the default dispatch over expressions, and Defines the default dispatch over expressions, and
implements memoization. implements memoization.
""" """
...@@ -352,6 +351,8 @@ class ExprFunctor(object): ...@@ -352,6 +351,8 @@ class ExprFunctor(object):
res = self.visit_if(expr) res = self.visit_if(expr)
elif isinstance(expr, Tuple): elif isinstance(expr, Tuple):
res = self.visit_tuple(expr) res = self.visit_tuple(expr)
elif isinstance(expr, TupleGetItem):
res = self.visit_tuple_getitem(expr)
elif isinstance(expr, Constant): elif isinstance(expr, Constant):
res = self.visit_constant(expr) res = self.visit_constant(expr)
else: else:
...@@ -361,31 +362,34 @@ class ExprFunctor(object): ...@@ -361,31 +362,34 @@ class ExprFunctor(object):
return res return res
def visit_function(self, _): def visit_function(self, _):
raise Exception("Abstract method please implement me.") raise NotImplementedError()
def visit_let(self, _): def visit_let(self, _):
raise Exception("Abstract method please implement me.") raise NotImplementedError()
def visit_call(self, _): def visit_call(self, _):
raise Exception("Abstract method please implement me.") raise NotImplementedError()
def visit_var(self, _): def visit_var(self, _):
raise Exception("Abstract method please implement me.") raise NotImplementedError()
def visit_type(self, typ): def visit_type(self, typ):
return typ return typ
def visit_if(self, _): def visit_if(self, _):
raise Exception("Abstract method please implement me.") raise NotImplementedError()
def visit_tuple(self, _): def visit_tuple(self, _):
raise Exception("Abstract method please implement me.") raise NotImplementedError()
def visit_tuple_getitem(self, _):
raise NotImplementedError()
def visit_constant(self, _): def visit_constant(self, _):
raise Exception("Abstract method please implement me.") raise NotImplementedError()
def visit_global_var(self, _): def visit_global_var(self, _):
raise Exception("Abstract method please implement me.") raise NotImplementedError()
class ExprMutator(ExprFunctor): class ExprMutator(ExprFunctor):
...@@ -395,7 +399,6 @@ class ExprMutator(ExprFunctor): ...@@ -395,7 +399,6 @@ class ExprMutator(ExprFunctor):
The default behavior recursively traverses the AST The default behavior recursively traverses the AST
and reconstructs the AST. and reconstructs the AST.
""" """
def visit_function(self, fn): def visit_function(self, fn):
new_body = self.visit(fn.body) new_body = self.visit(fn.body)
return Function( return Function(
...@@ -429,9 +432,19 @@ class ExprMutator(ExprFunctor): ...@@ -429,9 +432,19 @@ class ExprMutator(ExprFunctor):
def visit_tuple(self, tup): def visit_tuple(self, tup):
return Tuple([self.visit(field) for field in tup.fields]) return Tuple([self.visit(field) for field in tup.fields])
def visit_tuple_getitem(self, op):
tuple_value = self.visit(op.tuple_value)
if not tuple_value.same_as(op.tuple_value):
return TupleGetItem(tuple_value, op.index)
return op
def visit_global_var(self, gvar):
return gvar
def visit_constant(self, rconst): def visit_constant(self, rconst):
return rconst return rconst
class TupleWrapper(object): class TupleWrapper(object):
"""TupleWrapper. """TupleWrapper.
......
...@@ -160,6 +160,7 @@ def free_type_vars(expr): ...@@ -160,6 +160,7 @@ def free_type_vars(expr):
""" """
return _ir_pass.free_type_vars(expr) return _ir_pass.free_type_vars(expr)
def simplify_inference(expr): def simplify_inference(expr):
""" Simplify the data-flow graph for inference phase. """ Simplify the data-flow graph for inference phase.
...@@ -176,6 +177,7 @@ def simplify_inference(expr): ...@@ -176,6 +177,7 @@ def simplify_inference(expr):
""" """
return _ir_pass.simplify_inference(expr) return _ir_pass.simplify_inference(expr)
def dead_code_elimination(expr): def dead_code_elimination(expr):
""" Remove expressions which does not effect the program result (dead code). """ Remove expressions which does not effect the program result (dead code).
...@@ -256,8 +258,18 @@ def structural_hash(value): ...@@ -256,8 +258,18 @@ def structural_hash(value):
"relay.Expr or relay.Type").format(type(value)) "relay.Expr or relay.Type").format(type(value))
raise TypeError(msg) raise TypeError(msg)
def fuse_ops(expr, mod):
return _ir_pass.FuseOps(mod, expr)
def lower_ops(mod, expr, target='llvm'): def fuse_ops(expr):
return _ir_pass.LowerOps(mod, expr, target) """Fuse operators in expr together.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
Returns
-------
transformed_expr : tvm.relay.Expr
Transformed expression, containing fused result.
"""
return _ir_pass.FuseOps(expr)
#pylint: disable=invalid-name, unused-argument #pylint: disable=invalid-name, unused-argument
"""Backend compiler related feature registration""" """Backend compiler related feature registration"""
from __future__ import absolute_import from __future__ import absolute_import
import tvm
import topi import topi
import topi.cuda import topi.cuda
from . import register_schedule, register_compute from .op import register_compute, register_schedule, register_pattern, OpPattern
def schedule_injective(outputs, target): def schedule_injective(outputs, target):
"""Generic schedule for binary broadcast.""" """Generic schedule for binary broadcast."""
with tvm.target.create(target): with target:
return topi.generic.schedule_injective(outputs) return topi.generic.schedule_injective(outputs)
schedule_broadcast = schedule_injective schedule_broadcast = schedule_injective
schedule_elemwise = schedule_injective schedule_elemwise = schedule_injective
# log # log
@register_compute("log")
def log_compute(attrs, inputs, output_type, target): def log_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1 assert len(inputs) == 1
return [topi.log(inputs[0])] return [topi.log(inputs[0])]
register_compute("log", log_compute)
register_schedule("log", schedule_broadcast) register_schedule("log", schedule_broadcast)
# exp # exp
@register_compute("exp")
def exp_compute(attrs, inputs, output_type, target): def exp_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1 assert len(inputs) == 1
return [topi.exp(inputs[0])] return [topi.exp(inputs[0])]
register_compute("exp", exp_compute)
register_schedule("exp", schedule_broadcast) register_schedule("exp", schedule_broadcast)
# sqrt # sqrt
@register_compute("sqrt")
def sqrt_compute(attrs, inputs, output_type, target): def sqrt_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1 assert len(inputs) == 1
return [topi.sqrt(inputs[0])] return [topi.sqrt(inputs[0])]
register_compute("sqrt", sqrt_compute)
register_schedule("sqrt", schedule_broadcast) register_schedule("sqrt", schedule_broadcast)
# sigmoid # sigmoid
@register_compute("sigmoid")
def sigmoid_compute(attrs, inputs, output_type, target): def sigmoid_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1 assert len(inputs) == 1
return [topi.sigmoid(inputs[0])] return [topi.sigmoid(inputs[0])]
register_compute("sigmoid", sigmoid_compute)
register_schedule("sigmoid", schedule_broadcast) register_schedule("sigmoid", schedule_broadcast)
# floor # floor
@register_compute("floor")
def floor_compute(attrs, inputs, output_type, target): def floor_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1 assert len(inputs) == 1
return [topi.floor(inputs[0])] return [topi.floor(inputs[0])]
register_compute("floor", floor_compute)
register_schedule("floor", schedule_broadcast) register_schedule("floor", schedule_broadcast)
# ceil # ceil
@register_compute("ceil")
def ceil_compute(attrs, inputs, output_type, target): def ceil_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1 assert len(inputs) == 1
return [topi.ceil(inputs[0])] return [topi.ceil(inputs[0])]
register_compute("ceil", ceil_compute)
register_schedule("ceil", schedule_broadcast) register_schedule("ceil", schedule_broadcast)
# trunc # trunc
@register_compute("trunc")
def trunc_compute(attrs, inputs, output_type, target): def trunc_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1 assert len(inputs) == 1
return [topi.trunc(inputs[0])] return [topi.trunc(inputs[0])]
register_compute("trunc", trunc_compute)
register_schedule("trunc", schedule_broadcast) register_schedule("trunc", schedule_broadcast)
# round # round
@register_compute("round")
def round_compute(attrs, inputs, output_type, target): def round_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1 assert len(inputs) == 1
return [topi.round(inputs[0])] return [topi.round(inputs[0])]
register_compute("round", round_compute)
register_schedule("round", schedule_broadcast) register_schedule("round", schedule_broadcast)
# abs # abs
@register_compute("abs")
def abs_compute(attrs, inputs, output_type, target): def abs_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1 assert len(inputs) == 1
return [topi.abs(inputs[0])] return [topi.abs(inputs[0])]
register_compute("abs", abs_compute)
register_schedule("abs", schedule_broadcast) register_schedule("abs", schedule_broadcast)
# tanh # tanh
@register_compute("tanh")
def tanh_compute(attrs, inputs, output_type, target): def tanh_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1 assert len(inputs) == 1
return [topi.tanh(inputs[0])] return [topi.tanh(inputs[0])]
register_compute("tanh", tanh_compute)
register_schedule("tanh", schedule_broadcast) register_schedule("tanh", schedule_broadcast)
# negative # negative
@register_compute("negative")
def negative_compute(attrs, inputs, output_type, target): def negative_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1 assert len(inputs) == 1
return [topi.negative(inputs[0])] return [topi.negative(inputs[0])]
register_compute("negative", negative_compute)
register_schedule("negative", schedule_broadcast) register_schedule("negative", schedule_broadcast)
# add # add
@register_compute("add")
def add_compute(attrs, inputs, output_type, target): def add_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2 assert len(inputs) == 2
return [topi.add(inputs[0], inputs[1])] return [topi.add(inputs[0], inputs[1])]
register_compute("add", add_compute)
register_schedule("add", schedule_injective) register_schedule("add", schedule_injective)
# subtract # subtract
@register_compute("subtract")
def subtract_compute(attrs, inputs, output_type, target): def subtract_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2 assert len(inputs) == 2
return [topi.subtract(inputs[0], inputs[1])] return [topi.subtract(inputs[0], inputs[1])]
register_compute("subtract", subtract_compute)
register_schedule("subtract", schedule_broadcast) register_schedule("subtract", schedule_broadcast)
# multiply # multiply
@register_compute("multiply")
def multiply_compute(attrs, inputs, output_type, target): def multiply_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2 assert len(inputs) == 2
return [topi.multiply(inputs[0], inputs[1])] return [topi.multiply(inputs[0], inputs[1])]
register_compute("multiply", multiply_compute)
register_schedule("multiply", schedule_broadcast) register_schedule("multiply", schedule_broadcast)
# divide # divide
@register_compute("divide")
def divide_compute(attrs, inputs, output_type, target): def divide_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2 assert len(inputs) == 2
return [topi.divide(inputs[0], inputs[1])] return [topi.divide(inputs[0], inputs[1])]
register_compute("divide", divide_compute)
register_schedule("divide", schedule_broadcast) register_schedule("divide", schedule_broadcast)
# pow # power
def pow_compute(attrs, inputs, output_type, target): @register_compute("power")
def power_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2 assert len(inputs) == 2
return [topi.power(inputs[0], inputs[1])] return [topi.power(inputs[0], inputs[1])]
register_compute("pow", pow_compute) register_schedule("power", schedule_injective)
register_schedule("pow", schedule_injective)
# mod # mod
@register_compute("mod")
def mod_compute(attrs, inputs, output_type, target): def mod_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2 assert len(inputs) == 2
return [topi.mod(inputs[0], inputs[1])] return [topi.mod(inputs[0], inputs[1])]
register_compute("mod", mod_compute)
register_schedule("mod", schedule_broadcast) register_schedule("mod", schedule_broadcast)
# equal # equal
@register_compute("equal")
def equal_compute(attrs, inputs, output_type, target): def equal_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2 assert len(inputs) == 2
return [topi.equal(inputs[0], inputs[1])] return [topi.equal(inputs[0], inputs[1])]
register_compute("equal", equal_compute)
register_schedule("equal", schedule_broadcast) register_schedule("equal", schedule_broadcast)
# not_equal # not_equal
@register_compute("not_equal")
def not_equal_compute(attrs, inputs, output_type, target): def not_equal_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2 assert len(inputs) == 2
return [topi.not_equal(inputs[0], inputs[1])] return [topi.not_equal(inputs[0], inputs[1])]
register_compute("not_equal", not_equal_compute)
register_schedule("not_equal", schedule_broadcast) register_schedule("not_equal", schedule_broadcast)
# less # less
@register_compute("less")
def less_compute(attrs, inputs, output_type, target): def less_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2 assert len(inputs) == 2
return [topi.less(inputs[0], inputs[1])] return [topi.less(inputs[0], inputs[1])]
register_compute("less", less_compute)
register_schedule("less", schedule_broadcast) register_schedule("less", schedule_broadcast)
# less equal # less equal
@register_compute("less_equal")
def less_equal_compute(attrs, inputs, output_type, target): def less_equal_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2 assert len(inputs) == 2
return [topi.less_equal(inputs[0], inputs[1])] return [topi.less_equal(inputs[0], inputs[1])]
register_compute("less_equal", less_equal_compute)
register_schedule("less_equal", schedule_broadcast) register_schedule("less_equal", schedule_broadcast)
# greater # greater
@register_compute("greater")
def greater_compute(attrs, inputs, output_type, target): def greater_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2 assert len(inputs) == 2
return [topi.greater(inputs[0], inputs[1])] return [topi.greater(inputs[0], inputs[1])]
register_compute("greater", greater_compute)
register_schedule("greater", schedule_broadcast) register_schedule("greater", schedule_broadcast)
# greater equal # greater equal
@register_compute("greater_equal")
def greater_equal_compute(attrs, inputs, output_type, target): def greater_equal_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2 assert len(inputs) == 2
return [topi.greater_equal(inputs[0], inputs[1])] return [topi.greater_equal(inputs[0], inputs[1])]
register_compute("greater_equal", greater_equal_compute)
register_schedule("greater_equal", schedule_broadcast) register_schedule("greater_equal", schedule_broadcast)
# maximum # maximum
@register_compute("maximum")
def maximum_compute(attrs, inputs, output_type, target): def maximum_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2 assert len(inputs) == 2
return [topi.maximum(inputs[0], inputs[1])] return [topi.maximum(inputs[0], inputs[1])]
register_compute("maximum_compute", maximum_compute)
register_schedule("maximum_compute", schedule_injective) register_schedule("maximum_compute", schedule_injective)
# minimum # minimum
@register_compute("minimum")
def minimum_compute(attrs, inputs, output_type, target): def minimum_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2 assert len(inputs) == 2
return [topi.minimum(inputs[0], inputs[1])] return [topi.minimum(inputs[0], inputs[1])]
register_compute("minimum", minimum_compute)
register_schedule("minimum", schedule_injective) register_schedule("minimum", schedule_injective)
# right shift # right shift
@register_compute("right_shift")
def right_shift_compute(attrs, inputs, output_type, target): def right_shift_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2 assert len(inputs) == 2
return [topi.right_shift(inputs[0], inputs[1])] return [topi.right_shift(inputs[0], inputs[1])]
register_compute("right_shift", right_shift_compute)
register_schedule("right_shift", schedule_injective) register_schedule("right_shift", schedule_injective)
# lift shift # left shift
@register_compute("left_shift")
def left_shift_compute(attrs, inputs, output_type, target): def left_shift_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2 assert len(inputs) == 2
return [topi.left_shift(inputs[0], inputs[1])] return [topi.left_shift(inputs[0], inputs[1])]
register_compute("left_shift", left_shift_compute)
register_schedule("left_shift", schedule_injective) register_schedule("left_shift", schedule_injective)
# zeros # zeros
@register_compute("zeros")
def zeros_compute(attrs, inputs, output_type, target): def zeros_compute(attrs, inputs, output_type, target):
assert not inputs assert not inputs
return [topi.full(output_type.shape, output_type.dtype, 0.0)] return [topi.full(output_type.shape, output_type.dtype, 0.0)]
register_compute("zeros", zeros_compute) register_schedule("zeros", schedule_broadcast)
register_schedule("zeros", schedule_injective) register_pattern("zeros", OpPattern.ELEMWISE)
# zeros_like # zeros_like
@register_compute("zeros_like")
def zeros_like_compute(attrs, inputs, output_type, target): def zeros_like_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1 assert len(inputs) == 1
return [topi.full_like(inputs[0], 0.0)] return [topi.full_like(inputs[0], 0.0)]
register_compute("zeros_like", zeros_like_compute) register_schedule("zeros_like", schedule_broadcast)
register_schedule("zeros_like", schedule_injective)
# ones # ones
@register_compute("ones")
def ones_compute(attrs, inputs, output_type, target): def ones_compute(attrs, inputs, output_type, target):
assert not inputs assert not inputs
return [topi.full(output_type.shape, output_type.dtype, 1.0)] return [topi.full(output_type.shape, output_type.dtype, 1.0)]
register_compute("ones", ones_compute) register_schedule("ones", schedule_broadcast)
register_schedule("ones", schedule_injective) register_pattern("ones", OpPattern.ELEMWISE)
# ones_like # ones_like
@register_compute("ones_like")
def ones_like(attrs, inputs, output_type, target): def ones_like(attrs, inputs, output_type, target):
assert len(inputs) == 1 assert len(inputs) == 1
return [topi.full_like(inputs[0], 1.0)] return [topi.full_like(inputs[0], 1.0)]
register_compute("ones_like", ones_like) register_schedule("ones_like", schedule_broadcast)
register_schedule("ones_like", schedule_injective)
# clip # clip
@register_compute("clip")
def clip_compute(attrs, inputs, output_type, target): def clip_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1 assert len(inputs) == 1
return [topi.clip(inputs[0], attrs.a_min, attrs.a_max)] return [topi.clip(inputs[0], attrs.a_min, attrs.a_max)]
register_schedule("clip", schedule_elemwise)
register_pattern("clip", OpPattern.ELEMWISE)
register_compute("clip", clip_compute) # concatenate
register_schedule("clip", schedule_injective) @register_compute("concatenate")
def concatenate_compute(attrs, inputs, output_type, target):
return [topi.concatenate(inputs, axis=attrs.axis)]
register_schedule("concatenate", schedule_injective)
register_pattern("concatenate", OpPattern.INJECTIVE)
...@@ -72,13 +72,80 @@ def register(op_name, attr_key, value=None, level=10): ...@@ -72,13 +72,80 @@ def register(op_name, attr_key, value=None, level=10):
"""internal register function""" """internal register function"""
_Register(op_name, attr_key, v, level) _Register(op_name, attr_key, v, level)
return v return v
return _register(value) if value else _register return _register(value) if value is not None else _register
def register_schedule(op_name, schedule):
register(op_name, "FTVMSchedule", schedule)
def register_compute(op_name, compute): class OpPattern(object):
register(op_name, "FTVMCompute", compute) """Operator generic patterns
See Also
--------
top.tag : Contains explanation of the tag type.
"""
# Elementwise operator
ELEMWISE = 0
# Broadcast operator
BROADCAST = 1
# Injective mapping
INJECTIVE = 2
# Comunication
COMM_REDUCE = 3
# Complex op, can still fuse ewise into it
OUT_ELEMWISE_FUSABLE = 4
# Not fusable opaque op
OPAQUE = 8
def register_schedule(op_name, schedule=None, level=10):
"""Register schedule function for an op
Parameters
----------
op_name : str
The name of the op.
schedule : function
The schedule function.
level : int
The priority level
"""
return register(op_name, "FTVMSchedule", schedule, level)
def register_compute(op_name, compute=None, level=10):
"""Register compute function for an op.
Parameters
----------
op_name : str
The name of the op.
compute : function
The compute function.
level : int
The priority level
"""
return register(op_name, "FTVMCompute", compute, level)
def register_pattern(op_name, pattern, level=10):
"""Register operator pattern for an op.
Parameters
----------
op_name : str
The name of the op.
pattern : int
The pattern being used.
level : int
The priority level
"""
return register(op_name, "TOpPattern", pattern, level)
_init_api("relay.op", __name__) _init_api("relay.op", __name__)
......
...@@ -266,7 +266,7 @@ def divide(lhs, rhs): ...@@ -266,7 +266,7 @@ def divide(lhs, rhs):
return _make.divide(lhs, rhs) return _make.divide(lhs, rhs)
def pow(lhs, rhs): def power(lhs, rhs):
"""Power with numpy-style broadcasting. """Power with numpy-style broadcasting.
Parameters Parameters
...@@ -281,7 +281,7 @@ def pow(lhs, rhs): ...@@ -281,7 +281,7 @@ def pow(lhs, rhs):
result : relay.Expr result : relay.Expr
The computed result. The computed result.
""" """
return _make.pow(lhs, rhs) return _make.power(lhs, rhs)
def mod(lhs, rhs): def mod(lhs, rhs):
......
...@@ -6,3 +6,4 @@ from . import resnet ...@@ -6,3 +6,4 @@ from . import resnet
from . import dqn from . import dqn
from . import dcgan from . import dcgan
from . import mobilenet from . import mobilenet
from .config import ctx_list
"""Configuration about tests"""
from __future__ import absolute_import as _abs
import os
import tvm
def ctx_list():
"""Get context list for testcases"""
device_list = os.environ.get("RELAY_TEST_TARGETS", "")
device_list = (device_list.split(",") if device_list
else ["llvm", "cuda"])
device_list = set(device_list)
res = [(device, tvm.context(device, 0)) for device in device_list]
return [x for x in res if x[1].exist]
...@@ -154,13 +154,15 @@ std::unordered_set<std::string> TargetNode::libs() const { ...@@ -154,13 +154,15 @@ std::unordered_set<std::string> TargetNode::libs() const {
return result; return result;
} }
std::string TargetNode::str() const { const std::string& TargetNode::str() const {
if (str_repr_.length() != 0) return str_repr_;
std::ostringstream result; std::ostringstream result;
result << target_name; result << target_name;
for (const auto &x : options()) { for (const auto &x : options()) {
result << " " << x; result << " " << x;
} }
return result.str(); str_repr_ = result.str();
return str_repr_;
} }
......
/*!
* Copyright (c) 2018 by Contributors
* \file relay/backend/compile_engine.h
* \brief Internal compialtion engine handle function cache.
* and interface to low level code generation.
*/
#ifndef TVM_RELAY_BACKEND_COMPILE_ENGINE_H_
#define TVM_RELAY_BACKEND_COMPILE_ENGINE_H_
#include <tvm/lowered_func.h>
#include <tvm/relay/expr.h>
#include <string>
#include <functional>
namespace tvm {
namespace relay {
/*! \brief Node container to represent a cached function. */
struct CachedFuncNode : public Node {
/* \brief compiled target */
tvm::Target target;
/*! \brief Function name */
std::string func_name;
/* \brief The inputs to the function */
tvm::Array<Tensor> inputs;
/* \brief The outputs to the function */
tvm::Array<Tensor> outputs;
/*! \brief The lowered functions to support the function. */
tvm::Array<tvm::LoweredFunc> funcs;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("target", &target);
v->Visit("func_name", &func_name);
v->Visit("inputs", &inputs);
v->Visit("outputs", &outputs);
v->Visit("funcs", &funcs);
}
static constexpr const char* _type_key = "relay.CachedFunc";
TVM_DECLARE_NODE_TYPE_INFO(CachedFuncNode, Node);
};
TVM_DEFINE_NODE_REF(CachedFunc, CachedFuncNode);
class CCacheKey;
/*! \brief Compile cache key */
class CCacheKeyNode : public Node {
public:
/*! \brief The source function to be lowered. */
Function source_func;
/*! \brief The hardware target.*/
Target target;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("source_func", &source_func);
v->Visit("target", &target);
}
/*! \return The hash value of CCacheKey. */
inline size_t Hash() const;
/*!
* \brief check content equality
* \param other The other value.
* \return The result of equality check.
*/
inline bool Equal(const CCacheKeyNode* other) const;
/*!
* \brief create a cache key.
* \param source_func The source function.
* \param target The target device.
* \return the created key.
*/
TVM_DLL static CCacheKey make(Function source_func,
Target target);
static constexpr const char* _type_key = "relay.CCacheKey";
TVM_DECLARE_NODE_TYPE_INFO(CCacheKeyNode, tvm::Node);
private:
/*!
* \brief internal cached hash value.
*/
mutable size_t hash_{0};
};
/*! \brief cache entry used in compile engine */
class CCacheKey : public NodeRef {
public:
CCacheKey() {}
explicit CCacheKey(NodePtr<Node> n) : NodeRef(n) {}
const CCacheKeyNode* operator->() const {
return static_cast<CCacheKeyNode*>(node_.get());
}
// comparator
inline bool operator==(const CCacheKey& other) const {
CHECK(defined() && other.defined());
return (*this)->Equal(other.operator->());
}
using ContainerType = CCacheKeyNode;
};
/*! \brief Node container for compile cache. */
class CCacheValueNode : public Node {
public:
/*! \brief The corresponding function */
CachedFunc cached_func;
/*! \brief Result of Packed function generated by JIT */
PackedFunc packed_func;
/*! \brief usage statistics */
int use_count{0};
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("cached_func", &cached_func);
v->Visit("use_count", &use_count);
}
static constexpr const char* _type_key = "relay.CCacheValue";
TVM_DECLARE_NODE_TYPE_INFO(CCacheValueNode, tvm::Node);
};
/*! \brief cache entry used in compile engine */
class CCacheValue : public NodeRef {
public:
CCacheValue() {}
explicit CCacheValue(NodePtr<Node> n) : NodeRef(n) {}
CCacheValueNode* operator->() {
return static_cast<CCacheValueNode*>(node_.get());
}
const CCacheValueNode* operator->() const {
return static_cast<const CCacheValueNode*>(node_.get());
}
using ContainerType = CCacheValueNode;
};
/*!
* \brief Backend compilation engine for
* low level code generation.
*/
class CompileEngineNode : public Node {
public:
/*!
* \brief Get lowered result.
* \param key The key to the cached function.
* \return The result.
*/
virtual CachedFunc Lower(const CCacheKey& key) = 0;
/*!
* \brief Just in time compile to get a PackedFunc.
* \param key The key to the cached function.
* \return The result.
*/
virtual PackedFunc JIT(const CCacheKey& key) = 0;
/*! \brief clear the cache. */
virtual void Clear() = 0;
// VisitAttrs
void VisitAttrs(AttrVisitor*) final {}
static constexpr const char* _type_key = "relay.CompileEngine";
TVM_DECLARE_NODE_TYPE_INFO(CompileEngineNode, Node);
};
/*! \brier cache entry used in compile engine */
class CompileEngine : public NodeRef {
public:
CompileEngine() {}
explicit CompileEngine(NodePtr<Node> n) : NodeRef(n) {}
CompileEngineNode* operator->() {
return static_cast<CompileEngineNode*>(node_.get());
}
using ContainerType = CompileEngineNode;
/*! \brief The global compile engine. */
TVM_DLL static const CompileEngine& Global();
};
// implementations
inline size_t CCacheKeyNode::Hash() const {
if (hash_ != 0) return hash_;
// do structral hash, avoid 0.
hash_ = StructuralHash()(this->source_func);
hash_ = dmlc::HashCombine(
hash_, std::hash<std::string>()(target->str()));
if (hash_ == 0) hash_ = 1;
return hash_;
}
inline bool CCacheKeyNode::Equal(
const CCacheKeyNode* other) const {
if (Hash() != other->Hash()) return false;
return this->target->str() == other->target->str() &&
AlphaEqual(this->source_func, other->source_func);
}
} // namespace relay
} // namespace tvm
namespace std {
// overload hash
template<>
struct hash<::tvm::relay::CCacheKey> {
size_t operator()(const ::tvm::relay::CCacheKey& key) const {
CHECK(key.defined());
return key->Hash();
}
};
} // namespace std
#endif // TVM_RELAY_BACKEND_COMPILE_ENGINE_H_
...@@ -34,10 +34,12 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -34,10 +34,12 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
TensorType ConstantNode::tensor_type() const { TensorType ConstantNode::tensor_type() const {
auto dtype = TVMType2Type(data->dtype); auto dtype = TVMType2Type(data->dtype);
Array<tvm::Expr> shape; Array<tvm::Expr> shape;
for (int i = 0; i < data->ndim; i++) { for (int i = 0; i < data->ndim; i++) {
shape.push_back(tvm::ir::IntImm::make(HalideIR::Int(64), data->shape[i])); CHECK_LE(data->shape[i], std::numeric_limits<int32_t>::max());
CHECK_GE(data->shape[i], std::numeric_limits<int32_t>::min());
shape.push_back(
tvm::ir::IntImm::make(Int(32), data->shape[i]));
} }
return TensorTypeNode::make(shape, dtype); return TensorTypeNode::make(shape, dtype);
......
...@@ -67,13 +67,15 @@ const GenericOpMap& Op::GetGenericAttr(const std::string& key) { ...@@ -67,13 +67,15 @@ const GenericOpMap& Op::GetGenericAttr(const std::string& key) {
return *it->second.get(); return *it->second.get();
} }
void OpRegistry::UpdateAttr(const std::string& key, TVMRetValue value, void OpRegistry::UpdateAttr(const std::string& key,
TVMRetValue value,
int plevel) { int plevel) {
OpManager* mgr = OpManager::Global(); OpManager* mgr = OpManager::Global();
std::lock_guard<std::mutex> lock(mgr->mutex); std::lock_guard<std::mutex> lock(mgr->mutex);
std::unique_ptr<GenericOpMap>& op_map = mgr->attr[key]; std::unique_ptr<GenericOpMap>& op_map = mgr->attr[key];
if (op_map == nullptr) { if (op_map == nullptr) {
op_map.reset(new GenericOpMap()); op_map.reset(new GenericOpMap());
op_map->attr_name_ = key;
} }
uint32_t index = op_->index_; uint32_t index = op_->index_;
if (op_map->data_.size() <= index) { if (op_map->data_.size() <= index) {
...@@ -112,31 +114,31 @@ TVM_REGISTER_API("relay.op._OpGetAttr") ...@@ -112,31 +114,31 @@ TVM_REGISTER_API("relay.op._OpGetAttr")
}); });
TVM_REGISTER_API("relay.op._Register") TVM_REGISTER_API("relay.op._Register")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
std::string op_name = args[0]; std::string op_name = args[0];
std::string attr_key = args[1]; std::string attr_key = args[1];
runtime::TVMArgValue value = args[2]; runtime::TVMArgValue value = args[2];
int plevel = args[3]; int plevel = args[3];
auto& reg = auto& reg =
OpRegistry::Registry()->__REGISTER_OR_GET__(op_name).set_name(); OpRegistry::Registry()->__REGISTER_OR_GET__(op_name).set_name();
// enable resgiteration and override of certain properties // enable resgiteration and override of certain properties
if (attr_key == "num_inputs" && plevel > 128) { if (attr_key == "num_inputs" && plevel > 128) {
reg.set_num_inputs(value); reg.set_num_inputs(value);
} else if (attr_key == "attrs_type_key" && plevel > 128) { } else if (attr_key == "attrs_type_key" && plevel > 128) {
reg.set_attrs_type_key(value); reg.set_attrs_type_key(value);
} else {
// normal attr table override.
if (args[2].type_code() == kFuncHandle) {
// do an eager copy of the PackedFunc
PackedFunc f = args[2];
// If we get a function from frontend, avoid deleting it.
OpManager::Global()->frontend_funcs.push_back(new PackedFunc(f));
reg.set_attr(attr_key, f, plevel);
} else { } else {
// normal attr table override. reg.set_attr(attr_key, args[2], plevel);
if (args[2].type_code() == kFuncHandle) {
// do an eager copy of the PackedFunc
PackedFunc f = args[2];
// If we get a function from frontend, avoid deleting it.
OpManager::Global()->frontend_funcs.push_back(new PackedFunc(f));
reg.set_attr(attr_key, f, plevel);
} else {
reg.set_attr(attr_key, args[2], plevel);
}
} }
}); }
});
NodePtr<Node> CreateOp(const std::string& name) { NodePtr<Node> CreateOp(const std::string& name) {
auto op = Op::Get(name); auto op = Op::Get(name);
......
...@@ -271,7 +271,7 @@ class TextPrinter : ...@@ -271,7 +271,7 @@ class TextPrinter :
TextValue VisitExpr_(const FunctionNode* op) final { TextValue VisitExpr_(const FunctionNode* op) final {
TextValue id = AllocTempVar(); TextValue id = AllocTempVar();
std::ostringstream os; std::ostringstream os;
os << id << " = function"; os << id << " = fn";
this->PrintFuncInternal(os.str(), GetRef<Function>(op)); this->PrintFuncInternal(os.str(), GetRef<Function>(op));
this->PrintEndInst("\n"); this->PrintEndInst("\n");
return id; return id;
...@@ -516,11 +516,14 @@ class TextPrinter : ...@@ -516,11 +516,14 @@ class TextPrinter :
stream_ << ",\n"; stream_ << ",\n";
} }
} }
stream_ << ") "; stream_ << ')';
if (fn->ret_type.defined()) { if (fn->ret_type.defined()) {
stream_ << " -> "; stream_ << '\n';
this->PrintIndent(decl_indent);
stream_ << "-> ";
this->PrintType(fn->ret_type, stream_); this->PrintType(fn->ret_type, stream_);
} }
stream_ << ' ';
this->PrintScope(fn->body); this->PrintScope(fn->body);
} }
/*! /*!
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <vector> #include <vector>
namespace tvm { namespace tvm {
...@@ -44,7 +45,8 @@ std::vector<T> AsVector(const Array<T> &array) { ...@@ -44,7 +45,8 @@ std::vector<T> AsVector(const Array<T> &array) {
}); \ }); \
RELAY_REGISTER_OP(OpName) \ RELAY_REGISTER_OP(OpName) \
.set_num_inputs(1) \ .set_num_inputs(1) \
.add_argument("data", "Tensor", "The input tensor.") .add_argument("data", "Tensor", "The input tensor.") \
.set_attr<TOpPattern>("TOpPattern", kElemWise)
/*! Quick helper macro /*! Quick helper macro
* - Expose a positional make function to construct the node. * - Expose a positional make function to construct the node.
...@@ -68,7 +70,8 @@ std::vector<T> AsVector(const Array<T> &array) { ...@@ -68,7 +70,8 @@ std::vector<T> AsVector(const Array<T> &array) {
.set_num_inputs(2) \ .set_num_inputs(2) \
.add_argument("lhs", "Tensor", "The left hand side tensor.") \ .add_argument("lhs", "Tensor", "The left hand side tensor.") \
.add_argument("rhs", "Tensor", "The right hand side tensor.") \ .add_argument("rhs", "Tensor", "The right hand side tensor.") \
.add_type_rel("Broadcast", BroadcastRel) .add_type_rel("Broadcast", BroadcastRel) \
.set_attr<TOpPattern>("TOpPattern", kBroadcast)
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
......
...@@ -46,7 +46,7 @@ RELAY_REGISTER_BINARY_OP("relay.op._make.", "multiply") ...@@ -46,7 +46,7 @@ RELAY_REGISTER_BINARY_OP("relay.op._make.", "multiply")
.describe("Elementwise multiply with broadcasting") .describe("Elementwise multiply with broadcasting")
.set_support_level(1); .set_support_level(1);
RELAY_REGISTER_BINARY_OP("relay.op._make.", "pow") RELAY_REGISTER_BINARY_OP("relay.op._make.", "power")
.describe("Elementwise power with broadcasting") .describe("Elementwise power with broadcasting")
.set_support_level(4); .set_support_level(4);
...@@ -65,7 +65,8 @@ RELAY_REGISTER_BINARY_OP("relay.op._make.", "mod") ...@@ -65,7 +65,8 @@ RELAY_REGISTER_BINARY_OP("relay.op._make.", "mod")
.set_num_inputs(2) \ .set_num_inputs(2) \
.add_argument("lhs", "Tensor", "The left hand side tensor.") \ .add_argument("lhs", "Tensor", "The left hand side tensor.") \
.add_argument("rhs", "Tensor", "The right hand side tensor.") \ .add_argument("rhs", "Tensor", "The right hand side tensor.") \
.add_type_rel("BroadcastComp", BroadcastCompRel) .add_type_rel("BroadcastComp", BroadcastCompRel) \
.set_attr<TOpPattern>("TOpPattern", kBroadcast)
RELAY_REGISTER_CMP_OP("equal") RELAY_REGISTER_CMP_OP("equal")
.describe("Elementwise equal compare with broadcasting") .describe("Elementwise equal compare with broadcasting")
......
...@@ -3,32 +3,32 @@ ...@@ -3,32 +3,32 @@
* *
* \file src/tvm/relay/pass/fuse_ops.cc * \file src/tvm/relay/pass/fuse_ops.cc
* *
* \brief Fuse Relay eligble sequences of Relay operators into a single one. * \brief This is a backend-aware optimization pass.
* * Fuse necessary ops into a single one.
*/ */
#include <tvm/ir_operator.h>
#include <tvm/relay/pass.h> #include <tvm/relay/pass.h>
#include <tvm/runtime/module.h>
#include <tvm/lowered_func.h>
#include <tvm/operation.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/logging.h>
#include "../ir/type_functor.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
using namespace runtime; // Simple fuser that only makes each operator function as primitive.
class SimpleFuser : public ExprMutator {
struct AbstractFusableOps : ExprMutator { public:
Module mod; // Skip primitive function.
Array<GlobalVar> fusable_funcs; Expr VisitExpr_(const FunctionNode* fn_node) {
int counter = 0; NodeRef res = FunctionGetAttr(GetRef<Function>(fn_node), "Primitive");
size_t expr_hash; const ir::IntImm* pval = res.as<ir::IntImm>();
if (pval && pval->value != 0) {
AbstractFusableOps(Module mod, size_t expr_hash) : mod(mod), expr_hash(expr_hash) {} return GetRef<Expr>(fn_node);
} else {
return ExprMutator::VisitExpr_(fn_node);
}
}
Expr VisitExpr_(const CallNode* call) { Expr VisitExpr_(const CallNode* call) {
if (auto op_node = call->op.as<OpNode>()) { if (call->op.as<OpNode>()) {
// Placeholder fusion algorithm which abstracts // Placeholder fusion algorithm which abstracts
// single definitions into functions only. // single definitions into functions only.
Array<Var> params; Array<Var> params;
...@@ -37,50 +37,37 @@ struct AbstractFusableOps : ExprMutator { ...@@ -37,50 +37,37 @@ struct AbstractFusableOps : ExprMutator {
int param_number = 0; int param_number = 0;
for (auto arg : call->args) { for (auto arg : call->args) {
auto name = std::string("p") + std::to_string(param_number++); std::ostringstream os;
os << "p" << param_number++;
auto type = arg->checked_type(); auto type = arg->checked_type();
auto var = VarNode::make(name, type); auto var = VarNode::make(os.str(), type);
params.push_back(var); params.push_back(var);
inner_args.push_back(var); inner_args.push_back(var);
args.push_back(VisitExpr(arg)); args.push_back(this->Mutate(arg));
} }
auto body = CallNode::make(call->op, inner_args, call->attrs); auto body = CallNode::make(call->op, inner_args, call->attrs);
auto func = FunctionNode::make(params, body, call->checked_type(), {}); auto func = FunctionNode::make(
params, body, call->checked_type(), {});
func = FunctionSetAttr(func, "Primitive", tvm::Integer(1)); func = FunctionSetAttr(func, "Primitive", tvm::Integer(1));
std::string func_name = "fused_"; return CallNode::make(func, args, Attrs());
func_name += op_node->name;
func_name += "_";
func_name += std::to_string(counter++);
func_name += "_";
func_name += std::to_string(expr_hash);
auto gv = GlobalVarNode::make(func_name);
mod->Add(gv, func);
fusable_funcs.push_back(gv);
return CallNode::make(gv, args, Attrs());
} else { } else {
return ExprMutator::VisitExpr_(call); return ExprMutator::VisitExpr_(call);
} }
} }
}; };
Expr FuseOps(const Module& mod, const Expr& e) {
Expr FuseOps(const Expr& expr) {
// First we convert all chains of fusable ops into // First we convert all chains of fusable ops into
// abstracted functions which we mark as primtive // abstracted functions which we mark as primtive
// then we convert these primtive functions into // then we convert these primtive functions into
// new operators. // new operators.
auto abstract = AbstractFusableOps(mod, StructuralHash()(e)); return SimpleFuser().Mutate(expr);
auto abstracted_e = abstract.VisitExpr(e);
RELAY_LOG(INFO) << "FuseOps: before=" << e
<< "Fuse: after=" << abstracted_e;
return abstracted_e;
} }
TVM_REGISTER_API("relay._ir_pass.FuseOps") TVM_REGISTER_API("relay._ir_pass.FuseOps")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = FuseOps(args[1], args[0]); *ret = FuseOps(args[0]);
}); });
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
*
* \file src/tvm/relay/pass/lower_ops.cc
*
* \brief Lower a Relay program to set of TVM operators.
*
*/
#include <tvm/lowered_func.h>
#include <tvm/operation.h>
#include <tvm/build_module.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/logging.h>
#include <tvm/relay/pass.h>
#include <tvm/runtime/module.h>
#include <tvm/relay/build_module.h>
#include "../ir/type_functor.h"
namespace tvm {
namespace relay {
using namespace runtime;
LoweredOp LoweredOpNode::make(Function func, LoweredFunc lowered_func) {
auto node = make_node<LoweredOpNode>();
node->func = func;
node->lowered_func = lowered_func;
return LoweredOp(node);
}
struct AbstractLocalFunctions : ExprMutator {
Module mod;
size_t expr_hash;
int counter = 0;
std::unordered_set<GlobalVar, NodeHash, NodeEqual> visited_funcs;
explicit AbstractLocalFunctions(Module mod)
: mod(mod), expr_hash(0), counter(0), visited_funcs() {}
Expr Abstract(const Expr& e) {
expr_hash = StructuralHash()(e);
return VisitExpr(e);
}
Expr VisitExpr_(const GlobalVarNode* gvar_node) final {
auto gvar = GetRef<GlobalVar>(gvar_node);
auto it = visited_funcs.find(gvar);
if (it == visited_funcs.end()) {
auto func = mod->Lookup(gvar);
visited_funcs.insert(gvar);
auto new_func = FunctionNode::make(
func->params,
VisitExpr(func->body),
func->ret_type,
func->type_params,
func->attrs);
mod->Update(gvar, new_func);
}
return gvar;
}
Expr VisitExpr_(const FunctionNode* func_node) final {
Function func = GetRef<Function>(func_node);
auto free_vars = FreeVars(func);
Array<Var> params;
for (auto free_var : free_vars) {
auto var = VarNode::make("free_var", free_var->checked_type());
params.push_back(var);
}
std::string abs_func = "abstracted_func_";
abs_func += std::to_string(counter++);
abs_func += std::to_string(expr_hash);
auto gv = GlobalVarNode::make(abs_func);
auto lifted_func = FunctionNode::make(params, func, Type(), {}, {});
mod->Add(gv, lifted_func);
Array<Expr> args;
for (auto free_var : free_vars) {
args.push_back(free_var);
}
return CallNode::make(gv, args, {});
}
};
struct LiveFunctions : ExprVisitor {
Module mod;
explicit LiveFunctions(Module mod) : mod(mod), global_funcs() {}
std::unordered_set<GlobalVar, NodeHash, NodeEqual> visited_funcs;
std::unordered_set<GlobalVar, NodeHash, NodeEqual> global_funcs;
void Live(const Expr& e) {
CHECK(!e.as<FunctionNode>())
<< "functions should of been transformed away by previous pass";
VisitExpr(e);
}
void VisitExpr_(const FunctionNode* func_node) {
LOG(FATAL) << "functions should of been transformed away by previous pass";
}
void VisitExpr_(const GlobalVarNode* var_node) final {
GlobalVar var = GetRef<GlobalVar>(var_node);
auto it = visited_funcs.find(var);
if (it == visited_funcs.end()) {
auto func = mod->Lookup(var);
visited_funcs.insert(var);
// The last pass has trasnformed functions of the form:
//
// let x = fn (p_1, ..., p_n) { ... };
// ...
//
// into, a top-level declaration:
//
// def abs_f(fv_1, ..., fv_n) {
// return (fn (p_1...,p_N) { ... };)
// }
//
// and:
//
// let x = abs_f(fv_1, ... fv_n);
//
// The only other case we can handle is
//
// fn foo(...) { body }
//
// We just search through the body in this case.
if (auto inner_func = func->body.as<FunctionNode>()) {
return VisitExpr(inner_func->body);
} else {
return VisitExpr(func->body);
}
}
}
void VisitExpr_(const CallNode* call) final {
RELAY_LOG(INFO) << "LiveOps: CallNode=" << GetRef<Call>(call);
if (auto gv_node = call->op.as<GlobalVarNode>()) {
GlobalVar gvar = GetRef<GlobalVar>(gv_node);
Function func = mod->Lookup(gvar);
auto attr = FunctionGetAttr(func, "Primitive");
if (attr.defined() && Downcast<Integer>(attr)->value == 1) {
global_funcs.insert(gvar);
} else {
VisitExpr(gvar);
}
// Finally we need to ensure to visit all the args no matter what.
for (auto arg : call->args) {
VisitExpr(arg);
}
} else {
return ExprVisitor::VisitExpr_(call);
}
}
};
using FCompute = TypedPackedFunc<Array<Tensor>(
const Attrs&, const Array<Tensor>&, Type, tvm::Target)>;
using FSchedule = TypedPackedFunc<Schedule(const Array<Tensor>&, tvm::Target)>;
/*! \brief Return the set of operators in their TVM format. */
Array<LoweredOp> LowerOps(const Module& mod, const Expr& e,
const std::string& target) {
RELAY_LOG(INFO) << "LowerOps: e=" << e;
auto flower_ptr = Registry::Get("relay.op.compiler._lower");
CHECK(flower_ptr);
PackedFunc flower = *flower_ptr;
auto abstracted_e = AbstractLocalFunctions(mod).Abstract(e);
auto live_funcs = LiveFunctions(mod);
live_funcs.VisitExpr(abstracted_e);
auto schedule_reg = Op::GetAttr<FSchedule>("FTVMSchedule");
auto compute_reg = Op::GetAttr<FCompute>("FTVMCompute");
Array<LoweredOp> lowered_funcs;
for (auto func_name : live_funcs.global_funcs) {
auto func = mod->Lookup(func_name);
auto call = Downcast<Call>(func->body);
auto op_node = call->op.as<OpNode>();
CHECK(op_node) << "violated invariant that primtive calls contain a single op call";
auto op = GetRef<Op>(op_node);
RELAY_LOG(INFO) << "LowerOps: Lowering " << op->name;
CHECK(IsPrimitiveOp(op)) << "failed to lower "
<< op->name << "can only lower primitve operations";
Array<Tensor> inputs;
std::string input_name = "in";
int i = 0;
for (auto type_arg : call->type_args) {
auto tt = Downcast<TensorType>(type_arg);
inputs.push_back(PlaceholderOpNode::make(input_name + std::to_string(i),
tt->shape, tt->dtype)
.output(0));
i++;
}
auto output_tt = call->checked_type();
auto target_node = Target::create(target);
Array<Tensor> outputs =
compute_reg[op](call->attrs, inputs, output_tt, target_node);
auto schedule = schedule_reg[op](outputs, target_node);
size_t hash = StructuralHash()(func);
LoweredFunc lf =
flower(op->name + std::to_string(hash), schedule, inputs, outputs);
func = FunctionSetAttr(func, "LoweredFunc", lf);
mod->Add(func_name, func, true);
lowered_funcs.push_back(LoweredOpNode::make(func, lf));
}
return lowered_funcs;
}
TVM_REGISTER_API("relay._ir_pass.LowerOps")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = LowerOps(args[0], args[1], args[2]);
});
} // namespace relay
} // namespace tvm
...@@ -22,27 +22,6 @@ ...@@ -22,27 +22,6 @@
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
/*!
* \brief The name of Device API factory.
* \param type The device type.
*/
inline std::string DeviceName(int type) {
switch (type) {
case kDLCPU: return "cpu";
case kDLGPU: return "gpu";
case kDLOpenCL: return "opencl";
case kDLSDAccel: return "sdaccel";
case kDLAOCL: return "aocl";
case kDLVulkan: return "vulkan";
case kDLMetal: return "metal";
case kDLVPI: return "vpi";
case kDLROCM: return "rocm";
case kOpenGL: return "opengl";
case kDLExtDev: return "ext_dev";
default: LOG(FATAL) << "unknown type =" << type; return "Unknown";
}
}
class DeviceAPIManager { class DeviceAPIManager {
public: public:
static const int kMaxDeviceAPI = 32; static const int kMaxDeviceAPI = 32;
......
...@@ -187,8 +187,8 @@ void GraphRuntime::SetupStorage() { ...@@ -187,8 +187,8 @@ void GraphRuntime::SetupStorage() {
CHECK_GE(storage_id, 0) << "Do not support runtime shape op"; CHECK_GE(storage_id, 0) << "Do not support runtime shape op";
DLDataType t = vtype[i]; DLDataType t = vtype[i];
size_t bits = t.bits * t.lanes; size_t bits = t.bits * t.lanes;
CHECK_EQ(bits % 8U, 0U); CHECK(bits % 8U == 0U || bits ==1U);
size_t bytes = (bits / 8U) * size; size_t bytes = ((bits + 7U) / 8U) * size;
uint32_t sid = static_cast<uint32_t>(storage_id); uint32_t sid = static_cast<uint32_t>(storage_id);
if (sid >= pool_entry.size()) { if (sid >= pool_entry.size()) {
......
import tvm
import tvm.testing
import numpy as np
from tvm import relay
def test_compile_engine():
engine = relay.backend.compile_engine.get()
def get_func(shape):
x = relay.var("x", shape=shape)
y = relay.add(x, x)
z = relay.add(y, x)
f = relay.ir_pass.infer_type(relay.Function([x], z))
return f
z1 = engine.lower(get_func((10,)), "llvm")
z2 = engine.lower(get_func((10,)), "llvm")
z3 = engine.lower(get_func(()), "llvm")
assert z1.same_as(z2)
assert not z3.same_as(z1)
if tvm.context("cuda").exist:
z4 = engine.lower(get_func(()), "cuda")
assert not z3.same_as(z4)
# Test JIT target
for target in ["llvm"]:
ctx = tvm.context(target)
if ctx.exist:
f = engine.jit(get_func((10,)), target)
x = tvm.nd.array(np.ones(10).astype("float32"), ctx=ctx)
y = tvm.nd.empty((10,), ctx=ctx)
f(x, y)
tvm.testing.assert_allclose(
y.asnumpy(), x.asnumpy() * 3)
engine.dump()
if __name__ == "__main__":
test_compile_engine()
import numpy as np import numpy as np
from tvm import relay from tvm import relay
from tvm.relay import create_executor
from tvm.relay.ir_pass import infer_type from tvm.relay.ir_pass import infer_type
from tvm.relay.interpreter import Interpreter
from tvm.relay.scope_builder import ScopeBuilder from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay.op import add from tvm.relay.op import add
from tvm.relay.module import Module from tvm.relay.module import Module
...@@ -25,8 +23,8 @@ def check_rts(expr, args, expected_result, mod=None): ...@@ -25,8 +23,8 @@ def check_rts(expr, args, expected_result, mod=None):
expected_result: expected_result:
The expected result of running the expression. The expected result of running the expression.
""" """
intrp = create_executor('graph', mod=mod) intrp = relay.create_executor('debug', mod=mod)
graph = create_executor('graph', mod=mod) graph = relay.create_executor('graph', mod=mod)
eval_result = intrp.evaluate(expr)(*args) eval_result = intrp.evaluate(expr)(*args)
rts_result = graph.evaluate(expr)(*args) rts_result = graph.evaluate(expr)(*args)
np.testing.assert_allclose(eval_result.asnumpy(), rts_result.asnumpy()) np.testing.assert_allclose(eval_result.asnumpy(), rts_result.asnumpy())
......
import numpy as np import numpy as np
import tvm import tvm
import tvm.testing
from tvm import relay from tvm import relay
from tvm.relay.interpreter import Value, TupleValue from tvm.relay.backend.interpreter import Value, TupleValue
from tvm.relay import op
from tvm.relay.scope_builder import ScopeBuilder from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay import testing, create_executor from tvm.relay import testing, create_executor
def check_eval(expr, args, expected_result, mod=None, rtol=1e-07): def check_eval(expr, args, expected_result, mod=None, rtol=1e-07):
intrp = create_executor(mod=mod) # TODO(tqchen) add more types once the schedule register is fixed.
result = intrp.evaluate(expr)(*args) for target in ["llvm"]:
np.testing.assert_allclose(result.asnumpy(), expected_result, rtol=rtol) ctx = tvm.context(target, 0)
if not ctx.exist:
return
intrp = create_executor(mod=mod, ctx=ctx, target=target)
result = intrp.evaluate(expr)(*args)
# use tvm.testing which also set atol
tvm.testing.assert_allclose(
result.asnumpy(), expected_result, rtol=rtol)
def test_from_scalar(): def test_from_scalar():
...@@ -34,7 +41,7 @@ def test_id(): ...@@ -34,7 +41,7 @@ def test_id():
def test_add_const(): def test_add_const():
two = op.add(relay.const(1), relay.const(1)) two = relay.add(relay.const(1), relay.const(1))
func = relay.Function([], two) func = relay.Function([], two)
check_eval(func, [], 2) check_eval(func, [], 2)
...@@ -42,7 +49,7 @@ def test_add_const(): ...@@ -42,7 +49,7 @@ def test_add_const():
def test_mul_param(): def test_mul_param():
x = relay.var('x', shape=(10, 10)) x = relay.var('x', shape=(10, 10))
y = relay.var('y', shape=(1, 10)) y = relay.var('y', shape=(1, 10))
func = relay.Function([x, y], op.multiply(x, y)) func = relay.Function([x, y], relay.multiply(x, y))
x_data = np.random.rand(10, 10).astype('float32') x_data = np.random.rand(10, 10).astype('float32')
y_data = np.random.rand(1, 10).astype('float32') y_data = np.random.rand(1, 10).astype('float32')
check_eval(func, [x_data, y_data], x_data * y_data) check_eval(func, [x_data, y_data], x_data * y_data)
...@@ -53,7 +60,7 @@ def test_mul_param(): ...@@ -53,7 +60,7 @@ def test_mul_param():
# def test_dense(): # def test_dense():
# x = relay.var('x', shape=(10, 10)) # x = relay.var('x', shape=(10, 10))
# w = relay.var('w', shape=(10, 10)) # w = relay.var('w', shape=(10, 10))
# y = op.nn.dense(x, w) # y = relay.nn.dense(x, w)
# func = relay.Function([x, w], y) # func = relay.Function([x, w], y)
# x_data = np.random.rand(10, 10).astype('float32') # x_data = np.random.rand(10, 10).astype('float32')
# w_data = np.random.rand(10, 10).astype('float32') # w_data = np.random.rand(10, 10).astype('float32')
...@@ -63,7 +70,7 @@ def test_mul_param(): ...@@ -63,7 +70,7 @@ def test_mul_param():
# x = relay.var('x', shape=(10, 10)) # x = relay.var('x', shape=(10, 10))
# w = relay.var('w', shape=(10, 10)) # w = relay.var('w', shape=(10, 10))
# b = relay.var('b', shape=(10,)) # b = relay.var('b', shape=(10,))
# y = op.add(op.nn.dense(x, w), b) # y = relay.add(relay.nn.dense(x, w), b)
# func = relay.Function([x, w, b], y) # func = relay.Function([x, w, b], y)
# x_data = np.random.rand(10, 10).astype('float32') # x_data = np.random.rand(10, 10).astype('float32')
# w_data = np.random.rand(10, 10).astype('float32') # w_data = np.random.rand(10, 10).astype('float32')
...@@ -73,46 +80,49 @@ def test_mul_param(): ...@@ -73,46 +80,49 @@ def test_mul_param():
def test_equal(): def test_equal():
i = relay.var('i', shape=[], dtype='int32') i = relay.var('i', shape=[], dtype='int32')
j = relay.var('i', shape=[], dtype='int32') j = relay.var('i', shape=[], dtype='int32')
z = op.equal(i, j) z = relay.equal(i, j)
func = relay.Function([i, j], z, ret_type=relay.TensorType([], 'bool')) func = relay.Function([i, j], z, ret_type=relay.TensorType([], 'bool'))
i_data = relay.const(0) i_data = relay.const(0)
j_data = relay.const(0) j_data = relay.const(0)
check_eval(func, [i_data, j_data], True) check_eval(func, [i_data, j_data], True)
def test_subtract(): def test_subtract():
i = relay.var('i', shape=[], dtype='int32') i = relay.var('i', shape=[], dtype='int32')
sub = op.subtract(i, relay.const(1, dtype='int32')) sub = relay.subtract(i, relay.const(1, dtype='int32'))
func = relay.Function([i], sub, ret_type=relay.TensorType([], 'int32')) func = relay.Function([i], sub, ret_type=relay.TensorType([], 'int32'))
i_data = np.array(1, dtype='int32') i_data = np.array(1, dtype='int32')
check_eval(func, [i_data], 0) check_eval(func, [i_data], 0)
def test_simple_loop(): def test_simple_loop():
mod = relay.module.Module({}) mod = relay.module.Module({})
sum_up = relay.GlobalVar('sum_up') sum_up = relay.GlobalVar('sum_up')
i = relay.var('i', shape=[], dtype='int32') i = relay.var('i', shape=[], dtype='int32')
sb = ScopeBuilder() sb = ScopeBuilder()
with sb.if_scope(op.equal(i, relay.const(0, dtype='int32'))): with sb.if_scope(relay.equal(i, relay.const(0, dtype='int32'))):
sb.ret(i) sb.ret(i)
with sb.else_scope(): with sb.else_scope():
one_less = op.subtract(i, relay.const(1, dtype='int32')) one_less = relay.subtract(i, relay.const(1, dtype='int32'))
rec_call = relay.Call(sum_up, [one_less]) rec_call = relay.Call(sum_up, [one_less])
sb.ret(op.add(rec_call, i)) sb.ret(relay.add(rec_call, i))
func = relay.Function([i], sb.get(), ret_type=relay.TensorType([], 'int32')) func = relay.Function([i], sb.get(), ret_type=relay.TensorType([], 'int32'))
mod[sum_up] = func mod[sum_up] = func
i_data = np.array(10, dtype='int32') i_data = np.array(10, dtype='int32')
check_eval(sum_up, [i_data], sum(range(1, 11)), mod=mod) check_eval(sum_up, [i_data], sum(range(1, 11)), mod=mod)
def test_loop(): def test_loop():
mod = relay.module.Module({}) mod = relay.module.Module({})
sum_up = relay.GlobalVar('sum_up') sum_up = relay.GlobalVar('sum_up')
i = relay.var('i', shape=[], dtype='int32') i = relay.var('i', shape=[], dtype='int32')
accum = relay.var('accum', shape=[], dtype='int32') accum = relay.var('accum', shape=[], dtype='int32')
sb = ScopeBuilder() sb = ScopeBuilder()
with sb.if_scope(op.equal(i, relay.const(0))): with sb.if_scope(relay.equal(i, relay.const(0))):
sb.ret(accum) sb.ret(accum)
with sb.else_scope(): with sb.else_scope():
one_less = op.subtract(i, relay.const(1)) one_less = relay.subtract(i, relay.const(1))
new_accum = op.add(accum, i) new_accum = relay.add(accum, i)
sb.ret(relay.Call(sum_up, [one_less, new_accum])) sb.ret(relay.Call(sum_up, [one_less, new_accum]))
func = relay.Function([i, accum], sb.get()) func = relay.Function([i, accum], sb.get())
mod[sum_up] = func mod[sum_up] = func
...@@ -120,19 +130,21 @@ def test_loop(): ...@@ -120,19 +130,21 @@ def test_loop():
accum_data = np.array(0, dtype='int32') accum_data = np.array(0, dtype='int32')
check_eval(sum_up, [i_data, accum_data], sum(range(1, 11)), mod=mod) check_eval(sum_up, [i_data, accum_data], sum(range(1, 11)), mod=mod)
def test_mlp():
pass def test_binds():
# net = testing.mlp.get_workload(1) x = relay.var("x")
# import pdb; pdb.set_trace() y = relay.add(x, x)
intrp = create_executor("debug")
xx = np.ones((10, 20))
res = intrp.evaluate(y, binds={x: xx}).asnumpy()
tvm.testing.assert_allclose(xx + xx, res)
if __name__ == "__main__": if __name__ == "__main__":
test_id() test_id()
test_add_const() test_add_const()
# test_dense()
# test_linear()
test_equal() test_equal()
test_subtract() test_subtract()
test_simple_loop() test_simple_loop()
test_loop() test_loop()
test_mlp() test_binds()
...@@ -2,7 +2,7 @@ import math ...@@ -2,7 +2,7 @@ import math
import tvm import tvm
import numpy as np import numpy as np
from tvm import relay from tvm import relay
from tvm.relay.interpreter import create_executor from tvm.relay.testing import ctx_list
def sigmoid(x): def sigmoid(x):
one = np.ones_like(x) one = np.ones_like(x)
...@@ -27,10 +27,15 @@ def test_unary_op(): ...@@ -27,10 +27,15 @@ def test_unary_op():
if ref is not None: if ref is not None:
data = np.random.rand(*shape).astype(dtype) data = np.random.rand(*shape).astype(dtype)
intrp = create_executor()
op_res = intrp.evaluate(y, { x: relay.const(data) })
ref_res = ref(data) ref_res = ref(data)
np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01) func = relay.Function([x], y)
for target, ctx in ctx_list():
# use graph by execuor default for testing, as we need
# create function explicitly to avoid constant-folding.
intrp = relay.create_executor("graph", ctx=ctx, target=target)
op_res = intrp.evaluate(func)(data)
np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
for opfunc, ref in [(tvm.relay.log, np.log), for opfunc, ref in [(tvm.relay.log, np.log),
(tvm.relay.exp, np.exp), (tvm.relay.exp, np.exp),
...@@ -67,14 +72,17 @@ def test_binary_op(): ...@@ -67,14 +72,17 @@ def test_binary_op():
z = opfunc(x, y) z = opfunc(x, y)
x_data = np.random.rand(5, 10, 5).astype(t1.dtype) x_data = np.random.rand(5, 10, 5).astype(t1.dtype)
y_data = np.random.rand(5, 10, 5).astype(t2.dtype) y_data = np.random.rand(5, 10, 5).astype(t2.dtype)
intrp = create_executor()
op_res = intrp.evaluate(z, { x: relay.const(x_data), y: relay.const(y_data) })
ref_res = ref(x_data, y_data) ref_res = ref(x_data, y_data)
np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01) func = relay.Function([x, y], z)
for target, ctx in ctx_list():
# use graph by execuor default for testing, as we need
# create function explicitly to avoid constant-folding.
intrp = relay.create_executor("graph", ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data, y_data)
np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
for opfunc, ref in [(relay.add, np.add), for opfunc, ref in [(relay.add, np.add),
(relay.subtract, np.subtract), (relay.subtract, np.subtract),
(relay.mod, np.mod),
(relay.multiply, np.multiply), (relay.multiply, np.multiply),
(relay.divide, np.divide)]: (relay.divide, np.divide)]:
check_binary_op(opfunc, ref) check_binary_op(opfunc, ref)
...@@ -116,7 +124,7 @@ def test_log_softmax(): ...@@ -116,7 +124,7 @@ def test_log_softmax():
assert yy.checked_type == relay.TensorType((n, d)) assert yy.checked_type == relay.TensorType((n, d))
def test_concatenate_infer_type(): def test_concatenate():
n, t, d = tvm.var("n"), tvm.var("t"), 100 n, t, d = tvm.var("n"), tvm.var("t"), 100
x = relay.var("x", shape=(n, t, d)) x = relay.var("x", shape=(n, t, d))
y = relay.var("y", shape=(n, t, d)) y = relay.var("y", shape=(n, t, d))
...@@ -134,15 +142,23 @@ def test_concatenate_infer_type(): ...@@ -134,15 +142,23 @@ def test_concatenate_infer_type():
zz = relay.ir_pass.infer_type(z) zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.TensorType((n, t + t, 100)) assert zz.checked_type == relay.TensorType((n, t + t, 100))
# x = relay.var("x", shape=(10, 5)) x = relay.var("x", shape=(10, 5))
# y = relay.var("y", shape=(10, 5)) y = relay.var("y", shape=(10, 5))
# z = relay.concatenate((x, y), axis=1) z = relay.concatenate((x, y), axis=1)
# intrp = create_executor()
# x_data = np.random.rand(10, 5).astype('float32') # Check result.
# y_data = np.random.rand(10, 5).astype('float32') func = relay.Function([x, y], z)
# op_res = intrp.evaluate(z, { x: relay.const(x_data), y: relay.const(y_data) }) x_data = np.random.rand(10, 5).astype('float32')
# ref_res = np.concatenate(x_data, y_data, axis=1) y_data = np.random.rand(10, 5).astype('float32')
# np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01) ref_res = np.concatenate((x_data, y_data), axis=1)
for target, ctx in ctx_list():
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
intrp2 = relay.create_executor("debug", ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(x_data, y_data)
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=0.01)
op_res2 = intrp2.evaluate(func)(x_data, y_data)
tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=0.01)
def test_dropout(): def test_dropout():
n, t, d = tvm.var("n"), tvm.var("t"), tvm.var("d") n, t, d = tvm.var("n"), tvm.var("t"), tvm.var("d")
...@@ -206,7 +222,7 @@ if __name__ == "__main__": ...@@ -206,7 +222,7 @@ if __name__ == "__main__":
test_unary_op() test_unary_op()
test_binary_op() test_binary_op()
test_expand_dims_infer_type() test_expand_dims_infer_type()
test_concatenate_infer_type() test_concatenate()
test_softmax() test_softmax()
test_log_softmax() test_log_softmax()
test_dropout() test_dropout()
......
import tvm import tvm
import numpy as np import numpy as np
from tvm import relay from tvm import relay
from tvm.relay import create_executor from tvm.relay.testing import ctx_list
def test_binary_op(): def test_binary_op():
...@@ -24,12 +24,15 @@ def test_binary_op(): ...@@ -24,12 +24,15 @@ def test_binary_op():
z = opfunc(x, y) z = opfunc(x, y)
x_data = np.random.rand(5, 10, 5).astype(t1.dtype) x_data = np.random.rand(5, 10, 5).astype(t1.dtype)
y_data = np.random.rand(5, 10, 5).astype(t2.dtype) y_data = np.random.rand(5, 10, 5).astype(t2.dtype)
intrp = create_executor()
op_res = intrp.evaluate(z, { x: relay.const(x_data), y: relay.const(y_data) })
ref_res = ref(x_data, y_data) ref_res = ref(x_data, y_data)
np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01) func = relay.Function([x, y], z)
for opfunc, ref in [(relay.pow, np.power)]: for target, ctx in ctx_list():
intrp = relay.create_executor("graph", ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data, y_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res)
for opfunc, ref in [(relay.power, np.power)]:
check_binary_op(opfunc, ref) check_binary_op(opfunc, ref)
...@@ -57,15 +60,19 @@ def test_cmp_type(): ...@@ -57,15 +60,19 @@ def test_cmp_type():
z = op(x, y) z = op(x, y)
x_data = np.random.rand(*x_shape).astype(t1.dtype) x_data = np.random.rand(*x_shape).astype(t1.dtype)
y_data = np.random.rand(*y_shape).astype(t2.dtype) y_data = np.random.rand(*y_shape).astype(t2.dtype)
intrp = create_executor()
op_res = intrp.evaluate(z, { x: relay.const(x_data), y: relay.const(y_data) })
ref_res = ref(x_data, y_data) ref_res = ref(x_data, y_data)
np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01) func = relay.Function([x, y], z)
for target, ctx in ctx_list():
intrp = relay.create_executor("graph", ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data, y_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res)
def test_binary_int_broadcast(): def test_binary_int_broadcast():
for op, ref in [(relay.right_shift, np.right_shift), for op, ref in [(relay.right_shift, np.right_shift),
(relay.left_shift, np.left_shift), (relay.left_shift, np.left_shift),
(relay.mod, np.mod),
(relay.maximum, np.maximum), (relay.maximum, np.maximum),
(relay.minimum, np.minimum)]: (relay.minimum, np.minimum)]:
x = relay.var("x", relay.TensorType((10, 4), "int32")) x = relay.var("x", relay.TensorType((10, 4), "int32"))
...@@ -81,10 +88,14 @@ def test_binary_int_broadcast(): ...@@ -81,10 +88,14 @@ def test_binary_int_broadcast():
t2 = relay.TensorType(y_shape, 'int32') t2 = relay.TensorType(y_shape, 'int32')
x_data = np.random.rand(*x_shape).astype(t1.dtype) x_data = np.random.rand(*x_shape).astype(t1.dtype)
y_data = np.random.rand(*y_shape).astype(t2.dtype) y_data = np.random.rand(*y_shape).astype(t2.dtype)
intrp = create_executor() func = relay.Function([x, y], z)
op_res = intrp.evaluate(z, { x: relay.const(x_data), y: relay.const(y_data) })
ref_res = ref(x_data, y_data) ref_res = ref(x_data, y_data)
np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
for target, ctx in ctx_list():
intrp = relay.create_executor("graph", ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data, y_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res)
def test_where(): def test_where():
cond = relay.var("cond", relay.TensorType((3, 4), "float32")) cond = relay.var("cond", relay.TensorType((3, 4), "float32"))
......
import tvm
from tvm import relay
def test_fuse_simple():
"""Simple testcase."""
x = relay.var("x", shape=(10, 20))
y = relay.add(x, x)
z = relay.exp(y)
z = relay.ir_pass.infer_type(z)
zz = relay.ir_pass.fuse_ops(z)
zz = relay.ir_pass.fuse_ops(zz)
zz = relay.ir_pass.infer_type(zz)
zz.astext()
if __name__ == "__main__":
test_fuse_simple()
...@@ -3,10 +3,9 @@ ...@@ -3,10 +3,9 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import tvm import tvm
import topi import topi
from . import tag
from . import cpp from . import cpp
@tvm.tag_scope(tag=tag.BROADCAST)
def expand_dims(a, axis, num_newaxis=1): def expand_dims(a, axis, num_newaxis=1):
"""Expand the shape of an array. """Expand the shape of an array.
...@@ -25,7 +24,6 @@ def expand_dims(a, axis, num_newaxis=1): ...@@ -25,7 +24,6 @@ def expand_dims(a, axis, num_newaxis=1):
return cpp.expand_dims(a, axis, num_newaxis) return cpp.expand_dims(a, axis, num_newaxis)
@tvm.tag_scope(tag=tag.BROADCAST)
def expand_like(a, shape_like, axis): def expand_like(a, shape_like, axis):
"""Expand an input array with the shape of second array. """Expand an input array with the shape of second array.
This operation can always be composed of unsqueezing and This operation can always be composed of unsqueezing and
...@@ -79,7 +77,6 @@ def expand_like(a, shape_like, axis): ...@@ -79,7 +77,6 @@ def expand_like(a, shape_like, axis):
return tvm.compute(shape_like.shape, _compute) return tvm.compute(shape_like.shape, _compute)
@tvm.tag_scope(tag=tag.INJECTIVE)
def transpose(a, axes=None): def transpose(a, axes=None):
"""Permute the dimensions of an array. """Permute the dimensions of an array.
...@@ -141,7 +138,6 @@ def strided_slice(a, begin, end, strides=None): ...@@ -141,7 +138,6 @@ def strided_slice(a, begin, end, strides=None):
return cpp.strided_slice(a, begin, end, strides) return cpp.strided_slice(a, begin, end, strides)
@tvm.tag_scope(tag=tag.INJECTIVE)
def reshape(a, newshape): def reshape(a, newshape):
"""Reshape the array """Reshape the array
...@@ -159,7 +155,6 @@ def reshape(a, newshape): ...@@ -159,7 +155,6 @@ def reshape(a, newshape):
return cpp.reshape(a, newshape) return cpp.reshape(a, newshape)
@tvm.tag_scope(tag=tag.INJECTIVE)
def squeeze(a, axis=None): def squeeze(a, axis=None):
"""Remove single-dimensional entries from the shape of an array. """Remove single-dimensional entries from the shape of an array.
...@@ -178,7 +173,6 @@ def squeeze(a, axis=None): ...@@ -178,7 +173,6 @@ def squeeze(a, axis=None):
return cpp.squeeze(a, axis) return cpp.squeeze(a, axis)
@tvm.tag_scope(tag=tag.INJECTIVE)
def concatenate(a_tuple, axis=0): def concatenate(a_tuple, axis=0):
"""Join a sequence of arrays along an existing axis. """Join a sequence of arrays along an existing axis.
...@@ -197,7 +191,6 @@ def concatenate(a_tuple, axis=0): ...@@ -197,7 +191,6 @@ def concatenate(a_tuple, axis=0):
return cpp.concatenate(a_tuple, axis) return cpp.concatenate(a_tuple, axis)
@tvm.tag_scope(tag=tag.INJECTIVE)
def split(ary, indices_or_sections, axis=0): def split(ary, indices_or_sections, axis=0):
"""Split an array into multiple sub-arrays. """Split an array into multiple sub-arrays.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment