Unverified Commit 8876eac8 by Tianqi Chen Committed by GitHub

[RELAY] IR builder stablize refactor, clean pass (#1934)

parent 4300bbc2
...@@ -254,7 +254,7 @@ struct PadAttrs : public tvm::AttrsNode<PadAttrs> { ...@@ -254,7 +254,7 @@ struct PadAttrs : public tvm::AttrsNode<PadAttrs> {
struct LeakyReluAttrs : public tvm::AttrsNode<LeakyReluAttrs> { struct LeakyReluAttrs : public tvm::AttrsNode<LeakyReluAttrs> {
double alpha; double alpha;
TVM_DECLARE_ATTRS(DenseAttrs, "relay.attrs.LeakyReluAttrs") { TVM_DECLARE_ATTRS(LeakyReluAttrs, "relay.attrs.LeakyReluAttrs") {
TVM_ATTR_FIELD(alpha).set_lower_bound(0.0).set_default(0.25) TVM_ATTR_FIELD(alpha).set_lower_bound(0.0).set_default(0.25)
.describe("Slope coefficient for the negative half axis."); .describe("Slope coefficient for the negative half axis.");
} }
......
...@@ -47,12 +47,13 @@ class EnvironmentNode : public RelayNode { ...@@ -47,12 +47,13 @@ class EnvironmentNode : public RelayNode {
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("functions", &functions); v->Visit("functions", &functions);
v->Visit("global_map_", &global_map_); v->Visit("global_var_map_", &global_var_map_);
} }
TVM_DLL static Environment make(tvm::Map<GlobalVar, Function> global_funcs); TVM_DLL static Environment make(tvm::Map<GlobalVar, Function> global_funcs);
/*! \brief Add a function to the global environment. /*!
* \brief Add a function to the global environment.
* \param var The name of the global function. * \param var The name of the global function.
* \param func The function. * \param func The function.
* \param update Controls whether you can replace a definition in the * \param update Controls whether you can replace a definition in the
...@@ -60,39 +61,46 @@ class EnvironmentNode : public RelayNode { ...@@ -60,39 +61,46 @@ class EnvironmentNode : public RelayNode {
*/ */
void Add(const GlobalVar& var, const Function& func, bool update = false); void Add(const GlobalVar& var, const Function& func, bool update = false);
/*! \brief Update a function in the global environment. /*!
* \brief Update a function in the global environment.
* \param var The name of the global function to update. * \param var The name of the global function to update.
* \param func The new function. * \param func The new function.
*/ */
void Update(const GlobalVar& var, const Function& func); void Update(const GlobalVar& var, const Function& func);
/*! \brief Remove a function from the global environment. /*!
* \brief Remove a function from the global environment.
* \param var The name of the global function to update. * \param var The name of the global function to update.
*/ */
void Remove(const GlobalVar& var); void Remove(const GlobalVar& var);
/*! \brief Lookup a global function by its variable. /*!
* \brief Lookup a global function by its variable.
* \param str The unique string specifying the global variable. * \param str The unique string specifying the global variable.
* \returns The global variable. * \returns The global variable.
*/ */
GlobalVar GetGlobalVar(const std::string& str); GlobalVar GetGlobalVar(const std::string& str);
/*! \brief Lookup a global function by its variable. /*!
* \brief Lookup a global function by its variable.
* \param var The global var to lookup. * \param var The global var to lookup.
* \returns The function named by the variable argument. * \returns The function named by the variable argument.
*/ */
Function Lookup(const GlobalVar& var); Function Lookup(const GlobalVar& var);
/*! \brief Lookup a global function by its string name /*!
* \brief Lookup a global function by its string name
* \param name The name of the function. * \param name The name of the function.
* \returns The function named by the argument. * \returns The function named by the argument.
*/ */
Function Lookup(const std::string& name); Function Lookup(const std::string& name);
/*! \brief Combine with another Environment. /*!
* \brief Update the functions inside this environment by
* functions in another environment.
* \param other The other environment. * \param other The other environment.
*/ */
void Merge(const Environment& other); void Update(const Environment& other);
static constexpr const char* _type_key = "relay.Environment"; static constexpr const char* _type_key = "relay.Environment";
TVM_DECLARE_NODE_TYPE_INFO(EnvironmentNode, Node); TVM_DECLARE_NODE_TYPE_INFO(EnvironmentNode, Node);
...@@ -101,7 +109,7 @@ class EnvironmentNode : public RelayNode { ...@@ -101,7 +109,7 @@ class EnvironmentNode : public RelayNode {
/*! \brief A map from string names to global variables that /*! \brief A map from string names to global variables that
* ensures global uniqueness. * ensures global uniqueness.
*/ */
tvm::Map<std::string, GlobalVar> global_map_; tvm::Map<std::string, GlobalVar> global_var_map_;
}; };
struct Environment : public NodeRef { struct Environment : public NodeRef {
......
...@@ -197,7 +197,7 @@ class FunctionNode : public ExprNode { ...@@ -197,7 +197,7 @@ class FunctionNode : public ExprNode {
* *
* \note This can be usually empty for non-polymorphic functions. * \note This can be usually empty for non-polymorphic functions.
*/ */
tvm::Array<TypeParam> type_params; tvm::Array<TypeVar> type_params;
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("params", &params); v->Visit("params", &params);
...@@ -219,7 +219,7 @@ class FunctionNode : public ExprNode { ...@@ -219,7 +219,7 @@ class FunctionNode : public ExprNode {
TVM_DLL static Function make(tvm::Array<Var> params, TVM_DLL static Function make(tvm::Array<Var> params,
Expr body, Expr body,
Type ret_type, Type ret_type,
tvm::Array<TypeParam> ty_params); tvm::Array<TypeVar> ty_params);
static constexpr const char* _type_key = "relay.Function"; static constexpr const char* _type_key = "relay.Function";
TVM_DECLARE_NODE_TYPE_INFO(FunctionNode, ExprNode); TVM_DECLARE_NODE_TYPE_INFO(FunctionNode, ExprNode);
...@@ -375,13 +375,14 @@ class TupleGetItemNode : public ExprNode { ...@@ -375,13 +375,14 @@ class TupleGetItemNode : public ExprNode {
int index; int index;
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("tuple", &tuple); v->Visit("tuple_value", &tuple);
v->Visit("index", &index); v->Visit("index", &index);
v->Visit("_checked_type_", &checked_type_);
} }
TVM_DLL static TupleGetItem make(Expr tuple, int index); TVM_DLL static TupleGetItem make(Expr tuple, int index);
static constexpr const char * _type_key = "relay.GetItem"; static constexpr const char * _type_key = "relay.TupleGetItem";
TVM_DECLARE_NODE_TYPE_INFO(TupleGetItemNode, ExprNode); TVM_DECLARE_NODE_TYPE_INFO(TupleGetItemNode, ExprNode);
}; };
......
...@@ -371,14 +371,14 @@ inline OpRegistry& OpRegistry::add_type_rel( ...@@ -371,14 +371,14 @@ inline OpRegistry& OpRegistry::add_type_rel(
env_type_rel_func = env_func; env_type_rel_func = env_func;
} }
Array<TypeParam> type_params; Array<TypeVar> type_params;
Array<Type> arg_types; Array<Type> arg_types;
// Add inputs. // Add inputs.
std::string input_name_prefix = "in"; std::string input_name_prefix = "in";
for (int i = 0; i < get()->num_inputs; i++) { for (int i = 0; i < get()->num_inputs; i++) {
auto name = input_name_prefix + std::to_string(i); auto name = input_name_prefix + std::to_string(i);
auto param = TypeParamNode::make(name, TypeParamNode::Kind::kType); auto param = TypeVarNode::make(name, TypeVarNode::Kind::kType);
type_params.push_back(param); type_params.push_back(param);
arg_types.push_back(param); arg_types.push_back(param);
} }
...@@ -386,7 +386,7 @@ inline OpRegistry& OpRegistry::add_type_rel( ...@@ -386,7 +386,7 @@ inline OpRegistry& OpRegistry::add_type_rel(
Array<Type> ty_call_args = arg_types; Array<Type> ty_call_args = arg_types;
// Add output type. // Add output type.
auto out_param = TypeParamNode::make("out", TypeParamNode::Kind::kType); auto out_param = TypeVarNode::make("out", TypeVarNode::Kind::kType);
type_params.push_back(out_param); type_params.push_back(out_param);
// this will trigger copy on write. // this will trigger copy on write.
ty_call_args.push_back(out_param); ty_call_args.push_back(out_param);
......
...@@ -12,21 +12,30 @@ ...@@ -12,21 +12,30 @@
namespace tvm { namespace tvm {
namespace relay { namespace relay {
/*! \brief Infer the type of an expression with the provided environment. /*!
* \brief Infer the type of an expression.
* *
* The result of type checking is a new expression with unambigous * The result of type checking is a new expression with unambigous
* type information filled in, as well as it's checked type field * type information filled in, as well as it's checked type field
* populated with the result type. * populated with the result type.
* *
* \param env The environment used for global settings and referencing * \param expr The expression to type check.
* global functions. * \param env The environment used for referencing global functions, can be None.
*
* \param e The expression to type check.
* *
* \return A type checked expression with its checked_type field populated. * \return A type checked expression with its checked_type field populated.
*/ */
Expr InferType(const Environment& env, const Expr& e); Expr InferType(const Expr& expr, const Environment& env);
Expr InferType(const Environment& env, const GlobalVar& var, const Function& f); /*!
* \brief Infer the type of a function as if it is mapped to var in the env.
*
* \param f the function.
* \param env The environment used for referencing global functions.
* \param var The global variable corresponding to the function.
*
* \return A type checked Function with its checked_type field populated.
* \note this function mutates env and is not thread-safe.
*/
Function InferType(const Function& f, const Environment& env, const GlobalVar& var);
/*! /*!
* \brief Check that types are well kinded by applying "kinding rules". * \brief Check that types are well kinded by applying "kinding rules".
...@@ -111,7 +120,7 @@ tvm::Array<Var> FreeVariables(const Expr& e); ...@@ -111,7 +120,7 @@ tvm::Array<Var> FreeVariables(const Expr& e);
* *
* \return the set of free type variables. * \return the set of free type variables.
*/ */
tvm::Array<TypeParam> FreeTypeVariables(const Expr& e); tvm::Array<TypeVar> FreeTypeVariables(const Expr& e);
/*! \brief Get free type parameters from type t. /*! \brief Get free type parameters from type t.
* *
...@@ -121,7 +130,7 @@ tvm::Array<TypeParam> FreeTypeVariables(const Expr& e); ...@@ -121,7 +130,7 @@ tvm::Array<TypeParam> FreeTypeVariables(const Expr& e);
* *
* \return the set of free type variables. * \return the set of free type variables.
*/ */
tvm::Array<TypeParam> FreeTypeVariables(const Type& t); tvm::Array<TypeVar> FreeTypeVariables(const Type& t);
/*! \brief Remove expressions which does not effect the program result. /*! \brief Remove expressions which does not effect the program result.
* *
......
...@@ -98,7 +98,7 @@ RELAY_DEFINE_NODE_REF(TensorType, TensorTypeNode, Type); ...@@ -98,7 +98,7 @@ RELAY_DEFINE_NODE_REF(TensorType, TensorTypeNode, Type);
* This can be viewed as template parameter in c++ template function. * This can be viewed as template parameter in c++ template function.
* *
* For example, in the following pesudo code, * For example, in the following pesudo code,
* the TypeParam of f is TypeParam(kind=kShapeVar, var=n). * the TypeVar of f is TypeVar(kind=kShapeVar, var=n).
* This function can take in a Tensor with shape=(3, 3) and * This function can take in a Tensor with shape=(3, 3) and
* returns a Tensor with shape=(9,) * returns a Tensor with shape=(9,)
* *
...@@ -108,13 +108,13 @@ RELAY_DEFINE_NODE_REF(TensorType, TensorTypeNode, Type); ...@@ -108,13 +108,13 @@ RELAY_DEFINE_NODE_REF(TensorType, TensorTypeNode, Type);
* f(x : Tensor[i32, (n, n)]) -> Tensor[i32, (n * n)] * f(x : Tensor[i32, (n, n)]) -> Tensor[i32, (n * n)]
* *
* \endcode * \endcode
* \sa TypeParamNode The actual container class of TypeParam * \sa TypeVarNode The actual container class of TypeVar
*/ */
class TypeParam; class TypeVar;
/*! \brief TypeParam container node */ /*! \brief TypeVar container node */
class TypeParamNode : public TypeNode { class TypeVarNode : public TypeNode {
public: public:
/*! \brief possible kinds of TypeParam */ /*! \brief possible kinds of TypeVar */
enum Kind : int { enum Kind : int {
/*! \brief template variable in shape expression */ /*! \brief template variable in shape expression */
kType = 0, kType = 0,
...@@ -136,13 +136,13 @@ class TypeParamNode : public TypeNode { ...@@ -136,13 +136,13 @@ class TypeParamNode : public TypeNode {
v->Visit("span", &span); v->Visit("span", &span);
} }
TVM_DLL static TypeParam make(std::string name, Kind kind); TVM_DLL static TypeVar make(std::string name, Kind kind);
static constexpr const char* _type_key = "relay.TypeParam"; static constexpr const char* _type_key = "relay.TypeVar";
TVM_DECLARE_NODE_TYPE_INFO(TypeParamNode, TypeNode); TVM_DECLARE_NODE_TYPE_INFO(TypeVarNode, TypeNode);
}; };
RELAY_DEFINE_NODE_REF(TypeParam, TypeParamNode, Type); RELAY_DEFINE_NODE_REF(TypeVar, TypeVarNode, Type);
/*! /*!
* \brief IncompleteType. * \brief IncompleteType.
...@@ -150,20 +150,20 @@ RELAY_DEFINE_NODE_REF(TypeParam, TypeParamNode, Type); ...@@ -150,20 +150,20 @@ RELAY_DEFINE_NODE_REF(TypeParam, TypeParamNode, Type);
* *
* If we view the type relations as "computational graph of types", * If we view the type relations as "computational graph of types",
* then IncompleteType represents intermediate values of the graph, * then IncompleteType represents intermediate values of the graph,
* TypeParam represents the input to the graph. * TypeVar represents the input to the graph.
*/ */
class IncompleteType; class IncompleteType;
/*! \brief IncompleteType container node */ /*! \brief IncompleteType container node */
class IncompleteTypeNode : public TypeNode { class IncompleteTypeNode : public TypeNode {
public: public:
TypeParamNode::Kind kind; TypeVarNode::Kind kind;
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("kind", &kind); v->Visit("kind", &kind);
} }
TVM_DLL static IncompleteType make(TypeParamNode::Kind kind); TVM_DLL static IncompleteType make(TypeVarNode::Kind kind);
static constexpr const char* _type_key = "relay.IncompleteType"; static constexpr const char* _type_key = "relay.IncompleteType";
TVM_DECLARE_NODE_TYPE_INFO(IncompleteTypeNode, TypeNode); TVM_DECLARE_NODE_TYPE_INFO(IncompleteTypeNode, TypeNode);
...@@ -192,7 +192,7 @@ class FuncType; ...@@ -192,7 +192,7 @@ class FuncType;
* Relay support polymorphic function type. * Relay support polymorphic function type.
* This can be roughly viewed as template function in C++. * This can be roughly viewed as template function in C++.
* *
* \sa TypeParam, TypeConstraint * \sa TypeVar, TypeConstraint
*/ */
class FuncTypeNode : public TypeNode { class FuncTypeNode : public TypeNode {
public: public:
...@@ -203,7 +203,7 @@ class FuncTypeNode : public TypeNode { ...@@ -203,7 +203,7 @@ class FuncTypeNode : public TypeNode {
// The following fields are used in polymorphic(template) functions // The following fields are used in polymorphic(template) functions
// For normal functions, the following two fields will be empty. // For normal functions, the following two fields will be empty.
/*! \brief The type parameters of the function */ /*! \brief The type parameters of the function */
tvm::Array<TypeParam> type_params; tvm::Array<TypeVar> type_params;
/*! /*!
* \brief potential constraint the type need to obey * \brief potential constraint the type need to obey
* \note this field is reserved for futher purposes. * \note this field is reserved for futher purposes.
...@@ -220,7 +220,7 @@ class FuncTypeNode : public TypeNode { ...@@ -220,7 +220,7 @@ class FuncTypeNode : public TypeNode {
TVM_DLL static FuncType make(tvm::Array<Type> arg_types, TVM_DLL static FuncType make(tvm::Array<Type> arg_types,
Type ret_type, Type ret_type,
tvm::Array<TypeParam> type_params, tvm::Array<TypeVar> type_params,
tvm::Array<TypeConstraint> type_constraints); tvm::Array<TypeConstraint> type_constraints);
static constexpr const char* _type_key = "relay.FuncType"; static constexpr const char* _type_key = "relay.FuncType";
......
...@@ -5,7 +5,6 @@ from . import ty ...@@ -5,7 +5,6 @@ from . import ty
from . import expr from . import expr
from . import env from . import env
from . import ir_pass from . import ir_pass
from . import ir_builder
# Root operators # Root operators
from .op import Op from .op import Op
...@@ -16,6 +15,8 @@ from . import nn ...@@ -16,6 +15,8 @@ from . import nn
from . import vision from . import vision
from . import image from . import image
from .scope_builder import ScopeBuilder
# Span # Span
Span = base.Span Span = base.Span
...@@ -27,11 +28,12 @@ Type = ty.Type ...@@ -27,11 +28,12 @@ Type = ty.Type
TupleType = ty.TupleType TupleType = ty.TupleType
TensorType = ty.TensorType TensorType = ty.TensorType
Kind = ty.Kind Kind = ty.Kind
TypeParam = ty.TypeParam TypeVar = ty.TypeVar
TypeConstraint = ty.TypeConstraint TypeConstraint = ty.TypeConstraint
FuncType = ty.FuncType FuncType = ty.FuncType
TypeRelation = ty.TypeRelation TypeRelation = ty.TypeRelation
IncompleteType = ty.IncompleteType IncompleteType = ty.IncompleteType
scalar_type = ty.scalar_type
# Expr # Expr
Constant = expr.Constant Constant = expr.Constant
......
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, wildcard-import # pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, wildcard-import
"""A global environment storing everything needed to interpret or compile a Relay program.""" """A global environment storing everything needed to interpret or compile a Relay program."""
from .base import register_relay_node, RelayNode from .base import register_relay_node, RelayNode
from .._ffi import base as _base
from . import _make from . import _make
from . import _env from . import _env
from . import expr as _expr
@register_relay_node @register_relay_node
class Environment(RelayNode): class Environment(RelayNode):
"""The global Relay environment containing functions, """The global Relay environment containing collection of functions.
options and more.
"""
def __init__(self, funcs=None): Each global function is identified by an unique tvm.relay.GlobalVar.
"""Construct an environment. tvm.relay.GlobalVar and Environment is necessary in order to enable
recursions in function to avoid cyclic reference in the function.x
Parameters Parameters
------ ----------
funcs : optional, dict functions : dict, optional.
Map of global var to Function Map of global var to Function
Returns
------
env: A new environment containing :py:class:`~relay.env.Environment`.
""" """
funcs = funcs if funcs else {} def __init__(self, functions=None):
self.__init_handle_by_constructor__(_make.Environment, funcs) if functions is None:
functions = {}
def add(self, var, func): elif isinstance(functions, dict):
mapped_funcs = {}
for k, v in functions.items():
if isinstance(k, _base.string_types):
k = _expr.GlobalVar(k)
if not isinstance(k, _expr.GlobalVar):
raise TypeError("Expect functions to be Dict[GlobalVar, Function]")
mapped_funcs[k] = v
functions = mapped_funcs
self.__init_handle_by_constructor__(_make.Environment, functions)
def __setitem__(self, var, func):
"""Add a function to the environment. """Add a function to the environment.
Parameters Parameters
...@@ -36,23 +45,42 @@ class Environment(RelayNode): ...@@ -36,23 +45,42 @@ class Environment(RelayNode):
func: Function func: Function
The function. The function.
""" """
if isinstance(var, str): if isinstance(var, _base.string_types):
var = _env.Environment_GetGlobalVar(self, var) var = _expr.GlobalVar(var)
_env.Environment_Add(self, var, func) _env.Environment_Add(self, var, func)
def merge(self, other): def __getitem__(self, var):
"""Merge two environments. """Lookup a global function by name or by variable.
Parameters
----------
var: str or GlobalVar
The name or global variable.
Returns
-------
func: Function
The function referenced by :code:`var`.
"""
if isinstance(var, _base.string_types):
return _env.Environment_Lookup_str(self, var)
else:
return _env.Environment_Lookup(self, var)
def update(self, other):
"""Insert functions in another Environment to current one.
Parameters Parameters
---------- ----------
other: Environment other: Environment
The environment to merge into the current Environment. The environment to merge into the current Environment.
""" """
return _env.Environment_Merge(self, other) if isinstance(other, dict):
other = Environment(other)
return _env.Environment_Update(self, other)
def global_var(self, name): def get_global_var(self, name):
"""Get a global variable by name. """Get a global variable in the function by name.
Parameters Parameters
---------- ----------
...@@ -63,23 +91,9 @@ class Environment(RelayNode): ...@@ -63,23 +91,9 @@ class Environment(RelayNode):
------- -------
global_var: GlobalVar global_var: GlobalVar
The global variable mapped to :code:`name`. The global variable mapped to :code:`name`.
"""
return _env.Environment_GetGlobalVar(self, name)
def __getitem__(self, var):
"""Lookup a global function by name or by variable.
Parameters
----------
var: str or GlobalVar
The name or global variable.
Returns Raises
------- ------
func: Function tvm.TVMError if we cannot find corresponding global var.
The function referenced by :code:`var`.
""" """
if isinstance(var, str): return _env.Environment_GetGlobalVar(self, name)
return _env.Environment_Lookup_str(self, var)
else:
return _env.Environment_Lookup(self, var)
...@@ -28,9 +28,6 @@ class Expr(RelayNode): ...@@ -28,9 +28,6 @@ class Expr(RelayNode):
" the checked_type for this node") " the checked_type for this node")
return ret return ret
def __call__(self, *args):
return Call(self, args, None, None)
@register_relay_node @register_relay_node
class Constant(Expr): class Constant(Expr):
...@@ -57,6 +54,14 @@ class Tuple(Expr): ...@@ -57,6 +54,14 @@ class Tuple(Expr):
def __init__(self, fields): def __init__(self, fields):
self.__init_handle_by_constructor__(_make.Tuple, fields) self.__init_handle_by_constructor__(_make.Tuple, fields)
def __getitem__(self, index):
if index >= len(self):
raise IndexError("Tuple index out of range")
return self.fields[index]
def __len__(self):
return len(self.fields)
@register_relay_node @register_relay_node
class Var(Expr): class Var(Expr):
...@@ -95,6 +100,16 @@ class GlobalVar(Expr): ...@@ -95,6 +100,16 @@ class GlobalVar(Expr):
def __init__(self, name_hint): def __init__(self, name_hint):
self.__init_handle_by_constructor__(_make.GlobalVar, name_hint) self.__init_handle_by_constructor__(_make.GlobalVar, name_hint)
def __call__(self, *args):
"""Invoke the gobal function.
Parameters
----------
args: List[relay.Expr]
Arguments.
"""
return Call(self, args, None, None)
@register_relay_node @register_relay_node
class Function(Expr): class Function(Expr):
...@@ -126,6 +141,16 @@ class Function(Expr): ...@@ -126,6 +141,16 @@ class Function(Expr):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_make.Function, params, body, ret_type, type_params) _make.Function, params, body, ret_type, type_params)
def __call__(self, *args):
"""Invoke the gobal function.
Parameters
----------
args: List[relay.Expr]
Arguments.
"""
return Call(self, args, None, None)
@register_relay_node @register_relay_node
class Call(Expr): class Call(Expr):
...@@ -238,11 +263,17 @@ class TupleWrapper(_node.NodeGeneric): ...@@ -238,11 +263,17 @@ class TupleWrapper(_node.NodeGeneric):
return self.tuple_value return self.tuple_value
def __getitem__(self, key): def __getitem__(self, index):
return self.tuple_value.fields[key] if index >= len(self):
raise IndexError("Tuple index out of range")
return TupleGetItem(self.tuple_value, index)
def __len__(self): def __len__(self):
return len(self.tuple_value.fields) return self.size
def __repr__(self):
return ("TupleWrapper(" + self.tuple_value.__repr__() +
", " + self.size + ")")
def var(name_hint, def var(name_hint,
...@@ -304,13 +335,27 @@ def const(value, dtype=None): ...@@ -304,13 +335,27 @@ def const(value, dtype=None):
dtype: str, optional dtype: str, optional
The data type of the value. The data type of the value.
Note
----
When dtype is None, we use the following rule:
- int maps to "int32"
- float maps to "float32"
- bool maps to "bool"
- other using the same default rule as numpy.
""" """
if isinstance(value, _base.numeric_types): if isinstance(value, (_base.numeric_types, (bool, list))):
value = _np.array(value, dtype=dtype)
elif isinstance(value, (bool, list)):
value = _np.array(value, dtype=dtype) value = _np.array(value, dtype=dtype)
# convert default to int32 and float32
if dtype is None:
if value.dtype == "float64":
value = value.astype("float32")
elif value.dtype == "int64":
value = value.astype("int32")
if isinstance(value, (_np.ndarray, _np.generic)): if isinstance(value, (_np.ndarray, _np.generic)):
value = _nd.array(value) value = _nd.array(value)
if not isinstance(value, _nd.NDArray): if not isinstance(value, _nd.NDArray):
raise ValueError("value has to be scalar or NDArray") raise ValueError("value has to be scalar or NDArray")
return Constant(value) return Constant(value)
...@@ -2,37 +2,39 @@ ...@@ -2,37 +2,39 @@
# pylint: disable=unidiomatic-typecheck # pylint: disable=unidiomatic-typecheck
"""The set of passes for Relay. """The set of passes for Relay.
Exposes an interface for configuring the passes and scripting Exposes an interface for configuring the passes and
them in Python. scripting them in Python.
""" """
from . import _ir_pass from . import _ir_pass
from . import _make from . import _make
# pylint: disable=invalid-name # pylint: disable=invalid-name
def infer_type(env, expr): def infer_type(expr, env=None):
"""Infer the type of expr under the context of env. """Infer the type of expr under the context of env.
Parameters Parameters
---------- ----------
env : relay.Environment expr: tvm.relay.Expr
The input expression.
env: Optional[tvm.relay.Environment]
The global environment. The global environment.
expr : relay.Expr
The input expression.
Returns Returns
------- -------
checked_expr : relay.Expr checked_expr : tvm.relay.Expr
The checked expression. The checked expression.
""" """
return _ir_pass.infer_type(env, expr) return _ir_pass.infer_type(expr, env)
def well_formed(e):
def well_formed(expr):
"""Check that each Var is only bound once (well formed). """Check that each Var is only bound once (well formed).
Parameters Parameters
---------- ----------
e: relay.Expr expr: tvm.relay.Expr
The input expression The input expression
Returns Returns
...@@ -40,7 +42,8 @@ def well_formed(e): ...@@ -40,7 +42,8 @@ def well_formed(e):
well_form : bool well_form : bool
whether the input expression is well formed whether the input expression is well formed
""" """
return _ir_pass.well_formed(e) return _ir_pass.well_formed(expr)
def check_kind(t, env=None): def check_kind(t, env=None):
"""Check that the type is well kinded. """Check that the type is well kinded.
...@@ -48,10 +51,10 @@ def check_kind(t, env=None): ...@@ -48,10 +51,10 @@ def check_kind(t, env=None):
Parameters Parameters
---------- ----------
t: relay.Type t: tvm.relay.Type
The type to check The type to check
env: relay.Environment, optional env: tvm.relay.Environment, optional
The global environment The global environment
Returns Returns
...@@ -71,61 +74,65 @@ def check_kind(t, env=None): ...@@ -71,61 +74,65 @@ def check_kind(t, env=None):
else: else:
return _ir_pass.check_kind(t) return _ir_pass.check_kind(t)
def free_vars(e): def free_vars(e):
"""Get free variables from expression e. """Get free variables from expression e.
Parameters Parameters
---------- ----------
e: relay.Expr e: tvm.relay.Expr
The input expression The input expression
Returns Returns
------- -------
free : List[relay.Var] free : List[tvm.relay.Var]
the list of free variables The list of free variables
""" """
return _ir_pass.free_vars(e) return _ir_pass.free_vars(e)
def free_type_vars(e):
def free_type_vars(expr):
"""Get free type variables from expression/type e """Get free type variables from expression/type e
Parameters Parameters
---------- ----------
e: relay.Expr/relay.Type expr: Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type The input expression/type
Returns Returns
------- -------
free : List[relay.TypeParam] free : List[tvm.relay.TypeParam]
the list of free type variables The list of free type variables
""" """
return _ir_pass.free_type_vars(e) return _ir_pass.free_type_vars(expr)
def dead_code_elimination(e):
def dead_code_elimination(expr):
""" Remove expressions which does not effect the program result (dead code). """ Remove expressions which does not effect the program result (dead code).
Parameters Parameters
---------- ----------
e: relay.Expr e: tvm.relay.Expr
The input Expression The input Expression
Returns Returns
------- -------
result: relay.Expr result: tvm.relay.Expr
An expression which is semantically equal to the input expression, An expression which is semantically equal to the input expression,
but with dead code removed. but with dead code removed.
""" """
return _ir_pass.dead_code_elimination(e) return _ir_pass.dead_code_elimination(expr)
def alpha_equal(lhs, rhs): def alpha_equal(lhs, rhs):
"""Compare two Relay expr for structural equivalence (alpha equivalence). """Compare two Relay expr for structural equivalence (alpha equivalence).
Parameters Parameters
---------- ----------
lhs: relay.Expr lhs: tvm.relay.Expr
One of the input Expression. One of the input Expression.
rhs: relay.Expr rhs: tvm.relay.Expr
One of the input Expression. One of the input Expression.
Returns Returns
......
"""The scope builder interface """
from __future__ import absolute_import
from . import expr as _expr
from .._ffi import base as _base
class WithScope(object):
"""A wrapper for builder methods which introduce scoping.
Parameters
----------
enter_value: object
The value returned by enter.
"""
def __init__(self, enter_value, exit_cb):
self._enter_value = enter_value
self._exit_cb = exit_cb
def __enter__(self):
return self._enter_value
def __exit__(self, ptype, value, trace):
if value:
raise value
else:
self._exit_cb()
def _make_lets(bindings, ret_value):
"""Make a nested let expressions.
Parameters
----------
bindings: List[Tuple[tvm.relay.Var,tvm.relay.Expr]]
The sequence of let bindings
ret_value: tvm.relay.Expr
The final value of the expression.
Returns
-------
lets: tvm.relay.Expr
A nested let expression.
"""
if ret_value is None:
raise RuntimeError("ret is not called in this scope")
if isinstance(ret_value, _expr.If) and ret_value.false_branch is None:
raise RuntimeError("Creating an If expression without else.")
let_expr = ret_value
for var, value in reversed(bindings):
let_expr = _expr.Let(var, value, let_expr)
return let_expr
class ScopeBuilder(object):
"""Scope builder class.
Enables users to build up a nested
scope(let, if) expression easily.
Examples
--------
..code-block: python
sb = relay.ScopeBuilder()
cond = relay.var("cond", 'bool')
x = relay.var("x")
y = relay.var("y")
with sb.if_scope(cond):
one = relay.const(1, "float32")
t1 = sb.let(t1, relay.add(x, one))
sb.ret(t1)
with sb.else_scope():
sb.ret(y)
print(sb.get().astext())
"""
def __init__(self):
self._bindings = [[]]
self._ret_values = [None]
def _enter_scope(self):
self._bindings.append([])
self._ret_values.append(None)
def _exit_scope(self):
bindings = self._bindings.pop()
ret_value = self._ret_values.pop()
return bindings, ret_value
def let(self, var, value):
"""Create a new let binding.
Parameters
----------
var: Union[Tuple[str, relay.Type], tvm.relay.Var]
The variable or name of variable.
value: tvm.relay.Expr
The value to be binded
"""
if isinstance(var, (tuple, list)):
if len(var) > 2:
raise ValueError("Expect var to be Tuple[str, relay.Type]")
var = _expr.var(*var)
elif isinstance(var, _base.string_types):
var = _expr.var(var)
self._bindings[-1].append((var, value))
return var
def if_scope(self, cond):
"""Create a new if scope.
Parameters
----------
cond: tvm.relay.Expr
The condition
Returns
-------
scope: WithScope
The if scope.
Note
----
The user must follows with an else scope.
"""
self._enter_scope()
def _on_exit():
bindings, ret_value = self._exit_scope()
if self._ret_values[-1] is not None:
raise RuntimeError("result already returned before if scope")
true_branch = _make_lets(bindings, ret_value)
self._ret_values[-1] = _expr.If(cond, true_branch, None)
return WithScope(None, _on_exit)
def else_scope(self):
"""Create a new else scope.
Returns
-------
scope: WithScope
The if scope.
"""
self._enter_scope()
def _on_exit():
bindings, ret_value = self._exit_scope()
partial_if = self._ret_values[-1]
no_else = (not isinstance(partial_if, _expr.If) or
partial_if.false_branch is not None)
if no_else:
raise RuntimeError("else scope must follows")
false_branch = _make_lets(bindings, ret_value)
self._ret_values[-1] = _expr.If(
partial_if.cond,
partial_if.true_branch,
false_branch)
return WithScope(None, _on_exit)
def ret(self, value):
"""Set the return value of this scope.
Parameters
----------
value: tvm.relay.Expr
The return value.
"""
if self._ret_values[-1] is not None:
raise RuntimeError("ret value is already set in this scope.")
self._ret_values[-1] = value
def get(self):
"""Get the generated result.
Returns
-------
value: tvm.relay.Expr
The final result of the expression.
"""
if len(self._bindings) != 1:
raise RuntimeError("can only call get at the outmost scope")
return _make_lets(self._bindings[-1], self._ret_values[-1])
...@@ -56,7 +56,7 @@ class Kind(IntEnum): ...@@ -56,7 +56,7 @@ class Kind(IntEnum):
Shape = 3 Shape = 3
@register_relay_node @register_relay_node
class TypeParam(Type): class TypeVar(Type):
"""A type parameter used for generic types in Relay, """A type parameter used for generic types in Relay,
see tvm/relay/type.h for more details. see tvm/relay/type.h for more details.
...@@ -66,7 +66,7 @@ class TypeParam(Type): ...@@ -66,7 +66,7 @@ class TypeParam(Type):
""" """
def __init__(self, var, kind=Kind.Type): def __init__(self, var, kind=Kind.Type):
"""Construct a TypeParam. """Construct a TypeVar.
Parameters Parameters
---------- ----------
...@@ -78,10 +78,10 @@ class TypeParam(Type): ...@@ -78,10 +78,10 @@ class TypeParam(Type):
Returns Returns
------- -------
type_param: TypeParam type_param: TypeVar
The type parameter. The type parameter.
""" """
self.__init_handle_by_constructor__(_make.TypeParam, var, kind) self.__init_handle_by_constructor__(_make.TypeVar, var, kind)
@register_relay_node @register_relay_node
...@@ -122,26 +122,30 @@ class FuncType(Type): ...@@ -122,26 +122,30 @@ class FuncType(Type):
We informally write them as: We informally write them as:
`forall (type_params), (arg_types) -> ret_type where type_constraints` `forall (type_params), (arg_types) -> ret_type where type_constraints`
"""
def __init__(self,
arg_types,
ret_type,
type_params,
type_constraints):
"""Construct a function type.
Parameters Parameters
---------- ----------
arg_types: list of Type arg_types: List[tvm.relay.Type]
ret_type: Type The argument types
type_params: list of TypeParam
type_constraints: list of TypeConstraint
Returns ret_type: tvm.relay.Type
------- The return type.
func_type: FuncType
The function type. type_params: List[tvm.relay.TypeVar]
The type parameters
type_constraints: List[tvm.relay.TypeConstraint]
The type constraints.
""" """
def __init__(self,
arg_types,
ret_type,
type_params=None,
type_constraints=None):
if type_params is None:
type_params = []
if type_constraints is None:
type_constraints = []
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_make.FuncType, arg_types, ret_type, type_params, type_constraints) _make.FuncType, arg_types, ret_type, type_params, type_constraints)
...@@ -175,3 +179,21 @@ class TypeRelation(TypeConstraint): ...@@ -175,3 +179,21 @@ class TypeRelation(TypeConstraint):
def __init__(self, func, args, num_inputs, attrs): def __init__(self, func, args, num_inputs, attrs):
self.__init_handle_by_constructor__(_make.TypeRelation, self.__init_handle_by_constructor__(_make.TypeRelation,
func, args, num_inputs, attrs) func, args, num_inputs, attrs)
def scalar_type(dtype):
"""Creates a scalar type.
This function returns TensorType((), dtype)
Parameters
----------
dtype : str
The content data type.
Returns
-------
s_type: tvm.relay.TensorType
The result type.
"""
return TensorType((), dtype)
...@@ -16,87 +16,71 @@ using namespace runtime; ...@@ -16,87 +16,71 @@ using namespace runtime;
Environment EnvironmentNode::make(tvm::Map<GlobalVar, Function> global_funcs) { Environment EnvironmentNode::make(tvm::Map<GlobalVar, Function> global_funcs) {
auto n = make_node<EnvironmentNode>(); auto n = make_node<EnvironmentNode>();
n->functions = std::move(global_funcs); n->functions = std::move(global_funcs);
for (const auto& kv : n->functions) {
// set gloval var map
CHECK(!n->global_var_map_.count(kv.first->name_hint))
<< "Duplicate global function name " << kv.first->name_hint;
n->global_var_map_.Set(kv.first->name_hint, kv.first);
}
return Environment(n); return Environment(n);
} }
GlobalVar EnvironmentNode::GetGlobalVar(const std::string &str) { GlobalVar EnvironmentNode::GetGlobalVar(const std::string& name) {
auto global_id = global_map_.find(str); auto it = global_var_map_.find(name);
if (global_id != global_map_.end()) { CHECK(it != global_var_map_.end())
return (*global_id).second; << "Cannot find global var " << name << " in the Environment";
} else { return (*it).second;
auto id = GlobalVarNode::make(str);
this->global_map_.Set(str, id);
return id;
}
} }
/*! void EnvironmentNode::Add(const GlobalVar& var,
* \brief Add a new item to the global environment const Function& func,
* \note if the update flag is not set adding a duplicate
* definition will trigger an exception, otherwise we will
* update the definition if and only if it is type compatible.
*/
void EnvironmentNode::Add(const GlobalVar &var,
const Function &func,
bool update) { bool update) {
// Type check the item before we add it to the environment. // Type check the item before we add it to the environment.
auto env = GetRef<Environment>(this); auto env = GetRef<Environment>(this);
Function checked_func = InferType(func, env, var);
Expr checked_expr = InferType(env, var, func);
if (const FunctionNode *func_node = checked_expr.as<FunctionNode>()) {
auto checked_func = GetRef<Function>(func_node);
auto type = checked_func->checked_type(); auto type = checked_func->checked_type();
CHECK(type.as<IncompleteTypeNode>() == nullptr); CHECK(type.as<IncompleteTypeNode>() == nullptr);
if (functions.find(var) != functions.end()) { if (functions.find(var) != functions.end()) {
if (!update) { CHECK(update)
throw dmlc::Error("already have definition for XXXX."); << "Already have definition for " << var->name_hint;
}
auto old_type = functions[var].as<FunctionNode>()->checked_type(); auto old_type = functions[var].as<FunctionNode>()->checked_type();
CHECK(AlphaEqual(type, old_type))
if (!AlphaEqual(type, old_type)) { << "Environment#update changes type, not possible in this mode.";
throw dmlc::Error(
"Environment#update changes type, not possible in this mode.");
} }
this->functions.Set(var, checked_func); this->functions.Set(var, checked_func);
} else { // set gloval var map
this->functions.Set(var, checked_func); CHECK(!global_var_map_.count(var->name_hint))
} << "Duplicate global function name " << var->name_hint;
} else { global_var_map_.Set(var->name_hint, var);
LOG(FATAL) << "internal error: unknown item type, unreachable code";
}
} }
void EnvironmentNode::Update(const GlobalVar &var, const Function &func) { void EnvironmentNode::Update(const GlobalVar& var, const Function& func) {
this->Add(var, func, true); this->Add(var, func, true);
} }
void EnvironmentNode::Remove(const GlobalVar & var) { void EnvironmentNode::Remove(const GlobalVar& var) {
auto functions_node = this->functions.CopyOnWrite(); auto functions_node = this->functions.CopyOnWrite();
functions_node->data.erase(var.node_); functions_node->data.erase(var.node_);
auto gvar_node = global_var_map_.CopyOnWrite();
gvar_node->data.erase(var->name_hint);
} }
Function EnvironmentNode::Lookup(const GlobalVar &var) { Function EnvironmentNode::Lookup(const GlobalVar& var) {
auto func = functions.find(var); auto it = functions.find(var);
if (func != functions.end()) { CHECK(it != functions.end())
return (*func).second; << "There is no definition of " << var->name_hint;
} else { return (*it).second;
throw Error(std::string("there is no definition of ") + var->name_hint);
}
} }
Function EnvironmentNode::Lookup(const std::string &str) { Function EnvironmentNode::Lookup(const std::string &name) {
GlobalVar id = this->GetGlobalVar(str); GlobalVar id = this->GetGlobalVar(name);
return this->Lookup(id); return this->Lookup(id);
} }
void EnvironmentNode::Merge(const Environment &env) { void EnvironmentNode::Update(const Environment &env) {
for (auto pair : env->functions) { for (auto pair : env->functions) {
this->functions.Set(pair.first, pair.second); this->Update(pair.first, pair.second);
} }
} }
...@@ -134,10 +118,10 @@ TVM_REGISTER_API("relay._env.Environment_Lookup_str") ...@@ -134,10 +118,10 @@ TVM_REGISTER_API("relay._env.Environment_Lookup_str")
*ret = env->Lookup(var); *ret = env->Lookup(var);
}); });
TVM_REGISTER_API("relay._env.Environment_Merge") TVM_REGISTER_API("relay._env.Environment_Update")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
Environment env = args[0]; Environment env = args[0];
env->Merge(args[1]); env->Update(args[1]);
}); });
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
......
...@@ -104,7 +104,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -104,7 +104,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
Function FunctionNode::make(tvm::Array<Var> params, Function FunctionNode::make(tvm::Array<Var> params,
Expr body, Expr body,
Type ret_type, Type ret_type,
tvm::Array<TypeParam> type_params) { tvm::Array<TypeVar> type_params) {
NodePtr<FunctionNode> n = make_node<FunctionNode>(); NodePtr<FunctionNode> n = make_node<FunctionNode>();
n->params = std::move(params); n->params = std::move(params);
n->body = std::move(body); n->body = std::move(body);
......
...@@ -66,11 +66,11 @@ Expr ExprMutator::VisitExpr_(const TupleNode* op) { ...@@ -66,11 +66,11 @@ Expr ExprMutator::VisitExpr_(const TupleNode* op) {
} }
Expr ExprMutator::VisitExpr_(const FunctionNode* op) { Expr ExprMutator::VisitExpr_(const FunctionNode* op) {
tvm::Array<TypeParam> ty_params; tvm::Array<TypeVar> ty_params;
bool all_ty_params_changed = true; bool all_ty_params_changed = true;
for (auto ty_param : op->type_params) { for (auto ty_param : op->type_params) {
TypeParam new_ty_param = Downcast<TypeParam>(VisitType(ty_param)); TypeVar new_ty_param = Downcast<TypeVar>(VisitType(ty_param));
ty_params.push_back(new_ty_param); ty_params.push_back(new_ty_param);
all_ty_params_changed &= new_ty_param.same_as(ty_param); all_ty_params_changed &= new_ty_param.same_as(ty_param);
} }
......
...@@ -217,6 +217,8 @@ class TextPrinter : ...@@ -217,6 +217,8 @@ class TextPrinter :
return ConstScalar(dtype, static_cast<const float*>(op->data->data)); return ConstScalar(dtype, static_cast<const float*>(op->data->data));
} else if (dtype == Float(64)) { } else if (dtype == Float(64)) {
return ConstScalar(dtype, static_cast<const double*>(op->data->data)); return ConstScalar(dtype, static_cast<const double*>(op->data->data));
} else if (dtype == Bool()) {
return ConstScalar(dtype, static_cast<const uint8_t*>(op->data->data));
} }
} }
// default fall-back, record it as meta node. // default fall-back, record it as meta node.
...@@ -638,8 +640,14 @@ class TextPrinter : ...@@ -638,8 +640,14 @@ class TextPrinter :
* \return The corresponding name. * \return The corresponding name.
*/ */
TextValue AllocVarName(const Var& var) { TextValue AllocVarName(const Var& var) {
std::string name = GetUniqueName('%' + var->name_hint); std::string name = var->name_hint;
TextValue val(name); // always make sure first name is alpha
if (name.length() != 0 && !std::isalpha(name[0])) {
name = "%v" + name;
} else {
name = "%" + name;
}
TextValue val(GetUniqueName(name));
CHECK(!memo_.count(var)); CHECK(!memo_.count(var));
memo_[var] = val; memo_[var] = val;
return val; return val;
......
...@@ -36,30 +36,30 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -36,30 +36,30 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
p->stream << "TensorType(" << node->shape << ", " << node->dtype << ")"; p->stream << "TensorType(" << node->shape << ", " << node->dtype << ")";
}); });
TypeParam TypeParamNode::make(std::string name, TypeParamNode::Kind kind) { TypeVar TypeVarNode::make(std::string name, TypeVarNode::Kind kind) {
NodePtr<TypeParamNode> n = make_node<TypeParamNode>(); NodePtr<TypeVarNode> n = make_node<TypeVarNode>();
n->var = tvm::Var(name); n->var = tvm::Var(name);
n->kind = std::move(kind); n->kind = std::move(kind);
return TypeParam(n); return TypeVar(n);
} }
TVM_REGISTER_NODE_TYPE(TypeParamNode); TVM_REGISTER_NODE_TYPE(TypeVarNode);
TVM_REGISTER_API("relay._make.TypeParam") TVM_REGISTER_API("relay._make.TypeVar")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
int kind = args[1]; int kind = args[1];
*ret = *ret =
TypeParamNode::make(args[0], static_cast<TypeParamNode::Kind>(kind)); TypeVarNode::make(args[0], static_cast<TypeVarNode::Kind>(kind));
}); });
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TypeParamNode>([](const TypeParamNode *node, .set_dispatch<TypeVarNode>([](const TypeVarNode *node,
tvm::IRPrinter *p) { tvm::IRPrinter *p) {
p->stream << "TypeParamNode(" << node->var->name_hint << ", " p->stream << "TypeVarNode(" << node->var->name_hint << ", "
<< node->kind << ")"; << node->kind << ")";
}); });
IncompleteType IncompleteTypeNode::make(TypeParamNode::Kind kind) { IncompleteType IncompleteTypeNode::make(TypeVarNode::Kind kind) {
auto n = make_node<IncompleteTypeNode>(); auto n = make_node<IncompleteTypeNode>();
n->kind = std::move(kind); n->kind = std::move(kind);
return IncompleteType(n); return IncompleteType(n);
...@@ -70,7 +70,7 @@ TVM_REGISTER_NODE_TYPE(IncompleteTypeNode); ...@@ -70,7 +70,7 @@ TVM_REGISTER_NODE_TYPE(IncompleteTypeNode);
TVM_REGISTER_API("relay._make.IncompleteType") TVM_REGISTER_API("relay._make.IncompleteType")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
int kind = args[0]; int kind = args[0];
*ret = IncompleteTypeNode::make(static_cast<TypeParamNode::Kind>(kind)); *ret = IncompleteTypeNode::make(static_cast<TypeVarNode::Kind>(kind));
}); });
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
...@@ -82,7 +82,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -82,7 +82,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
FuncType FuncTypeNode::make(tvm::Array<Type> arg_types, FuncType FuncTypeNode::make(tvm::Array<Type> arg_types,
Type ret_type, Type ret_type,
tvm::Array<TypeParam> type_params, tvm::Array<TypeVar> type_params,
tvm::Array<TypeConstraint> type_constraints) { tvm::Array<TypeConstraint> type_constraints) {
NodePtr<FuncTypeNode> n = make_node<FuncTypeNode>(); NodePtr<FuncTypeNode> n = make_node<FuncTypeNode>();
n->arg_types = std::move(arg_types); n->arg_types = std::move(arg_types);
......
...@@ -78,6 +78,7 @@ RELAY_REGISTER_OP("image.resize") ...@@ -78,6 +78,7 @@ RELAY_REGISTER_OP("image.resize")
for layout NHWC for layout NHWC
(batch_size, size[0], size[1], channels) (batch_size, size[0], size[1], channels)
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.ResizeAttrs")
.set_num_inputs(1) .set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.") .add_argument("data", "Tensor", "The input tensor.")
.set_support_level(5) .set_support_level(5)
......
...@@ -247,6 +247,8 @@ RELAY_REGISTER_UNARY_OP("relay.op.nn._make.", "relu") ...@@ -247,6 +247,8 @@ RELAY_REGISTER_UNARY_OP("relay.op.nn._make.", "relu")
// Positional relay function to create LRN operator used by frontend FFI. // Positional relay function to create LRN operator used by frontend FFI.
TVM_REGISTER_NODE_TYPE(LRNAttrs);
Expr MakeLRN(Expr data, Expr MakeLRN(Expr data,
IndexExpr size, IndexExpr size,
IndexExpr axis, IndexExpr axis,
...@@ -290,6 +292,8 @@ centered at that value (zero padding is added where necessary). ...@@ -290,6 +292,8 @@ centered at that value (zero padding is added where necessary).
// Positional relay function to create L2Normalize operator used by frontend FFI. // Positional relay function to create L2Normalize operator used by frontend FFI.
TVM_REGISTER_NODE_TYPE(L2NormalizeAttrs);
Expr MakeL2Normalize(Expr data, Expr MakeL2Normalize(Expr data,
double eps, double eps,
Array<IndexExpr> axis) { Array<IndexExpr> axis) {
...@@ -315,6 +319,7 @@ Normalizes along dimension axis using an L2 norm ...@@ -315,6 +319,7 @@ Normalizes along dimension axis using an L2 norm
- **data**: The input tensor. - **data**: The input tensor.
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.L2NormalizeAttrs")
.set_num_inputs(1) .set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.") .add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2) .set_support_level(2)
......
...@@ -77,6 +77,7 @@ RELAY_REGISTER_OP("nn.pad") ...@@ -77,6 +77,7 @@ RELAY_REGISTER_OP("nn.pad")
.describe(R"code(Pad for n-D tensor. .describe(R"code(Pad for n-D tensor.
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.PadAttrs")
.set_num_inputs(1) .set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.") .add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2) .set_support_level(2)
......
...@@ -12,6 +12,7 @@ namespace tvm { ...@@ -12,6 +12,7 @@ namespace tvm {
namespace relay { namespace relay {
TVM_REGISTER_NODE_TYPE(MaxPool2DAttrs); TVM_REGISTER_NODE_TYPE(MaxPool2DAttrs);
TVM_REGISTER_NODE_TYPE(AvgPool2DAttrs);
template <typename AttrTtype> template <typename AttrTtype>
bool Pool2DRel(const Array<Type>& types, bool Pool2DRel(const Array<Type>& types,
...@@ -115,6 +116,7 @@ RELAY_REGISTER_OP("nn.max_pool2d") ...@@ -115,6 +116,7 @@ RELAY_REGISTER_OP("nn.max_pool2d")
equation. equation.
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.MaxPool2DAttrs")
.set_num_inputs(1) .set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.") .add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2) .set_support_level(2)
...@@ -169,6 +171,7 @@ Average pooling operation for one dimensional data. ...@@ -169,6 +171,7 @@ Average pooling operation for one dimensional data.
equation. equation.
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.AvgPool2DAttrs")
.set_num_inputs(1) .set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.") .add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2) .set_support_level(2)
...@@ -232,6 +235,7 @@ RELAY_REGISTER_OP("nn.global_avg_pool2d") ...@@ -232,6 +235,7 @@ RELAY_REGISTER_OP("nn.global_avg_pool2d")
(batch_size, channels, 1, 1) if `layout` is `NCHW`. (batch_size, channels, 1, 1) if `layout` is `NCHW`.
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.GlobalPool2DAttrs")
.set_num_inputs(1) .set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.") .add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2) .set_support_level(2)
...@@ -261,6 +265,7 @@ RELAY_REGISTER_OP("nn.global_max_pool2d") ...@@ -261,6 +265,7 @@ RELAY_REGISTER_OP("nn.global_max_pool2d")
(batch_size, channels, 1, 1) if `layout` is `NCHW`. (batch_size, channels, 1, 1) if `layout` is `NCHW`.
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.GlobalPool2DAttrs")
.set_num_inputs(1) .set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.") .add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2) .set_support_level(2)
......
...@@ -78,6 +78,7 @@ RELAY_REGISTER_OP("nn.upsampling") ...@@ -78,6 +78,7 @@ RELAY_REGISTER_OP("nn.upsampling")
(batch_size, in_height*scale, in_width*scale, channels) (batch_size, in_height*scale, in_width*scale, channels)
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.UpSamplingAttrs")
.set_num_inputs(1) .set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.") .add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2) .set_support_level(2)
......
...@@ -199,7 +199,7 @@ RELAY_REGISTER_REDUCE_OP("argmax") ...@@ -199,7 +199,7 @@ RELAY_REGISTER_REDUCE_OP("argmax")
values over a given axis. values over a given axis.
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_num_inputs(1) .set_attrs_type_key("relay.attrs.ReduceAttrs")
.set_support_level(4) .set_support_level(4)
.add_type_rel("ArgReduce", ArgReduceRel); .add_type_rel("ArgReduce", ArgReduceRel);
...@@ -209,7 +209,7 @@ RELAY_REGISTER_REDUCE_OP("argmin") ...@@ -209,7 +209,7 @@ RELAY_REGISTER_REDUCE_OP("argmin")
values over a given axis. values over a given axis.
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_num_inputs(1) .set_attrs_type_key("relay.attrs.ReduceAttrs")
.set_support_level(4) .set_support_level(4)
.add_type_rel("ArgReduce", ArgReduceRel); .add_type_rel("ArgReduce", ArgReduceRel);
......
...@@ -144,12 +144,14 @@ RELAY_REGISTER_OP("concatenate") ...@@ -144,12 +144,14 @@ RELAY_REGISTER_OP("concatenate")
- **axis** : The axis along which the tensors are concatenated. - **axis** : The axis along which the tensors are concatenated.
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.ConcatenateAttrs")
.set_num_inputs(1) .set_num_inputs(1)
.add_argument("data", "Tensor", "The input list of tensors.") .add_argument("data", "Tensor", "The input list of tensors.")
.set_support_level(1) .set_support_level(1)
.add_type_rel("Concatenate", ConcatenateRel); .add_type_rel("Concatenate", ConcatenateRel);
/* relay.transpose */ /* relay.transpose */
TVM_REGISTER_NODE_TYPE(TransposeAttrs);
bool TransposeRel(const Array<Type>& types, bool TransposeRel(const Array<Type>& types,
int num_inputs, int num_inputs,
...@@ -224,12 +226,15 @@ RELAY_REGISTER_OP("transpose") ...@@ -224,12 +226,15 @@ RELAY_REGISTER_OP("transpose")
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_num_inputs(1) .set_num_inputs(1)
.set_attrs_type_key("relay.attrs.TransposeAttrs")
.add_argument("data", "Tensor", "The input tensor.") .add_argument("data", "Tensor", "The input tensor.")
.set_support_level(3) .set_support_level(3)
.add_type_rel("Transpose", TransposeRel); .add_type_rel("Transpose", TransposeRel);
/* relay.reshape */ /* relay.reshape */
TVM_REGISTER_NODE_TYPE(ReshapeAttrs);
bool ReshapeRel(const Array<Type>& types, bool ReshapeRel(const Array<Type>& types,
int num_inputs, int num_inputs,
const Attrs& attrs, const Attrs& attrs,
...@@ -310,6 +315,7 @@ Example:: ...@@ -310,6 +315,7 @@ Example::
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_num_inputs(1) .set_num_inputs(1)
.set_attrs_type_key("relay.attrs.ReshapeAttrs")
.add_argument("data", "Tensor", "The input tensor.") .add_argument("data", "Tensor", "The input tensor.")
.set_support_level(3) .set_support_level(3)
.add_type_rel("Reshape", ReshapeRel); .add_type_rel("Reshape", ReshapeRel);
...@@ -397,12 +403,14 @@ Examples:: ...@@ -397,12 +403,14 @@ Examples::
[ 4., 3.]] [ 4., 3.]]
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.TakeAttrs")
.set_num_inputs(2) .set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.") .add_argument("data", "Tensor", "The input tensor.")
.add_argument("indices", "Tensor", "The indices tensor.") .add_argument("indices", "Tensor", "The indices tensor.")
.set_support_level(2) .set_support_level(2)
.add_type_rel("Take", TakeRel); .add_type_rel("Take", TakeRel);
// Init ops
TVM_REGISTER_NODE_TYPE(InitOpAttrs); TVM_REGISTER_NODE_TYPE(InitOpAttrs);
bool FullRel(const Array<Type>& types, bool FullRel(const Array<Type>& types,
...@@ -448,6 +456,7 @@ RELAY_REGISTER_OP("full") ...@@ -448,6 +456,7 @@ RELAY_REGISTER_OP("full")
.describe(R"code(Fill array with scalar value. .describe(R"code(Fill array with scalar value.
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.InitOpAttrs")
.set_num_inputs(1) .set_num_inputs(1)
.add_argument("fill_value", "double", "The value to fill.") .add_argument("fill_value", "double", "The value to fill.")
.set_support_level(3) .set_support_level(3)
...@@ -634,6 +643,10 @@ Examples:: ...@@ -634,6 +643,10 @@ Examples::
.set_support_level(4) .set_support_level(4)
.add_type_rel("Where", WhereRel); .add_type_rel("Where", WhereRel);
// Squeeze
TVM_REGISTER_NODE_TYPE(SqueezeAttrs);
Expr MakeSqueeze(Expr data, Expr MakeSqueeze(Expr data,
Array<IndexExpr> axes) { Array<IndexExpr> axes) {
auto attrs = make_node<SqueezeAttrs>(); auto attrs = make_node<SqueezeAttrs>();
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/logging.h> #include <tvm/relay/logging.h>
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/ir_pass.h>
#include <numeric> #include <numeric>
#include "./type_relations.h" #include "./type_relations.h"
...@@ -21,14 +22,6 @@ TensorType ToTensorType(const Type& t) { ...@@ -21,14 +22,6 @@ TensorType ToTensorType(const Type& t) {
} }
} }
// TODO(@jroesch) what size value do we extract, 64bit or 32bit?
int ToInt(const tvm::Expr& e) {
CHECK(e.defined());
auto imm = e.as<tvm::ir::IntImm>();
CHECK(imm) << "TYPE: " << imm << imm->type << std::endl;
return imm->value;
}
bool IdentityRel(const Array<Type>& types, bool IdentityRel(const Array<Type>& types,
int num_inputs, int num_inputs,
const Attrs& attrs, const Attrs& attrs,
...@@ -39,72 +32,54 @@ bool IdentityRel(const Array<Type>& types, ...@@ -39,72 +32,54 @@ bool IdentityRel(const Array<Type>& types,
return true; return true;
} }
Type ConcreteBroadcast(const TensorType& t1, bool EqualCheck(const IndexExpr& lhs,
const TensorType& t2, const IndexExpr& rhs) {
DataType output_dtype) { IndexExpr diff = lhs - rhs;
RELAY_LOG(INFO) << "ConcreteBroadcast: t1=" << t1 << " t2=" << t2 if (const int64_t* pdiff = as_const_int(diff)) {
<< std::endl; return pdiff[0] == 0;
auto sh1 = t1->shape;
auto sh2 = t2->shape;
RELAY_LOG(INFO) << "ConcreteBroadcast: sh1=" << sh1 << " sh2=" << sh2
<< std::endl;
if (sh1.size() == 0 && sh2.size() == 0) {
return TensorTypeNode::make({}, output_dtype);
// We have non-zero shapes so broadcast rules apply.
} else {
auto suffix_len = static_cast<int>(std::min(sh1.size(), sh2.size()));
auto full_len = static_cast<int>(std::max(sh1.size(), sh2.size()));
auto rev_sh1 = sh1.rbegin();
auto rev_sh2 = sh2.rbegin();
while (rev_sh1 != sh1.rend() && rev_sh2 != sh2.rend()) {
auto dim1 = ToInt(*rev_sh1);
auto dim2 = ToInt(*rev_sh2);
if ((dim1 != dim2) && ((dim1 != 1) && (dim2 != 1))) {
CHECK(false) << "Dimension mistmatch "
<< "dim1: " << dim1 << " dim2: " << dim2 << std::endl;
} }
rev_sh1++; // symbolic
rev_sh2++; diff = tvm::ir::CanonicalSimplify(diff);
if (const int64_t* pdiff = as_const_int(diff)) {
return pdiff[0] == 0;
} }
return false;
}
Array<IndexExpr> larger; bool EqualConstInt(const IndexExpr& lhs, int64_t value) {
Array<IndexExpr> smaller; if (const int64_t* pvalue = as_const_int(lhs)) {
return pvalue[0] == value;
for (int i = 0; i < (full_len - suffix_len); i++) {
smaller.push_back(make_const(tvm::Int(64), 1));
} }
return false;
}
if (sh1.size() < sh2.size()) { Type ConcreteBroadcast(const TensorType& t1,
for (auto sh : sh1) { const TensorType& t2,
smaller.push_back(sh); DataType output_dtype) {
} std::vector<IndexExpr> oshape;
larger = sh2; size_t ndim1 = t1->shape.size();
} else if (sh1.size() > sh2.size()) { size_t ndim2 = t2->shape.size();
for (auto sh : sh1) { size_t i = 1;
larger.push_back(sh); for (; i <= std::min(ndim1, ndim2); ++i) {
} IndexExpr s1 = t1->shape[ndim1 - i];
smaller = sh2; IndexExpr s2 = t2->shape[ndim2 - i];
if (EqualCheck(s1, s2)) {
oshape.push_back(s1);
} else if (EqualConstInt(s1, 1)) {
oshape.push_back(s2);
} else if (EqualConstInt(s2, 1)) {
oshape.push_back(s1);
} else { } else {
larger = sh1; LOG(FATAL) << "Incompatible broadcast type " << t1 << " and " << t2;
smaller = sh2;
} }
CHECK_EQ(larger.size(), smaller.size());
Array<IndexExpr> out_shape;
for (size_t i = 0; i < smaller.size(); i++) {
auto left = smaller[i].as<tvm::ir::IntImm>();
auto right = larger[i].as<tvm::ir::IntImm>();
CHECK(left);
CHECK(right);
int64_t dim = std::max(left->value, right->value);
out_shape.push_back(make_const(tvm::Int(64), dim));
} }
size_t max_ndim = std::max(ndim1, ndim2);
return TensorTypeNode::make(out_shape, output_dtype); auto& rshape = (ndim1 > ndim2) ? t1->shape : t2->shape;
for (; i <= max_ndim; ++i) {
oshape.push_back(rshape[max_ndim - i]);
} }
return TensorTypeNode::make(Array<IndexExpr>(
oshape.rbegin(), oshape.rend()), output_dtype);
} }
bool BroadcastRel(const Array<Type>& types, bool BroadcastRel(const Array<Type>& types,
...@@ -141,71 +116,5 @@ bool BroadcastCompRel(const Array<Type>& types, ...@@ -141,71 +116,5 @@ bool BroadcastCompRel(const Array<Type>& types,
return false; return false;
} }
/*! \brief Handle concrete concat case from known input to output. */
inline Type ConcreteConcatRel(const Type& input_type) {
if (auto tuple_node = input_type.as<TupleTypeNode>()) {
// NB: For now the axis argument is hardwired to be 0.
std::vector<int> dims;
DataType dtype;
CHECK_LT(1, tuple_node->fields.size());
bool skip_first = true;
// Collect the suffix dimensions since axis is zero.
// TODO(@jroesch): This is a demonstration of how
// to do varargs. It requires a little more work to
// fully type the behavior of concat.
auto first = Downcast<TensorType>(tuple_node->fields[0]);
dtype = first->dtype;
for (auto dim_expr : first->shape) {
if (!skip_first) {
dims.push_back(ToInt(dim_expr));
} else {
skip_first = false;
}
}
std::vector<int> axis_dims;
for (auto field_ty : tuple_node->fields) {
auto ttype = Downcast<TensorType>(field_ty);
for (size_t i = 0; i < ttype->shape.size(); i++) {
if (i != 0) {
CHECK_EQ(ToInt(dims[i - 1]), ToInt(ttype->shape[i]));
} else {
axis_dims.push_back(ToInt(ttype->shape[i]));
}
}
}
auto out_axis_dim = std::accumulate(axis_dims.begin(), axis_dims.end(), 0);
Array<tvm::Expr> out_shape = { make_const(Int(64), out_axis_dim) };
for (auto dim : dims) {
out_shape.push_back(make_const(Int(64), dim));
}
return TensorTypeNode::make(out_shape, dtype);
} else {
throw TypeRelationError("concat can only be used with a tuple as its argument");
}
}
bool ConcatRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
if (types[0].as<TupleTypeNode>()) {
reporter->Assign(types[1], ConcreteConcatRel(types[0]));
return true;
}
return false;
}
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -13,17 +13,6 @@ ...@@ -13,17 +13,6 @@
namespace tvm { namespace tvm {
namespace relay { namespace relay {
/*! \brief The error raised by a type relation.
*
* This error is how a type relation signals that it has failed.
*
*/
struct TypeRelationError : Error {
explicit TypeRelationError(const std::string& msg)
: Error(msg) {}
};
/*! /*!
* \brief The identity type relation, all the types are equal. * \brief The identity type relation, all the types are equal.
* *
...@@ -72,22 +61,6 @@ bool BroadcastCompRel(const Array<Type>& types, ...@@ -72,22 +61,6 @@ bool BroadcastCompRel(const Array<Type>& types,
const Attrs& attrs, const Attrs& attrs,
const TypeReporter& reporter); const TypeReporter& reporter);
/*!
* \brief The concat type relation, implements the concatenating
* rule over the list of input types producing one concatenated
* type.
*
* \param types The input and output types to the relation.
* \param num_inputs The number of input arguments.
* \param attrs The attributes
* \param reporter The reporter.
* \return true whether relation has been resolved.
*/
bool ConcatRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter);
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
......
...@@ -63,6 +63,7 @@ TVM_REGISTER_API("relay.op.vision._make.multibox_prior") ...@@ -63,6 +63,7 @@ TVM_REGISTER_API("relay.op.vision._make.multibox_prior")
RELAY_REGISTER_OP("vision.multibox_prior") RELAY_REGISTER_OP("vision.multibox_prior")
.describe(R"doc("Generate prior(anchor) boxes from data, sizes and ratios." .describe(R"doc("Generate prior(anchor) boxes from data, sizes and ratios."
)doc" TVM_ADD_FILELINE) )doc" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.MultiBoxPriorAttrs")
.set_num_inputs(1) .set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.") .add_argument("data", "Tensor", "The input tensor.")
.set_support_level(4) .set_support_level(4)
......
...@@ -34,7 +34,7 @@ bool SameNDArray(const NDArray& lhs, const NDArray& rhs) { ...@@ -34,7 +34,7 @@ bool SameNDArray(const NDArray& lhs, const NDArray& rhs) {
} }
struct TypeAlphaEq : TypeVisitor<const Type&> { struct TypeAlphaEq : TypeVisitor<const Type&> {
tvm::Map<TypeParam, TypeParam> eq_map; tvm::Map<TypeVar, TypeVar> eq_map;
bool equal; bool equal;
TypeAlphaEq() : eq_map(), equal(true) {} TypeAlphaEq() : eq_map(), equal(true) {}
...@@ -76,10 +76,10 @@ struct TypeAlphaEq : TypeVisitor<const Type&> { ...@@ -76,10 +76,10 @@ struct TypeAlphaEq : TypeVisitor<const Type&> {
} }
} }
void VisitType_(const TypeParamNode* ti1, const Type& t2) final { void VisitType_(const TypeVarNode* ti1, const Type& t2) final {
if (const TypeParamNode* ti2 = t2.as<TypeParamNode>()) { if (const TypeVarNode* ti2 = t2.as<TypeVarNode>()) {
auto tid1 = GetRef<TypeParam>(ti1); auto tid1 = GetRef<TypeVar>(ti1);
auto tid2 = GetRef<TypeParam>(ti2); auto tid2 = GetRef<TypeVar>(ti2);
// We handle open terms with this rule assuming variables are identical. // We handle open terms with this rule assuming variables are identical.
// //
......
...@@ -20,7 +20,9 @@ bool IsBoolLit(const Expr& e, bool b) { ...@@ -20,7 +20,9 @@ bool IsBoolLit(const Expr& e, bool b) {
if (const ConstantNode* c = e.as<ConstantNode>()) { if (const ConstantNode* c = e.as<ConstantNode>()) {
if (c->is_scalar()) { if (c->is_scalar()) {
auto dt = c->tensor_type()->dtype; auto dt = c->tensor_type()->dtype;
if (dt == UInt(8)) { if (dt == Bool()) {
return *reinterpret_cast<const uint8_t*>(c->data->data) == b;
} else if (dt == UInt(8)) {
return *reinterpret_cast<const uint8_t*>(c->data->data) == b; return *reinterpret_cast<const uint8_t*>(c->data->data) == b;
} else if (dt == UInt(16)) { } else if (dt == UInt(16)) {
return *reinterpret_cast<const uint16_t*>(c->data->data) == b; return *reinterpret_cast<const uint16_t*>(c->data->data) == b;
......
...@@ -20,7 +20,7 @@ namespace tvm { ...@@ -20,7 +20,7 @@ namespace tvm {
namespace relay { namespace relay {
using namespace tvm::runtime; using namespace tvm::runtime;
using Kind = TypeParamNode::Kind; using Kind = TypeVarNode::Kind;
struct KindChecker : TypeVisitor<> { struct KindChecker : TypeVisitor<> {
bool valid; bool valid;
...@@ -33,7 +33,7 @@ struct KindChecker : TypeVisitor<> { ...@@ -33,7 +33,7 @@ struct KindChecker : TypeVisitor<> {
return tv->kind == k; return tv->kind == k;
} }
if (const TypeParamNode *tp = t.as<TypeParamNode>()) { if (const TypeVarNode *tp = t.as<TypeVarNode>()) {
return tp->kind == k; return tp->kind == k;
} }
......
...@@ -61,7 +61,7 @@ class LetList { ...@@ -61,7 +61,7 @@ class LetList {
* \return a Var that hold the inserted expr. * \return a Var that hold the inserted expr.
*/ */
Var Push(Expr expr) { Var Push(Expr expr) {
return Push(IncompleteTypeNode::make(TypeParamNode::kType), expr); return Push(IncompleteTypeNode::make(TypeVarNode::kType), expr);
} }
/*! /*!
......
...@@ -61,7 +61,7 @@ class TypeFunctor<R(const Type& n, Args...)> { ...@@ -61,7 +61,7 @@ class TypeFunctor<R(const Type& n, Args...)> {
// Functions that can be overriden by subclass // Functions that can be overriden by subclass
virtual R VisitType_(const TensorTypeNode* op, virtual R VisitType_(const TensorTypeNode* op,
Args... args) TYPE_FUNCTOR_DEFAULT; Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const TypeParamNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const TypeVarNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const TypeConstraintNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const TypeConstraintNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const FuncTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const FuncTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const TypeRelationNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const TypeRelationNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
...@@ -79,7 +79,7 @@ class TypeFunctor<R(const Type& n, Args...)> { ...@@ -79,7 +79,7 @@ class TypeFunctor<R(const Type& n, Args...)> {
FType vtable; FType vtable;
// Set dispatch // Set dispatch
RELAY_TYPE_FUNCTOR_DISPATCH(TensorTypeNode); RELAY_TYPE_FUNCTOR_DISPATCH(TensorTypeNode);
RELAY_TYPE_FUNCTOR_DISPATCH(TypeParamNode); RELAY_TYPE_FUNCTOR_DISPATCH(TypeVarNode);
RELAY_TYPE_FUNCTOR_DISPATCH(TypeConstraintNode); RELAY_TYPE_FUNCTOR_DISPATCH(TypeConstraintNode);
RELAY_TYPE_FUNCTOR_DISPATCH(FuncTypeNode); RELAY_TYPE_FUNCTOR_DISPATCH(FuncTypeNode);
RELAY_TYPE_FUNCTOR_DISPATCH(TypeRelationNode); RELAY_TYPE_FUNCTOR_DISPATCH(TypeRelationNode);
......
...@@ -28,6 +28,39 @@ ...@@ -28,6 +28,39 @@
namespace tvm { namespace tvm {
namespace relay { namespace relay {
// Necessary deferred relation for TupleGetItem
struct TupleGetItemAttrs : public tvm::AttrsNode<TupleGetItemAttrs> {
int index;
TVM_DECLARE_ATTRS(TupleGetItemAttrs, "relay.attrs.TupleGetItemAttrs") {
TVM_ATTR_FIELD(index);
}
};
bool TupleGetItemRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
if (types[0].as<IncompleteTypeNode>()) return false;
const auto* data = types[0].as<TupleTypeNode>();
CHECK(data != nullptr)
<< "TupleGetItem expect input type to be TupleType "
<< " get " << types[0] << " instead";
const auto* param = attrs.as<TupleGetItemAttrs>();
CHECK(param != nullptr);
CHECK_GE(param->index, 0);
CHECK_LT(param->index, data->fields.size());
reporter->Assign(types[1], data->fields[param->index]);
return true;
}
TVM_REGISTER_NODE_TYPE(TupleGetItemAttrs);
TVM_REGISTER_API("tvm.relay.type_relation.TupleGetItem")
.set_body_typed<bool(const Array<Type>&, int, const Attrs&, const TypeReporter&)>(
TupleGetItemRel);
// //
// The inference algorithm can roughly be devided into three stages: // The inference algorithm can roughly be devided into three stages:
// - Populate the constraints by visiting the expression (TypeInferencer.GetType) // - Populate the constraints by visiting the expression (TypeInferencer.GetType)
...@@ -38,8 +71,7 @@ namespace relay { ...@@ -38,8 +71,7 @@ namespace relay {
class TypeInferencer : private ExprFunctor<Type(const Expr&)> { class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
public: public:
// constructors // constructors
TypeInferencer() TypeInferencer() {
: env_(EnvironmentNode::make({})) {
} }
explicit TypeInferencer(Environment env) explicit TypeInferencer(Environment env)
: env_(env) { : env_(env) {
...@@ -58,6 +90,8 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { ...@@ -58,6 +90,8 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
std::unordered_map<Expr, Type, NodeHash, NodeEqual> type_map_; std::unordered_map<Expr, Type, NodeHash, NodeEqual> type_map_;
// The solver used by the inferencer. // The solver used by the inferencer.
TypeSolver solver_; TypeSolver solver_;
// relation function
TypeRelationFn tuple_getitem_rel_;
// Unify two types // Unify two types
Type Unify(const Type& t1, const Type& t2, const Span& span) { Type Unify(const Type& t1, const Type& t2, const Span& span) {
// TODO(tqchen, jroesch): propagate span to solver // TODO(tqchen, jroesch): propagate span to solver
...@@ -90,12 +124,14 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { ...@@ -90,12 +124,14 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
if (op->type_annotation.defined()) { if (op->type_annotation.defined()) {
return op->type_annotation; return op->type_annotation;
} else { } else {
return IncompleteTypeNode::make(TypeParamNode::kType); return IncompleteTypeNode::make(TypeVarNode::kType);
} }
} }
Type VisitExpr_(const GlobalVarNode* op) final { Type VisitExpr_(const GlobalVarNode* op) final {
GlobalVar var = GetRef<GlobalVar>(op); GlobalVar var = GetRef<GlobalVar>(op);
CHECK(env_.defined())
<< "Cannot do type inference without a global variable";
Expr e = env_->Lookup(var); Expr e = env_->Lookup(var);
return e->checked_type(); return e->checked_type();
} }
...@@ -116,17 +152,17 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { ...@@ -116,17 +152,17 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
} }
Type VisitExpr_(const TupleGetItemNode* op) final { Type VisitExpr_(const TupleGetItemNode* op) final {
// TODO(M.K.) if (!tuple_getitem_rel_.defined()) {
// handle case where field type is not known tuple_getitem_rel_ = TypeRelationFn(
Type tuple_type = GetType(op->tuple); EnvFunc::Get("tvm.relay.type_relation.TupleGetItem").node_);
auto tuple_ty_node = tuple_type.as<TupleTypeNode>();
if (!tuple_ty_node) {
LOG(FATAL) << "only expressions with tuple types is accepted" << GetRef<TupleGetItem>(op);
}
if (static_cast<int>(tuple_ty_node->fields.size()) <= op->index) {
LOG(FATAL) << "tuple not big enough" << GetRef<TupleGetItem>(op);
} }
return tuple_ty_node->fields[op->index]; Type tuple_type = GetType(op->tuple);
Type rtype = IncompleteTypeNode::make(TypeVarNode::Kind::kType);
auto attrs = make_node<TupleGetItemAttrs>();
attrs->index = op->index;
solver_.AddConstraint(TypeRelationNode::make(
tuple_getitem_rel_, {tuple_type, rtype}, 1, Attrs(attrs)));
return rtype;
} }
Type VisitExpr_(const OpNode* op) final { Type VisitExpr_(const OpNode* op) final {
...@@ -169,7 +205,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { ...@@ -169,7 +205,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
for (size_t i = 0; i < op->type_params.size(); ++i) { for (size_t i = 0; i < op->type_params.size(); ++i) {
if (!op->type_params[i].same_as(rel->args[i])) return Type(); if (!op->type_params[i].same_as(rel->args[i])) return Type();
} }
Type rtype = IncompleteTypeNode::make(TypeParamNode::Kind::kType); Type rtype = IncompleteTypeNode::make(TypeVarNode::Kind::kType);
arg_types.push_back(rtype); arg_types.push_back(rtype);
// we can do simple replacement here // we can do simple replacement here
solver_.AddConstraint(TypeRelationNode::make( solver_.AddConstraint(TypeRelationNode::make(
...@@ -179,7 +215,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { ...@@ -179,7 +215,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
// instantiate the function type with fresh // instantiate the function type with fresh
FuncType Instantiate(const FuncTypeNode* fn_ty, Array<Type>* ty_args) { FuncType Instantiate(const FuncTypeNode* fn_ty, Array<Type>* ty_args) {
tvm::Map<TypeParam, Type> subst_map; tvm::Map<TypeVar, Type> subst_map;
// Build a subsitituion map up from the function type and type arguments. // Build a subsitituion map up from the function type and type arguments.
// Eventually allow the type vars to be passed in. // Eventually allow the type vars to be passed in.
...@@ -196,7 +232,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { ...@@ -196,7 +232,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
// This is a temporary work around to check recursive functions whose // This is a temporary work around to check recursive functions whose
// return type is not yet known. // return type is not yet known.
if (!ret_type.defined()) { if (!ret_type.defined()) {
ret_type = IncompleteTypeNode::make(TypeParamNode::Kind::kType); ret_type = IncompleteTypeNode::make(TypeVarNode::Kind::kType);
} }
Type inst_ty = FuncTypeNode::make(fn_ty->arg_types, Type inst_ty = FuncTypeNode::make(fn_ty->arg_types,
ret_type, {}, ret_type, {},
...@@ -305,7 +341,6 @@ class TypeInferencer::Resolver : public ExprMutator { ...@@ -305,7 +341,6 @@ class TypeInferencer::Resolver : public ExprMutator {
return AttachCheckedType(op); return AttachCheckedType(op);
} }
Expr VisitExpr_(const FunctionNode* op) final { Expr VisitExpr_(const FunctionNode* op) final {
return AttachCheckedType(op); return AttachCheckedType(op);
} }
...@@ -363,20 +398,21 @@ Expr TypeInferencer::Infer(Expr expr) { ...@@ -363,20 +398,21 @@ Expr TypeInferencer::Infer(Expr expr) {
return Resolver(type_map_, &solver_).VisitExpr(expr); return Resolver(type_map_, &solver_).VisitExpr(expr);
} }
Expr InferType(const Environment& env, const Expr& expr) {
Expr InferType(const Expr& expr, const Environment& env) {
return TypeInferencer(env).Infer(expr); return TypeInferencer(env).Infer(expr);
} }
Expr InferType(const Environment& env, Function InferType(const Function& func,
const GlobalVar& var, const Environment& env,
const Function& func) { const GlobalVar& var) {
Function func_copy = Function(make_node<FunctionNode>(*func.operator->())); Function func_copy = Function(make_node<FunctionNode>(*func.operator->()));
func_copy->checked_type_ = func_copy->func_type_annotation(); func_copy->checked_type_ = func_copy->func_type_annotation();
env->functions.Set(var, func_copy); env->functions.Set(var, func_copy);
Expr func_ret = TypeInferencer(env).Infer(func_copy); Expr func_ret = TypeInferencer(env).Infer(func_copy);
auto map_node = env->functions.CopyOnWrite(); auto map_node = env->functions.CopyOnWrite();
map_node->data.erase(var.node_); map_node->data.erase(var.node_);
return func_ret; return Downcast<Function>(func_ret);
} }
TVM_REGISTER_API("relay._ir_pass.infer_type") TVM_REGISTER_API("relay._ir_pass.infer_type")
......
...@@ -10,13 +10,13 @@ namespace tvm { ...@@ -10,13 +10,13 @@ namespace tvm {
namespace relay { namespace relay {
struct TypeSubstV : TypeMutator { struct TypeSubstV : TypeMutator {
tvm::Map<TypeParam, Type> subst_map; tvm::Map<TypeVar, Type> subst_map;
explicit TypeSubstV(tvm::Map<TypeParam, Type> subst_map) explicit TypeSubstV(tvm::Map<TypeVar, Type> subst_map)
: subst_map(subst_map) {} : subst_map(subst_map) {}
Type VisitType_(const TypeParamNode* op) override { Type VisitType_(const TypeVarNode* op) override {
auto id = GetRef<TypeParam>(op); auto id = GetRef<TypeVar>(op);
if (subst_map.find(id) != subst_map.end()) { if (subst_map.find(id) != subst_map.end()) {
return this->subst_map[id]; return this->subst_map[id];
} else { } else {
...@@ -25,12 +25,12 @@ struct TypeSubstV : TypeMutator { ...@@ -25,12 +25,12 @@ struct TypeSubstV : TypeMutator {
} }
}; };
Type TypeSubst(const Type& type, const TypeParam& target, const Type& subst) { Type TypeSubst(const Type& type, const TypeVar& target, const Type& subst) {
TypeSubstV ty_sub({ {target, subst} }); TypeSubstV ty_sub({ {target, subst} });
return ty_sub.VisitType(type); return ty_sub.VisitType(type);
} }
Type TypeSubst(const Type& type, tvm::Map<TypeParam, Type> subst_map) { Type TypeSubst(const Type& type, tvm::Map<TypeVar, Type> subst_map) {
TypeSubstV ty_sub(subst_map); TypeSubstV ty_sub(subst_map);
return ty_sub.VisitType(type); return ty_sub.VisitType(type);
} }
......
...@@ -11,8 +11,8 @@ ...@@ -11,8 +11,8 @@
namespace tvm { namespace tvm {
namespace relay { namespace relay {
Type TypeSubst(const Type& type, const TypeParam& target, const Type& subst); Type TypeSubst(const Type& type, const TypeVar& target, const Type& subst);
Type TypeSubst(const Type& type, tvm::Map<TypeParam, Type> subst_map); Type TypeSubst(const Type& type, tvm::Map<TypeVar, Type> subst_map);
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
......
...@@ -19,7 +19,7 @@ namespace relay { ...@@ -19,7 +19,7 @@ namespace relay {
*/ */
template <typename... Args> template <typename... Args>
struct TypeVisitor : ::tvm::relay::TypeFunctor<void(const Type& n, Args...)> { struct TypeVisitor : ::tvm::relay::TypeFunctor<void(const Type& n, Args...)> {
void VisitType_(const TypeParamNode* op, Args... args) override {} void VisitType_(const TypeVarNode* op, Args... args) override {}
void VisitType_(const FuncTypeNode* op, Args... args) override { void VisitType_(const FuncTypeNode* op, Args... args) override {
for (auto type_param : op->type_params) { for (auto type_param : op->type_params) {
...@@ -60,16 +60,16 @@ struct TypeMutator : TypeFunctor<Type(const Type& n)> { ...@@ -60,16 +60,16 @@ struct TypeMutator : TypeFunctor<Type(const Type& n)> {
return TensorTypeNode::make(op->shape, op->dtype); return TensorTypeNode::make(op->shape, op->dtype);
} }
Type VisitType_(const TypeParamNode* op) override { Type VisitType_(const TypeVarNode* op) override {
return GetRef<TypeParam>(op); return GetRef<TypeVar>(op);
} }
Type VisitType_(const FuncTypeNode* op) override { Type VisitType_(const FuncTypeNode* op) override {
Array<TypeParam> type_params; Array<TypeVar> type_params;
for (auto type_param : op->type_params) { for (auto type_param : op->type_params) {
auto new_type_param = VisitType(type_param); auto new_type_param = VisitType(type_param);
if (const TypeParamNode* tin = new_type_param.as<TypeParamNode>()) { if (const TypeVarNode* tin = new_type_param.as<TypeVarNode>()) {
type_params.push_back(GetRef<TypeParam>(tin)); type_params.push_back(GetRef<TypeVar>(tin));
} else { } else {
CHECK(false) << new_type_param << std::endl; CHECK(false) << new_type_param << std::endl;
} }
......
...@@ -14,14 +14,14 @@ namespace relay { ...@@ -14,14 +14,14 @@ namespace relay {
class FreeVar; class FreeVar;
class FreeTypeVar : private TypeVisitor<> { class FreeTypeVar : private TypeVisitor<> {
std::unordered_set<TypeParam, NodeHash, NodeEqual> * free_vars; std::unordered_set<TypeVar, NodeHash, NodeEqual> * free_vars;
std::unordered_set<TypeParam, NodeHash, NodeEqual> * bound_vars; std::unordered_set<TypeVar, NodeHash, NodeEqual> * bound_vars;
FreeTypeVar(std::unordered_set<TypeParam, NodeHash, NodeEqual> * free_vars, FreeTypeVar(std::unordered_set<TypeVar, NodeHash, NodeEqual> * free_vars,
std::unordered_set<TypeParam, NodeHash, NodeEqual> * bound_vars) : std::unordered_set<TypeVar, NodeHash, NodeEqual> * bound_vars) :
free_vars(free_vars), bound_vars(bound_vars) { } free_vars(free_vars), bound_vars(bound_vars) { }
void VisitType_(const TypeParamNode* tp) final { void VisitType_(const TypeVarNode* tp) final {
auto var = GetRef<TypeParam>(tp); auto var = GetRef<TypeVar>(tp);
if (bound_vars->count(var) == 0) { if (bound_vars->count(var) == 0) {
free_vars->insert(var); free_vars->insert(var);
} }
...@@ -75,8 +75,8 @@ class FreeVar : public ExprVisitor { ...@@ -75,8 +75,8 @@ class FreeVar : public ExprVisitor {
public: public:
std::unordered_set<Var, NodeHash, NodeEqual> free_vars; std::unordered_set<Var, NodeHash, NodeEqual> free_vars;
std::unordered_set<Var, NodeHash, NodeEqual> bound_vars; std::unordered_set<Var, NodeHash, NodeEqual> bound_vars;
std::unordered_set<TypeParam, NodeHash, NodeEqual> free_types; std::unordered_set<TypeVar, NodeHash, NodeEqual> free_types;
std::unordered_set<TypeParam, NodeHash, NodeEqual> bound_types; std::unordered_set<TypeVar, NodeHash, NodeEqual> bound_types;
void VisitType(const Type& t) final { void VisitType(const Type& t) final {
FreeTypeVar(&free_types, &bound_types)(t); FreeTypeVar(&free_types, &bound_types)(t);
...@@ -89,16 +89,16 @@ tvm::Array<Var> FreeVariables(const Expr& e) { ...@@ -89,16 +89,16 @@ tvm::Array<Var> FreeVariables(const Expr& e) {
return tvm::Array<Var>(fv.free_vars.begin(), fv.free_vars.end()); return tvm::Array<Var>(fv.free_vars.begin(), fv.free_vars.end());
} }
tvm::Array<TypeParam> FreeTypeVariables(const Expr& e) { tvm::Array<TypeVar> FreeTypeVariables(const Expr& e) {
FreeVar fv; FreeVar fv;
fv.VisitExpr(e); fv.VisitExpr(e);
return tvm::Array<TypeParam>(fv.free_types.begin(), fv.free_types.end()); return tvm::Array<TypeVar>(fv.free_types.begin(), fv.free_types.end());
} }
tvm::Array<TypeParam> FreeTypeVariables(const Type& t) { tvm::Array<TypeVar> FreeTypeVariables(const Type& t) {
FreeVar fv; FreeVar fv;
fv.VisitType(t); fv.VisitType(t);
return tvm::Array<TypeParam>(fv.free_types.begin(), fv.free_types.end()); return tvm::Array<TypeVar>(fv.free_types.begin(), fv.free_types.end());
} }
TVM_REGISTER_API("relay._ir_pass.free_vars") TVM_REGISTER_API("relay._ir_pass.free_vars")
......
import numpy as np
from tvm.relay.expr import Let, Constant
from tvm.relay.ir_builder import IRBuilder
def test_let():
b = IRBuilder()
x = b.let('x', 1)
b.ret(x)
prog, _ = b.get()
assert isinstance(prog, Let)
var = prog.var
value = prog.value
assert var.name_hint == 'x'
assert var == prog.body
assert isinstance(value, Constant)
assert value.data.asnumpy() == np.array(1)
if __name__ == "__main__":
test_let()
...@@ -34,7 +34,7 @@ def test_tensor_type(): ...@@ -34,7 +34,7 @@ def test_tensor_type():
def test_type_param(): def test_type_param():
tp = relay.TypeParam('name', relay.Kind.Type) tp = relay.TypeVar('name', relay.Kind.Type)
assert tp.kind == relay.Kind.Type assert tp.kind == relay.Kind.Type
# assert tp.span # TODO allow us to set span # assert tp.span # TODO allow us to set span
str(tp) str(tp)
...@@ -56,7 +56,7 @@ def test_func_type(): ...@@ -56,7 +56,7 @@ def test_func_type():
def test_tuple_type(): def test_tuple_type():
tp = relay.TypeParam('tp', relay.Kind.Type) tp = relay.TypeVar('tp', relay.Kind.Type)
tf = relay.FuncType(tvm.convert([]), None, tvm.convert([]), tvm.convert([])) tf = relay.FuncType(tvm.convert([]), None, tvm.convert([]), tvm.convert([]))
tt = relay.TensorType(tvm.convert([1, 2, 3]), 'float32') tt = relay.TensorType(tvm.convert([1, 2, 3]), 'float32')
fields = tvm.convert([tp, tf, tt]) fields = tvm.convert([tp, tf, tt])
...@@ -66,7 +66,7 @@ def test_tuple_type(): ...@@ -66,7 +66,7 @@ def test_tuple_type():
def test_type_relation(): def test_type_relation():
tp = relay.TypeParam('tp', relay.Kind.Type) tp = relay.TypeVar('tp', relay.Kind.Type)
tf = relay.FuncType(tvm.convert([]), None, tvm.convert([]), tvm.convert([])) tf = relay.FuncType(tvm.convert([]), None, tvm.convert([]), tvm.convert([]))
tt = relay.TensorType(tvm.convert([1, 2, 3]), 'float32') tt = relay.TensorType(tvm.convert([1, 2, 3]), 'float32')
args = tvm.convert([tf, tt, tp]) args = tvm.convert([tf, tt, tp])
...@@ -173,7 +173,7 @@ def test_if(): ...@@ -173,7 +173,7 @@ def test_if():
def test_tuple_get_item(): def test_tuple_get_item():
tup = relay.Var("tuple") tup = relay.Var("tuple")
get = relay.TupleGetItem(tup, 1) get = relay.TupleGetItem(tup, 1)
assert get.tuple == tup assert get.tuple_value == tup
assert get.index == 1 assert get.index == 1
str(get) str(get)
......
...@@ -27,7 +27,7 @@ def test_env(): ...@@ -27,7 +27,7 @@ def test_env():
z = relay.add(z, z) z = relay.add(z, z)
f = relay.Function([x, y], z) f = relay.Function([x, y], z)
env = relay.Environment() env = relay.Environment()
env.add("myf", f) env["myf"] = f
text = env.astext() text = env.astext()
assert "def @myf" in text assert "def @myf" in text
assert "%1 = add(%0, %0) # ty=float32" in text assert "%1 = add(%0, %0) # ty=float32" in text
...@@ -70,15 +70,18 @@ def test_let_if_scope(): ...@@ -70,15 +70,18 @@ def test_let_if_scope():
x = relay.var("x", "float32") x = relay.var("x", "float32")
y = relay.var("y", "float32") y = relay.var("y", "float32")
cond = relay.var("cond", "bool") cond = relay.var("cond", "bool")
v1 = relay.var("v")
v2 = relay.var("v", "float32") sb = relay.ScopeBuilder()
then_branch = relay.Let( with sb.if_scope(cond):
v1, relay.const(1, "float32"), v1 = sb.let("v", relay.const(1, "float32"))
relay.Let(v2, x, relay.subtract(v1, v2))) v2 = sb.let("v", x)
sb.ret(relay.subtract(v1, v2))
with sb.else_scope():
v3 = relay.var("v") v3 = relay.var("v")
let2 = relay.Let(v3, y, v3) let2 = relay.Let(v3, y, v3)
else_branch = relay.add(let2, let2) sb.ret(relay.add(let2, let2))
result = relay.If(cond, then_branch, else_branch) result = sb.get()
f = relay.Function([x, y, cond], result) f = relay.Function([x, y, cond], result)
text = f.astext() text = f.astext()
assert text.count("{") == 4 assert text.count("{") == 4
...@@ -86,10 +89,17 @@ def test_let_if_scope(): ...@@ -86,10 +89,17 @@ def test_let_if_scope():
show(f.astext()) show(f.astext())
def test_variable_name():
# avoid pure number even if the namehint is pure number
v1 = relay.var("1")
assert "%v1" in v1.astext()
if __name__ == "__main__": if __name__ == "__main__":
do_print[0] = True do_print[0] = True
test_let_if_scope()
test_func() test_func()
test_env() test_env()
test_meta_data() test_meta_data()
test_call_attrs() test_call_attrs()
test_let_if_scope()
test_variable_name()
import tvm import tvm
import numpy as np import numpy as np
from tvm import relay from tvm import relay
from tvm.relay.ir_pass import infer_type
from tvm.relay.ir_builder import IRBuilder, func_type
from tvm.relay.ir_builder import scalar_type, convert, tensor_type
from tvm.relay.env import Environment
def assert_has_type(expr, typ, env=Environment({})):
checked_expr = infer_type(env, expr)
checked_type = checked_expr.checked_type
if checked_type != typ:
raise RuntimeError("Type mismatch %s vs %s" % (
checked_type, typ))
def test_binary_op(): def test_binary_op():
def check_binary_op(opfunc): def check_binary_op(opfunc):
""" n = tvm.var("n")
Program: t1 = relay.TensorType((5, n, 5))
fn (x, y) { t2 = relay.TensorType((n, 1))
return x <op> y; x = relay.var("x", t1)
} y = relay.var("y", t2)
""" z = opfunc(x, y)
b = IRBuilder() # test printer
assert ("%0 = {}(%x, %y)".format(z.op.name)) in z.astext()
x = b.param('x', tensor_type(5, 5, 5)) assert relay.ir_pass.infer_type(z).checked_type == t1
y = b.param('y', tensor_type(5, 5, 5))
with b.function(x, y) as func:
b.ret(opfunc(x, y))
b.ret(func)
prog, env = b.get()
ttype = tensor_type(5, 5, 5)
expected_ty = func_type([ttype, ttype], ttype)
assert_has_type(func.to_func(), expected_ty)
for opfunc in [relay.pow]: for opfunc in [relay.pow]:
check_binary_op(opfunc) check_binary_op(opfunc)
def test_binary_broadcast_op():
def check_binary_broadcast_op(opfunc):
"""
Program:
fn (x: Tensor[(10, 4), f32], y: Tensor[(5, 10, 1), f32]) -> Tensor[(5, 10, 4), f32] {
return x <op> y;
}
"""
b = IRBuilder()
x = b.param('x', tensor_type(10, 4))
y = b.param('y', tensor_type(5, 10, 1))
with b.function(x, y) as func:
b.ret(opfunc(x, y))
b.ret(func)
prog, env = b.get()
expected_ty = func_type([tensor_type(10, 4), tensor_type(5, 10, 1)],
tensor_type(5, 10, 4))
assert_has_type(func.to_func(), expected_ty)
for opfunc in [relay.pow]:
check_binary_broadcast_op(opfunc)
def test_cmp_type(): def test_cmp_type():
for op in (relay.greater, for op in (relay.greater,
relay.greater_equal, relay.greater_equal,
...@@ -68,138 +26,59 @@ def test_cmp_type(): ...@@ -68,138 +26,59 @@ def test_cmp_type():
relay.less_equal, relay.less_equal,
relay.equal, relay.equal,
relay.not_equal): relay.not_equal):
ib = relay.ir_builder.IRBuilder() x = relay.var("x", relay.TensorType((10, 4), "float32"))
x = ib.param("x", relay.TensorType((10, 4), "float32")) y = relay.var("y", relay.TensorType((5, 10, 1), "float32"))
y = ib.param("y", relay.TensorType((5, 10, 1), "float32")) z = op(x, y)
with ib.function(x, y) as func: z.astext()
ib.ret(op(x, y)) zz = relay.ir_pass.infer_type(z)
ib.ret(func) assert zz.checked_type == relay.TensorType((5, 10, 4), "bool")
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.TensorType((5, 10, 4), "uint1")
def test_binary_broadcast(): def test_binary_int_broadcast():
for op in [relay.right_shift, for op in [relay.right_shift,
relay.left_shift, relay.left_shift,
relay.maximum, relay.maximum,
relay.minimum]: relay.minimum]:
ib = relay.ir_builder.IRBuilder() x = relay.var("x", relay.TensorType((10, 4), "int32"))
x = ib.param("x", relay.TensorType((10, 4), "int32")) y = relay.var("y", relay.TensorType((5, 10, 1), "int32"))
y = ib.param("y", relay.TensorType((5, 10, 1), "int32")) z = op(x, y)
with ib.function(x, y) as func: zz = relay.ir_pass.infer_type(z)
ib.ret(op(x, y)) assert zz.checked_type == relay.TensorType((5, 10, 4), "int32")
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type def test_arg_reduce():
assert ftype.ret_type == relay.TensorType((5, 10, 4), "int32") for op in [relay.argmax, relay.argmin]:
n, c , h, w = 10, 20, 3, 4
def test_argmax(): x = relay.var("x", relay.ty.TensorType((n, c , h, w), "float32"))
ib = relay.ir_builder.IRBuilder() z = relay.argmax(x, axis=(1,))
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") "axis=" in z.astext()
x = ib.param("x", relay.ty.TensorType((n, c , h, w), "float32")) zz = relay.ir_pass.infer_type(z)
with ib.function(x) as func: assert zz.checked_type == relay.ty.TensorType((n, h, w), "int32")
ib.ret(relay.argmax(x, axis=(1,)))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType((n, h, w), "int32")
ib = relay.ir_builder.IRBuilder()
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = ib.param("x", relay.ty.TensorType((n, c , h, w), "float32"))
with ib.function(x) as func:
ib.ret(relay.argmax(x, axis=(2,), keepdims=True))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType((n, c , 1, w), "int32")
ib = relay.ir_builder.IRBuilder()
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = ib.param("x", relay.ty.TensorType((n, c , h, w), "float32")) x = relay.var("x", relay.ty.TensorType((n, c , h, w), "float32"))
with ib.function(x) as func: z = relay.argmax(x, axis=(2,), keepdims=True)
ib.ret(relay.argmax(x, axis=(2,), keepdims=True, exclude=True)) zz = relay.ir_pass.infer_type(z)
ib.ret(func) assert zz.checked_type == relay.ty.TensorType((n, c , 1, w), "int32")
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType((1, 1 , h, 1), "int32")
def test_argmin():
ib = relay.ir_builder.IRBuilder()
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = ib.param("x", relay.ty.TensorType((n, c , h, w), "float32"))
with ib.function(x) as func:
ib.ret(relay.argmax(x, axis=(1,)))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType((n, h, w), "int32")
ib = relay.ir_builder.IRBuilder()
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = ib.param("x", relay.ty.TensorType((n, c , h, w), "float32"))
with ib.function(x) as func:
ib.ret(relay.argmin(x, axis=(2,), keepdims=True))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType((n, c , 1, w), "int32")
ib = relay.ir_builder.IRBuilder()
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = ib.param("x", relay.ty.TensorType((n, c , h, w), "float32"))
with ib.function(x) as func:
ib.ret(relay.argmin(x, axis=(2,), keepdims=True, exclude=True))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType((1, 1 , h, 1), "int32")
ib = relay.ir_builder.IRBuilder()
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = ib.param("x", relay.ty.TensorType((n, c , h, w), "float32"))
with ib.function(x) as func:
ib.ret(relay.argmin(x, axis=(2,1), keepdims=True, exclude=True))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType((1, c , h, 1), "int32")
ib = relay.ir_builder.IRBuilder()
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = ib.param("x", relay.ty.TensorType((n, c , h, w), "float32")) x = relay.var("x", relay.ty.TensorType((n, c , h, w), "float32"))
with ib.function(x) as func: z = relay.argmax(x, axis=(2,), keepdims=True, exclude=True)
ib.ret(relay.argmin(x, axis=None, keepdims=True, exclude=True)) zz = relay.ir_pass.infer_type(z)
ib.ret(func) assert zz.checked_type == relay.ty.TensorType((1, 1 , h, 1), "int32")
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType((1, 1 , 1, 1), "int32")
def test_where(): def test_where():
ib = relay.ir_builder.IRBuilder() cond = relay.var("cond", relay.TensorType((3, 4), "float32"))
cond = ib.param("cond", relay.TensorType((3, 4), "float32")) x = relay.var("x", relay.TensorType((3, 4), "float32"))
x = ib.param("x", relay.TensorType((3, 4), "float32")) y = relay.var("y", relay.TensorType((3, 4), "float32"))
y = ib.param("y", relay.TensorType((3, 4), "float32")) z = relay.where(cond, x, y)
with ib.function(cond, x, y) as func: zz = relay.ir_pass.infer_type(z)
ib.ret(relay.where(cond, x, y)) assert zz.checked_type == relay.TensorType((3, 4), "float32")
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.TensorType((3, 4), "float32")
if __name__ == "__main__": if __name__ == "__main__":
test_binary_op() test_binary_op()
test_binary_broadcast_op()
test_cmp_type() test_cmp_type()
test_binary_broadcast() test_binary_int_broadcast()
test_where() test_where()
test_multibox_prior() test_arg_reduce()
test_argmax()
test_argmin()
...@@ -4,26 +4,18 @@ import tvm ...@@ -4,26 +4,18 @@ import tvm
from tvm import relay from tvm import relay
def test_resize_infer_type(): def test_resize_infer_type():
ib = relay.ir_builder.IRBuilder()
n, c, h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") n, c, h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "int8")) x = relay.var("x", relay.TensorType((n, c, h, w), "int8"))
th, tw = tvm.var("th"), tvm.var("tw") th, tw = tvm.var("th"), tvm.var("tw")
z = relay.image.resize(x, (th, tw))
zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.TensorType((n, c, th, tw), "int8")
with ib.function(x) as func: x = relay.var("x", relay.TensorType((n, c, h, w), "int8"))
ib.ret(relay.image.resize(x, (th, tw))) z= relay.image.resize(x, (100, 200), "NCHW", "BILINEAR", False)
ib.ret(func) assert "size=" in z.astext()
func = relay.ir_pass.infer_type(ib.env, func.to_func()) zz = relay.ir_pass.infer_type(z)
ftype = func.checked_type assert zz.checked_type == relay.TensorType((n, c, 100, 200), "int8")
assert ftype.ret_type == relay.ty.TensorType((n, c, th, tw), "int8")
ib = relay.ir_builder.IRBuilder()
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "int8"))
with ib.function(x) as func:
ib.ret(relay.image.resize(x, (100, 200), "NCHW", "BILINEAR", False))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType((n, c, 100, 200), "int8")
...@@ -34,29 +26,21 @@ def test_multibox_prior(): ...@@ -34,29 +26,21 @@ def test_multibox_prior():
offsets = (0.2, 0.3) offsets = (0.2, 0.3)
clip = True clip = True
ib = relay.ir_builder.IRBuilder()
n, c, h, w = tvm.var("n"), 3, 56, 56 n, c, h, w = tvm.var("n"), 3, 56, 56
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32")) x = relay.var("x", relay.TensorType((n, c, h, w), "float32"))
with ib.function(x) as func: z = relay.vision.multibox_prior(x, sizes, ratios,
ib.ret(relay.vision.multibox_prior(x, sizes, ratios, steps, offsets, clip)
steps, offsets, clip)) assert "sizes=" in z.astext()
ib.ret(func) zz = relay.ir_pass.infer_type(z)
func = relay.ir_pass.infer_type(ib.env, func.to_func()) assert zz.checked_type == relay.TensorType(
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType(
(1, h * w * (len(sizes) + len(ratios) - 1), 4), "float32") (1, h * w * (len(sizes) + len(ratios) - 1), 4), "float32")
ib = relay.ir_builder.IRBuilder()
n, c, h, w = tvm.var("n"), 24, 32, 32 n, c, h, w = tvm.var("n"), 24, 32, 32
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32")) x = relay.var("x", relay.TensorType((n, c, h, w), "float32"))
z = relay.vision.multibox_prior(x)
with ib.function(x) as func: zz = relay.ir_pass.infer_type(z)
ib.ret(relay.vision.multibox_prior(x)) assert zz.checked_type == relay.TensorType(
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType(
(1, h * w, 4), "float32") (1, h * w, 4), "float32")
......
...@@ -4,7 +4,7 @@ from tvm.relay.ir_pass import check_kind ...@@ -4,7 +4,7 @@ from tvm.relay.ir_pass import check_kind
def test_tuple_kind(): def test_tuple_kind():
# only contain type kinds # only contain type kinds
tp = relay.TypeParam('tp', relay.Kind.Type) tp = relay.TypeVar('tp', relay.Kind.Type)
tt = relay.TensorType(tvm.convert([1, 2, 3]), 'float32') tt = relay.TensorType(tvm.convert([1, 2, 3]), 'float32')
tf = relay.FuncType(tvm.convert([]), tt, tvm.convert([]), tvm.convert([])) tf = relay.FuncType(tvm.convert([]), tt, tvm.convert([]), tvm.convert([]))
fields = tvm.convert([tp, tf, tt]) fields = tvm.convert([tp, tf, tt])
...@@ -15,8 +15,8 @@ def test_tuple_kind(): ...@@ -15,8 +15,8 @@ def test_tuple_kind():
def test_func_kind(): def test_func_kind():
# only contain type kinds # only contain type kinds
tp1 = relay.TypeParam('tp1', relay.Kind.Type) tp1 = relay.TypeVar('tp1', relay.Kind.Type)
tp2 = relay.TypeParam('tp2', relay.Kind.Type) tp2 = relay.TypeVar('tp2', relay.Kind.Type)
shape = tvm.convert([1, 2, 3]) shape = tvm.convert([1, 2, 3])
dtype = 'float32' dtype = 'float32'
...@@ -35,7 +35,7 @@ def test_func_kind(): ...@@ -35,7 +35,7 @@ def test_func_kind():
def test_relation_kind(): def test_relation_kind():
# only have type kinds for arguments # only have type kinds for arguments
tp = relay.TypeParam('tp', relay.Kind.Type) tp = relay.TypeVar('tp', relay.Kind.Type)
tt = relay.TensorType(tvm.convert([1, 2, 3]), 'float32') tt = relay.TensorType(tvm.convert([1, 2, 3]), 'float32')
tf = relay.FuncType(tvm.convert([]), tt, tvm.convert([]), tvm.convert([])) tf = relay.FuncType(tvm.convert([]), tt, tvm.convert([]), tvm.convert([]))
args = tvm.convert([tf, tt, tp]) args = tvm.convert([tf, tt, tp])
...@@ -45,9 +45,9 @@ def test_relation_kind(): ...@@ -45,9 +45,9 @@ def test_relation_kind():
def test_invalid_tuple_kind(): def test_invalid_tuple_kind():
tp1 = relay.TypeParam('tp1', relay.Kind.Shape) tp1 = relay.TypeVar('tp1', relay.Kind.Shape)
tp2 = relay.TypeParam('tp2', relay.Kind.BaseType) tp2 = relay.TypeVar('tp2', relay.Kind.BaseType)
tp3 = relay.TypeParam('tp3', relay.Kind.ShapeVar) tp3 = relay.TypeVar('tp3', relay.Kind.ShapeVar)
fields = tvm.convert([tp1, tp2, tp3]) fields = tvm.convert([tp1, tp2, tp3])
tup_ty = relay.TupleType(fields) tup_ty = relay.TupleType(fields)
...@@ -55,9 +55,9 @@ def test_invalid_tuple_kind(): ...@@ -55,9 +55,9 @@ def test_invalid_tuple_kind():
def test_invalid_func_kind(): def test_invalid_func_kind():
tp1 = relay.TypeParam('tp1', relay.Kind.Shape) tp1 = relay.TypeVar('tp1', relay.Kind.Shape)
tp2 = relay.TypeParam('tp2', relay.Kind.BaseType) tp2 = relay.TypeVar('tp2', relay.Kind.BaseType)
tp3 = relay.TypeParam('tp3', relay.Kind.ShapeVar) tp3 = relay.TypeVar('tp3', relay.Kind.ShapeVar)
type_params = tvm.convert([tp1, tp2, tp3]) type_params = tvm.convert([tp1, tp2, tp3])
type_constraints = tvm.convert([]) type_constraints = tvm.convert([])
...@@ -69,9 +69,9 @@ def test_invalid_func_kind(): ...@@ -69,9 +69,9 @@ def test_invalid_func_kind():
def test_invalid_relation_kind(): def test_invalid_relation_kind():
tp1 = relay.TypeParam('tp1', relay.Kind.Shape) tp1 = relay.TypeVar('tp1', relay.Kind.Shape)
tp2 = relay.TypeParam('tp2', relay.Kind.BaseType) tp2 = relay.TypeVar('tp2', relay.Kind.BaseType)
tp3 = relay.TypeParam('tp3', relay.Kind.ShapeVar) tp3 = relay.TypeVar('tp3', relay.Kind.ShapeVar)
args = tvm.convert([tp1, tp2, tp3]) args = tvm.convert([tp1, tp2, tp3])
tr = relay.TypeRelation(None, args, 2, None) tr = relay.TypeRelation(None, args, 2, None)
...@@ -79,19 +79,19 @@ def test_invalid_relation_kind(): ...@@ -79,19 +79,19 @@ def test_invalid_relation_kind():
def test_func_with_invalid_ret_type(): def test_func_with_invalid_ret_type():
tp1 = relay.TypeParam('tp1', relay.Kind.Type) tp1 = relay.TypeVar('tp1', relay.Kind.Type)
tp2 = relay.TypeParam('tp2', relay.Kind.Shape) tp2 = relay.TypeVar('tp2', relay.Kind.Shape)
tf = relay.FuncType(tvm.convert([tp1]), tp2, tvm.convert([tp1, tp2]), tvm.convert([])) tf = relay.FuncType(tvm.convert([tp1]), tp2, tvm.convert([tp1, tp2]), tvm.convert([]))
def test_func_with_invalid_arg_types(): def test_func_with_invalid_arg_types():
tp1 = relay.TypeParam('tp1', relay.Kind.Shape) tp1 = relay.TypeVar('tp1', relay.Kind.Shape)
tp2 = relay.TypeParam('tp2', relay.Kind.Type) tp2 = relay.TypeVar('tp2', relay.Kind.Type)
tf = relay.FuncType(tvm.convert([tp1]), tp2, tvm.convert([tp1, tp2]), tvm.convert([])) tf = relay.FuncType(tvm.convert([tp1]), tp2, tvm.convert([tp1, tp2]), tvm.convert([]))
def test_func_with_invalid_tuple(): def test_func_with_invalid_tuple():
tp1 = relay.TypeParam('tp1', relay.Kind.Shape) tp1 = relay.TypeVar('tp1', relay.Kind.Shape)
ret_type = relay.TupleType(tvm.convert([tp1, tp1, tp1])) ret_type = relay.TupleType(tvm.convert([tp1, tp1, tp1]))
...@@ -100,9 +100,9 @@ def test_func_with_invalid_tuple(): ...@@ -100,9 +100,9 @@ def test_func_with_invalid_tuple():
def test_func_with_invalid_relation(): def test_func_with_invalid_relation():
tp1 = relay.TypeParam('tp1', relay.Kind.Type) tp1 = relay.TypeVar('tp1', relay.Kind.Type)
tp2 = relay.TypeParam('tp2', relay.Kind.Shape) tp2 = relay.TypeVar('tp2', relay.Kind.Shape)
tp3 = relay.TypeParam('tp3', relay.Kind.ShapeVar) tp3 = relay.TypeVar('tp3', relay.Kind.ShapeVar)
tr = relay.TypeRelation(None, tvm.convert([tp2, tp3]), 1, None) tr = relay.TypeRelation(None, tvm.convert([tp2, tp3]), 1, None)
...@@ -113,7 +113,7 @@ def test_func_with_invalid_relation(): ...@@ -113,7 +113,7 @@ def test_func_with_invalid_relation():
def test_tuple_with_invalid_func(): def test_tuple_with_invalid_func():
tensor_type = relay.TensorType(tvm.convert([1, 2, 3]), 'float32') tensor_type = relay.TensorType(tvm.convert([1, 2, 3]), 'float32')
tp1 = relay.TypeParam('tp1', relay.Kind.Shape) tp1 = relay.TypeVar('tp1', relay.Kind.Shape)
tf = relay.FuncType(tvm.convert([]), tp1, tvm.convert([tp1]), tvm.convert([])) tf = relay.FuncType(tvm.convert([]), tp1, tvm.convert([tp1]), tvm.convert([]))
tup_ty = relay.TupleType(tvm.convert([tensor_type, tf])) tup_ty = relay.TupleType(tvm.convert([tensor_type, tf]))
......
import tvm import tvm
from tvm import relay from tvm import relay
from tvm.relay.ir_pass import dead_code_elimination, alpha_equal from tvm.relay.ir_pass import dead_code_elimination, alpha_equal
from tvm.relay.ir_builder import convert, IRBuilder
from tvm.relay.op import log, add, equal, subtract from tvm.relay.op import log, add, equal, subtract
...@@ -19,9 +18,9 @@ class env: ...@@ -19,9 +18,9 @@ class env:
self.tt = relay.TensorType(self.shape, "float32") self.tt = relay.TensorType(self.shape, "float32")
self.int32 = relay.TensorType([], "int32") self.int32 = relay.TensorType([], "int32")
self.float32 = relay.TensorType([], "float32") self.float32 = relay.TensorType([], "float32")
self.one = convert(1.0) self.one = relay.const(1.0)
self.two = convert(2.0) self.two = relay.const(2.0)
self.three = convert(3.0) self.three = relay.const(3.0)
e = env() e = env()
...@@ -58,9 +57,12 @@ def test_recursion(): ...@@ -58,9 +57,12 @@ def test_recursion():
f = relay.Var("f") f = relay.Var("f")
n = relay.Var("n", e.int32) n = relay.Var("n", e.int32)
data = relay.Var("data", e.float32) data = relay.Var("data", e.float32)
funcbody = relay.If(equal(n, convert(0)), data, f(subtract(n, convert(1.0)), log(data))) funcbody = relay.If(equal(n, relay.const(0)),
data,
relay.Call(f, [subtract(n, relay.const(1.0)),
log(data)]))
value = relay.Function([n, data], funcbody, e.float32, []) value = relay.Function([n, data], funcbody, e.float32, [])
orig = relay.Let(f, funcbody, f(convert(2.0), convert(10000.0))) orig = relay.Let(f, funcbody, relay.Call(f, [relay.const(2.0), relay.const(10000.0)]))
assert alpha_equal(dead_code_elimination(orig), orig) assert alpha_equal(dead_code_elimination(orig), orig)
assert alpha_equal(dead_code_elimination(relay.Let(f, funcbody, e.three)), e.three) assert alpha_equal(dead_code_elimination(relay.Let(f, funcbody, e.three)), e.three)
...@@ -70,8 +72,10 @@ def test_op_let(): ...@@ -70,8 +72,10 @@ def test_op_let():
def test_if(): def test_if():
orig = relay.If(convert(True), e.a, e.b) cond = relay.const(True)
assert alpha_equal(dead_code_elimination(orig), e.a) orig = relay.If(cond, e.a, e.b)
y = dead_code_elimination(orig)
assert alpha_equal(y, e.a)
def test_tuple_get_item(): def test_tuple_get_item():
...@@ -82,10 +86,10 @@ def test_tuple_get_item(): ...@@ -82,10 +86,10 @@ def test_tuple_get_item():
if __name__ == "__main__": if __name__ == "__main__":
test_if()
test_let() test_let()
test_used_let() test_used_let()
test_chain_unused_let() test_chain_unused_let()
test_recursion() test_recursion()
test_op_let() test_op_let()
test_if()
test_tuple_get_item() test_tuple_get_item()
...@@ -28,7 +28,7 @@ def test_tuple(): ...@@ -28,7 +28,7 @@ def test_tuple():
def test_free_type_vars(): def test_free_type_vars():
tp = relay.TypeParam("") tp = relay.TypeVar("")
ty = relay.TupleType([tp, relay.TensorType([], "int32")]) ty = relay.TupleType([tp, relay.TensorType([], "int32")])
x = relay.Var("x", ty) x = relay.Var("x", ty)
y = relay.Var("y") y = relay.Var("y")
......
...@@ -4,34 +4,17 @@ ...@@ -4,34 +4,17 @@
import tvm import tvm
import numpy as np import numpy as np
from tvm.relay.ir_pass import infer_type from tvm.relay.ir_pass import infer_type
from tvm.relay.ir_builder import IRBuilder, func_type
from tvm.relay.ir_builder import scalar_type, convert, tensor_type
from tvm.relay.env import Environment
from tvm.relay.op import log, add, equal, subtract, concatenate
from tvm.relay.expr import Function
from tvm import relay from tvm import relay
def assert_has_type(expr, typ, env=Environment({})):
checked_expr = infer_type(env, expr)
checked_type = checked_expr.checked_type
if checked_type != typ:
raise RuntimeError("Type mismatch %s vs %s" % (
checked_type, typ))
def assert_decl_has_type(env, name, typ):
func = env[name]
assert func.checked_type == typ
def test_monomorphic_let(): def test_monomorphic_let():
"Program: let x = 1; return x" "Program: let x = 1; return x"
b = IRBuilder() sb = relay.ScopeBuilder()
x = b.let('x', 1.0, value_type=scalar_type('float64')) x = sb.let('x', relay.const(1.0, "float64"))
b.ret(x) sb.ret(x)
xchecked = relay.ir_pass.infer_type(sb.get())
assert xchecked.checked_type == relay.scalar_type("float64")
prog, env = b.get()
assert_has_type(prog, scalar_type('float64'))
def test_dual_op(): def test_dual_op():
"""Program: """Program:
...@@ -41,31 +24,29 @@ def test_dual_op(): ...@@ -41,31 +24,29 @@ def test_dual_op():
return t1; return t1;
} }
""" """
b = IRBuilder() tp = relay.TensorType((10, 10), "float32")
with b.function(('x', tensor_type(10, 10))) as func: x = relay.var("x", tp)
x, = func.param_ids() sb = relay.ScopeBuilder()
t1 = b.let('t1', log(x)) t1 = sb.let("t1", relay.log(x))
t2 = b.let('t2', add(t1, x)) t2 = sb.let("t2", relay.add(t1, x))
b.ret(t2) sb.ret(t2)
f = relay.Function([x], sb.get())
assert_has_type(func.to_func(), fchecked = relay.ir_pass.infer_type(f)
func_type([tensor_type(10, 10)], tensor_type(10, 10))) assert fchecked.checked_type == relay.FuncType([tp], tp)
def test_decl(): def test_decl():
"""Program: """Program:
def f(x : Tensor[f32, (10, 10)]) { def f(x : Tensor[(10, 10), f32]) {
let lx = log(x); return log(x);
return lx;
} }
""" """
b = IRBuilder() sb = relay.ScopeBuilder()
x = b.param('x') tp = relay.TensorType((10, 10))
with b.decl('f', x): x = relay.var("x", tp)
lx = b.let('lx', log(x)) f = relay.Function([x], relay.log(x))
b.ret(lx) fchecked = relay.ir_pass.infer_type(f)
_, env = b.get() assert fchecked.checked_type == relay.FuncType([tp], tp)
assert_decl_has_type(env, 'f', func_type(['float32'], 'float32'))
def test_recursion(): def test_recursion():
...@@ -78,54 +59,44 @@ def test_recursion(): ...@@ -78,54 +59,44 @@ def test_recursion():
return f(n - 1, log(data)); return f(n - 1, log(data));
} }
} }
f(2, 10000);
""" """
b = IRBuilder() sb = relay.ScopeBuilder()
f = b.global_var('f') f = relay.GlobalVar("f")
n = b.param('n', ty='int32') ti32 = relay.scalar_type("int32")
data = b.param('data', ty='float32') tf32 = relay.scalar_type("float32")
with b.decl(f, n, data): n = relay.var("n", ti32)
with b.if_scope(equal(n, convert(0))): data = relay.var("data", tf32)
b.ret(data)
with b.else_scope(): with sb.if_scope(relay.equal(n, relay.const(0, ti32))):
b.ret(f(subtract(n, convert(1)), log(data))) sb.ret(data)
b.ret(f(convert(2.0), convert(10000.0))) with sb.else_scope():
assert_decl_has_type(b.env, 'f', func_type( sb.ret(f(relay.subtract(n, relay.const(1, ti32)), relay.log(data)))
['int32', 'float32'], 'float32')) env = relay.Environment()
# TODO(@jroesch): need evaluator or new runtime env[f] = relay.Function([n, data], sb.get())
# to execute this. assert "%3 = @f(%1, %2)" in env.astext()
assert env[f].checked_type == relay.FuncType([ti32, tf32], tf32)
def test_concat():
"""
Program:
def try_concat2(x: Float(3, 2), y: Float(2, 2)) -> Float(5, 2) {
return concatenate((x, y), axis=0);
}
"""
ib = IRBuilder()
try_concat2 = ib.global_var('try_concat2')
x = ib.param('x', ty=tensor_type(3, 2))
y = ib.param('y', ty=tensor_type(2, 2))
with ib.decl(try_concat2, x, y):
ib.ret(concatenate((x, y), axis=0))
fn_ty = func_type([tensor_type(3, 2), tensor_type(2, 2)], tensor_type(5, 2))
assert_decl_has_type(ib.env, try_concat2, fn_ty)
def test_tuple(): def test_tuple():
ib = IRBuilder() tp = relay.TensorType((10,))
dup = ib.global_var('dup') x = relay.var("x", tp)
x = ib.param('x') res = relay.Tuple([x, x])
with ib.decl(dup, x): assert (relay.ir_pass.infer_type(res).checked_type ==
ib.ret(relay.Tuple([x, x])) relay.TupleType([tp, tp]))
# todo: why is this not generalized?
fn_ty = func_type([tensor_type()], relay.TupleType([tensor_type(), tensor_type()]))
assert_decl_has_type(ib.env, dup, fn_ty) def test_free_expr():
x = relay.var("x", "float32")
y = relay.add(x, x)
yy = relay.ir_pass.infer_type(y)
assert yy.checked_type == relay.scalar_type("float32")
if __name__ == "__main__": if __name__ == "__main__":
test_free_expr()
test_dual_op() test_dual_op()
test_recursion() test_recursion()
test_monomorphic_let() test_monomorphic_let()
test_decl() test_decl()
test_recursion() test_recursion()
test_concat()
test_tuple() test_tuple()
import tvm import tvm
from tvm import relay from tvm import relay
from tvm.relay.ir_builder import scalar_type, convert, tensor_type
def make_rel(name, args, num_inputs=None, attrs=None): def make_rel(name, args, num_inputs=None, attrs=None):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment