Unverified Commit 0b4cc050 by Tianqi Chen Committed by GitHub

[RELAY][IR] Move type_annotation to Var, remove Param (#1900)

parent 53428606
...@@ -118,17 +118,27 @@ class Var; ...@@ -118,17 +118,27 @@ class Var;
/*! \brief Container for Var */ /*! \brief Container for Var */
class VarNode : public ExprNode { class VarNode : public ExprNode {
public: public:
/*! \brief The name of the variable, this only acts as a hint to the user, /*!
* \brief The name of the variable,
* this only acts as a hint to the user,
* and is not used for equality. * and is not used for equality.
*/ */
std::string name_hint; std::string name_hint;
/*!
* \brief type annotaion of the variable.
* This field records user provided type annotation of the Var.
* This field is optional and can be None.
*/
Type type_annotation;
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("name_hint", &name_hint); v->Visit("name_hint", &name_hint);
v->Visit("type_annotation", &type_annotation);
v->Visit("_checked_type_", &checked_type_); v->Visit("_checked_type_", &checked_type_);
} }
TVM_DLL static Var make(std::string name_hint); TVM_DLL static Var make(std::string name_hint,
Type type_annotation);
static constexpr const char* _type_key = "relay.Var"; static constexpr const char* _type_key = "relay.Var";
TVM_DECLARE_NODE_TYPE_INFO(VarNode, ExprNode); TVM_DECLARE_NODE_TYPE_INFO(VarNode, ExprNode);
...@@ -163,32 +173,6 @@ class GlobalVarNode : public ExprNode { ...@@ -163,32 +173,6 @@ class GlobalVarNode : public ExprNode {
RELAY_DEFINE_NODE_REF(GlobalVar, GlobalVarNode, Expr); RELAY_DEFINE_NODE_REF(GlobalVar, GlobalVarNode, Expr);
/*! /*!
* \brief Function parameter declaration.
*/
class Param;
/*! \brief A parameter. */
class ParamNode : public ExprNode {
public:
/*! \brief The variable */
Var var;
/*! \brief The type of the parameter */
Type type;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("var", &var);
v->Visit("type", &type);
v->Visit("span", &span);
}
TVM_DLL static Param make(Var var, Type type);
static constexpr const char* _type_key = "relay.Param";
TVM_DECLARE_NODE_TYPE_INFO(ParamNode, ExprNode);
};
RELAY_DEFINE_NODE_REF(Param, ParamNode, Expr);
/*!
* \brief Function (subgraph in computational graph) * \brief Function (subgraph in computational graph)
*/ */
class Function; class Function;
...@@ -196,7 +180,7 @@ class Function; ...@@ -196,7 +180,7 @@ class Function;
class FunctionNode : public ExprNode { class FunctionNode : public ExprNode {
public: public:
/*! \brief Function parameters */ /*! \brief Function parameters */
tvm::Array<Param> params; tvm::Array<Var> params;
/*! \brief User annotated return type of the function. */ /*! \brief User annotated return type of the function. */
Type ret_type; Type ret_type;
/*! /*!
...@@ -224,10 +208,18 @@ class FunctionNode : public ExprNode { ...@@ -224,10 +208,18 @@ class FunctionNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_); v->Visit("_checked_type_", &checked_type_);
} }
Type fn_type() const; /*!
* \brief Return the derived function annotation of this expression.
*
* \return The function type annotation.
* \note The function type annotation can contain IncompleteType.
*/
TVM_DLL FuncType func_type_annotation() const;
TVM_DLL static Function make(tvm::Array<Param> params, Type ret_type, TVM_DLL static Function make(tvm::Array<Var> params,
Expr body, tvm::Array<TypeParam> ty_params); Type ret_type,
Expr body,
tvm::Array<TypeParam> ty_params);
static constexpr const char* _type_key = "relay.Function"; static constexpr const char* _type_key = "relay.Function";
TVM_DECLARE_NODE_TYPE_INFO(FunctionNode, ExprNode); TVM_DECLARE_NODE_TYPE_INFO(FunctionNode, ExprNode);
...@@ -289,7 +281,7 @@ class CallNode : public ExprNode { ...@@ -289,7 +281,7 @@ class CallNode : public ExprNode {
TVM_DLL static Call make(Expr op, TVM_DLL static Call make(Expr op,
Array<Expr> args, Array<Expr> args,
Attrs attrs = Attrs(), Attrs attrs = Attrs(),
Array<Type> ty_args = Array<Type>()); Array<Type> type_args = Array<Type>());
static constexpr const char* _type_key = "relay.Call"; static constexpr const char* _type_key = "relay.Call";
TVM_DECLARE_NODE_TYPE_INFO(CallNode, ExprNode); TVM_DECLARE_NODE_TYPE_INFO(CallNode, ExprNode);
...@@ -318,19 +310,16 @@ class LetNode : public ExprNode { ...@@ -318,19 +310,16 @@ class LetNode : public ExprNode {
Expr value; Expr value;
/*! \brief The body of the let binding */ /*! \brief The body of the let binding */
Expr body; Expr body;
/*! \brief Type annotation of value, this can be null */
Type value_type;
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("var", &var); v->Visit("var", &var);
v->Visit("value", &value); v->Visit("value", &value);
v->Visit("body", &body); v->Visit("body", &body);
v->Visit("value_type", &value_type);
v->Visit("span", &span); v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_); v->Visit("_checked_type_", &checked_type_);
} }
TVM_DLL static Let make(Var var, Expr value, Expr body, Type value_type); TVM_DLL static Let make(Var var, Expr value, Expr body);
static constexpr const char* _type_key = "relay.Let"; static constexpr const char* _type_key = "relay.Let";
TVM_DECLARE_NODE_TYPE_INFO(LetNode, ExprNode); TVM_DECLARE_NODE_TYPE_INFO(LetNode, ExprNode);
...@@ -376,11 +365,11 @@ class IfNode : public ExprNode { ...@@ -376,11 +365,11 @@ class IfNode : public ExprNode {
RELAY_DEFINE_NODE_REF(If, IfNode, Expr); RELAY_DEFINE_NODE_REF(If, IfNode, Expr);
/*! \brief Get a field out of a tuple. */ /*! \brief Get index-th field out of a tuple. */
class TupleGetItem; class TupleGetItem;
class TupleGetItemNode : public ExprNode { class TupleGetItemNode : public ExprNode {
public: public:
/*! \brief The tuple */ /*! \brief The tuple Expression */
Expr tuple; Expr tuple;
/*! \brief which value to get */ /*! \brief which value to get */
int index; int index;
......
...@@ -80,7 +80,6 @@ class ExprFunctor<R(const Expr& n, Args...)> { ...@@ -80,7 +80,6 @@ class ExprFunctor<R(const Expr& n, Args...)> {
Args... args) EXPR_FUNCTOR_DEFAULT; Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const GlobalVarNode* op, virtual R VisitExpr_(const GlobalVarNode* op,
Args... args) EXPR_FUNCTOR_DEFAULT; Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const ParamNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const FunctionNode* op, virtual R VisitExpr_(const FunctionNode* op,
Args... args) EXPR_FUNCTOR_DEFAULT; Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
...@@ -103,7 +102,6 @@ class ExprFunctor<R(const Expr& n, Args...)> { ...@@ -103,7 +102,6 @@ class ExprFunctor<R(const Expr& n, Args...)> {
RELAY_EXPR_FUNCTOR_DISPATCH(TupleNode); RELAY_EXPR_FUNCTOR_DISPATCH(TupleNode);
RELAY_EXPR_FUNCTOR_DISPATCH(VarNode); RELAY_EXPR_FUNCTOR_DISPATCH(VarNode);
RELAY_EXPR_FUNCTOR_DISPATCH(GlobalVarNode); RELAY_EXPR_FUNCTOR_DISPATCH(GlobalVarNode);
RELAY_EXPR_FUNCTOR_DISPATCH(ParamNode);
RELAY_EXPR_FUNCTOR_DISPATCH(FunctionNode); RELAY_EXPR_FUNCTOR_DISPATCH(FunctionNode);
RELAY_EXPR_FUNCTOR_DISPATCH(CallNode); RELAY_EXPR_FUNCTOR_DISPATCH(CallNode);
RELAY_EXPR_FUNCTOR_DISPATCH(LetNode); RELAY_EXPR_FUNCTOR_DISPATCH(LetNode);
...@@ -127,7 +125,6 @@ class ExprVisitor : public ::tvm::relay::ExprFunctor<void(const Expr& n)> { ...@@ -127,7 +125,6 @@ class ExprVisitor : public ::tvm::relay::ExprFunctor<void(const Expr& n)> {
void VisitExpr_(const GlobalVarNode* op) override; void VisitExpr_(const GlobalVarNode* op) override;
void VisitExpr_(const ConstantNode* op) override; void VisitExpr_(const ConstantNode* op) override;
void VisitExpr_(const TupleNode* op) override; void VisitExpr_(const TupleNode* op) override;
void VisitExpr_(const ParamNode* op) override;
void VisitExpr_(const FunctionNode* op) override; void VisitExpr_(const FunctionNode* op) override;
void VisitExpr_(const CallNode* op) override; void VisitExpr_(const CallNode* op) override;
void VisitExpr_(const LetNode* op) override; void VisitExpr_(const LetNode* op) override;
...@@ -151,7 +148,6 @@ class ExprMutator ...@@ -151,7 +148,6 @@ class ExprMutator
Expr VisitExpr_(const GlobalVarNode* op) override; Expr VisitExpr_(const GlobalVarNode* op) override;
Expr VisitExpr_(const OpNode* op) override; Expr VisitExpr_(const OpNode* op) override;
Expr VisitExpr_(const TupleNode* op) override; Expr VisitExpr_(const TupleNode* op) override;
Expr VisitExpr_(const ParamNode* op) override;
Expr VisitExpr_(const FunctionNode* op) override; Expr VisitExpr_(const FunctionNode* op) override;
Expr VisitExpr_(const CallNode* call_node) override; Expr VisitExpr_(const CallNode* call_node) override;
Expr VisitExpr_(const LetNode* op) override; Expr VisitExpr_(const LetNode* op) override;
......
...@@ -34,7 +34,6 @@ Constant = expr.Constant ...@@ -34,7 +34,6 @@ Constant = expr.Constant
Tuple = expr.Tuple Tuple = expr.Tuple
Var = expr.Var Var = expr.Var
GlobalVar = expr.GlobalVar GlobalVar = expr.GlobalVar
Param = expr.Param
Function = expr.Function Function = expr.Function
Call = expr.Call Call = expr.Call
Let = expr.Let Let = expr.Let
......
...@@ -11,11 +11,11 @@ class Expr(NodeBase): ...@@ -11,11 +11,11 @@ class Expr(NodeBase):
"""The base type for all Relay expressions.""" """The base type for all Relay expressions."""
@property @property
def checked_type(self): def checked_type(self):
"""Get the checked type of relay. """Get the checked type of tvm.relay.Expr.
Returns Returns
------- -------
checked_type : relay.Type checked_type : tvm.relay.Type
The checked type. The checked type.
""" """
ret = self._checked_type_ ret = self._checked_type_
...@@ -25,70 +25,97 @@ class Expr(NodeBase): ...@@ -25,70 +25,97 @@ class Expr(NodeBase):
return ret return ret
def __call__(self, *args): def __call__(self, *args):
converted_args = []
for arg in args:
if isinstance(arg, Param):
converted_args.append(arg.var)
else:
converted_args.append(arg)
return Call(self, args, None, None) return Call(self, args, None, None)
@register_relay_node @register_relay_node
class Constant(Expr): class Constant(Expr):
"""A constant tensor in Relay, see tvm/relay/type.h for more details. """A constant expression in Relay.
"""
Parameters
----------
data : tvm.nd.NDArray
The data content of the constant expression.
"""
def __init__(self, data): def __init__(self, data):
self.__init_handle_by_constructor__(_make.Constant, data) self.__init_handle_by_constructor__(_make.Constant, data)
@register_relay_node @register_relay_node
class Tuple(Expr): class Tuple(Expr):
"""A hetereogenous sequence of values. """Tuple expression that groups several fields together.
see tvm/relay/type.h for more details.
"""
Parameters
----------
fields : List[tvm.relay.Expr]
The fields in the tuple.
"""
def __init__(self, fields): def __init__(self, fields):
self.__init_handle_by_constructor__(_make.Tuple, fields) self.__init_handle_by_constructor__(_make.Tuple, fields)
@register_relay_node @register_relay_node
class Var(Expr): class Var(Expr):
"""A local variable in Relay.""" """A local variable in Tvm.Relay.
def __init__(self, name_hint): Local variable can be used to declare input
self.__init_handle_by_constructor__(_make.Var, name_hint) arguments to a function, or intermediate variables.
Parameters
----------
name_hint: str
The name of the variable.
This name only acts as a hint, and is not used
for equality.
type_annotation: tvm.relay.Type, optional
The type annotation on the variable.
"""
def __init__(self, name_hint, type_annotation=None):
self.__init_handle_by_constructor__(
_make.Var, name_hint, type_annotation)
@register_relay_node @register_relay_node
class GlobalVar(Expr): class GlobalVar(Expr):
"""A global variable in Relay.""" """A global variable in Tvm.Relay.
GlobalVar is used to refer to the global functions
stored in the environment.
Parameters
----------
name_hint: str
The name of the variable.
"""
def __init__(self, name_hint): def __init__(self, name_hint):
self.__init_handle_by_constructor__(_make.GlobalVar, name_hint) self.__init_handle_by_constructor__(_make.GlobalVar, name_hint)
@register_relay_node @register_relay_node
class Param(Expr): class Function(Expr):
"""A function type in Relay, see tvm/relay/type.h for more details. """A function declaration expression.
"""
def __init__(self, var, ty): Parameters
self.__init_handle_by_constructor__(_make.Param, var, ty) ----------
params: List[tvm.relay.Var]
List of input parameters to the function.
ret_type: tvm.relay.Type
The return type annotation of the function.
@register_relay_node body: tvm.relay.Expr
class Function(Expr): The body of the function.
"""A function in Relay, see tvm/relay/expr.h for more details."""
type_params: Optional[List[tvm.relay.TypeParam]]
The additional type parameters, this is only
used in advanced usecase of template functions.
"""
def __init__(self, def __init__(self,
params, params,
ret_type, ret_type,
body, body,
type_params=None type_params=None):
):
if type_params is None: if type_params is None:
type_params = convert([]) type_params = convert([])
...@@ -98,39 +125,87 @@ class Function(Expr): ...@@ -98,39 +125,87 @@ class Function(Expr):
@register_relay_node @register_relay_node
class Call(Expr): class Call(Expr):
"""A function call in Relay, see tvm/relay/expr.h for more details.""" """Function call node in Relay.
Call node corresponds the operator application node
in computational graph terminology.
Parameters
----------
op: tvm.relay.Op or any tvm.relay.Expr with function type.
The operation to be called.
def __init__(self, op, args, attrs, ty_args=None): args: List[tvm.relay.Expr]
if not ty_args: The arguments to the call.
ty_args = []
attrs: Optional[tvm.Attrs]
Attributes to the call, can be None
type_args: Optional[List[tvm.relay.Type]]
The additional type arguments, this is only
used in advanced usecase of template functions.
"""
def __init__(self, op, args, attrs=None, type_args=None):
if not type_args:
type_args = []
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_make.Call, op, args, attrs, ty_args) _make.Call, op, args, attrs, type_args)
@register_relay_node @register_relay_node
class Let(Expr): class Let(Expr):
"""A variable bindings in Relay, see tvm/relay/expr.h for more details.""" """Let variable binding expression.
Parameters
----------
var: tvm.relay.Var
The local variable to be bound.
value: tvm.relay.Expr
The value to be bound.
def __init__(self, var, value, body, value_type=None): body: tvm.relay.Expr
The body of the let binding.
"""
def __init__(self, var, value, body):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_make.Let, var, value, body, value_type) _make.Let, var, value, body)
@register_relay_node @register_relay_node
class If(Expr): class If(Expr):
"""A conditional expression in Relay, see tvm/relay/expr.h for more details.""" """A conditional expression in Relay.
Parameters
----------
cond: tvm.relay.Expr
The condition.
def __init__(self, cond, true_value, false_value): true_branch: tvm.relay.Expr
The expression evaluated when condition is true.
false_branch: tvm.relay.Expr
The expression evaluated when condition is false.
"""
def __init__(self, cond, true_branch, false_branch):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_make.If, cond, true_value, false_value) _make.If, cond, true_branch, false_branch)
@register_relay_node @register_relay_node
class TupleGetItem(Expr): class TupleGetItem(Expr):
"""An expression that get field from tuple in Relay, see tvm/relay/expr.h for more details.""" """Get index-th item from a tuple.
Parameters
----------
tuple_value: tvm.relay.Expr
The input tuple expression.
def __init__(self, tuple_, index): index: int
The index.
"""
def __init__(self, tuple_value, index):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_make.TupleGetItem, tuple_, index) _make.TupleGetItem, tuple_value, index)
debug_print = _expr._debug_print debug_print = _expr._debug_print
...@@ -7,7 +7,7 @@ from collections import OrderedDict ...@@ -7,7 +7,7 @@ from collections import OrderedDict
import numpy as np import numpy as np
import tvm import tvm
from .ty import Type, FuncType, TensorType from .ty import Type, FuncType, TensorType
from .expr import Expr, Constant, Let, Var, Param, Function, If from .expr import Expr, Constant, Let, Var, Function, If
from .env import Environment from .env import Environment
...@@ -98,7 +98,7 @@ class PartialFunc(object): ...@@ -98,7 +98,7 @@ class PartialFunc(object):
self.type_params = type_params self.type_params = type_params
def param_ids(self): def param_ids(self):
return [p.var for p in self.params] return [p for p in self.params]
def to_func(self): def to_func(self):
"""Converts a PartialFunc into a :py:class:`~relay.Function`.""" """Converts a PartialFunc into a :py:class:`~relay.Function`."""
...@@ -113,9 +113,8 @@ class PartialFunc(object): ...@@ -113,9 +113,8 @@ class PartialFunc(object):
def _mk_let(bindings, ret_value): def _mk_let(bindings, ret_value):
let_expr = ret_value let_expr = ret_value
for var, (value, ty) in reversed(list(bindings.items())): for var, value in reversed(list(bindings.items())):
let_expr = Let(var, value, let_expr, ty) let_expr = Let(var, value, let_expr)
return let_expr return let_expr
...@@ -168,15 +167,12 @@ class IRBuilder(object): ...@@ -168,15 +167,12 @@ class IRBuilder(object):
#pylint: disable=invalid-name #pylint: disable=invalid-name
def bind(self, name, value, ty): def bind(self, name, value, ty):
lv = Var(name) lv = Var(name, ty)
self.scopes[-1][name] = lv self.scopes[-1][name] = lv
self.bindings[-1][lv] = (value, ty) self.bindings[-1][lv] = value
return lv return lv
def let(self, name, value, value_type=None): def let(self, name, value, value_type=None):
if isinstance(value, Param):
value = value.var
if not isinstance(value, Expr): if not isinstance(value, Expr):
value = convert(value) value = convert(value)
...@@ -185,23 +181,18 @@ class IRBuilder(object): ...@@ -185,23 +181,18 @@ class IRBuilder(object):
def _convert_params(self, raw_params): def _convert_params(self, raw_params):
relay_params = [] relay_params = []
for raw_param in raw_params: for raw_param in raw_params:
if isinstance(raw_param, Param): if isinstance(raw_param, Var):
var = raw_param.var
param = raw_param param = raw_param
elif isinstance(raw_param, tuple): elif isinstance(raw_param, tuple):
var, ty = raw_param var, ty = raw_param
if isinstance(var, str):
var = Var(var)
ty = _convert_type(ty) ty = _convert_type(ty)
param = Param(var, ty) param = Var(var, ty)
elif isinstance(param, str): elif isinstance(raw_param, str):
var = Var(raw_param) param = Var(raw_param, None)
ty = None
param = Param(var, ty)
else: else:
raise Exception("unknown parameter type") raise Exception("unknown parameter type")
self.scopes[-1][var.name_hint] = var self.scopes[-1][param.name_hint] = param
relay_params.append(param) relay_params.append(param)
return relay_params return relay_params
...@@ -265,7 +256,7 @@ class IRBuilder(object): ...@@ -265,7 +256,7 @@ class IRBuilder(object):
else: else:
ty = _convert_type(ty) ty = _convert_type(ty)
return Param(Var(name), ty) return Var(name, ty)
def global_var(self, name): def global_var(self, name):
# type: (str) -> GlobalVar # type: (str) -> GlobalVar
......
...@@ -96,7 +96,9 @@ class TypeDocifier : private TypeFunctor<Doc(const Type& n)> { ...@@ -96,7 +96,9 @@ class TypeDocifier : private TypeFunctor<Doc(const Type& n)> {
} }
std::vector<Doc> DocifyTypeParam(const tvm::Array<TypeParam>& arr) { std::vector<Doc> DocifyTypeParam(const tvm::Array<TypeParam>& arr) {
return MapDocify<TypeParam>(arr, [=](const TypeParam& tp) { return Docify(tp); }); return MapDocify<TypeParam>(arr, [=](const TypeParam& tp) {
return Docify(tp);
});
} }
std::vector<Doc> DocifyTypeConstraint(const tvm::Array<TypeConstraint>& arr) { std::vector<Doc> DocifyTypeConstraint(const tvm::Array<TypeConstraint>& arr) {
...@@ -188,10 +190,11 @@ class ExprDocifier : private ExprFunctor<Doc(const Expr& n)> { ...@@ -188,10 +190,11 @@ class ExprDocifier : private ExprFunctor<Doc(const Expr& n)> {
return vec; return vec;
} }
std::vector<Doc> DocifyParamArray(const tvm::Array<Param>& arr) { std::vector<Doc> DocifyParamArray(const tvm::Array<Var>& arr) {
std::vector<Doc> vec; std::vector<Doc> vec;
for (size_t i = 0; i < arr.size(); ++i) { for (Var param : arr) {
vec.push_back(Docify(arr[i])); vec.emplace_back(TypeAnnotation(DocOfStr(VarName(param)),
param->type_annotation));
} }
return vec; return vec;
} }
...@@ -212,10 +215,6 @@ class ExprDocifier : private ExprFunctor<Doc(const Expr& n)> { ...@@ -212,10 +215,6 @@ class ExprDocifier : private ExprFunctor<Doc(const Expr& n)> {
return DocOfStr(g->name_hint); return DocOfStr(g->name_hint);
} }
Doc VisitExpr_(const ParamNode* p) final {
return TypeAnnotation(Docify(p->var), p->type);
}
Doc VisitExpr_(const FunctionNode* f) final { Doc VisitExpr_(const FunctionNode* f) final {
return Group(TypeAnnotation(Seq("(", DocifyParamArray(f->params), ")"), f->ret_type) + Sep() + return Group(TypeAnnotation(Seq("(", DocifyParamArray(f->params), ")"), f->ret_type) + Sep() +
DocOfStr("=>") + Sep() + DocOfStr("=>") + Sep() +
...@@ -227,7 +226,8 @@ class ExprDocifier : private ExprFunctor<Doc(const Expr& n)> { ...@@ -227,7 +226,8 @@ class ExprDocifier : private ExprFunctor<Doc(const Expr& n)> {
} }
Doc VisitExpr_(const LetNode* l) final { Doc VisitExpr_(const LetNode* l) final {
return Group(DocOfStr("let") + Sep() + TypeAnnotation(Docify(l->var), l->value_type) + Sep() + return Group(DocOfStr("let") + Sep() +
TypeAnnotation(Docify(l->var), l->var->type_annotation) + Sep() +
DocOfStr("=") + Sep() + Docify(l->value) + DocOfStr(";") + Endl() + DocOfStr("=") + Sep() + Docify(l->value) + DocOfStr(";") + Endl() +
Docify(l->body)); Docify(l->body));
} }
......
...@@ -54,20 +54,26 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -54,20 +54,26 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
p->stream << "Tuple(" << node->fields << ")"; p->stream << "Tuple(" << node->fields << ")";
}); });
Var VarNode::make(std::string name_hint) { Var VarNode::make(std::string name_hint, Type type_annotation) {
NodePtr<VarNode> n = make_node<VarNode>(); NodePtr<VarNode> n = make_node<VarNode>();
n->name_hint = std::move(name_hint); n->name_hint = std::move(name_hint);
n->type_annotation = std::move(type_annotation);
return Var(n); return Var(n);
} }
TVM_REGISTER_API("relay._make.Var") TVM_REGISTER_API("relay._make.Var")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = VarNode::make(args[0]); *ret = VarNode::make(args[0], args[1]);
}); });
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<VarNode>([](const VarNode *node, tvm::IRPrinter *p) { .set_dispatch<VarNode>([](const VarNode *node, tvm::IRPrinter *p) {
p->stream << "Var(" << node->name_hint << ")"; p->stream << "Var(" << node->name_hint;
if (node->type_annotation.defined()) {
p->stream << ", ty=";
p->print(node->type_annotation);
}
p->stream << ")";
}); });
GlobalVar GlobalVarNode::make(std::string name_hint) { GlobalVar GlobalVarNode::make(std::string name_hint) {
...@@ -86,24 +92,10 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -86,24 +92,10 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
p->stream << "GlobalVar(" << node->name_hint << ")"; p->stream << "GlobalVar(" << node->name_hint << ")";
}); });
Param ParamNode::make(Var var, Type type) {
NodePtr<ParamNode> n = make_node<ParamNode>();
n->var = std::move(var);
n->type = std::move(type);
return Param(n);
}
TVM_REGISTER_API("relay._make.Param")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = ParamNode::make(args[0], args[1]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<ParamNode>([](const ParamNode *node, tvm::IRPrinter *p) {
p->stream << "Param(" << node->var << ", " << node->type << ")";
});
Function FunctionNode::make(tvm::Array<Param> params, Type ret_type, Expr body, Function FunctionNode::make(tvm::Array<Var> params,
Type ret_type,
Expr body,
tvm::Array<TypeParam> type_params) { tvm::Array<TypeParam> type_params) {
NodePtr<FunctionNode> n = make_node<FunctionNode>(); NodePtr<FunctionNode> n = make_node<FunctionNode>();
n->params = std::move(params); n->params = std::move(params);
...@@ -113,12 +105,11 @@ Function FunctionNode::make(tvm::Array<Param> params, Type ret_type, Expr body, ...@@ -113,12 +105,11 @@ Function FunctionNode::make(tvm::Array<Param> params, Type ret_type, Expr body,
return Function(n); return Function(n);
} }
Type FunctionNode::fn_type() const { FuncType FunctionNode::func_type_annotation() const {
Array<Type> param_types; Array<Type> param_types;
for (auto param : this->params) { for (auto param : this->params) {
param_types.push_back(param->type); param_types.push_back(param->type_annotation);
} }
return FuncTypeNode::make(param_types, this->ret_type, this->type_params, {}); return FuncTypeNode::make(param_types, this->ret_type, this->type_params, {});
} }
...@@ -155,24 +146,23 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -155,24 +146,23 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
<< node->attrs << ", " << node->type_args << ")"; << node->attrs << ", " << node->type_args << ")";
}); });
Let LetNode::make(Var var, Expr value, Expr body, Type value_type) { Let LetNode::make(Var var, Expr value, Expr body) {
NodePtr<LetNode> n = make_node<LetNode>(); NodePtr<LetNode> n = make_node<LetNode>();
n->var = std::move(var); n->var = std::move(var);
n->value = std::move(value); n->value = std::move(value);
n->body = std::move(body); n->body = std::move(body);
n->value_type = std::move(value_type);
return Let(n); return Let(n);
} }
TVM_REGISTER_API("relay._make.Let") TVM_REGISTER_API("relay._make.Let")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = LetNode::make(args[0], args[1], args[2], args[3]); *ret = LetNode::make(args[0], args[1], args[2]);
}); });
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<LetNode>([](const LetNode *node, tvm::IRPrinter *p) { .set_dispatch<LetNode>([](const LetNode *node, tvm::IRPrinter *p) {
p->stream << "LetNode(" << node->var << ", " << node->value p->stream << "LetNode(" << node->var << ", " << node->value
<< ", " << node->body << ", " << node->value_type << ")"; << ", " << node->body << ")";
}); });
If IfNode::make(Expr cond, Expr true_branch, Expr false_branch) { If IfNode::make(Expr cond, Expr true_branch, Expr false_branch) {
......
...@@ -24,6 +24,16 @@ Expr ExprMutator::Mutate(const Expr& expr) { ...@@ -24,6 +24,16 @@ Expr ExprMutator::Mutate(const Expr& expr) {
} }
Expr ExprMutator::VisitExpr_(const VarNode* op) { Expr ExprMutator::VisitExpr_(const VarNode* op) {
// NOTE: var will only be mutated once
// Thanks to the memo and reused during rewriting if necessary.
// It is safe to assume that the
if (op->type_annotation.defined()) {
auto type = this->VisitType(op->type_annotation);
if (!op->type_annotation.same_as(type)) {
return VarNode::make(op->name_hint, type);
}
}
// default case return self.
return GetRef<Expr>(op); return GetRef<Expr>(op);
} }
...@@ -55,16 +65,6 @@ Expr ExprMutator::VisitExpr_(const TupleNode* op) { ...@@ -55,16 +65,6 @@ Expr ExprMutator::VisitExpr_(const TupleNode* op) {
} }
} }
Expr ExprMutator::VisitExpr_(const ParamNode* op) {
Var var = Downcast<Var>(this->Mutate(op->var));
auto type = this->VisitType(op->type);
if (op->var.same_as(var) && op->type.same_as(type)) {
return GetRef<Expr>(op);
} else {
return ParamNode::make(var, type);
}
}
Expr ExprMutator::VisitExpr_(const FunctionNode* op) { Expr ExprMutator::VisitExpr_(const FunctionNode* op) {
tvm::Array<TypeParam> ty_params; tvm::Array<TypeParam> ty_params;
bool all_ty_params_changed = true; bool all_ty_params_changed = true;
...@@ -75,10 +75,10 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* op) { ...@@ -75,10 +75,10 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* op) {
all_ty_params_changed &= new_ty_param.same_as(ty_param); all_ty_params_changed &= new_ty_param.same_as(ty_param);
} }
tvm::Array<Param> params; tvm::Array<Var> params;
bool all_params_changed = true; bool all_params_changed = true;
for (auto param : op->params) { for (auto param : op->params) {
Param new_param = Downcast<Param>(this->Mutate(param)); Var new_param = Downcast<Var>(this->Mutate(param));
params.push_back(new_param); params.push_back(new_param);
all_params_changed &= param.same_as(new_param); all_params_changed &= param.same_as(new_param);
} }
...@@ -123,17 +123,15 @@ Expr ExprMutator::VisitExpr_(const CallNode* call_node) { ...@@ -123,17 +123,15 @@ Expr ExprMutator::VisitExpr_(const CallNode* call_node) {
Expr ExprMutator::VisitExpr_(const LetNode* op) { Expr ExprMutator::VisitExpr_(const LetNode* op) {
Var var = Downcast<Var>(this->Mutate(op->var)); Var var = Downcast<Var>(this->Mutate(op->var));
auto type = this->VisitType(op->value_type);
auto value = this->Mutate(op->value); auto value = this->Mutate(op->value);
auto body = this->Mutate(op->body); auto body = this->Mutate(op->body);
if (var.same_as(op->var) && if (var.same_as(op->var) &&
type.same_as(op->value_type) &&
value.same_as(op->value) && value.same_as(op->value) &&
body.same_as(op->body)) { body.same_as(op->body)) {
return GetRef<Expr>(op); return GetRef<Expr>(op);
} else { } else {
return LetNode::make(var, value, body, type); return LetNode::make(var, value, body);
} }
} }
...@@ -162,6 +160,9 @@ Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) { ...@@ -162,6 +160,9 @@ Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) {
Type ExprMutator::VisitType(const Type& t) { return t; } Type ExprMutator::VisitType(const Type& t) { return t; }
void ExprVisitor::ExprVisitor::VisitExpr_(const VarNode* op) { void ExprVisitor::ExprVisitor::VisitExpr_(const VarNode* op) {
if (op->type_annotation.defined()) {
this->VisitType(op->type_annotation);
}
} }
void ExprVisitor::ExprVisitor::VisitExpr_(const GlobalVarNode* op) { void ExprVisitor::ExprVisitor::VisitExpr_(const GlobalVarNode* op) {
...@@ -176,10 +177,6 @@ void ExprVisitor::ExprVisitor::VisitExpr_(const TupleNode* op) { ...@@ -176,10 +177,6 @@ void ExprVisitor::ExprVisitor::VisitExpr_(const TupleNode* op) {
} }
} }
void ExprVisitor::ExprVisitor::VisitExpr_(const ParamNode* op) {
this->VisitExpr(op->var);
}
void ExprVisitor::ExprVisitor::VisitExpr_(const FunctionNode* op) { void ExprVisitor::ExprVisitor::VisitExpr_(const FunctionNode* op) {
for (auto param : op->params) { for (auto param : op->params) {
this->VisitExpr(param); this->VisitExpr(param);
......
...@@ -252,15 +252,6 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> { ...@@ -252,15 +252,6 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
} }
} }
void VisitExpr_(const ParamNode* p1, const Expr& e2) final {
if (const ParamNode* p2 = e2.as<ParamNode>()) {
eq_map.Set(p1->var, p2->var);
equal = equal && AlphaEqual(p1->type, p2->type);
} else {
equal = false;
}
}
void VisitExpr_(const FunctionNode* func1, const Expr& e2) final { void VisitExpr_(const FunctionNode* func1, const Expr& e2) final {
if (const FunctionNode* func2 = e2.as<FunctionNode>()) { if (const FunctionNode* func2 = e2.as<FunctionNode>()) {
if (func1->params.size() != func2->params.size()) { if (func1->params.size() != func2->params.size()) {
...@@ -273,9 +264,10 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> { ...@@ -273,9 +264,10 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
return; return;
} }
for (size_t i = 0U; i < func1->params.size(); i++) { for (size_t i = 0; i < func1->params.size(); ++i) {
this->VisitExpr(func1->params[i], func2->params[i]); MergeVarDecl(func1->params[i], func2->params[i]);
} }
if (!equal) return;
for (size_t i = 0U; i < func1->type_params.size(); i++) { for (size_t i = 0U; i < func1->type_params.size(); i++) {
equal = equal && AlphaEqual(func1->type_params[i], func2->type_params[i]); equal = equal && AlphaEqual(func1->type_params[i], func2->type_params[i]);
...@@ -332,19 +324,9 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> { ...@@ -332,19 +324,9 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
void VisitExpr_(const LetNode* op, const Expr& e2) final { void VisitExpr_(const LetNode* op, const Expr& e2) final {
if (const LetNode* let = e2.as<LetNode>()) { if (const LetNode* let = e2.as<LetNode>()) {
eq_map.Set(op->var, let->var); MergeVarDecl(op->var, let->var);
this->VisitExpr(op->value, let->value); this->VisitExpr(op->value, let->value);
this->VisitExpr(op->body, let->body); this->VisitExpr(op->body, let->body);
// value_type should match as well (including nulls)
if (op->value_type.defined() != let->value_type.defined()) {
equal = false;
return;
}
if (op->value_type.defined()) {
equal = equal && AlphaEqual(op->value_type, let->value_type);
}
} else { } else {
equal = false; equal = false;
} }
...@@ -388,6 +370,20 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> { ...@@ -388,6 +370,20 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
equal = false; equal = false;
} }
} }
private:
void MergeVarDecl(const Var& var1, const Var& var2) {
if (var1->type_annotation.defined() != var2->type_annotation.defined()) {
equal = false;
return;
}
if (var1->type_annotation.defined() &&
!AlphaEqual(var1->type_annotation, var2->type_annotation)) {
equal = false;
return;
}
eq_map.Set(var1, var2);
}
}; };
bool AlphaEqual(const Expr& e1, const Expr& e2) { bool AlphaEqual(const Expr& e1, const Expr& e2) {
......
...@@ -54,12 +54,7 @@ class CalcDep : private ExprMutator { ...@@ -54,12 +54,7 @@ class CalcDep : private ExprMutator {
} }
private: private:
struct Binder { using VarMap = std::unordered_map<Var, Expr, NodeHash, NodeEqual>;
Type t;
Expr e;
Binder(const Type& t, const Expr& e) : t(t), e(e) { }
};
using VarMap = std::unordered_map<Var, Binder, NodeHash, NodeEqual>;
VarMap var_map_; VarMap var_map_;
Expr VisitExpr_(const IfNode* i) final { Expr VisitExpr_(const IfNode* i) final {
...@@ -74,9 +69,7 @@ class CalcDep : private ExprMutator { ...@@ -74,9 +69,7 @@ class CalcDep : private ExprMutator {
} }
Expr VisitExpr_(const LetNode* l) final { Expr VisitExpr_(const LetNode* l) final {
var_map_.insert(std::pair<Var, Binder>(l->var, var_map_[l->var] = Eliminate(l->value);
Binder(l->value_type,
Eliminate(l->value))));
return VisitExpr(l->body); return VisitExpr(l->body);
} }
...@@ -92,15 +85,16 @@ class CalcDep : private ExprMutator { ...@@ -92,15 +85,16 @@ class CalcDep : private ExprMutator {
explicit GenLet(const VarMap& var_map) : var_map_(var_map) { } explicit GenLet(const VarMap& var_map) : var_map_(var_map) { }
friend CalcDep; friend CalcDep;
void VisitExpr_(const VarNode* vn) final { void VisitExpr_(const VarNode* vnode) final {
Var v = GetRef<Var>(vn); Var v = GetRef<Var>(vnode);
if (var_map_.count(v) != 0) { auto it = var_map_.find(v);
auto val = var_map_.at(v); if (it != var_map_.end()) {
var_map_.erase(v); Expr expr = it->second;
var_map_.erase(it);
// erase before visit to handle letrec // erase before visit to handle letrec
VisitExpr(val.e); VisitExpr(expr);
// visit before push back so the dependency of dependency is before the dependency // visit before push back so the dependency of dependency is before the dependency
lets_.Push(v, val.t, val.e); lets_.Push(v, expr);
} }
} }
}; };
......
...@@ -26,23 +26,22 @@ namespace relay { ...@@ -26,23 +26,22 @@ namespace relay {
*/ */
class LetList { class LetList {
public: public:
/*! \brief insert a binding. /*!
* \brief insert a binding.
* *
* \param pv the var of the binding. * \param pv the var of the binding.
* *
* \param ty the type of the binding.
*
* \param expr the value of the binding. * \param expr the value of the binding.
* *
* \return a Var that hold the inserted expr. * \return a Var that hold the inserted expr.
*/ */
Var Push(const Var& pv, const Type& ty, const Expr& expr) { Var Push(Var pv, Expr expr) {
std::tuple<Var, Type, Expr> tuple(pv, ty, expr); lets_.emplace_back(std::make_pair(pv, expr));
lets_.push_back(tuple);
return pv; return pv;
} }
/*! \brief insert a binding. /*!
* \brief insert a binding.
* *
* \param ty the type of the binding. * \param ty the type of the binding.
* *
...@@ -50,33 +49,23 @@ class LetList { ...@@ -50,33 +49,23 @@ class LetList {
* *
* \return a Var that hold the inserted expr. * \return a Var that hold the inserted expr.
*/ */
Var Push(const Type& ty, const Expr& expr) { Var Push(Type ty, Expr expr) {
return Push(VarNode::make("x"), ty, expr); return Push(VarNode::make("x", ty), expr);
}
/*! \brief insert a binding.
*
* \param pv the var of the binding.
*
* \param expr the value of the binding.
*
* \return a Var that hold the inserted expr.
*/
Var Push(const Var& pv, const Expr& expr) {
return Push(pv, IncompleteTypeNode::make(TypeParamNode::kType), expr);
} }
/*! \brief insert a binding. /*!
* \brief insert a binding.
* *
* \param expr the value of the binding. * \param expr the value of the binding.
* *
* \return a Var that hold the inserted expr. * \return a Var that hold the inserted expr.
*/ */
Var Push(const Expr& expr) { Var Push(Expr expr) {
return Push(IncompleteTypeNode::make(TypeParamNode::kType), expr); return Push(IncompleteTypeNode::make(TypeParamNode::kType), expr);
} }
/*! \brief wrap an expr around the LetList. /*!
* \brief wrap an expr around the LetList.
* *
* \param body the Expression to be wrapped around. * \param body the Expression to be wrapped around.
* *
...@@ -85,7 +74,7 @@ class LetList { ...@@ -85,7 +74,7 @@ class LetList {
Expr Get(const Expr& body) const { Expr Get(const Expr& body) const {
Expr ret = body; Expr ret = body;
for (auto rit = lets_.rbegin(); rit != lets_.rend(); ++rit) { for (auto rit = lets_.rbegin(); rit != lets_.rend(); ++rit) {
ret = LetNode::make(std::get<0>(*rit), std::get<2>(*rit), ret, std::get<1>(*rit)); ret = LetNode::make(std::get<0>(*rit), std::get<1>(*rit), ret);
} }
return ret; return ret;
} }
...@@ -118,7 +107,7 @@ class LetList { ...@@ -118,7 +107,7 @@ class LetList {
} }
private: private:
std::vector<std::tuple<Var, Type, Expr> > lets_; std::vector<std::pair<Var, Expr> > lets_;
}; };
} // namespace relay } // namespace relay
......
...@@ -87,15 +87,11 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { ...@@ -87,15 +87,11 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
// Visitor logics // Visitor logics
Type VisitExpr_(const VarNode* op) final { Type VisitExpr_(const VarNode* op) final {
// The type of Var can already been lookedup in type_map_; if (op->type_annotation.defined()) {
LOG(FATAL) << "Cannot find binding for var " << GetRef<Var>(op); return op->type_annotation;
return Type(); } else {
return IncompleteTypeNode::make(TypeParamNode::kType);
} }
Type VisitExpr_(const ParamNode* op) final {
// directly handled by Funtion
LOG(FATAL) << "not reached";
return Type();
} }
Type VisitExpr_(const GlobalVarNode* op) final { Type VisitExpr_(const GlobalVarNode* op) final {
...@@ -139,11 +135,11 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { ...@@ -139,11 +135,11 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
Type VisitExpr_(const LetNode* op) final { Type VisitExpr_(const LetNode* op) final {
Type vtype = GetType(op->value); Type vtype = GetType(op->value);
if (op->value_type.defined()) { if (op->var->type_annotation.defined()) {
vtype = Unify(vtype, op->value_type, op->span); vtype = Unify(vtype, op->var->type_annotation, op->span);
} }
CHECK(!type_map_.count(op->var)); CHECK(!type_map_.count(op->var));
// NOTE: no scoping is necessary becase var are unique in program // NOTE: no scoping is necessary because var are unique in program
type_map_[op->var] = vtype; type_map_[op->var] = vtype;
return GetType(op->body); return GetType(op->body);
} }
...@@ -256,8 +252,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { ...@@ -256,8 +252,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
Type VisitExpr_(const FunctionNode* f) final { Type VisitExpr_(const FunctionNode* f) final {
for (auto param : f->params) { for (auto param : f->params) {
type_map_[param->var] = param->type; GetType(param);
type_map_[param] = param->type;
} }
Type rtype = GetType(f->body); Type rtype = GetType(f->body);
// Run solver using the currently known information // Run solver using the currently known information
...@@ -265,8 +260,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { ...@@ -265,8 +260,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
// Trying to resolve // Trying to resolve
Array<Type> arg_types; Array<Type> arg_types;
for (size_t i = 0; i < f->params.size(); ++i) { for (size_t i = 0; i < f->params.size(); ++i) {
Param param = f->params[i]; Type atype = solver_.Resolve(GetType(f->params[i]));
Type atype = solver_.Resolve(param->type);
CHECK(atype.as<IncompleteTypeNode>() == nullptr) CHECK(atype.as<IncompleteTypeNode>() == nullptr)
<< "Cannot resolve type of " << i << "Cannot resolve type of " << i
<< "-th parameter of function at" << f->span; << "-th parameter of function at" << f->span;
...@@ -311,9 +305,6 @@ class TypeInferencer::Resolver : public ExprMutator { ...@@ -311,9 +305,6 @@ class TypeInferencer::Resolver : public ExprMutator {
return AttachCheckedType(op); return AttachCheckedType(op);
} }
Expr VisitExpr_(const ParamNode* op) final {
return ExprMutator::VisitExpr_(op);
}
Expr VisitExpr_(const FunctionNode* op) final { Expr VisitExpr_(const FunctionNode* op) final {
return AttachCheckedType(op); return AttachCheckedType(op);
...@@ -380,7 +371,7 @@ Expr InferType(const Environment& env, ...@@ -380,7 +371,7 @@ Expr InferType(const Environment& env,
const GlobalVar& var, const GlobalVar& var,
const Function& func) { const Function& func) {
Function func_copy = Function(make_node<FunctionNode>(*func.operator->())); Function func_copy = Function(make_node<FunctionNode>(*func.operator->()));
func_copy->checked_type_ = func_copy->fn_type(); func_copy->checked_type_ = func_copy->func_type_annotation();
env->functions.Set(var, func_copy); env->functions.Set(var, func_copy);
Expr func_ret = TypeInferencer(env).Infer(func_copy); Expr func_ret = TypeInferencer(env).Infer(func_copy);
auto map_node = env->functions.CopyOnWrite(); auto map_node = env->functions.CopyOnWrite();
......
...@@ -50,14 +50,17 @@ class FreeVar : public ExprVisitor { ...@@ -50,14 +50,17 @@ class FreeVar : public ExprVisitor {
if (bound_vars.count(var) == 0) { if (bound_vars.count(var) == 0) {
free_vars.insert(var); free_vars.insert(var);
} }
if (v->type_annotation.defined()) {
VisitType(v->type_annotation);
}
} }
void VisitExpr_(const FunctionNode *f) final { void VisitExpr_(const FunctionNode *f) final {
for (const auto& tp : f->type_params) { for (const auto& tp : f->type_params) {
bound_types.insert(tp); bound_types.insert(tp);
} }
for (const auto& p : f->params) { for (const auto& param : f->params) {
bound_vars.insert(p->var); bound_vars.insert(param);
} }
VisitExpr(f->body); VisitExpr(f->body);
VisitType(f->ret_type); VisitType(f->ret_type);
...@@ -67,7 +70,6 @@ class FreeVar : public ExprVisitor { ...@@ -67,7 +70,6 @@ class FreeVar : public ExprVisitor {
bound_vars.insert(l->var); bound_vars.insert(l->var);
VisitExpr(l->value); VisitExpr(l->value);
VisitExpr(l->body); VisitExpr(l->body);
VisitType(l->value_type);
} }
public: public:
......
...@@ -34,8 +34,8 @@ class WellFormedChecker : private ExprVisitor { ...@@ -34,8 +34,8 @@ class WellFormedChecker : private ExprVisitor {
} }
void VisitExpr_(const FunctionNode * f) final { void VisitExpr_(const FunctionNode * f) final {
for (const Param & p : f->params) { for (const Var & param : f->params) {
Check(p->var); Check(param);
} }
CheckWellFormed(f->body); CheckWellFormed(f->body);
} }
......
...@@ -14,7 +14,6 @@ def test_let(): ...@@ -14,7 +14,6 @@ def test_let():
assert var == prog.body assert var == prog.body
assert isinstance(value, Constant) assert isinstance(value, Constant)
assert value.data.asnumpy() == np.array(1) assert value.data.asnumpy() == np.array(1)
assert prog.value_type == None
if __name__ == "__main__": if __name__ == "__main__":
test_let() test_let()
...@@ -49,18 +49,11 @@ def test_global_var(): ...@@ -49,18 +49,11 @@ def test_global_var():
show(gv) show(gv)
def test_param():
lv = relay.Var('x')
ty = None
param = relay.Param(lv, ty)
show(lv)
def test_function(): def test_function():
param_names = ['a', 'b', 'c', 'd'] param_names = ['a', 'b', 'c', 'd']
params = tvm.convert([relay.Param(relay.Var(n), None) for n in param_names]) params = tvm.convert([relay.Var(n) for n in param_names])
ret_type = None ret_type = None
body = params[0].var body = params[0]
type_params = tvm.convert([]) type_params = tvm.convert([])
fn = relay.Function(params, ret_type, body, type_params) fn = relay.Function(params, ret_type, body, type_params)
show(fn) show(fn)
...@@ -76,11 +69,11 @@ def test_call(): ...@@ -76,11 +69,11 @@ def test_call():
def test_let(): def test_let():
lv = relay.Var('x')
ty = relay.ty.TensorType((10, 20), 'float32') ty = relay.ty.TensorType((10, 20), 'float32')
lv = relay.Var('x', ty)
arr = tvm.nd.array(10) arr = tvm.nd.array(10)
value = relay.Constant(arr) value = relay.Constant(arr)
let = relay.Let(lv, value, lv, ty) let = relay.Let(lv, value, lv)
show(let) show(let)
......
...@@ -99,10 +99,16 @@ def test_tuple(): ...@@ -99,10 +99,16 @@ def test_tuple():
def test_local_var(): def test_local_var():
name_hint = 's' name_hint = 's'
lv = relay.Var(name_hint) lv = relay.Var(name_hint)
lv.name_hint == name_hint assert lv.name_hint == name_hint
assert lv.type_annotation is None
# assert lv.span == None todo(@jroesch): what do we do about spans # assert lv.span == None todo(@jroesch): what do we do about spans
str(lv) str(lv)
t1 = relay.ty.TensorType((), "float")
lv = relay.Var(name_hint, t1)
assert lv.name_hint == name_hint
assert lv.type_annotation == t1
def test_global_var(): def test_global_var():
name_hint = 'g' name_hint = 'g'
...@@ -112,19 +118,9 @@ def test_global_var(): ...@@ -112,19 +118,9 @@ def test_global_var():
str(gv) str(gv)
def test_param():
lv = relay.Var('x')
ty = None
param = relay.Param(lv, ty)
assert param.var == lv
assert param.type == ty
assert param.span == None
str(param)
def test_function(): def test_function():
param_names = ['a', 'b', 'c', 'd'] param_names = ['a', 'b', 'c', 'd']
params = tvm.convert([relay.Param(relay.Var(n), None) for n in param_names]) params = tvm.convert([relay.Var(n) for n in param_names])
ret_type = None ret_type = None
body = None body = None
type_params = tvm.convert([]) type_params = tvm.convert([])
...@@ -154,10 +150,9 @@ def test_let(): ...@@ -154,10 +150,9 @@ def test_let():
value = relay.Constant(arr) value = relay.Constant(arr)
# I would prefer that the order of arguments # I would prefer that the order of arguments
# matches syntax let x: t = v in b # matches syntax let x: t = v in b
let = relay.Let(lv, value, lv, ty) let = relay.Let(lv, value, lv)
assert let.var == lv assert let.var == lv
assert let.value == value assert let.value == value
assert let.value_type == ty
assert let.body == lv assert let.body == lv
assert let.span == None assert let.span == None
str(let) str(let)
...@@ -194,7 +189,6 @@ if __name__ == "__main__": ...@@ -194,7 +189,6 @@ if __name__ == "__main__":
test_tuple() test_tuple()
test_local_var() test_local_var()
test_global_var() test_global_var()
test_param()
test_function() test_function()
test_call() test_call()
test_let() test_let()
......
...@@ -7,23 +7,22 @@ def test_well_formed(): ...@@ -7,23 +7,22 @@ def test_well_formed():
assert well_formed(x) assert well_formed(x)
v = relay.Constant(tvm.nd.array(10)) v = relay.Constant(tvm.nd.array(10))
ty = None ty = None
let = relay.Let(x, v, x, ty) let = relay.Let(x, v, x)
assert well_formed(let) assert well_formed(let)
assert not well_formed(relay.Let(x, v, let, ty)) assert not well_formed(relay.Let(x, v, let))
f = relay.Function([relay.Param(x, ty)], ty, x) f = relay.Function([x], ty, x)
assert well_formed(f) assert well_formed(f)
# this test should pass in case of weak uniqueness (only test for shadowing) # this test should pass in case of weak uniqueness (only test for shadowing)
# but we want all binder to be distinct from each other. # but we want all binder to be distinct from each other.
assert not well_formed(relay.Let(relay.Var("y"), f, assert not well_formed(relay.Let(relay.Var("y"), f,
relay.Let(relay.Var("z"), f, v, ty), ty)) relay.Let(relay.Var("z"), f, v)))
def test_tuple(): def test_tuple():
x = relay.Var('x') x = relay.Var('x')
assert well_formed(x) assert well_formed(x)
v = relay.Constant(tvm.nd.array(10)) v = relay.Constant(tvm.nd.array(10))
ty = None let = relay.Let(x, v, x)
let = relay.Let(x, v, x, ty)
assert well_formed(let) assert well_formed(let)
assert well_formed(relay.Tuple([v, v])) assert well_formed(relay.Tuple([v, v]))
assert not well_formed(relay.Tuple([let, let])) assert not well_formed(relay.Tuple([let, let]))
......
...@@ -27,6 +27,8 @@ def test_single_op(): ...@@ -27,6 +27,8 @@ def test_single_op():
tvm.relay.sigmoid, tvm.relay.tanh]: tvm.relay.sigmoid, tvm.relay.tanh]:
check_single_op(opfunc) check_single_op(opfunc)
def test_expand_dims_infer_type(): def test_expand_dims_infer_type():
ib = relay.ir_builder.IRBuilder() ib = relay.ir_builder.IRBuilder()
n, t, d = tvm.var("n"), tvm.var("t"), 100 n, t, d = tvm.var("n"), tvm.var("t"), 100
...@@ -75,12 +77,13 @@ def test_unary_op(): ...@@ -75,12 +77,13 @@ def test_unary_op():
ib = relay.ir_builder.IRBuilder() ib = relay.ir_builder.IRBuilder()
x = ib.param("x", relay.TensorType((10, 4), "int32")) x = ib.param("x", relay.TensorType((10, 4), "int32"))
with ib.function(x) as func: with ib.function(x) as func:
ib.ret(op(x.var)) ib.ret(op(x))
ib.ret(func) ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func()) func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type ftype = func.checked_type
assert ftype.ret_type == relay.TensorType((10, 4), "int32") assert ftype.ret_type == relay.TensorType((10, 4), "int32")
def test_binary_op(): def test_binary_op():
def check_binary_op(opfunc): def check_binary_op(opfunc):
""" """
...@@ -94,7 +97,7 @@ def test_binary_op(): ...@@ -94,7 +97,7 @@ def test_binary_op():
x = b.param('x', tensor_type(5, 5, 5)) x = b.param('x', tensor_type(5, 5, 5))
y = b.param('y', tensor_type(5, 5, 5)) y = b.param('y', tensor_type(5, 5, 5))
with b.function(x, y) as func: with b.function(x, y) as func:
b.ret(opfunc(x.var, y.var)) b.ret(opfunc(x, y))
b.ret(func) b.ret(func)
prog, env = b.get() prog, env = b.get()
ttype = tensor_type(5, 5, 5) ttype = tensor_type(5, 5, 5)
...@@ -118,7 +121,7 @@ def test_binary_broadcast_op(): ...@@ -118,7 +121,7 @@ def test_binary_broadcast_op():
x = b.param('x', tensor_type(10, 4)) x = b.param('x', tensor_type(10, 4))
y = b.param('y', tensor_type(5, 10, 1)) y = b.param('y', tensor_type(5, 10, 1))
with b.function(x, y) as func: with b.function(x, y) as func:
b.ret(opfunc(x.var, y.var)) b.ret(opfunc(x, y))
b.ret(func) b.ret(func)
prog, env = b.get() prog, env = b.get()
......
...@@ -11,7 +11,7 @@ def test_conv2d_infer_type(): ...@@ -11,7 +11,7 @@ def test_conv2d_infer_type():
w = ib.param("w", relay.ty.IncompleteType()) w = ib.param("w", relay.ty.IncompleteType())
with ib.function(x, w) as func: with ib.function(x, w) as func:
ib.ret(relay.nn.conv2d(x.var, w.var, ib.ret(relay.nn.conv2d(x, w,
kernel_size=(3, 3), kernel_size=(3, 3),
padding=(1, 1), padding=(1, 1),
channels=2)) channels=2))
...@@ -29,7 +29,7 @@ def test_conv2d_infer_type(): ...@@ -29,7 +29,7 @@ def test_conv2d_infer_type():
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "int8")) x = ib.param("x", relay.ty.TensorType((n, c, h, w), "int8"))
w = ib.param("w", relay.ty.TensorType((2, 10, 3, 3), "int8")) w = ib.param("w", relay.ty.TensorType((2, 10, 3, 3), "int8"))
with ib.function(x, w) as func: with ib.function(x, w) as func:
ib.ret(relay.nn.conv2d(x.var, w.var, out_dtype="int32")) ib.ret(relay.nn.conv2d(x, w, out_dtype="int32"))
ib.ret(func) ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func()) func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type ftype = func.checked_type
...@@ -42,7 +42,7 @@ def test_conv2d_infer_type(): ...@@ -42,7 +42,7 @@ def test_conv2d_infer_type():
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "int8")) x = ib.param("x", relay.ty.TensorType((n, c, h, w), "int8"))
w = ib.param("w", relay.ty.IncompleteType()) w = ib.param("w", relay.ty.IncompleteType())
with ib.function(x, w) as func: with ib.function(x, w) as func:
ib.ret(relay.nn.conv2d(x.var, w.var, ib.ret(relay.nn.conv2d(x, w,
kernel_size=(3, 3), kernel_size=(3, 3),
padding=(1, 1), padding=(1, 1),
channels=16, channels=16,
...@@ -65,7 +65,7 @@ def test_conv2d_transpose_infer_type(): ...@@ -65,7 +65,7 @@ def test_conv2d_transpose_infer_type():
w = ib.param("w", relay.ty.IncompleteType()) w = ib.param("w", relay.ty.IncompleteType())
with ib.function(x, w) as func: with ib.function(x, w) as func:
ib.ret(relay.nn.conv2d_transpose(x.var, w.var, ib.ret(relay.nn.conv2d_transpose(x, w,
kernel_size=(3, 3), kernel_size=(3, 3),
padding=(1, 1), padding=(1, 1),
channels=15)) channels=15))
...@@ -83,7 +83,7 @@ def test_conv2d_transpose_infer_type(): ...@@ -83,7 +83,7 @@ def test_conv2d_transpose_infer_type():
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32")) x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32"))
w = ib.param("w", relay.ty.TensorType((12, 11, 5, 5), "float32")) w = ib.param("w", relay.ty.TensorType((12, 11, 5, 5), "float32"))
with ib.function(x, w) as func: with ib.function(x, w) as func:
ib.ret(relay.nn.conv2d_transpose(x.var, w.var, ib.ret(relay.nn.conv2d_transpose(x, w,
output_padding=(1, 1), output_padding=(1, 1),
channels=11, channels=11,
data_layout="NHWC")) data_layout="NHWC"))
...@@ -98,7 +98,7 @@ def test_upsampling_infer_type(): ...@@ -98,7 +98,7 @@ def test_upsampling_infer_type():
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32")) x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32"))
with ib.function(x) as func: with ib.function(x) as func:
ib.ret(relay.nn.upsampling(x.var, scale=2, layout="NCHW", method="BILINEAR")) ib.ret(relay.nn.upsampling(x, scale=2, layout="NCHW", method="BILINEAR"))
ib.ret(func) ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func()) func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type ftype = func.checked_type
...@@ -108,7 +108,7 @@ def test_upsampling_infer_type(): ...@@ -108,7 +108,7 @@ def test_upsampling_infer_type():
n, c = tvm.var("n"), tvm.var("c") n, c = tvm.var("n"), tvm.var("c")
x = ib.param("x", relay.ty.TensorType((n, c, 100, 200), "float32")) x = ib.param("x", relay.ty.TensorType((n, c, 100, 200), "float32"))
with ib.function(x) as func: with ib.function(x) as func:
ib.ret(relay.nn.upsampling(x.var, scale=2, layout="NCHW", method="BILINEAR")) ib.ret(relay.nn.upsampling(x, scale=2, layout="NCHW", method="BILINEAR"))
ib.ret(func) ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func()) func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type ftype = func.checked_type
...@@ -119,7 +119,7 @@ def _test_pool2d_infer_type(opfunc): ...@@ -119,7 +119,7 @@ def _test_pool2d_infer_type(opfunc):
n, c, h, w = tvm.var("n"), 10, 224, 224 n, c, h, w = tvm.var("n"), 10, 224, 224
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32")) x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32"))
with ib.function(x) as func: with ib.function(x) as func:
ib.ret(opfunc(x.var, pool_size=(1, 1))) ib.ret(opfunc(x, pool_size=(1, 1)))
ib.ret(func) ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func()) func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type ftype = func.checked_type
...@@ -132,7 +132,7 @@ def _test_pool2d_infer_type(opfunc): ...@@ -132,7 +132,7 @@ def _test_pool2d_infer_type(opfunc):
n, c, h, w = tvm.var("n"), 10, 224, 224 n, c, h, w = tvm.var("n"), 10, 224, 224
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32")) x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32"))
with ib.function(x) as func: with ib.function(x) as func:
ib.ret(opfunc(x.var, pool_size=(ph, pw), strides=(sh, sw))) ib.ret(opfunc(x, pool_size=(ph, pw), strides=(sh, sw)))
ib.ret(func) ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func()) func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type ftype = func.checked_type
...@@ -144,7 +144,7 @@ def _test_global_pool2d_infer_type(opfunc): ...@@ -144,7 +144,7 @@ def _test_global_pool2d_infer_type(opfunc):
n, c, h, w = tvm.var("n"), tvm.var("c"), 224, 224 n, c, h, w = tvm.var("n"), tvm.var("c"), 224, 224
x = ib.param("x", relay.ty.TensorType((n, h, w, c), "float32")) x = ib.param("x", relay.ty.TensorType((n, h, w, c), "float32"))
with ib.function(x) as func: with ib.function(x) as func:
ib.ret(opfunc(x.var, layout="NHWC")) ib.ret(opfunc(x, layout="NHWC"))
ib.ret(func) ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func()) func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type ftype = func.checked_type
...@@ -154,7 +154,7 @@ def _test_global_pool2d_infer_type(opfunc): ...@@ -154,7 +154,7 @@ def _test_global_pool2d_infer_type(opfunc):
n, c, h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") n, c, h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32")) x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32"))
with ib.function(x) as func: with ib.function(x) as func:
ib.ret(opfunc(x.var)) ib.ret(opfunc(x))
ib.ret(func) ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func()) func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type ftype = func.checked_type
...@@ -172,7 +172,7 @@ def test_flatten_infer_type(): ...@@ -172,7 +172,7 @@ def test_flatten_infer_type():
x = ib.param("x", relay.ty.TensorType((d1, d2, d3, d4), "float32")) x = ib.param("x", relay.ty.TensorType((d1, d2, d3, d4), "float32"))
with ib.function(x) as func: with ib.function(x) as func:
ib.ret(relay.nn.batch_flatten(x.var)) ib.ret(relay.nn.batch_flatten(x))
ib.ret(func) ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func()) func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type ftype = func.checked_type
...@@ -181,7 +181,7 @@ def test_flatten_infer_type(): ...@@ -181,7 +181,7 @@ def test_flatten_infer_type():
ib = relay.ir_builder.IRBuilder() ib = relay.ir_builder.IRBuilder()
x = ib.param("x", relay.ty.TensorType((3, 2, 4, 3), "float32")) x = ib.param("x", relay.ty.TensorType((3, 2, 4, 3), "float32"))
with ib.function(x) as func: with ib.function(x) as func:
ib.ret(relay.nn.batch_flatten(x.var)) ib.ret(relay.nn.batch_flatten(x))
ib.ret(func) ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func()) func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type ftype = func.checked_type
...@@ -190,7 +190,7 @@ def test_flatten_infer_type(): ...@@ -190,7 +190,7 @@ def test_flatten_infer_type():
ib = relay.ir_builder.IRBuilder() ib = relay.ir_builder.IRBuilder()
x = ib.param("x", relay.ty.TensorType((d1, 2, d3, 3), "float32")) x = ib.param("x", relay.ty.TensorType((d1, 2, d3, 3), "float32"))
with ib.function(x) as func: with ib.function(x) as func:
ib.ret(relay.nn.batch_flatten(x.var)) ib.ret(relay.nn.batch_flatten(x))
ib.ret(func) ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func()) func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type ftype = func.checked_type
...@@ -202,7 +202,7 @@ def test_pad_infer_type(): ...@@ -202,7 +202,7 @@ def test_pad_infer_type():
n, c, h, w = 1, 2, 3, 4 n, c, h, w = 1, 2, 3, 4
t = ib.param("t", relay.TensorType((n, c, h, w), "float32")) t = ib.param("t", relay.TensorType((n, c, h, w), "float32"))
with ib.function(t) as func: with ib.function(t) as func:
ib.ret(relay.nn.pad(t.var, ((1, 1), (2, 2), (3, 3), (4, 4)))) ib.ret(relay.nn.pad(t, ((1, 1), (2, 2), (3, 3), (4, 4))))
ib.ret(func) ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func()) func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type ftype = func.checked_type
...@@ -213,7 +213,7 @@ def test_pad_infer_type(): ...@@ -213,7 +213,7 @@ def test_pad_infer_type():
n, c, h, w = tvm.var("n"), 2, 3, tvm.var("w") n, c, h, w = tvm.var("n"), 2, 3, tvm.var("w")
t = ib.param("t", relay.TensorType((n, c, h, w), "float32")) t = ib.param("t", relay.TensorType((n, c, h, w), "float32"))
with ib.function(t) as func: with ib.function(t) as func:
ib.ret(relay.nn.pad(t.var, ((1, 1), (2, 2), (3, 3), (4, 4)))) ib.ret(relay.nn.pad(t, ((1, 1), (2, 2), (3, 3), (4, 4))))
ib.ret(func) ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func()) func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type ftype = func.checked_type
...@@ -227,4 +227,3 @@ if __name__ == "__main__": ...@@ -227,4 +227,3 @@ if __name__ == "__main__":
test_flatten_infer_type() test_flatten_infer_type()
test_pad_infer_type() test_pad_infer_type()
test_conv2d_transpose_infer_type() test_conv2d_transpose_infer_type()
...@@ -17,12 +17,13 @@ def test_zeros_ones(): ...@@ -17,12 +17,13 @@ def test_zeros_ones():
ftype = func.checked_type ftype = func.checked_type
assert ftype.ret_type == relay.TensorType((124, 50), "float64") assert ftype.ret_type == relay.TensorType((124, 50), "float64")
def test_unary_identity(): def test_unary_identity():
for op in [relay.zeros_like, relay.ones_like]: for op in [relay.zeros_like, relay.ones_like]:
ib = relay.ir_builder.IRBuilder() ib = relay.ir_builder.IRBuilder()
x = ib.param("x", relay.TensorType((8, 9, 4), "int32")) x = ib.param("x", relay.TensorType((8, 9, 4), "int32"))
with ib.function(x) as func: with ib.function(x) as func:
ib.ret(op(x.var)) ib.ret(op(x))
ib.ret(func) ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func()) func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type ftype = func.checked_type
...@@ -33,7 +34,7 @@ def test_clip_type(): ...@@ -33,7 +34,7 @@ def test_clip_type():
ib = relay.ir_builder.IRBuilder() ib = relay.ir_builder.IRBuilder()
a = ib.param("a", relay.TensorType((10, 4), "float32")) a = ib.param("a", relay.TensorType((10, 4), "float32"))
with ib.function(a) as func: with ib.function(a) as func:
ib.ret(relay.clip(a.var, 1., 4.)) ib.ret(relay.clip(a, 1., 4.))
ib.ret(func) ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func()) func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type ftype = func.checked_type
...@@ -106,7 +107,7 @@ def test_take_infer_type(): ...@@ -106,7 +107,7 @@ def test_take_infer_type():
x = ib.param("x", relay.ty.TensorType(dshape, "float32")) x = ib.param("x", relay.ty.TensorType(dshape, "float32"))
indices = ib.param("indices", relay.ty.TensorType(indices_shape, "int32")) indices = ib.param("indices", relay.ty.TensorType(indices_shape, "int32"))
with ib.function(x, indices) as func: with ib.function(x, indices) as func:
ib.ret(relay.take(x.var, indices.var, axis=axis)) ib.ret(relay.take(x, indices, axis=axis))
ib.ret(func) ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func()) func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type ftype = func.checked_type
...@@ -127,7 +128,7 @@ def test_full(): ...@@ -127,7 +128,7 @@ def test_full():
ib = relay.ir_builder.IRBuilder() ib = relay.ir_builder.IRBuilder()
x = ib.param("x", relay.TensorType((), "int8")) x = ib.param("x", relay.TensorType((), "int8"))
with ib.function(x) as func: with ib.function(x) as func:
ib.ret(relay.full(x.var, ())) ib.ret(relay.full(x, ()))
ib.ret(func) ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func()) func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type ftype = func.checked_type
...@@ -137,7 +138,7 @@ def test_full(): ...@@ -137,7 +138,7 @@ def test_full():
ib = relay.ir_builder.IRBuilder() ib = relay.ir_builder.IRBuilder()
x = ib.param("x", relay.TensorType((), "float32")) x = ib.param("x", relay.TensorType((), "float32"))
with ib.function(x) as func: with ib.function(x) as func:
ib.ret(relay.full(x.var, (1, 2), "int8")) ib.ret(relay.full(x, (1, 2), "int8"))
ib.ret(func) ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func()) func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type ftype = func.checked_type
...@@ -150,7 +151,7 @@ def test_full_like(): ...@@ -150,7 +151,7 @@ def test_full_like():
base = ib.param("base", relay.TensorType((1, 2, 3), "float32")) base = ib.param("base", relay.TensorType((1, 2, 3), "float32"))
fill = ib.param("fill", relay.TensorType((), "float32")) fill = ib.param("fill", relay.TensorType((), "float32"))
with ib.function(base, fill) as func: with ib.function(base, fill) as func:
ib.ret(relay.full_like(base.var, fill.var)) ib.ret(relay.full_like(base, fill))
ib.ret(func) ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func()) func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type ftype = func.checked_type
...@@ -162,7 +163,7 @@ def test_full_like(): ...@@ -162,7 +163,7 @@ def test_full_like():
base = ib.param("base", relay.TensorType((n, c, h, w), "float32")) base = ib.param("base", relay.TensorType((n, c, h, w), "float32"))
fill = ib.param("fill", relay.TensorType((), "float32")) fill = ib.param("fill", relay.TensorType((), "float32"))
with ib.function(base, fill) as func: with ib.function(base, fill) as func:
ib.ret(relay.full_like(base.var, fill.var)) ib.ret(relay.full_like(base, fill))
ib.ret(func) ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func()) func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type ftype = func.checked_type
......
...@@ -24,7 +24,7 @@ def test_cmp_type(): ...@@ -24,7 +24,7 @@ def test_cmp_type():
x = ib.param("x", relay.TensorType((10, 4), "float32")) x = ib.param("x", relay.TensorType((10, 4), "float32"))
y = ib.param("y", relay.TensorType((5, 10, 1), "float32")) y = ib.param("y", relay.TensorType((5, 10, 1), "float32"))
with ib.function(x, y) as func: with ib.function(x, y) as func:
ib.ret(op(x.var, y.var)) ib.ret(op(x, y))
ib.ret(func) ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func()) func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type ftype = func.checked_type
...@@ -39,7 +39,7 @@ def test_binary_broadcast(): ...@@ -39,7 +39,7 @@ def test_binary_broadcast():
x = ib.param("x", relay.TensorType((10, 4), "int32")) x = ib.param("x", relay.TensorType((10, 4), "int32"))
y = ib.param("y", relay.TensorType((5, 10, 1), "int32")) y = ib.param("y", relay.TensorType((5, 10, 1), "int32"))
with ib.function(x, y) as func: with ib.function(x, y) as func:
ib.ret(op(x.var, y.var)) ib.ret(op(x, y))
ib.ret(func) ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func()) func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type ftype = func.checked_type
...@@ -58,7 +58,7 @@ def test_binary_op(): ...@@ -58,7 +58,7 @@ def test_binary_op():
x = b.param('x', tensor_type(5, 5, 5)) x = b.param('x', tensor_type(5, 5, 5))
y = b.param('y', tensor_type(5, 5, 5)) y = b.param('y', tensor_type(5, 5, 5))
with b.function(x, y) as func: with b.function(x, y) as func:
b.ret(opfunc(x.var, y.var)) b.ret(opfunc(x, y))
b.ret(func) b.ret(func)
prog, env = b.get() prog, env = b.get()
ttype = tensor_type(5, 5, 5) ttype = tensor_type(5, 5, 5)
...@@ -81,7 +81,7 @@ def test_binary_broadcast_op(): ...@@ -81,7 +81,7 @@ def test_binary_broadcast_op():
x = b.param('x', tensor_type(10, 4)) x = b.param('x', tensor_type(10, 4))
y = b.param('y', tensor_type(5, 10, 1)) y = b.param('y', tensor_type(5, 10, 1))
with b.function(x, y) as func: with b.function(x, y) as func:
b.ret(opfunc(x.var, y.var)) b.ret(opfunc(x, y))
b.ret(func) b.ret(func)
prog, env = b.get() prog, env = b.get()
...@@ -103,7 +103,7 @@ def test_cmp_type(): ...@@ -103,7 +103,7 @@ def test_cmp_type():
x = ib.param("x", relay.TensorType((10, 4), "float32")) x = ib.param("x", relay.TensorType((10, 4), "float32"))
y = ib.param("y", relay.TensorType((5, 10, 1), "float32")) y = ib.param("y", relay.TensorType((5, 10, 1), "float32"))
with ib.function(x, y) as func: with ib.function(x, y) as func:
ib.ret(op(x.var, y.var)) ib.ret(op(x, y))
ib.ret(func) ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func()) func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type ftype = func.checked_type
...@@ -118,7 +118,7 @@ def test_binary_broadcast(): ...@@ -118,7 +118,7 @@ def test_binary_broadcast():
x = ib.param("x", relay.TensorType((10, 4), "int32")) x = ib.param("x", relay.TensorType((10, 4), "int32"))
y = ib.param("y", relay.TensorType((5, 10, 1), "int32")) y = ib.param("y", relay.TensorType((5, 10, 1), "int32"))
with ib.function(x, y) as func: with ib.function(x, y) as func:
ib.ret(op(x.var, y.var)) ib.ret(op(x, y))
ib.ret(func) ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func()) func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type ftype = func.checked_type
...@@ -131,7 +131,7 @@ def test_where(): ...@@ -131,7 +131,7 @@ def test_where():
x = ib.param("x", relay.TensorType((3, 4), "float32")) x = ib.param("x", relay.TensorType((3, 4), "float32"))
y = ib.param("y", relay.TensorType((3, 4), "float32")) y = ib.param("y", relay.TensorType((3, 4), "float32"))
with ib.function(cond, x, y) as func: with ib.function(cond, x, y) as func:
ib.ret(relay.where(cond.var, x.var, y.var)) ib.ret(relay.where(cond, x, y))
ib.ret(func) ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func()) func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type ftype = func.checked_type
......
...@@ -10,7 +10,7 @@ def test_resize_infer_type(): ...@@ -10,7 +10,7 @@ def test_resize_infer_type():
th, tw = tvm.var("th"), tvm.var("tw") th, tw = tvm.var("th"), tvm.var("tw")
with ib.function(x) as func: with ib.function(x) as func:
ib.ret(relay.image.resize(x.var, (th, tw))) ib.ret(relay.image.resize(x, (th, tw)))
ib.ret(func) ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func()) func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type ftype = func.checked_type
...@@ -19,7 +19,7 @@ def test_resize_infer_type(): ...@@ -19,7 +19,7 @@ def test_resize_infer_type():
ib = relay.ir_builder.IRBuilder() ib = relay.ir_builder.IRBuilder()
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "int8")) x = ib.param("x", relay.ty.TensorType((n, c, h, w), "int8"))
with ib.function(x) as func: with ib.function(x) as func:
ib.ret(relay.image.resize(x.var, (100, 200), "NCHW", "BILINEAR", False)) ib.ret(relay.image.resize(x, (100, 200), "NCHW", "BILINEAR", False))
ib.ret(func) ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func()) func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type ftype = func.checked_type
......
import tvm import tvm
import numpy as np
from tvm import relay from tvm import relay
from tvm.relay.ir_pass import alpha_equal from tvm.relay.ir_pass import alpha_equal
from tvm.relay.ir_builder import convert from tvm.relay.ir_builder import convert
...@@ -179,9 +180,9 @@ def test_var_alpha_equal(): ...@@ -179,9 +180,9 @@ def test_var_alpha_equal():
assert not alpha_equal(v1, v2) assert not alpha_equal(v1, v2)
# let node allows for setting the eq_map # let node allows for setting the eq_map
l1 = relay.Let(v1, convert(1), v1, None) l1 = relay.Let(v1, convert(1), v1)
l2 = relay.Let(v2, convert(1), v2, None) l2 = relay.Let(v2, convert(1), v2)
l3 = relay.Let(v1, convert(1), v2, None) l3 = relay.Let(v1, convert(1), v2)
assert alpha_equal(l1, l2) assert alpha_equal(l1, l2)
assert not alpha_equal(l1, l3) assert not alpha_equal(l1, l3)
...@@ -209,10 +210,10 @@ def test_tuple_alpha_equal(): ...@@ -209,10 +210,10 @@ def test_tuple_alpha_equal():
assert alpha_equal(tup, same) assert alpha_equal(tup, same)
# use the eq_map # use the eq_map
let_tup = relay.Let(v1, tup, v1, None) let_tup = relay.Let(v1, tup, v1)
let_mapped = relay.Let(v2, relay.Tuple([v2, convert(2), convert(3), let_mapped = relay.Let(v2, relay.Tuple([v2, convert(2), convert(3),
relay.Tuple([convert(4)])]), relay.Tuple([convert(4)])]),
v2, None) v2)
assert alpha_equal(let_tup, let_mapped) assert alpha_equal(let_tup, let_mapped)
more_fields = relay.Tuple([v1, convert(2), convert(3), relay.Tuple([convert(4)]), v2]) more_fields = relay.Tuple([v1, convert(2), convert(3), relay.Tuple([convert(4)]), v2])
...@@ -242,61 +243,44 @@ def test_tuple_get_item_alpha_equal(): ...@@ -242,61 +243,44 @@ def test_tuple_get_item_alpha_equal():
assert alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 1)) assert alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 1))
def test_param_alpha_equal():
# only checks equality of the types
v1 = relay.Var("v1")
v2 = relay.Var("v2")
p1 = relay.Param(v1, relay.TensorType((1, 2, 3), "float32"))
p2 = relay.Param(v2, relay.TensorType((1, 2, 3), "float32"))
assert alpha_equal(p1, p2)
p3 = relay.Param(v1, relay.TensorType((4, 5, 6), "int8"))
assert not alpha_equal(p1, p3)
p4 = relay.Param(v1, relay.TupleType([relay.TensorType((1, 2, 3),
"float32")]))
assert not alpha_equal(p1, p4)
def test_function_alpha_equal(): def test_function_alpha_equal():
v1 = relay.Var("v1")
v2 = relay.Var("v2")
v3 = relay.Var("v3")
v4 = relay.Var("v4")
tt1 = relay.TensorType((1, 2, 3), "float32") tt1 = relay.TensorType((1, 2, 3), "float32")
tt2 = relay.TensorType((4, 5, 6), "int8") tt2 = relay.TensorType((4, 5, 6), "int8")
tt3 = relay.TupleType([tt1, tt2]) tt3 = relay.TupleType([tt1, tt2])
v1 = relay.Var("v1", tt1)
v2 = relay.Var("v2", tt2)
v3 = relay.Var("v3", tt3)
v4 = relay.Var("v4", tt2)
vret = relay.Constant(tvm.nd.array(np.ones(1)))
tp1 = relay.TypeParam("tp1", relay.Kind.Type) tp1 = relay.TypeParam("tp1", relay.Kind.Type)
tp2 = relay.TypeParam("tp2", relay.Kind.Type) tp2 = relay.TypeParam("tp2", relay.Kind.Type)
tp3 = relay.TypeParam("tp3", relay.Kind.Shape) tp3 = relay.TypeParam("tp3", relay.Kind.Shape)
tp4 = relay.TypeParam("tp4", relay.Kind.Shape) tp4 = relay.TypeParam("tp4", relay.Kind.Shape)
basic_args = [relay.Param(v3, tt1), relay.Param(v4, tt2)] basic_args = [relay.Var("v3", tt1), relay.Var("v4", tt2)]
basic_tps = [tp1, tp2] basic_tps = [tp1, tp2]
func = relay.Function([relay.Param(v1, tt1), relay.Param(v2, tt2)], func = relay.Function([v1, v2],
tt2, v2, basic_tps) tt2, v1, basic_tps)
mapped = relay.Function(basic_args, tt2, v4, basic_tps) mapped = relay.Function(basic_args, tt2, basic_args[0], basic_tps)
assert alpha_equal(func, mapped) assert alpha_equal(func, mapped)
fewer_params = relay.Function([relay.Param(v4, tt2)], tt2, v4, basic_tps) fewer_params = relay.Function([relay.Var("v4", tt2)], tt2, v4, basic_tps)
assert not alpha_equal(func, fewer_params) assert not alpha_equal(func, fewer_params)
more_params = relay.Function([relay.Param(v3, tt1), relay.Param(v4, tt2), more_params = relay.Function([relay.Var("v3", tt1),
relay.Param(v2, tt2)], tt2, v4, basic_tps) relay.Var("v4", tt2),
relay.Var("v2", tt2)], tt2, v4, basic_tps)
assert not alpha_equal(func, more_params) assert not alpha_equal(func, more_params)
params_unordered = relay.Function([relay.Param(v3, tt2), params_unordered = relay.Function([v2, v1],
relay.Param(v4, tt1)], tt2, v1, basic_tps)
tt1, v3, basic_tps)
assert not alpha_equal(func, params_unordered) assert not alpha_equal(func, params_unordered)
params_mismatch = relay.Function([relay.Param(v3, tt3), params_mismatch = relay.Function([v1, v3],
relay.Param(v4, tt2)], tt2, v1, basic_tps)
tt2, v4, basic_tps)
assert not alpha_equal(func, params_mismatch) assert not alpha_equal(func, params_mismatch)
# also would not typecheck # also would not typecheck
...@@ -376,7 +360,10 @@ def test_call_alpha_equal(): ...@@ -376,7 +360,10 @@ def test_call_alpha_equal():
def test_let_alpha_equal(): def test_let_alpha_equal():
tt1 = relay.TensorType((), "float32")
tt2 = relay.TensorType((), "int8")
v1 = relay.Var("v1") v1 = relay.Var("v1")
v1_wtype = relay.Var("v1", tt1)
v2 = relay.Var("v2") v2 = relay.Var("v2")
v3 = relay.Var("v3") v3 = relay.Var("v3")
...@@ -394,14 +381,13 @@ def test_let_alpha_equal(): ...@@ -394,14 +381,13 @@ def test_let_alpha_equal():
assert not alpha_equal(let, different_body) assert not alpha_equal(let, different_body)
# specified types must match # specified types must match
tt1 = relay.TensorType((), "float32")
tt2 = relay.TensorType((), "int8") let_with_type = relay.Let(v1_wtype, convert(2), v1_wtype)
let_with_type = relay.Let(v1, convert(2), v1, tt1) same_type = relay.Let(v1_wtype, convert(2), v1_wtype)
same_type = relay.Let(v1, convert(2), v1, tt1)
assert alpha_equal(let_with_type, same_type) assert alpha_equal(let_with_type, same_type)
assert not alpha_equal(let, let_with_type) assert not alpha_equal(let, let_with_type)
v2 = relay.Var("v1", tt2)
different_type = relay.Let(v1, convert(2), v1, tt2) different_type = relay.Let(v2, convert(2), v2)
assert not alpha_equal(let_with_type, different_type) assert not alpha_equal(let_with_type, different_type)
...@@ -437,16 +423,13 @@ if __name__ == "__main__": ...@@ -437,16 +423,13 @@ if __name__ == "__main__":
test_tensor_type_alpha_equal() test_tensor_type_alpha_equal()
test_incomplete_type_alpha_equal() test_incomplete_type_alpha_equal()
test_constant_alpha_equal() test_constant_alpha_equal()
test_type_param_alpha_equal()
test_func_type_alpha_equal() test_func_type_alpha_equal()
test_tuple_type_alpha_equal() test_tuple_type_alpha_equal()
test_type_relation_alpha_equal() test_type_relation_alpha_equal()
test_constant_alpha_equal() test_constant_alpha_equal()
test_var_alpha_equal()
test_global_var_alpha_equal() test_global_var_alpha_equal()
test_tuple_alpha_equal() test_tuple_alpha_equal()
test_tuple_get_item_alpha_equal() test_tuple_get_item_alpha_equal()
test_param_alpha_equal()
test_function_alpha_equal() test_function_alpha_equal()
test_call_alpha_equal() test_call_alpha_equal()
test_let_alpha_equal() test_let_alpha_equal()
......
...@@ -28,17 +28,17 @@ e = env() ...@@ -28,17 +28,17 @@ e = env()
def test_let(): def test_let():
orig = relay.Let(e.x, e.y, e.z, e.tt) orig = relay.Let(e.x, e.y, e.z)
assert alpha_equal(dead_code_elimination(orig), e.z) assert alpha_equal(dead_code_elimination(orig), e.z)
def test_used_let(): def test_used_let():
orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c, e.tt), e.tt) orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c))
assert alpha_equal(dead_code_elimination(orig), relay.Let(e.c, e.d, e.c, e.tt)) assert alpha_equal(dead_code_elimination(orig), relay.Let(e.c, e.d, e.c))
def test_chain_unused_let(): def test_chain_unused_let():
orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.e, e.tt), e.tt) orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.e))
assert alpha_equal(dead_code_elimination(orig), e.e) assert alpha_equal(dead_code_elimination(orig), e.e)
...@@ -56,19 +56,17 @@ def test_recursion(): ...@@ -56,19 +56,17 @@ def test_recursion():
f(2, 10000); f(2, 10000);
""" """
f = relay.Var("f") f = relay.Var("f")
n = relay.Var("n") n = relay.Var("n", e.int32)
np = relay.Param(n, e.int32) data = relay.Var("data", e.float32)
data = relay.Var("data")
datap = relay.Param(data, e.float32)
funcbody = relay.If(equal(n, convert(0)), data, f(subtract(n, convert(1.0)), log(data))) funcbody = relay.If(equal(n, convert(0)), data, f(subtract(n, convert(1.0)), log(data)))
value = relay.Function([np, datap], e.float32, funcbody, []) value = relay.Function([n, data], e.float32, funcbody, [])
orig = relay.Let(f, funcbody, f(convert(2.0), convert(10000.0)), e.float32) orig = relay.Let(f, funcbody, f(convert(2.0), convert(10000.0)))
assert alpha_equal(dead_code_elimination(orig), orig) assert alpha_equal(dead_code_elimination(orig), orig)
assert alpha_equal(dead_code_elimination(relay.Let(f, funcbody, e.three, e.float32)), e.three) assert alpha_equal(dead_code_elimination(relay.Let(f, funcbody, e.three)), e.three)
def test_op_let(): def test_op_let():
assert alpha_equal(dead_code_elimination(add(relay.Let(e.a, e.one, e.three, e.float32), e.two)), add(e.three, e.two)) assert alpha_equal(dead_code_elimination(add(relay.Let(e.a, e.one, e.three), e.two)), add(e.three, e.two))
def test_if(): def test_if():
...@@ -80,7 +78,7 @@ def test_tuple_get_item(): ...@@ -80,7 +78,7 @@ def test_tuple_get_item():
t = relay.Var('t') t = relay.Var('t')
g = relay.TupleGetItem(t, 0) g = relay.TupleGetItem(t, 0)
assert alpha_equal(dead_code_elimination(g), g) assert alpha_equal(dead_code_elimination(g), g)
assert alpha_equal(dead_code_elimination(relay.TupleGetItem(relay.Let(e.a, e.one, t, e.float32), 0)), g) assert alpha_equal(dead_code_elimination(relay.TupleGetItem(relay.Let(e.a, e.one, t), 0)), g)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -3,16 +3,17 @@ from tvm import relay ...@@ -3,16 +3,17 @@ from tvm import relay
from tvm.relay.ir_pass import free_vars, free_type_vars from tvm.relay.ir_pass import free_vars, free_type_vars
def test_free_vars(): def test_free_vars():
x = relay.Var("x") ty = relay.TensorType([], "int32")
x = relay.Var("x", ty)
fvx = free_vars(x) fvx = free_vars(x)
assert len(fvx) == 1 assert len(fvx) == 1
assert fvx[0] == x assert fvx[0] == x
v = relay.Constant(tvm.nd.array(10)) v = relay.Constant(tvm.nd.array(10))
ty = relay.TensorType([], "int32")
let = relay.Let(x, v, x, ty) let = relay.Let(x, v, x)
fvx = free_vars(let) fvx = free_vars(let)
assert len(free_vars(let)) == 0 assert len(free_vars(let)) == 0
f = relay.Function([relay.Param(x, ty)], ty, x) f = relay.Function([x], ty, x)
assert len(free_vars(f)) == 0 assert len(free_vars(f)) == 0
...@@ -29,9 +30,9 @@ def test_tuple(): ...@@ -29,9 +30,9 @@ def test_tuple():
def test_free_type_vars(): def test_free_type_vars():
tp = relay.TypeParam("") tp = relay.TypeParam("")
ty = relay.TupleType([tp, relay.TensorType([], "int32")]) ty = relay.TupleType([tp, relay.TensorType([], "int32")])
x = relay.Var("x") x = relay.Var("x", ty)
y = relay.Var("y") y = relay.Var("y")
let = relay.Let(x, y, x, ty) let = relay.Let(x, y, x)
fvl = free_vars(let) fvl = free_vars(let)
assert len(fvl) == 1 assert len(fvl) == 1
assert fvl[0] == y assert fvl[0] == y
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment