Unverified Commit 8876eac8 by Tianqi Chen Committed by GitHub

[RELAY] IR builder stablize refactor, clean pass (#1934)

parent 4300bbc2
......@@ -254,7 +254,7 @@ struct PadAttrs : public tvm::AttrsNode<PadAttrs> {
struct LeakyReluAttrs : public tvm::AttrsNode<LeakyReluAttrs> {
double alpha;
TVM_DECLARE_ATTRS(DenseAttrs, "relay.attrs.LeakyReluAttrs") {
TVM_DECLARE_ATTRS(LeakyReluAttrs, "relay.attrs.LeakyReluAttrs") {
TVM_ATTR_FIELD(alpha).set_lower_bound(0.0).set_default(0.25)
.describe("Slope coefficient for the negative half axis.");
}
......
......@@ -47,12 +47,13 @@ class EnvironmentNode : public RelayNode {
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("functions", &functions);
v->Visit("global_map_", &global_map_);
v->Visit("global_var_map_", &global_var_map_);
}
TVM_DLL static Environment make(tvm::Map<GlobalVar, Function> global_funcs);
/*! \brief Add a function to the global environment.
/*!
* \brief Add a function to the global environment.
* \param var The name of the global function.
* \param func The function.
* \param update Controls whether you can replace a definition in the
......@@ -60,39 +61,46 @@ class EnvironmentNode : public RelayNode {
*/
void Add(const GlobalVar& var, const Function& func, bool update = false);
/*! \brief Update a function in the global environment.
/*!
* \brief Update a function in the global environment.
* \param var The name of the global function to update.
* \param func The new function.
*/
void Update(const GlobalVar& var, const Function& func);
/*! \brief Remove a function from the global environment.
/*!
* \brief Remove a function from the global environment.
* \param var The name of the global function to update.
*/
void Remove(const GlobalVar& var);
/*! \brief Lookup a global function by its variable.
/*!
* \brief Lookup a global function by its variable.
* \param str The unique string specifying the global variable.
* \returns The global variable.
*/
GlobalVar GetGlobalVar(const std::string& str);
/*! \brief Lookup a global function by its variable.
/*!
* \brief Lookup a global function by its variable.
* \param var The global var to lookup.
* \returns The function named by the variable argument.
*/
Function Lookup(const GlobalVar& var);
/*! \brief Lookup a global function by its string name
/*!
* \brief Lookup a global function by its string name
* \param name The name of the function.
* \returns The function named by the argument.
*/
Function Lookup(const std::string& name);
/*! \brief Combine with another Environment.
/*!
* \brief Update the functions inside this environment by
* functions in another environment.
* \param other The other environment.
*/
void Merge(const Environment& other);
void Update(const Environment& other);
static constexpr const char* _type_key = "relay.Environment";
TVM_DECLARE_NODE_TYPE_INFO(EnvironmentNode, Node);
......@@ -101,7 +109,7 @@ class EnvironmentNode : public RelayNode {
/*! \brief A map from string names to global variables that
* ensures global uniqueness.
*/
tvm::Map<std::string, GlobalVar> global_map_;
tvm::Map<std::string, GlobalVar> global_var_map_;
};
struct Environment : public NodeRef {
......
......@@ -197,7 +197,7 @@ class FunctionNode : public ExprNode {
*
* \note This can be usually empty for non-polymorphic functions.
*/
tvm::Array<TypeParam> type_params;
tvm::Array<TypeVar> type_params;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("params", &params);
......@@ -219,7 +219,7 @@ class FunctionNode : public ExprNode {
TVM_DLL static Function make(tvm::Array<Var> params,
Expr body,
Type ret_type,
tvm::Array<TypeParam> ty_params);
tvm::Array<TypeVar> ty_params);
static constexpr const char* _type_key = "relay.Function";
TVM_DECLARE_NODE_TYPE_INFO(FunctionNode, ExprNode);
......@@ -375,13 +375,14 @@ class TupleGetItemNode : public ExprNode {
int index;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("tuple", &tuple);
v->Visit("tuple_value", &tuple);
v->Visit("index", &index);
v->Visit("_checked_type_", &checked_type_);
}
TVM_DLL static TupleGetItem make(Expr tuple, int index);
static constexpr const char * _type_key = "relay.GetItem";
static constexpr const char * _type_key = "relay.TupleGetItem";
TVM_DECLARE_NODE_TYPE_INFO(TupleGetItemNode, ExprNode);
};
......
......@@ -371,14 +371,14 @@ inline OpRegistry& OpRegistry::add_type_rel(
env_type_rel_func = env_func;
}
Array<TypeParam> type_params;
Array<TypeVar> type_params;
Array<Type> arg_types;
// Add inputs.
std::string input_name_prefix = "in";
for (int i = 0; i < get()->num_inputs; i++) {
auto name = input_name_prefix + std::to_string(i);
auto param = TypeParamNode::make(name, TypeParamNode::Kind::kType);
auto param = TypeVarNode::make(name, TypeVarNode::Kind::kType);
type_params.push_back(param);
arg_types.push_back(param);
}
......@@ -386,7 +386,7 @@ inline OpRegistry& OpRegistry::add_type_rel(
Array<Type> ty_call_args = arg_types;
// Add output type.
auto out_param = TypeParamNode::make("out", TypeParamNode::Kind::kType);
auto out_param = TypeVarNode::make("out", TypeVarNode::Kind::kType);
type_params.push_back(out_param);
// this will trigger copy on write.
ty_call_args.push_back(out_param);
......
......@@ -12,21 +12,30 @@
namespace tvm {
namespace relay {
/*! \brief Infer the type of an expression with the provided environment.
/*!
* \brief Infer the type of an expression.
*
* The result of type checking is a new expression with unambigous
* type information filled in, as well as it's checked type field
* populated with the result type.
*
* \param env The environment used for global settings and referencing
* global functions.
*
* \param e The expression to type check.
* \param expr The expression to type check.
* \param env The environment used for referencing global functions, can be None.
*
* \return A type checked expression with its checked_type field populated.
*/
Expr InferType(const Environment& env, const Expr& e);
Expr InferType(const Environment& env, const GlobalVar& var, const Function& f);
Expr InferType(const Expr& expr, const Environment& env);
/*!
* \brief Infer the type of a function as if it is mapped to var in the env.
*
* \param f the function.
* \param env The environment used for referencing global functions.
* \param var The global variable corresponding to the function.
*
* \return A type checked Function with its checked_type field populated.
* \note this function mutates env and is not thread-safe.
*/
Function InferType(const Function& f, const Environment& env, const GlobalVar& var);
/*!
* \brief Check that types are well kinded by applying "kinding rules".
......@@ -111,7 +120,7 @@ tvm::Array<Var> FreeVariables(const Expr& e);
*
* \return the set of free type variables.
*/
tvm::Array<TypeParam> FreeTypeVariables(const Expr& e);
tvm::Array<TypeVar> FreeTypeVariables(const Expr& e);
/*! \brief Get free type parameters from type t.
*
......@@ -121,7 +130,7 @@ tvm::Array<TypeParam> FreeTypeVariables(const Expr& e);
*
* \return the set of free type variables.
*/
tvm::Array<TypeParam> FreeTypeVariables(const Type& t);
tvm::Array<TypeVar> FreeTypeVariables(const Type& t);
/*! \brief Remove expressions which does not effect the program result.
*
......
......@@ -98,7 +98,7 @@ RELAY_DEFINE_NODE_REF(TensorType, TensorTypeNode, Type);
* This can be viewed as template parameter in c++ template function.
*
* For example, in the following pesudo code,
* the TypeParam of f is TypeParam(kind=kShapeVar, var=n).
* the TypeVar of f is TypeVar(kind=kShapeVar, var=n).
* This function can take in a Tensor with shape=(3, 3) and
* returns a Tensor with shape=(9,)
*
......@@ -108,13 +108,13 @@ RELAY_DEFINE_NODE_REF(TensorType, TensorTypeNode, Type);
* f(x : Tensor[i32, (n, n)]) -> Tensor[i32, (n * n)]
*
* \endcode
* \sa TypeParamNode The actual container class of TypeParam
* \sa TypeVarNode The actual container class of TypeVar
*/
class TypeParam;
/*! \brief TypeParam container node */
class TypeParamNode : public TypeNode {
class TypeVar;
/*! \brief TypeVar container node */
class TypeVarNode : public TypeNode {
public:
/*! \brief possible kinds of TypeParam */
/*! \brief possible kinds of TypeVar */
enum Kind : int {
/*! \brief template variable in shape expression */
kType = 0,
......@@ -136,13 +136,13 @@ class TypeParamNode : public TypeNode {
v->Visit("span", &span);
}
TVM_DLL static TypeParam make(std::string name, Kind kind);
TVM_DLL static TypeVar make(std::string name, Kind kind);
static constexpr const char* _type_key = "relay.TypeParam";
TVM_DECLARE_NODE_TYPE_INFO(TypeParamNode, TypeNode);
static constexpr const char* _type_key = "relay.TypeVar";
TVM_DECLARE_NODE_TYPE_INFO(TypeVarNode, TypeNode);
};
RELAY_DEFINE_NODE_REF(TypeParam, TypeParamNode, Type);
RELAY_DEFINE_NODE_REF(TypeVar, TypeVarNode, Type);
/*!
* \brief IncompleteType.
......@@ -150,20 +150,20 @@ RELAY_DEFINE_NODE_REF(TypeParam, TypeParamNode, Type);
*
* If we view the type relations as "computational graph of types",
* then IncompleteType represents intermediate values of the graph,
* TypeParam represents the input to the graph.
* TypeVar represents the input to the graph.
*/
class IncompleteType;
/*! \brief IncompleteType container node */
class IncompleteTypeNode : public TypeNode {
public:
TypeParamNode::Kind kind;
TypeVarNode::Kind kind;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("kind", &kind);
}
TVM_DLL static IncompleteType make(TypeParamNode::Kind kind);
TVM_DLL static IncompleteType make(TypeVarNode::Kind kind);
static constexpr const char* _type_key = "relay.IncompleteType";
TVM_DECLARE_NODE_TYPE_INFO(IncompleteTypeNode, TypeNode);
......@@ -192,7 +192,7 @@ class FuncType;
* Relay support polymorphic function type.
* This can be roughly viewed as template function in C++.
*
* \sa TypeParam, TypeConstraint
* \sa TypeVar, TypeConstraint
*/
class FuncTypeNode : public TypeNode {
public:
......@@ -203,7 +203,7 @@ class FuncTypeNode : public TypeNode {
// The following fields are used in polymorphic(template) functions
// For normal functions, the following two fields will be empty.
/*! \brief The type parameters of the function */
tvm::Array<TypeParam> type_params;
tvm::Array<TypeVar> type_params;
/*!
* \brief potential constraint the type need to obey
* \note this field is reserved for futher purposes.
......@@ -220,7 +220,7 @@ class FuncTypeNode : public TypeNode {
TVM_DLL static FuncType make(tvm::Array<Type> arg_types,
Type ret_type,
tvm::Array<TypeParam> type_params,
tvm::Array<TypeVar> type_params,
tvm::Array<TypeConstraint> type_constraints);
static constexpr const char* _type_key = "relay.FuncType";
......
......@@ -5,7 +5,6 @@ from . import ty
from . import expr
from . import env
from . import ir_pass
from . import ir_builder
# Root operators
from .op import Op
......@@ -16,6 +15,8 @@ from . import nn
from . import vision
from . import image
from .scope_builder import ScopeBuilder
# Span
Span = base.Span
......@@ -27,11 +28,12 @@ Type = ty.Type
TupleType = ty.TupleType
TensorType = ty.TensorType
Kind = ty.Kind
TypeParam = ty.TypeParam
TypeVar = ty.TypeVar
TypeConstraint = ty.TypeConstraint
FuncType = ty.FuncType
TypeRelation = ty.TypeRelation
IncompleteType = ty.IncompleteType
scalar_type = ty.scalar_type
# Expr
Constant = expr.Constant
......
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, wildcard-import
"""A global environment storing everything needed to interpret or compile a Relay program."""
from .base import register_relay_node, RelayNode
from .._ffi import base as _base
from . import _make
from . import _env
from . import expr as _expr
@register_relay_node
class Environment(RelayNode):
"""The global Relay environment containing functions,
options and more.
"""
def __init__(self, funcs=None):
"""Construct an environment.
Parameters
------
funcs : optional, dict
Map of global var to Function
"""The global Relay environment containing collection of functions.
Returns
------
env: A new environment containing :py:class:`~relay.env.Environment`.
"""
funcs = funcs if funcs else {}
self.__init_handle_by_constructor__(_make.Environment, funcs)
Each global function is identified by an unique tvm.relay.GlobalVar.
tvm.relay.GlobalVar and Environment is necessary in order to enable
recursions in function to avoid cyclic reference in the function.x
def add(self, var, func):
Parameters
----------
functions : dict, optional.
Map of global var to Function
"""
def __init__(self, functions=None):
if functions is None:
functions = {}
elif isinstance(functions, dict):
mapped_funcs = {}
for k, v in functions.items():
if isinstance(k, _base.string_types):
k = _expr.GlobalVar(k)
if not isinstance(k, _expr.GlobalVar):
raise TypeError("Expect functions to be Dict[GlobalVar, Function]")
mapped_funcs[k] = v
functions = mapped_funcs
self.__init_handle_by_constructor__(_make.Environment, functions)
def __setitem__(self, var, func):
"""Add a function to the environment.
Parameters
......@@ -36,50 +45,55 @@ class Environment(RelayNode):
func: Function
The function.
"""
if isinstance(var, str):
var = _env.Environment_GetGlobalVar(self, var)
if isinstance(var, _base.string_types):
var = _expr.GlobalVar(var)
_env.Environment_Add(self, var, func)
def merge(self, other):
"""Merge two environments.
def __getitem__(self, var):
"""Lookup a global function by name or by variable.
Parameters
----------
other: Environment
The environment to merge into the current Environment.
var: str or GlobalVar
The name or global variable.
Returns
-------
func: Function
The function referenced by :code:`var`.
"""
return _env.Environment_Merge(self, other)
if isinstance(var, _base.string_types):
return _env.Environment_Lookup_str(self, var)
else:
return _env.Environment_Lookup(self, var)
def global_var(self, name):
"""Get a global variable by name.
def update(self, other):
"""Insert functions in another Environment to current one.
Parameters
----------
name: str
The name of the global variable.
Returns
-------
global_var: GlobalVar
The global variable mapped to :code:`name`.
other: Environment
The environment to merge into the current Environment.
"""
return _env.Environment_GetGlobalVar(self, name)
if isinstance(other, dict):
other = Environment(other)
return _env.Environment_Update(self, other)
def __getitem__(self, var):
"""Lookup a global function by name or by variable.
def get_global_var(self, name):
"""Get a global variable in the function by name.
Parameters
----------
var: str or GlobalVar
The name or global variable.
name: str
The name of the global variable.
Returns
-------
func: Function
The function referenced by :code:`var`.
global_var: GlobalVar
The global variable mapped to :code:`name`.
Raises
------
tvm.TVMError if we cannot find corresponding global var.
"""
if isinstance(var, str):
return _env.Environment_Lookup_str(self, var)
else:
return _env.Environment_Lookup(self, var)
return _env.Environment_GetGlobalVar(self, name)
......@@ -28,9 +28,6 @@ class Expr(RelayNode):
" the checked_type for this node")
return ret
def __call__(self, *args):
return Call(self, args, None, None)
@register_relay_node
class Constant(Expr):
......@@ -57,6 +54,14 @@ class Tuple(Expr):
def __init__(self, fields):
self.__init_handle_by_constructor__(_make.Tuple, fields)
def __getitem__(self, index):
if index >= len(self):
raise IndexError("Tuple index out of range")
return self.fields[index]
def __len__(self):
return len(self.fields)
@register_relay_node
class Var(Expr):
......@@ -95,6 +100,16 @@ class GlobalVar(Expr):
def __init__(self, name_hint):
self.__init_handle_by_constructor__(_make.GlobalVar, name_hint)
def __call__(self, *args):
"""Invoke the gobal function.
Parameters
----------
args: List[relay.Expr]
Arguments.
"""
return Call(self, args, None, None)
@register_relay_node
class Function(Expr):
......@@ -126,6 +141,16 @@ class Function(Expr):
self.__init_handle_by_constructor__(
_make.Function, params, body, ret_type, type_params)
def __call__(self, *args):
"""Invoke the gobal function.
Parameters
----------
args: List[relay.Expr]
Arguments.
"""
return Call(self, args, None, None)
@register_relay_node
class Call(Expr):
......@@ -238,11 +263,17 @@ class TupleWrapper(_node.NodeGeneric):
return self.tuple_value
def __getitem__(self, key):
return self.tuple_value.fields[key]
def __getitem__(self, index):
if index >= len(self):
raise IndexError("Tuple index out of range")
return TupleGetItem(self.tuple_value, index)
def __len__(self):
return len(self.tuple_value.fields)
return self.size
def __repr__(self):
return ("TupleWrapper(" + self.tuple_value.__repr__() +
", " + self.size + ")")
def var(name_hint,
......@@ -304,13 +335,27 @@ def const(value, dtype=None):
dtype: str, optional
The data type of the value.
Note
----
When dtype is None, we use the following rule:
- int maps to "int32"
- float maps to "float32"
- bool maps to "bool"
- other using the same default rule as numpy.
"""
if isinstance(value, _base.numeric_types):
value = _np.array(value, dtype=dtype)
elif isinstance(value, (bool, list)):
if isinstance(value, (_base.numeric_types, (bool, list))):
value = _np.array(value, dtype=dtype)
# convert default to int32 and float32
if dtype is None:
if value.dtype == "float64":
value = value.astype("float32")
elif value.dtype == "int64":
value = value.astype("int32")
if isinstance(value, (_np.ndarray, _np.generic)):
value = _nd.array(value)
if not isinstance(value, _nd.NDArray):
raise ValueError("value has to be scalar or NDArray")
return Constant(value)
......@@ -2,37 +2,39 @@
# pylint: disable=unidiomatic-typecheck
"""The set of passes for Relay.
Exposes an interface for configuring the passes and scripting
them in Python.
Exposes an interface for configuring the passes and
scripting them in Python.
"""
from . import _ir_pass
from . import _make
# pylint: disable=invalid-name
def infer_type(env, expr):
def infer_type(expr, env=None):
"""Infer the type of expr under the context of env.
Parameters
----------
env : relay.Environment
expr: tvm.relay.Expr
The input expression.
env: Optional[tvm.relay.Environment]
The global environment.
expr : relay.Expr
The input expression.
Returns
-------
checked_expr : relay.Expr
checked_expr : tvm.relay.Expr
The checked expression.
"""
return _ir_pass.infer_type(env, expr)
return _ir_pass.infer_type(expr, env)
def well_formed(e):
def well_formed(expr):
"""Check that each Var is only bound once (well formed).
Parameters
----------
e: relay.Expr
expr: tvm.relay.Expr
The input expression
Returns
......@@ -40,7 +42,8 @@ def well_formed(e):
well_form : bool
whether the input expression is well formed
"""
return _ir_pass.well_formed(e)
return _ir_pass.well_formed(expr)
def check_kind(t, env=None):
"""Check that the type is well kinded.
......@@ -48,10 +51,10 @@ def check_kind(t, env=None):
Parameters
----------
t: relay.Type
t: tvm.relay.Type
The type to check
env: relay.Environment, optional
env: tvm.relay.Environment, optional
The global environment
Returns
......@@ -71,61 +74,65 @@ def check_kind(t, env=None):
else:
return _ir_pass.check_kind(t)
def free_vars(e):
"""Get free variables from expression e.
Parameters
----------
e: relay.Expr
e: tvm.relay.Expr
The input expression
Returns
-------
free : List[relay.Var]
the list of free variables
free : List[tvm.relay.Var]
The list of free variables
"""
return _ir_pass.free_vars(e)
def free_type_vars(e):
def free_type_vars(expr):
"""Get free type variables from expression/type e
Parameters
----------
e: relay.Expr/relay.Type
The input expression/type
expr: Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type
Returns
-------
free : List[relay.TypeParam]
the list of free type variables
free : List[tvm.relay.TypeParam]
The list of free type variables
"""
return _ir_pass.free_type_vars(e)
return _ir_pass.free_type_vars(expr)
def dead_code_elimination(e):
def dead_code_elimination(expr):
""" Remove expressions which does not effect the program result (dead code).
Parameters
----------
e: relay.Expr
The input Expression
e: tvm.relay.Expr
The input Expression
Returns
-------
result: relay.Expr
An expression which is semantically equal to the input expression,
but with dead code removed.
result: tvm.relay.Expr
An expression which is semantically equal to the input expression,
but with dead code removed.
"""
return _ir_pass.dead_code_elimination(e)
return _ir_pass.dead_code_elimination(expr)
def alpha_equal(lhs, rhs):
"""Compare two Relay expr for structural equivalence (alpha equivalence).
Parameters
----------
lhs: relay.Expr
lhs: tvm.relay.Expr
One of the input Expression.
rhs: relay.Expr
rhs: tvm.relay.Expr
One of the input Expression.
Returns
......
"""The scope builder interface """
from __future__ import absolute_import
from . import expr as _expr
from .._ffi import base as _base
class WithScope(object):
"""A wrapper for builder methods which introduce scoping.
Parameters
----------
enter_value: object
The value returned by enter.
"""
def __init__(self, enter_value, exit_cb):
self._enter_value = enter_value
self._exit_cb = exit_cb
def __enter__(self):
return self._enter_value
def __exit__(self, ptype, value, trace):
if value:
raise value
else:
self._exit_cb()
def _make_lets(bindings, ret_value):
"""Make a nested let expressions.
Parameters
----------
bindings: List[Tuple[tvm.relay.Var,tvm.relay.Expr]]
The sequence of let bindings
ret_value: tvm.relay.Expr
The final value of the expression.
Returns
-------
lets: tvm.relay.Expr
A nested let expression.
"""
if ret_value is None:
raise RuntimeError("ret is not called in this scope")
if isinstance(ret_value, _expr.If) and ret_value.false_branch is None:
raise RuntimeError("Creating an If expression without else.")
let_expr = ret_value
for var, value in reversed(bindings):
let_expr = _expr.Let(var, value, let_expr)
return let_expr
class ScopeBuilder(object):
"""Scope builder class.
Enables users to build up a nested
scope(let, if) expression easily.
Examples
--------
..code-block: python
sb = relay.ScopeBuilder()
cond = relay.var("cond", 'bool')
x = relay.var("x")
y = relay.var("y")
with sb.if_scope(cond):
one = relay.const(1, "float32")
t1 = sb.let(t1, relay.add(x, one))
sb.ret(t1)
with sb.else_scope():
sb.ret(y)
print(sb.get().astext())
"""
def __init__(self):
self._bindings = [[]]
self._ret_values = [None]
def _enter_scope(self):
self._bindings.append([])
self._ret_values.append(None)
def _exit_scope(self):
bindings = self._bindings.pop()
ret_value = self._ret_values.pop()
return bindings, ret_value
def let(self, var, value):
"""Create a new let binding.
Parameters
----------
var: Union[Tuple[str, relay.Type], tvm.relay.Var]
The variable or name of variable.
value: tvm.relay.Expr
The value to be binded
"""
if isinstance(var, (tuple, list)):
if len(var) > 2:
raise ValueError("Expect var to be Tuple[str, relay.Type]")
var = _expr.var(*var)
elif isinstance(var, _base.string_types):
var = _expr.var(var)
self._bindings[-1].append((var, value))
return var
def if_scope(self, cond):
"""Create a new if scope.
Parameters
----------
cond: tvm.relay.Expr
The condition
Returns
-------
scope: WithScope
The if scope.
Note
----
The user must follows with an else scope.
"""
self._enter_scope()
def _on_exit():
bindings, ret_value = self._exit_scope()
if self._ret_values[-1] is not None:
raise RuntimeError("result already returned before if scope")
true_branch = _make_lets(bindings, ret_value)
self._ret_values[-1] = _expr.If(cond, true_branch, None)
return WithScope(None, _on_exit)
def else_scope(self):
"""Create a new else scope.
Returns
-------
scope: WithScope
The if scope.
"""
self._enter_scope()
def _on_exit():
bindings, ret_value = self._exit_scope()
partial_if = self._ret_values[-1]
no_else = (not isinstance(partial_if, _expr.If) or
partial_if.false_branch is not None)
if no_else:
raise RuntimeError("else scope must follows")
false_branch = _make_lets(bindings, ret_value)
self._ret_values[-1] = _expr.If(
partial_if.cond,
partial_if.true_branch,
false_branch)
return WithScope(None, _on_exit)
def ret(self, value):
"""Set the return value of this scope.
Parameters
----------
value: tvm.relay.Expr
The return value.
"""
if self._ret_values[-1] is not None:
raise RuntimeError("ret value is already set in this scope.")
self._ret_values[-1] = value
def get(self):
"""Get the generated result.
Returns
-------
value: tvm.relay.Expr
The final result of the expression.
"""
if len(self._bindings) != 1:
raise RuntimeError("can only call get at the outmost scope")
return _make_lets(self._bindings[-1], self._ret_values[-1])
......@@ -56,7 +56,7 @@ class Kind(IntEnum):
Shape = 3
@register_relay_node
class TypeParam(Type):
class TypeVar(Type):
"""A type parameter used for generic types in Relay,
see tvm/relay/type.h for more details.
......@@ -66,7 +66,7 @@ class TypeParam(Type):
"""
def __init__(self, var, kind=Kind.Type):
"""Construct a TypeParam.
"""Construct a TypeVar.
Parameters
----------
......@@ -78,10 +78,10 @@ class TypeParam(Type):
Returns
-------
type_param: TypeParam
type_param: TypeVar
The type parameter.
"""
self.__init_handle_by_constructor__(_make.TypeParam, var, kind)
self.__init_handle_by_constructor__(_make.TypeVar, var, kind)
@register_relay_node
......@@ -122,26 +122,30 @@ class FuncType(Type):
We informally write them as:
`forall (type_params), (arg_types) -> ret_type where type_constraints`
Parameters
----------
arg_types: List[tvm.relay.Type]
The argument types
ret_type: tvm.relay.Type
The return type.
type_params: List[tvm.relay.TypeVar]
The type parameters
type_constraints: List[tvm.relay.TypeConstraint]
The type constraints.
"""
def __init__(self,
arg_types,
ret_type,
type_params,
type_constraints):
"""Construct a function type.
Parameters
----------
arg_types: list of Type
ret_type: Type
type_params: list of TypeParam
type_constraints: list of TypeConstraint
Returns
-------
func_type: FuncType
The function type.
"""
type_params=None,
type_constraints=None):
if type_params is None:
type_params = []
if type_constraints is None:
type_constraints = []
self.__init_handle_by_constructor__(
_make.FuncType, arg_types, ret_type, type_params, type_constraints)
......@@ -175,3 +179,21 @@ class TypeRelation(TypeConstraint):
def __init__(self, func, args, num_inputs, attrs):
self.__init_handle_by_constructor__(_make.TypeRelation,
func, args, num_inputs, attrs)
def scalar_type(dtype):
"""Creates a scalar type.
This function returns TensorType((), dtype)
Parameters
----------
dtype : str
The content data type.
Returns
-------
s_type: tvm.relay.TensorType
The result type.
"""
return TensorType((), dtype)
......@@ -16,87 +16,71 @@ using namespace runtime;
Environment EnvironmentNode::make(tvm::Map<GlobalVar, Function> global_funcs) {
auto n = make_node<EnvironmentNode>();
n->functions = std::move(global_funcs);
for (const auto& kv : n->functions) {
// set gloval var map
CHECK(!n->global_var_map_.count(kv.first->name_hint))
<< "Duplicate global function name " << kv.first->name_hint;
n->global_var_map_.Set(kv.first->name_hint, kv.first);
}
return Environment(n);
}
GlobalVar EnvironmentNode::GetGlobalVar(const std::string &str) {
auto global_id = global_map_.find(str);
if (global_id != global_map_.end()) {
return (*global_id).second;
} else {
auto id = GlobalVarNode::make(str);
this->global_map_.Set(str, id);
return id;
}
GlobalVar EnvironmentNode::GetGlobalVar(const std::string& name) {
auto it = global_var_map_.find(name);
CHECK(it != global_var_map_.end())
<< "Cannot find global var " << name << " in the Environment";
return (*it).second;
}
/*!
* \brief Add a new item to the global environment
* \note if the update flag is not set adding a duplicate
* definition will trigger an exception, otherwise we will
* update the definition if and only if it is type compatible.
*/
void EnvironmentNode::Add(const GlobalVar &var,
const Function &func,
void EnvironmentNode::Add(const GlobalVar& var,
const Function& func,
bool update) {
// Type check the item before we add it to the environment.
auto env = GetRef<Environment>(this);
Expr checked_expr = InferType(env, var, func);
if (const FunctionNode *func_node = checked_expr.as<FunctionNode>()) {
auto checked_func = GetRef<Function>(func_node);
auto type = checked_func->checked_type();
CHECK(type.as<IncompleteTypeNode>() == nullptr);
if (functions.find(var) != functions.end()) {
if (!update) {
throw dmlc::Error("already have definition for XXXX.");
}
auto old_type = functions[var].as<FunctionNode>()->checked_type();
if (!AlphaEqual(type, old_type)) {
throw dmlc::Error(
"Environment#update changes type, not possible in this mode.");
}
this->functions.Set(var, checked_func);
} else {
this->functions.Set(var, checked_func);
}
} else {
LOG(FATAL) << "internal error: unknown item type, unreachable code";
Function checked_func = InferType(func, env, var);
auto type = checked_func->checked_type();
CHECK(type.as<IncompleteTypeNode>() == nullptr);
if (functions.find(var) != functions.end()) {
CHECK(update)
<< "Already have definition for " << var->name_hint;
auto old_type = functions[var].as<FunctionNode>()->checked_type();
CHECK(AlphaEqual(type, old_type))
<< "Environment#update changes type, not possible in this mode.";
}
this->functions.Set(var, checked_func);
// set gloval var map
CHECK(!global_var_map_.count(var->name_hint))
<< "Duplicate global function name " << var->name_hint;
global_var_map_.Set(var->name_hint, var);
}
void EnvironmentNode::Update(const GlobalVar &var, const Function &func) {
void EnvironmentNode::Update(const GlobalVar& var, const Function& func) {
this->Add(var, func, true);
}
void EnvironmentNode::Remove(const GlobalVar & var) {
void EnvironmentNode::Remove(const GlobalVar& var) {
auto functions_node = this->functions.CopyOnWrite();
functions_node->data.erase(var.node_);
auto gvar_node = global_var_map_.CopyOnWrite();
gvar_node->data.erase(var->name_hint);
}
Function EnvironmentNode::Lookup(const GlobalVar &var) {
auto func = functions.find(var);
if (func != functions.end()) {
return (*func).second;
} else {
throw Error(std::string("there is no definition of ") + var->name_hint);
}
Function EnvironmentNode::Lookup(const GlobalVar& var) {
auto it = functions.find(var);
CHECK(it != functions.end())
<< "There is no definition of " << var->name_hint;
return (*it).second;
}
Function EnvironmentNode::Lookup(const std::string &str) {
GlobalVar id = this->GetGlobalVar(str);
Function EnvironmentNode::Lookup(const std::string &name) {
GlobalVar id = this->GetGlobalVar(name);
return this->Lookup(id);
}
void EnvironmentNode::Merge(const Environment &env) {
void EnvironmentNode::Update(const Environment &env) {
for (auto pair : env->functions) {
this->functions.Set(pair.first, pair.second);
this->Update(pair.first, pair.second);
}
}
......@@ -134,10 +118,10 @@ TVM_REGISTER_API("relay._env.Environment_Lookup_str")
*ret = env->Lookup(var);
});
TVM_REGISTER_API("relay._env.Environment_Merge")
TVM_REGISTER_API("relay._env.Environment_Update")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Environment env = args[0];
env->Merge(args[1]);
env->Update(args[1]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
......
......@@ -104,7 +104,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
Function FunctionNode::make(tvm::Array<Var> params,
Expr body,
Type ret_type,
tvm::Array<TypeParam> type_params) {
tvm::Array<TypeVar> type_params) {
NodePtr<FunctionNode> n = make_node<FunctionNode>();
n->params = std::move(params);
n->body = std::move(body);
......
......@@ -66,11 +66,11 @@ Expr ExprMutator::VisitExpr_(const TupleNode* op) {
}
Expr ExprMutator::VisitExpr_(const FunctionNode* op) {
tvm::Array<TypeParam> ty_params;
tvm::Array<TypeVar> ty_params;
bool all_ty_params_changed = true;
for (auto ty_param : op->type_params) {
TypeParam new_ty_param = Downcast<TypeParam>(VisitType(ty_param));
TypeVar new_ty_param = Downcast<TypeVar>(VisitType(ty_param));
ty_params.push_back(new_ty_param);
all_ty_params_changed &= new_ty_param.same_as(ty_param);
}
......
......@@ -217,6 +217,8 @@ class TextPrinter :
return ConstScalar(dtype, static_cast<const float*>(op->data->data));
} else if (dtype == Float(64)) {
return ConstScalar(dtype, static_cast<const double*>(op->data->data));
} else if (dtype == Bool()) {
return ConstScalar(dtype, static_cast<const uint8_t*>(op->data->data));
}
}
// default fall-back, record it as meta node.
......@@ -638,8 +640,14 @@ class TextPrinter :
* \return The corresponding name.
*/
TextValue AllocVarName(const Var& var) {
std::string name = GetUniqueName('%' + var->name_hint);
TextValue val(name);
std::string name = var->name_hint;
// always make sure first name is alpha
if (name.length() != 0 && !std::isalpha(name[0])) {
name = "%v" + name;
} else {
name = "%" + name;
}
TextValue val(GetUniqueName(name));
CHECK(!memo_.count(var));
memo_[var] = val;
return val;
......
......@@ -36,30 +36,30 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
p->stream << "TensorType(" << node->shape << ", " << node->dtype << ")";
});
TypeParam TypeParamNode::make(std::string name, TypeParamNode::Kind kind) {
NodePtr<TypeParamNode> n = make_node<TypeParamNode>();
TypeVar TypeVarNode::make(std::string name, TypeVarNode::Kind kind) {
NodePtr<TypeVarNode> n = make_node<TypeVarNode>();
n->var = tvm::Var(name);
n->kind = std::move(kind);
return TypeParam(n);
return TypeVar(n);
}
TVM_REGISTER_NODE_TYPE(TypeParamNode);
TVM_REGISTER_NODE_TYPE(TypeVarNode);
TVM_REGISTER_API("relay._make.TypeParam")
TVM_REGISTER_API("relay._make.TypeVar")
.set_body([](TVMArgs args, TVMRetValue *ret) {
int kind = args[1];
*ret =
TypeParamNode::make(args[0], static_cast<TypeParamNode::Kind>(kind));
TypeVarNode::make(args[0], static_cast<TypeVarNode::Kind>(kind));
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TypeParamNode>([](const TypeParamNode *node,
.set_dispatch<TypeVarNode>([](const TypeVarNode *node,
tvm::IRPrinter *p) {
p->stream << "TypeParamNode(" << node->var->name_hint << ", "
p->stream << "TypeVarNode(" << node->var->name_hint << ", "
<< node->kind << ")";
});
IncompleteType IncompleteTypeNode::make(TypeParamNode::Kind kind) {
IncompleteType IncompleteTypeNode::make(TypeVarNode::Kind kind) {
auto n = make_node<IncompleteTypeNode>();
n->kind = std::move(kind);
return IncompleteType(n);
......@@ -70,7 +70,7 @@ TVM_REGISTER_NODE_TYPE(IncompleteTypeNode);
TVM_REGISTER_API("relay._make.IncompleteType")
.set_body([](TVMArgs args, TVMRetValue* ret) {
int kind = args[0];
*ret = IncompleteTypeNode::make(static_cast<TypeParamNode::Kind>(kind));
*ret = IncompleteTypeNode::make(static_cast<TypeVarNode::Kind>(kind));
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
......@@ -82,7 +82,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
FuncType FuncTypeNode::make(tvm::Array<Type> arg_types,
Type ret_type,
tvm::Array<TypeParam> type_params,
tvm::Array<TypeVar> type_params,
tvm::Array<TypeConstraint> type_constraints) {
NodePtr<FuncTypeNode> n = make_node<FuncTypeNode>();
n->arg_types = std::move(arg_types);
......
......@@ -78,6 +78,7 @@ RELAY_REGISTER_OP("image.resize")
for layout NHWC
(batch_size, size[0], size[1], channels)
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.ResizeAttrs")
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(5)
......
......@@ -247,6 +247,8 @@ RELAY_REGISTER_UNARY_OP("relay.op.nn._make.", "relu")
// Positional relay function to create LRN operator used by frontend FFI.
TVM_REGISTER_NODE_TYPE(LRNAttrs);
Expr MakeLRN(Expr data,
IndexExpr size,
IndexExpr axis,
......@@ -290,6 +292,8 @@ centered at that value (zero padding is added where necessary).
// Positional relay function to create L2Normalize operator used by frontend FFI.
TVM_REGISTER_NODE_TYPE(L2NormalizeAttrs);
Expr MakeL2Normalize(Expr data,
double eps,
Array<IndexExpr> axis) {
......@@ -315,6 +319,7 @@ Normalizes along dimension axis using an L2 norm
- **data**: The input tensor.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.L2NormalizeAttrs")
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2)
......
......@@ -77,6 +77,7 @@ RELAY_REGISTER_OP("nn.pad")
.describe(R"code(Pad for n-D tensor.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.PadAttrs")
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2)
......
......@@ -12,6 +12,7 @@ namespace tvm {
namespace relay {
TVM_REGISTER_NODE_TYPE(MaxPool2DAttrs);
TVM_REGISTER_NODE_TYPE(AvgPool2DAttrs);
template <typename AttrTtype>
bool Pool2DRel(const Array<Type>& types,
......@@ -115,6 +116,7 @@ RELAY_REGISTER_OP("nn.max_pool2d")
equation.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.MaxPool2DAttrs")
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2)
......@@ -169,6 +171,7 @@ Average pooling operation for one dimensional data.
equation.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.AvgPool2DAttrs")
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2)
......@@ -232,6 +235,7 @@ RELAY_REGISTER_OP("nn.global_avg_pool2d")
(batch_size, channels, 1, 1) if `layout` is `NCHW`.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.GlobalPool2DAttrs")
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2)
......@@ -261,6 +265,7 @@ RELAY_REGISTER_OP("nn.global_max_pool2d")
(batch_size, channels, 1, 1) if `layout` is `NCHW`.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.GlobalPool2DAttrs")
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2)
......
......@@ -78,6 +78,7 @@ RELAY_REGISTER_OP("nn.upsampling")
(batch_size, in_height*scale, in_width*scale, channels)
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.UpSamplingAttrs")
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2)
......
......@@ -199,7 +199,7 @@ RELAY_REGISTER_REDUCE_OP("argmax")
values over a given axis.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_attrs_type_key("relay.attrs.ReduceAttrs")
.set_support_level(4)
.add_type_rel("ArgReduce", ArgReduceRel);
......@@ -209,7 +209,7 @@ RELAY_REGISTER_REDUCE_OP("argmin")
values over a given axis.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_attrs_type_key("relay.attrs.ReduceAttrs")
.set_support_level(4)
.add_type_rel("ArgReduce", ArgReduceRel);
......
......@@ -144,12 +144,14 @@ RELAY_REGISTER_OP("concatenate")
- **axis** : The axis along which the tensors are concatenated.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.ConcatenateAttrs")
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input list of tensors.")
.set_support_level(1)
.add_type_rel("Concatenate", ConcatenateRel);
/* relay.transpose */
TVM_REGISTER_NODE_TYPE(TransposeAttrs);
bool TransposeRel(const Array<Type>& types,
int num_inputs,
......@@ -224,12 +226,15 @@ RELAY_REGISTER_OP("transpose")
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_attrs_type_key("relay.attrs.TransposeAttrs")
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(3)
.add_type_rel("Transpose", TransposeRel);
/* relay.reshape */
TVM_REGISTER_NODE_TYPE(ReshapeAttrs);
bool ReshapeRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
......@@ -310,6 +315,7 @@ Example::
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_attrs_type_key("relay.attrs.ReshapeAttrs")
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(3)
.add_type_rel("Reshape", ReshapeRel);
......@@ -397,12 +403,14 @@ Examples::
[ 4., 3.]]
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.TakeAttrs")
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("indices", "Tensor", "The indices tensor.")
.set_support_level(2)
.add_type_rel("Take", TakeRel);
// Init ops
TVM_REGISTER_NODE_TYPE(InitOpAttrs);
bool FullRel(const Array<Type>& types,
......@@ -448,6 +456,7 @@ RELAY_REGISTER_OP("full")
.describe(R"code(Fill array with scalar value.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.InitOpAttrs")
.set_num_inputs(1)
.add_argument("fill_value", "double", "The value to fill.")
.set_support_level(3)
......@@ -634,6 +643,10 @@ Examples::
.set_support_level(4)
.add_type_rel("Where", WhereRel);
// Squeeze
TVM_REGISTER_NODE_TYPE(SqueezeAttrs);
Expr MakeSqueeze(Expr data,
Array<IndexExpr> axes) {
auto attrs = make_node<SqueezeAttrs>();
......
......@@ -7,6 +7,7 @@
#include <tvm/relay/expr.h>
#include <tvm/relay/logging.h>
#include <tvm/relay/op.h>
#include <tvm/ir_pass.h>
#include <numeric>
#include "./type_relations.h"
......@@ -21,14 +22,6 @@ TensorType ToTensorType(const Type& t) {
}
}
// TODO(@jroesch) what size value do we extract, 64bit or 32bit?
int ToInt(const tvm::Expr& e) {
CHECK(e.defined());
auto imm = e.as<tvm::ir::IntImm>();
CHECK(imm) << "TYPE: " << imm << imm->type << std::endl;
return imm->value;
}
bool IdentityRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
......@@ -39,72 +32,54 @@ bool IdentityRel(const Array<Type>& types,
return true;
}
bool EqualCheck(const IndexExpr& lhs,
const IndexExpr& rhs) {
IndexExpr diff = lhs - rhs;
if (const int64_t* pdiff = as_const_int(diff)) {
return pdiff[0] == 0;
}
// symbolic
diff = tvm::ir::CanonicalSimplify(diff);
if (const int64_t* pdiff = as_const_int(diff)) {
return pdiff[0] == 0;
}
return false;
}
bool EqualConstInt(const IndexExpr& lhs, int64_t value) {
if (const int64_t* pvalue = as_const_int(lhs)) {
return pvalue[0] == value;
}
return false;
}
Type ConcreteBroadcast(const TensorType& t1,
const TensorType& t2,
DataType output_dtype) {
RELAY_LOG(INFO) << "ConcreteBroadcast: t1=" << t1 << " t2=" << t2
<< std::endl;
auto sh1 = t1->shape;
auto sh2 = t2->shape;
RELAY_LOG(INFO) << "ConcreteBroadcast: sh1=" << sh1 << " sh2=" << sh2
<< std::endl;
if (sh1.size() == 0 && sh2.size() == 0) {
return TensorTypeNode::make({}, output_dtype);
// We have non-zero shapes so broadcast rules apply.
} else {
auto suffix_len = static_cast<int>(std::min(sh1.size(), sh2.size()));
auto full_len = static_cast<int>(std::max(sh1.size(), sh2.size()));
auto rev_sh1 = sh1.rbegin();
auto rev_sh2 = sh2.rbegin();
while (rev_sh1 != sh1.rend() && rev_sh2 != sh2.rend()) {
auto dim1 = ToInt(*rev_sh1);
auto dim2 = ToInt(*rev_sh2);
if ((dim1 != dim2) && ((dim1 != 1) && (dim2 != 1))) {
CHECK(false) << "Dimension mistmatch "
<< "dim1: " << dim1 << " dim2: " << dim2 << std::endl;
}
rev_sh1++;
rev_sh2++;
}
Array<IndexExpr> larger;
Array<IndexExpr> smaller;
for (int i = 0; i < (full_len - suffix_len); i++) {
smaller.push_back(make_const(tvm::Int(64), 1));
}
if (sh1.size() < sh2.size()) {
for (auto sh : sh1) {
smaller.push_back(sh);
}
larger = sh2;
} else if (sh1.size() > sh2.size()) {
for (auto sh : sh1) {
larger.push_back(sh);
}
smaller = sh2;
std::vector<IndexExpr> oshape;
size_t ndim1 = t1->shape.size();
size_t ndim2 = t2->shape.size();
size_t i = 1;
for (; i <= std::min(ndim1, ndim2); ++i) {
IndexExpr s1 = t1->shape[ndim1 - i];
IndexExpr s2 = t2->shape[ndim2 - i];
if (EqualCheck(s1, s2)) {
oshape.push_back(s1);
} else if (EqualConstInt(s1, 1)) {
oshape.push_back(s2);
} else if (EqualConstInt(s2, 1)) {
oshape.push_back(s1);
} else {
larger = sh1;
smaller = sh2;
LOG(FATAL) << "Incompatible broadcast type " << t1 << " and " << t2;
}
CHECK_EQ(larger.size(), smaller.size());
Array<IndexExpr> out_shape;
for (size_t i = 0; i < smaller.size(); i++) {
auto left = smaller[i].as<tvm::ir::IntImm>();
auto right = larger[i].as<tvm::ir::IntImm>();
CHECK(left);
CHECK(right);
int64_t dim = std::max(left->value, right->value);
out_shape.push_back(make_const(tvm::Int(64), dim));
}
return TensorTypeNode::make(out_shape, output_dtype);
}
size_t max_ndim = std::max(ndim1, ndim2);
auto& rshape = (ndim1 > ndim2) ? t1->shape : t2->shape;
for (; i <= max_ndim; ++i) {
oshape.push_back(rshape[max_ndim - i]);
}
return TensorTypeNode::make(Array<IndexExpr>(
oshape.rbegin(), oshape.rend()), output_dtype);
}
bool BroadcastRel(const Array<Type>& types,
......@@ -141,71 +116,5 @@ bool BroadcastCompRel(const Array<Type>& types,
return false;
}
/*! \brief Handle concrete concat case from known input to output. */
inline Type ConcreteConcatRel(const Type& input_type) {
if (auto tuple_node = input_type.as<TupleTypeNode>()) {
// NB: For now the axis argument is hardwired to be 0.
std::vector<int> dims;
DataType dtype;
CHECK_LT(1, tuple_node->fields.size());
bool skip_first = true;
// Collect the suffix dimensions since axis is zero.
// TODO(@jroesch): This is a demonstration of how
// to do varargs. It requires a little more work to
// fully type the behavior of concat.
auto first = Downcast<TensorType>(tuple_node->fields[0]);
dtype = first->dtype;
for (auto dim_expr : first->shape) {
if (!skip_first) {
dims.push_back(ToInt(dim_expr));
} else {
skip_first = false;
}
}
std::vector<int> axis_dims;
for (auto field_ty : tuple_node->fields) {
auto ttype = Downcast<TensorType>(field_ty);
for (size_t i = 0; i < ttype->shape.size(); i++) {
if (i != 0) {
CHECK_EQ(ToInt(dims[i - 1]), ToInt(ttype->shape[i]));
} else {
axis_dims.push_back(ToInt(ttype->shape[i]));
}
}
}
auto out_axis_dim = std::accumulate(axis_dims.begin(), axis_dims.end(), 0);
Array<tvm::Expr> out_shape = { make_const(Int(64), out_axis_dim) };
for (auto dim : dims) {
out_shape.push_back(make_const(Int(64), dim));
}
return TensorTypeNode::make(out_shape, dtype);
} else {
throw TypeRelationError("concat can only be used with a tuple as its argument");
}
}
bool ConcatRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
if (types[0].as<TupleTypeNode>()) {
reporter->Assign(types[1], ConcreteConcatRel(types[0]));
return true;
}
return false;
}
} // namespace relay
} // namespace tvm
......@@ -13,17 +13,6 @@
namespace tvm {
namespace relay {
/*! \brief The error raised by a type relation.
*
* This error is how a type relation signals that it has failed.
*
*/
struct TypeRelationError : Error {
explicit TypeRelationError(const std::string& msg)
: Error(msg) {}
};
/*!
* \brief The identity type relation, all the types are equal.
*
......@@ -72,22 +61,6 @@ bool BroadcastCompRel(const Array<Type>& types,
const Attrs& attrs,
const TypeReporter& reporter);
/*!
* \brief The concat type relation, implements the concatenating
* rule over the list of input types producing one concatenated
* type.
*
* \param types The input and output types to the relation.
* \param num_inputs The number of input arguments.
* \param attrs The attributes
* \param reporter The reporter.
* \return true whether relation has been resolved.
*/
bool ConcatRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter);
} // namespace relay
} // namespace tvm
......
......@@ -63,6 +63,7 @@ TVM_REGISTER_API("relay.op.vision._make.multibox_prior")
RELAY_REGISTER_OP("vision.multibox_prior")
.describe(R"doc("Generate prior(anchor) boxes from data, sizes and ratios."
)doc" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.MultiBoxPriorAttrs")
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(4)
......
......@@ -34,7 +34,7 @@ bool SameNDArray(const NDArray& lhs, const NDArray& rhs) {
}
struct TypeAlphaEq : TypeVisitor<const Type&> {
tvm::Map<TypeParam, TypeParam> eq_map;
tvm::Map<TypeVar, TypeVar> eq_map;
bool equal;
TypeAlphaEq() : eq_map(), equal(true) {}
......@@ -76,10 +76,10 @@ struct TypeAlphaEq : TypeVisitor<const Type&> {
}
}
void VisitType_(const TypeParamNode* ti1, const Type& t2) final {
if (const TypeParamNode* ti2 = t2.as<TypeParamNode>()) {
auto tid1 = GetRef<TypeParam>(ti1);
auto tid2 = GetRef<TypeParam>(ti2);
void VisitType_(const TypeVarNode* ti1, const Type& t2) final {
if (const TypeVarNode* ti2 = t2.as<TypeVarNode>()) {
auto tid1 = GetRef<TypeVar>(ti1);
auto tid2 = GetRef<TypeVar>(ti2);
// We handle open terms with this rule assuming variables are identical.
//
......
......@@ -20,7 +20,9 @@ bool IsBoolLit(const Expr& e, bool b) {
if (const ConstantNode* c = e.as<ConstantNode>()) {
if (c->is_scalar()) {
auto dt = c->tensor_type()->dtype;
if (dt == UInt(8)) {
if (dt == Bool()) {
return *reinterpret_cast<const uint8_t*>(c->data->data) == b;
} else if (dt == UInt(8)) {
return *reinterpret_cast<const uint8_t*>(c->data->data) == b;
} else if (dt == UInt(16)) {
return *reinterpret_cast<const uint16_t*>(c->data->data) == b;
......
......@@ -20,7 +20,7 @@ namespace tvm {
namespace relay {
using namespace tvm::runtime;
using Kind = TypeParamNode::Kind;
using Kind = TypeVarNode::Kind;
struct KindChecker : TypeVisitor<> {
bool valid;
......@@ -33,7 +33,7 @@ struct KindChecker : TypeVisitor<> {
return tv->kind == k;
}
if (const TypeParamNode *tp = t.as<TypeParamNode>()) {
if (const TypeVarNode *tp = t.as<TypeVarNode>()) {
return tp->kind == k;
}
......
......@@ -61,7 +61,7 @@ class LetList {
* \return a Var that hold the inserted expr.
*/
Var Push(Expr expr) {
return Push(IncompleteTypeNode::make(TypeParamNode::kType), expr);
return Push(IncompleteTypeNode::make(TypeVarNode::kType), expr);
}
/*!
......
......@@ -61,7 +61,7 @@ class TypeFunctor<R(const Type& n, Args...)> {
// Functions that can be overriden by subclass
virtual R VisitType_(const TensorTypeNode* op,
Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const TypeParamNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const TypeVarNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const TypeConstraintNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const FuncTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const TypeRelationNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
......@@ -79,7 +79,7 @@ class TypeFunctor<R(const Type& n, Args...)> {
FType vtable;
// Set dispatch
RELAY_TYPE_FUNCTOR_DISPATCH(TensorTypeNode);
RELAY_TYPE_FUNCTOR_DISPATCH(TypeParamNode);
RELAY_TYPE_FUNCTOR_DISPATCH(TypeVarNode);
RELAY_TYPE_FUNCTOR_DISPATCH(TypeConstraintNode);
RELAY_TYPE_FUNCTOR_DISPATCH(FuncTypeNode);
RELAY_TYPE_FUNCTOR_DISPATCH(TypeRelationNode);
......
......@@ -28,6 +28,39 @@
namespace tvm {
namespace relay {
// Necessary deferred relation for TupleGetItem
struct TupleGetItemAttrs : public tvm::AttrsNode<TupleGetItemAttrs> {
int index;
TVM_DECLARE_ATTRS(TupleGetItemAttrs, "relay.attrs.TupleGetItemAttrs") {
TVM_ATTR_FIELD(index);
}
};
bool TupleGetItemRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
if (types[0].as<IncompleteTypeNode>()) return false;
const auto* data = types[0].as<TupleTypeNode>();
CHECK(data != nullptr)
<< "TupleGetItem expect input type to be TupleType "
<< " get " << types[0] << " instead";
const auto* param = attrs.as<TupleGetItemAttrs>();
CHECK(param != nullptr);
CHECK_GE(param->index, 0);
CHECK_LT(param->index, data->fields.size());
reporter->Assign(types[1], data->fields[param->index]);
return true;
}
TVM_REGISTER_NODE_TYPE(TupleGetItemAttrs);
TVM_REGISTER_API("tvm.relay.type_relation.TupleGetItem")
.set_body_typed<bool(const Array<Type>&, int, const Attrs&, const TypeReporter&)>(
TupleGetItemRel);
//
// The inference algorithm can roughly be devided into three stages:
// - Populate the constraints by visiting the expression (TypeInferencer.GetType)
......@@ -38,8 +71,7 @@ namespace relay {
class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
public:
// constructors
TypeInferencer()
: env_(EnvironmentNode::make({})) {
TypeInferencer() {
}
explicit TypeInferencer(Environment env)
: env_(env) {
......@@ -58,6 +90,8 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
std::unordered_map<Expr, Type, NodeHash, NodeEqual> type_map_;
// The solver used by the inferencer.
TypeSolver solver_;
// relation function
TypeRelationFn tuple_getitem_rel_;
// Unify two types
Type Unify(const Type& t1, const Type& t2, const Span& span) {
// TODO(tqchen, jroesch): propagate span to solver
......@@ -90,12 +124,14 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
if (op->type_annotation.defined()) {
return op->type_annotation;
} else {
return IncompleteTypeNode::make(TypeParamNode::kType);
return IncompleteTypeNode::make(TypeVarNode::kType);
}
}
Type VisitExpr_(const GlobalVarNode* op) final {
GlobalVar var = GetRef<GlobalVar>(op);
CHECK(env_.defined())
<< "Cannot do type inference without a global variable";
Expr e = env_->Lookup(var);
return e->checked_type();
}
......@@ -116,17 +152,17 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
}
Type VisitExpr_(const TupleGetItemNode* op) final {
// TODO(M.K.)
// handle case where field type is not known
Type tuple_type = GetType(op->tuple);
auto tuple_ty_node = tuple_type.as<TupleTypeNode>();
if (!tuple_ty_node) {
LOG(FATAL) << "only expressions with tuple types is accepted" << GetRef<TupleGetItem>(op);
}
if (static_cast<int>(tuple_ty_node->fields.size()) <= op->index) {
LOG(FATAL) << "tuple not big enough" << GetRef<TupleGetItem>(op);
if (!tuple_getitem_rel_.defined()) {
tuple_getitem_rel_ = TypeRelationFn(
EnvFunc::Get("tvm.relay.type_relation.TupleGetItem").node_);
}
return tuple_ty_node->fields[op->index];
Type tuple_type = GetType(op->tuple);
Type rtype = IncompleteTypeNode::make(TypeVarNode::Kind::kType);
auto attrs = make_node<TupleGetItemAttrs>();
attrs->index = op->index;
solver_.AddConstraint(TypeRelationNode::make(
tuple_getitem_rel_, {tuple_type, rtype}, 1, Attrs(attrs)));
return rtype;
}
Type VisitExpr_(const OpNode* op) final {
......@@ -169,7 +205,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
for (size_t i = 0; i < op->type_params.size(); ++i) {
if (!op->type_params[i].same_as(rel->args[i])) return Type();
}
Type rtype = IncompleteTypeNode::make(TypeParamNode::Kind::kType);
Type rtype = IncompleteTypeNode::make(TypeVarNode::Kind::kType);
arg_types.push_back(rtype);
// we can do simple replacement here
solver_.AddConstraint(TypeRelationNode::make(
......@@ -179,7 +215,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
// instantiate the function type with fresh
FuncType Instantiate(const FuncTypeNode* fn_ty, Array<Type>* ty_args) {
tvm::Map<TypeParam, Type> subst_map;
tvm::Map<TypeVar, Type> subst_map;
// Build a subsitituion map up from the function type and type arguments.
// Eventually allow the type vars to be passed in.
......@@ -196,7 +232,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
// This is a temporary work around to check recursive functions whose
// return type is not yet known.
if (!ret_type.defined()) {
ret_type = IncompleteTypeNode::make(TypeParamNode::Kind::kType);
ret_type = IncompleteTypeNode::make(TypeVarNode::Kind::kType);
}
Type inst_ty = FuncTypeNode::make(fn_ty->arg_types,
ret_type, {},
......@@ -305,7 +341,6 @@ class TypeInferencer::Resolver : public ExprMutator {
return AttachCheckedType(op);
}
Expr VisitExpr_(const FunctionNode* op) final {
return AttachCheckedType(op);
}
......@@ -363,20 +398,21 @@ Expr TypeInferencer::Infer(Expr expr) {
return Resolver(type_map_, &solver_).VisitExpr(expr);
}
Expr InferType(const Environment& env, const Expr& expr) {
Expr InferType(const Expr& expr, const Environment& env) {
return TypeInferencer(env).Infer(expr);
}
Expr InferType(const Environment& env,
const GlobalVar& var,
const Function& func) {
Function InferType(const Function& func,
const Environment& env,
const GlobalVar& var) {
Function func_copy = Function(make_node<FunctionNode>(*func.operator->()));
func_copy->checked_type_ = func_copy->func_type_annotation();
env->functions.Set(var, func_copy);
Expr func_ret = TypeInferencer(env).Infer(func_copy);
auto map_node = env->functions.CopyOnWrite();
map_node->data.erase(var.node_);
return func_ret;
return Downcast<Function>(func_ret);
}
TVM_REGISTER_API("relay._ir_pass.infer_type")
......
......@@ -10,13 +10,13 @@ namespace tvm {
namespace relay {
struct TypeSubstV : TypeMutator {
tvm::Map<TypeParam, Type> subst_map;
tvm::Map<TypeVar, Type> subst_map;
explicit TypeSubstV(tvm::Map<TypeParam, Type> subst_map)
explicit TypeSubstV(tvm::Map<TypeVar, Type> subst_map)
: subst_map(subst_map) {}
Type VisitType_(const TypeParamNode* op) override {
auto id = GetRef<TypeParam>(op);
Type VisitType_(const TypeVarNode* op) override {
auto id = GetRef<TypeVar>(op);
if (subst_map.find(id) != subst_map.end()) {
return this->subst_map[id];
} else {
......@@ -25,12 +25,12 @@ struct TypeSubstV : TypeMutator {
}
};
Type TypeSubst(const Type& type, const TypeParam& target, const Type& subst) {
Type TypeSubst(const Type& type, const TypeVar& target, const Type& subst) {
TypeSubstV ty_sub({ {target, subst} });
return ty_sub.VisitType(type);
}
Type TypeSubst(const Type& type, tvm::Map<TypeParam, Type> subst_map) {
Type TypeSubst(const Type& type, tvm::Map<TypeVar, Type> subst_map) {
TypeSubstV ty_sub(subst_map);
return ty_sub.VisitType(type);
}
......
......@@ -11,8 +11,8 @@
namespace tvm {
namespace relay {
Type TypeSubst(const Type& type, const TypeParam& target, const Type& subst);
Type TypeSubst(const Type& type, tvm::Map<TypeParam, Type> subst_map);
Type TypeSubst(const Type& type, const TypeVar& target, const Type& subst);
Type TypeSubst(const Type& type, tvm::Map<TypeVar, Type> subst_map);
} // namespace relay
} // namespace tvm
......
......@@ -19,7 +19,7 @@ namespace relay {
*/
template <typename... Args>
struct TypeVisitor : ::tvm::relay::TypeFunctor<void(const Type& n, Args...)> {
void VisitType_(const TypeParamNode* op, Args... args) override {}
void VisitType_(const TypeVarNode* op, Args... args) override {}
void VisitType_(const FuncTypeNode* op, Args... args) override {
for (auto type_param : op->type_params) {
......@@ -60,16 +60,16 @@ struct TypeMutator : TypeFunctor<Type(const Type& n)> {
return TensorTypeNode::make(op->shape, op->dtype);
}
Type VisitType_(const TypeParamNode* op) override {
return GetRef<TypeParam>(op);
Type VisitType_(const TypeVarNode* op) override {
return GetRef<TypeVar>(op);
}
Type VisitType_(const FuncTypeNode* op) override {
Array<TypeParam> type_params;
Array<TypeVar> type_params;
for (auto type_param : op->type_params) {
auto new_type_param = VisitType(type_param);
if (const TypeParamNode* tin = new_type_param.as<TypeParamNode>()) {
type_params.push_back(GetRef<TypeParam>(tin));
if (const TypeVarNode* tin = new_type_param.as<TypeVarNode>()) {
type_params.push_back(GetRef<TypeVar>(tin));
} else {
CHECK(false) << new_type_param << std::endl;
}
......
......@@ -14,14 +14,14 @@ namespace relay {
class FreeVar;
class FreeTypeVar : private TypeVisitor<> {
std::unordered_set<TypeParam, NodeHash, NodeEqual> * free_vars;
std::unordered_set<TypeParam, NodeHash, NodeEqual> * bound_vars;
FreeTypeVar(std::unordered_set<TypeParam, NodeHash, NodeEqual> * free_vars,
std::unordered_set<TypeParam, NodeHash, NodeEqual> * bound_vars) :
std::unordered_set<TypeVar, NodeHash, NodeEqual> * free_vars;
std::unordered_set<TypeVar, NodeHash, NodeEqual> * bound_vars;
FreeTypeVar(std::unordered_set<TypeVar, NodeHash, NodeEqual> * free_vars,
std::unordered_set<TypeVar, NodeHash, NodeEqual> * bound_vars) :
free_vars(free_vars), bound_vars(bound_vars) { }
void VisitType_(const TypeParamNode* tp) final {
auto var = GetRef<TypeParam>(tp);
void VisitType_(const TypeVarNode* tp) final {
auto var = GetRef<TypeVar>(tp);
if (bound_vars->count(var) == 0) {
free_vars->insert(var);
}
......@@ -75,8 +75,8 @@ class FreeVar : public ExprVisitor {
public:
std::unordered_set<Var, NodeHash, NodeEqual> free_vars;
std::unordered_set<Var, NodeHash, NodeEqual> bound_vars;
std::unordered_set<TypeParam, NodeHash, NodeEqual> free_types;
std::unordered_set<TypeParam, NodeHash, NodeEqual> bound_types;
std::unordered_set<TypeVar, NodeHash, NodeEqual> free_types;
std::unordered_set<TypeVar, NodeHash, NodeEqual> bound_types;
void VisitType(const Type& t) final {
FreeTypeVar(&free_types, &bound_types)(t);
......@@ -89,16 +89,16 @@ tvm::Array<Var> FreeVariables(const Expr& e) {
return tvm::Array<Var>(fv.free_vars.begin(), fv.free_vars.end());
}
tvm::Array<TypeParam> FreeTypeVariables(const Expr& e) {
tvm::Array<TypeVar> FreeTypeVariables(const Expr& e) {
FreeVar fv;
fv.VisitExpr(e);
return tvm::Array<TypeParam>(fv.free_types.begin(), fv.free_types.end());
return tvm::Array<TypeVar>(fv.free_types.begin(), fv.free_types.end());
}
tvm::Array<TypeParam> FreeTypeVariables(const Type& t) {
tvm::Array<TypeVar> FreeTypeVariables(const Type& t) {
FreeVar fv;
fv.VisitType(t);
return tvm::Array<TypeParam>(fv.free_types.begin(), fv.free_types.end());
return tvm::Array<TypeVar>(fv.free_types.begin(), fv.free_types.end());
}
TVM_REGISTER_API("relay._ir_pass.free_vars")
......
import numpy as np
from tvm.relay.expr import Let, Constant
from tvm.relay.ir_builder import IRBuilder
def test_let():
b = IRBuilder()
x = b.let('x', 1)
b.ret(x)
prog, _ = b.get()
assert isinstance(prog, Let)
var = prog.var
value = prog.value
assert var.name_hint == 'x'
assert var == prog.body
assert isinstance(value, Constant)
assert value.data.asnumpy() == np.array(1)
if __name__ == "__main__":
test_let()
......@@ -34,7 +34,7 @@ def test_tensor_type():
def test_type_param():
tp = relay.TypeParam('name', relay.Kind.Type)
tp = relay.TypeVar('name', relay.Kind.Type)
assert tp.kind == relay.Kind.Type
# assert tp.span # TODO allow us to set span
str(tp)
......@@ -56,7 +56,7 @@ def test_func_type():
def test_tuple_type():
tp = relay.TypeParam('tp', relay.Kind.Type)
tp = relay.TypeVar('tp', relay.Kind.Type)
tf = relay.FuncType(tvm.convert([]), None, tvm.convert([]), tvm.convert([]))
tt = relay.TensorType(tvm.convert([1, 2, 3]), 'float32')
fields = tvm.convert([tp, tf, tt])
......@@ -66,7 +66,7 @@ def test_tuple_type():
def test_type_relation():
tp = relay.TypeParam('tp', relay.Kind.Type)
tp = relay.TypeVar('tp', relay.Kind.Type)
tf = relay.FuncType(tvm.convert([]), None, tvm.convert([]), tvm.convert([]))
tt = relay.TensorType(tvm.convert([1, 2, 3]), 'float32')
args = tvm.convert([tf, tt, tp])
......@@ -173,7 +173,7 @@ def test_if():
def test_tuple_get_item():
tup = relay.Var("tuple")
get = relay.TupleGetItem(tup, 1)
assert get.tuple == tup
assert get.tuple_value == tup
assert get.index == 1
str(get)
......
......@@ -27,7 +27,7 @@ def test_env():
z = relay.add(z, z)
f = relay.Function([x, y], z)
env = relay.Environment()
env.add("myf", f)
env["myf"] = f
text = env.astext()
assert "def @myf" in text
assert "%1 = add(%0, %0) # ty=float32" in text
......@@ -70,15 +70,18 @@ def test_let_if_scope():
x = relay.var("x", "float32")
y = relay.var("y", "float32")
cond = relay.var("cond", "bool")
v1 = relay.var("v")
v2 = relay.var("v", "float32")
then_branch = relay.Let(
v1, relay.const(1, "float32"),
relay.Let(v2, x, relay.subtract(v1, v2)))
v3 = relay.var("v")
let2 = relay.Let(v3, y, v3)
else_branch = relay.add(let2, let2)
result = relay.If(cond, then_branch, else_branch)
sb = relay.ScopeBuilder()
with sb.if_scope(cond):
v1 = sb.let("v", relay.const(1, "float32"))
v2 = sb.let("v", x)
sb.ret(relay.subtract(v1, v2))
with sb.else_scope():
v3 = relay.var("v")
let2 = relay.Let(v3, y, v3)
sb.ret(relay.add(let2, let2))
result = sb.get()
f = relay.Function([x, y, cond], result)
text = f.astext()
assert text.count("{") == 4
......@@ -86,10 +89,17 @@ def test_let_if_scope():
show(f.astext())
def test_variable_name():
# avoid pure number even if the namehint is pure number
v1 = relay.var("1")
assert "%v1" in v1.astext()
if __name__ == "__main__":
do_print[0] = True
test_let_if_scope()
test_func()
test_env()
test_meta_data()
test_call_attrs()
test_let_if_scope()
test_variable_name()
import tvm
import numpy as np
from tvm import relay
from tvm.relay.ir_pass import infer_type
from tvm.relay.ir_builder import IRBuilder, func_type
from tvm.relay.ir_builder import scalar_type, convert, tensor_type
from tvm.relay.env import Environment
def assert_has_type(expr, typ, env=Environment({})):
checked_expr = infer_type(env, expr)
checked_type = checked_expr.checked_type
if checked_type != typ:
raise RuntimeError("Type mismatch %s vs %s" % (
checked_type, typ))
def test_binary_op():
def check_binary_op(opfunc):
"""
Program:
fn (x, y) {
return x <op> y;
}
"""
b = IRBuilder()
x = b.param('x', tensor_type(5, 5, 5))
y = b.param('y', tensor_type(5, 5, 5))
with b.function(x, y) as func:
b.ret(opfunc(x, y))
b.ret(func)
prog, env = b.get()
ttype = tensor_type(5, 5, 5)
expected_ty = func_type([ttype, ttype], ttype)
assert_has_type(func.to_func(), expected_ty)
n = tvm.var("n")
t1 = relay.TensorType((5, n, 5))
t2 = relay.TensorType((n, 1))
x = relay.var("x", t1)
y = relay.var("y", t2)
z = opfunc(x, y)
# test printer
assert ("%0 = {}(%x, %y)".format(z.op.name)) in z.astext()
assert relay.ir_pass.infer_type(z).checked_type == t1
for opfunc in [relay.pow]:
check_binary_op(opfunc)
def test_binary_broadcast_op():
def check_binary_broadcast_op(opfunc):
"""
Program:
fn (x: Tensor[(10, 4), f32], y: Tensor[(5, 10, 1), f32]) -> Tensor[(5, 10, 4), f32] {
return x <op> y;
}
"""
b = IRBuilder()
x = b.param('x', tensor_type(10, 4))
y = b.param('y', tensor_type(5, 10, 1))
with b.function(x, y) as func:
b.ret(opfunc(x, y))
b.ret(func)
prog, env = b.get()
expected_ty = func_type([tensor_type(10, 4), tensor_type(5, 10, 1)],
tensor_type(5, 10, 4))
assert_has_type(func.to_func(), expected_ty)
for opfunc in [relay.pow]:
check_binary_broadcast_op(opfunc)
def test_cmp_type():
for op in (relay.greater,
relay.greater_equal,
......@@ -68,138 +26,59 @@ def test_cmp_type():
relay.less_equal,
relay.equal,
relay.not_equal):
ib = relay.ir_builder.IRBuilder()
x = ib.param("x", relay.TensorType((10, 4), "float32"))
y = ib.param("y", relay.TensorType((5, 10, 1), "float32"))
with ib.function(x, y) as func:
ib.ret(op(x, y))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.TensorType((5, 10, 4), "uint1")
x = relay.var("x", relay.TensorType((10, 4), "float32"))
y = relay.var("y", relay.TensorType((5, 10, 1), "float32"))
z = op(x, y)
z.astext()
zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.TensorType((5, 10, 4), "bool")
def test_binary_broadcast():
def test_binary_int_broadcast():
for op in [relay.right_shift,
relay.left_shift,
relay.maximum,
relay.minimum]:
ib = relay.ir_builder.IRBuilder()
x = ib.param("x", relay.TensorType((10, 4), "int32"))
y = ib.param("y", relay.TensorType((5, 10, 1), "int32"))
with ib.function(x, y) as func:
ib.ret(op(x, y))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.TensorType((5, 10, 4), "int32")
def test_argmax():
ib = relay.ir_builder.IRBuilder()
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = ib.param("x", relay.ty.TensorType((n, c , h, w), "float32"))
with ib.function(x) as func:
ib.ret(relay.argmax(x, axis=(1,)))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType((n, h, w), "int32")
ib = relay.ir_builder.IRBuilder()
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = ib.param("x", relay.ty.TensorType((n, c , h, w), "float32"))
with ib.function(x) as func:
ib.ret(relay.argmax(x, axis=(2,), keepdims=True))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType((n, c , 1, w), "int32")
ib = relay.ir_builder.IRBuilder()
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = ib.param("x", relay.ty.TensorType((n, c , h, w), "float32"))
with ib.function(x) as func:
ib.ret(relay.argmax(x, axis=(2,), keepdims=True, exclude=True))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType((1, 1 , h, 1), "int32")
def test_argmin():
ib = relay.ir_builder.IRBuilder()
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = ib.param("x", relay.ty.TensorType((n, c , h, w), "float32"))
with ib.function(x) as func:
ib.ret(relay.argmax(x, axis=(1,)))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType((n, h, w), "int32")
ib = relay.ir_builder.IRBuilder()
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = ib.param("x", relay.ty.TensorType((n, c , h, w), "float32"))
with ib.function(x) as func:
ib.ret(relay.argmin(x, axis=(2,), keepdims=True))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType((n, c , 1, w), "int32")
ib = relay.ir_builder.IRBuilder()
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = ib.param("x", relay.ty.TensorType((n, c , h, w), "float32"))
with ib.function(x) as func:
ib.ret(relay.argmin(x, axis=(2,), keepdims=True, exclude=True))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType((1, 1 , h, 1), "int32")
ib = relay.ir_builder.IRBuilder()
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = ib.param("x", relay.ty.TensorType((n, c , h, w), "float32"))
with ib.function(x) as func:
ib.ret(relay.argmin(x, axis=(2,1), keepdims=True, exclude=True))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType((1, c , h, 1), "int32")
ib = relay.ir_builder.IRBuilder()
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = ib.param("x", relay.ty.TensorType((n, c , h, w), "float32"))
with ib.function(x) as func:
ib.ret(relay.argmin(x, axis=None, keepdims=True, exclude=True))
ib.ret(func)
x = relay.var("x", relay.TensorType((10, 4), "int32"))
y = relay.var("y", relay.TensorType((5, 10, 1), "int32"))
z = op(x, y)
zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.TensorType((5, 10, 4), "int32")
def test_arg_reduce():
for op in [relay.argmax, relay.argmin]:
n, c , h, w = 10, 20, 3, 4
x = relay.var("x", relay.ty.TensorType((n, c , h, w), "float32"))
z = relay.argmax(x, axis=(1,))
"axis=" in z.astext()
zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.ty.TensorType((n, h, w), "int32")
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = relay.var("x", relay.ty.TensorType((n, c , h, w), "float32"))
z = relay.argmax(x, axis=(2,), keepdims=True)
zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.ty.TensorType((n, c , 1, w), "int32")
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = relay.var("x", relay.ty.TensorType((n, c , h, w), "float32"))
z = relay.argmax(x, axis=(2,), keepdims=True, exclude=True)
zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.ty.TensorType((1, 1 , h, 1), "int32")
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType((1, 1 , 1, 1), "int32")
def test_where():
ib = relay.ir_builder.IRBuilder()
cond = ib.param("cond", relay.TensorType((3, 4), "float32"))
x = ib.param("x", relay.TensorType((3, 4), "float32"))
y = ib.param("y", relay.TensorType((3, 4), "float32"))
with ib.function(cond, x, y) as func:
ib.ret(relay.where(cond, x, y))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.TensorType((3, 4), "float32")
cond = relay.var("cond", relay.TensorType((3, 4), "float32"))
x = relay.var("x", relay.TensorType((3, 4), "float32"))
y = relay.var("y", relay.TensorType((3, 4), "float32"))
z = relay.where(cond, x, y)
zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.TensorType((3, 4), "float32")
if __name__ == "__main__":
test_binary_op()
test_binary_broadcast_op()
test_cmp_type()
test_binary_broadcast()
test_binary_int_broadcast()
test_where()
test_multibox_prior()
test_argmax()
test_argmin()
test_arg_reduce()
......@@ -4,26 +4,18 @@ import tvm
from tvm import relay
def test_resize_infer_type():
ib = relay.ir_builder.IRBuilder()
n, c, h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "int8"))
x = relay.var("x", relay.TensorType((n, c, h, w), "int8"))
th, tw = tvm.var("th"), tvm.var("tw")
z = relay.image.resize(x, (th, tw))
zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.TensorType((n, c, th, tw), "int8")
with ib.function(x) as func:
ib.ret(relay.image.resize(x, (th, tw)))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType((n, c, th, tw), "int8")
ib = relay.ir_builder.IRBuilder()
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "int8"))
with ib.function(x) as func:
ib.ret(relay.image.resize(x, (100, 200), "NCHW", "BILINEAR", False))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType((n, c, 100, 200), "int8")
x = relay.var("x", relay.TensorType((n, c, h, w), "int8"))
z= relay.image.resize(x, (100, 200), "NCHW", "BILINEAR", False)
assert "size=" in z.astext()
zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.TensorType((n, c, 100, 200), "int8")
......@@ -34,29 +26,21 @@ def test_multibox_prior():
offsets = (0.2, 0.3)
clip = True
ib = relay.ir_builder.IRBuilder()
n, c, h, w = tvm.var("n"), 3, 56, 56
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32"))
x = relay.var("x", relay.TensorType((n, c, h, w), "float32"))
with ib.function(x) as func:
ib.ret(relay.vision.multibox_prior(x, sizes, ratios,
steps, offsets, clip))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType(
z = relay.vision.multibox_prior(x, sizes, ratios,
steps, offsets, clip)
assert "sizes=" in z.astext()
zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.TensorType(
(1, h * w * (len(sizes) + len(ratios) - 1), 4), "float32")
ib = relay.ir_builder.IRBuilder()
n, c, h, w = tvm.var("n"), 24, 32, 32
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32"))
with ib.function(x) as func:
ib.ret(relay.vision.multibox_prior(x))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType(
x = relay.var("x", relay.TensorType((n, c, h, w), "float32"))
z = relay.vision.multibox_prior(x)
zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.TensorType(
(1, h * w, 4), "float32")
......
......@@ -4,7 +4,7 @@ from tvm.relay.ir_pass import check_kind
def test_tuple_kind():
# only contain type kinds
tp = relay.TypeParam('tp', relay.Kind.Type)
tp = relay.TypeVar('tp', relay.Kind.Type)
tt = relay.TensorType(tvm.convert([1, 2, 3]), 'float32')
tf = relay.FuncType(tvm.convert([]), tt, tvm.convert([]), tvm.convert([]))
fields = tvm.convert([tp, tf, tt])
......@@ -15,8 +15,8 @@ def test_tuple_kind():
def test_func_kind():
# only contain type kinds
tp1 = relay.TypeParam('tp1', relay.Kind.Type)
tp2 = relay.TypeParam('tp2', relay.Kind.Type)
tp1 = relay.TypeVar('tp1', relay.Kind.Type)
tp2 = relay.TypeVar('tp2', relay.Kind.Type)
shape = tvm.convert([1, 2, 3])
dtype = 'float32'
......@@ -35,7 +35,7 @@ def test_func_kind():
def test_relation_kind():
# only have type kinds for arguments
tp = relay.TypeParam('tp', relay.Kind.Type)
tp = relay.TypeVar('tp', relay.Kind.Type)
tt = relay.TensorType(tvm.convert([1, 2, 3]), 'float32')
tf = relay.FuncType(tvm.convert([]), tt, tvm.convert([]), tvm.convert([]))
args = tvm.convert([tf, tt, tp])
......@@ -45,9 +45,9 @@ def test_relation_kind():
def test_invalid_tuple_kind():
tp1 = relay.TypeParam('tp1', relay.Kind.Shape)
tp2 = relay.TypeParam('tp2', relay.Kind.BaseType)
tp3 = relay.TypeParam('tp3', relay.Kind.ShapeVar)
tp1 = relay.TypeVar('tp1', relay.Kind.Shape)
tp2 = relay.TypeVar('tp2', relay.Kind.BaseType)
tp3 = relay.TypeVar('tp3', relay.Kind.ShapeVar)
fields = tvm.convert([tp1, tp2, tp3])
tup_ty = relay.TupleType(fields)
......@@ -55,9 +55,9 @@ def test_invalid_tuple_kind():
def test_invalid_func_kind():
tp1 = relay.TypeParam('tp1', relay.Kind.Shape)
tp2 = relay.TypeParam('tp2', relay.Kind.BaseType)
tp3 = relay.TypeParam('tp3', relay.Kind.ShapeVar)
tp1 = relay.TypeVar('tp1', relay.Kind.Shape)
tp2 = relay.TypeVar('tp2', relay.Kind.BaseType)
tp3 = relay.TypeVar('tp3', relay.Kind.ShapeVar)
type_params = tvm.convert([tp1, tp2, tp3])
type_constraints = tvm.convert([])
......@@ -69,9 +69,9 @@ def test_invalid_func_kind():
def test_invalid_relation_kind():
tp1 = relay.TypeParam('tp1', relay.Kind.Shape)
tp2 = relay.TypeParam('tp2', relay.Kind.BaseType)
tp3 = relay.TypeParam('tp3', relay.Kind.ShapeVar)
tp1 = relay.TypeVar('tp1', relay.Kind.Shape)
tp2 = relay.TypeVar('tp2', relay.Kind.BaseType)
tp3 = relay.TypeVar('tp3', relay.Kind.ShapeVar)
args = tvm.convert([tp1, tp2, tp3])
tr = relay.TypeRelation(None, args, 2, None)
......@@ -79,19 +79,19 @@ def test_invalid_relation_kind():
def test_func_with_invalid_ret_type():
tp1 = relay.TypeParam('tp1', relay.Kind.Type)
tp2 = relay.TypeParam('tp2', relay.Kind.Shape)
tp1 = relay.TypeVar('tp1', relay.Kind.Type)
tp2 = relay.TypeVar('tp2', relay.Kind.Shape)
tf = relay.FuncType(tvm.convert([tp1]), tp2, tvm.convert([tp1, tp2]), tvm.convert([]))
def test_func_with_invalid_arg_types():
tp1 = relay.TypeParam('tp1', relay.Kind.Shape)
tp2 = relay.TypeParam('tp2', relay.Kind.Type)
tp1 = relay.TypeVar('tp1', relay.Kind.Shape)
tp2 = relay.TypeVar('tp2', relay.Kind.Type)
tf = relay.FuncType(tvm.convert([tp1]), tp2, tvm.convert([tp1, tp2]), tvm.convert([]))
def test_func_with_invalid_tuple():
tp1 = relay.TypeParam('tp1', relay.Kind.Shape)
tp1 = relay.TypeVar('tp1', relay.Kind.Shape)
ret_type = relay.TupleType(tvm.convert([tp1, tp1, tp1]))
......@@ -100,9 +100,9 @@ def test_func_with_invalid_tuple():
def test_func_with_invalid_relation():
tp1 = relay.TypeParam('tp1', relay.Kind.Type)
tp2 = relay.TypeParam('tp2', relay.Kind.Shape)
tp3 = relay.TypeParam('tp3', relay.Kind.ShapeVar)
tp1 = relay.TypeVar('tp1', relay.Kind.Type)
tp2 = relay.TypeVar('tp2', relay.Kind.Shape)
tp3 = relay.TypeVar('tp3', relay.Kind.ShapeVar)
tr = relay.TypeRelation(None, tvm.convert([tp2, tp3]), 1, None)
......@@ -113,7 +113,7 @@ def test_func_with_invalid_relation():
def test_tuple_with_invalid_func():
tensor_type = relay.TensorType(tvm.convert([1, 2, 3]), 'float32')
tp1 = relay.TypeParam('tp1', relay.Kind.Shape)
tp1 = relay.TypeVar('tp1', relay.Kind.Shape)
tf = relay.FuncType(tvm.convert([]), tp1, tvm.convert([tp1]), tvm.convert([]))
tup_ty = relay.TupleType(tvm.convert([tensor_type, tf]))
......
import tvm
from tvm import relay
from tvm.relay.ir_pass import dead_code_elimination, alpha_equal
from tvm.relay.ir_builder import convert, IRBuilder
from tvm.relay.op import log, add, equal, subtract
......@@ -19,9 +18,9 @@ class env:
self.tt = relay.TensorType(self.shape, "float32")
self.int32 = relay.TensorType([], "int32")
self.float32 = relay.TensorType([], "float32")
self.one = convert(1.0)
self.two = convert(2.0)
self.three = convert(3.0)
self.one = relay.const(1.0)
self.two = relay.const(2.0)
self.three = relay.const(3.0)
e = env()
......@@ -58,9 +57,12 @@ def test_recursion():
f = relay.Var("f")
n = relay.Var("n", e.int32)
data = relay.Var("data", e.float32)
funcbody = relay.If(equal(n, convert(0)), data, f(subtract(n, convert(1.0)), log(data)))
funcbody = relay.If(equal(n, relay.const(0)),
data,
relay.Call(f, [subtract(n, relay.const(1.0)),
log(data)]))
value = relay.Function([n, data], funcbody, e.float32, [])
orig = relay.Let(f, funcbody, f(convert(2.0), convert(10000.0)))
orig = relay.Let(f, funcbody, relay.Call(f, [relay.const(2.0), relay.const(10000.0)]))
assert alpha_equal(dead_code_elimination(orig), orig)
assert alpha_equal(dead_code_elimination(relay.Let(f, funcbody, e.three)), e.three)
......@@ -70,8 +72,10 @@ def test_op_let():
def test_if():
orig = relay.If(convert(True), e.a, e.b)
assert alpha_equal(dead_code_elimination(orig), e.a)
cond = relay.const(True)
orig = relay.If(cond, e.a, e.b)
y = dead_code_elimination(orig)
assert alpha_equal(y, e.a)
def test_tuple_get_item():
......@@ -82,10 +86,10 @@ def test_tuple_get_item():
if __name__ == "__main__":
test_if()
test_let()
test_used_let()
test_chain_unused_let()
test_recursion()
test_op_let()
test_if()
test_tuple_get_item()
......@@ -28,7 +28,7 @@ def test_tuple():
def test_free_type_vars():
tp = relay.TypeParam("")
tp = relay.TypeVar("")
ty = relay.TupleType([tp, relay.TensorType([], "int32")])
x = relay.Var("x", ty)
y = relay.Var("y")
......
......@@ -4,34 +4,17 @@
import tvm
import numpy as np
from tvm.relay.ir_pass import infer_type
from tvm.relay.ir_builder import IRBuilder, func_type
from tvm.relay.ir_builder import scalar_type, convert, tensor_type
from tvm.relay.env import Environment
from tvm.relay.op import log, add, equal, subtract, concatenate
from tvm.relay.expr import Function
from tvm import relay
def assert_has_type(expr, typ, env=Environment({})):
checked_expr = infer_type(env, expr)
checked_type = checked_expr.checked_type
if checked_type != typ:
raise RuntimeError("Type mismatch %s vs %s" % (
checked_type, typ))
def assert_decl_has_type(env, name, typ):
func = env[name]
assert func.checked_type == typ
def test_monomorphic_let():
"Program: let x = 1; return x"
b = IRBuilder()
x = b.let('x', 1.0, value_type=scalar_type('float64'))
b.ret(x)
sb = relay.ScopeBuilder()
x = sb.let('x', relay.const(1.0, "float64"))
sb.ret(x)
xchecked = relay.ir_pass.infer_type(sb.get())
assert xchecked.checked_type == relay.scalar_type("float64")
prog, env = b.get()
assert_has_type(prog, scalar_type('float64'))
def test_dual_op():
"""Program:
......@@ -41,31 +24,29 @@ def test_dual_op():
return t1;
}
"""
b = IRBuilder()
with b.function(('x', tensor_type(10, 10))) as func:
x, = func.param_ids()
t1 = b.let('t1', log(x))
t2 = b.let('t2', add(t1, x))
b.ret(t2)
assert_has_type(func.to_func(),
func_type([tensor_type(10, 10)], tensor_type(10, 10)))
tp = relay.TensorType((10, 10), "float32")
x = relay.var("x", tp)
sb = relay.ScopeBuilder()
t1 = sb.let("t1", relay.log(x))
t2 = sb.let("t2", relay.add(t1, x))
sb.ret(t2)
f = relay.Function([x], sb.get())
fchecked = relay.ir_pass.infer_type(f)
assert fchecked.checked_type == relay.FuncType([tp], tp)
def test_decl():
"""Program:
def f(x : Tensor[f32, (10, 10)]) {
let lx = log(x);
return lx;
def f(x : Tensor[(10, 10), f32]) {
return log(x);
}
"""
b = IRBuilder()
x = b.param('x')
with b.decl('f', x):
lx = b.let('lx', log(x))
b.ret(lx)
_, env = b.get()
assert_decl_has_type(env, 'f', func_type(['float32'], 'float32'))
sb = relay.ScopeBuilder()
tp = relay.TensorType((10, 10))
x = relay.var("x", tp)
f = relay.Function([x], relay.log(x))
fchecked = relay.ir_pass.infer_type(f)
assert fchecked.checked_type == relay.FuncType([tp], tp)
def test_recursion():
......@@ -78,54 +59,44 @@ def test_recursion():
return f(n - 1, log(data));
}
}
f(2, 10000);
"""
b = IRBuilder()
f = b.global_var('f')
n = b.param('n', ty='int32')
data = b.param('data', ty='float32')
with b.decl(f, n, data):
with b.if_scope(equal(n, convert(0))):
b.ret(data)
with b.else_scope():
b.ret(f(subtract(n, convert(1)), log(data)))
b.ret(f(convert(2.0), convert(10000.0)))
assert_decl_has_type(b.env, 'f', func_type(
['int32', 'float32'], 'float32'))
# TODO(@jroesch): need evaluator or new runtime
# to execute this.
sb = relay.ScopeBuilder()
f = relay.GlobalVar("f")
ti32 = relay.scalar_type("int32")
tf32 = relay.scalar_type("float32")
n = relay.var("n", ti32)
data = relay.var("data", tf32)
with sb.if_scope(relay.equal(n, relay.const(0, ti32))):
sb.ret(data)
with sb.else_scope():
sb.ret(f(relay.subtract(n, relay.const(1, ti32)), relay.log(data)))
env = relay.Environment()
env[f] = relay.Function([n, data], sb.get())
assert "%3 = @f(%1, %2)" in env.astext()
assert env[f].checked_type == relay.FuncType([ti32, tf32], tf32)
def test_concat():
"""
Program:
def try_concat2(x: Float(3, 2), y: Float(2, 2)) -> Float(5, 2) {
return concatenate((x, y), axis=0);
}
"""
ib = IRBuilder()
try_concat2 = ib.global_var('try_concat2')
x = ib.param('x', ty=tensor_type(3, 2))
y = ib.param('y', ty=tensor_type(2, 2))
with ib.decl(try_concat2, x, y):
ib.ret(concatenate((x, y), axis=0))
fn_ty = func_type([tensor_type(3, 2), tensor_type(2, 2)], tensor_type(5, 2))
assert_decl_has_type(ib.env, try_concat2, fn_ty)
def test_tuple():
ib = IRBuilder()
dup = ib.global_var('dup')
x = ib.param('x')
with ib.decl(dup, x):
ib.ret(relay.Tuple([x, x]))
# todo: why is this not generalized?
fn_ty = func_type([tensor_type()], relay.TupleType([tensor_type(), tensor_type()]))
assert_decl_has_type(ib.env, dup, fn_ty)
tp = relay.TensorType((10,))
x = relay.var("x", tp)
res = relay.Tuple([x, x])
assert (relay.ir_pass.infer_type(res).checked_type ==
relay.TupleType([tp, tp]))
def test_free_expr():
x = relay.var("x", "float32")
y = relay.add(x, x)
yy = relay.ir_pass.infer_type(y)
assert yy.checked_type == relay.scalar_type("float32")
if __name__ == "__main__":
test_free_expr()
test_dual_op()
test_recursion()
test_monomorphic_let()
test_decl()
test_recursion()
test_concat()
test_tuple()
import tvm
from tvm import relay
from tvm.relay.ir_builder import scalar_type, convert, tensor_type
def make_rel(name, args, num_inputs=None, attrs=None):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment