Unverified Commit c91ded32 by Tianqi Chen Committed by GitHub

[RELAY][BACKEND] CompileEngine refactor. (#2059)

parent 4e77eeb2
......@@ -14,7 +14,6 @@
#include "lowered_func.h"
namespace tvm {
using namespace tvm::runtime;
/*!
* \brief Container for target device information.
......@@ -40,7 +39,7 @@ class TargetNode : public Node {
Array<Expr> libs_array;
/*! \return the full device string to pass to codegen::Build */
EXPORT std::string str() const;
TVM_DLL const std::string& str() const;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("target_name", &target_name);
......@@ -54,16 +53,20 @@ class TargetNode : public Node {
}
/*! \brief Get the keys for this target as a vector of string */
EXPORT std::vector<std::string> keys() const;
TVM_DLL std::vector<std::string> keys() const;
/*! \brief Get the options for this target as a vector of string */
EXPORT std::vector<std::string> options() const;
TVM_DLL std::vector<std::string> options() const;
/*! \brief Get the keys for this target as an unordered_set of string */
EXPORT std::unordered_set<std::string> libs() const;
TVM_DLL std::unordered_set<std::string> libs() const;
static constexpr const char* _type_key = "Target";
TVM_DECLARE_NODE_TYPE_INFO(TargetNode, Node);
private:
/*! \brief Internal string repr. */
mutable std::string str_repr_;
};
class Target : public NodeRef {
......@@ -75,20 +78,20 @@ class Target : public NodeRef {
* \brief Create a Target given a string
* \param target_str the string to parse
*/
EXPORT static Target create(const std::string& target_str);
TVM_DLL static Target create(const std::string& target_str);
/*!
* \brief Push a new target context onto the thread local stack. The Target on top of
* the stack is used to determine which specialization to use when invoking a GenericFunc.
* \param target The target to set as the current context.
*/
EXPORT static void EnterTargetScope(const tvm::Target& target);
TVM_DLL static void EnterTargetScope(const tvm::Target& target);
/*!
* \brief Pop a target off the thread local context stack, restoring the previous target
* as the current context.
*/
EXPORT static void ExitTargetScope();
TVM_DLL static void ExitTargetScope();
/*!
* \brief Get the current target context from thread local storage.
......@@ -98,7 +101,7 @@ class Target : public NodeRef {
* \return The target that is the current context. The target may not be defined if
* allow_not_defined is true.
*/
EXPORT static tvm::Target current_target(bool allow_not_defined = true);
TVM_DLL static tvm::Target current_target(bool allow_not_defined = true);
inline const TargetNode* operator->() const {
return static_cast<const TargetNode*>(node_.get());
......@@ -130,39 +133,39 @@ struct TargetContext {
/*! \brief This namespace provides functions to construct Target instances */
namespace target {
/*! \return A target for LLVM */
EXPORT Target llvm(const std::vector<std::string>& options =
TVM_DLL Target llvm(const std::vector<std::string>& options =
std::vector<std::string>());
/*! \return A target for CUDA */
EXPORT Target cuda(const std::vector<std::string>& options =
TVM_DLL Target cuda(const std::vector<std::string>& options =
std::vector<std::string>());
/*! \return A target for ROCm */
EXPORT Target rocm(const std::vector<std::string>& options =
TVM_DLL Target rocm(const std::vector<std::string>& options =
std::vector<std::string>());
/*! \return A target for OpenCL */
EXPORT Target opencl(const std::vector<std::string>& options =
TVM_DLL Target opencl(const std::vector<std::string>& options =
std::vector<std::string>());
/*! \return A target for Metal */
EXPORT Target metal(const std::vector<std::string>& options =
TVM_DLL Target metal(const std::vector<std::string>& options =
std::vector<std::string>());
/*! \return A target for rasp */
EXPORT Target rasp(const std::vector<std::string>& options =
TVM_DLL Target rasp(const std::vector<std::string>& options =
std::vector<std::string>());
/*! \return A target for Mali */
EXPORT Target mali(const std::vector<std::string>& options =
TVM_DLL Target mali(const std::vector<std::string>& options =
std::vector<std::string>());
/*! \return A target for Intel Graphics */
EXPORT Target intel_graphics(const std::vector<std::string>& options =
TVM_DLL Target intel_graphics(const std::vector<std::string>& options =
std::vector<std::string>());
/*! \return A target for stackvm */
EXPORT Target stackvm(const std::vector<std::string>& options =
TVM_DLL Target stackvm(const std::vector<std::string>& options =
std::vector<std::string>());
} // namespace target
......@@ -212,7 +215,7 @@ class BuildConfigNode : public Node {
bool partition_const_loop = false;
/*! \brief Whether to dump the IR of each pass (only when building from python) */
std::vector< std::pair<int, PackedFunc> > add_lower_pass;
std::vector< std::pair<int, runtime::PackedFunc> > add_lower_pass;
/*! \brief Whether to dump the IR of each pass (only when building from python) */
bool dump_pass_ir = false;
......@@ -255,20 +258,20 @@ class BuildConfig : public ::tvm::NodeRef {
* \brief Push a new BuildConfig context onto the thread local stack.
* \param build_config The configuration to set as the current context.
*/
EXPORT static void EnterBuildConfigScope(const tvm::BuildConfig& build_config);
TVM_DLL static void EnterBuildConfigScope(const tvm::BuildConfig& build_config);
/*!
* \brief Pop a build config off the thread local context stack, restoring the previous
* configuration as the current context.
*/
EXPORT static void ExitBuildConfigScope();
TVM_DLL static void ExitBuildConfigScope();
/*!
* \brief Get the current BuildConfig context from thread local storage, or a default
* configuration if a BuildConfig scope has not been entered.
* \return The configuration that is the current context.
*/
EXPORT static tvm::BuildConfig Current();
TVM_DLL static tvm::BuildConfig Current();
using ContainerType = BuildConfigNode;
};
......@@ -297,7 +300,7 @@ struct BuildConfigContext {
* \brief Construct a BuildConfig containing a new BuildConfigNode
* \return The new BuildConfig
*/
EXPORT BuildConfig build_config();
TVM_DLL BuildConfig build_config();
/*!
* \brief Build a LoweredFunc given a schedule, args and binds
......@@ -308,7 +311,7 @@ EXPORT BuildConfig build_config();
* \param config The build configuration.
* \return The lowered function.
*/
EXPORT Array<LoweredFunc> lower(Schedule sch,
TVM_DLL Array<LoweredFunc> lower(Schedule sch,
const Array<Tensor>& args,
const std::string& name,
const std::unordered_map<Tensor, Buffer>& binds,
......@@ -322,7 +325,7 @@ EXPORT Array<LoweredFunc> lower(Schedule sch,
* \param config The build configuration.
* \return The built module.
*/
EXPORT runtime::Module build(const Array<LoweredFunc>& funcs,
TVM_DLL runtime::Module build(const Array<LoweredFunc>& funcs,
const Target& target,
const Target& target_host,
const BuildConfig& config);
......@@ -344,7 +347,7 @@ class GenericFunc : public NodeRef {
* false, an error will be logged if the call would override a previously registered function.
* \return reference to self.
*/
TVM_DLL GenericFunc& set_default(const PackedFunc value,
TVM_DLL GenericFunc& set_default(const runtime::PackedFunc value,
bool allow_override = false);
/*!
* \brief Register a specialized function
......@@ -355,7 +358,7 @@ class GenericFunc : public NodeRef {
* \return reference to self.
*/
TVM_DLL GenericFunc& register_func(const std::vector<std::string>& tags,
const PackedFunc value,
const runtime::PackedFunc value,
bool allow_override = false);
/*!
* \brief Call generic function by directly passing in unpacked format.
......@@ -372,14 +375,15 @@ class GenericFunc : public NodeRef {
* \endcode
*/
template<typename... Args>
inline TVMRetValue operator()(Args&& ...args) const;
inline runtime::TVMRetValue operator()(Args&& ...args) const;
/*!
* \brief Invoke the relevant function for the current target context, set by set_target_context.
* Arguments are passed in packed format.
* \param args The arguments to pass to the function.
* \param ret The return value
*/
TVM_DLL void CallPacked(TVMArgs args, TVMRetValue* ret) const;
TVM_DLL void CallPacked(runtime::TVMArgs args,
runtime::TVMRetValue* ret) const;
/*!
* \brief Find or register the GenericFunc instance corresponding to the give name
......@@ -412,14 +416,14 @@ class GenericFunc : public NodeRef {
};
template<typename... Args>
inline TVMRetValue GenericFunc::operator()(Args&& ...args) const {
inline runtime::TVMRetValue GenericFunc::operator()(Args&& ...args) const {
const int kNumArgs = sizeof...(Args);
const int kArraySize = kNumArgs > 0 ? kNumArgs : 1;
TVMValue values[kArraySize];
int type_codes[kArraySize];
runtime::detail::for_each(TVMArgsSetter(values, type_codes),
std::forward<Args>(args)...);
TVMRetValue rv;
runtime::TVMRetValue rv;
CallPacked(TVMArgs(values, type_codes, kNumArgs), &rv);
return rv;
}
......@@ -432,9 +436,9 @@ class GenericFuncNode : public Node {
/*! \brief name of the function */
std::string name_;
/* \brief the generic builder */
PackedFunc generic_func_;
runtime::PackedFunc generic_func_;
/* \brief map from keys to registered functions */
std::unordered_map<std::string, PackedFunc> dispatch_dict_;
std::unordered_map<std::string, runtime::PackedFunc> dispatch_dict_;
static constexpr const char* _type_key = "GenericFunc";
TVM_DECLARE_NODE_TYPE_INFO(GenericFuncNode, Node);
......
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/build_module.h
* \brief The passes and data structures needed to build a
* tvm::Module from a Relay program.
*/
#ifndef TVM_RELAY_BUILD_MODULE_H_
#define TVM_RELAY_BUILD_MODULE_H_
#include <tvm/lowered_func.h>
#include <tvm/relay/module.h>
#include <tvm/relay/expr.h>
#include <string>
namespace tvm {
namespace relay {
/*! \brief A lowered Relay operation.
*
* A lowered operation is a pair containing the "primitive" function used
* to produce the lowered function as well as the lowered function itself.
*/
class LoweredOp;
/*! \brief Call container. */
class LoweredOpNode : public Node {
public:
/*!
* \brief The primitive function to be lowered.
*
* A primitive function consists only of calls to relay::Op which
* can be fused.
*/
Function func;
/*!
* \brief The lowered function.
*/
LoweredFunc lowered_func;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("func", &func);
v->Visit("lowered_func", &lowered_func);
}
TVM_DLL static LoweredOp make(
Function func,
LoweredFunc lowered_func);
static constexpr const char* _type_key = "relay.LoweredOp";
TVM_DECLARE_NODE_TYPE_INFO(LoweredOpNode, Node);
};
RELAY_DEFINE_NODE_REF(LoweredOp, LoweredOpNode, NodeRef);
/*!
* \brief Lower the operations contained in a Relay expression.
*
* The lowering pass will only lower functions marked as primitive,
* the FuseOps pass will provide this behavior, if run before LowerOps.
*
* \note This will do a reachability analysis and lower all definitions
* reachable from the provided expression.
*
* \param mod The module.
* \param expr The expression with operations to be lowered.
* \param target The target to lower the functions to.
*
* \return The set of lowered operations.
*/
Array<LoweredOp> LowerOps(const Module& mod, const Expr& expr,
const std::string& target = "llvm");
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_BUILD_MODULE_H_
......@@ -16,6 +16,7 @@
#ifndef TVM_RELAY_INTERPRETER_H_
#define TVM_RELAY_INTERPRETER_H_
#include <tvm/build_module.h>
#include <tvm/relay/module.h>
#include <tvm/relay/expr.h>
......@@ -27,7 +28,9 @@ namespace relay {
*/
class Value;
/*! \brief Evaluate an expression using the interpreter producing a value.
/*!
*\brief Create a Interpreter function that can
* evaluate an expression and produce a value.
*
* The resulting value can be passed to Python, making it easy to use
* for testing and debugging.
......@@ -38,8 +41,14 @@ class Value;
*
* Our intent is that this will never be the most efficient implementation of
* Relay's semantics, but a readable and clear one.
*
* \param mod The function module.
* \param context The primary context that the interepreter runs on.
* \param target Compiler target flag to compile the functions on the context.
* \return A function that takes in an expression and returns a value.
*/
Value Evaluate(Module mod, Expr e);
runtime::TypedPackedFunc<Value(Expr)>
CreateInterpreter(Module mod, DLContext context, Target target);
/*! \brief The base container type of Relay values. */
class ValueNode : public RelayNode {
......@@ -125,9 +134,6 @@ struct TensorValueNode : ValueNode {
/*! \brief Build a value from an NDArray. */
TVM_DLL static TensorValue make(runtime::NDArray data);
/*! \brief Construct an empty tensor value from t. */
TVM_DLL static TensorValue FromType(const Type& t);
static constexpr const char* _type_key = "relay.TensorValue";
TVM_DECLARE_NODE_TYPE_INFO(TensorValueNode, ValueNode);
};
......
/*!
* Copyright (c) 2017 by Contributors
* \file nnvm/compiler/op_attr_types.h
* \brief The Expr and related elements in DataFlow construction.
*/
#ifndef TVM_RELAY_OP_ATTR_TYPES_H_
#define TVM_RELAY_OP_ATTR_TYPES_H_
#include <tvm/tensor.h>
#include <tvm/schedule.h>
#include <tvm/build_module.h>
#include <tvm/relay/type.h>
#include <tvm/relay/expr.h>
namespace tvm {
namespace relay {
/*! \brief operator pattern used in graph fusion */
enum OpPatternKind {
// Elementwise operation
kElemWise = 0,
// Broadcasting operator, can always map output axis to the input in order.
// for example :code:`out[i, ax1, j, ax2] = input[i, j]`.
// Note that the axis need to be in order so transpose is not a bcast operator.
kBroadcast = 1,
// Injective operator, can always injectively map output axis to a single input axis.
// All injective operator can still be safely fused to injective and reduction.
kInjective = 2,
// Communicative reduction operator.
kCommReduce = 3,
// Complex operation, can still fuse elemwise operations into its output.
// but cannot chain another complex op
kOutEWiseFusable = 4,
// Opaque operation, cannot fuse anything.
kOpaque = 8
};
/*! \brief the operator pattern */
using TOpPattern = int;
/*!
* \brief Computation description interface.
*
* \note This function have a special convention
* for functions with tuple input/output.
*
* So far we restrict tuple support to the following case:
* - Function which takes a single tuple as input.
* - Function which outputs a single tuple.
*
* In both cases, the tuple is flattened as array.
*
* \param attrs The attribute of the primitive
* \param inputs The input tensors.
* \param out_type The output type information
& these are always placeholders.
* \return The output compute description of the operator.
*/
using FTVMCompute = runtime::TypedPackedFunc<
Array<Tensor>(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target)>;
/*!
* \brief Build the computation schedule for
* op whose root is at current op.
*
* \param attrs The attribute of the node.
* \param outs The output tensors.
* \param target The build target.
* \return schedule The computation schedule.
*/
using FTVMSchedule = runtime::TypedPackedFunc<
Schedule(const Array<Tensor>& outs,
const Target& target)>;
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_OP_ATTR_TYPES_H_
......@@ -178,6 +178,40 @@ class DeviceAPI {
/*! \brief The device type bigger than this is RPC device */
constexpr int kRPCSessMask = 128;
/*!
* \brief The name of Device API factory.
* \param type The device type.
* \return the device name.
*/
inline const char* DeviceName(int type) {
switch (type) {
case kDLCPU: return "cpu";
case kDLGPU: return "gpu";
case kDLOpenCL: return "opencl";
case kDLSDAccel: return "sdaccel";
case kDLAOCL: return "aocl";
case kDLVulkan: return "vulkan";
case kDLMetal: return "metal";
case kDLVPI: return "vpi";
case kDLROCM: return "rocm";
case kOpenGL: return "opengl";
case kDLExtDev: return "ext_dev";
default: LOG(FATAL) << "unknown type =" << type; return "Unknown";
}
}
#ifndef _LIBCPP_SGX_NO_IOSTREAMS
inline std::ostream& operator<<(std::ostream& os, DLContext ctx) { // NOLINT(*)
int device_type = static_cast<int>(ctx.device_type);
if (device_type > kRPCSessMask) {
os << "remote[" << (device_type / kRPCSessMask) << "]-";
device_type = device_type % kRPCSessMask;
}
os << runtime::DeviceName(device_type) << "(" << ctx.device_id << ")";
return os;
}
#endif
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_DEVICE_API_H_
......@@ -888,6 +888,7 @@ inline std::ostream& operator<<(std::ostream& os, TVMType t) { // NOLINT(*)
}
return os;
}
#endif
inline std::string TVMType2String(TVMType t) {
......
......@@ -132,7 +132,7 @@ class GraphModule(object):
params : dict of str to NDArray
Additonal arguments
"""
if key:
if key is not None:
self._get_input(key).copyfrom(value)
if params:
......
......@@ -7,8 +7,7 @@ from . import ty
from . import expr
from . import module
from . import ir_pass
from .build_module import build
from .interpreter import create_executor
from .build_module import build, create_executor
# Root operators
from .op import Op
......@@ -18,7 +17,7 @@ from .op.transform import *
from . import nn
from . import vision
from . import image
from . import backend
from .scope_builder import ScopeBuilder
......@@ -56,13 +55,6 @@ TupleGetItem = expr.TupleGetItem
var = expr.var
const = expr.const
@register_func("relay._tensor_value_repr")
def _tensor_value_repr(tv):
return str(tv.data.asnumpy())
@register_func("relay._constant_repr")
def _tensor_constant_repr(tv):
return str(tv.data.asnumpy())
# pylint: disable=unused-argument
@register_func("relay.debug")
......
"""The interface to the Evaluator exposed from C++."""
from tvm._ffi.function import _init_api
_init_api("relay._interpreter", __name__)
"""Backend codege modules for relay."""
from . import compile_engine
"""The interface of expr function exposed from C++."""
from __future__ import absolute_import
import logging
from ... import build_module as _build
from ... import container as _container
from ..._ffi.function import _init_api, register_func
@register_func("relay.backend.lower")
def lower(sch, inputs, func_name, source_func):
"""Backend function for lowering.
Parameters
----------
sch : tvm.Schedule
The schedule.
inputs : List[tvm.Tensor]
The inputs to the function.
func_name : str
The name of the function.
source-func : tvm.relay.Function
The source function to be lowered.
Returns
-------
lowered_funcs : List[tvm.LoweredFunc]
The result of lowering.
"""
import traceback
# pylint: disable=broad-except
try:
f = _build.lower(sch, inputs, name=func_name)
logging.debug("lower function %s", func_name)
logging.debug("%s", _build.lower(sch, inputs, simple_mode=True))
except Exception:
msg = traceback.format_exc()
msg += "Error during compile function\n"
msg += "-----------------------------\n"
msg += source_func.astext()
raise RuntimeError(msg)
return f if isinstance(
f, (_container.Array, tuple, list)) else [f]
@register_func("relay.backend.build")
def build(funcs, target, target_host=None):
"""Backend build function.
Parameters
----------
funcs : List[tvm.LoweredFunc]
The list of lowered functions.
target : tvm.Target
The target to run the code on.
target_host : tvm.Target
The host target.
Returns
-------
module : tvm.Module
The runtime module.
"""
if target_host == "":
target_host = None
return _build.build(funcs, target=target, target_host=target_host)
@register_func("relay._tensor_value_repr")
def _tensor_value_repr(tvalue):
return str(tvalue.data.asnumpy())
@register_func("relay._constant_repr")
def _tensor_constant_repr(tvalue):
return str(tvalue.data.asnumpy())
_init_api("relay.backend", __name__)
"""Backend code generation engine."""
from __future__ import absolute_import
from ..base import register_relay_node, NodeBase
from ... import target as _target
from .. import expr as _expr
from . import _backend
@register_relay_node
class CachedFunc(NodeBase):
"""Low-level tensor function to back a relay primitive function.
"""
pass
@register_relay_node
class CCacheKey(NodeBase):
"""Key in the CompileEngine.
Parameters
----------
source_func : tvm.relay.Function
The source function.
target : tvm.Target
The target we want to run the function on.
"""
def __init__(self, source_func, target):
self.__init_handle_by_constructor__(
_backend._make_CCacheKey, source_func, target)
@register_relay_node
class CCacheValue(NodeBase):
"""Value in the CompileEngine, including usage statistics.
"""
pass
def _get_cache_key(source_func, target):
if isinstance(source_func, _expr.Function):
if isinstance(target, str):
target = _target.create(target)
if not target:
raise ValueError("Need target when source_func is a Function")
return CCacheKey(source_func, target)
if not isinstance(source_func, CCacheKey):
raise TypeError("Expect source_func to be CCacheKey")
return source_func
@register_relay_node
class CompileEngine(NodeBase):
"""CompileEngine to get lowered code.
"""
def __init__(self):
raise RuntimeError("Cannot construct a CompileEngine")
def lower(self, source_func, target=None):
"""Lower a source_func to a CachedFunc.
Parameters
----------
source_func : Union[tvm.relay.Function, CCacheKey]
The source relay function.
target : tvm.Target
The target platform.
Returns
-------
cached_func: CachedFunc
The result of lowering.
"""
key = _get_cache_key(source_func, target)
return _backend._CompileEngineLower(self, key)
def jit(self, source_func, target=None):
"""JIT a source_func to a tvm.Function.
Parameters
----------
source_func : Union[tvm.relay.Function, CCacheKey]
The source relay function.
target : tvm.Target
The target platform.
Returns
-------
cached_func: CachedFunc
The result of lowering.
"""
key = _get_cache_key(source_func, target)
return _backend._CompileEngineJIT(self, key)
def clear(self):
"""clear the existing cached functions"""
_backend._CompileEngineClear(self)
def items(self):
"""List items in the cache.
Returns
-------
item_list : List[Tuple[CCacheKey, CCacheValue]]
The list of items.
"""
res = _backend._CompileEngineListItems(self)
assert len(res) % 2 == 0
return [(res[2*i], res[2*i+1]) for i in range(len(res) // 2)]
def dump(self):
"""Return a string representation of engine dump.
Returns
-------
dump : str
The dumped string representation
"""
items = self.items()
res = "====================================\n"
res += "CompilerEngine dump, %d items cached\n" % len(items)
for k, v in items:
res += "------------------------------------\n"
res += "target={}\n".format(k.target)
res += "use_count={}\n".format(v.use_count)
res += "func_name={}\n".format(v.cached_func.func_name)
res += k.source_func.astext() + "\n"
res += "===================================\n"
return res
def get():
"""Get the global compile engine.
Returns
-------
engine : tvm.relay.backend.CompileEngine
The compile engine.
"""
return _backend._CompileEngineGlobal()
"""
A compiler from a Relay expression to TVM's graph runtime.
The compiler is built from a few pieces.
First we define a compiler from a single Relay expression to the
graph langauge. We require the expression to be a function.
The function's parameters correpond to the placeholder/inputs
and model parameters found in the computation graph representation.
The body of the function represents the computation graph.
The compiler's output is a program in the graph language, which is composed of
graph langauge is composed of Node, NodeRef, InputNode, OpNode.
This "little language" represents programs in TVM's graph format.
To connect to the graph runtime, we use a printer that converts our graph format
into TVM's JSON format. The resulting string can be loaded by
contrib.graph_runtime or any other TVM runtime comptatible system.
"""
from __future__ import absolute_import
import json
import attr
from . import compile_engine
from ..op import Op
from ..expr import Function, GlobalVar, ExprFunctor
from ..ty import TupleType, TensorType
@attr.s
class NodeRef(object):
"""A reference to a node, used for constructing the graph."""
ident = attr.ib()
index = attr.ib(default=0)
version = attr.ib(default=0)
def to_json(self):
return [self.ident, self.index, self.version]
@attr.s
class Node(object):
"""The base class for nodes in the TVM runtime system graph input."""
name = attr.ib()
attrs = attr.ib()
def to_json(self):
raise Exception("Abstract method, please implement me.")
@attr.s
class InputNode(Node):
"""An input node in the TVM runtime system graph input."""
name = attr.ib()
attrs = attr.ib()
def to_json(self):
return {
"op": "null",
"name": self.name,
"inputs": []
}
@attr.s
class OpNode(Node):
"""An operator node in the TVM runtime system"s graph input."""
op_name = attr.ib()
inputs = attr.ib()
op_attrs = attr.ib()
num_outputs = attr.ib(default=1)
def to_json(self):
attrs = dict.copy(self.op_attrs)
# Extend ops with extra info.
attrs["func_name"] = self.op_name
attrs["flatten_data"] = "0"
attrs["num_inputs"] = str(len(self.inputs))
attrs["num_outputs"] = str(self.num_outputs)
return {
"op": "tvm_op",
"name": self.name,
"attrs": attrs,
"inputs": self.inputs
}
def shape_to_json(shape):
"""Convert symbolic shape to json compatible forma."""
return [sh.value for sh in shape]
class GraphRuntimeCodegen(ExprFunctor):
"""The compiler from Relay to the TVM runtime system."""
nodes = attr.ib()
var_map = attr.ib()
def __init__(self, mod, target):
ExprFunctor.__init__(self)
self.mod = mod
self.target = target
self.nodes = []
self.var_map = {}
self.compile_engine = compile_engine.get()
self.lowered_funcs = set()
self._name_map = {}
def add_node(self, node, checked_type):
"""
Add a node to the graph.
Parameters
----------
node: Node
The node to add to the graph.
checked_type: Type
The type of the node.
Returns
-------
node_ref: Union[NodeRef, List[NodeRef]]
A reference to the node.
"""
node_id = len(self.nodes)
self.nodes.append(node)
# Tuple return value, flatten as tuple
if isinstance(checked_type, TupleType):
ret = []
shape = []
dtype = []
for i, typ in enumerate(checked_type.fields):
if not isinstance(typ, TensorType):
raise RuntimeError("type %s not supported" % typ)
ret.append(NodeRef(node_id, i))
shape.append(shape_to_json(typ.shape))
dtype.append(typ.dtype)
node.attrs["shape"] = shape
node.attrs["dtype"] = dtype
assert isinstance(node, OpNode)
node.num_outputs = len(checked_type.fields)
return tuple(ret)
# Normal tensor return type
if not isinstance(checked_type, TensorType):
raise RuntimeError("type %s not supported" % checked_type)
node.attrs["shape"] = [shape_to_json(checked_type.shape)]
node.attrs["dtype"] = [checked_type.dtype]
node.num_outputs = 1
return NodeRef(node_id, 0)
def visit_tuple(self, vtuple):
fields = []
for field in vtuple.fields:
ref = self.visit(field)
assert isinstance(ref, NodeRef)
fields.append(ref)
return tuple(fields)
def visit_tuple_getitem(self, op):
vtuple = self.visit(op.tuple_value)
assert isinstance(vtuple, tuple)
return vtuple[op.index]
def visit_constant(self, _):
raise RuntimeError("constant not supported")
def visit_function(self, _):
raise RuntimeError("function not supported")
def visit_if(self, _):
raise RuntimeError("if not supported")
def visit_global_var(self, _):
raise RuntimeError()
def visit_let(self, let):
"""
Visit the let binding, by first traversing its value,
then setting the metadata on the returned NodeRef.
Finally visit the body, and return the NodeRef corresponding
to it.
Parameters
----------
let: tvm.relay.Expr
The let binding to transform.
Returns
-------
ref: NodeRef
The node reference to the body.
"""
assert let.var not in self.var_map
self.var_map[let.var] = self.visit(let.value)
return self.visit(let.body)
def visit_var(self, rvar):
return self.var_map[rvar]
def visit_call(self, call):
"""Transform a ::tvm.relay.Call into an operator in the TVM graph."""
if isinstance(call.op, Op):
raise Exception(
"Operators should be transformed away; try applying" +
"the fuse_ops transformation to the expression.")
elif isinstance(call.op, GlobalVar):
func = self.mod[call.op]
elif isinstance(call.op, Function):
func = call.op
else:
raise Exception(
"TVM runtime does not support calls to {0}".format(type(call.op)))
if int(func.attrs.Primitive) != 1:
raise Exception(
"TVM only support calls to primitive functions " +
"(i.e functions composed of fusable operator invocations)")
cached_func = self.compile_engine.lower(func, self.target)
for loweredf in cached_func.funcs:
self.lowered_funcs.add(loweredf)
inputs = []
tuple_arg_count = 0
for arg in call.args:
if isinstance(arg.checked_type, TupleType):
tuple_arg_count += 1
inputs.append(self.visit(arg))
# We need to specially handle tuple inputs and
# tuple output cases.
# Tuple input function(e.g. concat)
if tuple_arg_count:
assert len(call.args) == 1
assert isinstance(inputs[0], tuple)
inputs = list(inputs[0])
inputs = [x.to_json() for x in inputs]
op_name = cached_func.func_name
op_node = OpNode(self._get_unique_name(op_name), {},
op_name, inputs, {})
return self.add_node(op_node, call.checked_type)
def _get_json(self):
"""
Convert the sequence of nodes stored by the compiler into the
TVM graph runtime format.
Returns
-------
graph_json : str
The generated JSON as a string.
"""
nodes = []
# First we compute "nodes" field.
for node in self.nodes:
nodes.append(node.to_json())
arg_nodes = []
# Compute "arg_nodes" and "heads" fields.
for i, node in enumerate(self.nodes):
if isinstance(node, InputNode):
arg_nodes.append(i)
heads = self.heads
heads = heads if isinstance(heads, tuple) else [heads]
heads = [x.to_json() for x in heads]
# Compute "node_row_ptr" and entry attributes.
num_entry = 0
shapes = []
storage_ids = []
dltypes = []
node_row_ptr = [0]
for node in self.nodes:
assert node.num_outputs == len(node.attrs["shape"])
shapes += node.attrs["shape"]
dltypes += node.attrs["dtype"]
for i in range(node.num_outputs):
storage_ids.append(i + num_entry)
num_entry += node.num_outputs
node_row_ptr.append(num_entry)
# Compute "attrs" field.
attrs = {}
attrs["shape"] = ["list_shape", shapes]
attrs["storage_id"] = ["list_int", storage_ids]
attrs["dltype"] = ["list_str", dltypes]
json_dict = {
"nodes": nodes,
"arg_nodes": arg_nodes,
"heads": heads,
"attrs": attrs,
"node_row_ptr": node_row_ptr
}
return json.dumps(json_dict, indent=2)
def codegen(self, func):
"""Compile a single function into a graph.
Parameters
----------
func: tvm.relay.Expr
The function to compile.
Returns
-------
graph_json : str
The graph json that can be consumed by runtime.
lowered_funcs : List[tvm.LoweredFunc]
The lowered functions.
"""
# First we convert all the parameters into input nodes.
for param in func.params:
node = InputNode(param.name_hint, {})
self.var_map[param] = self.add_node(
node, param.type_annotation)
# Then we compile the body into a graph which can depend
# on input variables.
self.heads = self.visit(func.body)
graph_json = self._get_json()
lowered_funcs = list(self.lowered_funcs)
return graph_json, lowered_funcs
def _get_unique_name(self, name):
if name not in self._name_map:
self._name_map[name] = 1
return name
index = self._name_map[name]
self._name_map[name] += 1
return self.get_unique_name(name + str(index))
#pylint: disable=no-else-return
"""An interface to the Realy interpreter."""
from __future__ import absolute_import
import numpy as np
from .. import register_func, nd
from .base import NodeBase, register_relay_node
from . import build_module
from . import _make
from . import _interpreter
from . import ir_pass
from .module import Module
from .expr import Call, Constant, GlobalVar, Function, const
from .scope_builder import ScopeBuilder
from .._ffi.base import integer_types
from ..contrib import graph_runtime as tvm_runtime
from .. import cpu
from . import _backend
from .. import _make, ir_pass
from ... import register_func, nd
from ..base import NodeBase, register_relay_node
from ..expr import Call, Constant, GlobalVar, Function, const
from ..scope_builder import ScopeBuilder
class Value(NodeBase):
"""Base class of all values.
"""
@staticmethod
@register_func("relay.from_scalar")
def from_scalar(i, dtype=None):
def from_scalar(value, dtype=None):
"""Convert a Python scalar to a Relay scalar."""
if dtype is None:
if isinstance(i, integer_types):
dtype = 'int32'
elif isinstance(i, float):
dtype = 'float32'
elif isinstance(i, bool):
dtype = 'uint8'
else:
raise Exception("unable to infer dtype {0}".format(type(i)))
return TensorValue(nd.array(np.array(i, dtype=dtype)))
return TensorValue(const(value, dtype).data)
@register_relay_node
......@@ -65,10 +50,6 @@ class TensorValue(Value):
self.__init_handle_by_constructor__(
_make.TensorValue, data)
def as_ndarray(self):
"""Convert a Relay TensorValue into a tvm.ndarray."""
return self.data
def asnumpy(self):
"""Convert a Relay TensorValue into a numpy.ndarray."""
return self.data.asnumpy()
......@@ -79,7 +60,7 @@ class TensorValue(Value):
def _arg_to_ast(arg):
if isinstance(arg, TensorValue):
return Constant(arg.data)
return Constant(arg.data.copyto(_nd.cpu(0)))
elif isinstance(arg, np.ndarray):
return Constant(nd.array(arg))
elif isinstance(arg, Constant):
......@@ -87,29 +68,9 @@ def _arg_to_ast(arg):
else:
return const(arg)
class Executor(object):
"""An abstract interface for executing Relay programs."""
def __init__(self, mod=None):
"""
Parameters
----------
mod: relay.Module
The module.
"""
if mod is None:
self.mod = Module({})
else:
self.mod = mod
def optimize(self, expr):
# TODO: We need to move this optimization code into the optimizer/pass manager
ck_expr = ir_pass.infer_type(expr, mod=self.mod)
fused_expr = ir_pass.fuse_ops(self.mod, ck_expr)
ck_fused = ir_pass.infer_type(fused_expr, mod=self.mod)
return ck_fused
def _make_executor(self, _):
"""
Construct a Python function that implements the evaluation
......@@ -122,50 +83,85 @@ class Executor(object):
Returns
-------
executor: function
executor: function,
A Python function which implements the behavior of `expr`.
"""
raise Exception("abstract method: please implement me.")
raise NotImplementedError()
def evaluate(self, expr, params=None):
def evaluate(self, expr, binds=None):
"""
Evaluate a Relay expression on the interpreter.
Evaluate a Relay expression on the executor.
Parameters
----------
expr: tvm.relay.Expr
The expression to evaluate.
binds: Map[tvm.relay.Var, tvm.relay.Expr]
Additional binding of free variable.
Returns
-------
val : Union[function, Value]
The evaluation result.
"""
if params:
if binds:
scope_builder = ScopeBuilder()
for key in params:
value = params[key]
scope_builder.let(key, value)
for key, value in binds.items():
scope_builder.let(key, _arg_to_ast(value))
scope_builder.ret(expr)
expr = scope_builder.get()
if isinstance(expr, Function):
assert not ir_pass.free_vars(expr)
executor = self._make_executor(expr)
# If we are evaluating a function or top-level defintion
# the user must call the function themselves.
#
# If we are evaluating an open term with parameters we will
# just return them the result.
if isinstance(expr, (Function, GlobalVar)):
return executor
else:
return executor()
return self._make_executor(expr)
# normal expression evaluated by running a function.
func = Function([], expr)
return self._make_executor(func)()
class Interpreter(Executor):
"""
A wrapper around the Relay interpreter, implements the excecutor interface.
Simple interpreter interface.
Parameters
----------
mod : tvm.relay.Module
The module to support the execution.
ctx : tvm.TVMContext
The runtime context to run the code on.
target : tvm.Target
The target option to build the function.
"""
def __init__(self, mod, ctx, target):
self.mod = mod
self.ctx = ctx
self.target = target
self._intrp = _backend.CreateInterpreter(mod, ctx, target)
def optimize(self, expr):
"""Optimize an expr.
Parameters
----------
expr : Expr
The expression to be optimized.
Returns
-------
opt_expr : Expr
The optimized expression.
"""
def __init__(self, mod=None):
Executor.__init__(self, mod)
# TODO: We need to move this optimization code into the optimizer/pass manager
ck_expr = ir_pass.infer_type(expr, mod=self.mod)
fused_expr = ir_pass.fuse_ops(ck_expr)
ck_fused = ir_pass.infer_type(fused_expr, mod=self.mod)
return ck_fused
def _make_executor(self, expr):
def _interp_wrapper(*args):
......@@ -178,46 +174,9 @@ class Interpreter(Executor):
func = self.optimize(func)
self.mod._add(expr, func, True)
opt_expr = Call(expr, relay_args)
return _interpreter.evaluate(self.mod, opt_expr)
elif isinstance(expr, Function):
return self._intrp(opt_expr)
else:
call = Call(expr, relay_args)
opt_expr = self.optimize(call)
return _interpreter.evaluate(self.mod, opt_expr)
else:
assert not args
opt_expr = self.optimize(expr)
return _interpreter.evaluate(self.mod, opt_expr)
return self._intrp(opt_expr)
return _interp_wrapper
class GraphRuntime(Executor):
"""A wrapper around the TVM graph runtime, implements the Executor interface."""
def __init__(self, mod=None):
Executor.__init__(self, mod)
def _make_executor(self, expr):
def _graph_wrapper(*args):
func = self.optimize(expr)
graph_json, mod, params = build_module.build(func, mod=self.mod)
assert params is None
gmodule = tvm_runtime.create(graph_json, mod, cpu(0))
# Create map of inputs.
inputs = {}
for i, arg in enumerate(args):
inputs[func.params[i].name_hint] = arg
# Set the inputs here.
gmodule.set_input(**inputs)
# Run the module, and fetch the output.
gmodule.run()
return gmodule.get_output(0)
return _graph_wrapper
def create_executor(mode='debug', mod=None):
if mode == 'debug':
return Interpreter(mod)
elif mode == 'graph':
return GraphRuntime(mod)
else:
raise Exception("unknown mode {0}".format(mode))
......@@ -2,45 +2,257 @@
Construct the necessary state for the TVM graph runtime
from a Relay expression.
"""
from ..build_module import build as tvm_build_module
from . graph_runtime_codegen import GraphRuntimeCodegen
from ..build_module import build as _tvm_build_module
from .. import nd as _nd, target as _target, autotvm
from ..contrib import graph_runtime as _graph_rt
from . import ir_pass
from .module import Module
from .backend import interpreter as _interpreter
from .backend import graph_runtime_codegen as _graph_gen
def build(func, params=None, target=None, mod=None):
# List of optimization pass and level when switch on
OPT_PASS_LEVEL = {
"SimplifyInference": 0,
"OpFusion": 1,
"FoldScaleAxis": 3,
}
class BuildConfig(object):
"""Configuration scope to set a build config option.
Parameters
----------
kwargs
Keyword arguments of configurations to set.
"""
current = None
defaults = {
"opt_level": 2,
"add_pass": None,
}
def __init__(self, **kwargs):
self._old_scope = None
for k, _ in kwargs.items():
if k not in BuildConfig.defaults:
raise ValueError(
"invalid argument %s, candidates are %s" % (k, BuildConfig.defaults.keys()))
self._attr = kwargs
def __getattr__(self, name):
if name not in self._attr:
return BuildConfig.defaults[name]
return self._attr[name]
def __enter__(self):
# pylint: disable=protected-access
self._old_scope = BuildConfig.current
attr = BuildConfig.current._attr.copy()
attr.update(self._attr)
self._attr = attr
BuildConfig.current = self
return self
def __exit__(self, ptype, value, trace):
assert self._old_scope
BuildConfig.current = self._old_scope
def pass_enabled(self, pass_name):
"""Get whether pass is enabled.
Parameters
----------
pass_name : str
The optimization pass name
Returns
-------
enabled : bool
Whether pass is enabled.
"""
if self.add_pass and pass_name in self.add_pass:
return True
return self.opt_level >= OPT_PASS_LEVEL[pass_name]
BuildConfig.current = BuildConfig()
def build_config(**kwargs):
"""Configure the build behavior by setting config variables.
Parameters
----------
opt_level: int, default=2
Optimization level. See OPT_PASS_LEVEL for level of each pass.
add_pass: set of str
Optimization pass to be added regardless of optimization level.
Returns
-------
config: BuildConfig
The build configuration
"""
Compile a single function to the components needed by the
TVM RTS.
return BuildConfig(**kwargs)
def optimize(func):
"""Perform target invariant optimizations.
Parameters
----------
func: relay.Expr
func : tvm.relay.Function
The input to optimization.
Returns
-------
opt_func : tvm.relay.Function
The optimized version of the function.
"""
cfg = BuildConfig.current
if cfg.pass_enabled("FoldScaleAxis"):
func = ir_pass.infer_type(func)
func = ir_pass.simplify_inference(func)
if cfg.pass_enabled("FoldScaleAxis"):
func = ir_pass.infer_type(func)
func = ir_pass.backward_fold_scale_axis(func)
func = ir_pass.infer_type(func)
func = ir_pass.forward_fold_scale_axis(func)
return func
def build(func,
target=None,
target_host=None,
params=None):
"""Build a function to run on TVM graph runtime.
Parameters
----------
func: relay.Function
The function to build.
target: optional str
The target platform.
target : str or :any:`tvm.target.Target`, optional
The build target
target_host : str or :any:`tvm.target.Target` optional
Host compilation target, if target is device.
When TVM compiles device specific program such as CUDA,
we also need host(CPU) side code to interact with the driver
setup the dimensions and parameters correctly.
target_host is used to specify the host side codegen target.
By default, llvm is used if it is enabled,
otherwise a stackvm intepreter is used.
params : dict of str to NDArray
Input parameters to the graph that do not change
during inference time. Used for pre-compute
folding optimization.
Returns
-------
(graph_json, mod, params): tuple of (str, tvm.Module, dict)
The outputs of building a Relay function for the TVM runtime.
graph_json : str
The json string that can be accepted by graph runtime.
mod : tvm.Module
The module containing necessary libraries.
params : dict
The parameters of the final graph.
"""
target = target if target else _target.current_target()
if target is None:
target = 'llvm'
if mod is None:
mod = Module({})
comp = GraphRuntimeCodegen(mod)
# NB(@jroesch) This creates lowered functions, and generates names for them
#
# We need these names to emit the correct graph as these are names of the
# functions contained in the module.
lowered_ops = ir_pass.lower_ops(mod, func)
mod = tvm_build_module([lf.lowered_func for lf in lowered_ops], target)
# Therefore the call to compile must come after.
comp.codegen(func)
graph_json = comp.to_json()
raise ValueError("Target is not set in env or passed as argument.")
target = _target.create(target)
# If current dispatch context is fallback context (the default root context),
# then load pre-tuned parameters from TopHub
if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext):
tophub_context = autotvm.tophub.context(target)
else:
tophub_context = autotvm.util.EmptyContext()
with tophub_context:
func = optimize(func)
# Fuse ops before running code gen
func = ir_pass.infer_type(func)
func = ir_pass.fuse_ops(func)
# Graph code generation
func = ir_pass.infer_type(func)
graph_gen = _graph_gen.GraphRuntimeCodegen(mod=None, target=target)
graph_json, lowered_funcs = graph_gen.codegen(func)
mod = _tvm_build_module(lowered_funcs, target=target, target_host=target_host)
return graph_json, mod, params
class GraphExecutor(_interpreter.Executor):
"""Wrapper around Executor interface.
This executor is used for debug and testing purpoes.
Parameters
----------
mod : tvm.relay.Module
The module to support the execution.
ctx : tvm.TVMContext
The runtime context to run the code on.
target : tvm.Target
The target option to build the function.
"""
def __init__(self, mod, ctx, target):
self.mod = mod
self.ctx = ctx
self.target = target
def _make_executor(self, func):
def _graph_wrapper(*args):
graph_json, mod, params = build(func, target=self.target)
assert params is None
gmodule = _graph_rt.create(graph_json, mod, self.ctx)
# Create map of inputs.
for i, arg in enumerate(args):
gmodule.set_input(i, arg)
# Run the module, and fetch the output.
gmodule.run()
return gmodule.get_output(0)
return _graph_wrapper
def create_executor(kind="debug",
mod=None,
ctx=None,
target="llvm"):
"""Factory function to create an executor.
Parameters
----------
kind : str
The type of executor
mod : relay.Mod
The mod
ctx : tvm.TVMContext
The context to execute the code.
target : tvm.Target
The corresponding context
"""
if ctx is not None:
assert ctx.device_type == _nd.context(str(target), 0).device_type
else:
ctx = _nd.context(str(target), 0)
if isinstance(target, str):
target = _target.create(target)
if kind == "debug":
return _interpreter.Interpreter(mod, ctx, target)
elif kind == "graph":
return GraphExecutor(mod, ctx, target)
else:
raise RuntimeError("unknown mode {0}".format(mode))
......@@ -319,12 +319,11 @@ class TupleGetItem(Expr):
self.__init_handle_by_constructor__(
_make.TupleGetItem, tuple_value, index)
class ExprFunctor(object):
"""
An abstract visitor defined over Expr.
A Python version of the class defined in `expr_functor.h`.
Defines the default dispatch over expressions, and
implements memoization.
"""
......@@ -352,6 +351,8 @@ class ExprFunctor(object):
res = self.visit_if(expr)
elif isinstance(expr, Tuple):
res = self.visit_tuple(expr)
elif isinstance(expr, TupleGetItem):
res = self.visit_tuple_getitem(expr)
elif isinstance(expr, Constant):
res = self.visit_constant(expr)
else:
......@@ -361,31 +362,34 @@ class ExprFunctor(object):
return res
def visit_function(self, _):
raise Exception("Abstract method please implement me.")
raise NotImplementedError()
def visit_let(self, _):
raise Exception("Abstract method please implement me.")
raise NotImplementedError()
def visit_call(self, _):
raise Exception("Abstract method please implement me.")
raise NotImplementedError()
def visit_var(self, _):
raise Exception("Abstract method please implement me.")
raise NotImplementedError()
def visit_type(self, typ):
return typ
def visit_if(self, _):
raise Exception("Abstract method please implement me.")
raise NotImplementedError()
def visit_tuple(self, _):
raise Exception("Abstract method please implement me.")
raise NotImplementedError()
def visit_tuple_getitem(self, _):
raise NotImplementedError()
def visit_constant(self, _):
raise Exception("Abstract method please implement me.")
raise NotImplementedError()
def visit_global_var(self, _):
raise Exception("Abstract method please implement me.")
raise NotImplementedError()
class ExprMutator(ExprFunctor):
......@@ -395,7 +399,6 @@ class ExprMutator(ExprFunctor):
The default behavior recursively traverses the AST
and reconstructs the AST.
"""
def visit_function(self, fn):
new_body = self.visit(fn.body)
return Function(
......@@ -429,9 +432,19 @@ class ExprMutator(ExprFunctor):
def visit_tuple(self, tup):
return Tuple([self.visit(field) for field in tup.fields])
def visit_tuple_getitem(self, op):
tuple_value = self.visit(op.tuple_value)
if not tuple_value.same_as(op.tuple_value):
return TupleGetItem(tuple_value, op.index)
return op
def visit_global_var(self, gvar):
return gvar
def visit_constant(self, rconst):
return rconst
class TupleWrapper(object):
"""TupleWrapper.
......
"""
A compiler from a Relay expression to TVM's graph runtime.
The compiler is built from a few pieces.
First we define a compiler from a single Relay expression to the
graph langauge. We require the expression to be a function.
The function's parameters correpond to the placeholder/inputs
and model parameters found in the computation graph representation.
The body of the function represents the computation graph.
The compiler's output is a program in the graph language, which is composed of
graph langauge is composed of Node, NodeRef, InputNode, OpNode.
This "little language" represents programs in TVM's graph format.
To connect to the graph runtime, we use a printer that converts our graph format
into TVM's JSON format. The resulting string can be loaded by
contrib.graph_runtime or any other TVM runtime comptatible system.
We expose this functionality in compile_to_tvm.
"""
from __future__ import absolute_import
import json
import attr
from . import ir_pass
from .op import Op
from .expr import Function, GlobalVar, ExprMutator
@attr.s
class NodeRef(object):
"""A reference to a node, used for constructing the graph."""
ident = attr.ib()
index = attr.ib(default=0)
version = attr.ib(default=0)
def to_json(self):
return [self.ident, self.index, self.version]
@attr.s
class Node(object):
"""The base class for nodes in the TVM runtime system graph input."""
name = attr.ib()
attrs = attr.ib()
is_output = attr.ib()
def to_json(self):
raise Exception("Abstract method, please implement me.")
@attr.s
class InputNode(Node):
"""An input node in the TVM runtime system graph input."""
name = attr.ib()
attrs = attr.ib()
is_output = attr.ib(default=False)
def to_json(self):
return {
"op": "null",
"name": self.name,
"inputs": []
}
@attr.s
class OpNode(Node):
"""An operator node in the TVM runtime system's graph input."""
op_name = attr.ib()
inputs = attr.ib()
op_attrs = attr.ib()
is_output = attr.ib(default=False)
def to_json(self):
attrs = dict.copy(self.op_attrs)
# Extend ops with extra info.
attrs['func_name'] = self.op_name
# When do we flatten?
attrs['flatten_data'] = "0"
# Fix me!
attrs['num_inputs'] = str(len(self.inputs))
attrs['num_outputs'] = "1"
return {
"op": "tvm_op",
"name": self.name,
"attrs": attrs,
"inputs": self.inputs
}
def shape_to_json(shape):
return [sh.value for sh in shape]
def from_tensor(typ):
return (typ.dtype, shape_to_json(typ.shape))
class GraphRuntimeCodegen(ExprMutator):
"""The compiler from Relay to the TVM runtime system."""
nodes = attr.ib()
id_map = attr.ib()
def __init__(self, env):
ExprMutator.__init__(self)
self.nodes = []
self.id_map = {}
self.env = env
def add_node(self, node):
"""
Add a node to the graph.
Parameters
----------
node: Node
The node to add to the graph.
Returns
-------
node_ref: NodeRef
A reference to the node.
"""
self.nodes.append(node)
ident = len(self.nodes) - 1
return NodeRef(ident)
def add_binding(self, ident, ref):
"""
Add a identifier to node mapping.
Parameters
----------
ident: relay.Var
The variable to map
ref: NodeRef
The node the identifier points.
"""
self.id_map[ident] = ref
def let_bind(self, ident, node):
"""
Let bind node to ident.
Parameters
----------
ident: relay.Var
The variable to map.
ref: NodeRef
The node the identifier points.
Returns
-------
ref: NodeRef
Return reference to the node.
"""
ref = self.add_node(node)
self.add_binding(ident, ref)
return ref
def get_node(self, ref):
"""
Lookup a node by a node reference.
Parameters
----------
ref: NodeRef
The reference to lookup.
Returns
-------
node: Node
The node.
"""
return self.nodes[ref.ident]
def lookup(self, ident):
"""
Lookup a node by identifier.
Parameters
----------
ident: relay.Var
The reference to lookup.
Returns
-------
node: Node
The node.
"""
return self.id_map[ident]
def codegen(self, func):
"""Compile a single function into a graph.
Parameters
----------
func: tvm.relay.Expr
The function to compile.
"""
# First we convert all the parameters into input nodes.
params = func.params
for param in params:
dtype, shape = from_tensor(param.type_annotation)
node = InputNode("{0}".format(param.name_hint), {
"shape": shape,
"dtype": dtype,
})
self.let_bind(param, node)
# Then we compile the body into a graph which can depend
# on input variables.
output_ref = self.visit(func.body)
# Finally we retreive return value of program, which will
# become our output node.
self.get_node(output_ref).is_output = True
def visit_let(self, let):
"""
Visit the let binding, by first traversing its value,
then setting the metadata on the returned NodeRef.
Finally visit the body, and return the NodeRef corresponding
to it.
Parameters
----------
let: tvm.relay.Expr
The let binding to transform.
Returns
-------
ref: NodeRef
The node reference to the body.
"""
ident = let.var
val = let.value
body = let.body
val_ref = self.visit(val)
dtype, shape = from_tensor(val.checked_type())
val_node = self.get_node(val_ref)
val_node.attrs["dtype"] = dtype
val_node.attrs["shape"] = shape
self.add_binding(ident, val_ref)
return self.visit(body)
def visit_var(self, rvar):
return self.lookup(rvar)
def visit_call(self, call):
"""Transform a ::tvm.relay.Call into an operator in the TVM graph."""
inputs = []
for arg in call.args:
inputs.append(self.visit(arg).to_json())
if isinstance(call.op, Op):
raise Exception(
"Operators should be transformed away; try applying" +
"the fuse_ops transformation to the expression.")
elif isinstance(call.op, GlobalVar):
func = self.env[call.op]
elif isinstance(call.op, Function):
func = call.op
else:
raise Exception(
"TVM runtime does not support calls to {0}".format(type(call.op)))
if int(func.attrs.Primitive) != 1:
raise Exception(
"TVM only support calls to primitive functions " +
"(i.e functions composed of fusable operator invocations)")
op_name = func.attrs.LoweredFunc.name
attrs = {'shape': shape_to_json(call.checked_type.shape),
'dtype': call.checked_type.dtype}
call_hash = str(ir_pass.structural_hash(call))
op_node = OpNode("call_" + call_hash, attrs, op_name, inputs, {})
return self.add_node(op_node)
def to_json(self):
"""
Convert the sequence of nodes stored by the compiler into the
TVM graph runtime format.
Returns
-------
graph_json : str
The generated JSON as a string.
"""
nodes = []
# First we compute "nodes" field.
for node in self.nodes:
nodes.append(node.to_json())
arg_nodes = []
heads = []
# Compute "arg_nodes" and "heads" fields.
for i, node in enumerate(self.nodes):
if isinstance(node, InputNode):
arg_nodes.append(i)
if node.is_output:
# Need to fix this.
heads.append(NodeRef(i).to_json())
def compute_node_row_ptr(nodes):
"""Calculate the node_row_ptr field by doing a DFS backwards
from the output and reversing the path.
"""
row_ptr = [len(nodes)]
discovered = set()
stack = []
stack.append(len(nodes) - 1)
while stack:
i = stack.pop()
if i not in discovered:
discovered.add(i)
row_ptr.append(i)
node = nodes[i]
if isinstance(node, OpNode):
for inp in node.inputs:
stack.append(inp[0])
row_ptr.reverse()
return row_ptr
# Compute "node_row_ptr".
node_row_ptr = compute_node_row_ptr(self.nodes)
# Compute "attrs" field.
attrs = {}
# These fields are mandatory.
shapes = []
storage_ids = []
dtype = []
dltype = []
for i, node in enumerate(self.nodes):
storage_ids.append(i)
shapes.append(node.attrs['shape'])
if node.attrs['dtype'] == 'float32':
dtype.append(0)
dltype.append('float32')
attrs["shape"] = ["list_shape", shapes]
attrs["storage_id"] = ["list_int", storage_ids]
attrs["dtype"] = ["list_int", dtype]
attrs["dltype"] = ["list_str", dltype]
json_dict = {
"nodes": nodes,
"arg_nodes": arg_nodes,
"heads": heads,
"attrs": attrs,
"node_row_ptr": node_row_ptr
}
return json.dumps(json_dict)
......@@ -160,6 +160,7 @@ def free_type_vars(expr):
"""
return _ir_pass.free_type_vars(expr)
def simplify_inference(expr):
""" Simplify the data-flow graph for inference phase.
......@@ -176,6 +177,7 @@ def simplify_inference(expr):
"""
return _ir_pass.simplify_inference(expr)
def dead_code_elimination(expr):
""" Remove expressions which does not effect the program result (dead code).
......@@ -256,8 +258,18 @@ def structural_hash(value):
"relay.Expr or relay.Type").format(type(value))
raise TypeError(msg)
def fuse_ops(expr, mod):
return _ir_pass.FuseOps(mod, expr)
def lower_ops(mod, expr, target='llvm'):
return _ir_pass.LowerOps(mod, expr, target)
def fuse_ops(expr):
"""Fuse operators in expr together.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
Returns
-------
transformed_expr : tvm.relay.Expr
Transformed expression, containing fused result.
"""
return _ir_pass.FuseOps(expr)
#pylint: disable=invalid-name, unused-argument
"""Backend compiler related feature registration"""
from __future__ import absolute_import
import tvm
import topi
import topi.cuda
from . import register_schedule, register_compute
from .op import register_compute, register_schedule, register_pattern, OpPattern
def schedule_injective(outputs, target):
"""Generic schedule for binary broadcast."""
with tvm.target.create(target):
with target:
return topi.generic.schedule_injective(outputs)
schedule_broadcast = schedule_injective
schedule_elemwise = schedule_injective
# log
@register_compute("log")
def log_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.log(inputs[0])]
register_compute("log", log_compute)
register_schedule("log", schedule_broadcast)
# exp
@register_compute("exp")
def exp_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.exp(inputs[0])]
register_compute("exp", exp_compute)
register_schedule("exp", schedule_broadcast)
# sqrt
@register_compute("sqrt")
def sqrt_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.sqrt(inputs[0])]
register_compute("sqrt", sqrt_compute)
register_schedule("sqrt", schedule_broadcast)
# sigmoid
@register_compute("sigmoid")
def sigmoid_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.sigmoid(inputs[0])]
register_compute("sigmoid", sigmoid_compute)
register_schedule("sigmoid", schedule_broadcast)
# floor
@register_compute("floor")
def floor_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.floor(inputs[0])]
register_compute("floor", floor_compute)
register_schedule("floor", schedule_broadcast)
# ceil
@register_compute("ceil")
def ceil_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.ceil(inputs[0])]
register_compute("ceil", ceil_compute)
register_schedule("ceil", schedule_broadcast)
# trunc
@register_compute("trunc")
def trunc_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.trunc(inputs[0])]
register_compute("trunc", trunc_compute)
register_schedule("trunc", schedule_broadcast)
# round
@register_compute("round")
def round_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.round(inputs[0])]
register_compute("round", round_compute)
register_schedule("round", schedule_broadcast)
# abs
@register_compute("abs")
def abs_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.abs(inputs[0])]
register_compute("abs", abs_compute)
register_schedule("abs", schedule_broadcast)
# tanh
@register_compute("tanh")
def tanh_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.tanh(inputs[0])]
register_compute("tanh", tanh_compute)
register_schedule("tanh", schedule_broadcast)
# negative
@register_compute("negative")
def negative_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.negative(inputs[0])]
register_compute("negative", negative_compute)
register_schedule("negative", schedule_broadcast)
# add
@register_compute("add")
def add_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.add(inputs[0], inputs[1])]
register_compute("add", add_compute)
register_schedule("add", schedule_injective)
# subtract
@register_compute("subtract")
def subtract_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.subtract(inputs[0], inputs[1])]
register_compute("subtract", subtract_compute)
register_schedule("subtract", schedule_broadcast)
# multiply
@register_compute("multiply")
def multiply_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.multiply(inputs[0], inputs[1])]
register_compute("multiply", multiply_compute)
register_schedule("multiply", schedule_broadcast)
# divide
@register_compute("divide")
def divide_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.divide(inputs[0], inputs[1])]
register_compute("divide", divide_compute)
register_schedule("divide", schedule_broadcast)
# pow
def pow_compute(attrs, inputs, output_type, target):
# power
@register_compute("power")
def power_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.power(inputs[0], inputs[1])]
register_compute("pow", pow_compute)
register_schedule("pow", schedule_injective)
register_schedule("power", schedule_injective)
# mod
@register_compute("mod")
def mod_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.mod(inputs[0], inputs[1])]
register_compute("mod", mod_compute)
register_schedule("mod", schedule_broadcast)
# equal
@register_compute("equal")
def equal_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.equal(inputs[0], inputs[1])]
register_compute("equal", equal_compute)
register_schedule("equal", schedule_broadcast)
# not_equal
@register_compute("not_equal")
def not_equal_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.not_equal(inputs[0], inputs[1])]
register_compute("not_equal", not_equal_compute)
register_schedule("not_equal", schedule_broadcast)
# less
@register_compute("less")
def less_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.less(inputs[0], inputs[1])]
register_compute("less", less_compute)
register_schedule("less", schedule_broadcast)
# less equal
@register_compute("less_equal")
def less_equal_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.less_equal(inputs[0], inputs[1])]
register_compute("less_equal", less_equal_compute)
register_schedule("less_equal", schedule_broadcast)
# greater
@register_compute("greater")
def greater_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.greater(inputs[0], inputs[1])]
register_compute("greater", greater_compute)
register_schedule("greater", schedule_broadcast)
# greater equal
@register_compute("greater_equal")
def greater_equal_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.greater_equal(inputs[0], inputs[1])]
register_compute("greater_equal", greater_equal_compute)
register_schedule("greater_equal", schedule_broadcast)
# maximum
@register_compute("maximum")
def maximum_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.maximum(inputs[0], inputs[1])]
register_compute("maximum_compute", maximum_compute)
register_schedule("maximum_compute", schedule_injective)
# minimum
@register_compute("minimum")
def minimum_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.minimum(inputs[0], inputs[1])]
register_compute("minimum", minimum_compute)
register_schedule("minimum", schedule_injective)
# right shift
@register_compute("right_shift")
def right_shift_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.right_shift(inputs[0], inputs[1])]
register_compute("right_shift", right_shift_compute)
register_schedule("right_shift", schedule_injective)
# lift shift
# left shift
@register_compute("left_shift")
def left_shift_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.left_shift(inputs[0], inputs[1])]
register_compute("left_shift", left_shift_compute)
register_schedule("left_shift", schedule_injective)
# zeros
@register_compute("zeros")
def zeros_compute(attrs, inputs, output_type, target):
assert not inputs
return [topi.full(output_type.shape, output_type.dtype, 0.0)]
register_compute("zeros", zeros_compute)
register_schedule("zeros", schedule_injective)
register_schedule("zeros", schedule_broadcast)
register_pattern("zeros", OpPattern.ELEMWISE)
# zeros_like
@register_compute("zeros_like")
def zeros_like_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.full_like(inputs[0], 0.0)]
register_compute("zeros_like", zeros_like_compute)
register_schedule("zeros_like", schedule_injective)
register_schedule("zeros_like", schedule_broadcast)
# ones
@register_compute("ones")
def ones_compute(attrs, inputs, output_type, target):
assert not inputs
return [topi.full(output_type.shape, output_type.dtype, 1.0)]
register_compute("ones", ones_compute)
register_schedule("ones", schedule_injective)
register_schedule("ones", schedule_broadcast)
register_pattern("ones", OpPattern.ELEMWISE)
# ones_like
@register_compute("ones_like")
def ones_like(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.full_like(inputs[0], 1.0)]
register_compute("ones_like", ones_like)
register_schedule("ones_like", schedule_injective)
register_schedule("ones_like", schedule_broadcast)
# clip
@register_compute("clip")
def clip_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.clip(inputs[0], attrs.a_min, attrs.a_max)]
register_schedule("clip", schedule_elemwise)
register_pattern("clip", OpPattern.ELEMWISE)
register_compute("clip", clip_compute)
register_schedule("clip", schedule_injective)
# concatenate
@register_compute("concatenate")
def concatenate_compute(attrs, inputs, output_type, target):
return [topi.concatenate(inputs, axis=attrs.axis)]
register_schedule("concatenate", schedule_injective)
register_pattern("concatenate", OpPattern.INJECTIVE)
......@@ -72,13 +72,80 @@ def register(op_name, attr_key, value=None, level=10):
"""internal register function"""
_Register(op_name, attr_key, v, level)
return v
return _register(value) if value else _register
return _register(value) if value is not None else _register
def register_schedule(op_name, schedule):
register(op_name, "FTVMSchedule", schedule)
def register_compute(op_name, compute):
register(op_name, "FTVMCompute", compute)
class OpPattern(object):
"""Operator generic patterns
See Also
--------
top.tag : Contains explanation of the tag type.
"""
# Elementwise operator
ELEMWISE = 0
# Broadcast operator
BROADCAST = 1
# Injective mapping
INJECTIVE = 2
# Comunication
COMM_REDUCE = 3
# Complex op, can still fuse ewise into it
OUT_ELEMWISE_FUSABLE = 4
# Not fusable opaque op
OPAQUE = 8
def register_schedule(op_name, schedule=None, level=10):
"""Register schedule function for an op
Parameters
----------
op_name : str
The name of the op.
schedule : function
The schedule function.
level : int
The priority level
"""
return register(op_name, "FTVMSchedule", schedule, level)
def register_compute(op_name, compute=None, level=10):
"""Register compute function for an op.
Parameters
----------
op_name : str
The name of the op.
compute : function
The compute function.
level : int
The priority level
"""
return register(op_name, "FTVMCompute", compute, level)
def register_pattern(op_name, pattern, level=10):
"""Register operator pattern for an op.
Parameters
----------
op_name : str
The name of the op.
pattern : int
The pattern being used.
level : int
The priority level
"""
return register(op_name, "TOpPattern", pattern, level)
_init_api("relay.op", __name__)
......
......@@ -266,7 +266,7 @@ def divide(lhs, rhs):
return _make.divide(lhs, rhs)
def pow(lhs, rhs):
def power(lhs, rhs):
"""Power with numpy-style broadcasting.
Parameters
......@@ -281,7 +281,7 @@ def pow(lhs, rhs):
result : relay.Expr
The computed result.
"""
return _make.pow(lhs, rhs)
return _make.power(lhs, rhs)
def mod(lhs, rhs):
......
......@@ -6,3 +6,4 @@ from . import resnet
from . import dqn
from . import dcgan
from . import mobilenet
from .config import ctx_list
"""Configuration about tests"""
from __future__ import absolute_import as _abs
import os
import tvm
def ctx_list():
"""Get context list for testcases"""
device_list = os.environ.get("RELAY_TEST_TARGETS", "")
device_list = (device_list.split(",") if device_list
else ["llvm", "cuda"])
device_list = set(device_list)
res = [(device, tvm.context(device, 0)) for device in device_list]
return [x for x in res if x[1].exist]
......@@ -154,13 +154,15 @@ std::unordered_set<std::string> TargetNode::libs() const {
return result;
}
std::string TargetNode::str() const {
const std::string& TargetNode::str() const {
if (str_repr_.length() != 0) return str_repr_;
std::ostringstream result;
result << target_name;
for (const auto &x : options()) {
result << " " << x;
}
return result.str();
str_repr_ = result.str();
return str_repr_;
}
......
/*!
* Copyright (c) 2018 by Contributors
* \file relay/backend/compile_engine.cc
* \brief Internal compialtion engine.
*/
#include <tvm/schedule.h>
#include <tvm/packed_func_ext.h>
#include <tvm/operation.h>
#include <tvm/runtime/registry.h>
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
#include <utility>
#include <limits>
#include <mutex>
#include <functional>
#include "compile_engine.h"
namespace tvm {
namespace relay {
CCacheKey CCacheKeyNode::make(Function source_func, Target target) {
auto n = make_node<CCacheKeyNode>();
n->source_func = std::move(source_func);
n->target = std::move(target);
return CCacheKey(n);
}
// The getter to get schedule from compile engine.
// Get schedule from functor.
class ScheduleGetter :
public ExprFunctor<Array<Tensor>(const Expr&)> {
public:
explicit ScheduleGetter(Target target)
: target_(target) {}
Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
// for now, we always use int32 shape when possible
// even if the result of shape inference becomes int64.
Array<IndexExpr> res;
for (IndexExpr val : shape) {
const int64_t* pval = as_const_int(val);
if (pval != nullptr) {
CHECK_LE(pval[0], std::numeric_limits<int32_t>::max());
CHECK_GE(pval[0], std::numeric_limits<int32_t>::min());
res.push_back(ir::IntImm::make(Int(32), *pval));
} else {
res.push_back(val);
}
}
return res;
}
std::pair<Schedule, CachedFunc> Create(const Function& prim_func) {
static auto fschedule =
Op::GetAttr<FTVMSchedule>("FTVMSchedule");
auto cache_node = make_node<CachedFuncNode>();
cache_node->target = target_;
if (prim_func->params.size() == 1 &&
prim_func->params[0]->checked_type().as<TupleTypeNode>()) {
// Handle tuple input type by flattening them.
// This is the current calling convention of tuple input.
Array<tvm::Tensor> inputs;
for (Type field : prim_func->params[0]->type_as<TupleTypeNode>()->fields) {
const auto* ttype = field.as<TensorTypeNode>();
CHECK(ttype != nullptr);
tvm::Tensor tensor = tvm::placeholder(
GetShape(ttype->shape), ttype->dtype);
cache_node->inputs.push_back(tensor);
inputs.push_back(tensor);
}
memo_[prim_func->params[0]] = inputs;
} else {
for (Var param : prim_func->params) {
const auto* ttype = param->type_as<TensorTypeNode>();
tvm::Tensor tensor = tvm::placeholder(
GetShape(ttype->shape), ttype->dtype);
cache_node->inputs.push_back(tensor);
memo_[param] = Array<Tensor>({tensor});
}
}
readable_name_stream_ << "fused";
// enter the target context
TargetContext target_ctx(target_);
cache_node->outputs = this->VisitExpr(prim_func->body);
cache_node->func_name = readable_name_stream_.str();
CachedFunc cfunc(cache_node);
CHECK(master_op_.defined());
Schedule schedule = fschedule[master_op_](
cache_node->outputs, target_);
return std::make_pair(schedule, cfunc);
}
Array<Tensor> VisitExpr(const Expr& expr) {
auto it = memo_.find(expr);
if (it != memo_.end()) {
return it->second;
} else {
Array<Tensor> res = ExprFunctor::VisitExpr(expr);
memo_[expr] = res;
return res;
}
}
Array<Tensor> VisitExpr_(const VarNode* op) final {
LOG(FATAL) << "Free variable " << op->name_hint;
return {};
}
Array<Tensor> VisitExpr_(const CallNode* call_node) final {
static auto fcompute =
Op::GetAttr<FTVMCompute>("FTVMCompute");
static auto fpattern =
Op::GetAttr<TOpPattern>("TOpPattern");
Array<Tensor> inputs;
int count_tuple = 0;
for (Expr arg : call_node->args) {
if (arg->checked_type().as<TupleTypeNode>()) {
++count_tuple;
}
for (Tensor tensor : VisitExpr(arg)) {
inputs.push_back(tensor);
}
}
if (count_tuple) {
CHECK_EQ(call_node->args.size(), 1U)
<< "Only allow function with a single tuple input";
}
CHECK(call_node->op.as<OpNode>())
<< "Primitive function only allows call into primitive ops";
Op op = Downcast<Op>(call_node->op);
Array<Tensor> outputs = fcompute[op](
call_node->attrs,
inputs,
call_node->checked_type(),
target_);
int op_pattern = fpattern[op];
if (op_pattern >= kCommReduce) {
CHECK(!master_op_.defined())
<< "Two complicated op in a primitive function";
}
if (op_pattern >= master_op_patetrn_) {
master_op_ = op;
master_op_patetrn_ = op_pattern;
}
if (outputs.size() != 1) {
const auto* tuple_type =
call_node->checked_type().as<TupleTypeNode>();
CHECK(tuple_type) << "Expect output to be a tuple type";
CHECK_EQ(tuple_type->fields.size(), outputs.size());
}
readable_name_stream_ << '_' << op->name;
return outputs;
}
Array<Tensor> VisitExpr_(const FunctionNode* op) final {
LOG(FATAL) << "Do not support sub function";
return Array<Tensor>();
}
Array<Tensor> VisitExpr_(const LetNode* op) final {
Array<Tensor> val = VisitExpr(op->value);
CHECK(!memo_.count(op->var));
memo_[op->var] = val;
return VisitExpr(op->body);
}
Array<Tensor> VisitExpr_(const TupleNode* op) final {
Array<Tensor> fields;
for (Expr field : op->fields) {
CHECK(field->checked_type().as<TensorTypeNode>())
<< "Only allow Tuple of Tensor";
Array<Tensor> res = VisitExpr(field);
CHECK_EQ(res.size(), 1);
fields.push_back(res[0]);
}
return fields;
}
Array<Tensor> VisitExpr_(const TupleGetItemNode* op) final {
const auto* tuple_type = op->tuple->type_as<TupleTypeNode>();
Array<Tensor> tuple = VisitExpr(op->tuple);
CHECK_EQ(tuple_type->fields.size(), tuple.size());
CHECK_GE(op->index, 0);
CHECK_LT(static_cast<size_t>(op->index), tuple.size());
return {tuple[op->index]};
}
private:
tvm::Target target_;
Op master_op_;
int master_op_patetrn_{0};
std::ostringstream readable_name_stream_;
std::unordered_map<Expr, Array<Tensor>, NodeHash, NodeEqual> memo_;
};
class CompileEngineImpl : public CompileEngineNode {
public:
// Lower the fucntion.
CachedFunc Lower(const CCacheKey& key) {
return LowerInternal(key)->cached_func;
}
// For now, build one module per function.
PackedFunc JIT(const CCacheKey& key) final {
CCacheValue value = LowerInternal(key);
if (value->packed_func != nullptr) return value->packed_func;
// build the function.
if (const auto* f = runtime::Registry::Get("relay.backend.build")) {
tvm::runtime::Module m = (*f)(value->cached_func->funcs, key->target);
value->packed_func = m.GetFunction(value->cached_func->func_name);
} else {
LOG(FATAL) << "relay.backend.build is not registered";
}
return value->packed_func;
}
void Clear() final {
cache_.clear();
}
// List all items in the cache.
Array<NodeRef> ListItems() {
std::lock_guard<std::mutex> lock(mutex_);
Array<NodeRef> items;
for (auto& kv : cache_) {
items.push_back(kv.first);
items.push_back(kv.second);
}
return items;
}
/*!
* \brief Create schedule for target.
* \param source_func The primitive function to be lowered.
* \param target The target we want to create schedule for.
* \return Pair of schedule and cache.
* The funcs field in cache is not yet populated.
*/
std::pair<Schedule, CachedFunc> CreateSchedule(
const Function& source_func, const Target& target) {
return ScheduleGetter(target).Create(source_func);
}
private:
// implement lowered func
CCacheValue LowerInternal(const CCacheKey& key) {
std::lock_guard<std::mutex> lock(mutex_);
CCacheValue value;
auto it = cache_.find(key);
if (it != cache_.end()) {
it->second->use_count += 1;
if (it->second->cached_func.defined()) return it->second;
value = it->second;
} else {
value = CCacheValue(make_node<CCacheValueNode>());
value->use_count = 0;
cache_[key] = value;
}
CHECK(!value->cached_func.defined());
auto spair = CreateSchedule(key->source_func, key->target);
auto cache_node = make_node<CachedFuncNode>(
*(spair.second.operator->()));
cache_node->func_name = GetUniqeName(cache_node->func_name);
// NOTE: array will copy on write.
Array<Tensor> all_args = cache_node->inputs;
for (Tensor arg : cache_node->outputs) {
all_args.push_back(arg);
}
// lower the function
if (const auto* f = runtime::Registry::Get("relay.backend.lower")) {
cache_node->funcs = (*f)(
spair.first, all_args, cache_node->func_name, key->source_func);
} else {
LOG(FATAL) << "relay.backend._lower is not registred";
}
value->cached_func = CachedFunc(cache_node);
return value;
}
/*!
* \brief Get unique name from name.
* \param name The orginal name.
* \return Updated name which is unique.
*/
std::string GetUniqeName(std::string name) {
while (true) {
auto it = name_map_.find(name);
if (it == name_map_.end()) {
name_map_[name] = 1;
return name;
} else {
std::ostringstream os;
os << name << "_" << it->second;
++(it->second);
name = os.str();
}
}
return name;
}
/*! \brief compiler cache lock*/
std::mutex mutex_;
/*! \brief internal name map to get an unique name */
std::unordered_map<std::string, int> name_map_;
/*! \brief internal compiler cache */
std::unordered_map<CCacheKey, CCacheValue> cache_;
};
/*! \brief The global compile engine */
const CompileEngine& CompileEngine::Global() {
// intentionally allocate raw pointer to avoid
// free during destructuion.
static CompileEngine* inst = new CompileEngine(
make_node<CompileEngineImpl>());
return *inst;
}
TVM_REGISTER_GLOBAL("relay.backend._make_CCacheKey")
.set_body_typed<CCacheKey(Function, Target)>(CCacheKeyNode::make);
TVM_REGISTER_GLOBAL("relay.backend._CompileEngineGlobal")
.set_body_typed<CompileEngine()>([]() {
return CompileEngine::Global();
});
TVM_REGISTER_GLOBAL("relay.backend._CompileEngineClear")
.set_body_typed<void(const CompileEngine&)>([](CompileEngine self) {
self->Clear();
});
TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLower")
.set_body_typed<CachedFunc(CompileEngine, CCacheKey)>(
[](CompileEngine self, CCacheKey key) {
return self->Lower(key);
});
TVM_REGISTER_GLOBAL("relay.backend._CompileEngineJIT")
.set_body_typed<PackedFunc(CompileEngine, CCacheKey)>(
[](CompileEngine self, CCacheKey key) {
return self->JIT(key);
});
TVM_REGISTER_GLOBAL("relay.backend._CompileEngineListItems")
.set_body_typed<Array<NodeRef>(CompileEngine)>(
[](CompileEngine self){
return static_cast<CompileEngineImpl*>(self.operator->())->ListItems();
});
} // namespace relay
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
* \file relay/backend/compile_engine.h
* \brief Internal compialtion engine handle function cache.
* and interface to low level code generation.
*/
#ifndef TVM_RELAY_BACKEND_COMPILE_ENGINE_H_
#define TVM_RELAY_BACKEND_COMPILE_ENGINE_H_
#include <tvm/lowered_func.h>
#include <tvm/relay/expr.h>
#include <string>
#include <functional>
namespace tvm {
namespace relay {
/*! \brief Node container to represent a cached function. */
struct CachedFuncNode : public Node {
/* \brief compiled target */
tvm::Target target;
/*! \brief Function name */
std::string func_name;
/* \brief The inputs to the function */
tvm::Array<Tensor> inputs;
/* \brief The outputs to the function */
tvm::Array<Tensor> outputs;
/*! \brief The lowered functions to support the function. */
tvm::Array<tvm::LoweredFunc> funcs;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("target", &target);
v->Visit("func_name", &func_name);
v->Visit("inputs", &inputs);
v->Visit("outputs", &outputs);
v->Visit("funcs", &funcs);
}
static constexpr const char* _type_key = "relay.CachedFunc";
TVM_DECLARE_NODE_TYPE_INFO(CachedFuncNode, Node);
};
TVM_DEFINE_NODE_REF(CachedFunc, CachedFuncNode);
class CCacheKey;
/*! \brief Compile cache key */
class CCacheKeyNode : public Node {
public:
/*! \brief The source function to be lowered. */
Function source_func;
/*! \brief The hardware target.*/
Target target;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("source_func", &source_func);
v->Visit("target", &target);
}
/*! \return The hash value of CCacheKey. */
inline size_t Hash() const;
/*!
* \brief check content equality
* \param other The other value.
* \return The result of equality check.
*/
inline bool Equal(const CCacheKeyNode* other) const;
/*!
* \brief create a cache key.
* \param source_func The source function.
* \param target The target device.
* \return the created key.
*/
TVM_DLL static CCacheKey make(Function source_func,
Target target);
static constexpr const char* _type_key = "relay.CCacheKey";
TVM_DECLARE_NODE_TYPE_INFO(CCacheKeyNode, tvm::Node);
private:
/*!
* \brief internal cached hash value.
*/
mutable size_t hash_{0};
};
/*! \brief cache entry used in compile engine */
class CCacheKey : public NodeRef {
public:
CCacheKey() {}
explicit CCacheKey(NodePtr<Node> n) : NodeRef(n) {}
const CCacheKeyNode* operator->() const {
return static_cast<CCacheKeyNode*>(node_.get());
}
// comparator
inline bool operator==(const CCacheKey& other) const {
CHECK(defined() && other.defined());
return (*this)->Equal(other.operator->());
}
using ContainerType = CCacheKeyNode;
};
/*! \brief Node container for compile cache. */
class CCacheValueNode : public Node {
public:
/*! \brief The corresponding function */
CachedFunc cached_func;
/*! \brief Result of Packed function generated by JIT */
PackedFunc packed_func;
/*! \brief usage statistics */
int use_count{0};
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("cached_func", &cached_func);
v->Visit("use_count", &use_count);
}
static constexpr const char* _type_key = "relay.CCacheValue";
TVM_DECLARE_NODE_TYPE_INFO(CCacheValueNode, tvm::Node);
};
/*! \brief cache entry used in compile engine */
class CCacheValue : public NodeRef {
public:
CCacheValue() {}
explicit CCacheValue(NodePtr<Node> n) : NodeRef(n) {}
CCacheValueNode* operator->() {
return static_cast<CCacheValueNode*>(node_.get());
}
const CCacheValueNode* operator->() const {
return static_cast<const CCacheValueNode*>(node_.get());
}
using ContainerType = CCacheValueNode;
};
/*!
* \brief Backend compilation engine for
* low level code generation.
*/
class CompileEngineNode : public Node {
public:
/*!
* \brief Get lowered result.
* \param key The key to the cached function.
* \return The result.
*/
virtual CachedFunc Lower(const CCacheKey& key) = 0;
/*!
* \brief Just in time compile to get a PackedFunc.
* \param key The key to the cached function.
* \return The result.
*/
virtual PackedFunc JIT(const CCacheKey& key) = 0;
/*! \brief clear the cache. */
virtual void Clear() = 0;
// VisitAttrs
void VisitAttrs(AttrVisitor*) final {}
static constexpr const char* _type_key = "relay.CompileEngine";
TVM_DECLARE_NODE_TYPE_INFO(CompileEngineNode, Node);
};
/*! \brier cache entry used in compile engine */
class CompileEngine : public NodeRef {
public:
CompileEngine() {}
explicit CompileEngine(NodePtr<Node> n) : NodeRef(n) {}
CompileEngineNode* operator->() {
return static_cast<CompileEngineNode*>(node_.get());
}
using ContainerType = CompileEngineNode;
/*! \brief The global compile engine. */
TVM_DLL static const CompileEngine& Global();
};
// implementations
inline size_t CCacheKeyNode::Hash() const {
if (hash_ != 0) return hash_;
// do structral hash, avoid 0.
hash_ = StructuralHash()(this->source_func);
hash_ = dmlc::HashCombine(
hash_, std::hash<std::string>()(target->str()));
if (hash_ == 0) hash_ = 1;
return hash_;
}
inline bool CCacheKeyNode::Equal(
const CCacheKeyNode* other) const {
if (Hash() != other->Hash()) return false;
return this->target->str() == other->target->str() &&
AlphaEqual(this->source_func, other->source_func);
}
} // namespace relay
} // namespace tvm
namespace std {
// overload hash
template<>
struct hash<::tvm::relay::CCacheKey> {
size_t operator()(const ::tvm::relay::CCacheKey& key) const {
CHECK(key.defined());
return key->Hash();
}
};
} // namespace std
#endif // TVM_RELAY_BACKEND_COMPILE_ENGINE_H_
......@@ -3,15 +3,12 @@
* \file src/tvm/relay/interpreter.cc
* \brief An interpreter for the Relay IR.
*/
#include <tvm/codegen.h>
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/device_api.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/interpreter.h>
#include <tvm/relay/logging.h>
#include <tvm/relay/pass.h>
#include <tvm/relay/build_module.h>
#include "./ir/type_functor.h"
#include "compile_engine.h"
namespace tvm {
namespace relay {
......@@ -33,12 +30,12 @@ Closure ClosureNode::make(tvm::Map<Var, Value> env, Function func) {
}
TVM_REGISTER_API("relay._make.Closure")
.set_body([](TVMArgs args, TVMRetValue* ret) {
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ClosureNode::make(args[0], args[1]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<ClosureNode>([](const ClosureNode* node, tvm::IRPrinter* p) {
.set_dispatch<ClosureNode>([](const ClosureNode* node, tvm::IRPrinter* p) {
p->stream << "ClosureNode(" << node->func << ")";
});
......@@ -49,13 +46,12 @@ TupleValue TupleValueNode::make(tvm::Array<Value> value) {
}
TVM_REGISTER_API("relay._make.TupleValue")
.set_body([](TVMArgs args, TVMRetValue* ret) {
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TupleValueNode::make(args[0]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TupleValueNode>([](const TupleValueNode* node,
tvm::IRPrinter* p) {
.set_dispatch<TupleValueNode>([](const TupleValueNode* node, tvm::IRPrinter* p) {
p->stream << "TupleValueNode(" << node->fields << ")";
});
......@@ -66,65 +62,18 @@ TensorValue TensorValueNode::make(runtime::NDArray data) {
}
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TensorValueNode>([](const TensorValueNode* node,
tvm::IRPrinter* p) {
.set_dispatch<TensorValueNode>([](const TensorValueNode* node, tvm::IRPrinter* p) {
auto to_str = GetPackedFunc("relay._tensor_value_repr");
std::string data_str = to_str(GetRef<TensorValue>(node));
p->stream << "TensorValueNode(" << data_str << ")";
});
TensorValue TensorValueNode::FromType(const Type& t) {
if (auto tt_node = t.as<TensorTypeNode>()) {
std::vector<int64_t> dims;
for (auto dim : tt_node->shape) {
auto int_node = dim.as<tvm::ir::IntImm>();
CHECK(int_node) << "expected concrete dimensions";
dims.push_back(int_node->value);
}
DLDataType dtype;
DLContext context;
switch (tt_node->dtype.code()) {
case halideir_type_int:
dtype.code = kDLInt;
break;
case halideir_type_uint:
dtype.code = kDLUInt;
break;
case halideir_type_float:
dtype.code = kDLFloat;
break;
default:
throw dmlc::Error("can not convert HalideIR type into DLTensor dtype");
}
dtype.bits = tt_node->dtype.bits();
dtype.lanes = tt_node->dtype.lanes();
// TODO(@jroesch): Is this the right place to place the tensor?
context.device_type = DLDeviceType::kDLCPU;
context.device_id = 0;
runtime::NDArray data = NDArray::Empty(dims, dtype, context);
return TensorValueNode::make(data);
} else {
LOG(FATAL) << "expected a tensor type";
return TensorValue();
}
}
TVM_REGISTER_API("relay._make.TensorValue")
.set_body([](TVMArgs args, TVMRetValue* ret) {
.set_body([](TVMArgs args, TVMRetValue* ret) {
runtime::NDArray data = args[0];
*ret = TensorValueNode::make(data);
});
/* Evaluator Implementation. */
struct EvalError : dmlc::Error {
explicit EvalError(const std::string& msg) : Error(msg) {}
};
/*!
* \brief A stack frame in the Relay interpreter.
*
......@@ -175,70 +124,67 @@ struct Stack {
};
};
/*! \brief The equal comparator for expressions. */
struct ExprEqual {
bool operator()(const Expr& a, const Expr& b) const {
return AlphaEqual(a, b);
// NOTE: the current interpreter assumes A-normal form.
// which is better for execution.
//
// It will run duplicated computations when taking program that
// contains DAG in dataflow-form.
// Conversion to ANF is recommended before running the interpretation.
//
class Interpreter :
public ExprFunctor<Value(const Expr& n)> {
public:
Interpreter(Module mod,
DLContext context,
Target target)
: mod_(mod), context_(context), target_(target) {
engine_ = CompileEngine::Global();
}
};
struct Interpreter : ExprFunctor<Value(const Expr& n)> {
Module mod;
Stack stack;
using JitKey = Function;
using OpMap = std::unordered_map<JitKey, PackedFunc, StructuralHash, ExprEqual>;
OpMap operator_map_;
template <typename T>
T with_frame(const Frame& fr, const std::function<T()>& f) {
Stack::LocalFrame lf(stack, fr);
T WithFrame(const Frame& fr, const std::function<T()>& f) {
Stack::LocalFrame lf(stack_, fr);
return f();
}
Interpreter(Module mod) : mod(mod), operator_map_() {}
Interpreter(Module mod, OpMap operator_map) : mod(mod), operator_map_(operator_map) {}
void extend(const Var& id, Value v) {
this->stack.current_frame().locals.Set(id, v);
stack_.current_frame().locals.Set(id, v);
}
inline Value Lookup(const Var& local) {
return this->stack.Lookup(local);
return stack_.Lookup(local);
}
Value Eval(const Expr& expr) {
return (*this)(expr);
}
Value VisitExpr(const Expr& expr) override {
RELAY_LOG(INFO) << "VisitExpr: " << expr << std::endl;
Value VisitExpr(const Expr& expr) final {
auto ret = ExprFunctor<Value(const Expr& n)>::VisitExpr(expr);
return ret;
}
Value VisitExpr_(const VarNode* var_node) override {
Value VisitExpr_(const VarNode* var_node) final {
return Lookup(GetRef<Var>(var_node));
}
Value VisitExpr_(const GlobalVarNode* op) override {
return Eval(this->mod->Lookup(GetRef<GlobalVar>(op)));
Value VisitExpr_(const GlobalVarNode* op) final {
return Eval(mod_->Lookup(GetRef<GlobalVar>(op)));
}
Value VisitExpr_(const OpNode* id) override {
// TODO(@jroesch): Eta-expand and return in this case.
throw EvalError(
"internal error, need to wrap intrinsic into call synthetic call node "
"in "
"this case, eta expand");
LOG(FATAL) << "internal error, need to wrap intrinsic into call synthetic call node "
<< "in "
<< "this case, eta expand";
return Value();
}
Value VisitExpr_(const ConstantNode* op) override {
return TensorValueNode::make(op->data);
Value VisitExpr_(const ConstantNode* op) final {
return TensorValueNode::make(op->data.CopyTo(context_));
}
Value VisitExpr_(const TupleNode* op) override {
Value VisitExpr_(const TupleNode* op) final {
std::vector<Value> values;
for (const auto& field : op->fields) {
......@@ -249,7 +195,7 @@ struct Interpreter : ExprFunctor<Value(const Expr& n)> {
return TupleValueNode::make(values);
}
Value VisitExpr_(const FunctionNode* func_node) override {
Value VisitExpr_(const FunctionNode* func_node) final {
auto func = GetRef<Function>(func_node);
tvm::Map<Var, Value> captured_mod;
Array<Var> free_vars = FreeVars(func);
......@@ -261,50 +207,108 @@ struct Interpreter : ExprFunctor<Value(const Expr& n)> {
return ClosureNode::make(captured_mod, func);
}
inline Value InvokeCompiledOp(PackedFunc func, const Array<Value>& args,
Type ret_type) {
Value InvokePrimitiveOp(Function func,
const Array<Value>& args) {
// Marshal the arguments.
auto arg_len = args.size() + 1;
// Handle tuple input/output by flattening them.
size_t arg_len = 0;
for (size_t i = 0; i < args.size(); i++) {
if (args[i].as<TensorValueNode>()) {
++arg_len;
} else {
const auto* tvalue = args[i].as<TupleValueNode>();
arg_len += tvalue->fields.size();
}
}
size_t num_inputs = arg_len;
if (const auto* tuple_type = func->body->checked_type().as<TupleTypeNode>()) {
arg_len += tuple_type->fields.size();
} else {
CHECK(func->body->checked_type().as<TensorTypeNode>());
arg_len += 1;
}
std::vector<TVMValue> values(arg_len);
std::vector<int> codes(arg_len);
TVMArgsSetter setter(values.data(), codes.data());
TVMRetValue ret;
// We need real type information to properly allocate the structure.
for (size_t i = 0; i < args.size(); i++) {
if (const TensorValueNode* tv = args[i].as<TensorValueNode>()) {
auto fset_input = [&](size_t i, Value val) {
const TensorValueNode* tv = val.as<TensorValueNode>();
CHECK(tv != nullptr) << "expect Tensor argument";
setter(i, tv->data);
DLContext arg_ctx = tv->data->ctx;
CHECK(arg_ctx.device_type == context_.device_type &&
arg_ctx.device_id == context_.device_id)
<< "Interpreter expect context to be "
<< context_ << ", but get " << arg_ctx;
};
if (func->params.size() == 1 &&
func->params[0]->checked_type().as<TupleTypeNode>()) {
// handle tuple input.
const TupleValueNode* tuple = args[0].as<TupleValueNode>();
CHECK(tuple);
for (size_t i = 0; i < tuple->fields.size(); ++i) {
fset_input(i, tuple->fields[i]);
}
} else {
CHECK_EQ(num_inputs, args.size());
// Decide the target context.
// Primitive functions always sit in the same context.
for (size_t i = 0; i < args.size(); i++) {
fset_input(i, args[i]);
}
}
// TVM's calling convention is that the final argument is the output
// buffer. To preserve the illusion of being a functional language
// we need to allocate space for the output buffer based on the
// return type.
CHECK(ret_type.as<TensorTypeNode>());
auto out_tensor = TensorValueNode::FromType(ret_type);
auto fset_output = [&](size_t i, Type val_type) {
const TensorTypeNode* rtype = val_type.as<TensorTypeNode>();
CHECK(rtype != nullptr);
// Allocate output tensor.
std::vector<int64_t> shape;
for (auto dim : rtype->shape) {
const auto* ivalue = as_const_int(dim);
CHECK(ivalue) << "expected concrete dimensions";
shape.push_back(ivalue[0]);
}
DLDataType dtype = Type2TVMType(rtype->dtype);
auto out_tensor = TensorValueNode::make(
NDArray::Empty(shape, dtype, context_));
setter(num_inputs + i, out_tensor->data);
return out_tensor;
};
setter(arg_len - 1, out_tensor->data);
func.CallPacked(TVMArgs(values.data(), codes.data(), arg_len), &ret);
PackedFunc packed_func = engine_->JIT(CCacheKeyNode::make(func, target_));
TVMRetValue rv;
if (const TupleTypeNode* rtype = func->body->checked_type().as<TupleTypeNode>()) {
Array<Value> fields;
for (size_t i = 0; i < rtype->fields.size(); ++i) {
fields.push_back(fset_output(i, rtype->fields[i]));
}
packed_func.CallPacked(TVMArgs(values.data(), codes.data(), arg_len), &rv);
return TupleValueNode::make(fields);
} else {
Value out_tensor = fset_output(0, func->body->checked_type());
packed_func.CallPacked(TVMArgs(values.data(), codes.data(), arg_len), &rv);
return out_tensor;
}
Value Invoke(const Closure& closure, const tvm::Array<Value>& args) {
// Get a reference to the function inside the closure.
auto func = closure->func;
auto compiled = operator_map_.find(func);
tvm::Array<Function> funcs;
for (auto op : operator_map_) {
funcs.push_back(op.first);
}
// This case we know we have precompiled the operator.
if (compiled != operator_map_.end()) {
auto func_ty = func->func_type_annotation();
return InvokeCompiledOp(compiled->second, args, func_ty->ret_type);
// Check if function is a primitive function.
bool IsPrimitive(const Function& func) const {
NodeRef res = FunctionGetAttr(func, "Primitive");
const ir::IntImm* pval = res.as<ir::IntImm>();
return pval && pval->value != 0;
}
// Invoke the closure
Value Invoke(const Closure& closure, const tvm::Array<Value>& args) {
// Get a reference to the function inside the closure.
if (IsPrimitive(closure->func)) {
return InvokePrimitiveOp(closure->func, args);
}
auto func = closure->func;
// Allocate a frame with the parameters and free variables.
tvm::Map<Var, Value> locals;
......@@ -321,15 +325,14 @@ struct Interpreter : ExprFunctor<Value(const Expr& n)> {
locals.Set((*it).first, (*it).second);
}
return with_frame<Value>(Frame(locals), [&]() { return Eval(func->body); });
return WithFrame<Value>(Frame(locals), [&]() { return Eval(func->body); });
}
Value VisitExpr_(const CallNode* call) override {
Value VisitExpr_(const CallNode* call) final {
tvm::Array<Value> args;
for (auto arg : call->args) {
args.push_back(Eval(arg));
}
// We should not find operators after running fusion,
// and operator lowering.
//
......@@ -340,26 +343,25 @@ struct Interpreter : ExprFunctor<Value(const Expr& n)> {
<< "; operators should be removed by future passes; try "
"fusing and lowering";
}
// Now we just evaluate and expect to find a closure.
Value fn_val = Eval(call->op);
if (const ClosureNode* closure_node = fn_val.as<ClosureNode>()) {
auto closure = GetRef<Closure>(closure_node);
return this->Invoke(closure, args);
} else {
throw EvalError(
"internal error: type error, expected function value in the call "
"position");
LOG(FATAL) << "internal error: type error, expected function value in the call "
<< "position";
return Value();
}
}
Value VisitExpr_(const LetNode* op) override {
Value VisitExpr_(const LetNode* op) final {
auto value = Eval(op->value);
this->extend(op->var, value);
return Eval(op->body);
}
Value VisitExpr_(const TupleGetItemNode* op) override {
Value VisitExpr_(const TupleGetItemNode* op) final {
Value val = Eval(op->tuple);
auto product_node = val.as<TupleValueNode>();
CHECK(product_node)
......@@ -369,64 +371,56 @@ struct Interpreter : ExprFunctor<Value(const Expr& n)> {
return product_node->fields[op->index];
}
Value VisitExpr_(const IfNode* op) override {
Value VisitExpr_(const IfNode* op) final {
Value v = Eval(op->cond);
if (const TensorValueNode* bv = v.as<TensorValueNode>()) {
DLContext cpu_ctx;
cpu_ctx.device_type = kDLCPU;
cpu_ctx.device_id = 0;
NDArray cpu_array = bv->data.CopyTo(cpu_ctx);
CHECK_EQ(TVMType2Type(cpu_array->dtype), Bool());
// TODO(@jroesch, @MK): Refactor code into helper from DCE.
if (reinterpret_cast<uint8_t*>(bv->data->data)[0]) {
if (reinterpret_cast<uint8_t*>(cpu_array->data)[0]) {
return Eval(op->true_branch);
} else {
return Eval(op->false_branch);
}
} else {
throw EvalError("type error, type system should have caught this");
LOG(FATAL) << "type error, type system should have caught this";
return Value();
}
}
};
Interpreter::OpMap CompileOperators(const Module& mod, const Expr& e) {
Interpreter::OpMap op_map;
auto lowered_ops = LowerOps(mod, e);
RELAY_LOG(INFO) << "LoweredFuncs: " << lowered_ops << std::endl;
if (lowered_ops.size()) {
const PackedFunc* fbuild_ptr = Registry::Get("relay.op.compiler._build");
CHECK(fbuild_ptr) << "Could not find registered function: relay.op.compiler._build";
auto fbuild = *fbuild_ptr;
// Collect the set of lowered functions to build a module.
Array<LoweredFunc> lowered_funcs;
for (auto lop : lowered_ops) {
lowered_funcs.push_back(lop->lowered_func);
}
runtime::Module module = fbuild(lowered_funcs);
// Loop over the lowered operations to map them into the operator map.
for (auto lop : lowered_ops) {
Function func = lop->func;
LoweredFunc lf = lop->lowered_func;
RELAY_LOG(INFO) << "LoweredFunc: " << lf->name << std::endl;
auto op_impl = module.GetFunction(lf->name);
op_map.insert({func, op_impl});
}
}
private:
// module
Module mod_;
// For simplicity we only run the interpreter on a single context.
// Context to run the interpreter on.
DLContext context_;
// Target parameter being used by the interpreter.
Target target_;
// value stack.
Stack stack_;
// Backend compile engine.
CompileEngine engine_;
};
return op_map;
}
Value Evaluate(Module mod, Expr e) {
auto op_map = CompileOperators(mod, e);
Interpreter interp(mod, op_map);
return interp.Eval(e);
TypedPackedFunc<Value(Expr)>
CreateInterpreter(
Module mod,
DLContext context,
Target target) {
auto intrp = std::make_shared<Interpreter>(mod, context, target);
auto packed = [intrp](Expr expr) {
return intrp->Eval(expr);
};
return TypedPackedFunc<Value(Expr)>(packed);
}
TVM_REGISTER_API("relay._interpreter.evaluate")
.set_body([](TVMArgs args, TVMRetValue* ret) {
Module mod = args[0];
Expr expr = args[1];
*ret = Evaluate(mod, expr);
TVM_REGISTER_API("relay.backend.CreateInterpreter")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = CreateInterpreter(args[0], args[1], args[2]);
});
} // namespace relay
} // namespace tvm
......@@ -34,10 +34,12 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
TensorType ConstantNode::tensor_type() const {
auto dtype = TVMType2Type(data->dtype);
Array<tvm::Expr> shape;
for (int i = 0; i < data->ndim; i++) {
shape.push_back(tvm::ir::IntImm::make(HalideIR::Int(64), data->shape[i]));
CHECK_LE(data->shape[i], std::numeric_limits<int32_t>::max());
CHECK_GE(data->shape[i], std::numeric_limits<int32_t>::min());
shape.push_back(
tvm::ir::IntImm::make(Int(32), data->shape[i]));
}
return TensorTypeNode::make(shape, dtype);
......
......@@ -67,13 +67,15 @@ const GenericOpMap& Op::GetGenericAttr(const std::string& key) {
return *it->second.get();
}
void OpRegistry::UpdateAttr(const std::string& key, TVMRetValue value,
void OpRegistry::UpdateAttr(const std::string& key,
TVMRetValue value,
int plevel) {
OpManager* mgr = OpManager::Global();
std::lock_guard<std::mutex> lock(mgr->mutex);
std::unique_ptr<GenericOpMap>& op_map = mgr->attr[key];
if (op_map == nullptr) {
op_map.reset(new GenericOpMap());
op_map->attr_name_ = key;
}
uint32_t index = op_->index_;
if (op_map->data_.size() <= index) {
......@@ -112,7 +114,7 @@ TVM_REGISTER_API("relay.op._OpGetAttr")
});
TVM_REGISTER_API("relay.op._Register")
.set_body([](TVMArgs args, TVMRetValue* rv) {
.set_body([](TVMArgs args, TVMRetValue* rv) {
std::string op_name = args[0];
std::string attr_key = args[1];
runtime::TVMArgValue value = args[2];
......
......@@ -271,7 +271,7 @@ class TextPrinter :
TextValue VisitExpr_(const FunctionNode* op) final {
TextValue id = AllocTempVar();
std::ostringstream os;
os << id << " = function";
os << id << " = fn";
this->PrintFuncInternal(os.str(), GetRef<Function>(op));
this->PrintEndInst("\n");
return id;
......@@ -516,11 +516,14 @@ class TextPrinter :
stream_ << ",\n";
}
}
stream_ << ") ";
stream_ << ')';
if (fn->ret_type.defined()) {
stream_ << " -> ";
stream_ << '\n';
this->PrintIndent(decl_indent);
stream_ << "-> ";
this->PrintType(fn->ret_type, stream_);
}
stream_ << ' ';
this->PrintScope(fn->body);
}
/*!
......
......@@ -9,6 +9,7 @@
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <vector>
namespace tvm {
......@@ -44,7 +45,8 @@ std::vector<T> AsVector(const Array<T> &array) {
}); \
RELAY_REGISTER_OP(OpName) \
.set_num_inputs(1) \
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("data", "Tensor", "The input tensor.") \
.set_attr<TOpPattern>("TOpPattern", kElemWise)
/*! Quick helper macro
* - Expose a positional make function to construct the node.
......@@ -68,7 +70,8 @@ std::vector<T> AsVector(const Array<T> &array) {
.set_num_inputs(2) \
.add_argument("lhs", "Tensor", "The left hand side tensor.") \
.add_argument("rhs", "Tensor", "The right hand side tensor.") \
.add_type_rel("Broadcast", BroadcastRel)
.add_type_rel("Broadcast", BroadcastRel) \
.set_attr<TOpPattern>("TOpPattern", kBroadcast)
} // namespace relay
} // namespace tvm
......
......@@ -46,7 +46,7 @@ RELAY_REGISTER_BINARY_OP("relay.op._make.", "multiply")
.describe("Elementwise multiply with broadcasting")
.set_support_level(1);
RELAY_REGISTER_BINARY_OP("relay.op._make.", "pow")
RELAY_REGISTER_BINARY_OP("relay.op._make.", "power")
.describe("Elementwise power with broadcasting")
.set_support_level(4);
......@@ -65,7 +65,8 @@ RELAY_REGISTER_BINARY_OP("relay.op._make.", "mod")
.set_num_inputs(2) \
.add_argument("lhs", "Tensor", "The left hand side tensor.") \
.add_argument("rhs", "Tensor", "The right hand side tensor.") \
.add_type_rel("BroadcastComp", BroadcastCompRel)
.add_type_rel("BroadcastComp", BroadcastCompRel) \
.set_attr<TOpPattern>("TOpPattern", kBroadcast)
RELAY_REGISTER_CMP_OP("equal")
.describe("Elementwise equal compare with broadcasting")
......
......@@ -3,32 +3,32 @@
*
* \file src/tvm/relay/pass/fuse_ops.cc
*
* \brief Fuse Relay eligble sequences of Relay operators into a single one.
*
* \brief This is a backend-aware optimization pass.
* Fuse necessary ops into a single one.
*/
#include <tvm/ir_operator.h>
#include <tvm/relay/pass.h>
#include <tvm/runtime/module.h>
#include <tvm/lowered_func.h>
#include <tvm/operation.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/logging.h>
#include "../ir/type_functor.h"
namespace tvm {
namespace relay {
using namespace runtime;
struct AbstractFusableOps : ExprMutator {
Module mod;
Array<GlobalVar> fusable_funcs;
int counter = 0;
size_t expr_hash;
AbstractFusableOps(Module mod, size_t expr_hash) : mod(mod), expr_hash(expr_hash) {}
// Simple fuser that only makes each operator function as primitive.
class SimpleFuser : public ExprMutator {
public:
// Skip primitive function.
Expr VisitExpr_(const FunctionNode* fn_node) {
NodeRef res = FunctionGetAttr(GetRef<Function>(fn_node), "Primitive");
const ir::IntImm* pval = res.as<ir::IntImm>();
if (pval && pval->value != 0) {
return GetRef<Expr>(fn_node);
} else {
return ExprMutator::VisitExpr_(fn_node);
}
}
Expr VisitExpr_(const CallNode* call) {
if (auto op_node = call->op.as<OpNode>()) {
if (call->op.as<OpNode>()) {
// Placeholder fusion algorithm which abstracts
// single definitions into functions only.
Array<Var> params;
......@@ -37,50 +37,37 @@ struct AbstractFusableOps : ExprMutator {
int param_number = 0;
for (auto arg : call->args) {
auto name = std::string("p") + std::to_string(param_number++);
std::ostringstream os;
os << "p" << param_number++;
auto type = arg->checked_type();
auto var = VarNode::make(name, type);
auto var = VarNode::make(os.str(), type);
params.push_back(var);
inner_args.push_back(var);
args.push_back(VisitExpr(arg));
args.push_back(this->Mutate(arg));
}
auto body = CallNode::make(call->op, inner_args, call->attrs);
auto func = FunctionNode::make(params, body, call->checked_type(), {});
auto func = FunctionNode::make(
params, body, call->checked_type(), {});
func = FunctionSetAttr(func, "Primitive", tvm::Integer(1));
std::string func_name = "fused_";
func_name += op_node->name;
func_name += "_";
func_name += std::to_string(counter++);
func_name += "_";
func_name += std::to_string(expr_hash);
auto gv = GlobalVarNode::make(func_name);
mod->Add(gv, func);
fusable_funcs.push_back(gv);
return CallNode::make(gv, args, Attrs());
return CallNode::make(func, args, Attrs());
} else {
return ExprMutator::VisitExpr_(call);
}
}
};
Expr FuseOps(const Module& mod, const Expr& e) {
Expr FuseOps(const Expr& expr) {
// First we convert all chains of fusable ops into
// abstracted functions which we mark as primtive
// then we convert these primtive functions into
// new operators.
auto abstract = AbstractFusableOps(mod, StructuralHash()(e));
auto abstracted_e = abstract.VisitExpr(e);
RELAY_LOG(INFO) << "FuseOps: before=" << e
<< "Fuse: after=" << abstracted_e;
return abstracted_e;
return SimpleFuser().Mutate(expr);
}
TVM_REGISTER_API("relay._ir_pass.FuseOps")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = FuseOps(args[1], args[0]);
*ret = FuseOps(args[0]);
});
} // namespace relay
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
*
* \file src/tvm/relay/pass/lower_ops.cc
*
* \brief Lower a Relay program to set of TVM operators.
*
*/
#include <tvm/lowered_func.h>
#include <tvm/operation.h>
#include <tvm/build_module.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/logging.h>
#include <tvm/relay/pass.h>
#include <tvm/runtime/module.h>
#include <tvm/relay/build_module.h>
#include "../ir/type_functor.h"
namespace tvm {
namespace relay {
using namespace runtime;
LoweredOp LoweredOpNode::make(Function func, LoweredFunc lowered_func) {
auto node = make_node<LoweredOpNode>();
node->func = func;
node->lowered_func = lowered_func;
return LoweredOp(node);
}
struct AbstractLocalFunctions : ExprMutator {
Module mod;
size_t expr_hash;
int counter = 0;
std::unordered_set<GlobalVar, NodeHash, NodeEqual> visited_funcs;
explicit AbstractLocalFunctions(Module mod)
: mod(mod), expr_hash(0), counter(0), visited_funcs() {}
Expr Abstract(const Expr& e) {
expr_hash = StructuralHash()(e);
return VisitExpr(e);
}
Expr VisitExpr_(const GlobalVarNode* gvar_node) final {
auto gvar = GetRef<GlobalVar>(gvar_node);
auto it = visited_funcs.find(gvar);
if (it == visited_funcs.end()) {
auto func = mod->Lookup(gvar);
visited_funcs.insert(gvar);
auto new_func = FunctionNode::make(
func->params,
VisitExpr(func->body),
func->ret_type,
func->type_params,
func->attrs);
mod->Update(gvar, new_func);
}
return gvar;
}
Expr VisitExpr_(const FunctionNode* func_node) final {
Function func = GetRef<Function>(func_node);
auto free_vars = FreeVars(func);
Array<Var> params;
for (auto free_var : free_vars) {
auto var = VarNode::make("free_var", free_var->checked_type());
params.push_back(var);
}
std::string abs_func = "abstracted_func_";
abs_func += std::to_string(counter++);
abs_func += std::to_string(expr_hash);
auto gv = GlobalVarNode::make(abs_func);
auto lifted_func = FunctionNode::make(params, func, Type(), {}, {});
mod->Add(gv, lifted_func);
Array<Expr> args;
for (auto free_var : free_vars) {
args.push_back(free_var);
}
return CallNode::make(gv, args, {});
}
};
struct LiveFunctions : ExprVisitor {
Module mod;
explicit LiveFunctions(Module mod) : mod(mod), global_funcs() {}
std::unordered_set<GlobalVar, NodeHash, NodeEqual> visited_funcs;
std::unordered_set<GlobalVar, NodeHash, NodeEqual> global_funcs;
void Live(const Expr& e) {
CHECK(!e.as<FunctionNode>())
<< "functions should of been transformed away by previous pass";
VisitExpr(e);
}
void VisitExpr_(const FunctionNode* func_node) {
LOG(FATAL) << "functions should of been transformed away by previous pass";
}
void VisitExpr_(const GlobalVarNode* var_node) final {
GlobalVar var = GetRef<GlobalVar>(var_node);
auto it = visited_funcs.find(var);
if (it == visited_funcs.end()) {
auto func = mod->Lookup(var);
visited_funcs.insert(var);
// The last pass has trasnformed functions of the form:
//
// let x = fn (p_1, ..., p_n) { ... };
// ...
//
// into, a top-level declaration:
//
// def abs_f(fv_1, ..., fv_n) {
// return (fn (p_1...,p_N) { ... };)
// }
//
// and:
//
// let x = abs_f(fv_1, ... fv_n);
//
// The only other case we can handle is
//
// fn foo(...) { body }
//
// We just search through the body in this case.
if (auto inner_func = func->body.as<FunctionNode>()) {
return VisitExpr(inner_func->body);
} else {
return VisitExpr(func->body);
}
}
}
void VisitExpr_(const CallNode* call) final {
RELAY_LOG(INFO) << "LiveOps: CallNode=" << GetRef<Call>(call);
if (auto gv_node = call->op.as<GlobalVarNode>()) {
GlobalVar gvar = GetRef<GlobalVar>(gv_node);
Function func = mod->Lookup(gvar);
auto attr = FunctionGetAttr(func, "Primitive");
if (attr.defined() && Downcast<Integer>(attr)->value == 1) {
global_funcs.insert(gvar);
} else {
VisitExpr(gvar);
}
// Finally we need to ensure to visit all the args no matter what.
for (auto arg : call->args) {
VisitExpr(arg);
}
} else {
return ExprVisitor::VisitExpr_(call);
}
}
};
using FCompute = TypedPackedFunc<Array<Tensor>(
const Attrs&, const Array<Tensor>&, Type, tvm::Target)>;
using FSchedule = TypedPackedFunc<Schedule(const Array<Tensor>&, tvm::Target)>;
/*! \brief Return the set of operators in their TVM format. */
Array<LoweredOp> LowerOps(const Module& mod, const Expr& e,
const std::string& target) {
RELAY_LOG(INFO) << "LowerOps: e=" << e;
auto flower_ptr = Registry::Get("relay.op.compiler._lower");
CHECK(flower_ptr);
PackedFunc flower = *flower_ptr;
auto abstracted_e = AbstractLocalFunctions(mod).Abstract(e);
auto live_funcs = LiveFunctions(mod);
live_funcs.VisitExpr(abstracted_e);
auto schedule_reg = Op::GetAttr<FSchedule>("FTVMSchedule");
auto compute_reg = Op::GetAttr<FCompute>("FTVMCompute");
Array<LoweredOp> lowered_funcs;
for (auto func_name : live_funcs.global_funcs) {
auto func = mod->Lookup(func_name);
auto call = Downcast<Call>(func->body);
auto op_node = call->op.as<OpNode>();
CHECK(op_node) << "violated invariant that primtive calls contain a single op call";
auto op = GetRef<Op>(op_node);
RELAY_LOG(INFO) << "LowerOps: Lowering " << op->name;
CHECK(IsPrimitiveOp(op)) << "failed to lower "
<< op->name << "can only lower primitve operations";
Array<Tensor> inputs;
std::string input_name = "in";
int i = 0;
for (auto type_arg : call->type_args) {
auto tt = Downcast<TensorType>(type_arg);
inputs.push_back(PlaceholderOpNode::make(input_name + std::to_string(i),
tt->shape, tt->dtype)
.output(0));
i++;
}
auto output_tt = call->checked_type();
auto target_node = Target::create(target);
Array<Tensor> outputs =
compute_reg[op](call->attrs, inputs, output_tt, target_node);
auto schedule = schedule_reg[op](outputs, target_node);
size_t hash = StructuralHash()(func);
LoweredFunc lf =
flower(op->name + std::to_string(hash), schedule, inputs, outputs);
func = FunctionSetAttr(func, "LoweredFunc", lf);
mod->Add(func_name, func, true);
lowered_funcs.push_back(LoweredOpNode::make(func, lf));
}
return lowered_funcs;
}
TVM_REGISTER_API("relay._ir_pass.LowerOps")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = LowerOps(args[0], args[1], args[2]);
});
} // namespace relay
} // namespace tvm
......@@ -22,27 +22,6 @@
namespace tvm {
namespace runtime {
/*!
* \brief The name of Device API factory.
* \param type The device type.
*/
inline std::string DeviceName(int type) {
switch (type) {
case kDLCPU: return "cpu";
case kDLGPU: return "gpu";
case kDLOpenCL: return "opencl";
case kDLSDAccel: return "sdaccel";
case kDLAOCL: return "aocl";
case kDLVulkan: return "vulkan";
case kDLMetal: return "metal";
case kDLVPI: return "vpi";
case kDLROCM: return "rocm";
case kOpenGL: return "opengl";
case kDLExtDev: return "ext_dev";
default: LOG(FATAL) << "unknown type =" << type; return "Unknown";
}
}
class DeviceAPIManager {
public:
static const int kMaxDeviceAPI = 32;
......
......@@ -187,8 +187,8 @@ void GraphRuntime::SetupStorage() {
CHECK_GE(storage_id, 0) << "Do not support runtime shape op";
DLDataType t = vtype[i];
size_t bits = t.bits * t.lanes;
CHECK_EQ(bits % 8U, 0U);
size_t bytes = (bits / 8U) * size;
CHECK(bits % 8U == 0U || bits ==1U);
size_t bytes = ((bits + 7U) / 8U) * size;
uint32_t sid = static_cast<uint32_t>(storage_id);
if (sid >= pool_entry.size()) {
......
import tvm
import tvm.testing
import numpy as np
from tvm import relay
def test_compile_engine():
engine = relay.backend.compile_engine.get()
def get_func(shape):
x = relay.var("x", shape=shape)
y = relay.add(x, x)
z = relay.add(y, x)
f = relay.ir_pass.infer_type(relay.Function([x], z))
return f
z1 = engine.lower(get_func((10,)), "llvm")
z2 = engine.lower(get_func((10,)), "llvm")
z3 = engine.lower(get_func(()), "llvm")
assert z1.same_as(z2)
assert not z3.same_as(z1)
if tvm.context("cuda").exist:
z4 = engine.lower(get_func(()), "cuda")
assert not z3.same_as(z4)
# Test JIT target
for target in ["llvm"]:
ctx = tvm.context(target)
if ctx.exist:
f = engine.jit(get_func((10,)), target)
x = tvm.nd.array(np.ones(10).astype("float32"), ctx=ctx)
y = tvm.nd.empty((10,), ctx=ctx)
f(x, y)
tvm.testing.assert_allclose(
y.asnumpy(), x.asnumpy() * 3)
engine.dump()
if __name__ == "__main__":
test_compile_engine()
import numpy as np
from tvm import relay
from tvm.relay import create_executor
from tvm.relay.ir_pass import infer_type
from tvm.relay.interpreter import Interpreter
from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay.op import add
from tvm.relay.module import Module
......@@ -25,8 +23,8 @@ def check_rts(expr, args, expected_result, mod=None):
expected_result:
The expected result of running the expression.
"""
intrp = create_executor('graph', mod=mod)
graph = create_executor('graph', mod=mod)
intrp = relay.create_executor('debug', mod=mod)
graph = relay.create_executor('graph', mod=mod)
eval_result = intrp.evaluate(expr)(*args)
rts_result = graph.evaluate(expr)(*args)
np.testing.assert_allclose(eval_result.asnumpy(), rts_result.asnumpy())
......
import numpy as np
import tvm
import tvm.testing
from tvm import relay
from tvm.relay.interpreter import Value, TupleValue
from tvm.relay import op
from tvm.relay.backend.interpreter import Value, TupleValue
from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay import testing, create_executor
def check_eval(expr, args, expected_result, mod=None, rtol=1e-07):
intrp = create_executor(mod=mod)
# TODO(tqchen) add more types once the schedule register is fixed.
for target in ["llvm"]:
ctx = tvm.context(target, 0)
if not ctx.exist:
return
intrp = create_executor(mod=mod, ctx=ctx, target=target)
result = intrp.evaluate(expr)(*args)
np.testing.assert_allclose(result.asnumpy(), expected_result, rtol=rtol)
# use tvm.testing which also set atol
tvm.testing.assert_allclose(
result.asnumpy(), expected_result, rtol=rtol)
def test_from_scalar():
......@@ -34,7 +41,7 @@ def test_id():
def test_add_const():
two = op.add(relay.const(1), relay.const(1))
two = relay.add(relay.const(1), relay.const(1))
func = relay.Function([], two)
check_eval(func, [], 2)
......@@ -42,7 +49,7 @@ def test_add_const():
def test_mul_param():
x = relay.var('x', shape=(10, 10))
y = relay.var('y', shape=(1, 10))
func = relay.Function([x, y], op.multiply(x, y))
func = relay.Function([x, y], relay.multiply(x, y))
x_data = np.random.rand(10, 10).astype('float32')
y_data = np.random.rand(1, 10).astype('float32')
check_eval(func, [x_data, y_data], x_data * y_data)
......@@ -53,7 +60,7 @@ def test_mul_param():
# def test_dense():
# x = relay.var('x', shape=(10, 10))
# w = relay.var('w', shape=(10, 10))
# y = op.nn.dense(x, w)
# y = relay.nn.dense(x, w)
# func = relay.Function([x, w], y)
# x_data = np.random.rand(10, 10).astype('float32')
# w_data = np.random.rand(10, 10).astype('float32')
......@@ -63,7 +70,7 @@ def test_mul_param():
# x = relay.var('x', shape=(10, 10))
# w = relay.var('w', shape=(10, 10))
# b = relay.var('b', shape=(10,))
# y = op.add(op.nn.dense(x, w), b)
# y = relay.add(relay.nn.dense(x, w), b)
# func = relay.Function([x, w, b], y)
# x_data = np.random.rand(10, 10).astype('float32')
# w_data = np.random.rand(10, 10).astype('float32')
......@@ -73,46 +80,49 @@ def test_mul_param():
def test_equal():
i = relay.var('i', shape=[], dtype='int32')
j = relay.var('i', shape=[], dtype='int32')
z = op.equal(i, j)
z = relay.equal(i, j)
func = relay.Function([i, j], z, ret_type=relay.TensorType([], 'bool'))
i_data = relay.const(0)
j_data = relay.const(0)
check_eval(func, [i_data, j_data], True)
def test_subtract():
i = relay.var('i', shape=[], dtype='int32')
sub = op.subtract(i, relay.const(1, dtype='int32'))
sub = relay.subtract(i, relay.const(1, dtype='int32'))
func = relay.Function([i], sub, ret_type=relay.TensorType([], 'int32'))
i_data = np.array(1, dtype='int32')
check_eval(func, [i_data], 0)
def test_simple_loop():
mod = relay.module.Module({})
sum_up = relay.GlobalVar('sum_up')
i = relay.var('i', shape=[], dtype='int32')
sb = ScopeBuilder()
with sb.if_scope(op.equal(i, relay.const(0, dtype='int32'))):
with sb.if_scope(relay.equal(i, relay.const(0, dtype='int32'))):
sb.ret(i)
with sb.else_scope():
one_less = op.subtract(i, relay.const(1, dtype='int32'))
one_less = relay.subtract(i, relay.const(1, dtype='int32'))
rec_call = relay.Call(sum_up, [one_less])
sb.ret(op.add(rec_call, i))
sb.ret(relay.add(rec_call, i))
func = relay.Function([i], sb.get(), ret_type=relay.TensorType([], 'int32'))
mod[sum_up] = func
i_data = np.array(10, dtype='int32')
check_eval(sum_up, [i_data], sum(range(1, 11)), mod=mod)
def test_loop():
mod = relay.module.Module({})
sum_up = relay.GlobalVar('sum_up')
i = relay.var('i', shape=[], dtype='int32')
accum = relay.var('accum', shape=[], dtype='int32')
sb = ScopeBuilder()
with sb.if_scope(op.equal(i, relay.const(0))):
with sb.if_scope(relay.equal(i, relay.const(0))):
sb.ret(accum)
with sb.else_scope():
one_less = op.subtract(i, relay.const(1))
new_accum = op.add(accum, i)
one_less = relay.subtract(i, relay.const(1))
new_accum = relay.add(accum, i)
sb.ret(relay.Call(sum_up, [one_less, new_accum]))
func = relay.Function([i, accum], sb.get())
mod[sum_up] = func
......@@ -120,19 +130,21 @@ def test_loop():
accum_data = np.array(0, dtype='int32')
check_eval(sum_up, [i_data, accum_data], sum(range(1, 11)), mod=mod)
def test_mlp():
pass
# net = testing.mlp.get_workload(1)
# import pdb; pdb.set_trace()
def test_binds():
x = relay.var("x")
y = relay.add(x, x)
intrp = create_executor("debug")
xx = np.ones((10, 20))
res = intrp.evaluate(y, binds={x: xx}).asnumpy()
tvm.testing.assert_allclose(xx + xx, res)
if __name__ == "__main__":
test_id()
test_add_const()
# test_dense()
# test_linear()
test_equal()
test_subtract()
test_simple_loop()
test_loop()
test_mlp()
test_binds()
......@@ -2,7 +2,7 @@ import math
import tvm
import numpy as np
from tvm import relay
from tvm.relay.interpreter import create_executor
from tvm.relay.testing import ctx_list
def sigmoid(x):
one = np.ones_like(x)
......@@ -27,11 +27,16 @@ def test_unary_op():
if ref is not None:
data = np.random.rand(*shape).astype(dtype)
intrp = create_executor()
op_res = intrp.evaluate(y, { x: relay.const(data) })
ref_res = ref(data)
func = relay.Function([x], y)
for target, ctx in ctx_list():
# use graph by execuor default for testing, as we need
# create function explicitly to avoid constant-folding.
intrp = relay.create_executor("graph", ctx=ctx, target=target)
op_res = intrp.evaluate(func)(data)
np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
for opfunc, ref in [(tvm.relay.log, np.log),
(tvm.relay.exp, np.exp),
(tvm.relay.sqrt, np.sqrt),
......@@ -67,14 +72,17 @@ def test_binary_op():
z = opfunc(x, y)
x_data = np.random.rand(5, 10, 5).astype(t1.dtype)
y_data = np.random.rand(5, 10, 5).astype(t2.dtype)
intrp = create_executor()
op_res = intrp.evaluate(z, { x: relay.const(x_data), y: relay.const(y_data) })
ref_res = ref(x_data, y_data)
func = relay.Function([x, y], z)
for target, ctx in ctx_list():
# use graph by execuor default for testing, as we need
# create function explicitly to avoid constant-folding.
intrp = relay.create_executor("graph", ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data, y_data)
np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
for opfunc, ref in [(relay.add, np.add),
(relay.subtract, np.subtract),
(relay.mod, np.mod),
(relay.multiply, np.multiply),
(relay.divide, np.divide)]:
check_binary_op(opfunc, ref)
......@@ -116,7 +124,7 @@ def test_log_softmax():
assert yy.checked_type == relay.TensorType((n, d))
def test_concatenate_infer_type():
def test_concatenate():
n, t, d = tvm.var("n"), tvm.var("t"), 100
x = relay.var("x", shape=(n, t, d))
y = relay.var("y", shape=(n, t, d))
......@@ -134,15 +142,23 @@ def test_concatenate_infer_type():
zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.TensorType((n, t + t, 100))
# x = relay.var("x", shape=(10, 5))
# y = relay.var("y", shape=(10, 5))
# z = relay.concatenate((x, y), axis=1)
# intrp = create_executor()
# x_data = np.random.rand(10, 5).astype('float32')
# y_data = np.random.rand(10, 5).astype('float32')
# op_res = intrp.evaluate(z, { x: relay.const(x_data), y: relay.const(y_data) })
# ref_res = np.concatenate(x_data, y_data, axis=1)
# np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
x = relay.var("x", shape=(10, 5))
y = relay.var("y", shape=(10, 5))
z = relay.concatenate((x, y), axis=1)
# Check result.
func = relay.Function([x, y], z)
x_data = np.random.rand(10, 5).astype('float32')
y_data = np.random.rand(10, 5).astype('float32')
ref_res = np.concatenate((x_data, y_data), axis=1)
for target, ctx in ctx_list():
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
intrp2 = relay.create_executor("debug", ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(x_data, y_data)
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=0.01)
op_res2 = intrp2.evaluate(func)(x_data, y_data)
tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=0.01)
def test_dropout():
n, t, d = tvm.var("n"), tvm.var("t"), tvm.var("d")
......@@ -206,7 +222,7 @@ if __name__ == "__main__":
test_unary_op()
test_binary_op()
test_expand_dims_infer_type()
test_concatenate_infer_type()
test_concatenate()
test_softmax()
test_log_softmax()
test_dropout()
......
import tvm
import numpy as np
from tvm import relay
from tvm.relay import create_executor
from tvm.relay.testing import ctx_list
def test_binary_op():
......@@ -24,12 +24,15 @@ def test_binary_op():
z = opfunc(x, y)
x_data = np.random.rand(5, 10, 5).astype(t1.dtype)
y_data = np.random.rand(5, 10, 5).astype(t2.dtype)
intrp = create_executor()
op_res = intrp.evaluate(z, { x: relay.const(x_data), y: relay.const(y_data) })
ref_res = ref(x_data, y_data)
np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
func = relay.Function([x, y], z)
for opfunc, ref in [(relay.pow, np.power)]:
for target, ctx in ctx_list():
intrp = relay.create_executor("graph", ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data, y_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res)
for opfunc, ref in [(relay.power, np.power)]:
check_binary_op(opfunc, ref)
......@@ -57,15 +60,19 @@ def test_cmp_type():
z = op(x, y)
x_data = np.random.rand(*x_shape).astype(t1.dtype)
y_data = np.random.rand(*y_shape).astype(t2.dtype)
intrp = create_executor()
op_res = intrp.evaluate(z, { x: relay.const(x_data), y: relay.const(y_data) })
ref_res = ref(x_data, y_data)
np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
func = relay.Function([x, y], z)
for target, ctx in ctx_list():
intrp = relay.create_executor("graph", ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data, y_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res)
def test_binary_int_broadcast():
for op, ref in [(relay.right_shift, np.right_shift),
(relay.left_shift, np.left_shift),
(relay.mod, np.mod),
(relay.maximum, np.maximum),
(relay.minimum, np.minimum)]:
x = relay.var("x", relay.TensorType((10, 4), "int32"))
......@@ -81,10 +88,14 @@ def test_binary_int_broadcast():
t2 = relay.TensorType(y_shape, 'int32')
x_data = np.random.rand(*x_shape).astype(t1.dtype)
y_data = np.random.rand(*y_shape).astype(t2.dtype)
intrp = create_executor()
op_res = intrp.evaluate(z, { x: relay.const(x_data), y: relay.const(y_data) })
func = relay.Function([x, y], z)
ref_res = ref(x_data, y_data)
np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
for target, ctx in ctx_list():
intrp = relay.create_executor("graph", ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data, y_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res)
def test_where():
cond = relay.var("cond", relay.TensorType((3, 4), "float32"))
......
import tvm
from tvm import relay
def test_fuse_simple():
"""Simple testcase."""
x = relay.var("x", shape=(10, 20))
y = relay.add(x, x)
z = relay.exp(y)
z = relay.ir_pass.infer_type(z)
zz = relay.ir_pass.fuse_ops(z)
zz = relay.ir_pass.fuse_ops(zz)
zz = relay.ir_pass.infer_type(zz)
zz.astext()
if __name__ == "__main__":
test_fuse_simple()
......@@ -3,10 +3,9 @@
from __future__ import absolute_import as _abs
import tvm
import topi
from . import tag
from . import cpp
@tvm.tag_scope(tag=tag.BROADCAST)
def expand_dims(a, axis, num_newaxis=1):
"""Expand the shape of an array.
......@@ -25,7 +24,6 @@ def expand_dims(a, axis, num_newaxis=1):
return cpp.expand_dims(a, axis, num_newaxis)
@tvm.tag_scope(tag=tag.BROADCAST)
def expand_like(a, shape_like, axis):
"""Expand an input array with the shape of second array.
This operation can always be composed of unsqueezing and
......@@ -79,7 +77,6 @@ def expand_like(a, shape_like, axis):
return tvm.compute(shape_like.shape, _compute)
@tvm.tag_scope(tag=tag.INJECTIVE)
def transpose(a, axes=None):
"""Permute the dimensions of an array.
......@@ -141,7 +138,6 @@ def strided_slice(a, begin, end, strides=None):
return cpp.strided_slice(a, begin, end, strides)
@tvm.tag_scope(tag=tag.INJECTIVE)
def reshape(a, newshape):
"""Reshape the array
......@@ -159,7 +155,6 @@ def reshape(a, newshape):
return cpp.reshape(a, newshape)
@tvm.tag_scope(tag=tag.INJECTIVE)
def squeeze(a, axis=None):
"""Remove single-dimensional entries from the shape of an array.
......@@ -178,7 +173,6 @@ def squeeze(a, axis=None):
return cpp.squeeze(a, axis)
@tvm.tag_scope(tag=tag.INJECTIVE)
def concatenate(a_tuple, axis=0):
"""Join a sequence of arrays along an existing axis.
......@@ -197,7 +191,6 @@ def concatenate(a_tuple, axis=0):
return cpp.concatenate(a_tuple, axis)
@tvm.tag_scope(tag=tag.INJECTIVE)
def split(ary, indices_or_sections, axis=0):
"""Split an array into multiple sub-arrays.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment