Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
0b4cc050
Unverified
Commit
0b4cc050
authored
Oct 14, 2018
by
Tianqi Chen
Committed by
GitHub
Oct 14, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RELAY][IR] Move type_annotation to Var, remove Param (#1900)
parent
53428606
Hide whitespace changes
Inline
Side-by-side
Showing
26 changed files
with
375 additions
and
396 deletions
+375
-396
include/tvm/relay/expr.h
+29
-40
include/tvm/relay/expr_functor.h
+0
-4
python/tvm/relay/__init__.py
+0
-1
python/tvm/relay/expr.py
+117
-42
python/tvm/relay/ir_builder.py
+12
-21
src/relay/ir/debug_printer.cc
+9
-9
src/relay/ir/expr.cc
+18
-28
src/relay/ir/expr_functor.cc
+16
-19
src/relay/pass/alpha_eq.cc
+18
-22
src/relay/pass/dead_code.cc
+10
-16
src/relay/pass/let_list.h
+21
-32
src/relay/pass/type_infer.cc
+11
-20
src/relay/pass/util.cc
+5
-3
src/relay/pass/well_formed.cc
+2
-2
tests/python/relay/test_ir_builder.py
+0
-1
tests/python/relay/test_ir_debug_printer.py
+4
-11
tests/python/relay/test_ir_nodes.py
+9
-15
tests/python/relay/test_ir_well_formed.py
+5
-6
tests/python/relay/test_op_level1.py
+6
-3
tests/python/relay/test_op_level2.py
+16
-17
tests/python/relay/test_op_level3.py
+8
-7
tests/python/relay/test_op_level4.py
+7
-7
tests/python/relay/test_op_level5.py
+2
-2
tests/python/relay/test_pass_alpha_equal.py
+32
-49
tests/python/relay/test_pass_dead_code_elimination.py
+11
-13
tests/python/relay/test_pass_free_vars.py
+7
-6
No files found.
include/tvm/relay/expr.h
View file @
0b4cc050
...
...
@@ -118,17 +118,27 @@ class Var;
/*! \brief Container for Var */
class
VarNode
:
public
ExprNode
{
public
:
/*! \brief The name of the variable, this only acts as a hint to the user,
* and is not used for equality.
/*!
* \brief The name of the variable,
* this only acts as a hint to the user,
* and is not used for equality.
*/
std
::
string
name_hint
;
/*!
* \brief type annotaion of the variable.
* This field records user provided type annotation of the Var.
* This field is optional and can be None.
*/
Type
type_annotation
;
void
VisitAttrs
(
tvm
::
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"name_hint"
,
&
name_hint
);
v
->
Visit
(
"type_annotation"
,
&
type_annotation
);
v
->
Visit
(
"_checked_type_"
,
&
checked_type_
);
}
TVM_DLL
static
Var
make
(
std
::
string
name_hint
);
TVM_DLL
static
Var
make
(
std
::
string
name_hint
,
Type
type_annotation
);
static
constexpr
const
char
*
_type_key
=
"relay.Var"
;
TVM_DECLARE_NODE_TYPE_INFO
(
VarNode
,
ExprNode
);
...
...
@@ -163,32 +173,6 @@ class GlobalVarNode : public ExprNode {
RELAY_DEFINE_NODE_REF
(
GlobalVar
,
GlobalVarNode
,
Expr
);
/*!
* \brief Function parameter declaration.
*/
class
Param
;
/*! \brief A parameter. */
class
ParamNode
:
public
ExprNode
{
public
:
/*! \brief The variable */
Var
var
;
/*! \brief The type of the parameter */
Type
type
;
void
VisitAttrs
(
tvm
::
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"var"
,
&
var
);
v
->
Visit
(
"type"
,
&
type
);
v
->
Visit
(
"span"
,
&
span
);
}
TVM_DLL
static
Param
make
(
Var
var
,
Type
type
);
static
constexpr
const
char
*
_type_key
=
"relay.Param"
;
TVM_DECLARE_NODE_TYPE_INFO
(
ParamNode
,
ExprNode
);
};
RELAY_DEFINE_NODE_REF
(
Param
,
ParamNode
,
Expr
);
/*!
* \brief Function (subgraph in computational graph)
*/
class
Function
;
...
...
@@ -196,7 +180,7 @@ class Function;
class
FunctionNode
:
public
ExprNode
{
public
:
/*! \brief Function parameters */
tvm
::
Array
<
Param
>
params
;
tvm
::
Array
<
Var
>
params
;
/*! \brief User annotated return type of the function. */
Type
ret_type
;
/*!
...
...
@@ -224,10 +208,18 @@ class FunctionNode : public ExprNode {
v
->
Visit
(
"_checked_type_"
,
&
checked_type_
);
}
Type
fn_type
()
const
;
/*!
* \brief Return the derived function annotation of this expression.
*
* \return The function type annotation.
* \note The function type annotation can contain IncompleteType.
*/
TVM_DLL
FuncType
func_type_annotation
()
const
;
TVM_DLL
static
Function
make
(
tvm
::
Array
<
Param
>
params
,
Type
ret_type
,
Expr
body
,
tvm
::
Array
<
TypeParam
>
ty_params
);
TVM_DLL
static
Function
make
(
tvm
::
Array
<
Var
>
params
,
Type
ret_type
,
Expr
body
,
tvm
::
Array
<
TypeParam
>
ty_params
);
static
constexpr
const
char
*
_type_key
=
"relay.Function"
;
TVM_DECLARE_NODE_TYPE_INFO
(
FunctionNode
,
ExprNode
);
...
...
@@ -289,7 +281,7 @@ class CallNode : public ExprNode {
TVM_DLL
static
Call
make
(
Expr
op
,
Array
<
Expr
>
args
,
Attrs
attrs
=
Attrs
(),
Array
<
Type
>
ty_args
=
Array
<
Type
>
());
Array
<
Type
>
ty
pe
_args
=
Array
<
Type
>
());
static
constexpr
const
char
*
_type_key
=
"relay.Call"
;
TVM_DECLARE_NODE_TYPE_INFO
(
CallNode
,
ExprNode
);
...
...
@@ -318,19 +310,16 @@ class LetNode : public ExprNode {
Expr
value
;
/*! \brief The body of the let binding */
Expr
body
;
/*! \brief Type annotation of value, this can be null */
Type
value_type
;
void
VisitAttrs
(
tvm
::
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"var"
,
&
var
);
v
->
Visit
(
"value"
,
&
value
);
v
->
Visit
(
"body"
,
&
body
);
v
->
Visit
(
"value_type"
,
&
value_type
);
v
->
Visit
(
"span"
,
&
span
);
v
->
Visit
(
"_checked_type_"
,
&
checked_type_
);
}
TVM_DLL
static
Let
make
(
Var
var
,
Expr
value
,
Expr
body
,
Type
value_type
);
TVM_DLL
static
Let
make
(
Var
var
,
Expr
value
,
Expr
body
);
static
constexpr
const
char
*
_type_key
=
"relay.Let"
;
TVM_DECLARE_NODE_TYPE_INFO
(
LetNode
,
ExprNode
);
...
...
@@ -376,11 +365,11 @@ class IfNode : public ExprNode {
RELAY_DEFINE_NODE_REF
(
If
,
IfNode
,
Expr
);
/*! \brief Get
a
field out of a tuple. */
/*! \brief Get
index-th
field out of a tuple. */
class
TupleGetItem
;
class
TupleGetItemNode
:
public
ExprNode
{
public
:
/*! \brief The tuple */
/*! \brief The tuple
Expression
*/
Expr
tuple
;
/*! \brief which value to get */
int
index
;
...
...
include/tvm/relay/expr_functor.h
View file @
0b4cc050
...
...
@@ -80,7 +80,6 @@ class ExprFunctor<R(const Expr& n, Args...)> {
Args
...
args
)
EXPR_FUNCTOR_DEFAULT
;
virtual
R
VisitExpr_
(
const
GlobalVarNode
*
op
,
Args
...
args
)
EXPR_FUNCTOR_DEFAULT
;
virtual
R
VisitExpr_
(
const
ParamNode
*
op
,
Args
...
args
)
EXPR_FUNCTOR_DEFAULT
;
virtual
R
VisitExpr_
(
const
FunctionNode
*
op
,
Args
...
args
)
EXPR_FUNCTOR_DEFAULT
;
virtual
R
VisitExpr_
(
const
CallNode
*
op
,
Args
...
args
)
EXPR_FUNCTOR_DEFAULT
;
...
...
@@ -103,7 +102,6 @@ class ExprFunctor<R(const Expr& n, Args...)> {
RELAY_EXPR_FUNCTOR_DISPATCH
(
TupleNode
);
RELAY_EXPR_FUNCTOR_DISPATCH
(
VarNode
);
RELAY_EXPR_FUNCTOR_DISPATCH
(
GlobalVarNode
);
RELAY_EXPR_FUNCTOR_DISPATCH
(
ParamNode
);
RELAY_EXPR_FUNCTOR_DISPATCH
(
FunctionNode
);
RELAY_EXPR_FUNCTOR_DISPATCH
(
CallNode
);
RELAY_EXPR_FUNCTOR_DISPATCH
(
LetNode
);
...
...
@@ -127,7 +125,6 @@ class ExprVisitor : public ::tvm::relay::ExprFunctor<void(const Expr& n)> {
void
VisitExpr_
(
const
GlobalVarNode
*
op
)
override
;
void
VisitExpr_
(
const
ConstantNode
*
op
)
override
;
void
VisitExpr_
(
const
TupleNode
*
op
)
override
;
void
VisitExpr_
(
const
ParamNode
*
op
)
override
;
void
VisitExpr_
(
const
FunctionNode
*
op
)
override
;
void
VisitExpr_
(
const
CallNode
*
op
)
override
;
void
VisitExpr_
(
const
LetNode
*
op
)
override
;
...
...
@@ -151,7 +148,6 @@ class ExprMutator
Expr
VisitExpr_
(
const
GlobalVarNode
*
op
)
override
;
Expr
VisitExpr_
(
const
OpNode
*
op
)
override
;
Expr
VisitExpr_
(
const
TupleNode
*
op
)
override
;
Expr
VisitExpr_
(
const
ParamNode
*
op
)
override
;
Expr
VisitExpr_
(
const
FunctionNode
*
op
)
override
;
Expr
VisitExpr_
(
const
CallNode
*
call_node
)
override
;
Expr
VisitExpr_
(
const
LetNode
*
op
)
override
;
...
...
python/tvm/relay/__init__.py
View file @
0b4cc050
...
...
@@ -34,7 +34,6 @@ Constant = expr.Constant
Tuple
=
expr
.
Tuple
Var
=
expr
.
Var
GlobalVar
=
expr
.
GlobalVar
Param
=
expr
.
Param
Function
=
expr
.
Function
Call
=
expr
.
Call
Let
=
expr
.
Let
...
...
python/tvm/relay/expr.py
View file @
0b4cc050
...
...
@@ -11,11 +11,11 @@ class Expr(NodeBase):
"""The base type for all Relay expressions."""
@property
def
checked_type
(
self
):
"""Get the checked type of
relay
.
"""Get the checked type of
tvm.relay.Expr
.
Returns
-------
checked_type : relay.Type
checked_type :
tvm.
relay.Type
The checked type.
"""
ret
=
self
.
_checked_type_
...
...
@@ -25,70 +25,97 @@ class Expr(NodeBase):
return
ret
def
__call__
(
self
,
*
args
):
converted_args
=
[]
for
arg
in
args
:
if
isinstance
(
arg
,
Param
):
converted_args
.
append
(
arg
.
var
)
else
:
converted_args
.
append
(
arg
)
return
Call
(
self
,
args
,
None
,
None
)
@register_relay_node
class
Constant
(
Expr
):
"""A constant tensor in Relay, see tvm/relay/type.h for more details.
"""
"""A constant expression in Relay.
Parameters
----------
data : tvm.nd.NDArray
The data content of the constant expression.
"""
def
__init__
(
self
,
data
):
self
.
__init_handle_by_constructor__
(
_make
.
Constant
,
data
)
@register_relay_node
class
Tuple
(
Expr
):
"""A hetereogenous sequence of values.
see tvm/relay/type.h for more details.
"""
"""Tuple expression that groups several fields together.
Parameters
----------
fields : List[tvm.relay.Expr]
The fields in the tuple.
"""
def
__init__
(
self
,
fields
):
self
.
__init_handle_by_constructor__
(
_make
.
Tuple
,
fields
)
@register_relay_node
class
Var
(
Expr
):
"""A local variable in
Relay."""
"""A local variable in
Tvm.Relay.
def
__init__
(
self
,
name_hint
):
self
.
__init_handle_by_constructor__
(
_make
.
Var
,
name_hint
)
Local variable can be used to declare input
arguments to a function, or intermediate variables.
Parameters
----------
name_hint: str
The name of the variable.
This name only acts as a hint, and is not used
for equality.
type_annotation: tvm.relay.Type, optional
The type annotation on the variable.
"""
def
__init__
(
self
,
name_hint
,
type_annotation
=
None
):
self
.
__init_handle_by_constructor__
(
_make
.
Var
,
name_hint
,
type_annotation
)
@register_relay_node
class
GlobalVar
(
Expr
):
"""A global variable in
Relay."""
"""A global variable in
Tvm.Relay.
GlobalVar is used to refer to the global functions
stored in the environment.
Parameters
----------
name_hint: str
The name of the variable.
"""
def
__init__
(
self
,
name_hint
):
self
.
__init_handle_by_constructor__
(
_make
.
GlobalVar
,
name_hint
)
@register_relay_node
class
Param
(
Expr
):
"""A function type in Relay, see tvm/relay/type.h for more details.
"""
class
Function
(
Expr
):
"""A function declaration expression.
def
__init__
(
self
,
var
,
ty
):
self
.
__init_handle_by_constructor__
(
_make
.
Param
,
var
,
ty
)
Parameters
----------
params: List[tvm.relay.Var]
List of input parameters to the function.
ret_type: tvm.relay.Type
The return type annotation of the function.
@register_relay_node
class
Function
(
Expr
):
"""A function in Relay, see tvm/relay/expr.h for more details."""
body: tvm.relay.Expr
The body of the function.
type_params: Optional[List[tvm.relay.TypeParam]]
The additional type parameters, this is only
used in advanced usecase of template functions.
"""
def
__init__
(
self
,
params
,
ret_type
,
body
,
type_params
=
None
):
type_params
=
None
):
if
type_params
is
None
:
type_params
=
convert
([])
...
...
@@ -98,39 +125,87 @@ class Function(Expr):
@register_relay_node
class
Call
(
Expr
):
"""A function call in Relay, see tvm/relay/expr.h for more details."""
"""Function call node in Relay.
Call node corresponds the operator application node
in computational graph terminology.
Parameters
----------
op: tvm.relay.Op or any tvm.relay.Expr with function type.
The operation to be called.
def
__init__
(
self
,
op
,
args
,
attrs
,
ty_args
=
None
):
if
not
ty_args
:
ty_args
=
[]
args: List[tvm.relay.Expr]
The arguments to the call.
attrs: Optional[tvm.Attrs]
Attributes to the call, can be None
type_args: Optional[List[tvm.relay.Type]]
The additional type arguments, this is only
used in advanced usecase of template functions.
"""
def
__init__
(
self
,
op
,
args
,
attrs
=
None
,
type_args
=
None
):
if
not
type_args
:
type_args
=
[]
self
.
__init_handle_by_constructor__
(
_make
.
Call
,
op
,
args
,
attrs
,
ty_args
)
_make
.
Call
,
op
,
args
,
attrs
,
ty
pe
_args
)
@register_relay_node
class
Let
(
Expr
):
"""A variable bindings in Relay, see tvm/relay/expr.h for more details."""
"""Let variable binding expression.
Parameters
----------
var: tvm.relay.Var
The local variable to be bound.
value: tvm.relay.Expr
The value to be bound.
def
__init__
(
self
,
var
,
value
,
body
,
value_type
=
None
):
body: tvm.relay.Expr
The body of the let binding.
"""
def
__init__
(
self
,
var
,
value
,
body
):
self
.
__init_handle_by_constructor__
(
_make
.
Let
,
var
,
value
,
body
,
value_type
)
_make
.
Let
,
var
,
value
,
body
)
@register_relay_node
class
If
(
Expr
):
"""A conditional expression in Relay, see tvm/relay/expr.h for more details."""
"""A conditional expression in Relay.
Parameters
----------
cond: tvm.relay.Expr
The condition.
def
__init__
(
self
,
cond
,
true_value
,
false_value
):
true_branch: tvm.relay.Expr
The expression evaluated when condition is true.
false_branch: tvm.relay.Expr
The expression evaluated when condition is false.
"""
def
__init__
(
self
,
cond
,
true_branch
,
false_branch
):
self
.
__init_handle_by_constructor__
(
_make
.
If
,
cond
,
true_value
,
false_value
)
_make
.
If
,
cond
,
true_branch
,
false_branch
)
@register_relay_node
class
TupleGetItem
(
Expr
):
"""An expression that get field from tuple in Relay, see tvm/relay/expr.h for more details."""
"""Get index-th item from a tuple.
Parameters
----------
tuple_value: tvm.relay.Expr
The input tuple expression.
def
__init__
(
self
,
tuple_
,
index
):
index: int
The index.
"""
def
__init__
(
self
,
tuple_value
,
index
):
self
.
__init_handle_by_constructor__
(
_make
.
TupleGetItem
,
tuple_
,
index
)
_make
.
TupleGetItem
,
tuple_
value
,
index
)
debug_print
=
_expr
.
_debug_print
python/tvm/relay/ir_builder.py
View file @
0b4cc050
...
...
@@ -7,7 +7,7 @@ from collections import OrderedDict
import
numpy
as
np
import
tvm
from
.ty
import
Type
,
FuncType
,
TensorType
from
.expr
import
Expr
,
Constant
,
Let
,
Var
,
Param
,
Function
,
If
from
.expr
import
Expr
,
Constant
,
Let
,
Var
,
Function
,
If
from
.env
import
Environment
...
...
@@ -98,7 +98,7 @@ class PartialFunc(object):
self
.
type_params
=
type_params
def
param_ids
(
self
):
return
[
p
.
var
for
p
in
self
.
params
]
return
[
p
for
p
in
self
.
params
]
def
to_func
(
self
):
"""Converts a PartialFunc into a :py:class:`~relay.Function`."""
...
...
@@ -113,9 +113,8 @@ class PartialFunc(object):
def
_mk_let
(
bindings
,
ret_value
):
let_expr
=
ret_value
for
var
,
(
value
,
ty
)
in
reversed
(
list
(
bindings
.
items
())):
let_expr
=
Let
(
var
,
value
,
let_expr
,
ty
)
for
var
,
value
in
reversed
(
list
(
bindings
.
items
())):
let_expr
=
Let
(
var
,
value
,
let_expr
)
return
let_expr
...
...
@@ -168,15 +167,12 @@ class IRBuilder(object):
#pylint: disable=invalid-name
def
bind
(
self
,
name
,
value
,
ty
):
lv
=
Var
(
name
)
lv
=
Var
(
name
,
ty
)
self
.
scopes
[
-
1
][
name
]
=
lv
self
.
bindings
[
-
1
][
lv
]
=
(
value
,
ty
)
self
.
bindings
[
-
1
][
lv
]
=
value
return
lv
def
let
(
self
,
name
,
value
,
value_type
=
None
):
if
isinstance
(
value
,
Param
):
value
=
value
.
var
if
not
isinstance
(
value
,
Expr
):
value
=
convert
(
value
)
...
...
@@ -185,23 +181,18 @@ class IRBuilder(object):
def
_convert_params
(
self
,
raw_params
):
relay_params
=
[]
for
raw_param
in
raw_params
:
if
isinstance
(
raw_param
,
Param
):
var
=
raw_param
.
var
if
isinstance
(
raw_param
,
Var
):
param
=
raw_param
elif
isinstance
(
raw_param
,
tuple
):
var
,
ty
=
raw_param
if
isinstance
(
var
,
str
):
var
=
Var
(
var
)
ty
=
_convert_type
(
ty
)
param
=
Param
(
var
,
ty
)
elif
isinstance
(
param
,
str
):
var
=
Var
(
raw_param
)
ty
=
None
param
=
Param
(
var
,
ty
)
param
=
Var
(
var
,
ty
)
elif
isinstance
(
raw_param
,
str
):
param
=
Var
(
raw_param
,
None
)
else
:
raise
Exception
(
"unknown parameter type"
)
self
.
scopes
[
-
1
][
var
.
name_hint
]
=
var
self
.
scopes
[
-
1
][
param
.
name_hint
]
=
param
relay_params
.
append
(
param
)
return
relay_params
...
...
@@ -265,7 +256,7 @@ class IRBuilder(object):
else
:
ty
=
_convert_type
(
ty
)
return
Param
(
Var
(
name
)
,
ty
)
return
Var
(
name
,
ty
)
def
global_var
(
self
,
name
):
# type: (str) -> GlobalVar
...
...
src/relay/ir/debug_printer.cc
View file @
0b4cc050
...
...
@@ -96,7 +96,9 @@ class TypeDocifier : private TypeFunctor<Doc(const Type& n)> {
}
std
::
vector
<
Doc
>
DocifyTypeParam
(
const
tvm
::
Array
<
TypeParam
>&
arr
)
{
return
MapDocify
<
TypeParam
>
(
arr
,
[
=
](
const
TypeParam
&
tp
)
{
return
Docify
(
tp
);
});
return
MapDocify
<
TypeParam
>
(
arr
,
[
=
](
const
TypeParam
&
tp
)
{
return
Docify
(
tp
);
});
}
std
::
vector
<
Doc
>
DocifyTypeConstraint
(
const
tvm
::
Array
<
TypeConstraint
>&
arr
)
{
...
...
@@ -188,10 +190,11 @@ class ExprDocifier : private ExprFunctor<Doc(const Expr& n)> {
return
vec
;
}
std
::
vector
<
Doc
>
DocifyParamArray
(
const
tvm
::
Array
<
Param
>&
arr
)
{
std
::
vector
<
Doc
>
DocifyParamArray
(
const
tvm
::
Array
<
Var
>&
arr
)
{
std
::
vector
<
Doc
>
vec
;
for
(
size_t
i
=
0
;
i
<
arr
.
size
();
++
i
)
{
vec
.
push_back
(
Docify
(
arr
[
i
]));
for
(
Var
param
:
arr
)
{
vec
.
emplace_back
(
TypeAnnotation
(
DocOfStr
(
VarName
(
param
)),
param
->
type_annotation
));
}
return
vec
;
}
...
...
@@ -212,10 +215,6 @@ class ExprDocifier : private ExprFunctor<Doc(const Expr& n)> {
return
DocOfStr
(
g
->
name_hint
);
}
Doc
VisitExpr_
(
const
ParamNode
*
p
)
final
{
return
TypeAnnotation
(
Docify
(
p
->
var
),
p
->
type
);
}
Doc
VisitExpr_
(
const
FunctionNode
*
f
)
final
{
return
Group
(
TypeAnnotation
(
Seq
(
"("
,
DocifyParamArray
(
f
->
params
),
")"
),
f
->
ret_type
)
+
Sep
()
+
DocOfStr
(
"=>"
)
+
Sep
()
+
...
...
@@ -227,7 +226,8 @@ class ExprDocifier : private ExprFunctor<Doc(const Expr& n)> {
}
Doc
VisitExpr_
(
const
LetNode
*
l
)
final
{
return
Group
(
DocOfStr
(
"let"
)
+
Sep
()
+
TypeAnnotation
(
Docify
(
l
->
var
),
l
->
value_type
)
+
Sep
()
+
return
Group
(
DocOfStr
(
"let"
)
+
Sep
()
+
TypeAnnotation
(
Docify
(
l
->
var
),
l
->
var
->
type_annotation
)
+
Sep
()
+
DocOfStr
(
"="
)
+
Sep
()
+
Docify
(
l
->
value
)
+
DocOfStr
(
";"
)
+
Endl
()
+
Docify
(
l
->
body
));
}
...
...
src/relay/ir/expr.cc
View file @
0b4cc050
...
...
@@ -54,20 +54,26 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
p
->
stream
<<
"Tuple("
<<
node
->
fields
<<
")"
;
});
Var
VarNode
::
make
(
std
::
string
name_hint
)
{
Var
VarNode
::
make
(
std
::
string
name_hint
,
Type
type_annotation
)
{
NodePtr
<
VarNode
>
n
=
make_node
<
VarNode
>
();
n
->
name_hint
=
std
::
move
(
name_hint
);
n
->
type_annotation
=
std
::
move
(
type_annotation
);
return
Var
(
n
);
}
TVM_REGISTER_API
(
"relay._make.Var"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
VarNode
::
make
(
args
[
0
]);
*
ret
=
VarNode
::
make
(
args
[
0
]
,
args
[
1
]
);
});
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
.
set_dispatch
<
VarNode
>
([](
const
VarNode
*
node
,
tvm
::
IRPrinter
*
p
)
{
p
->
stream
<<
"Var("
<<
node
->
name_hint
<<
")"
;
p
->
stream
<<
"Var("
<<
node
->
name_hint
;
if
(
node
->
type_annotation
.
defined
())
{
p
->
stream
<<
", ty="
;
p
->
print
(
node
->
type_annotation
);
}
p
->
stream
<<
")"
;
});
GlobalVar
GlobalVarNode
::
make
(
std
::
string
name_hint
)
{
...
...
@@ -86,24 +92,10 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
p
->
stream
<<
"GlobalVar("
<<
node
->
name_hint
<<
")"
;
});
Param
ParamNode
::
make
(
Var
var
,
Type
type
)
{
NodePtr
<
ParamNode
>
n
=
make_node
<
ParamNode
>
();
n
->
var
=
std
::
move
(
var
);
n
->
type
=
std
::
move
(
type
);
return
Param
(
n
);
}
TVM_REGISTER_API
(
"relay._make.Param"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
ParamNode
::
make
(
args
[
0
],
args
[
1
]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
.
set_dispatch
<
ParamNode
>
([](
const
ParamNode
*
node
,
tvm
::
IRPrinter
*
p
)
{
p
->
stream
<<
"Param("
<<
node
->
var
<<
", "
<<
node
->
type
<<
")"
;
});
Function
FunctionNode
::
make
(
tvm
::
Array
<
Param
>
params
,
Type
ret_type
,
Expr
body
,
Function
FunctionNode
::
make
(
tvm
::
Array
<
Var
>
params
,
Type
ret_type
,
Expr
body
,
tvm
::
Array
<
TypeParam
>
type_params
)
{
NodePtr
<
FunctionNode
>
n
=
make_node
<
FunctionNode
>
();
n
->
params
=
std
::
move
(
params
);
...
...
@@ -113,12 +105,11 @@ Function FunctionNode::make(tvm::Array<Param> params, Type ret_type, Expr body,
return
Function
(
n
);
}
Type
FunctionNode
::
fn_type
()
const
{
FuncType
FunctionNode
::
func_type_annotation
()
const
{
Array
<
Type
>
param_types
;
for
(
auto
param
:
this
->
params
)
{
param_types
.
push_back
(
param
->
type
);
param_types
.
push_back
(
param
->
type
_annotation
);
}
return
FuncTypeNode
::
make
(
param_types
,
this
->
ret_type
,
this
->
type_params
,
{});
}
...
...
@@ -155,24 +146,23 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
<<
node
->
attrs
<<
", "
<<
node
->
type_args
<<
")"
;
});
Let
LetNode
::
make
(
Var
var
,
Expr
value
,
Expr
body
,
Type
value_type
)
{
Let
LetNode
::
make
(
Var
var
,
Expr
value
,
Expr
body
)
{
NodePtr
<
LetNode
>
n
=
make_node
<
LetNode
>
();
n
->
var
=
std
::
move
(
var
);
n
->
value
=
std
::
move
(
value
);
n
->
body
=
std
::
move
(
body
);
n
->
value_type
=
std
::
move
(
value_type
);
return
Let
(
n
);
}
TVM_REGISTER_API
(
"relay._make.Let"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
LetNode
::
make
(
args
[
0
],
args
[
1
],
args
[
2
],
args
[
3
]);
});
*
ret
=
LetNode
::
make
(
args
[
0
],
args
[
1
],
args
[
2
]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
.
set_dispatch
<
LetNode
>
([](
const
LetNode
*
node
,
tvm
::
IRPrinter
*
p
)
{
p
->
stream
<<
"LetNode("
<<
node
->
var
<<
", "
<<
node
->
value
<<
", "
<<
node
->
body
<<
", "
<<
node
->
value_type
<<
")"
;
<<
", "
<<
node
->
body
<<
")"
;
});
If
IfNode
::
make
(
Expr
cond
,
Expr
true_branch
,
Expr
false_branch
)
{
...
...
src/relay/ir/expr_functor.cc
View file @
0b4cc050
...
...
@@ -24,6 +24,16 @@ Expr ExprMutator::Mutate(const Expr& expr) {
}
Expr
ExprMutator
::
VisitExpr_
(
const
VarNode
*
op
)
{
// NOTE: var will only be mutated once
// Thanks to the memo and reused during rewriting if necessary.
// It is safe to assume that the
if
(
op
->
type_annotation
.
defined
())
{
auto
type
=
this
->
VisitType
(
op
->
type_annotation
);
if
(
!
op
->
type_annotation
.
same_as
(
type
))
{
return
VarNode
::
make
(
op
->
name_hint
,
type
);
}
}
// default case return self.
return
GetRef
<
Expr
>
(
op
);
}
...
...
@@ -55,16 +65,6 @@ Expr ExprMutator::VisitExpr_(const TupleNode* op) {
}
}
Expr
ExprMutator
::
VisitExpr_
(
const
ParamNode
*
op
)
{
Var
var
=
Downcast
<
Var
>
(
this
->
Mutate
(
op
->
var
));
auto
type
=
this
->
VisitType
(
op
->
type
);
if
(
op
->
var
.
same_as
(
var
)
&&
op
->
type
.
same_as
(
type
))
{
return
GetRef
<
Expr
>
(
op
);
}
else
{
return
ParamNode
::
make
(
var
,
type
);
}
}
Expr
ExprMutator
::
VisitExpr_
(
const
FunctionNode
*
op
)
{
tvm
::
Array
<
TypeParam
>
ty_params
;
bool
all_ty_params_changed
=
true
;
...
...
@@ -75,10 +75,10 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* op) {
all_ty_params_changed
&=
new_ty_param
.
same_as
(
ty_param
);
}
tvm
::
Array
<
Param
>
params
;
tvm
::
Array
<
Var
>
params
;
bool
all_params_changed
=
true
;
for
(
auto
param
:
op
->
params
)
{
Param
new_param
=
Downcast
<
Param
>
(
this
->
Mutate
(
param
));
Var
new_param
=
Downcast
<
Var
>
(
this
->
Mutate
(
param
));
params
.
push_back
(
new_param
);
all_params_changed
&=
param
.
same_as
(
new_param
);
}
...
...
@@ -123,17 +123,15 @@ Expr ExprMutator::VisitExpr_(const CallNode* call_node) {
Expr
ExprMutator
::
VisitExpr_
(
const
LetNode
*
op
)
{
Var
var
=
Downcast
<
Var
>
(
this
->
Mutate
(
op
->
var
));
auto
type
=
this
->
VisitType
(
op
->
value_type
);
auto
value
=
this
->
Mutate
(
op
->
value
);
auto
body
=
this
->
Mutate
(
op
->
body
);
if
(
var
.
same_as
(
op
->
var
)
&&
type
.
same_as
(
op
->
value_type
)
&&
value
.
same_as
(
op
->
value
)
&&
body
.
same_as
(
op
->
body
))
{
return
GetRef
<
Expr
>
(
op
);
}
else
{
return
LetNode
::
make
(
var
,
value
,
body
,
type
);
return
LetNode
::
make
(
var
,
value
,
body
);
}
}
...
...
@@ -162,6 +160,9 @@ Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) {
Type
ExprMutator
::
VisitType
(
const
Type
&
t
)
{
return
t
;
}
void
ExprVisitor
::
ExprVisitor
::
VisitExpr_
(
const
VarNode
*
op
)
{
if
(
op
->
type_annotation
.
defined
())
{
this
->
VisitType
(
op
->
type_annotation
);
}
}
void
ExprVisitor
::
ExprVisitor
::
VisitExpr_
(
const
GlobalVarNode
*
op
)
{
...
...
@@ -176,10 +177,6 @@ void ExprVisitor::ExprVisitor::VisitExpr_(const TupleNode* op) {
}
}
void
ExprVisitor
::
ExprVisitor
::
VisitExpr_
(
const
ParamNode
*
op
)
{
this
->
VisitExpr
(
op
->
var
);
}
void
ExprVisitor
::
ExprVisitor
::
VisitExpr_
(
const
FunctionNode
*
op
)
{
for
(
auto
param
:
op
->
params
)
{
this
->
VisitExpr
(
param
);
...
...
src/relay/pass/alpha_eq.cc
View file @
0b4cc050
...
...
@@ -252,15 +252,6 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
}
}
void
VisitExpr_
(
const
ParamNode
*
p1
,
const
Expr
&
e2
)
final
{
if
(
const
ParamNode
*
p2
=
e2
.
as
<
ParamNode
>
())
{
eq_map
.
Set
(
p1
->
var
,
p2
->
var
);
equal
=
equal
&&
AlphaEqual
(
p1
->
type
,
p2
->
type
);
}
else
{
equal
=
false
;
}
}
void
VisitExpr_
(
const
FunctionNode
*
func1
,
const
Expr
&
e2
)
final
{
if
(
const
FunctionNode
*
func2
=
e2
.
as
<
FunctionNode
>
())
{
if
(
func1
->
params
.
size
()
!=
func2
->
params
.
size
())
{
...
...
@@ -273,9 +264,10 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
return
;
}
for
(
size_t
i
=
0
U
;
i
<
func1
->
params
.
size
();
i
++
)
{
this
->
VisitExpr
(
func1
->
params
[
i
],
func2
->
params
[
i
]);
for
(
size_t
i
=
0
;
i
<
func1
->
params
.
size
();
++
i
)
{
MergeVarDecl
(
func1
->
params
[
i
],
func2
->
params
[
i
]);
}
if
(
!
equal
)
return
;
for
(
size_t
i
=
0U
;
i
<
func1
->
type_params
.
size
();
i
++
)
{
equal
=
equal
&&
AlphaEqual
(
func1
->
type_params
[
i
],
func2
->
type_params
[
i
]);
...
...
@@ -332,19 +324,9 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
void
VisitExpr_
(
const
LetNode
*
op
,
const
Expr
&
e2
)
final
{
if
(
const
LetNode
*
let
=
e2
.
as
<
LetNode
>
())
{
eq_map
.
Set
(
op
->
var
,
let
->
var
);
MergeVarDecl
(
op
->
var
,
let
->
var
);
this
->
VisitExpr
(
op
->
value
,
let
->
value
);
this
->
VisitExpr
(
op
->
body
,
let
->
body
);
// value_type should match as well (including nulls)
if
(
op
->
value_type
.
defined
()
!=
let
->
value_type
.
defined
())
{
equal
=
false
;
return
;
}
if
(
op
->
value_type
.
defined
())
{
equal
=
equal
&&
AlphaEqual
(
op
->
value_type
,
let
->
value_type
);
}
}
else
{
equal
=
false
;
}
...
...
@@ -388,6 +370,20 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
equal
=
false
;
}
}
private
:
void
MergeVarDecl
(
const
Var
&
var1
,
const
Var
&
var2
)
{
if
(
var1
->
type_annotation
.
defined
()
!=
var2
->
type_annotation
.
defined
())
{
equal
=
false
;
return
;
}
if
(
var1
->
type_annotation
.
defined
()
&&
!
AlphaEqual
(
var1
->
type_annotation
,
var2
->
type_annotation
))
{
equal
=
false
;
return
;
}
eq_map
.
Set
(
var1
,
var2
);
}
};
bool
AlphaEqual
(
const
Expr
&
e1
,
const
Expr
&
e2
)
{
...
...
src/relay/pass/dead_code.cc
View file @
0b4cc050
...
...
@@ -54,12 +54,7 @@ class CalcDep : private ExprMutator {
}
private
:
struct
Binder
{
Type
t
;
Expr
e
;
Binder
(
const
Type
&
t
,
const
Expr
&
e
)
:
t
(
t
),
e
(
e
)
{
}
};
using
VarMap
=
std
::
unordered_map
<
Var
,
Binder
,
NodeHash
,
NodeEqual
>
;
using
VarMap
=
std
::
unordered_map
<
Var
,
Expr
,
NodeHash
,
NodeEqual
>
;
VarMap
var_map_
;
Expr
VisitExpr_
(
const
IfNode
*
i
)
final
{
...
...
@@ -74,9 +69,7 @@ class CalcDep : private ExprMutator {
}
Expr
VisitExpr_
(
const
LetNode
*
l
)
final
{
var_map_
.
insert
(
std
::
pair
<
Var
,
Binder
>
(
l
->
var
,
Binder
(
l
->
value_type
,
Eliminate
(
l
->
value
))));
var_map_
[
l
->
var
]
=
Eliminate
(
l
->
value
);
return
VisitExpr
(
l
->
body
);
}
...
...
@@ -92,15 +85,16 @@ class CalcDep : private ExprMutator {
explicit
GenLet
(
const
VarMap
&
var_map
)
:
var_map_
(
var_map
)
{
}
friend
CalcDep
;
void
VisitExpr_
(
const
VarNode
*
vn
)
final
{
Var
v
=
GetRef
<
Var
>
(
vn
);
if
(
var_map_
.
count
(
v
)
!=
0
)
{
auto
val
=
var_map_
.
at
(
v
);
var_map_
.
erase
(
v
);
void
VisitExpr_
(
const
VarNode
*
vnode
)
final
{
Var
v
=
GetRef
<
Var
>
(
vnode
);
auto
it
=
var_map_
.
find
(
v
);
if
(
it
!=
var_map_
.
end
())
{
Expr
expr
=
it
->
second
;
var_map_
.
erase
(
it
);
// erase before visit to handle letrec
VisitExpr
(
val
.
e
);
VisitExpr
(
expr
);
// visit before push back so the dependency of dependency is before the dependency
lets_
.
Push
(
v
,
val
.
t
,
val
.
e
);
lets_
.
Push
(
v
,
expr
);
}
}
};
...
...
src/relay/pass/let_list.h
View file @
0b4cc050
...
...
@@ -26,57 +26,46 @@ namespace relay {
*/
class
LetList
{
public
:
/*! \brief insert a binding.
/*!
* \brief insert a binding.
*
*
\param pv the var of the binding.
* \param pv the var of the binding.
*
*
\param ty the typ
e of the binding.
*
\param expr the valu
e of the binding.
*
* \param expr the value of the binding.
*
* \return a Var that hold the inserted expr.
* \return a Var that hold the inserted expr.
*/
Var
Push
(
const
Var
&
pv
,
const
Type
&
ty
,
const
Expr
&
expr
)
{
std
::
tuple
<
Var
,
Type
,
Expr
>
tuple
(
pv
,
ty
,
expr
);
lets_
.
push_back
(
tuple
);
Var
Push
(
Var
pv
,
Expr
expr
)
{
lets_
.
emplace_back
(
std
::
make_pair
(
pv
,
expr
));
return
pv
;
}
/*! \brief insert a binding.
/*!
* \brief insert a binding.
*
*
\param ty the type of the binding.
* \param ty the type of the binding.
*
*
\param expr the value of the binding.
* \param expr the value of the binding.
*
* \return a Var that hold the inserted expr.
*/
Var
Push
(
const
Type
&
ty
,
const
Expr
&
expr
)
{
return
Push
(
VarNode
::
make
(
"x"
),
ty
,
expr
);
}
/*! \brief insert a binding.
*
* \param pv the var of the binding.
*
* \param expr the value of the binding.
*
* \return a Var that hold the inserted expr.
* \return a Var that hold the inserted expr.
*/
Var
Push
(
const
Var
&
pv
,
const
Expr
&
expr
)
{
return
Push
(
pv
,
IncompleteTypeNode
::
make
(
TypeParamNode
::
kType
),
expr
);
Var
Push
(
Type
ty
,
Expr
expr
)
{
return
Push
(
VarNode
::
make
(
"x"
,
ty
),
expr
);
}
/*! \brief insert a binding.
/*!
* \brief insert a binding.
*
* \param expr the value of the binding.
*
* \return a Var that hold the inserted expr.
*/
Var
Push
(
const
Expr
&
expr
)
{
Var
Push
(
Expr
expr
)
{
return
Push
(
IncompleteTypeNode
::
make
(
TypeParamNode
::
kType
),
expr
);
}
/*! \brief wrap an expr around the LetList.
/*!
* \brief wrap an expr around the LetList.
*
* \param body the Expression to be wrapped around.
*
...
...
@@ -85,7 +74,7 @@ class LetList {
Expr
Get
(
const
Expr
&
body
)
const
{
Expr
ret
=
body
;
for
(
auto
rit
=
lets_
.
rbegin
();
rit
!=
lets_
.
rend
();
++
rit
)
{
ret
=
LetNode
::
make
(
std
::
get
<
0
>
(
*
rit
),
std
::
get
<
2
>
(
*
rit
),
ret
,
std
::
get
<
1
>
(
*
rit
)
);
ret
=
LetNode
::
make
(
std
::
get
<
0
>
(
*
rit
),
std
::
get
<
1
>
(
*
rit
),
ret
);
}
return
ret
;
}
...
...
@@ -118,7 +107,7 @@ class LetList {
}
private
:
std
::
vector
<
std
::
tuple
<
Var
,
Type
,
Expr
>
>
lets_
;
std
::
vector
<
std
::
pair
<
Var
,
Expr
>
>
lets_
;
};
}
// namespace relay
...
...
src/relay/pass/type_infer.cc
View file @
0b4cc050
...
...
@@ -87,15 +87,11 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
// Visitor logics
Type
VisitExpr_
(
const
VarNode
*
op
)
final
{
// The type of Var can already been lookedup in type_map_;
LOG
(
FATAL
)
<<
"Cannot find binding for var "
<<
GetRef
<
Var
>
(
op
);
return
Type
();
}
Type
VisitExpr_
(
const
ParamNode
*
op
)
final
{
// directly handled by Funtion
LOG
(
FATAL
)
<<
"not reached"
;
return
Type
();
if
(
op
->
type_annotation
.
defined
())
{
return
op
->
type_annotation
;
}
else
{
return
IncompleteTypeNode
::
make
(
TypeParamNode
::
kType
);
}
}
Type
VisitExpr_
(
const
GlobalVarNode
*
op
)
final
{
...
...
@@ -139,11 +135,11 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
Type
VisitExpr_
(
const
LetNode
*
op
)
final
{
Type
vtype
=
GetType
(
op
->
value
);
if
(
op
->
va
lue_type
.
defined
())
{
vtype
=
Unify
(
vtype
,
op
->
va
lue_type
,
op
->
span
);
if
(
op
->
va
r
->
type_annotation
.
defined
())
{
vtype
=
Unify
(
vtype
,
op
->
va
r
->
type_annotation
,
op
->
span
);
}
CHECK
(
!
type_map_
.
count
(
op
->
var
));
// NOTE: no scoping is necessary becase var are unique in program
// NOTE: no scoping is necessary beca
u
se var are unique in program
type_map_
[
op
->
var
]
=
vtype
;
return
GetType
(
op
->
body
);
}
...
...
@@ -256,8 +252,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
Type
VisitExpr_
(
const
FunctionNode
*
f
)
final
{
for
(
auto
param
:
f
->
params
)
{
type_map_
[
param
->
var
]
=
param
->
type
;
type_map_
[
param
]
=
param
->
type
;
GetType
(
param
);
}
Type
rtype
=
GetType
(
f
->
body
);
// Run solver using the currently known information
...
...
@@ -265,8 +260,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
// Trying to resolve
Array
<
Type
>
arg_types
;
for
(
size_t
i
=
0
;
i
<
f
->
params
.
size
();
++
i
)
{
Param
param
=
f
->
params
[
i
];
Type
atype
=
solver_
.
Resolve
(
param
->
type
);
Type
atype
=
solver_
.
Resolve
(
GetType
(
f
->
params
[
i
]));
CHECK
(
atype
.
as
<
IncompleteTypeNode
>
()
==
nullptr
)
<<
"Cannot resolve type of "
<<
i
<<
"-th parameter of function at"
<<
f
->
span
;
...
...
@@ -311,9 +305,6 @@ class TypeInferencer::Resolver : public ExprMutator {
return
AttachCheckedType
(
op
);
}
Expr
VisitExpr_
(
const
ParamNode
*
op
)
final
{
return
ExprMutator
::
VisitExpr_
(
op
);
}
Expr
VisitExpr_
(
const
FunctionNode
*
op
)
final
{
return
AttachCheckedType
(
op
);
...
...
@@ -380,7 +371,7 @@ Expr InferType(const Environment& env,
const
GlobalVar
&
var
,
const
Function
&
func
)
{
Function
func_copy
=
Function
(
make_node
<
FunctionNode
>
(
*
func
.
operator
->
()));
func_copy
->
checked_type_
=
func_copy
->
f
n_type
();
func_copy
->
checked_type_
=
func_copy
->
f
unc_type_annotation
();
env
->
functions
.
Set
(
var
,
func_copy
);
Expr
func_ret
=
TypeInferencer
(
env
).
Infer
(
func_copy
);
auto
map_node
=
env
->
functions
.
CopyOnWrite
();
...
...
src/relay/pass/util.cc
View file @
0b4cc050
...
...
@@ -50,14 +50,17 @@ class FreeVar : public ExprVisitor {
if
(
bound_vars
.
count
(
var
)
==
0
)
{
free_vars
.
insert
(
var
);
}
if
(
v
->
type_annotation
.
defined
())
{
VisitType
(
v
->
type_annotation
);
}
}
void
VisitExpr_
(
const
FunctionNode
*
f
)
final
{
for
(
const
auto
&
tp
:
f
->
type_params
)
{
bound_types
.
insert
(
tp
);
}
for
(
const
auto
&
p
:
f
->
params
)
{
bound_vars
.
insert
(
p
->
var
);
for
(
const
auto
&
p
aram
:
f
->
params
)
{
bound_vars
.
insert
(
p
aram
);
}
VisitExpr
(
f
->
body
);
VisitType
(
f
->
ret_type
);
...
...
@@ -67,7 +70,6 @@ class FreeVar : public ExprVisitor {
bound_vars
.
insert
(
l
->
var
);
VisitExpr
(
l
->
value
);
VisitExpr
(
l
->
body
);
VisitType
(
l
->
value_type
);
}
public
:
...
...
src/relay/pass/well_formed.cc
View file @
0b4cc050
...
...
@@ -34,8 +34,8 @@ class WellFormedChecker : private ExprVisitor {
}
void
VisitExpr_
(
const
FunctionNode
*
f
)
final
{
for
(
const
Param
&
p
:
f
->
params
)
{
Check
(
p
->
var
);
for
(
const
Var
&
param
:
f
->
params
)
{
Check
(
p
aram
);
}
CheckWellFormed
(
f
->
body
);
}
...
...
tests/python/relay/test_ir_builder.py
View file @
0b4cc050
...
...
@@ -14,7 +14,6 @@ def test_let():
assert
var
==
prog
.
body
assert
isinstance
(
value
,
Constant
)
assert
value
.
data
.
asnumpy
()
==
np
.
array
(
1
)
assert
prog
.
value_type
==
None
if
__name__
==
"__main__"
:
test_let
()
tests/python/relay/test_ir_debug_printer.py
View file @
0b4cc050
...
...
@@ -49,18 +49,11 @@ def test_global_var():
show
(
gv
)
def
test_param
():
lv
=
relay
.
Var
(
'x'
)
ty
=
None
param
=
relay
.
Param
(
lv
,
ty
)
show
(
lv
)
def
test_function
():
param_names
=
[
'a'
,
'b'
,
'c'
,
'd'
]
params
=
tvm
.
convert
([
relay
.
Param
(
relay
.
Var
(
n
),
None
)
for
n
in
param_names
])
params
=
tvm
.
convert
([
relay
.
Var
(
n
)
for
n
in
param_names
])
ret_type
=
None
body
=
params
[
0
]
.
var
body
=
params
[
0
]
type_params
=
tvm
.
convert
([])
fn
=
relay
.
Function
(
params
,
ret_type
,
body
,
type_params
)
show
(
fn
)
...
...
@@ -76,11 +69,11 @@ def test_call():
def
test_let
():
lv
=
relay
.
Var
(
'x'
)
ty
=
relay
.
ty
.
TensorType
((
10
,
20
),
'float32'
)
lv
=
relay
.
Var
(
'x'
,
ty
)
arr
=
tvm
.
nd
.
array
(
10
)
value
=
relay
.
Constant
(
arr
)
let
=
relay
.
Let
(
lv
,
value
,
lv
,
ty
)
let
=
relay
.
Let
(
lv
,
value
,
lv
)
show
(
let
)
...
...
tests/python/relay/test_ir_nodes.py
View file @
0b4cc050
...
...
@@ -99,10 +99,16 @@ def test_tuple():
def
test_local_var
():
name_hint
=
's'
lv
=
relay
.
Var
(
name_hint
)
lv
.
name_hint
==
name_hint
assert
lv
.
name_hint
==
name_hint
assert
lv
.
type_annotation
is
None
# assert lv.span == None todo(@jroesch): what do we do about spans
str
(
lv
)
t1
=
relay
.
ty
.
TensorType
((),
"float"
)
lv
=
relay
.
Var
(
name_hint
,
t1
)
assert
lv
.
name_hint
==
name_hint
assert
lv
.
type_annotation
==
t1
def
test_global_var
():
name_hint
=
'g'
...
...
@@ -112,19 +118,9 @@ def test_global_var():
str
(
gv
)
def
test_param
():
lv
=
relay
.
Var
(
'x'
)
ty
=
None
param
=
relay
.
Param
(
lv
,
ty
)
assert
param
.
var
==
lv
assert
param
.
type
==
ty
assert
param
.
span
==
None
str
(
param
)
def
test_function
():
param_names
=
[
'a'
,
'b'
,
'c'
,
'd'
]
params
=
tvm
.
convert
([
relay
.
Param
(
relay
.
Var
(
n
),
None
)
for
n
in
param_names
])
params
=
tvm
.
convert
([
relay
.
Var
(
n
)
for
n
in
param_names
])
ret_type
=
None
body
=
None
type_params
=
tvm
.
convert
([])
...
...
@@ -154,10 +150,9 @@ def test_let():
value
=
relay
.
Constant
(
arr
)
# I would prefer that the order of arguments
# matches syntax let x: t = v in b
let
=
relay
.
Let
(
lv
,
value
,
lv
,
ty
)
let
=
relay
.
Let
(
lv
,
value
,
lv
)
assert
let
.
var
==
lv
assert
let
.
value
==
value
assert
let
.
value_type
==
ty
assert
let
.
body
==
lv
assert
let
.
span
==
None
str
(
let
)
...
...
@@ -194,7 +189,6 @@ if __name__ == "__main__":
test_tuple
()
test_local_var
()
test_global_var
()
test_param
()
test_function
()
test_call
()
test_let
()
...
...
tests/python/relay/test_ir_well_formed.py
View file @
0b4cc050
...
...
@@ -7,23 +7,22 @@ def test_well_formed():
assert
well_formed
(
x
)
v
=
relay
.
Constant
(
tvm
.
nd
.
array
(
10
))
ty
=
None
let
=
relay
.
Let
(
x
,
v
,
x
,
ty
)
let
=
relay
.
Let
(
x
,
v
,
x
)
assert
well_formed
(
let
)
assert
not
well_formed
(
relay
.
Let
(
x
,
v
,
let
,
ty
))
f
=
relay
.
Function
([
relay
.
Param
(
x
,
ty
)
],
ty
,
x
)
assert
not
well_formed
(
relay
.
Let
(
x
,
v
,
let
))
f
=
relay
.
Function
([
x
],
ty
,
x
)
assert
well_formed
(
f
)
# this test should pass in case of weak uniqueness (only test for shadowing)
# but we want all binder to be distinct from each other.
assert
not
well_formed
(
relay
.
Let
(
relay
.
Var
(
"y"
),
f
,
relay
.
Let
(
relay
.
Var
(
"z"
),
f
,
v
,
ty
),
ty
))
relay
.
Let
(
relay
.
Var
(
"z"
),
f
,
v
)
))
def
test_tuple
():
x
=
relay
.
Var
(
'x'
)
assert
well_formed
(
x
)
v
=
relay
.
Constant
(
tvm
.
nd
.
array
(
10
))
ty
=
None
let
=
relay
.
Let
(
x
,
v
,
x
,
ty
)
let
=
relay
.
Let
(
x
,
v
,
x
)
assert
well_formed
(
let
)
assert
well_formed
(
relay
.
Tuple
([
v
,
v
]))
assert
not
well_formed
(
relay
.
Tuple
([
let
,
let
]))
...
...
tests/python/relay/test_op_level1.py
View file @
0b4cc050
...
...
@@ -27,6 +27,8 @@ def test_single_op():
tvm
.
relay
.
sigmoid
,
tvm
.
relay
.
tanh
]:
check_single_op
(
opfunc
)
def
test_expand_dims_infer_type
():
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
t
,
d
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"t"
),
100
...
...
@@ -75,12 +77,13 @@ def test_unary_op():
ib
=
relay
.
ir_builder
.
IRBuilder
()
x
=
ib
.
param
(
"x"
,
relay
.
TensorType
((
10
,
4
),
"int32"
))
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
op
(
x
.
var
))
ib
.
ret
(
op
(
x
))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
TensorType
((
10
,
4
),
"int32"
)
def
test_binary_op
():
def
check_binary_op
(
opfunc
):
"""
...
...
@@ -94,7 +97,7 @@ def test_binary_op():
x
=
b
.
param
(
'x'
,
tensor_type
(
5
,
5
,
5
))
y
=
b
.
param
(
'y'
,
tensor_type
(
5
,
5
,
5
))
with
b
.
function
(
x
,
y
)
as
func
:
b
.
ret
(
opfunc
(
x
.
var
,
y
.
var
))
b
.
ret
(
opfunc
(
x
,
y
))
b
.
ret
(
func
)
prog
,
env
=
b
.
get
()
ttype
=
tensor_type
(
5
,
5
,
5
)
...
...
@@ -118,7 +121,7 @@ def test_binary_broadcast_op():
x
=
b
.
param
(
'x'
,
tensor_type
(
10
,
4
))
y
=
b
.
param
(
'y'
,
tensor_type
(
5
,
10
,
1
))
with
b
.
function
(
x
,
y
)
as
func
:
b
.
ret
(
opfunc
(
x
.
var
,
y
.
var
))
b
.
ret
(
opfunc
(
x
,
y
))
b
.
ret
(
func
)
prog
,
env
=
b
.
get
()
...
...
tests/python/relay/test_op_level2.py
View file @
0b4cc050
...
...
@@ -11,7 +11,7 @@ def test_conv2d_infer_type():
w
=
ib
.
param
(
"w"
,
relay
.
ty
.
IncompleteType
())
with
ib
.
function
(
x
,
w
)
as
func
:
ib
.
ret
(
relay
.
nn
.
conv2d
(
x
.
var
,
w
.
var
,
ib
.
ret
(
relay
.
nn
.
conv2d
(
x
,
w
,
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
),
channels
=
2
))
...
...
@@ -29,7 +29,7 @@ def test_conv2d_infer_type():
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"int8"
))
w
=
ib
.
param
(
"w"
,
relay
.
ty
.
TensorType
((
2
,
10
,
3
,
3
),
"int8"
))
with
ib
.
function
(
x
,
w
)
as
func
:
ib
.
ret
(
relay
.
nn
.
conv2d
(
x
.
var
,
w
.
var
,
out_dtype
=
"int32"
))
ib
.
ret
(
relay
.
nn
.
conv2d
(
x
,
w
,
out_dtype
=
"int32"
))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
...
...
@@ -42,7 +42,7 @@ def test_conv2d_infer_type():
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"int8"
))
w
=
ib
.
param
(
"w"
,
relay
.
ty
.
IncompleteType
())
with
ib
.
function
(
x
,
w
)
as
func
:
ib
.
ret
(
relay
.
nn
.
conv2d
(
x
.
var
,
w
.
var
,
ib
.
ret
(
relay
.
nn
.
conv2d
(
x
,
w
,
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
),
channels
=
16
,
...
...
@@ -65,7 +65,7 @@ def test_conv2d_transpose_infer_type():
w
=
ib
.
param
(
"w"
,
relay
.
ty
.
IncompleteType
())
with
ib
.
function
(
x
,
w
)
as
func
:
ib
.
ret
(
relay
.
nn
.
conv2d_transpose
(
x
.
var
,
w
.
var
,
ib
.
ret
(
relay
.
nn
.
conv2d_transpose
(
x
,
w
,
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
),
channels
=
15
))
...
...
@@ -83,7 +83,7 @@ def test_conv2d_transpose_infer_type():
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
w
=
ib
.
param
(
"w"
,
relay
.
ty
.
TensorType
((
12
,
11
,
5
,
5
),
"float32"
))
with
ib
.
function
(
x
,
w
)
as
func
:
ib
.
ret
(
relay
.
nn
.
conv2d_transpose
(
x
.
var
,
w
.
var
,
ib
.
ret
(
relay
.
nn
.
conv2d_transpose
(
x
,
w
,
output_padding
=
(
1
,
1
),
channels
=
11
,
data_layout
=
"NHWC"
))
...
...
@@ -98,7 +98,7 @@ def test_upsampling_infer_type():
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
),
tvm
.
var
(
"h"
),
tvm
.
var
(
"w"
)
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
relay
.
nn
.
upsampling
(
x
.
var
,
scale
=
2
,
layout
=
"NCHW"
,
method
=
"BILINEAR"
))
ib
.
ret
(
relay
.
nn
.
upsampling
(
x
,
scale
=
2
,
layout
=
"NCHW"
,
method
=
"BILINEAR"
))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
...
...
@@ -108,7 +108,7 @@ def test_upsampling_infer_type():
n
,
c
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
)
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
100
,
200
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
relay
.
nn
.
upsampling
(
x
.
var
,
scale
=
2
,
layout
=
"NCHW"
,
method
=
"BILINEAR"
))
ib
.
ret
(
relay
.
nn
.
upsampling
(
x
,
scale
=
2
,
layout
=
"NCHW"
,
method
=
"BILINEAR"
))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
...
...
@@ -119,7 +119,7 @@ def _test_pool2d_infer_type(opfunc):
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
10
,
224
,
224
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
opfunc
(
x
.
var
,
pool_size
=
(
1
,
1
)))
ib
.
ret
(
opfunc
(
x
,
pool_size
=
(
1
,
1
)))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
...
...
@@ -132,7 +132,7 @@ def _test_pool2d_infer_type(opfunc):
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
10
,
224
,
224
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
opfunc
(
x
.
var
,
pool_size
=
(
ph
,
pw
),
strides
=
(
sh
,
sw
)))
ib
.
ret
(
opfunc
(
x
,
pool_size
=
(
ph
,
pw
),
strides
=
(
sh
,
sw
)))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
...
...
@@ -144,7 +144,7 @@ def _test_global_pool2d_infer_type(opfunc):
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
),
224
,
224
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
h
,
w
,
c
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
opfunc
(
x
.
var
,
layout
=
"NHWC"
))
ib
.
ret
(
opfunc
(
x
,
layout
=
"NHWC"
))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
...
...
@@ -154,7 +154,7 @@ def _test_global_pool2d_infer_type(opfunc):
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
),
tvm
.
var
(
"h"
),
tvm
.
var
(
"w"
)
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
opfunc
(
x
.
var
))
ib
.
ret
(
opfunc
(
x
))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
...
...
@@ -172,7 +172,7 @@ def test_flatten_infer_type():
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
d1
,
d2
,
d3
,
d4
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
relay
.
nn
.
batch_flatten
(
x
.
var
))
ib
.
ret
(
relay
.
nn
.
batch_flatten
(
x
))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
...
...
@@ -181,7 +181,7 @@ def test_flatten_infer_type():
ib
=
relay
.
ir_builder
.
IRBuilder
()
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
3
,
2
,
4
,
3
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
relay
.
nn
.
batch_flatten
(
x
.
var
))
ib
.
ret
(
relay
.
nn
.
batch_flatten
(
x
))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
...
...
@@ -190,7 +190,7 @@ def test_flatten_infer_type():
ib
=
relay
.
ir_builder
.
IRBuilder
()
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
d1
,
2
,
d3
,
3
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
relay
.
nn
.
batch_flatten
(
x
.
var
))
ib
.
ret
(
relay
.
nn
.
batch_flatten
(
x
))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
...
...
@@ -202,7 +202,7 @@ def test_pad_infer_type():
n
,
c
,
h
,
w
=
1
,
2
,
3
,
4
t
=
ib
.
param
(
"t"
,
relay
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
with
ib
.
function
(
t
)
as
func
:
ib
.
ret
(
relay
.
nn
.
pad
(
t
.
var
,
((
1
,
1
),
(
2
,
2
),
(
3
,
3
),
(
4
,
4
))))
ib
.
ret
(
relay
.
nn
.
pad
(
t
,
((
1
,
1
),
(
2
,
2
),
(
3
,
3
),
(
4
,
4
))))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
...
...
@@ -213,7 +213,7 @@ def test_pad_infer_type():
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
2
,
3
,
tvm
.
var
(
"w"
)
t
=
ib
.
param
(
"t"
,
relay
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
with
ib
.
function
(
t
)
as
func
:
ib
.
ret
(
relay
.
nn
.
pad
(
t
.
var
,
((
1
,
1
),
(
2
,
2
),
(
3
,
3
),
(
4
,
4
))))
ib
.
ret
(
relay
.
nn
.
pad
(
t
,
((
1
,
1
),
(
2
,
2
),
(
3
,
3
),
(
4
,
4
))))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
...
...
@@ -227,4 +227,3 @@ if __name__ == "__main__":
test_flatten_infer_type
()
test_pad_infer_type
()
test_conv2d_transpose_infer_type
()
tests/python/relay/test_op_level3.py
View file @
0b4cc050
...
...
@@ -17,12 +17,13 @@ def test_zeros_ones():
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
TensorType
((
124
,
50
),
"float64"
)
def
test_unary_identity
():
for
op
in
[
relay
.
zeros_like
,
relay
.
ones_like
]:
ib
=
relay
.
ir_builder
.
IRBuilder
()
x
=
ib
.
param
(
"x"
,
relay
.
TensorType
((
8
,
9
,
4
),
"int32"
))
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
op
(
x
.
var
))
ib
.
ret
(
op
(
x
))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
...
...
@@ -33,7 +34,7 @@ def test_clip_type():
ib
=
relay
.
ir_builder
.
IRBuilder
()
a
=
ib
.
param
(
"a"
,
relay
.
TensorType
((
10
,
4
),
"float32"
))
with
ib
.
function
(
a
)
as
func
:
ib
.
ret
(
relay
.
clip
(
a
.
var
,
1.
,
4.
))
ib
.
ret
(
relay
.
clip
(
a
,
1.
,
4.
))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
...
...
@@ -106,7 +107,7 @@ def test_take_infer_type():
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
(
dshape
,
"float32"
))
indices
=
ib
.
param
(
"indices"
,
relay
.
ty
.
TensorType
(
indices_shape
,
"int32"
))
with
ib
.
function
(
x
,
indices
)
as
func
:
ib
.
ret
(
relay
.
take
(
x
.
var
,
indices
.
var
,
axis
=
axis
))
ib
.
ret
(
relay
.
take
(
x
,
indices
,
axis
=
axis
))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
...
...
@@ -127,7 +128,7 @@ def test_full():
ib
=
relay
.
ir_builder
.
IRBuilder
()
x
=
ib
.
param
(
"x"
,
relay
.
TensorType
((),
"int8"
))
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
relay
.
full
(
x
.
var
,
()))
ib
.
ret
(
relay
.
full
(
x
,
()))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
...
...
@@ -137,7 +138,7 @@ def test_full():
ib
=
relay
.
ir_builder
.
IRBuilder
()
x
=
ib
.
param
(
"x"
,
relay
.
TensorType
((),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
relay
.
full
(
x
.
var
,
(
1
,
2
),
"int8"
))
ib
.
ret
(
relay
.
full
(
x
,
(
1
,
2
),
"int8"
))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
...
...
@@ -150,7 +151,7 @@ def test_full_like():
base
=
ib
.
param
(
"base"
,
relay
.
TensorType
((
1
,
2
,
3
),
"float32"
))
fill
=
ib
.
param
(
"fill"
,
relay
.
TensorType
((),
"float32"
))
with
ib
.
function
(
base
,
fill
)
as
func
:
ib
.
ret
(
relay
.
full_like
(
base
.
var
,
fill
.
var
))
ib
.
ret
(
relay
.
full_like
(
base
,
fill
))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
...
...
@@ -162,7 +163,7 @@ def test_full_like():
base
=
ib
.
param
(
"base"
,
relay
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
fill
=
ib
.
param
(
"fill"
,
relay
.
TensorType
((),
"float32"
))
with
ib
.
function
(
base
,
fill
)
as
func
:
ib
.
ret
(
relay
.
full_like
(
base
.
var
,
fill
.
var
))
ib
.
ret
(
relay
.
full_like
(
base
,
fill
))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
...
...
tests/python/relay/test_op_level4.py
View file @
0b4cc050
...
...
@@ -24,7 +24,7 @@ def test_cmp_type():
x
=
ib
.
param
(
"x"
,
relay
.
TensorType
((
10
,
4
),
"float32"
))
y
=
ib
.
param
(
"y"
,
relay
.
TensorType
((
5
,
10
,
1
),
"float32"
))
with
ib
.
function
(
x
,
y
)
as
func
:
ib
.
ret
(
op
(
x
.
var
,
y
.
var
))
ib
.
ret
(
op
(
x
,
y
))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
...
...
@@ -39,7 +39,7 @@ def test_binary_broadcast():
x
=
ib
.
param
(
"x"
,
relay
.
TensorType
((
10
,
4
),
"int32"
))
y
=
ib
.
param
(
"y"
,
relay
.
TensorType
((
5
,
10
,
1
),
"int32"
))
with
ib
.
function
(
x
,
y
)
as
func
:
ib
.
ret
(
op
(
x
.
var
,
y
.
var
))
ib
.
ret
(
op
(
x
,
y
))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
...
...
@@ -58,7 +58,7 @@ def test_binary_op():
x
=
b
.
param
(
'x'
,
tensor_type
(
5
,
5
,
5
))
y
=
b
.
param
(
'y'
,
tensor_type
(
5
,
5
,
5
))
with
b
.
function
(
x
,
y
)
as
func
:
b
.
ret
(
opfunc
(
x
.
var
,
y
.
var
))
b
.
ret
(
opfunc
(
x
,
y
))
b
.
ret
(
func
)
prog
,
env
=
b
.
get
()
ttype
=
tensor_type
(
5
,
5
,
5
)
...
...
@@ -81,7 +81,7 @@ def test_binary_broadcast_op():
x
=
b
.
param
(
'x'
,
tensor_type
(
10
,
4
))
y
=
b
.
param
(
'y'
,
tensor_type
(
5
,
10
,
1
))
with
b
.
function
(
x
,
y
)
as
func
:
b
.
ret
(
opfunc
(
x
.
var
,
y
.
var
))
b
.
ret
(
opfunc
(
x
,
y
))
b
.
ret
(
func
)
prog
,
env
=
b
.
get
()
...
...
@@ -103,7 +103,7 @@ def test_cmp_type():
x
=
ib
.
param
(
"x"
,
relay
.
TensorType
((
10
,
4
),
"float32"
))
y
=
ib
.
param
(
"y"
,
relay
.
TensorType
((
5
,
10
,
1
),
"float32"
))
with
ib
.
function
(
x
,
y
)
as
func
:
ib
.
ret
(
op
(
x
.
var
,
y
.
var
))
ib
.
ret
(
op
(
x
,
y
))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
...
...
@@ -118,7 +118,7 @@ def test_binary_broadcast():
x
=
ib
.
param
(
"x"
,
relay
.
TensorType
((
10
,
4
),
"int32"
))
y
=
ib
.
param
(
"y"
,
relay
.
TensorType
((
5
,
10
,
1
),
"int32"
))
with
ib
.
function
(
x
,
y
)
as
func
:
ib
.
ret
(
op
(
x
.
var
,
y
.
var
))
ib
.
ret
(
op
(
x
,
y
))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
...
...
@@ -131,7 +131,7 @@ def test_where():
x
=
ib
.
param
(
"x"
,
relay
.
TensorType
((
3
,
4
),
"float32"
))
y
=
ib
.
param
(
"y"
,
relay
.
TensorType
((
3
,
4
),
"float32"
))
with
ib
.
function
(
cond
,
x
,
y
)
as
func
:
ib
.
ret
(
relay
.
where
(
cond
.
var
,
x
.
var
,
y
.
var
))
ib
.
ret
(
relay
.
where
(
cond
,
x
,
y
))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
...
...
tests/python/relay/test_op_level5.py
View file @
0b4cc050
...
...
@@ -10,7 +10,7 @@ def test_resize_infer_type():
th
,
tw
=
tvm
.
var
(
"th"
),
tvm
.
var
(
"tw"
)
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
relay
.
image
.
resize
(
x
.
var
,
(
th
,
tw
)))
ib
.
ret
(
relay
.
image
.
resize
(
x
,
(
th
,
tw
)))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
...
...
@@ -19,7 +19,7 @@ def test_resize_infer_type():
ib
=
relay
.
ir_builder
.
IRBuilder
()
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"int8"
))
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
relay
.
image
.
resize
(
x
.
var
,
(
100
,
200
),
"NCHW"
,
"BILINEAR"
,
False
))
ib
.
ret
(
relay
.
image
.
resize
(
x
,
(
100
,
200
),
"NCHW"
,
"BILINEAR"
,
False
))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
...
...
tests/python/relay/test_pass_alpha_equal.py
View file @
0b4cc050
import
tvm
import
numpy
as
np
from
tvm
import
relay
from
tvm.relay.ir_pass
import
alpha_equal
from
tvm.relay.ir_builder
import
convert
...
...
@@ -179,9 +180,9 @@ def test_var_alpha_equal():
assert
not
alpha_equal
(
v1
,
v2
)
# let node allows for setting the eq_map
l1
=
relay
.
Let
(
v1
,
convert
(
1
),
v1
,
None
)
l2
=
relay
.
Let
(
v2
,
convert
(
1
),
v2
,
None
)
l3
=
relay
.
Let
(
v1
,
convert
(
1
),
v2
,
None
)
l1
=
relay
.
Let
(
v1
,
convert
(
1
),
v1
)
l2
=
relay
.
Let
(
v2
,
convert
(
1
),
v2
)
l3
=
relay
.
Let
(
v1
,
convert
(
1
),
v2
)
assert
alpha_equal
(
l1
,
l2
)
assert
not
alpha_equal
(
l1
,
l3
)
...
...
@@ -209,10 +210,10 @@ def test_tuple_alpha_equal():
assert
alpha_equal
(
tup
,
same
)
# use the eq_map
let_tup
=
relay
.
Let
(
v1
,
tup
,
v1
,
None
)
let_tup
=
relay
.
Let
(
v1
,
tup
,
v1
)
let_mapped
=
relay
.
Let
(
v2
,
relay
.
Tuple
([
v2
,
convert
(
2
),
convert
(
3
),
relay
.
Tuple
([
convert
(
4
)])]),
v2
,
None
)
v2
)
assert
alpha_equal
(
let_tup
,
let_mapped
)
more_fields
=
relay
.
Tuple
([
v1
,
convert
(
2
),
convert
(
3
),
relay
.
Tuple
([
convert
(
4
)]),
v2
])
...
...
@@ -242,61 +243,44 @@ def test_tuple_get_item_alpha_equal():
assert
alpha_equal
(
relay
.
TupleGetItem
(
x
,
1
),
relay
.
TupleGetItem
(
x
,
1
))
def
test_param_alpha_equal
():
# only checks equality of the types
v1
=
relay
.
Var
(
"v1"
)
v2
=
relay
.
Var
(
"v2"
)
p1
=
relay
.
Param
(
v1
,
relay
.
TensorType
((
1
,
2
,
3
),
"float32"
))
p2
=
relay
.
Param
(
v2
,
relay
.
TensorType
((
1
,
2
,
3
),
"float32"
))
assert
alpha_equal
(
p1
,
p2
)
p3
=
relay
.
Param
(
v1
,
relay
.
TensorType
((
4
,
5
,
6
),
"int8"
))
assert
not
alpha_equal
(
p1
,
p3
)
p4
=
relay
.
Param
(
v1
,
relay
.
TupleType
([
relay
.
TensorType
((
1
,
2
,
3
),
"float32"
)]))
assert
not
alpha_equal
(
p1
,
p4
)
def
test_function_alpha_equal
():
v1
=
relay
.
Var
(
"v1"
)
v2
=
relay
.
Var
(
"v2"
)
v3
=
relay
.
Var
(
"v3"
)
v4
=
relay
.
Var
(
"v4"
)
tt1
=
relay
.
TensorType
((
1
,
2
,
3
),
"float32"
)
tt2
=
relay
.
TensorType
((
4
,
5
,
6
),
"int8"
)
tt3
=
relay
.
TupleType
([
tt1
,
tt2
])
v1
=
relay
.
Var
(
"v1"
,
tt1
)
v2
=
relay
.
Var
(
"v2"
,
tt2
)
v3
=
relay
.
Var
(
"v3"
,
tt3
)
v4
=
relay
.
Var
(
"v4"
,
tt2
)
vret
=
relay
.
Constant
(
tvm
.
nd
.
array
(
np
.
ones
(
1
)))
tp1
=
relay
.
TypeParam
(
"tp1"
,
relay
.
Kind
.
Type
)
tp2
=
relay
.
TypeParam
(
"tp2"
,
relay
.
Kind
.
Type
)
tp3
=
relay
.
TypeParam
(
"tp3"
,
relay
.
Kind
.
Shape
)
tp4
=
relay
.
TypeParam
(
"tp4"
,
relay
.
Kind
.
Shape
)
basic_args
=
[
relay
.
Param
(
v3
,
tt1
),
relay
.
Param
(
v4
,
tt2
)]
basic_args
=
[
relay
.
Var
(
"v3"
,
tt1
),
relay
.
Var
(
"v4"
,
tt2
)]
basic_tps
=
[
tp1
,
tp2
]
func
=
relay
.
Function
([
relay
.
Param
(
v1
,
tt1
),
relay
.
Param
(
v2
,
tt2
)
],
tt2
,
v
2
,
basic_tps
)
mapped
=
relay
.
Function
(
basic_args
,
tt2
,
v4
,
basic_tps
)
func
=
relay
.
Function
([
v1
,
v2
],
tt2
,
v
1
,
basic_tps
)
mapped
=
relay
.
Function
(
basic_args
,
tt2
,
basic_args
[
0
]
,
basic_tps
)
assert
alpha_equal
(
func
,
mapped
)
fewer_params
=
relay
.
Function
([
relay
.
Param
(
v4
,
tt2
)],
tt2
,
v4
,
basic_tps
)
fewer_params
=
relay
.
Function
([
relay
.
Var
(
"v4"
,
tt2
)],
tt2
,
v4
,
basic_tps
)
assert
not
alpha_equal
(
func
,
fewer_params
)
more_params
=
relay
.
Function
([
relay
.
Param
(
v3
,
tt1
),
relay
.
Param
(
v4
,
tt2
),
relay
.
Param
(
v2
,
tt2
)],
tt2
,
v4
,
basic_tps
)
more_params
=
relay
.
Function
([
relay
.
Var
(
"v3"
,
tt1
),
relay
.
Var
(
"v4"
,
tt2
),
relay
.
Var
(
"v2"
,
tt2
)],
tt2
,
v4
,
basic_tps
)
assert
not
alpha_equal
(
func
,
more_params
)
params_unordered
=
relay
.
Function
([
relay
.
Param
(
v3
,
tt2
),
relay
.
Param
(
v4
,
tt1
)],
tt1
,
v3
,
basic_tps
)
params_unordered
=
relay
.
Function
([
v2
,
v1
],
tt2
,
v1
,
basic_tps
)
assert
not
alpha_equal
(
func
,
params_unordered
)
params_mismatch
=
relay
.
Function
([
relay
.
Param
(
v3
,
tt3
),
relay
.
Param
(
v4
,
tt2
)],
tt2
,
v4
,
basic_tps
)
params_mismatch
=
relay
.
Function
([
v1
,
v3
],
tt2
,
v1
,
basic_tps
)
assert
not
alpha_equal
(
func
,
params_mismatch
)
# also would not typecheck
...
...
@@ -376,7 +360,10 @@ def test_call_alpha_equal():
def
test_let_alpha_equal
():
tt1
=
relay
.
TensorType
((),
"float32"
)
tt2
=
relay
.
TensorType
((),
"int8"
)
v1
=
relay
.
Var
(
"v1"
)
v1_wtype
=
relay
.
Var
(
"v1"
,
tt1
)
v2
=
relay
.
Var
(
"v2"
)
v3
=
relay
.
Var
(
"v3"
)
...
...
@@ -394,14 +381,13 @@ def test_let_alpha_equal():
assert
not
alpha_equal
(
let
,
different_body
)
# specified types must match
tt1
=
relay
.
TensorType
((),
"float32"
)
tt2
=
relay
.
TensorType
((),
"int8"
)
let_with_type
=
relay
.
Let
(
v1
,
convert
(
2
),
v1
,
tt1
)
same_type
=
relay
.
Let
(
v1
,
convert
(
2
),
v1
,
tt1
)
let_with_type
=
relay
.
Let
(
v1_wtype
,
convert
(
2
),
v1_wtype
)
same_type
=
relay
.
Let
(
v1_wtype
,
convert
(
2
),
v1_wtype
)
assert
alpha_equal
(
let_with_type
,
same_type
)
assert
not
alpha_equal
(
let
,
let_with_type
)
different_type
=
relay
.
Let
(
v
1
,
convert
(
2
),
v1
,
tt
2
)
v2
=
relay
.
Var
(
"v1"
,
tt2
)
different_type
=
relay
.
Let
(
v
2
,
convert
(
2
),
v
2
)
assert
not
alpha_equal
(
let_with_type
,
different_type
)
...
...
@@ -437,16 +423,13 @@ if __name__ == "__main__":
test_tensor_type_alpha_equal
()
test_incomplete_type_alpha_equal
()
test_constant_alpha_equal
()
test_type_param_alpha_equal
()
test_func_type_alpha_equal
()
test_tuple_type_alpha_equal
()
test_type_relation_alpha_equal
()
test_constant_alpha_equal
()
test_var_alpha_equal
()
test_global_var_alpha_equal
()
test_tuple_alpha_equal
()
test_tuple_get_item_alpha_equal
()
test_param_alpha_equal
()
test_function_alpha_equal
()
test_call_alpha_equal
()
test_let_alpha_equal
()
...
...
tests/python/relay/test_pass_dead_code_elimination.py
View file @
0b4cc050
...
...
@@ -28,17 +28,17 @@ e = env()
def
test_let
():
orig
=
relay
.
Let
(
e
.
x
,
e
.
y
,
e
.
z
,
e
.
tt
)
orig
=
relay
.
Let
(
e
.
x
,
e
.
y
,
e
.
z
)
assert
alpha_equal
(
dead_code_elimination
(
orig
),
e
.
z
)
def
test_used_let
():
orig
=
relay
.
Let
(
e
.
a
,
e
.
b
,
relay
.
Let
(
e
.
c
,
e
.
d
,
e
.
c
,
e
.
tt
),
e
.
tt
)
assert
alpha_equal
(
dead_code_elimination
(
orig
),
relay
.
Let
(
e
.
c
,
e
.
d
,
e
.
c
,
e
.
tt
))
orig
=
relay
.
Let
(
e
.
a
,
e
.
b
,
relay
.
Let
(
e
.
c
,
e
.
d
,
e
.
c
)
)
assert
alpha_equal
(
dead_code_elimination
(
orig
),
relay
.
Let
(
e
.
c
,
e
.
d
,
e
.
c
))
def
test_chain_unused_let
():
orig
=
relay
.
Let
(
e
.
a
,
e
.
b
,
relay
.
Let
(
e
.
c
,
e
.
d
,
e
.
e
,
e
.
tt
),
e
.
tt
)
orig
=
relay
.
Let
(
e
.
a
,
e
.
b
,
relay
.
Let
(
e
.
c
,
e
.
d
,
e
.
e
)
)
assert
alpha_equal
(
dead_code_elimination
(
orig
),
e
.
e
)
...
...
@@ -56,19 +56,17 @@ def test_recursion():
f(2, 10000);
"""
f
=
relay
.
Var
(
"f"
)
n
=
relay
.
Var
(
"n"
)
np
=
relay
.
Param
(
n
,
e
.
int32
)
data
=
relay
.
Var
(
"data"
)
datap
=
relay
.
Param
(
data
,
e
.
float32
)
n
=
relay
.
Var
(
"n"
,
e
.
int32
)
data
=
relay
.
Var
(
"data"
,
e
.
float32
)
funcbody
=
relay
.
If
(
equal
(
n
,
convert
(
0
)),
data
,
f
(
subtract
(
n
,
convert
(
1.0
)),
log
(
data
)))
value
=
relay
.
Function
([
n
p
,
datap
],
e
.
float32
,
funcbody
,
[])
orig
=
relay
.
Let
(
f
,
funcbody
,
f
(
convert
(
2.0
),
convert
(
10000.0
))
,
e
.
float32
)
value
=
relay
.
Function
([
n
,
data
],
e
.
float32
,
funcbody
,
[])
orig
=
relay
.
Let
(
f
,
funcbody
,
f
(
convert
(
2.0
),
convert
(
10000.0
)))
assert
alpha_equal
(
dead_code_elimination
(
orig
),
orig
)
assert
alpha_equal
(
dead_code_elimination
(
relay
.
Let
(
f
,
funcbody
,
e
.
three
,
e
.
float32
)),
e
.
three
)
assert
alpha_equal
(
dead_code_elimination
(
relay
.
Let
(
f
,
funcbody
,
e
.
three
)),
e
.
three
)
def
test_op_let
():
assert
alpha_equal
(
dead_code_elimination
(
add
(
relay
.
Let
(
e
.
a
,
e
.
one
,
e
.
three
,
e
.
float32
),
e
.
two
)),
add
(
e
.
three
,
e
.
two
))
assert
alpha_equal
(
dead_code_elimination
(
add
(
relay
.
Let
(
e
.
a
,
e
.
one
,
e
.
three
),
e
.
two
)),
add
(
e
.
three
,
e
.
two
))
def
test_if
():
...
...
@@ -80,7 +78,7 @@ def test_tuple_get_item():
t
=
relay
.
Var
(
't'
)
g
=
relay
.
TupleGetItem
(
t
,
0
)
assert
alpha_equal
(
dead_code_elimination
(
g
),
g
)
assert
alpha_equal
(
dead_code_elimination
(
relay
.
TupleGetItem
(
relay
.
Let
(
e
.
a
,
e
.
one
,
t
,
e
.
float32
),
0
)),
g
)
assert
alpha_equal
(
dead_code_elimination
(
relay
.
TupleGetItem
(
relay
.
Let
(
e
.
a
,
e
.
one
,
t
),
0
)),
g
)
if
__name__
==
"__main__"
:
...
...
tests/python/relay/test_pass_free_vars.py
View file @
0b4cc050
...
...
@@ -3,16 +3,17 @@ from tvm import relay
from
tvm.relay.ir_pass
import
free_vars
,
free_type_vars
def
test_free_vars
():
x
=
relay
.
Var
(
"x"
)
ty
=
relay
.
TensorType
([],
"int32"
)
x
=
relay
.
Var
(
"x"
,
ty
)
fvx
=
free_vars
(
x
)
assert
len
(
fvx
)
==
1
assert
fvx
[
0
]
==
x
v
=
relay
.
Constant
(
tvm
.
nd
.
array
(
10
))
ty
=
relay
.
TensorType
([],
"int32"
)
let
=
relay
.
Let
(
x
,
v
,
x
,
ty
)
let
=
relay
.
Let
(
x
,
v
,
x
)
fvx
=
free_vars
(
let
)
assert
len
(
free_vars
(
let
))
==
0
f
=
relay
.
Function
([
relay
.
Param
(
x
,
ty
)
],
ty
,
x
)
f
=
relay
.
Function
([
x
],
ty
,
x
)
assert
len
(
free_vars
(
f
))
==
0
...
...
@@ -29,9 +30,9 @@ def test_tuple():
def
test_free_type_vars
():
tp
=
relay
.
TypeParam
(
""
)
ty
=
relay
.
TupleType
([
tp
,
relay
.
TensorType
([],
"int32"
)])
x
=
relay
.
Var
(
"x"
)
x
=
relay
.
Var
(
"x"
,
ty
)
y
=
relay
.
Var
(
"y"
)
let
=
relay
.
Let
(
x
,
y
,
x
,
ty
)
let
=
relay
.
Let
(
x
,
y
,
x
)
fvl
=
free_vars
(
let
)
assert
len
(
fvl
)
==
1
assert
fvl
[
0
]
==
y
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment