Unverified Commit 0b4cc050 by Tianqi Chen Committed by GitHub

[RELAY][IR] Move type_annotation to Var, remove Param (#1900)

parent 53428606
......@@ -118,17 +118,27 @@ class Var;
/*! \brief Container for Var */
class VarNode : public ExprNode {
public:
/*! \brief The name of the variable, this only acts as a hint to the user,
* and is not used for equality.
/*!
* \brief The name of the variable,
* this only acts as a hint to the user,
* and is not used for equality.
*/
std::string name_hint;
/*!
* \brief type annotaion of the variable.
* This field records user provided type annotation of the Var.
* This field is optional and can be None.
*/
Type type_annotation;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("name_hint", &name_hint);
v->Visit("type_annotation", &type_annotation);
v->Visit("_checked_type_", &checked_type_);
}
TVM_DLL static Var make(std::string name_hint);
TVM_DLL static Var make(std::string name_hint,
Type type_annotation);
static constexpr const char* _type_key = "relay.Var";
TVM_DECLARE_NODE_TYPE_INFO(VarNode, ExprNode);
......@@ -163,32 +173,6 @@ class GlobalVarNode : public ExprNode {
RELAY_DEFINE_NODE_REF(GlobalVar, GlobalVarNode, Expr);
/*!
* \brief Function parameter declaration.
*/
class Param;
/*! \brief A parameter. */
class ParamNode : public ExprNode {
public:
/*! \brief The variable */
Var var;
/*! \brief The type of the parameter */
Type type;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("var", &var);
v->Visit("type", &type);
v->Visit("span", &span);
}
TVM_DLL static Param make(Var var, Type type);
static constexpr const char* _type_key = "relay.Param";
TVM_DECLARE_NODE_TYPE_INFO(ParamNode, ExprNode);
};
RELAY_DEFINE_NODE_REF(Param, ParamNode, Expr);
/*!
* \brief Function (subgraph in computational graph)
*/
class Function;
......@@ -196,7 +180,7 @@ class Function;
class FunctionNode : public ExprNode {
public:
/*! \brief Function parameters */
tvm::Array<Param> params;
tvm::Array<Var> params;
/*! \brief User annotated return type of the function. */
Type ret_type;
/*!
......@@ -224,10 +208,18 @@ class FunctionNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_);
}
Type fn_type() const;
/*!
* \brief Return the derived function annotation of this expression.
*
* \return The function type annotation.
* \note The function type annotation can contain IncompleteType.
*/
TVM_DLL FuncType func_type_annotation() const;
TVM_DLL static Function make(tvm::Array<Param> params, Type ret_type,
Expr body, tvm::Array<TypeParam> ty_params);
TVM_DLL static Function make(tvm::Array<Var> params,
Type ret_type,
Expr body,
tvm::Array<TypeParam> ty_params);
static constexpr const char* _type_key = "relay.Function";
TVM_DECLARE_NODE_TYPE_INFO(FunctionNode, ExprNode);
......@@ -289,7 +281,7 @@ class CallNode : public ExprNode {
TVM_DLL static Call make(Expr op,
Array<Expr> args,
Attrs attrs = Attrs(),
Array<Type> ty_args = Array<Type>());
Array<Type> type_args = Array<Type>());
static constexpr const char* _type_key = "relay.Call";
TVM_DECLARE_NODE_TYPE_INFO(CallNode, ExprNode);
......@@ -318,19 +310,16 @@ class LetNode : public ExprNode {
Expr value;
/*! \brief The body of the let binding */
Expr body;
/*! \brief Type annotation of value, this can be null */
Type value_type;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("var", &var);
v->Visit("value", &value);
v->Visit("body", &body);
v->Visit("value_type", &value_type);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
TVM_DLL static Let make(Var var, Expr value, Expr body, Type value_type);
TVM_DLL static Let make(Var var, Expr value, Expr body);
static constexpr const char* _type_key = "relay.Let";
TVM_DECLARE_NODE_TYPE_INFO(LetNode, ExprNode);
......@@ -376,11 +365,11 @@ class IfNode : public ExprNode {
RELAY_DEFINE_NODE_REF(If, IfNode, Expr);
/*! \brief Get a field out of a tuple. */
/*! \brief Get index-th field out of a tuple. */
class TupleGetItem;
class TupleGetItemNode : public ExprNode {
public:
/*! \brief The tuple */
/*! \brief The tuple Expression */
Expr tuple;
/*! \brief which value to get */
int index;
......
......@@ -80,7 +80,6 @@ class ExprFunctor<R(const Expr& n, Args...)> {
Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const GlobalVarNode* op,
Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const ParamNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const FunctionNode* op,
Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
......@@ -103,7 +102,6 @@ class ExprFunctor<R(const Expr& n, Args...)> {
RELAY_EXPR_FUNCTOR_DISPATCH(TupleNode);
RELAY_EXPR_FUNCTOR_DISPATCH(VarNode);
RELAY_EXPR_FUNCTOR_DISPATCH(GlobalVarNode);
RELAY_EXPR_FUNCTOR_DISPATCH(ParamNode);
RELAY_EXPR_FUNCTOR_DISPATCH(FunctionNode);
RELAY_EXPR_FUNCTOR_DISPATCH(CallNode);
RELAY_EXPR_FUNCTOR_DISPATCH(LetNode);
......@@ -127,7 +125,6 @@ class ExprVisitor : public ::tvm::relay::ExprFunctor<void(const Expr& n)> {
void VisitExpr_(const GlobalVarNode* op) override;
void VisitExpr_(const ConstantNode* op) override;
void VisitExpr_(const TupleNode* op) override;
void VisitExpr_(const ParamNode* op) override;
void VisitExpr_(const FunctionNode* op) override;
void VisitExpr_(const CallNode* op) override;
void VisitExpr_(const LetNode* op) override;
......@@ -151,7 +148,6 @@ class ExprMutator
Expr VisitExpr_(const GlobalVarNode* op) override;
Expr VisitExpr_(const OpNode* op) override;
Expr VisitExpr_(const TupleNode* op) override;
Expr VisitExpr_(const ParamNode* op) override;
Expr VisitExpr_(const FunctionNode* op) override;
Expr VisitExpr_(const CallNode* call_node) override;
Expr VisitExpr_(const LetNode* op) override;
......
......@@ -34,7 +34,6 @@ Constant = expr.Constant
Tuple = expr.Tuple
Var = expr.Var
GlobalVar = expr.GlobalVar
Param = expr.Param
Function = expr.Function
Call = expr.Call
Let = expr.Let
......
......@@ -11,11 +11,11 @@ class Expr(NodeBase):
"""The base type for all Relay expressions."""
@property
def checked_type(self):
"""Get the checked type of relay.
"""Get the checked type of tvm.relay.Expr.
Returns
-------
checked_type : relay.Type
checked_type : tvm.relay.Type
The checked type.
"""
ret = self._checked_type_
......@@ -25,70 +25,97 @@ class Expr(NodeBase):
return ret
def __call__(self, *args):
converted_args = []
for arg in args:
if isinstance(arg, Param):
converted_args.append(arg.var)
else:
converted_args.append(arg)
return Call(self, args, None, None)
@register_relay_node
class Constant(Expr):
"""A constant tensor in Relay, see tvm/relay/type.h for more details.
"""
"""A constant expression in Relay.
Parameters
----------
data : tvm.nd.NDArray
The data content of the constant expression.
"""
def __init__(self, data):
self.__init_handle_by_constructor__(_make.Constant, data)
@register_relay_node
class Tuple(Expr):
"""A hetereogenous sequence of values.
see tvm/relay/type.h for more details.
"""
"""Tuple expression that groups several fields together.
Parameters
----------
fields : List[tvm.relay.Expr]
The fields in the tuple.
"""
def __init__(self, fields):
self.__init_handle_by_constructor__(_make.Tuple, fields)
@register_relay_node
class Var(Expr):
"""A local variable in Relay."""
"""A local variable in Tvm.Relay.
def __init__(self, name_hint):
self.__init_handle_by_constructor__(_make.Var, name_hint)
Local variable can be used to declare input
arguments to a function, or intermediate variables.
Parameters
----------
name_hint: str
The name of the variable.
This name only acts as a hint, and is not used
for equality.
type_annotation: tvm.relay.Type, optional
The type annotation on the variable.
"""
def __init__(self, name_hint, type_annotation=None):
self.__init_handle_by_constructor__(
_make.Var, name_hint, type_annotation)
@register_relay_node
class GlobalVar(Expr):
"""A global variable in Relay."""
"""A global variable in Tvm.Relay.
GlobalVar is used to refer to the global functions
stored in the environment.
Parameters
----------
name_hint: str
The name of the variable.
"""
def __init__(self, name_hint):
self.__init_handle_by_constructor__(_make.GlobalVar, name_hint)
@register_relay_node
class Param(Expr):
"""A function type in Relay, see tvm/relay/type.h for more details.
"""
class Function(Expr):
"""A function declaration expression.
def __init__(self, var, ty):
self.__init_handle_by_constructor__(_make.Param, var, ty)
Parameters
----------
params: List[tvm.relay.Var]
List of input parameters to the function.
ret_type: tvm.relay.Type
The return type annotation of the function.
@register_relay_node
class Function(Expr):
"""A function in Relay, see tvm/relay/expr.h for more details."""
body: tvm.relay.Expr
The body of the function.
type_params: Optional[List[tvm.relay.TypeParam]]
The additional type parameters, this is only
used in advanced usecase of template functions.
"""
def __init__(self,
params,
ret_type,
body,
type_params=None
):
type_params=None):
if type_params is None:
type_params = convert([])
......@@ -98,39 +125,87 @@ class Function(Expr):
@register_relay_node
class Call(Expr):
"""A function call in Relay, see tvm/relay/expr.h for more details."""
"""Function call node in Relay.
Call node corresponds the operator application node
in computational graph terminology.
Parameters
----------
op: tvm.relay.Op or any tvm.relay.Expr with function type.
The operation to be called.
def __init__(self, op, args, attrs, ty_args=None):
if not ty_args:
ty_args = []
args: List[tvm.relay.Expr]
The arguments to the call.
attrs: Optional[tvm.Attrs]
Attributes to the call, can be None
type_args: Optional[List[tvm.relay.Type]]
The additional type arguments, this is only
used in advanced usecase of template functions.
"""
def __init__(self, op, args, attrs=None, type_args=None):
if not type_args:
type_args = []
self.__init_handle_by_constructor__(
_make.Call, op, args, attrs, ty_args)
_make.Call, op, args, attrs, type_args)
@register_relay_node
class Let(Expr):
"""A variable bindings in Relay, see tvm/relay/expr.h for more details."""
"""Let variable binding expression.
Parameters
----------
var: tvm.relay.Var
The local variable to be bound.
value: tvm.relay.Expr
The value to be bound.
def __init__(self, var, value, body, value_type=None):
body: tvm.relay.Expr
The body of the let binding.
"""
def __init__(self, var, value, body):
self.__init_handle_by_constructor__(
_make.Let, var, value, body, value_type)
_make.Let, var, value, body)
@register_relay_node
class If(Expr):
"""A conditional expression in Relay, see tvm/relay/expr.h for more details."""
"""A conditional expression in Relay.
Parameters
----------
cond: tvm.relay.Expr
The condition.
def __init__(self, cond, true_value, false_value):
true_branch: tvm.relay.Expr
The expression evaluated when condition is true.
false_branch: tvm.relay.Expr
The expression evaluated when condition is false.
"""
def __init__(self, cond, true_branch, false_branch):
self.__init_handle_by_constructor__(
_make.If, cond, true_value, false_value)
_make.If, cond, true_branch, false_branch)
@register_relay_node
class TupleGetItem(Expr):
"""An expression that get field from tuple in Relay, see tvm/relay/expr.h for more details."""
"""Get index-th item from a tuple.
Parameters
----------
tuple_value: tvm.relay.Expr
The input tuple expression.
def __init__(self, tuple_, index):
index: int
The index.
"""
def __init__(self, tuple_value, index):
self.__init_handle_by_constructor__(
_make.TupleGetItem, tuple_, index)
_make.TupleGetItem, tuple_value, index)
debug_print = _expr._debug_print
......@@ -7,7 +7,7 @@ from collections import OrderedDict
import numpy as np
import tvm
from .ty import Type, FuncType, TensorType
from .expr import Expr, Constant, Let, Var, Param, Function, If
from .expr import Expr, Constant, Let, Var, Function, If
from .env import Environment
......@@ -98,7 +98,7 @@ class PartialFunc(object):
self.type_params = type_params
def param_ids(self):
return [p.var for p in self.params]
return [p for p in self.params]
def to_func(self):
"""Converts a PartialFunc into a :py:class:`~relay.Function`."""
......@@ -113,9 +113,8 @@ class PartialFunc(object):
def _mk_let(bindings, ret_value):
let_expr = ret_value
for var, (value, ty) in reversed(list(bindings.items())):
let_expr = Let(var, value, let_expr, ty)
for var, value in reversed(list(bindings.items())):
let_expr = Let(var, value, let_expr)
return let_expr
......@@ -168,15 +167,12 @@ class IRBuilder(object):
#pylint: disable=invalid-name
def bind(self, name, value, ty):
lv = Var(name)
lv = Var(name, ty)
self.scopes[-1][name] = lv
self.bindings[-1][lv] = (value, ty)
self.bindings[-1][lv] = value
return lv
def let(self, name, value, value_type=None):
if isinstance(value, Param):
value = value.var
if not isinstance(value, Expr):
value = convert(value)
......@@ -185,23 +181,18 @@ class IRBuilder(object):
def _convert_params(self, raw_params):
relay_params = []
for raw_param in raw_params:
if isinstance(raw_param, Param):
var = raw_param.var
if isinstance(raw_param, Var):
param = raw_param
elif isinstance(raw_param, tuple):
var, ty = raw_param
if isinstance(var, str):
var = Var(var)
ty = _convert_type(ty)
param = Param(var, ty)
elif isinstance(param, str):
var = Var(raw_param)
ty = None
param = Param(var, ty)
param = Var(var, ty)
elif isinstance(raw_param, str):
param = Var(raw_param, None)
else:
raise Exception("unknown parameter type")
self.scopes[-1][var.name_hint] = var
self.scopes[-1][param.name_hint] = param
relay_params.append(param)
return relay_params
......@@ -265,7 +256,7 @@ class IRBuilder(object):
else:
ty = _convert_type(ty)
return Param(Var(name), ty)
return Var(name, ty)
def global_var(self, name):
# type: (str) -> GlobalVar
......
......@@ -96,7 +96,9 @@ class TypeDocifier : private TypeFunctor<Doc(const Type& n)> {
}
std::vector<Doc> DocifyTypeParam(const tvm::Array<TypeParam>& arr) {
return MapDocify<TypeParam>(arr, [=](const TypeParam& tp) { return Docify(tp); });
return MapDocify<TypeParam>(arr, [=](const TypeParam& tp) {
return Docify(tp);
});
}
std::vector<Doc> DocifyTypeConstraint(const tvm::Array<TypeConstraint>& arr) {
......@@ -188,10 +190,11 @@ class ExprDocifier : private ExprFunctor<Doc(const Expr& n)> {
return vec;
}
std::vector<Doc> DocifyParamArray(const tvm::Array<Param>& arr) {
std::vector<Doc> DocifyParamArray(const tvm::Array<Var>& arr) {
std::vector<Doc> vec;
for (size_t i = 0; i < arr.size(); ++i) {
vec.push_back(Docify(arr[i]));
for (Var param : arr) {
vec.emplace_back(TypeAnnotation(DocOfStr(VarName(param)),
param->type_annotation));
}
return vec;
}
......@@ -212,10 +215,6 @@ class ExprDocifier : private ExprFunctor<Doc(const Expr& n)> {
return DocOfStr(g->name_hint);
}
Doc VisitExpr_(const ParamNode* p) final {
return TypeAnnotation(Docify(p->var), p->type);
}
Doc VisitExpr_(const FunctionNode* f) final {
return Group(TypeAnnotation(Seq("(", DocifyParamArray(f->params), ")"), f->ret_type) + Sep() +
DocOfStr("=>") + Sep() +
......@@ -227,7 +226,8 @@ class ExprDocifier : private ExprFunctor<Doc(const Expr& n)> {
}
Doc VisitExpr_(const LetNode* l) final {
return Group(DocOfStr("let") + Sep() + TypeAnnotation(Docify(l->var), l->value_type) + Sep() +
return Group(DocOfStr("let") + Sep() +
TypeAnnotation(Docify(l->var), l->var->type_annotation) + Sep() +
DocOfStr("=") + Sep() + Docify(l->value) + DocOfStr(";") + Endl() +
Docify(l->body));
}
......
......@@ -54,20 +54,26 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
p->stream << "Tuple(" << node->fields << ")";
});
Var VarNode::make(std::string name_hint) {
Var VarNode::make(std::string name_hint, Type type_annotation) {
NodePtr<VarNode> n = make_node<VarNode>();
n->name_hint = std::move(name_hint);
n->type_annotation = std::move(type_annotation);
return Var(n);
}
TVM_REGISTER_API("relay._make.Var")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = VarNode::make(args[0]);
*ret = VarNode::make(args[0], args[1]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<VarNode>([](const VarNode *node, tvm::IRPrinter *p) {
p->stream << "Var(" << node->name_hint << ")";
p->stream << "Var(" << node->name_hint;
if (node->type_annotation.defined()) {
p->stream << ", ty=";
p->print(node->type_annotation);
}
p->stream << ")";
});
GlobalVar GlobalVarNode::make(std::string name_hint) {
......@@ -86,24 +92,10 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
p->stream << "GlobalVar(" << node->name_hint << ")";
});
Param ParamNode::make(Var var, Type type) {
NodePtr<ParamNode> n = make_node<ParamNode>();
n->var = std::move(var);
n->type = std::move(type);
return Param(n);
}
TVM_REGISTER_API("relay._make.Param")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = ParamNode::make(args[0], args[1]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<ParamNode>([](const ParamNode *node, tvm::IRPrinter *p) {
p->stream << "Param(" << node->var << ", " << node->type << ")";
});
Function FunctionNode::make(tvm::Array<Param> params, Type ret_type, Expr body,
Function FunctionNode::make(tvm::Array<Var> params,
Type ret_type,
Expr body,
tvm::Array<TypeParam> type_params) {
NodePtr<FunctionNode> n = make_node<FunctionNode>();
n->params = std::move(params);
......@@ -113,12 +105,11 @@ Function FunctionNode::make(tvm::Array<Param> params, Type ret_type, Expr body,
return Function(n);
}
Type FunctionNode::fn_type() const {
FuncType FunctionNode::func_type_annotation() const {
Array<Type> param_types;
for (auto param : this->params) {
param_types.push_back(param->type);
param_types.push_back(param->type_annotation);
}
return FuncTypeNode::make(param_types, this->ret_type, this->type_params, {});
}
......@@ -155,24 +146,23 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
<< node->attrs << ", " << node->type_args << ")";
});
Let LetNode::make(Var var, Expr value, Expr body, Type value_type) {
Let LetNode::make(Var var, Expr value, Expr body) {
NodePtr<LetNode> n = make_node<LetNode>();
n->var = std::move(var);
n->value = std::move(value);
n->body = std::move(body);
n->value_type = std::move(value_type);
return Let(n);
}
TVM_REGISTER_API("relay._make.Let")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = LetNode::make(args[0], args[1], args[2], args[3]);
});
*ret = LetNode::make(args[0], args[1], args[2]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<LetNode>([](const LetNode *node, tvm::IRPrinter *p) {
p->stream << "LetNode(" << node->var << ", " << node->value
<< ", " << node->body << ", " << node->value_type << ")";
<< ", " << node->body << ")";
});
If IfNode::make(Expr cond, Expr true_branch, Expr false_branch) {
......
......@@ -24,6 +24,16 @@ Expr ExprMutator::Mutate(const Expr& expr) {
}
Expr ExprMutator::VisitExpr_(const VarNode* op) {
// NOTE: var will only be mutated once
// Thanks to the memo and reused during rewriting if necessary.
// It is safe to assume that the
if (op->type_annotation.defined()) {
auto type = this->VisitType(op->type_annotation);
if (!op->type_annotation.same_as(type)) {
return VarNode::make(op->name_hint, type);
}
}
// default case return self.
return GetRef<Expr>(op);
}
......@@ -55,16 +65,6 @@ Expr ExprMutator::VisitExpr_(const TupleNode* op) {
}
}
Expr ExprMutator::VisitExpr_(const ParamNode* op) {
Var var = Downcast<Var>(this->Mutate(op->var));
auto type = this->VisitType(op->type);
if (op->var.same_as(var) && op->type.same_as(type)) {
return GetRef<Expr>(op);
} else {
return ParamNode::make(var, type);
}
}
Expr ExprMutator::VisitExpr_(const FunctionNode* op) {
tvm::Array<TypeParam> ty_params;
bool all_ty_params_changed = true;
......@@ -75,10 +75,10 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* op) {
all_ty_params_changed &= new_ty_param.same_as(ty_param);
}
tvm::Array<Param> params;
tvm::Array<Var> params;
bool all_params_changed = true;
for (auto param : op->params) {
Param new_param = Downcast<Param>(this->Mutate(param));
Var new_param = Downcast<Var>(this->Mutate(param));
params.push_back(new_param);
all_params_changed &= param.same_as(new_param);
}
......@@ -123,17 +123,15 @@ Expr ExprMutator::VisitExpr_(const CallNode* call_node) {
Expr ExprMutator::VisitExpr_(const LetNode* op) {
Var var = Downcast<Var>(this->Mutate(op->var));
auto type = this->VisitType(op->value_type);
auto value = this->Mutate(op->value);
auto body = this->Mutate(op->body);
if (var.same_as(op->var) &&
type.same_as(op->value_type) &&
value.same_as(op->value) &&
body.same_as(op->body)) {
return GetRef<Expr>(op);
} else {
return LetNode::make(var, value, body, type);
return LetNode::make(var, value, body);
}
}
......@@ -162,6 +160,9 @@ Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) {
Type ExprMutator::VisitType(const Type& t) { return t; }
void ExprVisitor::ExprVisitor::VisitExpr_(const VarNode* op) {
if (op->type_annotation.defined()) {
this->VisitType(op->type_annotation);
}
}
void ExprVisitor::ExprVisitor::VisitExpr_(const GlobalVarNode* op) {
......@@ -176,10 +177,6 @@ void ExprVisitor::ExprVisitor::VisitExpr_(const TupleNode* op) {
}
}
void ExprVisitor::ExprVisitor::VisitExpr_(const ParamNode* op) {
this->VisitExpr(op->var);
}
void ExprVisitor::ExprVisitor::VisitExpr_(const FunctionNode* op) {
for (auto param : op->params) {
this->VisitExpr(param);
......
......@@ -252,15 +252,6 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
}
}
void VisitExpr_(const ParamNode* p1, const Expr& e2) final {
if (const ParamNode* p2 = e2.as<ParamNode>()) {
eq_map.Set(p1->var, p2->var);
equal = equal && AlphaEqual(p1->type, p2->type);
} else {
equal = false;
}
}
void VisitExpr_(const FunctionNode* func1, const Expr& e2) final {
if (const FunctionNode* func2 = e2.as<FunctionNode>()) {
if (func1->params.size() != func2->params.size()) {
......@@ -273,9 +264,10 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
return;
}
for (size_t i = 0U; i < func1->params.size(); i++) {
this->VisitExpr(func1->params[i], func2->params[i]);
for (size_t i = 0; i < func1->params.size(); ++i) {
MergeVarDecl(func1->params[i], func2->params[i]);
}
if (!equal) return;
for (size_t i = 0U; i < func1->type_params.size(); i++) {
equal = equal && AlphaEqual(func1->type_params[i], func2->type_params[i]);
......@@ -332,19 +324,9 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
void VisitExpr_(const LetNode* op, const Expr& e2) final {
if (const LetNode* let = e2.as<LetNode>()) {
eq_map.Set(op->var, let->var);
MergeVarDecl(op->var, let->var);
this->VisitExpr(op->value, let->value);
this->VisitExpr(op->body, let->body);
// value_type should match as well (including nulls)
if (op->value_type.defined() != let->value_type.defined()) {
equal = false;
return;
}
if (op->value_type.defined()) {
equal = equal && AlphaEqual(op->value_type, let->value_type);
}
} else {
equal = false;
}
......@@ -388,6 +370,20 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
equal = false;
}
}
private:
void MergeVarDecl(const Var& var1, const Var& var2) {
if (var1->type_annotation.defined() != var2->type_annotation.defined()) {
equal = false;
return;
}
if (var1->type_annotation.defined() &&
!AlphaEqual(var1->type_annotation, var2->type_annotation)) {
equal = false;
return;
}
eq_map.Set(var1, var2);
}
};
bool AlphaEqual(const Expr& e1, const Expr& e2) {
......
......@@ -54,12 +54,7 @@ class CalcDep : private ExprMutator {
}
private:
struct Binder {
Type t;
Expr e;
Binder(const Type& t, const Expr& e) : t(t), e(e) { }
};
using VarMap = std::unordered_map<Var, Binder, NodeHash, NodeEqual>;
using VarMap = std::unordered_map<Var, Expr, NodeHash, NodeEqual>;
VarMap var_map_;
Expr VisitExpr_(const IfNode* i) final {
......@@ -74,9 +69,7 @@ class CalcDep : private ExprMutator {
}
Expr VisitExpr_(const LetNode* l) final {
var_map_.insert(std::pair<Var, Binder>(l->var,
Binder(l->value_type,
Eliminate(l->value))));
var_map_[l->var] = Eliminate(l->value);
return VisitExpr(l->body);
}
......@@ -92,15 +85,16 @@ class CalcDep : private ExprMutator {
explicit GenLet(const VarMap& var_map) : var_map_(var_map) { }
friend CalcDep;
void VisitExpr_(const VarNode* vn) final {
Var v = GetRef<Var>(vn);
if (var_map_.count(v) != 0) {
auto val = var_map_.at(v);
var_map_.erase(v);
void VisitExpr_(const VarNode* vnode) final {
Var v = GetRef<Var>(vnode);
auto it = var_map_.find(v);
if (it != var_map_.end()) {
Expr expr = it->second;
var_map_.erase(it);
// erase before visit to handle letrec
VisitExpr(val.e);
VisitExpr(expr);
// visit before push back so the dependency of dependency is before the dependency
lets_.Push(v, val.t, val.e);
lets_.Push(v, expr);
}
}
};
......
......@@ -26,57 +26,46 @@ namespace relay {
*/
class LetList {
public:
/*! \brief insert a binding.
/*!
* \brief insert a binding.
*
* \param pv the var of the binding.
* \param pv the var of the binding.
*
* \param ty the type of the binding.
* \param expr the value of the binding.
*
* \param expr the value of the binding.
*
* \return a Var that hold the inserted expr.
* \return a Var that hold the inserted expr.
*/
Var Push(const Var& pv, const Type& ty, const Expr& expr) {
std::tuple<Var, Type, Expr> tuple(pv, ty, expr);
lets_.push_back(tuple);
Var Push(Var pv, Expr expr) {
lets_.emplace_back(std::make_pair(pv, expr));
return pv;
}
/*! \brief insert a binding.
/*!
* \brief insert a binding.
*
* \param ty the type of the binding.
* \param ty the type of the binding.
*
* \param expr the value of the binding.
* \param expr the value of the binding.
*
* \return a Var that hold the inserted expr.
*/
Var Push(const Type& ty, const Expr& expr) {
return Push(VarNode::make("x"), ty, expr);
}
/*! \brief insert a binding.
*
* \param pv the var of the binding.
*
* \param expr the value of the binding.
*
* \return a Var that hold the inserted expr.
* \return a Var that hold the inserted expr.
*/
Var Push(const Var& pv, const Expr& expr) {
return Push(pv, IncompleteTypeNode::make(TypeParamNode::kType), expr);
Var Push(Type ty, Expr expr) {
return Push(VarNode::make("x", ty), expr);
}
/*! \brief insert a binding.
/*!
* \brief insert a binding.
*
* \param expr the value of the binding.
*
* \return a Var that hold the inserted expr.
*/
Var Push(const Expr& expr) {
Var Push(Expr expr) {
return Push(IncompleteTypeNode::make(TypeParamNode::kType), expr);
}
/*! \brief wrap an expr around the LetList.
/*!
* \brief wrap an expr around the LetList.
*
* \param body the Expression to be wrapped around.
*
......@@ -85,7 +74,7 @@ class LetList {
Expr Get(const Expr& body) const {
Expr ret = body;
for (auto rit = lets_.rbegin(); rit != lets_.rend(); ++rit) {
ret = LetNode::make(std::get<0>(*rit), std::get<2>(*rit), ret, std::get<1>(*rit));
ret = LetNode::make(std::get<0>(*rit), std::get<1>(*rit), ret);
}
return ret;
}
......@@ -118,7 +107,7 @@ class LetList {
}
private:
std::vector<std::tuple<Var, Type, Expr> > lets_;
std::vector<std::pair<Var, Expr> > lets_;
};
} // namespace relay
......
......@@ -87,15 +87,11 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
// Visitor logics
Type VisitExpr_(const VarNode* op) final {
// The type of Var can already been lookedup in type_map_;
LOG(FATAL) << "Cannot find binding for var " << GetRef<Var>(op);
return Type();
}
Type VisitExpr_(const ParamNode* op) final {
// directly handled by Funtion
LOG(FATAL) << "not reached";
return Type();
if (op->type_annotation.defined()) {
return op->type_annotation;
} else {
return IncompleteTypeNode::make(TypeParamNode::kType);
}
}
Type VisitExpr_(const GlobalVarNode* op) final {
......@@ -139,11 +135,11 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
Type VisitExpr_(const LetNode* op) final {
Type vtype = GetType(op->value);
if (op->value_type.defined()) {
vtype = Unify(vtype, op->value_type, op->span);
if (op->var->type_annotation.defined()) {
vtype = Unify(vtype, op->var->type_annotation, op->span);
}
CHECK(!type_map_.count(op->var));
// NOTE: no scoping is necessary becase var are unique in program
// NOTE: no scoping is necessary because var are unique in program
type_map_[op->var] = vtype;
return GetType(op->body);
}
......@@ -256,8 +252,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
Type VisitExpr_(const FunctionNode* f) final {
for (auto param : f->params) {
type_map_[param->var] = param->type;
type_map_[param] = param->type;
GetType(param);
}
Type rtype = GetType(f->body);
// Run solver using the currently known information
......@@ -265,8 +260,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
// Trying to resolve
Array<Type> arg_types;
for (size_t i = 0; i < f->params.size(); ++i) {
Param param = f->params[i];
Type atype = solver_.Resolve(param->type);
Type atype = solver_.Resolve(GetType(f->params[i]));
CHECK(atype.as<IncompleteTypeNode>() == nullptr)
<< "Cannot resolve type of " << i
<< "-th parameter of function at" << f->span;
......@@ -311,9 +305,6 @@ class TypeInferencer::Resolver : public ExprMutator {
return AttachCheckedType(op);
}
Expr VisitExpr_(const ParamNode* op) final {
return ExprMutator::VisitExpr_(op);
}
Expr VisitExpr_(const FunctionNode* op) final {
return AttachCheckedType(op);
......@@ -380,7 +371,7 @@ Expr InferType(const Environment& env,
const GlobalVar& var,
const Function& func) {
Function func_copy = Function(make_node<FunctionNode>(*func.operator->()));
func_copy->checked_type_ = func_copy->fn_type();
func_copy->checked_type_ = func_copy->func_type_annotation();
env->functions.Set(var, func_copy);
Expr func_ret = TypeInferencer(env).Infer(func_copy);
auto map_node = env->functions.CopyOnWrite();
......
......@@ -50,14 +50,17 @@ class FreeVar : public ExprVisitor {
if (bound_vars.count(var) == 0) {
free_vars.insert(var);
}
if (v->type_annotation.defined()) {
VisitType(v->type_annotation);
}
}
void VisitExpr_(const FunctionNode *f) final {
for (const auto& tp : f->type_params) {
bound_types.insert(tp);
}
for (const auto& p : f->params) {
bound_vars.insert(p->var);
for (const auto& param : f->params) {
bound_vars.insert(param);
}
VisitExpr(f->body);
VisitType(f->ret_type);
......@@ -67,7 +70,6 @@ class FreeVar : public ExprVisitor {
bound_vars.insert(l->var);
VisitExpr(l->value);
VisitExpr(l->body);
VisitType(l->value_type);
}
public:
......
......@@ -34,8 +34,8 @@ class WellFormedChecker : private ExprVisitor {
}
void VisitExpr_(const FunctionNode * f) final {
for (const Param & p : f->params) {
Check(p->var);
for (const Var & param : f->params) {
Check(param);
}
CheckWellFormed(f->body);
}
......
......@@ -14,7 +14,6 @@ def test_let():
assert var == prog.body
assert isinstance(value, Constant)
assert value.data.asnumpy() == np.array(1)
assert prog.value_type == None
if __name__ == "__main__":
test_let()
......@@ -49,18 +49,11 @@ def test_global_var():
show(gv)
def test_param():
lv = relay.Var('x')
ty = None
param = relay.Param(lv, ty)
show(lv)
def test_function():
param_names = ['a', 'b', 'c', 'd']
params = tvm.convert([relay.Param(relay.Var(n), None) for n in param_names])
params = tvm.convert([relay.Var(n) for n in param_names])
ret_type = None
body = params[0].var
body = params[0]
type_params = tvm.convert([])
fn = relay.Function(params, ret_type, body, type_params)
show(fn)
......@@ -76,11 +69,11 @@ def test_call():
def test_let():
lv = relay.Var('x')
ty = relay.ty.TensorType((10, 20), 'float32')
lv = relay.Var('x', ty)
arr = tvm.nd.array(10)
value = relay.Constant(arr)
let = relay.Let(lv, value, lv, ty)
let = relay.Let(lv, value, lv)
show(let)
......
......@@ -99,10 +99,16 @@ def test_tuple():
def test_local_var():
name_hint = 's'
lv = relay.Var(name_hint)
lv.name_hint == name_hint
assert lv.name_hint == name_hint
assert lv.type_annotation is None
# assert lv.span == None todo(@jroesch): what do we do about spans
str(lv)
t1 = relay.ty.TensorType((), "float")
lv = relay.Var(name_hint, t1)
assert lv.name_hint == name_hint
assert lv.type_annotation == t1
def test_global_var():
name_hint = 'g'
......@@ -112,19 +118,9 @@ def test_global_var():
str(gv)
def test_param():
lv = relay.Var('x')
ty = None
param = relay.Param(lv, ty)
assert param.var == lv
assert param.type == ty
assert param.span == None
str(param)
def test_function():
param_names = ['a', 'b', 'c', 'd']
params = tvm.convert([relay.Param(relay.Var(n), None) for n in param_names])
params = tvm.convert([relay.Var(n) for n in param_names])
ret_type = None
body = None
type_params = tvm.convert([])
......@@ -154,10 +150,9 @@ def test_let():
value = relay.Constant(arr)
# I would prefer that the order of arguments
# matches syntax let x: t = v in b
let = relay.Let(lv, value, lv, ty)
let = relay.Let(lv, value, lv)
assert let.var == lv
assert let.value == value
assert let.value_type == ty
assert let.body == lv
assert let.span == None
str(let)
......@@ -194,7 +189,6 @@ if __name__ == "__main__":
test_tuple()
test_local_var()
test_global_var()
test_param()
test_function()
test_call()
test_let()
......
......@@ -7,23 +7,22 @@ def test_well_formed():
assert well_formed(x)
v = relay.Constant(tvm.nd.array(10))
ty = None
let = relay.Let(x, v, x, ty)
let = relay.Let(x, v, x)
assert well_formed(let)
assert not well_formed(relay.Let(x, v, let, ty))
f = relay.Function([relay.Param(x, ty)], ty, x)
assert not well_formed(relay.Let(x, v, let))
f = relay.Function([x], ty, x)
assert well_formed(f)
# this test should pass in case of weak uniqueness (only test for shadowing)
# but we want all binder to be distinct from each other.
assert not well_formed(relay.Let(relay.Var("y"), f,
relay.Let(relay.Var("z"), f, v, ty), ty))
relay.Let(relay.Var("z"), f, v)))
def test_tuple():
x = relay.Var('x')
assert well_formed(x)
v = relay.Constant(tvm.nd.array(10))
ty = None
let = relay.Let(x, v, x, ty)
let = relay.Let(x, v, x)
assert well_formed(let)
assert well_formed(relay.Tuple([v, v]))
assert not well_formed(relay.Tuple([let, let]))
......
......@@ -27,6 +27,8 @@ def test_single_op():
tvm.relay.sigmoid, tvm.relay.tanh]:
check_single_op(opfunc)
def test_expand_dims_infer_type():
ib = relay.ir_builder.IRBuilder()
n, t, d = tvm.var("n"), tvm.var("t"), 100
......@@ -75,12 +77,13 @@ def test_unary_op():
ib = relay.ir_builder.IRBuilder()
x = ib.param("x", relay.TensorType((10, 4), "int32"))
with ib.function(x) as func:
ib.ret(op(x.var))
ib.ret(op(x))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.TensorType((10, 4), "int32")
def test_binary_op():
def check_binary_op(opfunc):
"""
......@@ -94,7 +97,7 @@ def test_binary_op():
x = b.param('x', tensor_type(5, 5, 5))
y = b.param('y', tensor_type(5, 5, 5))
with b.function(x, y) as func:
b.ret(opfunc(x.var, y.var))
b.ret(opfunc(x, y))
b.ret(func)
prog, env = b.get()
ttype = tensor_type(5, 5, 5)
......@@ -118,7 +121,7 @@ def test_binary_broadcast_op():
x = b.param('x', tensor_type(10, 4))
y = b.param('y', tensor_type(5, 10, 1))
with b.function(x, y) as func:
b.ret(opfunc(x.var, y.var))
b.ret(opfunc(x, y))
b.ret(func)
prog, env = b.get()
......
......@@ -11,7 +11,7 @@ def test_conv2d_infer_type():
w = ib.param("w", relay.ty.IncompleteType())
with ib.function(x, w) as func:
ib.ret(relay.nn.conv2d(x.var, w.var,
ib.ret(relay.nn.conv2d(x, w,
kernel_size=(3, 3),
padding=(1, 1),
channels=2))
......@@ -29,7 +29,7 @@ def test_conv2d_infer_type():
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "int8"))
w = ib.param("w", relay.ty.TensorType((2, 10, 3, 3), "int8"))
with ib.function(x, w) as func:
ib.ret(relay.nn.conv2d(x.var, w.var, out_dtype="int32"))
ib.ret(relay.nn.conv2d(x, w, out_dtype="int32"))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
......@@ -42,7 +42,7 @@ def test_conv2d_infer_type():
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "int8"))
w = ib.param("w", relay.ty.IncompleteType())
with ib.function(x, w) as func:
ib.ret(relay.nn.conv2d(x.var, w.var,
ib.ret(relay.nn.conv2d(x, w,
kernel_size=(3, 3),
padding=(1, 1),
channels=16,
......@@ -65,7 +65,7 @@ def test_conv2d_transpose_infer_type():
w = ib.param("w", relay.ty.IncompleteType())
with ib.function(x, w) as func:
ib.ret(relay.nn.conv2d_transpose(x.var, w.var,
ib.ret(relay.nn.conv2d_transpose(x, w,
kernel_size=(3, 3),
padding=(1, 1),
channels=15))
......@@ -83,7 +83,7 @@ def test_conv2d_transpose_infer_type():
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32"))
w = ib.param("w", relay.ty.TensorType((12, 11, 5, 5), "float32"))
with ib.function(x, w) as func:
ib.ret(relay.nn.conv2d_transpose(x.var, w.var,
ib.ret(relay.nn.conv2d_transpose(x, w,
output_padding=(1, 1),
channels=11,
data_layout="NHWC"))
......@@ -98,7 +98,7 @@ def test_upsampling_infer_type():
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32"))
with ib.function(x) as func:
ib.ret(relay.nn.upsampling(x.var, scale=2, layout="NCHW", method="BILINEAR"))
ib.ret(relay.nn.upsampling(x, scale=2, layout="NCHW", method="BILINEAR"))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
......@@ -108,7 +108,7 @@ def test_upsampling_infer_type():
n, c = tvm.var("n"), tvm.var("c")
x = ib.param("x", relay.ty.TensorType((n, c, 100, 200), "float32"))
with ib.function(x) as func:
ib.ret(relay.nn.upsampling(x.var, scale=2, layout="NCHW", method="BILINEAR"))
ib.ret(relay.nn.upsampling(x, scale=2, layout="NCHW", method="BILINEAR"))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
......@@ -119,7 +119,7 @@ def _test_pool2d_infer_type(opfunc):
n, c, h, w = tvm.var("n"), 10, 224, 224
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32"))
with ib.function(x) as func:
ib.ret(opfunc(x.var, pool_size=(1, 1)))
ib.ret(opfunc(x, pool_size=(1, 1)))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
......@@ -132,7 +132,7 @@ def _test_pool2d_infer_type(opfunc):
n, c, h, w = tvm.var("n"), 10, 224, 224
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32"))
with ib.function(x) as func:
ib.ret(opfunc(x.var, pool_size=(ph, pw), strides=(sh, sw)))
ib.ret(opfunc(x, pool_size=(ph, pw), strides=(sh, sw)))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
......@@ -144,7 +144,7 @@ def _test_global_pool2d_infer_type(opfunc):
n, c, h, w = tvm.var("n"), tvm.var("c"), 224, 224
x = ib.param("x", relay.ty.TensorType((n, h, w, c), "float32"))
with ib.function(x) as func:
ib.ret(opfunc(x.var, layout="NHWC"))
ib.ret(opfunc(x, layout="NHWC"))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
......@@ -154,7 +154,7 @@ def _test_global_pool2d_infer_type(opfunc):
n, c, h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32"))
with ib.function(x) as func:
ib.ret(opfunc(x.var))
ib.ret(opfunc(x))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
......@@ -172,7 +172,7 @@ def test_flatten_infer_type():
x = ib.param("x", relay.ty.TensorType((d1, d2, d3, d4), "float32"))
with ib.function(x) as func:
ib.ret(relay.nn.batch_flatten(x.var))
ib.ret(relay.nn.batch_flatten(x))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
......@@ -181,7 +181,7 @@ def test_flatten_infer_type():
ib = relay.ir_builder.IRBuilder()
x = ib.param("x", relay.ty.TensorType((3, 2, 4, 3), "float32"))
with ib.function(x) as func:
ib.ret(relay.nn.batch_flatten(x.var))
ib.ret(relay.nn.batch_flatten(x))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
......@@ -190,7 +190,7 @@ def test_flatten_infer_type():
ib = relay.ir_builder.IRBuilder()
x = ib.param("x", relay.ty.TensorType((d1, 2, d3, 3), "float32"))
with ib.function(x) as func:
ib.ret(relay.nn.batch_flatten(x.var))
ib.ret(relay.nn.batch_flatten(x))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
......@@ -202,7 +202,7 @@ def test_pad_infer_type():
n, c, h, w = 1, 2, 3, 4
t = ib.param("t", relay.TensorType((n, c, h, w), "float32"))
with ib.function(t) as func:
ib.ret(relay.nn.pad(t.var, ((1, 1), (2, 2), (3, 3), (4, 4))))
ib.ret(relay.nn.pad(t, ((1, 1), (2, 2), (3, 3), (4, 4))))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
......@@ -213,7 +213,7 @@ def test_pad_infer_type():
n, c, h, w = tvm.var("n"), 2, 3, tvm.var("w")
t = ib.param("t", relay.TensorType((n, c, h, w), "float32"))
with ib.function(t) as func:
ib.ret(relay.nn.pad(t.var, ((1, 1), (2, 2), (3, 3), (4, 4))))
ib.ret(relay.nn.pad(t, ((1, 1), (2, 2), (3, 3), (4, 4))))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
......@@ -227,4 +227,3 @@ if __name__ == "__main__":
test_flatten_infer_type()
test_pad_infer_type()
test_conv2d_transpose_infer_type()
......@@ -17,12 +17,13 @@ def test_zeros_ones():
ftype = func.checked_type
assert ftype.ret_type == relay.TensorType((124, 50), "float64")
def test_unary_identity():
for op in [relay.zeros_like, relay.ones_like]:
ib = relay.ir_builder.IRBuilder()
x = ib.param("x", relay.TensorType((8, 9, 4), "int32"))
with ib.function(x) as func:
ib.ret(op(x.var))
ib.ret(op(x))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
......@@ -33,7 +34,7 @@ def test_clip_type():
ib = relay.ir_builder.IRBuilder()
a = ib.param("a", relay.TensorType((10, 4), "float32"))
with ib.function(a) as func:
ib.ret(relay.clip(a.var, 1., 4.))
ib.ret(relay.clip(a, 1., 4.))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
......@@ -106,7 +107,7 @@ def test_take_infer_type():
x = ib.param("x", relay.ty.TensorType(dshape, "float32"))
indices = ib.param("indices", relay.ty.TensorType(indices_shape, "int32"))
with ib.function(x, indices) as func:
ib.ret(relay.take(x.var, indices.var, axis=axis))
ib.ret(relay.take(x, indices, axis=axis))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
......@@ -127,7 +128,7 @@ def test_full():
ib = relay.ir_builder.IRBuilder()
x = ib.param("x", relay.TensorType((), "int8"))
with ib.function(x) as func:
ib.ret(relay.full(x.var, ()))
ib.ret(relay.full(x, ()))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
......@@ -137,7 +138,7 @@ def test_full():
ib = relay.ir_builder.IRBuilder()
x = ib.param("x", relay.TensorType((), "float32"))
with ib.function(x) as func:
ib.ret(relay.full(x.var, (1, 2), "int8"))
ib.ret(relay.full(x, (1, 2), "int8"))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
......@@ -150,7 +151,7 @@ def test_full_like():
base = ib.param("base", relay.TensorType((1, 2, 3), "float32"))
fill = ib.param("fill", relay.TensorType((), "float32"))
with ib.function(base, fill) as func:
ib.ret(relay.full_like(base.var, fill.var))
ib.ret(relay.full_like(base, fill))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
......@@ -162,7 +163,7 @@ def test_full_like():
base = ib.param("base", relay.TensorType((n, c, h, w), "float32"))
fill = ib.param("fill", relay.TensorType((), "float32"))
with ib.function(base, fill) as func:
ib.ret(relay.full_like(base.var, fill.var))
ib.ret(relay.full_like(base, fill))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
......
......@@ -24,7 +24,7 @@ def test_cmp_type():
x = ib.param("x", relay.TensorType((10, 4), "float32"))
y = ib.param("y", relay.TensorType((5, 10, 1), "float32"))
with ib.function(x, y) as func:
ib.ret(op(x.var, y.var))
ib.ret(op(x, y))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
......@@ -39,7 +39,7 @@ def test_binary_broadcast():
x = ib.param("x", relay.TensorType((10, 4), "int32"))
y = ib.param("y", relay.TensorType((5, 10, 1), "int32"))
with ib.function(x, y) as func:
ib.ret(op(x.var, y.var))
ib.ret(op(x, y))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
......@@ -58,7 +58,7 @@ def test_binary_op():
x = b.param('x', tensor_type(5, 5, 5))
y = b.param('y', tensor_type(5, 5, 5))
with b.function(x, y) as func:
b.ret(opfunc(x.var, y.var))
b.ret(opfunc(x, y))
b.ret(func)
prog, env = b.get()
ttype = tensor_type(5, 5, 5)
......@@ -81,7 +81,7 @@ def test_binary_broadcast_op():
x = b.param('x', tensor_type(10, 4))
y = b.param('y', tensor_type(5, 10, 1))
with b.function(x, y) as func:
b.ret(opfunc(x.var, y.var))
b.ret(opfunc(x, y))
b.ret(func)
prog, env = b.get()
......@@ -103,7 +103,7 @@ def test_cmp_type():
x = ib.param("x", relay.TensorType((10, 4), "float32"))
y = ib.param("y", relay.TensorType((5, 10, 1), "float32"))
with ib.function(x, y) as func:
ib.ret(op(x.var, y.var))
ib.ret(op(x, y))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
......@@ -118,7 +118,7 @@ def test_binary_broadcast():
x = ib.param("x", relay.TensorType((10, 4), "int32"))
y = ib.param("y", relay.TensorType((5, 10, 1), "int32"))
with ib.function(x, y) as func:
ib.ret(op(x.var, y.var))
ib.ret(op(x, y))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
......@@ -131,7 +131,7 @@ def test_where():
x = ib.param("x", relay.TensorType((3, 4), "float32"))
y = ib.param("y", relay.TensorType((3, 4), "float32"))
with ib.function(cond, x, y) as func:
ib.ret(relay.where(cond.var, x.var, y.var))
ib.ret(relay.where(cond, x, y))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
......
......@@ -10,7 +10,7 @@ def test_resize_infer_type():
th, tw = tvm.var("th"), tvm.var("tw")
with ib.function(x) as func:
ib.ret(relay.image.resize(x.var, (th, tw)))
ib.ret(relay.image.resize(x, (th, tw)))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
......@@ -19,7 +19,7 @@ def test_resize_infer_type():
ib = relay.ir_builder.IRBuilder()
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "int8"))
with ib.function(x) as func:
ib.ret(relay.image.resize(x.var, (100, 200), "NCHW", "BILINEAR", False))
ib.ret(relay.image.resize(x, (100, 200), "NCHW", "BILINEAR", False))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
......
import tvm
import numpy as np
from tvm import relay
from tvm.relay.ir_pass import alpha_equal
from tvm.relay.ir_builder import convert
......@@ -179,9 +180,9 @@ def test_var_alpha_equal():
assert not alpha_equal(v1, v2)
# let node allows for setting the eq_map
l1 = relay.Let(v1, convert(1), v1, None)
l2 = relay.Let(v2, convert(1), v2, None)
l3 = relay.Let(v1, convert(1), v2, None)
l1 = relay.Let(v1, convert(1), v1)
l2 = relay.Let(v2, convert(1), v2)
l3 = relay.Let(v1, convert(1), v2)
assert alpha_equal(l1, l2)
assert not alpha_equal(l1, l3)
......@@ -209,10 +210,10 @@ def test_tuple_alpha_equal():
assert alpha_equal(tup, same)
# use the eq_map
let_tup = relay.Let(v1, tup, v1, None)
let_tup = relay.Let(v1, tup, v1)
let_mapped = relay.Let(v2, relay.Tuple([v2, convert(2), convert(3),
relay.Tuple([convert(4)])]),
v2, None)
v2)
assert alpha_equal(let_tup, let_mapped)
more_fields = relay.Tuple([v1, convert(2), convert(3), relay.Tuple([convert(4)]), v2])
......@@ -242,61 +243,44 @@ def test_tuple_get_item_alpha_equal():
assert alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 1))
def test_param_alpha_equal():
# only checks equality of the types
v1 = relay.Var("v1")
v2 = relay.Var("v2")
p1 = relay.Param(v1, relay.TensorType((1, 2, 3), "float32"))
p2 = relay.Param(v2, relay.TensorType((1, 2, 3), "float32"))
assert alpha_equal(p1, p2)
p3 = relay.Param(v1, relay.TensorType((4, 5, 6), "int8"))
assert not alpha_equal(p1, p3)
p4 = relay.Param(v1, relay.TupleType([relay.TensorType((1, 2, 3),
"float32")]))
assert not alpha_equal(p1, p4)
def test_function_alpha_equal():
v1 = relay.Var("v1")
v2 = relay.Var("v2")
v3 = relay.Var("v3")
v4 = relay.Var("v4")
tt1 = relay.TensorType((1, 2, 3), "float32")
tt2 = relay.TensorType((4, 5, 6), "int8")
tt3 = relay.TupleType([tt1, tt2])
v1 = relay.Var("v1", tt1)
v2 = relay.Var("v2", tt2)
v3 = relay.Var("v3", tt3)
v4 = relay.Var("v4", tt2)
vret = relay.Constant(tvm.nd.array(np.ones(1)))
tp1 = relay.TypeParam("tp1", relay.Kind.Type)
tp2 = relay.TypeParam("tp2", relay.Kind.Type)
tp3 = relay.TypeParam("tp3", relay.Kind.Shape)
tp4 = relay.TypeParam("tp4", relay.Kind.Shape)
basic_args = [relay.Param(v3, tt1), relay.Param(v4, tt2)]
basic_args = [relay.Var("v3", tt1), relay.Var("v4", tt2)]
basic_tps = [tp1, tp2]
func = relay.Function([relay.Param(v1, tt1), relay.Param(v2, tt2)],
tt2, v2, basic_tps)
mapped = relay.Function(basic_args, tt2, v4, basic_tps)
func = relay.Function([v1, v2],
tt2, v1, basic_tps)
mapped = relay.Function(basic_args, tt2, basic_args[0], basic_tps)
assert alpha_equal(func, mapped)
fewer_params = relay.Function([relay.Param(v4, tt2)], tt2, v4, basic_tps)
fewer_params = relay.Function([relay.Var("v4", tt2)], tt2, v4, basic_tps)
assert not alpha_equal(func, fewer_params)
more_params = relay.Function([relay.Param(v3, tt1), relay.Param(v4, tt2),
relay.Param(v2, tt2)], tt2, v4, basic_tps)
more_params = relay.Function([relay.Var("v3", tt1),
relay.Var("v4", tt2),
relay.Var("v2", tt2)], tt2, v4, basic_tps)
assert not alpha_equal(func, more_params)
params_unordered = relay.Function([relay.Param(v3, tt2),
relay.Param(v4, tt1)],
tt1, v3, basic_tps)
params_unordered = relay.Function([v2, v1],
tt2, v1, basic_tps)
assert not alpha_equal(func, params_unordered)
params_mismatch = relay.Function([relay.Param(v3, tt3),
relay.Param(v4, tt2)],
tt2, v4, basic_tps)
params_mismatch = relay.Function([v1, v3],
tt2, v1, basic_tps)
assert not alpha_equal(func, params_mismatch)
# also would not typecheck
......@@ -376,7 +360,10 @@ def test_call_alpha_equal():
def test_let_alpha_equal():
tt1 = relay.TensorType((), "float32")
tt2 = relay.TensorType((), "int8")
v1 = relay.Var("v1")
v1_wtype = relay.Var("v1", tt1)
v2 = relay.Var("v2")
v3 = relay.Var("v3")
......@@ -394,14 +381,13 @@ def test_let_alpha_equal():
assert not alpha_equal(let, different_body)
# specified types must match
tt1 = relay.TensorType((), "float32")
tt2 = relay.TensorType((), "int8")
let_with_type = relay.Let(v1, convert(2), v1, tt1)
same_type = relay.Let(v1, convert(2), v1, tt1)
let_with_type = relay.Let(v1_wtype, convert(2), v1_wtype)
same_type = relay.Let(v1_wtype, convert(2), v1_wtype)
assert alpha_equal(let_with_type, same_type)
assert not alpha_equal(let, let_with_type)
different_type = relay.Let(v1, convert(2), v1, tt2)
v2 = relay.Var("v1", tt2)
different_type = relay.Let(v2, convert(2), v2)
assert not alpha_equal(let_with_type, different_type)
......@@ -437,16 +423,13 @@ if __name__ == "__main__":
test_tensor_type_alpha_equal()
test_incomplete_type_alpha_equal()
test_constant_alpha_equal()
test_type_param_alpha_equal()
test_func_type_alpha_equal()
test_tuple_type_alpha_equal()
test_type_relation_alpha_equal()
test_constant_alpha_equal()
test_var_alpha_equal()
test_global_var_alpha_equal()
test_tuple_alpha_equal()
test_tuple_get_item_alpha_equal()
test_param_alpha_equal()
test_function_alpha_equal()
test_call_alpha_equal()
test_let_alpha_equal()
......
......@@ -28,17 +28,17 @@ e = env()
def test_let():
orig = relay.Let(e.x, e.y, e.z, e.tt)
orig = relay.Let(e.x, e.y, e.z)
assert alpha_equal(dead_code_elimination(orig), e.z)
def test_used_let():
orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c, e.tt), e.tt)
assert alpha_equal(dead_code_elimination(orig), relay.Let(e.c, e.d, e.c, e.tt))
orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c))
assert alpha_equal(dead_code_elimination(orig), relay.Let(e.c, e.d, e.c))
def test_chain_unused_let():
orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.e, e.tt), e.tt)
orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.e))
assert alpha_equal(dead_code_elimination(orig), e.e)
......@@ -56,19 +56,17 @@ def test_recursion():
f(2, 10000);
"""
f = relay.Var("f")
n = relay.Var("n")
np = relay.Param(n, e.int32)
data = relay.Var("data")
datap = relay.Param(data, e.float32)
n = relay.Var("n", e.int32)
data = relay.Var("data", e.float32)
funcbody = relay.If(equal(n, convert(0)), data, f(subtract(n, convert(1.0)), log(data)))
value = relay.Function([np, datap], e.float32, funcbody, [])
orig = relay.Let(f, funcbody, f(convert(2.0), convert(10000.0)), e.float32)
value = relay.Function([n, data], e.float32, funcbody, [])
orig = relay.Let(f, funcbody, f(convert(2.0), convert(10000.0)))
assert alpha_equal(dead_code_elimination(orig), orig)
assert alpha_equal(dead_code_elimination(relay.Let(f, funcbody, e.three, e.float32)), e.three)
assert alpha_equal(dead_code_elimination(relay.Let(f, funcbody, e.three)), e.three)
def test_op_let():
assert alpha_equal(dead_code_elimination(add(relay.Let(e.a, e.one, e.three, e.float32), e.two)), add(e.three, e.two))
assert alpha_equal(dead_code_elimination(add(relay.Let(e.a, e.one, e.three), e.two)), add(e.three, e.two))
def test_if():
......@@ -80,7 +78,7 @@ def test_tuple_get_item():
t = relay.Var('t')
g = relay.TupleGetItem(t, 0)
assert alpha_equal(dead_code_elimination(g), g)
assert alpha_equal(dead_code_elimination(relay.TupleGetItem(relay.Let(e.a, e.one, t, e.float32), 0)), g)
assert alpha_equal(dead_code_elimination(relay.TupleGetItem(relay.Let(e.a, e.one, t), 0)), g)
if __name__ == "__main__":
......
......@@ -3,16 +3,17 @@ from tvm import relay
from tvm.relay.ir_pass import free_vars, free_type_vars
def test_free_vars():
x = relay.Var("x")
ty = relay.TensorType([], "int32")
x = relay.Var("x", ty)
fvx = free_vars(x)
assert len(fvx) == 1
assert fvx[0] == x
v = relay.Constant(tvm.nd.array(10))
ty = relay.TensorType([], "int32")
let = relay.Let(x, v, x, ty)
let = relay.Let(x, v, x)
fvx = free_vars(let)
assert len(free_vars(let)) == 0
f = relay.Function([relay.Param(x, ty)], ty, x)
f = relay.Function([x], ty, x)
assert len(free_vars(f)) == 0
......@@ -29,9 +30,9 @@ def test_tuple():
def test_free_type_vars():
tp = relay.TypeParam("")
ty = relay.TupleType([tp, relay.TensorType([], "int32")])
x = relay.Var("x")
x = relay.Var("x", ty)
y = relay.Var("y")
let = relay.Let(x, y, x, ty)
let = relay.Let(x, y, x)
fvl = free_vars(let)
assert len(fvl) == 1
assert fvl[0] == y
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment