Commit 10ea05e6 by Jared Roesch Committed by Tianqi Chen

[RELAY][RUNTIME] Add Relay interpreter and compiler for TVM runtime system. (#1954)

parent 07399e02
...@@ -22,8 +22,15 @@ namespace tvm { ...@@ -22,8 +22,15 @@ namespace tvm {
* You can find more about Relay by reading the language reference. * You can find more about Relay by reading the language reference.
*/ */
namespace relay { namespace relay {
#define RELAY_DEBUG(...) \
{ auto fdebug = runtime::Registry::Get("relay.debug"); \
CHECK(fdebug) << "Could not find Relay Python debugger function."; \
(*fdebug)("RELAY_DEBUG", __FILE__, __LINE__, __VA_ARGS__); \
}
/*! /*!
* \brief we always used NodeRef for referencing nodes. * \brief We always used NodeRef for referencing nodes.
* *
* By default, NodeRef is a std::shared_ptr of node * By default, NodeRef is a std::shared_ptr of node
*/ */
......
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/build_module.h
* \brief The passes and data structures needed to build a
* tvm::Module from a Relay program.
*/
#ifndef TVM_RELAY_BUILD_MODULE_H_
#define TVM_RELAY_BUILD_MODULE_H_
#include <tvm/lowered_func.h>
#include <tvm/relay/environment.h>
#include <tvm/relay/expr.h>
#include <string>
namespace tvm {
namespace relay {
/*! \brief A lowered Relay operation.
*
* A lowered operation is a pair containing the "primitive" function used
* to produce the lowered function as well as the lowered function itself.
*/
class LoweredOp;
/*! \brief Call container. */
class LoweredOpNode : public Node {
public:
/*!
* \brief The primitive function to be lowered.
*
* A primitive function consists only of calls to relay::Op which
* can be fused.
*/
Function func;
/*!
* \brief The lowered function.
*/
LoweredFunc lowered_func;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("func", &func);
v->Visit("lowered_func", &lowered_func);
}
TVM_DLL static LoweredOp make(
Function func,
LoweredFunc lowered_func);
static constexpr const char* _type_key = "relay.LoweredOp";
TVM_DECLARE_NODE_TYPE_INFO(LoweredOpNode, Node);
};
RELAY_DEFINE_NODE_REF(LoweredOp, LoweredOpNode, NodeRef);
/*!
* \brief Lower the operations contained in a Relay expression.
*
* The lowering pass will only lower functions marked as primitive,
* the FuseOps pass will provide this behavior, if run before LowerOps.
*
* \note This will do a reachability analysis and lower all definitions
* reachable from the provided expression.
*
* \param env The environment.
* \param expr The expression with operations to be lowered.
* \param target The target to lower the functions to.
*
* \return The set of lowered operations.
*/
Array<LoweredOp> LowerOps(const Environment& env, const Expr& expr,
const std::string& target = "llvm");
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_BUILD_MODULE_H_
...@@ -213,12 +213,18 @@ class FunctionNode : public ExprNode { ...@@ -213,12 +213,18 @@ class FunctionNode : public ExprNode {
*/ */
tvm::Array<TypeVar> type_params; tvm::Array<TypeVar> type_params;
/*!
* \brief The attributes which store metadata about functions.
*/
tvm::Attrs attrs;
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("params", &params); v->Visit("params", &params);
v->Visit("body", &body); v->Visit("body", &body);
v->Visit("ret_type", &ret_type); v->Visit("ret_type", &ret_type);
v->Visit("type_params", &type_params); v->Visit("type_params", &type_params);
v->Visit("span", &span); v->Visit("span", &span);
v->Visit("attrs", &attrs);
v->Visit("_checked_type_", &checked_type_); v->Visit("_checked_type_", &checked_type_);
} }
...@@ -233,7 +239,8 @@ class FunctionNode : public ExprNode { ...@@ -233,7 +239,8 @@ class FunctionNode : public ExprNode {
TVM_DLL static Function make(tvm::Array<Var> params, TVM_DLL static Function make(tvm::Array<Var> params,
Expr body, Expr body,
Type ret_type, Type ret_type,
tvm::Array<TypeVar> ty_params); tvm::Array<TypeVar> ty_params,
tvm::Attrs attrs = Attrs());
static constexpr const char* _type_key = "relay.Function"; static constexpr const char* _type_key = "relay.Function";
TVM_DECLARE_NODE_TYPE_INFO(FunctionNode, ExprNode); TVM_DECLARE_NODE_TYPE_INFO(FunctionNode, ExprNode);
...@@ -241,6 +248,11 @@ class FunctionNode : public ExprNode { ...@@ -241,6 +248,11 @@ class FunctionNode : public ExprNode {
RELAY_DEFINE_NODE_REF(Function, FunctionNode, Expr); RELAY_DEFINE_NODE_REF(Function, FunctionNode, Expr);
TVM_DLL NodeRef FunctionGetAttr(const Function& func, const std::string& key);
TVM_DLL Function FunctionSetAttr(const Function& func, const std::string& key, const NodeRef& data);
/*! /*!
* \brief Call corresponds to operator invocation. * \brief Call corresponds to operator invocation.
* Corresponds to the operator in computational graph terminology. * Corresponds to the operator in computational graph terminology.
......
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/interpreter.h
* \brief An interpreter for Relay.
*
* This file implements a simple reference interpreter for Relay programs.
* Given a Relay environment, and a Relay expression it produces a value.
*
* The interpreter's values are a naive representation of the values that
* can be produced by a Relay program and are exposed via tvm::Node's
* system to Python for introspection and debugging.
*
* The interpreter's intent is to serve as a reference semantics for the Relay IR,
* as well as for debugging and testing.
*/
#ifndef TVM_RELAY_INTERPRETER_H_
#define TVM_RELAY_INTERPRETER_H_
#include <tvm/relay/environment.h>
#include <tvm/relay/expr.h>
namespace tvm {
namespace relay {
/*!
* \brief A Relay value.
*/
class Value;
/*! \brief Evaluate an expression using the interpreter producing a value.
*
* The resulting value can be passed to Python, making it easy to use
* for testing and debugging.
*
* The interpreter interprets the program fragments not supported by the
* TVM runtime, although the interpreter is naively implemented it uses
* TVM operators for evaluating all operators.
*
* Our intent is that this will never be the most efficient implementation of
* Relay's semantics, but a readable and clear one.
*/
Value Evaluate(Environment env, Expr e);
/*! \brief The base container type of Relay values. */
class ValueNode : public RelayNode {
public:
static constexpr const char* _type_key = "relay.Value";
TVM_DECLARE_BASE_NODE_INFO(ValueNode, RelayNode);
};
class Value : public NodeRef {
public:
Value() {}
explicit Value(NodePtr<Node> n) : NodeRef(n) {}
const ValueNode* operator->() const {
return static_cast<const ValueNode*>(node_.get());
}
using ContainerType = ValueNode;
};
/*! \brief A Relay closure, i.e a scope and a function. */
class Closure;
/*! \brief The container type of Closures. */
class ClosureNode : public ValueNode {
public:
/*! \brief The set of free variables in the closure.
*
* These are the captured variables which are required for
* evaluation when we call the closure.
*/
tvm::Map<Var, Value> env;
/*! \brief The function which implements the closure.
*
* \note May reference the variables contained in the env.
*/
Function func;
ClosureNode() {}
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("env", &env);
v->Visit("func", &func);
}
TVM_DLL static Closure make(tvm::Map<Var, Value> env, Function func);
static constexpr const char* _type_key = "relay.Closure";
TVM_DECLARE_NODE_TYPE_INFO(ClosureNode, ValueNode);
};
RELAY_DEFINE_NODE_REF(Closure, ClosureNode, Value);
/*! \brief A tuple value. */
class TupleValue;
/*! \brief Tuple (x, ... y). */
struct TupleValueNode : ValueNode {
tvm::Array<Value> fields;
TupleValueNode() {}
void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("fields", &fields); }
TVM_DLL static TupleValue make(tvm::Array<Value> value);
static constexpr const char* _type_key = "relay.TupleValue";
TVM_DECLARE_NODE_TYPE_INFO(TupleValueNode, ValueNode);
};
RELAY_DEFINE_NODE_REF(TupleValue, TupleValueNode, Value);
/*! \brief A tensor value. */
class TensorValue;
/*! \brief The tensor value container, wrapping an NDArray. */
struct TensorValueNode : ValueNode {
runtime::NDArray data;
TensorValueNode() {}
void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("data", &data); }
/*! \brief Build a value from an NDArray. */
TVM_DLL static TensorValue make(runtime::NDArray data);
/*! \brief Construct an empty tensor value from t. */
TVM_DLL static TensorValue FromType(const Type& t);
static constexpr const char* _type_key = "relay.TensorValue";
TVM_DECLARE_NODE_TYPE_INFO(TensorValueNode, ValueNode);
};
RELAY_DEFINE_NODE_REF(TensorValue, TensorValueNode, Value);
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_INTERPRETER_H_
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <tvm/relay/environment.h> #include <tvm/relay/environment.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <string>
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -20,7 +21,8 @@ namespace relay { ...@@ -20,7 +21,8 @@ namespace relay {
* populated with the result type. * populated with the result type.
* *
* \param expr The expression to type check. * \param expr The expression to type check.
* \param env The environment used for referencing global functions, can be None. * \param env The environment used for referencing global functions, can be
* None.
* *
* \return A type checked expression with its checked_type field populated. * \return A type checked expression with its checked_type field populated.
*/ */
...@@ -35,7 +37,8 @@ Expr InferType(const Expr& expr, const Environment& env); ...@@ -35,7 +37,8 @@ Expr InferType(const Expr& expr, const Environment& env);
* \return A type checked Function with its checked_type field populated. * \return A type checked Function with its checked_type field populated.
* \note this function mutates env and is not thread-safe. * \note this function mutates env and is not thread-safe.
*/ */
Function InferType(const Function& f, const Environment& env, const GlobalVar& var); Function InferType(const Function& f, const Environment& env,
const GlobalVar& var);
/*! /*!
* \brief Check that types are well kinded by applying "kinding rules". * \brief Check that types are well kinded by applying "kinding rules".
...@@ -94,28 +97,30 @@ bool AlphaEqual(const Type& t1, const Type& t2); ...@@ -94,28 +97,30 @@ bool AlphaEqual(const Type& t1, const Type& t2);
* *
* For example, the expression `let x = 1 in let x = 2 in 3` bound x twice. * For example, the expression `let x = 1 in let x = 2 in 3` bound x twice.
* *
* `let f = (\x -> x) in let g = (\x -> x + 1) in f(g(2))` also bound x twice, although x is not shadowed. * `let f = (\x -> x) in let g = (\x -> x + 1) in f(g(2))` also bound x twice,
* although x is not shadowed.
* *
* \param e the expression to check. * \param expr the expression to check.
* *
* \return true iff all Var in e is bound at most once. * \return true iff all Var in expr is bound at most once.
*/ */
bool WellFormed(const Expr& e); bool WellFormed(const Expr& expr);
/*! \brief Get free Vars from expr in PostDFS order. /*! \brief Get free type parameters from expression expr.
* *
* Free variables are variables that are not bound by a * Free variables are variables that are not bound by a
* let or a function parameter in the context. * let or a function parameter in the context.
* *
* \param expr the expression. * \param expr the expression.
* *
* \return List of free vars, in the PostDFS order visited by expr. * \return List of free vars, in the PostDFS order in the expression.
*/ */
tvm::Array<Var> FreeVars(const Expr& expr); tvm::Array<Var> FreeVars(const Expr& expr);
/*! \brief Get free TypeVars from expression expr. /*! \brief Get free TypeVars from expression expr.
* *
* Free type parameters are type parameters that are not bound by a function type in the context. * Free type parameters are type parameters that are not bound by a function
* type in the context.
* *
* \param expr the expression. * \param expr the expression.
* *
...@@ -125,10 +130,12 @@ tvm::Array<TypeVar> FreeTypeVars(const Expr& expr); ...@@ -125,10 +130,12 @@ tvm::Array<TypeVar> FreeTypeVars(const Expr& expr);
/*! \brief Remove expressions which does not effect the program result. /*! \brief Remove expressions which does not effect the program result.
* *
* It will remove let binding that are not referenced, and if branch that are not entered. * It will remove let bindings which are not referenced, and branches that will
* not be entered.
* *
* For example, this pass should turn `let a = 1 in 2` into `2`, as the value of the expression does not depend on a. * For example, this pass should turn `let a = 1 in 2` into `2`, as the value of
* Another example is `if (true) then 1 else 2` will be optimized into 1. * the expression does not depend on a. Another example is `if (true) then 1
* else 2` will be optimized into 1.
* *
* \param e the expression to optimize. * \param e the expression to optimize.
* *
...@@ -136,7 +143,9 @@ tvm::Array<TypeVar> FreeTypeVars(const Expr& expr); ...@@ -136,7 +143,9 @@ tvm::Array<TypeVar> FreeTypeVars(const Expr& expr);
*/ */
Expr DeadCodeElimination(const Expr& e); Expr DeadCodeElimination(const Expr& e);
/*! \brief Hash a Relay type. /*! \brief A hashing structure in the style of std::hash. */
struct StructuralHash {
/*! \brief Hash a Relay type.
* *
* Implements structural hashing of a Relay type. * Implements structural hashing of a Relay type.
* *
...@@ -144,9 +153,9 @@ Expr DeadCodeElimination(const Expr& e); ...@@ -144,9 +153,9 @@ Expr DeadCodeElimination(const Expr& e);
* *
* \return the hash value. * \return the hash value.
*/ */
size_t StructuralHash(const Type& type); size_t operator()(const Type& type) const;
/*! \brief Hash a Relay expression. /*! \brief Hash a Relay expression.
* *
* Implements structural hashing of a Relay expression. * Implements structural hashing of a Relay expression.
* *
...@@ -154,9 +163,10 @@ size_t StructuralHash(const Type& type); ...@@ -154,9 +163,10 @@ size_t StructuralHash(const Type& type);
* *
* \return the hash value. * \return the hash value.
*/ */
size_t StructuralHash(const Expr& expr); size_t operator()(const Expr& expr) const;
};
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_PASS_H_ #endif // TVM_RELAY_PASS_H_
# pylint: disable=wildcard-import, redefined-builtin, invalid-name # pylint: disable=wildcard-import, redefined-builtin, invalid-name
"""The Relay IR namespace containing the IR definition and compiler.""" """The Relay IR namespace containing the IR definition and compiler."""
from __future__ import absolute_import
from ..api import register_func
from . import base from . import base
from . import ty from . import ty
from . import expr from . import expr
...@@ -15,6 +17,7 @@ from . import nn ...@@ -15,6 +17,7 @@ from . import nn
from . import vision from . import vision
from . import image from . import image
from .scope_builder import ScopeBuilder from .scope_builder import ScopeBuilder
# Span # Span
...@@ -46,6 +49,21 @@ Let = expr.Let ...@@ -46,6 +49,21 @@ Let = expr.Let
If = expr.If If = expr.If
TupleGetItem = expr.TupleGetItem TupleGetItem = expr.TupleGetItem
# helper functions # helper functions
var = expr.var var = expr.var
const = expr.const const = expr.const
@register_func("relay._tensor_value_repr")
def _tensor_value_repr(tv):
return str(tv.data.asnumpy())
@register_func("relay._constant_repr")
def _tensor_constant_repr(tv):
return str(tv.data.asnumpy())
# pylint: disable=unused-argument
@register_func("relay.debug")
def _debug(*args):
import pdb
pdb.set_trace()
"""The interface to the Evaluator exposed from C++."""
from tvm._ffi.function import _init_api
_init_api("relay._interpreter", __name__)
...@@ -45,9 +45,12 @@ class Environment(RelayNode): ...@@ -45,9 +45,12 @@ class Environment(RelayNode):
func: Function func: Function
The function. The function.
""" """
return self._add(var, func)
def _add(self, var, func, update=False):
if isinstance(var, _base.string_types): if isinstance(var, _base.string_types):
var = _expr.GlobalVar(var) var = _expr.GlobalVar(var)
_env.Environment_Add(self, var, func) return _env.Environment_Add(self, var, func, update)
def __getitem__(self, var): def __getitem__(self, var):
"""Lookup a global function by name or by variable. """Lookup a global function by name or by variable.
......
#pylint: disable=no-else-return
"""An interface to the Realy interpreter."""
from __future__ import absolute_import
import numpy as np
from .. import register_func, nd
from .base import NodeBase, register_relay_node
from . import _make
from . import _interpreter
from . import ir_pass
from .expr import Call, Constant, GlobalVar
from . import const
from .._ffi.base import integer_types
class Value(NodeBase):
"""Base class of all values.
"""
@staticmethod
@register_func("relay.from_scalar")
def from_scalar(i, dtype=None):
"""Convert a Python scalar to a Relay scalar."""
if dtype is None:
if isinstance(i, integer_types):
dtype = 'int32'
elif isinstance(i, float):
dtype = 'float32'
elif isinstance(i, bool):
dtype = 'uint8'
else:
raise Exception("unable to infer dtype {0}".format(type(i)))
return TensorValue(nd.array(np.array(i, dtype=dtype)))
@register_relay_node
class TupleValue(Value):
def __init__(self, *fields):
self.__init_handle_by_constructor__(
_make.TupleValue, fields)
def __getitem__(self, field_no):
return self.fields[field_no]
@register_relay_node
class Closure(Value):
pass
@register_relay_node
class TensorValue(Value):
"""A Tensor value produced by the evaluator."""
def __init__(self, data):
"""Allocate a new TensorValue and copy the data from `array` into
the new array.
"""
if isinstance(data, np.ndarray):
data = nd.array(data)
self.__init_handle_by_constructor__(
_make.TensorValue, data)
def as_ndarray(self):
"""Convert a Relay TensorValue into a tvm.ndarray."""
return self.data
def asnumpy(self):
"""Convert a Relay TensorValue into a numpy.ndarray."""
return self.data.asnumpy()
def __eq__(self, other):
return self.data == other.data
def _arg_to_ast(arg):
if isinstance(arg, TensorValue):
return Constant(arg.data)
elif isinstance(arg, np.ndarray):
return Constant(nd.array(arg))
elif isinstance(arg, Constant):
return arg
else:
return const(arg)
def apply_passes(expr, env=None):
ck_expr = ir_pass.infer_type(expr, env=env)
fused_expr = ir_pass.fuse_ops(env, ck_expr)
return fused_expr
def evaluate(env, expr, *args):
"""
Evaluate a Relay expression on the interpreter.
Parameters
----------
env: tvm.relay.Environment
The global environment used.
expr: tvm.relay.Expr
The expression to evaluate.
args: list of tvm.relay.Expr
The arguments to apply to the expression, only works
if the expression has a function type.
Returns
-------
value: tvm.relay.eval.Value
The value produced by evaluating the expression.
"""
# assert len(args) == 0
relay_args = []
for arg in args:
relay_args.append(_arg_to_ast(arg))
# TODO: We need to move this optimization code into the optimizer/pass manager
if isinstance(expr, GlobalVar):
func = env[expr]
func = apply_passes(func, env)
env._add(expr, func, True)
opt_expr = Call(expr, relay_args)
# import pdb; pdb.set_trace()
return _interpreter.evaluate(env, opt_expr)
else:
expr = Call(expr, relay_args)
opt_expr = apply_passes(expr, env)
return _interpreter.evaluate(env, opt_expr)
...@@ -240,3 +240,9 @@ def structural_hash(value): ...@@ -240,3 +240,9 @@ def structural_hash(value):
msg = ("found value of type {0} expected" + msg = ("found value of type {0} expected" +
"relay.Expr or relay.Type").format(type(value)) "relay.Expr or relay.Type").format(type(value))
raise TypeError(msg) raise TypeError(msg)
def fuse_ops(expr, env):
return _ir_pass.FuseOps(env, expr)
def lower_ops(env, expr, target='llvm'):
return _ir_pass.LowerOps(env, expr, target)
#pylint: disable=invalid-name #pylint: disable=invalid-name, unused-argument
"""Backend compiler related feature registration""" """Backend compiler related feature registration"""
import tvm
import topi
from . import register
def add_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.add(inputs[0], inputs[1])]
def add_schedule(outputs, target):
assert len(outputs) == 1
return tvm.create_schedule(outputs[0].op)
register("add", "FTVMCompute", add_compute)
register("add", "FTVMSchedule", add_schedule)
def subtract_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.subtract(inputs[0], inputs[1])]
def subtract_schedule(outputs, target):
assert len(outputs) == 1
return tvm.create_schedule(outputs[0].op)
register("subtract", "FTVMCompute", subtract_compute)
register("subtract", "FTVMSchedule", subtract_schedule)
def multiply_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.multiply(inputs[0], inputs[1])]
def multiply_schedule(outputs, target):
assert len(outputs) == 1
return tvm.create_schedule(outputs[0].op)
register("multiply", "FTVMCompute", multiply_compute)
register("multiply", "FTVMSchedule", multiply_schedule)
def equal_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.equal(inputs[0], inputs[1])]
def equal_schedule(outputs, target):
assert len(outputs) == 1
return tvm.create_schedule(outputs[0].op)
register("equal", "FTVMCompute", equal_compute)
register("equal", "FTVMSchedule", equal_schedule)
#pylint: disable=invalid-name, unused-argument
"""Backend compiler related feature registration"""
import tvm
import topi
from .. import register
def dense_compiler(attrs, inputs, output_type):
assert len(inputs) == 2
return [topi.nn.dense(inputs[0], inputs[1])]
def dense_schedule(outputs, target):
assert len(outputs) == 1
return tvm.create_schedule(outputs[0].op)
register("nn.dense", "FTVMCompute", dense_compiler)
register("nn.dense", "FTVMSchedule", dense_schedule)
...@@ -3,7 +3,8 @@ from ..._ffi.function import _init_api ...@@ -3,7 +3,8 @@ from ..._ffi.function import _init_api
from ..base import register_relay_node from ..base import register_relay_node
from ..expr import Expr from ..expr import Expr
from ...api import register_func
from ...build_module import lower, build
@register_relay_node @register_relay_node
class Op(Expr): class Op(Expr):
...@@ -75,3 +76,11 @@ def register(op_name, attr_key, value=None, level=10): ...@@ -75,3 +76,11 @@ def register(op_name, attr_key, value=None, level=10):
_init_api("relay.op", __name__) _init_api("relay.op", __name__)
@register_func("relay.op.compiler._lower")
def _lower(name, schedule, inputs, outputs):
return lower(schedule, list(inputs) + list(outputs), name=name)
@register_func("relay.op.compiler._build")
def _build(lowered_funcs):
return build(lowered_funcs, target="llvm")
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
""" """
a simple multilayer perceptron a simple multilayer perceptron
""" """
from __future__ import absolute_import
from tvm import relay from tvm import relay
from .init import create_workload from .init import create_workload
......
...@@ -66,3 +66,5 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -66,3 +66,5 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -49,9 +49,16 @@ void EnvironmentNode::Add(const GlobalVar& var, ...@@ -49,9 +49,16 @@ void EnvironmentNode::Add(const GlobalVar& var,
<< "Environment#update changes type, not possible in this mode."; << "Environment#update changes type, not possible in this mode.";
} }
this->functions.Set(var, checked_func); this->functions.Set(var, checked_func);
// set gloval var map
auto it = global_var_map_.find(var->name_hint);
if (it != global_var_map_.end()) {
CHECK_EQ((*it).second, var);
} else {
// set global var map
CHECK(!global_var_map_.count(var->name_hint)) CHECK(!global_var_map_.count(var->name_hint))
<< "Duplicate global function name " << var->name_hint; << "Duplicate global function name " << var->name_hint;
}
global_var_map_.Set(var->name_hint, var); global_var_map_.Set(var->name_hint, var);
} }
...@@ -94,7 +101,7 @@ TVM_REGISTER_API("relay._make.Environment") ...@@ -94,7 +101,7 @@ TVM_REGISTER_API("relay._make.Environment")
TVM_REGISTER_API("relay._env.Environment_Add") TVM_REGISTER_API("relay._env.Environment_Add")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
Environment env = args[0]; Environment env = args[0];
env->Add(args[1], args[2], false); env->Add(args[1], args[2], args[3]);
}); });
TVM_REGISTER_API("relay._env.Environment_GetGlobalVar") TVM_REGISTER_API("relay._env.Environment_GetGlobalVar")
......
...@@ -26,7 +26,10 @@ TVM_REGISTER_API("relay._make.Constant") ...@@ -26,7 +26,10 @@ TVM_REGISTER_API("relay._make.Constant")
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<ConstantNode>([](const ConstantNode* node, tvm::IRPrinter* p) { .set_dispatch<ConstantNode>([](const ConstantNode* node, tvm::IRPrinter* p) {
p->stream << "Constant(TODO)"; const PackedFunc* fprint = Registry::Get("relay._constant_repr");
CHECK(fprint) << "unable to find printing function for constants";
std::string data = (*fprint)(GetRef<Constant>(node));
p->stream << "Constant(" << data << ")";
}); });
TensorType ConstantNode::tensor_type() const { TensorType ConstantNode::tensor_type() const {
...@@ -104,12 +107,14 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -104,12 +107,14 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
Function FunctionNode::make(tvm::Array<Var> params, Function FunctionNode::make(tvm::Array<Var> params,
Expr body, Expr body,
Type ret_type, Type ret_type,
tvm::Array<TypeVar> type_params) { tvm::Array<TypeVar> type_params,
tvm::Attrs attrs) {
NodePtr<FunctionNode> n = make_node<FunctionNode>(); NodePtr<FunctionNode> n = make_node<FunctionNode>();
n->params = std::move(params); n->params = std::move(params);
n->body = std::move(body); n->body = std::move(body);
n->ret_type = std::move(ret_type); n->ret_type = std::move(ret_type);
n->type_params = std::move(type_params); n->type_params = std::move(type_params);
n->attrs = std::move(attrs);
return Function(n); return Function(n);
} }
...@@ -121,6 +126,39 @@ FuncType FunctionNode::func_type_annotation() const { ...@@ -121,6 +126,39 @@ FuncType FunctionNode::func_type_annotation() const {
return FuncTypeNode::make(param_types, this->ret_type, this->type_params, {}); return FuncTypeNode::make(param_types, this->ret_type, this->type_params, {});
} }
NodeRef FunctionGetAttr(const Function& func, const std::string& key) {
if (!func->attrs.defined()) { return NodeRef(); }
const DictAttrsNode* dict_attrs = func->attrs.as<DictAttrsNode>();
CHECK(dict_attrs);
auto it = dict_attrs->dict.find(key);
if (it != dict_attrs->dict.end()) {
return (*it).second;
} else {
return NodeRef();
}
}
Function FunctionSetAttr(const Function& func, const std::string& key, const NodeRef& data) {
const DictAttrsNode* dattrs = func->attrs.as<DictAttrsNode>();
Attrs func_attrs;
if (dattrs) {
Map<std::string, NodeRef> dict = dattrs->dict;
dict.Set(key, data);
func_attrs = DictAttrsNode::make(dict);
} else {
Map<std::string, NodeRef> dict = {{key, data}};
func_attrs = DictAttrsNode::make(dict);
}
return FunctionNode::make(
func->params,
func->body,
func->ret_type,
func->type_params,
func_attrs);
}
TVM_REGISTER_NODE_TYPE(FunctionNode); TVM_REGISTER_NODE_TYPE(FunctionNode);
TVM_REGISTER_API("relay._make.Function") TVM_REGISTER_API("relay._make.Function")
...@@ -132,7 +170,8 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -132,7 +170,8 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<FunctionNode>([](const FunctionNode* node, .set_dispatch<FunctionNode>([](const FunctionNode* node,
tvm::IRPrinter* p) { tvm::IRPrinter* p) {
p->stream << "FunctionNode(" << node->params << ", " << node->ret_type p->stream << "FunctionNode(" << node->params << ", " << node->ret_type
<< ", " << node->body << ", " << node->type_params << ")"; << ", " << node->body << ", " << node->type_params << ", "
<< node->attrs << ")";
}); });
Call CallNode::make(Expr op, Array<Expr> args, Attrs attrs, Call CallNode::make(Expr op, Array<Expr> args, Attrs attrs,
......
...@@ -92,7 +92,7 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* op) { ...@@ -92,7 +92,7 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* op) {
body.same_as(op->body)) { body.same_as(op->body)) {
return GetRef<Expr>(op); return GetRef<Expr>(op);
} else { } else {
return FunctionNode::make(params, body, ret_type, ty_params); return FunctionNode::make(params, body, ret_type, ty_params, op->attrs);
} }
} }
...@@ -198,6 +198,7 @@ void ExprVisitor::ExprVisitor::VisitExpr_(const FunctionNode* op) { ...@@ -198,6 +198,7 @@ void ExprVisitor::ExprVisitor::VisitExpr_(const FunctionNode* op) {
void ExprVisitor::VisitExpr_(const CallNode* op) { void ExprVisitor::VisitExpr_(const CallNode* op) {
this->VisitExpr(op->op); this->VisitExpr(op->op);
for (auto ty_arg : op->type_args) { for (auto ty_arg : op->type_args) {
this->VisitType(ty_arg); this->VisitType(ty_arg);
} }
......
...@@ -285,11 +285,11 @@ class RelayHashHandler: ...@@ -285,11 +285,11 @@ class RelayHashHandler:
int var_counter = 0; int var_counter = 0;
}; };
size_t StructuralHash(const Type& type) { size_t StructuralHash::operator()(const Type& type) const {
return RelayHashHandler().TypeHash(type); return RelayHashHandler().TypeHash(type);
} }
size_t StructuralHash(const Expr& expr) { size_t StructuralHash::operator()(const Expr& expr) const {
return RelayHashHandler().ExprHash(expr); return RelayHashHandler().ExprHash(expr);
} }
......
/*!
* Copyright (c) 2018 by Contributors
*
* \file src/tvm/relay/pass/fuse_ops.cc
*
* \brief Fuse Relay eligble sequences of Relay operators into a single one.
*
*/
#include <tvm/relay/pass.h>
#include <tvm/runtime/module.h>
#include <tvm/lowered_func.h>
#include <tvm/operation.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/logging.h>
#include "../ir/type_functor.h"
namespace tvm {
namespace relay {
using namespace runtime;
struct AbstractFusableOps : ExprMutator {
Environment env;
Array<GlobalVar> fusable_funcs;
int counter = 0;
size_t expr_hash;
AbstractFusableOps(Environment env, size_t expr_hash) : env(env), expr_hash(expr_hash) {}
Expr VisitExpr_(const CallNode* call) {
if (auto op_node = call->op.as<OpNode>()) {
// Placeholder fusion algorithm which abstracts
// single definitions into functions only.
Array<Var> params;
Array<Expr> inner_args;
Array<Expr> args;
int param_number = 0;
for (auto arg : call->args) {
auto name = std::string("p") + std::to_string(param_number++);
auto type = arg->checked_type();
auto var = VarNode::make(name, type);
params.push_back(var);
inner_args.push_back(var);
args.push_back(VisitExpr(arg));
}
auto body = CallNode::make(call->op, inner_args, call->attrs);
auto func = FunctionNode::make(params, body, call->checked_type(), {});
func = FunctionSetAttr(func, "Primitive", tvm::Integer(1));
std::string func_name = "fused_";
func_name += op_node->name;
func_name += "_";
func_name += std::to_string(counter++);
func_name += "_";
func_name += std::to_string(expr_hash);
auto gv = GlobalVarNode::make(func_name);
env->Add(gv, func);
fusable_funcs.push_back(gv);
return CallNode::make(gv, args, Attrs());
} else {
return ExprMutator::VisitExpr_(call);
}
}
};
Expr FuseOps(const Environment& env, const Expr& e) {
// First we convert all chains of fusable ops into
// abstracted functions which we mark as primtive
// then we convert these primtive functions into
// new operators.
auto abstract = AbstractFusableOps(env, StructuralHash()(e));
auto abstracted_e = abstract.VisitExpr(e);
RELAY_LOG(INFO) << "FuseOps: before=" << e
<< "Fuse: after=" << abstracted_e;
return abstracted_e;
}
TVM_REGISTER_API("relay._ir_pass.FuseOps")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = FuseOps(args[1], args[0]);
});
} // namespace relay
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
*
* \file src/tvm/relay/pass/lower_ops.cc
*
* \brief Lower a Relay program to set of TVM operators.
*
*/
#include <tvm/lowered_func.h>
#include <tvm/operation.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/logging.h>
#include <tvm/relay/pass.h>
#include <tvm/runtime/module.h>
#include <tvm/relay/build_module.h>
#include "../ir/type_functor.h"
namespace tvm {
namespace relay {
using namespace runtime;
LoweredOp LoweredOpNode::make(Function func, LoweredFunc lowered_func) {
auto node = make_node<LoweredOpNode>();
node->func = func;
node->lowered_func = lowered_func;
return LoweredOp(node);
}
struct AbstractLocalFunctions : ExprMutator {
Environment env;
size_t expr_hash;
int counter = 0;
std::unordered_set<GlobalVar, NodeHash, NodeEqual> visited_funcs;
explicit AbstractLocalFunctions(Environment env)
: env(env), expr_hash(0), counter(0), visited_funcs() {}
Expr Abstract(const Expr& e) {
expr_hash = StructuralHash()(e);
return VisitExpr(e);
}
Expr VisitExpr_(const GlobalVarNode* gvar_node) final {
auto gvar = GetRef<GlobalVar>(gvar_node);
auto it = visited_funcs.find(gvar);
if (it == visited_funcs.end()) {
auto func = env->Lookup(gvar);
visited_funcs.insert(gvar);
auto new_func = FunctionNode::make(
func->params,
VisitExpr(func->body),
func->ret_type,
func->type_params,
func->attrs);
env->Update(gvar, new_func);
}
return gvar;
}
Expr VisitExpr_(const FunctionNode* func_node) final {
Function func = GetRef<Function>(func_node);
auto free_vars = FreeVars(func);
Array<Var> params;
for (auto free_var : free_vars) {
auto var = VarNode::make("free_var", free_var->checked_type());
params.push_back(var);
}
std::string abs_func = "abstracted_func_";
abs_func += std::to_string(counter++);
abs_func += std::to_string(expr_hash);
auto gv = GlobalVarNode::make(abs_func);
auto lifted_func = FunctionNode::make(params, func, Type(), {}, {});
env->Add(gv, lifted_func);
Array<Expr> args;
for (auto free_var : free_vars) {
args.push_back(free_var);
}
return CallNode::make(gv, args, {});
}
};
struct LiveFunctions : ExprVisitor {
Environment env;
explicit LiveFunctions(Environment env) : env(env), global_funcs() {}
std::unordered_set<GlobalVar, NodeHash, NodeEqual> visited_funcs;
std::unordered_set<GlobalVar, NodeHash, NodeEqual> global_funcs;
void Live(const Expr& e) {
CHECK(!e.as<FunctionNode>())
<< "functions should of been transformed away by previous pass";
VisitExpr(e);
}
void VisitExpr_(const FunctionNode* func_node) {
LOG(FATAL) << "functions should of been transformed away by previous pass";
}
void VisitExpr_(const GlobalVarNode* var_node) final {
GlobalVar var = GetRef<GlobalVar>(var_node);
auto it = visited_funcs.find(var);
if (it == visited_funcs.end()) {
auto func = env->Lookup(var);
visited_funcs.insert(var);
// The last pass has trasnformed functions of the form:
//
// let x = fn (p_1, ..., p_n) { ... };
// ...
//
// into, a top-level declaration:
//
// def abs_f(fv_1, ..., fv_n) {
// return (fn (p_1...,p_N) { ... };)
// }
//
// and:
//
// let x = abs_f(fv_1, ... fv_n);
//
// The only other case we can handle is
//
// fn foo(...) { body }
//
// We just search through the body in this case.
if (auto inner_func = func->body.as<FunctionNode>()) {
return VisitExpr(inner_func->body);
} else {
return VisitExpr(func->body);
}
}
}
void VisitExpr_(const CallNode* call) final {
RELAY_LOG(INFO) << "LiveOps: CallNode=" << GetRef<Call>(call);
if (auto gv_node = call->op.as<GlobalVarNode>()) {
GlobalVar gvar = GetRef<GlobalVar>(gv_node);
Function func = env->Lookup(gvar);
auto attr = FunctionGetAttr(func, "Primitive");
if (attr.defined() && Downcast<Integer>(attr)->value == 1) {
global_funcs.insert(gvar);
} else {
VisitExpr(gvar);
}
// Finally we need to ensure to visit all the args no matter what.
for (auto arg : call->args) {
VisitExpr(arg);
}
} else {
return ExprVisitor::VisitExpr_(call);
}
}
};
using FCompute = TypedPackedFunc<Array<Tensor>(
const Attrs&, const Array<Tensor>&, Type, std::string)>;
using FSchedule = TypedPackedFunc<Schedule(const Array<Tensor>&, std::string)>;
/*! \brief Return the set of operators in their TVM format. */
Array<LoweredOp> LowerOps(const Environment& env, const Expr& e,
const std::string& target) {
RELAY_LOG(INFO) << "LowerOps: e=" << e;
auto flower_ptr = Registry::Get("relay.op.compiler._lower");
CHECK(flower_ptr);
PackedFunc flower = *flower_ptr;
auto abstracted_e = AbstractLocalFunctions(env).Abstract(e);
auto live_funcs = LiveFunctions(env);
live_funcs.VisitExpr(abstracted_e);
auto schedule_reg = Op::GetAttr<FSchedule>("FTVMSchedule");
auto compute_reg = Op::GetAttr<FCompute>("FTVMCompute");
Array<LoweredOp> lowered_funcs;
for (auto func_name : live_funcs.global_funcs) {
auto func = env->Lookup(func_name);
auto call = Downcast<Call>(func->body);
auto op_node = call->op.as<OpNode>();
CHECK(op_node) << "violated invariant that primtiive calls contain a single op call";
auto op = GetRef<Op>(op_node);
RELAY_LOG(INFO) << "LowerOps: Lowering " << op->name;
CHECK(IsPrimitiveOp(op)) << "failed to lower "
<< op->name << "can only lower primitve operations";
Array<Tensor> inputs;
std::string input_name = "in";
int i = 0;
for (auto type_arg : call->type_args) {
auto tt = Downcast<TensorType>(type_arg);
inputs.push_back(PlaceholderOpNode::make(input_name + std::to_string(i),
tt->shape, tt->dtype)
.output(0));
i++;
}
auto output_tt = op->op_type->ret_type;
Array<Tensor> outputs =
compute_reg[op](call->attrs, inputs, output_tt, target);
auto schedule = schedule_reg[op](outputs, target);
size_t hash = StructuralHash()(func);
LoweredFunc lf =
flower(op->name + std::to_string(hash), schedule, inputs, outputs);
func = FunctionSetAttr(func, "LoweredFunc", lf);
env->Add(func_name, func, true);
lowered_funcs.push_back(LoweredOpNode::make(func, lf));
}
return lowered_funcs;
}
TVM_REGISTER_API("relay._ir_pass.LowerOps")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = LowerOps(args[0], args[1], args[2]);
});
} // namespace relay
} // namespace tvm
...@@ -298,8 +298,8 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { ...@@ -298,8 +298,8 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
auto* fn_ty_node = ftype.as<FuncTypeNode>(); auto* fn_ty_node = ftype.as<FuncTypeNode>();
CHECK(fn_ty_node != nullptr) CHECK(fn_ty_node != nullptr)
<< "only expressions with function types can be called, at " << "only expressions with function types can be called, found "
<< call->span; << ftype << " at " << call->span;
Array<Type> type_args; Array<Type> type_args;
FuncType fn_ty = Instantiate(fn_ty_node, &type_args); FuncType fn_ty = Instantiate(fn_ty_node, &type_args);
...@@ -505,12 +505,16 @@ Expr TypeInferencer::Infer(Expr expr) { ...@@ -505,12 +505,16 @@ Expr TypeInferencer::Infer(Expr expr) {
// Step 1: Solve the constraints. // Step 1: Solve the constraints.
solver_.Solve(); solver_.Solve();
// Step 2: Attach resolved types to checked_type field. // Step 2: Attach resolved types to checked_type field.
return Resolver(type_map_, &solver_).VisitExpr(expr); auto resolved_expr = Resolver(type_map_, &solver_).VisitExpr(expr);
CHECK(WellFormed(resolved_expr));
return resolved_expr;
} }
Expr InferType(const Expr& expr, const Environment& env) { Expr InferType(const Expr& expr, const Environment& env) {
return TypeInferencer(env).Infer(expr); auto e = TypeInferencer(env).Infer(expr);
CHECK(WellFormed(e));
return e;
} }
Function InferType(const Function& func, Function InferType(const Function& func,
...@@ -522,6 +526,7 @@ Function InferType(const Function& func, ...@@ -522,6 +526,7 @@ Function InferType(const Function& func,
Expr func_ret = TypeInferencer(env).Infer(func_copy); Expr func_ret = TypeInferencer(env).Infer(func_copy);
auto map_node = env->functions.CopyOnWrite(); auto map_node = env->functions.CopyOnWrite();
map_node->data.erase(var.node_); map_node->data.erase(var.node_);
CHECK(WellFormed(func_ret));
return Downcast<Function>(func_ret); return Downcast<Function>(func_ret);
} }
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
* *
* \file util.cc * \file util.cc
* *
* \brief simple util for relay. * \brief Utility functions for Relay.
*/ */
#include <tvm/relay/pass.h> #include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
......
import numpy as np
from tvm import relay
from tvm.relay.ir_pass import infer_type
from tvm.relay.interpreter import evaluate
from tvm.relay.graph_runtime_codegen import graph_evaluate
from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay.op import add
from tvm.relay.env import Environment
# @tq, @jr should we put this in testing ns?
def check_rts(env, expr, args, expected_result):
"""
Check that evaluating `expr` applied to the arguments produces
`result` on both the evaluator and TVM runtime.
Parameters
----------
expr:
The expression to evaluate
args: list of Expr
The arguments to supply the expr.
expected_result:
The expected result of running the expression.
"""
eval_result = evaluate(env, expr, *args)
rts_result = graph_evaluate(env, expr, *args)
np.testing.assert_allclose(eval_result.asnumpy(), rts_result.asnumpy())
def test_add_op_scalar():
"""
Program:
fn (x, y) {
return x + y;
}
"""
env = Environment()
x = relay.var('x', shape=())
y = relay.var('y', shape=())
func = relay.Function([x, y], add(x, y))
x_data = np.array(10.0, dtype='float32')
y_data = np.array(1.0, dtype='float32')
check_rts(env, func, [x_data, y_data], x_data + y_data)
def test_add_op_tensor():
"""
Program:
fn (x, y) {
return x + y;
}
"""
env = Environment()
x = relay.var('x', shape=(10, 5))
y = relay.var('y', shape=(10, 5))
func = relay.Function([x, y], add(x, y))
x_data = np.random.rand(10, 5).astype('float32')
y_data = np.random.rand(10, 5).astype('float32')
check_rts(env, func, [x_data, y_data], x_data + y_data)
def test_add_op_broadcast():
"""
Program:
fn (x, y) {
return x + y;
}
"""
env = Environment()
x = relay.var('x', shape=(10, 5))
y = relay.var('y', shape=(1, 5))
func = relay.Function([x, y], add(x, y))
x_data = np.random.rand(10, 5).astype('float32')
y_data = np.random.rand(1, 5).astype('float32')
check_rts(env, func, [x_data, y_data], x_data + y_data)
if __name__ == "__main__":
test_add_op_scalar()
test_add_op_tensor()
test_add_op_broadcast()
import numpy as np
import tvm
from tvm import relay
from tvm.relay.interpreter import Value, TupleValue, evaluate
from tvm.relay import op
from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay import testing
def check_eval(expr, args, expected_result, env=None, rtol=1e-07):
if env is None:
env = relay.env.Environment({})
result = evaluate(env, expr, *args)
np.testing.assert_allclose(result.asnumpy(), expected_result, rtol=rtol)
def test_from_scalar():
np.testing.assert_allclose(Value.from_scalar(1, 'int32').asnumpy(), 1)
np.testing.assert_allclose(Value.from_scalar(10.0, 'float32').asnumpy(), 10.0)
np.testing.assert_allclose(Value.from_scalar(True).asnumpy(), True)
def test_tuple_value():
tv = TupleValue(Value.from_scalar(
1), Value.from_scalar(2), Value.from_scalar(3))
np.testing.assert_allclose(tv[0].asnumpy(), 1)
np.testing.assert_allclose(tv[1].asnumpy(), 2)
np.testing.assert_allclose(tv[2].asnumpy(), 3)
def test_id():
x = relay.var('x', 'float32')
ident = relay.Function([x], x)
env = relay.env.Environment({})
res = evaluate(env, ident, 1.0)
check_eval(ident, [1.0], 1.0)
def test_add_const():
two = op.add(relay.const(1), relay.const(1))
func = relay.Function([], two)
check_eval(func, [], 2)
def test_mul_param():
x = relay.var('x', shape=(10, 10))
y = relay.var('y', shape=(1, 10))
func = relay.Function([x, y], op.multiply(x, y))
x_data = np.random.rand(10, 10).astype('float32')
y_data = np.random.rand(1, 10).astype('float32')
check_eval(func, [x_data, y_data], x_data * y_data)
# failing due to numeric issues
# def test_dense():
# x = relay.var('x', shape=(10, 10))
# w = relay.var('w', shape=(10, 10))
# y = op.nn.dense(x, w)
# func = relay.Function([x, w], y)
# x_data = np.random.rand(10, 10).astype('float32')
# w_data = np.random.rand(10, 10).astype('float32')
# check_eval(func, [x_data, w_data], x_data @ w_data, rtol=0.1)
# def test_linear():
# x = relay.var('x', shape=(10, 10))
# w = relay.var('w', shape=(10, 10))
# b = relay.var('b', shape=(10,))
# y = op.add(op.nn.dense(x, w), b)
# func = relay.Function([x, w, b], y)
# x_data = np.random.rand(10, 10).astype('float32')
# w_data = np.random.rand(10, 10).astype('float32')
# b_data = np.random.rand(10).astype('float32')
# check_eval(func, [x_data, w_data, b_data], x_data @ w_data + b_data)
def test_equal():
i = relay.var('i', shape=[], dtype='int32')
j = relay.var('i', shape=[], dtype='int32')
z = op.equal(i, j)
func = relay.Function([i, j], z, ret_type=relay.TensorType([], 'bool'))
i_data = relay.const(0)
j_data = relay.const(0)
check_eval(func, [i_data, j_data], True)
def test_subtract():
i = relay.var('i', shape=[], dtype='int32')
sub = op.subtract(i, relay.const(1, dtype='int32'))
func = relay.Function([i], sub, ret_type=relay.TensorType([], 'int32'))
i_data = np.array(1, dtype='int32')
check_eval(func, [i_data], 0)
def test_simple_loop():
env = relay.env.Environment({})
sum_up = relay.GlobalVar('sum_up')
i = relay.var('i', shape=[], dtype='int32')
sb = ScopeBuilder()
with sb.if_scope(op.equal(i, relay.const(0, dtype='int32'))):
sb.ret(i)
with sb.else_scope():
one_less = op.subtract(i, relay.const(1, dtype='int32'))
rec_call = relay.Call(sum_up, [one_less])
sb.ret(op.add(rec_call, i))
func = relay.Function([i], sb.get(), ret_type=relay.TensorType([], 'int32'))
env[sum_up] = func
i_data = np.array(10, dtype='int32')
check_eval(sum_up, [i_data], sum(range(1, 11)), env=env)
def test_loop():
env = relay.env.Environment({})
sum_up = relay.GlobalVar('sum_up')
i = relay.var('i', shape=[], dtype='int32')
accum = relay.var('accum', shape=[], dtype='int32')
sb = ScopeBuilder()
with sb.if_scope(op.equal(i, relay.const(0))):
sb.ret(accum)
with sb.else_scope():
one_less = op.subtract(i, relay.const(1))
new_accum = op.add(accum, i)
sb.ret(relay.Call(sum_up, [one_less, new_accum]))
func = relay.Function([i, accum], sb.get())
env[sum_up] = func
i_data = np.array(10, dtype='int32')
accum_data = np.array(0, dtype='int32')
check_eval(sum_up, [i_data, accum_data], sum(range(1, 11)), env=env)
def test_mlp():
pass
# net = testing.mlp.get_workload(1)
# import pdb; pdb.set_trace()
if __name__ == "__main__":
test_id()
test_add_const()
# test_dense()
# test_linear()
test_equal()
test_subtract()
test_simple_loop()
test_loop()
test_mlp()
...@@ -5,6 +5,16 @@ import tvm ...@@ -5,6 +5,16 @@ import tvm
import numpy as np import numpy as np
from tvm.relay.ir_pass import infer_type from tvm.relay.ir_pass import infer_type
from tvm import relay from tvm import relay
from tvm.relay import op
from tvm.relay.scope_builder import ScopeBuilder
def assert_has_type(expr, typ, env=relay.env.Environment({})):
checked_expr = infer_type(expr, env)
checked_type = checked_expr.checked_type
if checked_type != typ:
raise RuntimeError("Type mismatch %s vs %s" % (
checked_type, typ))
def test_monomorphic_let(): def test_monomorphic_let():
...@@ -16,6 +26,31 @@ def test_monomorphic_let(): ...@@ -16,6 +26,31 @@ def test_monomorphic_let():
assert xchecked.checked_type == relay.scalar_type("float64") assert xchecked.checked_type == relay.scalar_type("float64")
def test_single_op():
"Program: fn (x : float32) { let t1 = f(x); t1 }"
x = relay.var('x', shape=[])
func = relay.Function([x], op.log(x))
ttype = relay.TensorType([], dtype='float32')
assert_has_type(func, relay.FuncType([ttype], ttype))
def test_add_broadcast_op():
"""
Program:
fn (x: Tensor[(10, 4), f32], y: Tensor[(5, 10, 1), f32]) -> Tensor[(5, 10, 4), f32] {
return x + y;
}
"""
pass
# x = relay.var('x', shape=(10, 4))
# y = relay.var('y', shape=(5, 10, 1))
# z = x + y
# func = relay.Function([x, y], z)
# ttype = relay.TensorType((5, 5, 5), 'float32')
# expected_ty = relay.FuncType([ttype, ttype], ttype)
# assert_has_type(func.to_func(), expected_ty)
def test_dual_op(): def test_dual_op():
"""Program: """Program:
fn (x : Tensor[f32, (10, 10)]) { fn (x : Tensor[f32, (10, 10)]) {
...@@ -41,7 +76,6 @@ def test_decl(): ...@@ -41,7 +76,6 @@ def test_decl():
return log(x); return log(x);
} }
""" """
sb = relay.ScopeBuilder()
tp = relay.TensorType((10, 10)) tp = relay.TensorType((10, 10))
x = relay.var("x", tp) x = relay.var("x", tp)
f = relay.Function([x], relay.log(x)) f = relay.Function([x], relay.log(x))
...@@ -76,6 +110,24 @@ def test_recursion(): ...@@ -76,6 +110,24 @@ def test_recursion():
assert "%3 = @f(%1, %2)" in env.astext() assert "%3 = @f(%1, %2)" in env.astext()
assert env[f].checked_type == relay.FuncType([ti32, tf32], tf32) assert env[f].checked_type == relay.FuncType([ti32, tf32], tf32)
# This currently fails and should pass under the type system.
#
# This test is to illustrate problem with our weak form of
# unification.
#
def test_incomplete_call():
sb = ScopeBuilder()
x = relay.var('x', dtype='int32')
f = relay.var('f')
func = relay.Function([x, f], relay.Call(f, [x]))
try:
relay.ir_pass.infer_type(func)
assert False
except tvm.TVMError as e:
assert True
def test_tuple(): def test_tuple():
tp = relay.TensorType((10,)) tp = relay.TensorType((10,))
...@@ -84,13 +136,13 @@ def test_tuple(): ...@@ -84,13 +136,13 @@ def test_tuple():
assert (relay.ir_pass.infer_type(res).checked_type == assert (relay.ir_pass.infer_type(res).checked_type ==
relay.TupleType([tp, tp])) relay.TupleType([tp, tp]))
def test_free_expr(): def test_free_expr():
x = relay.var("x", "float32") x = relay.var("x", "float32")
y = relay.add(x, x) y = relay.add(x, x)
yy = relay.ir_pass.infer_type(y) yy = relay.ir_pass.infer_type(y)
assert yy.checked_type == relay.scalar_type("float32") assert yy.checked_type == relay.scalar_type("float32")
def test_type_args(): def test_type_args():
x = relay.var("x", shape=(10, 10)) x = relay.var("x", shape=(10, 10))
y = relay.var("y", shape=(1, 10)) y = relay.var("y", shape=(1, 10))
...@@ -107,6 +159,7 @@ def test_type_args(): ...@@ -107,6 +159,7 @@ def test_type_args():
assert sh2[0].value == 1 assert sh2[0].value == 1
assert sh2[1].value == 10 assert sh2[1].value == 10
def test_self_reference(): def test_self_reference():
""" """
Program: Program:
...@@ -117,30 +170,40 @@ def test_self_reference(): ...@@ -117,30 +170,40 @@ def test_self_reference():
a = relay.TypeVar("a") a = relay.TypeVar("a")
x = relay.var("x", a) x = relay.var("x", a)
sb = relay.ScopeBuilder() sb = relay.ScopeBuilder()
f = relay.Function([x], x) f = relay.Function([x], x)
fx = relay.Call(f, [x]) fx = relay.Call(f, [x])
assert relay.ir_pass.infer_type(x).checked_type == a assert relay.ir_pass.infer_type(x).checked_type == a
assert relay.ir_pass.infer_type(f).checked_type == relay.FuncType([a], a) assert relay.ir_pass.infer_type(f).checked_type == relay.FuncType([a], a)
assert relay.ir_pass.infer_type(fx).checked_type == a assert relay.ir_pass.infer_type(fx).checked_type == a
def test_global_var_cow_issue(): def test_global_var_cow_issue():
env = relay.env.Environment({}) env = relay.env.Environment({})
gv = relay.GlobalVar("foo") gv = relay.GlobalVar("foo")
x = relay.var('x', shape=[]) x = relay.var('x', shape=[])
func = relay.Function([x], relay.Call(gv, [x]), relay.TensorType([], 'float32')) func = relay.Function([x], relay.Call(gv, [x]),
relay.TensorType([], 'float32'))
env[gv] = func env[gv] = func
# They should both point to the same global variable if global variables are
# stable across type checking.
assert gv == func.body.op def test_equal():
i = relay.var('i', shape=[], dtype='int32')
eq = op.equal(i, relay.const(0, dtype='int32'))
# This should fail ....
func = relay.Function([i], eq, ret_type=relay.TensorType([], 'int32'))
if __name__ == "__main__": if __name__ == "__main__":
test_free_expr() test_free_expr()
test_dual_op() test_dual_op()
test_single_op()
test_recursion() test_recursion()
test_monomorphic_let() test_monomorphic_let()
test_decl() test_decl()
test_recursion() test_recursion()
test_tuple() test_tuple()
test_incomplete_call()
test_free_expr() test_free_expr()
test_type_args() test_type_args()
test_self_reference() test_self_reference()
......
#!/bin/bash #!/bin/bash
export PYTHONPATH=python:apps/extension/python export PYTHONPATH=python:topi/python:apps/extension/python
export LD_LIBRARY_PATH=build:${LD_LIBRARY_PATH} export LD_LIBRARY_PATH=build:${LD_LIBRARY_PATH}
rm -rf python/tvm/*.pyc python/tvm/*/*.pyc python/tvm/*/*/*.pyc rm -rf python/tvm/*.pyc python/tvm/*/*.pyc python/tvm/*/*/*.pyc
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment