Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
0b4cc050
Unverified
Commit
0b4cc050
authored
Oct 14, 2018
by
Tianqi Chen
Committed by
GitHub
Oct 14, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RELAY][IR] Move type_annotation to Var, remove Param (#1900)
parent
53428606
Show whitespace changes
Inline
Side-by-side
Showing
26 changed files
with
367 additions
and
388 deletions
+367
-388
include/tvm/relay/expr.h
+28
-39
include/tvm/relay/expr_functor.h
+0
-4
python/tvm/relay/__init__.py
+0
-1
python/tvm/relay/expr.py
+117
-42
python/tvm/relay/ir_builder.py
+12
-21
src/relay/ir/debug_printer.cc
+9
-9
src/relay/ir/expr.cc
+18
-28
src/relay/ir/expr_functor.cc
+16
-19
src/relay/pass/alpha_eq.cc
+18
-22
src/relay/pass/dead_code.cc
+10
-16
src/relay/pass/let_list.h
+15
-26
src/relay/pass/type_infer.cc
+10
-19
src/relay/pass/util.cc
+5
-3
src/relay/pass/well_formed.cc
+2
-2
tests/python/relay/test_ir_builder.py
+0
-1
tests/python/relay/test_ir_debug_printer.py
+4
-11
tests/python/relay/test_ir_nodes.py
+9
-15
tests/python/relay/test_ir_well_formed.py
+5
-6
tests/python/relay/test_op_level1.py
+6
-3
tests/python/relay/test_op_level2.py
+16
-17
tests/python/relay/test_op_level3.py
+8
-7
tests/python/relay/test_op_level4.py
+7
-7
tests/python/relay/test_op_level5.py
+2
-2
tests/python/relay/test_pass_alpha_equal.py
+32
-49
tests/python/relay/test_pass_dead_code_elimination.py
+11
-13
tests/python/relay/test_pass_free_vars.py
+7
-6
No files found.
include/tvm/relay/expr.h
View file @
0b4cc050
...
@@ -118,17 +118,27 @@ class Var;
...
@@ -118,17 +118,27 @@ class Var;
/*! \brief Container for Var */
/*! \brief Container for Var */
class
VarNode
:
public
ExprNode
{
class
VarNode
:
public
ExprNode
{
public
:
public
:
/*! \brief The name of the variable, this only acts as a hint to the user,
/*!
* \brief The name of the variable,
* this only acts as a hint to the user,
* and is not used for equality.
* and is not used for equality.
*/
*/
std
::
string
name_hint
;
std
::
string
name_hint
;
/*!
* \brief type annotaion of the variable.
* This field records user provided type annotation of the Var.
* This field is optional and can be None.
*/
Type
type_annotation
;
void
VisitAttrs
(
tvm
::
AttrVisitor
*
v
)
final
{
void
VisitAttrs
(
tvm
::
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"name_hint"
,
&
name_hint
);
v
->
Visit
(
"name_hint"
,
&
name_hint
);
v
->
Visit
(
"type_annotation"
,
&
type_annotation
);
v
->
Visit
(
"_checked_type_"
,
&
checked_type_
);
v
->
Visit
(
"_checked_type_"
,
&
checked_type_
);
}
}
TVM_DLL
static
Var
make
(
std
::
string
name_hint
);
TVM_DLL
static
Var
make
(
std
::
string
name_hint
,
Type
type_annotation
);
static
constexpr
const
char
*
_type_key
=
"relay.Var"
;
static
constexpr
const
char
*
_type_key
=
"relay.Var"
;
TVM_DECLARE_NODE_TYPE_INFO
(
VarNode
,
ExprNode
);
TVM_DECLARE_NODE_TYPE_INFO
(
VarNode
,
ExprNode
);
...
@@ -163,32 +173,6 @@ class GlobalVarNode : public ExprNode {
...
@@ -163,32 +173,6 @@ class GlobalVarNode : public ExprNode {
RELAY_DEFINE_NODE_REF
(
GlobalVar
,
GlobalVarNode
,
Expr
);
RELAY_DEFINE_NODE_REF
(
GlobalVar
,
GlobalVarNode
,
Expr
);
/*!
/*!
* \brief Function parameter declaration.
*/
class
Param
;
/*! \brief A parameter. */
class
ParamNode
:
public
ExprNode
{
public
:
/*! \brief The variable */
Var
var
;
/*! \brief The type of the parameter */
Type
type
;
void
VisitAttrs
(
tvm
::
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"var"
,
&
var
);
v
->
Visit
(
"type"
,
&
type
);
v
->
Visit
(
"span"
,
&
span
);
}
TVM_DLL
static
Param
make
(
Var
var
,
Type
type
);
static
constexpr
const
char
*
_type_key
=
"relay.Param"
;
TVM_DECLARE_NODE_TYPE_INFO
(
ParamNode
,
ExprNode
);
};
RELAY_DEFINE_NODE_REF
(
Param
,
ParamNode
,
Expr
);
/*!
* \brief Function (subgraph in computational graph)
* \brief Function (subgraph in computational graph)
*/
*/
class
Function
;
class
Function
;
...
@@ -196,7 +180,7 @@ class Function;
...
@@ -196,7 +180,7 @@ class Function;
class
FunctionNode
:
public
ExprNode
{
class
FunctionNode
:
public
ExprNode
{
public
:
public
:
/*! \brief Function parameters */
/*! \brief Function parameters */
tvm
::
Array
<
Param
>
params
;
tvm
::
Array
<
Var
>
params
;
/*! \brief User annotated return type of the function. */
/*! \brief User annotated return type of the function. */
Type
ret_type
;
Type
ret_type
;
/*!
/*!
...
@@ -224,10 +208,18 @@ class FunctionNode : public ExprNode {
...
@@ -224,10 +208,18 @@ class FunctionNode : public ExprNode {
v
->
Visit
(
"_checked_type_"
,
&
checked_type_
);
v
->
Visit
(
"_checked_type_"
,
&
checked_type_
);
}
}
Type
fn_type
()
const
;
/*!
* \brief Return the derived function annotation of this expression.
*
* \return The function type annotation.
* \note The function type annotation can contain IncompleteType.
*/
TVM_DLL
FuncType
func_type_annotation
()
const
;
TVM_DLL
static
Function
make
(
tvm
::
Array
<
Param
>
params
,
Type
ret_type
,
TVM_DLL
static
Function
make
(
tvm
::
Array
<
Var
>
params
,
Expr
body
,
tvm
::
Array
<
TypeParam
>
ty_params
);
Type
ret_type
,
Expr
body
,
tvm
::
Array
<
TypeParam
>
ty_params
);
static
constexpr
const
char
*
_type_key
=
"relay.Function"
;
static
constexpr
const
char
*
_type_key
=
"relay.Function"
;
TVM_DECLARE_NODE_TYPE_INFO
(
FunctionNode
,
ExprNode
);
TVM_DECLARE_NODE_TYPE_INFO
(
FunctionNode
,
ExprNode
);
...
@@ -289,7 +281,7 @@ class CallNode : public ExprNode {
...
@@ -289,7 +281,7 @@ class CallNode : public ExprNode {
TVM_DLL
static
Call
make
(
Expr
op
,
TVM_DLL
static
Call
make
(
Expr
op
,
Array
<
Expr
>
args
,
Array
<
Expr
>
args
,
Attrs
attrs
=
Attrs
(),
Attrs
attrs
=
Attrs
(),
Array
<
Type
>
ty_args
=
Array
<
Type
>
());
Array
<
Type
>
ty
pe
_args
=
Array
<
Type
>
());
static
constexpr
const
char
*
_type_key
=
"relay.Call"
;
static
constexpr
const
char
*
_type_key
=
"relay.Call"
;
TVM_DECLARE_NODE_TYPE_INFO
(
CallNode
,
ExprNode
);
TVM_DECLARE_NODE_TYPE_INFO
(
CallNode
,
ExprNode
);
...
@@ -318,19 +310,16 @@ class LetNode : public ExprNode {
...
@@ -318,19 +310,16 @@ class LetNode : public ExprNode {
Expr
value
;
Expr
value
;
/*! \brief The body of the let binding */
/*! \brief The body of the let binding */
Expr
body
;
Expr
body
;
/*! \brief Type annotation of value, this can be null */
Type
value_type
;
void
VisitAttrs
(
tvm
::
AttrVisitor
*
v
)
final
{
void
VisitAttrs
(
tvm
::
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"var"
,
&
var
);
v
->
Visit
(
"var"
,
&
var
);
v
->
Visit
(
"value"
,
&
value
);
v
->
Visit
(
"value"
,
&
value
);
v
->
Visit
(
"body"
,
&
body
);
v
->
Visit
(
"body"
,
&
body
);
v
->
Visit
(
"value_type"
,
&
value_type
);
v
->
Visit
(
"span"
,
&
span
);
v
->
Visit
(
"span"
,
&
span
);
v
->
Visit
(
"_checked_type_"
,
&
checked_type_
);
v
->
Visit
(
"_checked_type_"
,
&
checked_type_
);
}
}
TVM_DLL
static
Let
make
(
Var
var
,
Expr
value
,
Expr
body
,
Type
value_type
);
TVM_DLL
static
Let
make
(
Var
var
,
Expr
value
,
Expr
body
);
static
constexpr
const
char
*
_type_key
=
"relay.Let"
;
static
constexpr
const
char
*
_type_key
=
"relay.Let"
;
TVM_DECLARE_NODE_TYPE_INFO
(
LetNode
,
ExprNode
);
TVM_DECLARE_NODE_TYPE_INFO
(
LetNode
,
ExprNode
);
...
@@ -376,11 +365,11 @@ class IfNode : public ExprNode {
...
@@ -376,11 +365,11 @@ class IfNode : public ExprNode {
RELAY_DEFINE_NODE_REF
(
If
,
IfNode
,
Expr
);
RELAY_DEFINE_NODE_REF
(
If
,
IfNode
,
Expr
);
/*! \brief Get
a
field out of a tuple. */
/*! \brief Get
index-th
field out of a tuple. */
class
TupleGetItem
;
class
TupleGetItem
;
class
TupleGetItemNode
:
public
ExprNode
{
class
TupleGetItemNode
:
public
ExprNode
{
public
:
public
:
/*! \brief The tuple */
/*! \brief The tuple
Expression
*/
Expr
tuple
;
Expr
tuple
;
/*! \brief which value to get */
/*! \brief which value to get */
int
index
;
int
index
;
...
...
include/tvm/relay/expr_functor.h
View file @
0b4cc050
...
@@ -80,7 +80,6 @@ class ExprFunctor<R(const Expr& n, Args...)> {
...
@@ -80,7 +80,6 @@ class ExprFunctor<R(const Expr& n, Args...)> {
Args
...
args
)
EXPR_FUNCTOR_DEFAULT
;
Args
...
args
)
EXPR_FUNCTOR_DEFAULT
;
virtual
R
VisitExpr_
(
const
GlobalVarNode
*
op
,
virtual
R
VisitExpr_
(
const
GlobalVarNode
*
op
,
Args
...
args
)
EXPR_FUNCTOR_DEFAULT
;
Args
...
args
)
EXPR_FUNCTOR_DEFAULT
;
virtual
R
VisitExpr_
(
const
ParamNode
*
op
,
Args
...
args
)
EXPR_FUNCTOR_DEFAULT
;
virtual
R
VisitExpr_
(
const
FunctionNode
*
op
,
virtual
R
VisitExpr_
(
const
FunctionNode
*
op
,
Args
...
args
)
EXPR_FUNCTOR_DEFAULT
;
Args
...
args
)
EXPR_FUNCTOR_DEFAULT
;
virtual
R
VisitExpr_
(
const
CallNode
*
op
,
Args
...
args
)
EXPR_FUNCTOR_DEFAULT
;
virtual
R
VisitExpr_
(
const
CallNode
*
op
,
Args
...
args
)
EXPR_FUNCTOR_DEFAULT
;
...
@@ -103,7 +102,6 @@ class ExprFunctor<R(const Expr& n, Args...)> {
...
@@ -103,7 +102,6 @@ class ExprFunctor<R(const Expr& n, Args...)> {
RELAY_EXPR_FUNCTOR_DISPATCH
(
TupleNode
);
RELAY_EXPR_FUNCTOR_DISPATCH
(
TupleNode
);
RELAY_EXPR_FUNCTOR_DISPATCH
(
VarNode
);
RELAY_EXPR_FUNCTOR_DISPATCH
(
VarNode
);
RELAY_EXPR_FUNCTOR_DISPATCH
(
GlobalVarNode
);
RELAY_EXPR_FUNCTOR_DISPATCH
(
GlobalVarNode
);
RELAY_EXPR_FUNCTOR_DISPATCH
(
ParamNode
);
RELAY_EXPR_FUNCTOR_DISPATCH
(
FunctionNode
);
RELAY_EXPR_FUNCTOR_DISPATCH
(
FunctionNode
);
RELAY_EXPR_FUNCTOR_DISPATCH
(
CallNode
);
RELAY_EXPR_FUNCTOR_DISPATCH
(
CallNode
);
RELAY_EXPR_FUNCTOR_DISPATCH
(
LetNode
);
RELAY_EXPR_FUNCTOR_DISPATCH
(
LetNode
);
...
@@ -127,7 +125,6 @@ class ExprVisitor : public ::tvm::relay::ExprFunctor<void(const Expr& n)> {
...
@@ -127,7 +125,6 @@ class ExprVisitor : public ::tvm::relay::ExprFunctor<void(const Expr& n)> {
void
VisitExpr_
(
const
GlobalVarNode
*
op
)
override
;
void
VisitExpr_
(
const
GlobalVarNode
*
op
)
override
;
void
VisitExpr_
(
const
ConstantNode
*
op
)
override
;
void
VisitExpr_
(
const
ConstantNode
*
op
)
override
;
void
VisitExpr_
(
const
TupleNode
*
op
)
override
;
void
VisitExpr_
(
const
TupleNode
*
op
)
override
;
void
VisitExpr_
(
const
ParamNode
*
op
)
override
;
void
VisitExpr_
(
const
FunctionNode
*
op
)
override
;
void
VisitExpr_
(
const
FunctionNode
*
op
)
override
;
void
VisitExpr_
(
const
CallNode
*
op
)
override
;
void
VisitExpr_
(
const
CallNode
*
op
)
override
;
void
VisitExpr_
(
const
LetNode
*
op
)
override
;
void
VisitExpr_
(
const
LetNode
*
op
)
override
;
...
@@ -151,7 +148,6 @@ class ExprMutator
...
@@ -151,7 +148,6 @@ class ExprMutator
Expr
VisitExpr_
(
const
GlobalVarNode
*
op
)
override
;
Expr
VisitExpr_
(
const
GlobalVarNode
*
op
)
override
;
Expr
VisitExpr_
(
const
OpNode
*
op
)
override
;
Expr
VisitExpr_
(
const
OpNode
*
op
)
override
;
Expr
VisitExpr_
(
const
TupleNode
*
op
)
override
;
Expr
VisitExpr_
(
const
TupleNode
*
op
)
override
;
Expr
VisitExpr_
(
const
ParamNode
*
op
)
override
;
Expr
VisitExpr_
(
const
FunctionNode
*
op
)
override
;
Expr
VisitExpr_
(
const
FunctionNode
*
op
)
override
;
Expr
VisitExpr_
(
const
CallNode
*
call_node
)
override
;
Expr
VisitExpr_
(
const
CallNode
*
call_node
)
override
;
Expr
VisitExpr_
(
const
LetNode
*
op
)
override
;
Expr
VisitExpr_
(
const
LetNode
*
op
)
override
;
...
...
python/tvm/relay/__init__.py
View file @
0b4cc050
...
@@ -34,7 +34,6 @@ Constant = expr.Constant
...
@@ -34,7 +34,6 @@ Constant = expr.Constant
Tuple
=
expr
.
Tuple
Tuple
=
expr
.
Tuple
Var
=
expr
.
Var
Var
=
expr
.
Var
GlobalVar
=
expr
.
GlobalVar
GlobalVar
=
expr
.
GlobalVar
Param
=
expr
.
Param
Function
=
expr
.
Function
Function
=
expr
.
Function
Call
=
expr
.
Call
Call
=
expr
.
Call
Let
=
expr
.
Let
Let
=
expr
.
Let
...
...
python/tvm/relay/expr.py
View file @
0b4cc050
...
@@ -11,11 +11,11 @@ class Expr(NodeBase):
...
@@ -11,11 +11,11 @@ class Expr(NodeBase):
"""The base type for all Relay expressions."""
"""The base type for all Relay expressions."""
@property
@property
def
checked_type
(
self
):
def
checked_type
(
self
):
"""Get the checked type of
relay
.
"""Get the checked type of
tvm.relay.Expr
.
Returns
Returns
-------
-------
checked_type : relay.Type
checked_type :
tvm.
relay.Type
The checked type.
The checked type.
"""
"""
ret
=
self
.
_checked_type_
ret
=
self
.
_checked_type_
...
@@ -25,70 +25,97 @@ class Expr(NodeBase):
...
@@ -25,70 +25,97 @@ class Expr(NodeBase):
return
ret
return
ret
def
__call__
(
self
,
*
args
):
def
__call__
(
self
,
*
args
):
converted_args
=
[]
for
arg
in
args
:
if
isinstance
(
arg
,
Param
):
converted_args
.
append
(
arg
.
var
)
else
:
converted_args
.
append
(
arg
)
return
Call
(
self
,
args
,
None
,
None
)
return
Call
(
self
,
args
,
None
,
None
)
@register_relay_node
@register_relay_node
class
Constant
(
Expr
):
class
Constant
(
Expr
):
"""A constant tensor in Relay, see tvm/relay/type.h for more details.
"""A constant expression in Relay.
"""
Parameters
----------
data : tvm.nd.NDArray
The data content of the constant expression.
"""
def
__init__
(
self
,
data
):
def
__init__
(
self
,
data
):
self
.
__init_handle_by_constructor__
(
_make
.
Constant
,
data
)
self
.
__init_handle_by_constructor__
(
_make
.
Constant
,
data
)
@register_relay_node
@register_relay_node
class
Tuple
(
Expr
):
class
Tuple
(
Expr
):
"""A hetereogenous sequence of values.
"""Tuple expression that groups several fields together.
see tvm/relay/type.h for more details.
"""
Parameters
----------
fields : List[tvm.relay.Expr]
The fields in the tuple.
"""
def
__init__
(
self
,
fields
):
def
__init__
(
self
,
fields
):
self
.
__init_handle_by_constructor__
(
_make
.
Tuple
,
fields
)
self
.
__init_handle_by_constructor__
(
_make
.
Tuple
,
fields
)
@register_relay_node
@register_relay_node
class
Var
(
Expr
):
class
Var
(
Expr
):
"""A local variable in
Relay."""
"""A local variable in
Tvm.Relay.
def
__init__
(
self
,
name_hint
):
Local variable can be used to declare input
self
.
__init_handle_by_constructor__
(
_make
.
Var
,
name_hint
)
arguments to a function, or intermediate variables.
Parameters
----------
name_hint: str
The name of the variable.
This name only acts as a hint, and is not used
for equality.
type_annotation: tvm.relay.Type, optional
The type annotation on the variable.
"""
def
__init__
(
self
,
name_hint
,
type_annotation
=
None
):
self
.
__init_handle_by_constructor__
(
_make
.
Var
,
name_hint
,
type_annotation
)
@register_relay_node
@register_relay_node
class
GlobalVar
(
Expr
):
class
GlobalVar
(
Expr
):
"""A global variable in
Relay."""
"""A global variable in
Tvm.Relay.
GlobalVar is used to refer to the global functions
stored in the environment.
Parameters
----------
name_hint: str
The name of the variable.
"""
def
__init__
(
self
,
name_hint
):
def
__init__
(
self
,
name_hint
):
self
.
__init_handle_by_constructor__
(
_make
.
GlobalVar
,
name_hint
)
self
.
__init_handle_by_constructor__
(
_make
.
GlobalVar
,
name_hint
)
@register_relay_node
@register_relay_node
class
Param
(
Expr
):
class
Function
(
Expr
):
"""A function type in Relay, see tvm/relay/type.h for more details.
"""A function declaration expression.
"""
def
__init__
(
self
,
var
,
ty
):
Parameters
self
.
__init_handle_by_constructor__
(
_make
.
Param
,
var
,
ty
)
----------
params: List[tvm.relay.Var]
List of input parameters to the function.
ret_type: tvm.relay.Type
The return type annotation of the function.
@register_relay_node
body: tvm.relay.Expr
class
Function
(
Expr
):
The body of the function.
"""A function in Relay, see tvm/relay/expr.h for more details."""
type_params: Optional[List[tvm.relay.TypeParam]]
The additional type parameters, this is only
used in advanced usecase of template functions.
"""
def
__init__
(
self
,
def
__init__
(
self
,
params
,
params
,
ret_type
,
ret_type
,
body
,
body
,
type_params
=
None
type_params
=
None
):
):
if
type_params
is
None
:
if
type_params
is
None
:
type_params
=
convert
([])
type_params
=
convert
([])
...
@@ -98,39 +125,87 @@ class Function(Expr):
...
@@ -98,39 +125,87 @@ class Function(Expr):
@register_relay_node
@register_relay_node
class
Call
(
Expr
):
class
Call
(
Expr
):
"""A function call in Relay, see tvm/relay/expr.h for more details."""
"""Function call node in Relay.
Call node corresponds the operator application node
in computational graph terminology.
Parameters
----------
op: tvm.relay.Op or any tvm.relay.Expr with function type.
The operation to be called.
def
__init__
(
self
,
op
,
args
,
attrs
,
ty_args
=
None
):
args: List[tvm.relay.Expr]
if
not
ty_args
:
The arguments to the call.
ty_args
=
[]
attrs: Optional[tvm.Attrs]
Attributes to the call, can be None
type_args: Optional[List[tvm.relay.Type]]
The additional type arguments, this is only
used in advanced usecase of template functions.
"""
def
__init__
(
self
,
op
,
args
,
attrs
=
None
,
type_args
=
None
):
if
not
type_args
:
type_args
=
[]
self
.
__init_handle_by_constructor__
(
self
.
__init_handle_by_constructor__
(
_make
.
Call
,
op
,
args
,
attrs
,
ty_args
)
_make
.
Call
,
op
,
args
,
attrs
,
ty
pe
_args
)
@register_relay_node
@register_relay_node
class
Let
(
Expr
):
class
Let
(
Expr
):
"""A variable bindings in Relay, see tvm/relay/expr.h for more details."""
"""Let variable binding expression.
Parameters
----------
var: tvm.relay.Var
The local variable to be bound.
value: tvm.relay.Expr
The value to be bound.
def
__init__
(
self
,
var
,
value
,
body
,
value_type
=
None
):
body: tvm.relay.Expr
The body of the let binding.
"""
def
__init__
(
self
,
var
,
value
,
body
):
self
.
__init_handle_by_constructor__
(
self
.
__init_handle_by_constructor__
(
_make
.
Let
,
var
,
value
,
body
,
value_type
)
_make
.
Let
,
var
,
value
,
body
)
@register_relay_node
@register_relay_node
class
If
(
Expr
):
class
If
(
Expr
):
"""A conditional expression in Relay, see tvm/relay/expr.h for more details."""
"""A conditional expression in Relay.
Parameters
----------
cond: tvm.relay.Expr
The condition.
def
__init__
(
self
,
cond
,
true_value
,
false_value
):
true_branch: tvm.relay.Expr
The expression evaluated when condition is true.
false_branch: tvm.relay.Expr
The expression evaluated when condition is false.
"""
def
__init__
(
self
,
cond
,
true_branch
,
false_branch
):
self
.
__init_handle_by_constructor__
(
self
.
__init_handle_by_constructor__
(
_make
.
If
,
cond
,
true_value
,
false_value
)
_make
.
If
,
cond
,
true_branch
,
false_branch
)
@register_relay_node
@register_relay_node
class
TupleGetItem
(
Expr
):
class
TupleGetItem
(
Expr
):
"""An expression that get field from tuple in Relay, see tvm/relay/expr.h for more details."""
"""Get index-th item from a tuple.
Parameters
----------
tuple_value: tvm.relay.Expr
The input tuple expression.
def
__init__
(
self
,
tuple_
,
index
):
index: int
The index.
"""
def
__init__
(
self
,
tuple_value
,
index
):
self
.
__init_handle_by_constructor__
(
self
.
__init_handle_by_constructor__
(
_make
.
TupleGetItem
,
tuple_
,
index
)
_make
.
TupleGetItem
,
tuple_
value
,
index
)
debug_print
=
_expr
.
_debug_print
debug_print
=
_expr
.
_debug_print
python/tvm/relay/ir_builder.py
View file @
0b4cc050
...
@@ -7,7 +7,7 @@ from collections import OrderedDict
...
@@ -7,7 +7,7 @@ from collections import OrderedDict
import
numpy
as
np
import
numpy
as
np
import
tvm
import
tvm
from
.ty
import
Type
,
FuncType
,
TensorType
from
.ty
import
Type
,
FuncType
,
TensorType
from
.expr
import
Expr
,
Constant
,
Let
,
Var
,
Param
,
Function
,
If
from
.expr
import
Expr
,
Constant
,
Let
,
Var
,
Function
,
If
from
.env
import
Environment
from
.env
import
Environment
...
@@ -98,7 +98,7 @@ class PartialFunc(object):
...
@@ -98,7 +98,7 @@ class PartialFunc(object):
self
.
type_params
=
type_params
self
.
type_params
=
type_params
def
param_ids
(
self
):
def
param_ids
(
self
):
return
[
p
.
var
for
p
in
self
.
params
]
return
[
p
for
p
in
self
.
params
]
def
to_func
(
self
):
def
to_func
(
self
):
"""Converts a PartialFunc into a :py:class:`~relay.Function`."""
"""Converts a PartialFunc into a :py:class:`~relay.Function`."""
...
@@ -113,9 +113,8 @@ class PartialFunc(object):
...
@@ -113,9 +113,8 @@ class PartialFunc(object):
def
_mk_let
(
bindings
,
ret_value
):
def
_mk_let
(
bindings
,
ret_value
):
let_expr
=
ret_value
let_expr
=
ret_value
for
var
,
(
value
,
ty
)
in
reversed
(
list
(
bindings
.
items
())):
for
var
,
value
in
reversed
(
list
(
bindings
.
items
())):
let_expr
=
Let
(
var
,
value
,
let_expr
,
ty
)
let_expr
=
Let
(
var
,
value
,
let_expr
)
return
let_expr
return
let_expr
...
@@ -168,15 +167,12 @@ class IRBuilder(object):
...
@@ -168,15 +167,12 @@ class IRBuilder(object):
#pylint: disable=invalid-name
#pylint: disable=invalid-name
def
bind
(
self
,
name
,
value
,
ty
):
def
bind
(
self
,
name
,
value
,
ty
):
lv
=
Var
(
name
)
lv
=
Var
(
name
,
ty
)
self
.
scopes
[
-
1
][
name
]
=
lv
self
.
scopes
[
-
1
][
name
]
=
lv
self
.
bindings
[
-
1
][
lv
]
=
(
value
,
ty
)
self
.
bindings
[
-
1
][
lv
]
=
value
return
lv
return
lv
def
let
(
self
,
name
,
value
,
value_type
=
None
):
def
let
(
self
,
name
,
value
,
value_type
=
None
):
if
isinstance
(
value
,
Param
):
value
=
value
.
var
if
not
isinstance
(
value
,
Expr
):
if
not
isinstance
(
value
,
Expr
):
value
=
convert
(
value
)
value
=
convert
(
value
)
...
@@ -185,23 +181,18 @@ class IRBuilder(object):
...
@@ -185,23 +181,18 @@ class IRBuilder(object):
def
_convert_params
(
self
,
raw_params
):
def
_convert_params
(
self
,
raw_params
):
relay_params
=
[]
relay_params
=
[]
for
raw_param
in
raw_params
:
for
raw_param
in
raw_params
:
if
isinstance
(
raw_param
,
Param
):
if
isinstance
(
raw_param
,
Var
):
var
=
raw_param
.
var
param
=
raw_param
param
=
raw_param
elif
isinstance
(
raw_param
,
tuple
):
elif
isinstance
(
raw_param
,
tuple
):
var
,
ty
=
raw_param
var
,
ty
=
raw_param
if
isinstance
(
var
,
str
):
var
=
Var
(
var
)
ty
=
_convert_type
(
ty
)
ty
=
_convert_type
(
ty
)
param
=
Param
(
var
,
ty
)
param
=
Var
(
var
,
ty
)
elif
isinstance
(
param
,
str
):
elif
isinstance
(
raw_param
,
str
):
var
=
Var
(
raw_param
)
param
=
Var
(
raw_param
,
None
)
ty
=
None
param
=
Param
(
var
,
ty
)
else
:
else
:
raise
Exception
(
"unknown parameter type"
)
raise
Exception
(
"unknown parameter type"
)
self
.
scopes
[
-
1
][
var
.
name_hint
]
=
var
self
.
scopes
[
-
1
][
param
.
name_hint
]
=
param
relay_params
.
append
(
param
)
relay_params
.
append
(
param
)
return
relay_params
return
relay_params
...
@@ -265,7 +256,7 @@ class IRBuilder(object):
...
@@ -265,7 +256,7 @@ class IRBuilder(object):
else
:
else
:
ty
=
_convert_type
(
ty
)
ty
=
_convert_type
(
ty
)
return
Param
(
Var
(
name
)
,
ty
)
return
Var
(
name
,
ty
)
def
global_var
(
self
,
name
):
def
global_var
(
self
,
name
):
# type: (str) -> GlobalVar
# type: (str) -> GlobalVar
...
...
src/relay/ir/debug_printer.cc
View file @
0b4cc050
...
@@ -96,7 +96,9 @@ class TypeDocifier : private TypeFunctor<Doc(const Type& n)> {
...
@@ -96,7 +96,9 @@ class TypeDocifier : private TypeFunctor<Doc(const Type& n)> {
}
}
std
::
vector
<
Doc
>
DocifyTypeParam
(
const
tvm
::
Array
<
TypeParam
>&
arr
)
{
std
::
vector
<
Doc
>
DocifyTypeParam
(
const
tvm
::
Array
<
TypeParam
>&
arr
)
{
return
MapDocify
<
TypeParam
>
(
arr
,
[
=
](
const
TypeParam
&
tp
)
{
return
Docify
(
tp
);
});
return
MapDocify
<
TypeParam
>
(
arr
,
[
=
](
const
TypeParam
&
tp
)
{
return
Docify
(
tp
);
});
}
}
std
::
vector
<
Doc
>
DocifyTypeConstraint
(
const
tvm
::
Array
<
TypeConstraint
>&
arr
)
{
std
::
vector
<
Doc
>
DocifyTypeConstraint
(
const
tvm
::
Array
<
TypeConstraint
>&
arr
)
{
...
@@ -188,10 +190,11 @@ class ExprDocifier : private ExprFunctor<Doc(const Expr& n)> {
...
@@ -188,10 +190,11 @@ class ExprDocifier : private ExprFunctor<Doc(const Expr& n)> {
return
vec
;
return
vec
;
}
}
std
::
vector
<
Doc
>
DocifyParamArray
(
const
tvm
::
Array
<
Param
>&
arr
)
{
std
::
vector
<
Doc
>
DocifyParamArray
(
const
tvm
::
Array
<
Var
>&
arr
)
{
std
::
vector
<
Doc
>
vec
;
std
::
vector
<
Doc
>
vec
;
for
(
size_t
i
=
0
;
i
<
arr
.
size
();
++
i
)
{
for
(
Var
param
:
arr
)
{
vec
.
push_back
(
Docify
(
arr
[
i
]));
vec
.
emplace_back
(
TypeAnnotation
(
DocOfStr
(
VarName
(
param
)),
param
->
type_annotation
));
}
}
return
vec
;
return
vec
;
}
}
...
@@ -212,10 +215,6 @@ class ExprDocifier : private ExprFunctor<Doc(const Expr& n)> {
...
@@ -212,10 +215,6 @@ class ExprDocifier : private ExprFunctor<Doc(const Expr& n)> {
return
DocOfStr
(
g
->
name_hint
);
return
DocOfStr
(
g
->
name_hint
);
}
}
Doc
VisitExpr_
(
const
ParamNode
*
p
)
final
{
return
TypeAnnotation
(
Docify
(
p
->
var
),
p
->
type
);
}
Doc
VisitExpr_
(
const
FunctionNode
*
f
)
final
{
Doc
VisitExpr_
(
const
FunctionNode
*
f
)
final
{
return
Group
(
TypeAnnotation
(
Seq
(
"("
,
DocifyParamArray
(
f
->
params
),
")"
),
f
->
ret_type
)
+
Sep
()
+
return
Group
(
TypeAnnotation
(
Seq
(
"("
,
DocifyParamArray
(
f
->
params
),
")"
),
f
->
ret_type
)
+
Sep
()
+
DocOfStr
(
"=>"
)
+
Sep
()
+
DocOfStr
(
"=>"
)
+
Sep
()
+
...
@@ -227,7 +226,8 @@ class ExprDocifier : private ExprFunctor<Doc(const Expr& n)> {
...
@@ -227,7 +226,8 @@ class ExprDocifier : private ExprFunctor<Doc(const Expr& n)> {
}
}
Doc
VisitExpr_
(
const
LetNode
*
l
)
final
{
Doc
VisitExpr_
(
const
LetNode
*
l
)
final
{
return
Group
(
DocOfStr
(
"let"
)
+
Sep
()
+
TypeAnnotation
(
Docify
(
l
->
var
),
l
->
value_type
)
+
Sep
()
+
return
Group
(
DocOfStr
(
"let"
)
+
Sep
()
+
TypeAnnotation
(
Docify
(
l
->
var
),
l
->
var
->
type_annotation
)
+
Sep
()
+
DocOfStr
(
"="
)
+
Sep
()
+
Docify
(
l
->
value
)
+
DocOfStr
(
";"
)
+
Endl
()
+
DocOfStr
(
"="
)
+
Sep
()
+
Docify
(
l
->
value
)
+
DocOfStr
(
";"
)
+
Endl
()
+
Docify
(
l
->
body
));
Docify
(
l
->
body
));
}
}
...
...
src/relay/ir/expr.cc
View file @
0b4cc050
...
@@ -54,20 +54,26 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
...
@@ -54,20 +54,26 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
p
->
stream
<<
"Tuple("
<<
node
->
fields
<<
")"
;
p
->
stream
<<
"Tuple("
<<
node
->
fields
<<
")"
;
});
});
Var
VarNode
::
make
(
std
::
string
name_hint
)
{
Var
VarNode
::
make
(
std
::
string
name_hint
,
Type
type_annotation
)
{
NodePtr
<
VarNode
>
n
=
make_node
<
VarNode
>
();
NodePtr
<
VarNode
>
n
=
make_node
<
VarNode
>
();
n
->
name_hint
=
std
::
move
(
name_hint
);
n
->
name_hint
=
std
::
move
(
name_hint
);
n
->
type_annotation
=
std
::
move
(
type_annotation
);
return
Var
(
n
);
return
Var
(
n
);
}
}
TVM_REGISTER_API
(
"relay._make.Var"
)
TVM_REGISTER_API
(
"relay._make.Var"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
VarNode
::
make
(
args
[
0
]);
*
ret
=
VarNode
::
make
(
args
[
0
]
,
args
[
1
]
);
});
});
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
.
set_dispatch
<
VarNode
>
([](
const
VarNode
*
node
,
tvm
::
IRPrinter
*
p
)
{
.
set_dispatch
<
VarNode
>
([](
const
VarNode
*
node
,
tvm
::
IRPrinter
*
p
)
{
p
->
stream
<<
"Var("
<<
node
->
name_hint
<<
")"
;
p
->
stream
<<
"Var("
<<
node
->
name_hint
;
if
(
node
->
type_annotation
.
defined
())
{
p
->
stream
<<
", ty="
;
p
->
print
(
node
->
type_annotation
);
}
p
->
stream
<<
")"
;
});
});
GlobalVar
GlobalVarNode
::
make
(
std
::
string
name_hint
)
{
GlobalVar
GlobalVarNode
::
make
(
std
::
string
name_hint
)
{
...
@@ -86,24 +92,10 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
...
@@ -86,24 +92,10 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
p
->
stream
<<
"GlobalVar("
<<
node
->
name_hint
<<
")"
;
p
->
stream
<<
"GlobalVar("
<<
node
->
name_hint
<<
")"
;
});
});
Param
ParamNode
::
make
(
Var
var
,
Type
type
)
{
NodePtr
<
ParamNode
>
n
=
make_node
<
ParamNode
>
();
n
->
var
=
std
::
move
(
var
);
n
->
type
=
std
::
move
(
type
);
return
Param
(
n
);
}
TVM_REGISTER_API
(
"relay._make.Param"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
ParamNode
::
make
(
args
[
0
],
args
[
1
]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
.
set_dispatch
<
ParamNode
>
([](
const
ParamNode
*
node
,
tvm
::
IRPrinter
*
p
)
{
p
->
stream
<<
"Param("
<<
node
->
var
<<
", "
<<
node
->
type
<<
")"
;
});
Function
FunctionNode
::
make
(
tvm
::
Array
<
Param
>
params
,
Type
ret_type
,
Expr
body
,
Function
FunctionNode
::
make
(
tvm
::
Array
<
Var
>
params
,
Type
ret_type
,
Expr
body
,
tvm
::
Array
<
TypeParam
>
type_params
)
{
tvm
::
Array
<
TypeParam
>
type_params
)
{
NodePtr
<
FunctionNode
>
n
=
make_node
<
FunctionNode
>
();
NodePtr
<
FunctionNode
>
n
=
make_node
<
FunctionNode
>
();
n
->
params
=
std
::
move
(
params
);
n
->
params
=
std
::
move
(
params
);
...
@@ -113,12 +105,11 @@ Function FunctionNode::make(tvm::Array<Param> params, Type ret_type, Expr body,
...
@@ -113,12 +105,11 @@ Function FunctionNode::make(tvm::Array<Param> params, Type ret_type, Expr body,
return
Function
(
n
);
return
Function
(
n
);
}
}
Type
FunctionNode
::
fn_type
()
const
{
FuncType
FunctionNode
::
func_type_annotation
()
const
{
Array
<
Type
>
param_types
;
Array
<
Type
>
param_types
;
for
(
auto
param
:
this
->
params
)
{
for
(
auto
param
:
this
->
params
)
{
param_types
.
push_back
(
param
->
type
);
param_types
.
push_back
(
param
->
type
_annotation
);
}
}
return
FuncTypeNode
::
make
(
param_types
,
this
->
ret_type
,
this
->
type_params
,
{});
return
FuncTypeNode
::
make
(
param_types
,
this
->
ret_type
,
this
->
type_params
,
{});
}
}
...
@@ -155,24 +146,23 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
...
@@ -155,24 +146,23 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
<<
node
->
attrs
<<
", "
<<
node
->
type_args
<<
")"
;
<<
node
->
attrs
<<
", "
<<
node
->
type_args
<<
")"
;
});
});
Let
LetNode
::
make
(
Var
var
,
Expr
value
,
Expr
body
,
Type
value_type
)
{
Let
LetNode
::
make
(
Var
var
,
Expr
value
,
Expr
body
)
{
NodePtr
<
LetNode
>
n
=
make_node
<
LetNode
>
();
NodePtr
<
LetNode
>
n
=
make_node
<
LetNode
>
();
n
->
var
=
std
::
move
(
var
);
n
->
var
=
std
::
move
(
var
);
n
->
value
=
std
::
move
(
value
);
n
->
value
=
std
::
move
(
value
);
n
->
body
=
std
::
move
(
body
);
n
->
body
=
std
::
move
(
body
);
n
->
value_type
=
std
::
move
(
value_type
);
return
Let
(
n
);
return
Let
(
n
);
}
}
TVM_REGISTER_API
(
"relay._make.Let"
)
TVM_REGISTER_API
(
"relay._make.Let"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
LetNode
::
make
(
args
[
0
],
args
[
1
],
args
[
2
],
args
[
3
]);
*
ret
=
LetNode
::
make
(
args
[
0
],
args
[
1
],
args
[
2
]);
});
});
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
.
set_dispatch
<
LetNode
>
([](
const
LetNode
*
node
,
tvm
::
IRPrinter
*
p
)
{
.
set_dispatch
<
LetNode
>
([](
const
LetNode
*
node
,
tvm
::
IRPrinter
*
p
)
{
p
->
stream
<<
"LetNode("
<<
node
->
var
<<
", "
<<
node
->
value
p
->
stream
<<
"LetNode("
<<
node
->
var
<<
", "
<<
node
->
value
<<
", "
<<
node
->
body
<<
", "
<<
node
->
value_type
<<
")"
;
<<
", "
<<
node
->
body
<<
")"
;
});
});
If
IfNode
::
make
(
Expr
cond
,
Expr
true_branch
,
Expr
false_branch
)
{
If
IfNode
::
make
(
Expr
cond
,
Expr
true_branch
,
Expr
false_branch
)
{
...
...
src/relay/ir/expr_functor.cc
View file @
0b4cc050
...
@@ -24,6 +24,16 @@ Expr ExprMutator::Mutate(const Expr& expr) {
...
@@ -24,6 +24,16 @@ Expr ExprMutator::Mutate(const Expr& expr) {
}
}
Expr
ExprMutator
::
VisitExpr_
(
const
VarNode
*
op
)
{
Expr
ExprMutator
::
VisitExpr_
(
const
VarNode
*
op
)
{
// NOTE: var will only be mutated once
// Thanks to the memo and reused during rewriting if necessary.
// It is safe to assume that the
if
(
op
->
type_annotation
.
defined
())
{
auto
type
=
this
->
VisitType
(
op
->
type_annotation
);
if
(
!
op
->
type_annotation
.
same_as
(
type
))
{
return
VarNode
::
make
(
op
->
name_hint
,
type
);
}
}
// default case return self.
return
GetRef
<
Expr
>
(
op
);
return
GetRef
<
Expr
>
(
op
);
}
}
...
@@ -55,16 +65,6 @@ Expr ExprMutator::VisitExpr_(const TupleNode* op) {
...
@@ -55,16 +65,6 @@ Expr ExprMutator::VisitExpr_(const TupleNode* op) {
}
}
}
}
Expr
ExprMutator
::
VisitExpr_
(
const
ParamNode
*
op
)
{
Var
var
=
Downcast
<
Var
>
(
this
->
Mutate
(
op
->
var
));
auto
type
=
this
->
VisitType
(
op
->
type
);
if
(
op
->
var
.
same_as
(
var
)
&&
op
->
type
.
same_as
(
type
))
{
return
GetRef
<
Expr
>
(
op
);
}
else
{
return
ParamNode
::
make
(
var
,
type
);
}
}
Expr
ExprMutator
::
VisitExpr_
(
const
FunctionNode
*
op
)
{
Expr
ExprMutator
::
VisitExpr_
(
const
FunctionNode
*
op
)
{
tvm
::
Array
<
TypeParam
>
ty_params
;
tvm
::
Array
<
TypeParam
>
ty_params
;
bool
all_ty_params_changed
=
true
;
bool
all_ty_params_changed
=
true
;
...
@@ -75,10 +75,10 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* op) {
...
@@ -75,10 +75,10 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* op) {
all_ty_params_changed
&=
new_ty_param
.
same_as
(
ty_param
);
all_ty_params_changed
&=
new_ty_param
.
same_as
(
ty_param
);
}
}
tvm
::
Array
<
Param
>
params
;
tvm
::
Array
<
Var
>
params
;
bool
all_params_changed
=
true
;
bool
all_params_changed
=
true
;
for
(
auto
param
:
op
->
params
)
{
for
(
auto
param
:
op
->
params
)
{
Param
new_param
=
Downcast
<
Param
>
(
this
->
Mutate
(
param
));
Var
new_param
=
Downcast
<
Var
>
(
this
->
Mutate
(
param
));
params
.
push_back
(
new_param
);
params
.
push_back
(
new_param
);
all_params_changed
&=
param
.
same_as
(
new_param
);
all_params_changed
&=
param
.
same_as
(
new_param
);
}
}
...
@@ -123,17 +123,15 @@ Expr ExprMutator::VisitExpr_(const CallNode* call_node) {
...
@@ -123,17 +123,15 @@ Expr ExprMutator::VisitExpr_(const CallNode* call_node) {
Expr
ExprMutator
::
VisitExpr_
(
const
LetNode
*
op
)
{
Expr
ExprMutator
::
VisitExpr_
(
const
LetNode
*
op
)
{
Var
var
=
Downcast
<
Var
>
(
this
->
Mutate
(
op
->
var
));
Var
var
=
Downcast
<
Var
>
(
this
->
Mutate
(
op
->
var
));
auto
type
=
this
->
VisitType
(
op
->
value_type
);
auto
value
=
this
->
Mutate
(
op
->
value
);
auto
value
=
this
->
Mutate
(
op
->
value
);
auto
body
=
this
->
Mutate
(
op
->
body
);
auto
body
=
this
->
Mutate
(
op
->
body
);
if
(
var
.
same_as
(
op
->
var
)
&&
if
(
var
.
same_as
(
op
->
var
)
&&
type
.
same_as
(
op
->
value_type
)
&&
value
.
same_as
(
op
->
value
)
&&
value
.
same_as
(
op
->
value
)
&&
body
.
same_as
(
op
->
body
))
{
body
.
same_as
(
op
->
body
))
{
return
GetRef
<
Expr
>
(
op
);
return
GetRef
<
Expr
>
(
op
);
}
else
{
}
else
{
return
LetNode
::
make
(
var
,
value
,
body
,
type
);
return
LetNode
::
make
(
var
,
value
,
body
);
}
}
}
}
...
@@ -162,6 +160,9 @@ Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) {
...
@@ -162,6 +160,9 @@ Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) {
Type
ExprMutator
::
VisitType
(
const
Type
&
t
)
{
return
t
;
}
Type
ExprMutator
::
VisitType
(
const
Type
&
t
)
{
return
t
;
}
void
ExprVisitor
::
ExprVisitor
::
VisitExpr_
(
const
VarNode
*
op
)
{
void
ExprVisitor
::
ExprVisitor
::
VisitExpr_
(
const
VarNode
*
op
)
{
if
(
op
->
type_annotation
.
defined
())
{
this
->
VisitType
(
op
->
type_annotation
);
}
}
}
void
ExprVisitor
::
ExprVisitor
::
VisitExpr_
(
const
GlobalVarNode
*
op
)
{
void
ExprVisitor
::
ExprVisitor
::
VisitExpr_
(
const
GlobalVarNode
*
op
)
{
...
@@ -176,10 +177,6 @@ void ExprVisitor::ExprVisitor::VisitExpr_(const TupleNode* op) {
...
@@ -176,10 +177,6 @@ void ExprVisitor::ExprVisitor::VisitExpr_(const TupleNode* op) {
}
}
}
}
void
ExprVisitor
::
ExprVisitor
::
VisitExpr_
(
const
ParamNode
*
op
)
{
this
->
VisitExpr
(
op
->
var
);
}
void
ExprVisitor
::
ExprVisitor
::
VisitExpr_
(
const
FunctionNode
*
op
)
{
void
ExprVisitor
::
ExprVisitor
::
VisitExpr_
(
const
FunctionNode
*
op
)
{
for
(
auto
param
:
op
->
params
)
{
for
(
auto
param
:
op
->
params
)
{
this
->
VisitExpr
(
param
);
this
->
VisitExpr
(
param
);
...
...
src/relay/pass/alpha_eq.cc
View file @
0b4cc050
...
@@ -252,15 +252,6 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
...
@@ -252,15 +252,6 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
}
}
}
}
void
VisitExpr_
(
const
ParamNode
*
p1
,
const
Expr
&
e2
)
final
{
if
(
const
ParamNode
*
p2
=
e2
.
as
<
ParamNode
>
())
{
eq_map
.
Set
(
p1
->
var
,
p2
->
var
);
equal
=
equal
&&
AlphaEqual
(
p1
->
type
,
p2
->
type
);
}
else
{
equal
=
false
;
}
}
void
VisitExpr_
(
const
FunctionNode
*
func1
,
const
Expr
&
e2
)
final
{
void
VisitExpr_
(
const
FunctionNode
*
func1
,
const
Expr
&
e2
)
final
{
if
(
const
FunctionNode
*
func2
=
e2
.
as
<
FunctionNode
>
())
{
if
(
const
FunctionNode
*
func2
=
e2
.
as
<
FunctionNode
>
())
{
if
(
func1
->
params
.
size
()
!=
func2
->
params
.
size
())
{
if
(
func1
->
params
.
size
()
!=
func2
->
params
.
size
())
{
...
@@ -273,9 +264,10 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
...
@@ -273,9 +264,10 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
return
;
return
;
}
}
for
(
size_t
i
=
0
U
;
i
<
func1
->
params
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
func1
->
params
.
size
();
++
i
)
{
this
->
VisitExpr
(
func1
->
params
[
i
],
func2
->
params
[
i
]);
MergeVarDecl
(
func1
->
params
[
i
],
func2
->
params
[
i
]);
}
}
if
(
!
equal
)
return
;
for
(
size_t
i
=
0U
;
i
<
func1
->
type_params
.
size
();
i
++
)
{
for
(
size_t
i
=
0U
;
i
<
func1
->
type_params
.
size
();
i
++
)
{
equal
=
equal
&&
AlphaEqual
(
func1
->
type_params
[
i
],
func2
->
type_params
[
i
]);
equal
=
equal
&&
AlphaEqual
(
func1
->
type_params
[
i
],
func2
->
type_params
[
i
]);
...
@@ -332,19 +324,9 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
...
@@ -332,19 +324,9 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
void
VisitExpr_
(
const
LetNode
*
op
,
const
Expr
&
e2
)
final
{
void
VisitExpr_
(
const
LetNode
*
op
,
const
Expr
&
e2
)
final
{
if
(
const
LetNode
*
let
=
e2
.
as
<
LetNode
>
())
{
if
(
const
LetNode
*
let
=
e2
.
as
<
LetNode
>
())
{
eq_map
.
Set
(
op
->
var
,
let
->
var
);
MergeVarDecl
(
op
->
var
,
let
->
var
);
this
->
VisitExpr
(
op
->
value
,
let
->
value
);
this
->
VisitExpr
(
op
->
value
,
let
->
value
);
this
->
VisitExpr
(
op
->
body
,
let
->
body
);
this
->
VisitExpr
(
op
->
body
,
let
->
body
);
// value_type should match as well (including nulls)
if
(
op
->
value_type
.
defined
()
!=
let
->
value_type
.
defined
())
{
equal
=
false
;
return
;
}
if
(
op
->
value_type
.
defined
())
{
equal
=
equal
&&
AlphaEqual
(
op
->
value_type
,
let
->
value_type
);
}
}
else
{
}
else
{
equal
=
false
;
equal
=
false
;
}
}
...
@@ -388,6 +370,20 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
...
@@ -388,6 +370,20 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
equal
=
false
;
equal
=
false
;
}
}
}
}
private
:
void
MergeVarDecl
(
const
Var
&
var1
,
const
Var
&
var2
)
{
if
(
var1
->
type_annotation
.
defined
()
!=
var2
->
type_annotation
.
defined
())
{
equal
=
false
;
return
;
}
if
(
var1
->
type_annotation
.
defined
()
&&
!
AlphaEqual
(
var1
->
type_annotation
,
var2
->
type_annotation
))
{
equal
=
false
;
return
;
}
eq_map
.
Set
(
var1
,
var2
);
}
};
};
bool
AlphaEqual
(
const
Expr
&
e1
,
const
Expr
&
e2
)
{
bool
AlphaEqual
(
const
Expr
&
e1
,
const
Expr
&
e2
)
{
...
...
src/relay/pass/dead_code.cc
View file @
0b4cc050
...
@@ -54,12 +54,7 @@ class CalcDep : private ExprMutator {
...
@@ -54,12 +54,7 @@ class CalcDep : private ExprMutator {
}
}
private
:
private
:
struct
Binder
{
using
VarMap
=
std
::
unordered_map
<
Var
,
Expr
,
NodeHash
,
NodeEqual
>
;
Type
t
;
Expr
e
;
Binder
(
const
Type
&
t
,
const
Expr
&
e
)
:
t
(
t
),
e
(
e
)
{
}
};
using
VarMap
=
std
::
unordered_map
<
Var
,
Binder
,
NodeHash
,
NodeEqual
>
;
VarMap
var_map_
;
VarMap
var_map_
;
Expr
VisitExpr_
(
const
IfNode
*
i
)
final
{
Expr
VisitExpr_
(
const
IfNode
*
i
)
final
{
...
@@ -74,9 +69,7 @@ class CalcDep : private ExprMutator {
...
@@ -74,9 +69,7 @@ class CalcDep : private ExprMutator {
}
}
Expr
VisitExpr_
(
const
LetNode
*
l
)
final
{
Expr
VisitExpr_
(
const
LetNode
*
l
)
final
{
var_map_
.
insert
(
std
::
pair
<
Var
,
Binder
>
(
l
->
var
,
var_map_
[
l
->
var
]
=
Eliminate
(
l
->
value
);
Binder
(
l
->
value_type
,
Eliminate
(
l
->
value
))));
return
VisitExpr
(
l
->
body
);
return
VisitExpr
(
l
->
body
);
}
}
...
@@ -92,15 +85,16 @@ class CalcDep : private ExprMutator {
...
@@ -92,15 +85,16 @@ class CalcDep : private ExprMutator {
explicit
GenLet
(
const
VarMap
&
var_map
)
:
var_map_
(
var_map
)
{
}
explicit
GenLet
(
const
VarMap
&
var_map
)
:
var_map_
(
var_map
)
{
}
friend
CalcDep
;
friend
CalcDep
;
void
VisitExpr_
(
const
VarNode
*
vn
)
final
{
void
VisitExpr_
(
const
VarNode
*
vnode
)
final
{
Var
v
=
GetRef
<
Var
>
(
vn
);
Var
v
=
GetRef
<
Var
>
(
vnode
);
if
(
var_map_
.
count
(
v
)
!=
0
)
{
auto
it
=
var_map_
.
find
(
v
);
auto
val
=
var_map_
.
at
(
v
);
if
(
it
!=
var_map_
.
end
())
{
var_map_
.
erase
(
v
);
Expr
expr
=
it
->
second
;
var_map_
.
erase
(
it
);
// erase before visit to handle letrec
// erase before visit to handle letrec
VisitExpr
(
val
.
e
);
VisitExpr
(
expr
);
// visit before push back so the dependency of dependency is before the dependency
// visit before push back so the dependency of dependency is before the dependency
lets_
.
Push
(
v
,
val
.
t
,
val
.
e
);
lets_
.
Push
(
v
,
expr
);
}
}
}
}
};
};
...
...
src/relay/pass/let_list.h
View file @
0b4cc050
...
@@ -26,23 +26,22 @@ namespace relay {
...
@@ -26,23 +26,22 @@ namespace relay {
*/
*/
class
LetList
{
class
LetList
{
public
:
public
:
/*! \brief insert a binding.
/*!
* \brief insert a binding.
*
*
* \param pv the var of the binding.
* \param pv the var of the binding.
*
*
* \param ty the type of the binding.
*
* \param expr the value of the binding.
* \param expr the value of the binding.
*
*
* \return a Var that hold the inserted expr.
* \return a Var that hold the inserted expr.
*/
*/
Var
Push
(
const
Var
&
pv
,
const
Type
&
ty
,
const
Expr
&
expr
)
{
Var
Push
(
Var
pv
,
Expr
expr
)
{
std
::
tuple
<
Var
,
Type
,
Expr
>
tuple
(
pv
,
ty
,
expr
);
lets_
.
emplace_back
(
std
::
make_pair
(
pv
,
expr
));
lets_
.
push_back
(
tuple
);
return
pv
;
return
pv
;
}
}
/*! \brief insert a binding.
/*!
* \brief insert a binding.
*
*
* \param ty the type of the binding.
* \param ty the type of the binding.
*
*
...
@@ -50,33 +49,23 @@ class LetList {
...
@@ -50,33 +49,23 @@ class LetList {
*
*
* \return a Var that hold the inserted expr.
* \return a Var that hold the inserted expr.
*/
*/
Var
Push
(
const
Type
&
ty
,
const
Expr
&
expr
)
{
Var
Push
(
Type
ty
,
Expr
expr
)
{
return
Push
(
VarNode
::
make
(
"x"
),
ty
,
expr
);
return
Push
(
VarNode
::
make
(
"x"
,
ty
),
expr
);
}
/*! \brief insert a binding.
*
* \param pv the var of the binding.
*
* \param expr the value of the binding.
*
* \return a Var that hold the inserted expr.
*/
Var
Push
(
const
Var
&
pv
,
const
Expr
&
expr
)
{
return
Push
(
pv
,
IncompleteTypeNode
::
make
(
TypeParamNode
::
kType
),
expr
);
}
}
/*! \brief insert a binding.
/*!
* \brief insert a binding.
*
*
* \param expr the value of the binding.
* \param expr the value of the binding.
*
*
* \return a Var that hold the inserted expr.
* \return a Var that hold the inserted expr.
*/
*/
Var
Push
(
const
Expr
&
expr
)
{
Var
Push
(
Expr
expr
)
{
return
Push
(
IncompleteTypeNode
::
make
(
TypeParamNode
::
kType
),
expr
);
return
Push
(
IncompleteTypeNode
::
make
(
TypeParamNode
::
kType
),
expr
);
}
}
/*! \brief wrap an expr around the LetList.
/*!
* \brief wrap an expr around the LetList.
*
*
* \param body the Expression to be wrapped around.
* \param body the Expression to be wrapped around.
*
*
...
@@ -85,7 +74,7 @@ class LetList {
...
@@ -85,7 +74,7 @@ class LetList {
Expr
Get
(
const
Expr
&
body
)
const
{
Expr
Get
(
const
Expr
&
body
)
const
{
Expr
ret
=
body
;
Expr
ret
=
body
;
for
(
auto
rit
=
lets_
.
rbegin
();
rit
!=
lets_
.
rend
();
++
rit
)
{
for
(
auto
rit
=
lets_
.
rbegin
();
rit
!=
lets_
.
rend
();
++
rit
)
{
ret
=
LetNode
::
make
(
std
::
get
<
0
>
(
*
rit
),
std
::
get
<
2
>
(
*
rit
),
ret
,
std
::
get
<
1
>
(
*
rit
)
);
ret
=
LetNode
::
make
(
std
::
get
<
0
>
(
*
rit
),
std
::
get
<
1
>
(
*
rit
),
ret
);
}
}
return
ret
;
return
ret
;
}
}
...
@@ -118,7 +107,7 @@ class LetList {
...
@@ -118,7 +107,7 @@ class LetList {
}
}
private
:
private
:
std
::
vector
<
std
::
tuple
<
Var
,
Type
,
Expr
>
>
lets_
;
std
::
vector
<
std
::
pair
<
Var
,
Expr
>
>
lets_
;
};
};
}
// namespace relay
}
// namespace relay
...
...
src/relay/pass/type_infer.cc
View file @
0b4cc050
...
@@ -87,15 +87,11 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
...
@@ -87,15 +87,11 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
// Visitor logics
// Visitor logics
Type
VisitExpr_
(
const
VarNode
*
op
)
final
{
Type
VisitExpr_
(
const
VarNode
*
op
)
final
{
// The type of Var can already been lookedup in type_map_;
if
(
op
->
type_annotation
.
defined
())
{
LOG
(
FATAL
)
<<
"Cannot find binding for var "
<<
GetRef
<
Var
>
(
op
);
return
op
->
type_annotation
;
return
Type
();
}
else
{
return
IncompleteTypeNode
::
make
(
TypeParamNode
::
kType
);
}
}
Type
VisitExpr_
(
const
ParamNode
*
op
)
final
{
// directly handled by Funtion
LOG
(
FATAL
)
<<
"not reached"
;
return
Type
();
}
}
Type
VisitExpr_
(
const
GlobalVarNode
*
op
)
final
{
Type
VisitExpr_
(
const
GlobalVarNode
*
op
)
final
{
...
@@ -139,11 +135,11 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
...
@@ -139,11 +135,11 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
Type
VisitExpr_
(
const
LetNode
*
op
)
final
{
Type
VisitExpr_
(
const
LetNode
*
op
)
final
{
Type
vtype
=
GetType
(
op
->
value
);
Type
vtype
=
GetType
(
op
->
value
);
if
(
op
->
va
lue_type
.
defined
())
{
if
(
op
->
va
r
->
type_annotation
.
defined
())
{
vtype
=
Unify
(
vtype
,
op
->
va
lue_type
,
op
->
span
);
vtype
=
Unify
(
vtype
,
op
->
va
r
->
type_annotation
,
op
->
span
);
}
}
CHECK
(
!
type_map_
.
count
(
op
->
var
));
CHECK
(
!
type_map_
.
count
(
op
->
var
));
// NOTE: no scoping is necessary becase var are unique in program
// NOTE: no scoping is necessary beca
u
se var are unique in program
type_map_
[
op
->
var
]
=
vtype
;
type_map_
[
op
->
var
]
=
vtype
;
return
GetType
(
op
->
body
);
return
GetType
(
op
->
body
);
}
}
...
@@ -256,8 +252,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
...
@@ -256,8 +252,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
Type
VisitExpr_
(
const
FunctionNode
*
f
)
final
{
Type
VisitExpr_
(
const
FunctionNode
*
f
)
final
{
for
(
auto
param
:
f
->
params
)
{
for
(
auto
param
:
f
->
params
)
{
type_map_
[
param
->
var
]
=
param
->
type
;
GetType
(
param
);
type_map_
[
param
]
=
param
->
type
;
}
}
Type
rtype
=
GetType
(
f
->
body
);
Type
rtype
=
GetType
(
f
->
body
);
// Run solver using the currently known information
// Run solver using the currently known information
...
@@ -265,8 +260,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
...
@@ -265,8 +260,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
// Trying to resolve
// Trying to resolve
Array
<
Type
>
arg_types
;
Array
<
Type
>
arg_types
;
for
(
size_t
i
=
0
;
i
<
f
->
params
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
f
->
params
.
size
();
++
i
)
{
Param
param
=
f
->
params
[
i
];
Type
atype
=
solver_
.
Resolve
(
GetType
(
f
->
params
[
i
]));
Type
atype
=
solver_
.
Resolve
(
param
->
type
);
CHECK
(
atype
.
as
<
IncompleteTypeNode
>
()
==
nullptr
)
CHECK
(
atype
.
as
<
IncompleteTypeNode
>
()
==
nullptr
)
<<
"Cannot resolve type of "
<<
i
<<
"Cannot resolve type of "
<<
i
<<
"-th parameter of function at"
<<
f
->
span
;
<<
"-th parameter of function at"
<<
f
->
span
;
...
@@ -311,9 +305,6 @@ class TypeInferencer::Resolver : public ExprMutator {
...
@@ -311,9 +305,6 @@ class TypeInferencer::Resolver : public ExprMutator {
return
AttachCheckedType
(
op
);
return
AttachCheckedType
(
op
);
}
}
Expr
VisitExpr_
(
const
ParamNode
*
op
)
final
{
return
ExprMutator
::
VisitExpr_
(
op
);
}
Expr
VisitExpr_
(
const
FunctionNode
*
op
)
final
{
Expr
VisitExpr_
(
const
FunctionNode
*
op
)
final
{
return
AttachCheckedType
(
op
);
return
AttachCheckedType
(
op
);
...
@@ -380,7 +371,7 @@ Expr InferType(const Environment& env,
...
@@ -380,7 +371,7 @@ Expr InferType(const Environment& env,
const
GlobalVar
&
var
,
const
GlobalVar
&
var
,
const
Function
&
func
)
{
const
Function
&
func
)
{
Function
func_copy
=
Function
(
make_node
<
FunctionNode
>
(
*
func
.
operator
->
()));
Function
func_copy
=
Function
(
make_node
<
FunctionNode
>
(
*
func
.
operator
->
()));
func_copy
->
checked_type_
=
func_copy
->
f
n_type
();
func_copy
->
checked_type_
=
func_copy
->
f
unc_type_annotation
();
env
->
functions
.
Set
(
var
,
func_copy
);
env
->
functions
.
Set
(
var
,
func_copy
);
Expr
func_ret
=
TypeInferencer
(
env
).
Infer
(
func_copy
);
Expr
func_ret
=
TypeInferencer
(
env
).
Infer
(
func_copy
);
auto
map_node
=
env
->
functions
.
CopyOnWrite
();
auto
map_node
=
env
->
functions
.
CopyOnWrite
();
...
...
src/relay/pass/util.cc
View file @
0b4cc050
...
@@ -50,14 +50,17 @@ class FreeVar : public ExprVisitor {
...
@@ -50,14 +50,17 @@ class FreeVar : public ExprVisitor {
if
(
bound_vars
.
count
(
var
)
==
0
)
{
if
(
bound_vars
.
count
(
var
)
==
0
)
{
free_vars
.
insert
(
var
);
free_vars
.
insert
(
var
);
}
}
if
(
v
->
type_annotation
.
defined
())
{
VisitType
(
v
->
type_annotation
);
}
}
}
void
VisitExpr_
(
const
FunctionNode
*
f
)
final
{
void
VisitExpr_
(
const
FunctionNode
*
f
)
final
{
for
(
const
auto
&
tp
:
f
->
type_params
)
{
for
(
const
auto
&
tp
:
f
->
type_params
)
{
bound_types
.
insert
(
tp
);
bound_types
.
insert
(
tp
);
}
}
for
(
const
auto
&
p
:
f
->
params
)
{
for
(
const
auto
&
p
aram
:
f
->
params
)
{
bound_vars
.
insert
(
p
->
var
);
bound_vars
.
insert
(
p
aram
);
}
}
VisitExpr
(
f
->
body
);
VisitExpr
(
f
->
body
);
VisitType
(
f
->
ret_type
);
VisitType
(
f
->
ret_type
);
...
@@ -67,7 +70,6 @@ class FreeVar : public ExprVisitor {
...
@@ -67,7 +70,6 @@ class FreeVar : public ExprVisitor {
bound_vars
.
insert
(
l
->
var
);
bound_vars
.
insert
(
l
->
var
);
VisitExpr
(
l
->
value
);
VisitExpr
(
l
->
value
);
VisitExpr
(
l
->
body
);
VisitExpr
(
l
->
body
);
VisitType
(
l
->
value_type
);
}
}
public
:
public
:
...
...
src/relay/pass/well_formed.cc
View file @
0b4cc050
...
@@ -34,8 +34,8 @@ class WellFormedChecker : private ExprVisitor {
...
@@ -34,8 +34,8 @@ class WellFormedChecker : private ExprVisitor {
}
}
void
VisitExpr_
(
const
FunctionNode
*
f
)
final
{
void
VisitExpr_
(
const
FunctionNode
*
f
)
final
{
for
(
const
Param
&
p
:
f
->
params
)
{
for
(
const
Var
&
param
:
f
->
params
)
{
Check
(
p
->
var
);
Check
(
p
aram
);
}
}
CheckWellFormed
(
f
->
body
);
CheckWellFormed
(
f
->
body
);
}
}
...
...
tests/python/relay/test_ir_builder.py
View file @
0b4cc050
...
@@ -14,7 +14,6 @@ def test_let():
...
@@ -14,7 +14,6 @@ def test_let():
assert
var
==
prog
.
body
assert
var
==
prog
.
body
assert
isinstance
(
value
,
Constant
)
assert
isinstance
(
value
,
Constant
)
assert
value
.
data
.
asnumpy
()
==
np
.
array
(
1
)
assert
value
.
data
.
asnumpy
()
==
np
.
array
(
1
)
assert
prog
.
value_type
==
None
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_let
()
test_let
()
tests/python/relay/test_ir_debug_printer.py
View file @
0b4cc050
...
@@ -49,18 +49,11 @@ def test_global_var():
...
@@ -49,18 +49,11 @@ def test_global_var():
show
(
gv
)
show
(
gv
)
def
test_param
():
lv
=
relay
.
Var
(
'x'
)
ty
=
None
param
=
relay
.
Param
(
lv
,
ty
)
show
(
lv
)
def
test_function
():
def
test_function
():
param_names
=
[
'a'
,
'b'
,
'c'
,
'd'
]
param_names
=
[
'a'
,
'b'
,
'c'
,
'd'
]
params
=
tvm
.
convert
([
relay
.
Param
(
relay
.
Var
(
n
),
None
)
for
n
in
param_names
])
params
=
tvm
.
convert
([
relay
.
Var
(
n
)
for
n
in
param_names
])
ret_type
=
None
ret_type
=
None
body
=
params
[
0
]
.
var
body
=
params
[
0
]
type_params
=
tvm
.
convert
([])
type_params
=
tvm
.
convert
([])
fn
=
relay
.
Function
(
params
,
ret_type
,
body
,
type_params
)
fn
=
relay
.
Function
(
params
,
ret_type
,
body
,
type_params
)
show
(
fn
)
show
(
fn
)
...
@@ -76,11 +69,11 @@ def test_call():
...
@@ -76,11 +69,11 @@ def test_call():
def
test_let
():
def
test_let
():
lv
=
relay
.
Var
(
'x'
)
ty
=
relay
.
ty
.
TensorType
((
10
,
20
),
'float32'
)
ty
=
relay
.
ty
.
TensorType
((
10
,
20
),
'float32'
)
lv
=
relay
.
Var
(
'x'
,
ty
)
arr
=
tvm
.
nd
.
array
(
10
)
arr
=
tvm
.
nd
.
array
(
10
)
value
=
relay
.
Constant
(
arr
)
value
=
relay
.
Constant
(
arr
)
let
=
relay
.
Let
(
lv
,
value
,
lv
,
ty
)
let
=
relay
.
Let
(
lv
,
value
,
lv
)
show
(
let
)
show
(
let
)
...
...
tests/python/relay/test_ir_nodes.py
View file @
0b4cc050
...
@@ -99,10 +99,16 @@ def test_tuple():
...
@@ -99,10 +99,16 @@ def test_tuple():
def
test_local_var
():
def
test_local_var
():
name_hint
=
's'
name_hint
=
's'
lv
=
relay
.
Var
(
name_hint
)
lv
=
relay
.
Var
(
name_hint
)
lv
.
name_hint
==
name_hint
assert
lv
.
name_hint
==
name_hint
assert
lv
.
type_annotation
is
None
# assert lv.span == None todo(@jroesch): what do we do about spans
# assert lv.span == None todo(@jroesch): what do we do about spans
str
(
lv
)
str
(
lv
)
t1
=
relay
.
ty
.
TensorType
((),
"float"
)
lv
=
relay
.
Var
(
name_hint
,
t1
)
assert
lv
.
name_hint
==
name_hint
assert
lv
.
type_annotation
==
t1
def
test_global_var
():
def
test_global_var
():
name_hint
=
'g'
name_hint
=
'g'
...
@@ -112,19 +118,9 @@ def test_global_var():
...
@@ -112,19 +118,9 @@ def test_global_var():
str
(
gv
)
str
(
gv
)
def
test_param
():
lv
=
relay
.
Var
(
'x'
)
ty
=
None
param
=
relay
.
Param
(
lv
,
ty
)
assert
param
.
var
==
lv
assert
param
.
type
==
ty
assert
param
.
span
==
None
str
(
param
)
def
test_function
():
def
test_function
():
param_names
=
[
'a'
,
'b'
,
'c'
,
'd'
]
param_names
=
[
'a'
,
'b'
,
'c'
,
'd'
]
params
=
tvm
.
convert
([
relay
.
Param
(
relay
.
Var
(
n
),
None
)
for
n
in
param_names
])
params
=
tvm
.
convert
([
relay
.
Var
(
n
)
for
n
in
param_names
])
ret_type
=
None
ret_type
=
None
body
=
None
body
=
None
type_params
=
tvm
.
convert
([])
type_params
=
tvm
.
convert
([])
...
@@ -154,10 +150,9 @@ def test_let():
...
@@ -154,10 +150,9 @@ def test_let():
value
=
relay
.
Constant
(
arr
)
value
=
relay
.
Constant
(
arr
)
# I would prefer that the order of arguments
# I would prefer that the order of arguments
# matches syntax let x: t = v in b
# matches syntax let x: t = v in b
let
=
relay
.
Let
(
lv
,
value
,
lv
,
ty
)
let
=
relay
.
Let
(
lv
,
value
,
lv
)
assert
let
.
var
==
lv
assert
let
.
var
==
lv
assert
let
.
value
==
value
assert
let
.
value
==
value
assert
let
.
value_type
==
ty
assert
let
.
body
==
lv
assert
let
.
body
==
lv
assert
let
.
span
==
None
assert
let
.
span
==
None
str
(
let
)
str
(
let
)
...
@@ -194,7 +189,6 @@ if __name__ == "__main__":
...
@@ -194,7 +189,6 @@ if __name__ == "__main__":
test_tuple
()
test_tuple
()
test_local_var
()
test_local_var
()
test_global_var
()
test_global_var
()
test_param
()
test_function
()
test_function
()
test_call
()
test_call
()
test_let
()
test_let
()
...
...
tests/python/relay/test_ir_well_formed.py
View file @
0b4cc050
...
@@ -7,23 +7,22 @@ def test_well_formed():
...
@@ -7,23 +7,22 @@ def test_well_formed():
assert
well_formed
(
x
)
assert
well_formed
(
x
)
v
=
relay
.
Constant
(
tvm
.
nd
.
array
(
10
))
v
=
relay
.
Constant
(
tvm
.
nd
.
array
(
10
))
ty
=
None
ty
=
None
let
=
relay
.
Let
(
x
,
v
,
x
,
ty
)
let
=
relay
.
Let
(
x
,
v
,
x
)
assert
well_formed
(
let
)
assert
well_formed
(
let
)
assert
not
well_formed
(
relay
.
Let
(
x
,
v
,
let
,
ty
))
assert
not
well_formed
(
relay
.
Let
(
x
,
v
,
let
))
f
=
relay
.
Function
([
relay
.
Param
(
x
,
ty
)
],
ty
,
x
)
f
=
relay
.
Function
([
x
],
ty
,
x
)
assert
well_formed
(
f
)
assert
well_formed
(
f
)
# this test should pass in case of weak uniqueness (only test for shadowing)
# this test should pass in case of weak uniqueness (only test for shadowing)
# but we want all binder to be distinct from each other.
# but we want all binder to be distinct from each other.
assert
not
well_formed
(
relay
.
Let
(
relay
.
Var
(
"y"
),
f
,
assert
not
well_formed
(
relay
.
Let
(
relay
.
Var
(
"y"
),
f
,
relay
.
Let
(
relay
.
Var
(
"z"
),
f
,
v
,
ty
),
ty
))
relay
.
Let
(
relay
.
Var
(
"z"
),
f
,
v
)
))
def
test_tuple
():
def
test_tuple
():
x
=
relay
.
Var
(
'x'
)
x
=
relay
.
Var
(
'x'
)
assert
well_formed
(
x
)
assert
well_formed
(
x
)
v
=
relay
.
Constant
(
tvm
.
nd
.
array
(
10
))
v
=
relay
.
Constant
(
tvm
.
nd
.
array
(
10
))
ty
=
None
let
=
relay
.
Let
(
x
,
v
,
x
)
let
=
relay
.
Let
(
x
,
v
,
x
,
ty
)
assert
well_formed
(
let
)
assert
well_formed
(
let
)
assert
well_formed
(
relay
.
Tuple
([
v
,
v
]))
assert
well_formed
(
relay
.
Tuple
([
v
,
v
]))
assert
not
well_formed
(
relay
.
Tuple
([
let
,
let
]))
assert
not
well_formed
(
relay
.
Tuple
([
let
,
let
]))
...
...
tests/python/relay/test_op_level1.py
View file @
0b4cc050
...
@@ -27,6 +27,8 @@ def test_single_op():
...
@@ -27,6 +27,8 @@ def test_single_op():
tvm
.
relay
.
sigmoid
,
tvm
.
relay
.
tanh
]:
tvm
.
relay
.
sigmoid
,
tvm
.
relay
.
tanh
]:
check_single_op
(
opfunc
)
check_single_op
(
opfunc
)
def
test_expand_dims_infer_type
():
def
test_expand_dims_infer_type
():
ib
=
relay
.
ir_builder
.
IRBuilder
()
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
t
,
d
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"t"
),
100
n
,
t
,
d
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"t"
),
100
...
@@ -75,12 +77,13 @@ def test_unary_op():
...
@@ -75,12 +77,13 @@ def test_unary_op():
ib
=
relay
.
ir_builder
.
IRBuilder
()
ib
=
relay
.
ir_builder
.
IRBuilder
()
x
=
ib
.
param
(
"x"
,
relay
.
TensorType
((
10
,
4
),
"int32"
))
x
=
ib
.
param
(
"x"
,
relay
.
TensorType
((
10
,
4
),
"int32"
))
with
ib
.
function
(
x
)
as
func
:
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
op
(
x
.
var
))
ib
.
ret
(
op
(
x
))
ib
.
ret
(
func
)
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
TensorType
((
10
,
4
),
"int32"
)
assert
ftype
.
ret_type
==
relay
.
TensorType
((
10
,
4
),
"int32"
)
def
test_binary_op
():
def
test_binary_op
():
def
check_binary_op
(
opfunc
):
def
check_binary_op
(
opfunc
):
"""
"""
...
@@ -94,7 +97,7 @@ def test_binary_op():
...
@@ -94,7 +97,7 @@ def test_binary_op():
x
=
b
.
param
(
'x'
,
tensor_type
(
5
,
5
,
5
))
x
=
b
.
param
(
'x'
,
tensor_type
(
5
,
5
,
5
))
y
=
b
.
param
(
'y'
,
tensor_type
(
5
,
5
,
5
))
y
=
b
.
param
(
'y'
,
tensor_type
(
5
,
5
,
5
))
with
b
.
function
(
x
,
y
)
as
func
:
with
b
.
function
(
x
,
y
)
as
func
:
b
.
ret
(
opfunc
(
x
.
var
,
y
.
var
))
b
.
ret
(
opfunc
(
x
,
y
))
b
.
ret
(
func
)
b
.
ret
(
func
)
prog
,
env
=
b
.
get
()
prog
,
env
=
b
.
get
()
ttype
=
tensor_type
(
5
,
5
,
5
)
ttype
=
tensor_type
(
5
,
5
,
5
)
...
@@ -118,7 +121,7 @@ def test_binary_broadcast_op():
...
@@ -118,7 +121,7 @@ def test_binary_broadcast_op():
x
=
b
.
param
(
'x'
,
tensor_type
(
10
,
4
))
x
=
b
.
param
(
'x'
,
tensor_type
(
10
,
4
))
y
=
b
.
param
(
'y'
,
tensor_type
(
5
,
10
,
1
))
y
=
b
.
param
(
'y'
,
tensor_type
(
5
,
10
,
1
))
with
b
.
function
(
x
,
y
)
as
func
:
with
b
.
function
(
x
,
y
)
as
func
:
b
.
ret
(
opfunc
(
x
.
var
,
y
.
var
))
b
.
ret
(
opfunc
(
x
,
y
))
b
.
ret
(
func
)
b
.
ret
(
func
)
prog
,
env
=
b
.
get
()
prog
,
env
=
b
.
get
()
...
...
tests/python/relay/test_op_level2.py
View file @
0b4cc050
...
@@ -11,7 +11,7 @@ def test_conv2d_infer_type():
...
@@ -11,7 +11,7 @@ def test_conv2d_infer_type():
w
=
ib
.
param
(
"w"
,
relay
.
ty
.
IncompleteType
())
w
=
ib
.
param
(
"w"
,
relay
.
ty
.
IncompleteType
())
with
ib
.
function
(
x
,
w
)
as
func
:
with
ib
.
function
(
x
,
w
)
as
func
:
ib
.
ret
(
relay
.
nn
.
conv2d
(
x
.
var
,
w
.
var
,
ib
.
ret
(
relay
.
nn
.
conv2d
(
x
,
w
,
kernel_size
=
(
3
,
3
),
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
),
padding
=
(
1
,
1
),
channels
=
2
))
channels
=
2
))
...
@@ -29,7 +29,7 @@ def test_conv2d_infer_type():
...
@@ -29,7 +29,7 @@ def test_conv2d_infer_type():
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"int8"
))
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"int8"
))
w
=
ib
.
param
(
"w"
,
relay
.
ty
.
TensorType
((
2
,
10
,
3
,
3
),
"int8"
))
w
=
ib
.
param
(
"w"
,
relay
.
ty
.
TensorType
((
2
,
10
,
3
,
3
),
"int8"
))
with
ib
.
function
(
x
,
w
)
as
func
:
with
ib
.
function
(
x
,
w
)
as
func
:
ib
.
ret
(
relay
.
nn
.
conv2d
(
x
.
var
,
w
.
var
,
out_dtype
=
"int32"
))
ib
.
ret
(
relay
.
nn
.
conv2d
(
x
,
w
,
out_dtype
=
"int32"
))
ib
.
ret
(
func
)
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
ftype
=
func
.
checked_type
...
@@ -42,7 +42,7 @@ def test_conv2d_infer_type():
...
@@ -42,7 +42,7 @@ def test_conv2d_infer_type():
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"int8"
))
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"int8"
))
w
=
ib
.
param
(
"w"
,
relay
.
ty
.
IncompleteType
())
w
=
ib
.
param
(
"w"
,
relay
.
ty
.
IncompleteType
())
with
ib
.
function
(
x
,
w
)
as
func
:
with
ib
.
function
(
x
,
w
)
as
func
:
ib
.
ret
(
relay
.
nn
.
conv2d
(
x
.
var
,
w
.
var
,
ib
.
ret
(
relay
.
nn
.
conv2d
(
x
,
w
,
kernel_size
=
(
3
,
3
),
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
),
padding
=
(
1
,
1
),
channels
=
16
,
channels
=
16
,
...
@@ -65,7 +65,7 @@ def test_conv2d_transpose_infer_type():
...
@@ -65,7 +65,7 @@ def test_conv2d_transpose_infer_type():
w
=
ib
.
param
(
"w"
,
relay
.
ty
.
IncompleteType
())
w
=
ib
.
param
(
"w"
,
relay
.
ty
.
IncompleteType
())
with
ib
.
function
(
x
,
w
)
as
func
:
with
ib
.
function
(
x
,
w
)
as
func
:
ib
.
ret
(
relay
.
nn
.
conv2d_transpose
(
x
.
var
,
w
.
var
,
ib
.
ret
(
relay
.
nn
.
conv2d_transpose
(
x
,
w
,
kernel_size
=
(
3
,
3
),
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
),
padding
=
(
1
,
1
),
channels
=
15
))
channels
=
15
))
...
@@ -83,7 +83,7 @@ def test_conv2d_transpose_infer_type():
...
@@ -83,7 +83,7 @@ def test_conv2d_transpose_infer_type():
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
w
=
ib
.
param
(
"w"
,
relay
.
ty
.
TensorType
((
12
,
11
,
5
,
5
),
"float32"
))
w
=
ib
.
param
(
"w"
,
relay
.
ty
.
TensorType
((
12
,
11
,
5
,
5
),
"float32"
))
with
ib
.
function
(
x
,
w
)
as
func
:
with
ib
.
function
(
x
,
w
)
as
func
:
ib
.
ret
(
relay
.
nn
.
conv2d_transpose
(
x
.
var
,
w
.
var
,
ib
.
ret
(
relay
.
nn
.
conv2d_transpose
(
x
,
w
,
output_padding
=
(
1
,
1
),
output_padding
=
(
1
,
1
),
channels
=
11
,
channels
=
11
,
data_layout
=
"NHWC"
))
data_layout
=
"NHWC"
))
...
@@ -98,7 +98,7 @@ def test_upsampling_infer_type():
...
@@ -98,7 +98,7 @@ def test_upsampling_infer_type():
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
),
tvm
.
var
(
"h"
),
tvm
.
var
(
"w"
)
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
),
tvm
.
var
(
"h"
),
tvm
.
var
(
"w"
)
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
relay
.
nn
.
upsampling
(
x
.
var
,
scale
=
2
,
layout
=
"NCHW"
,
method
=
"BILINEAR"
))
ib
.
ret
(
relay
.
nn
.
upsampling
(
x
,
scale
=
2
,
layout
=
"NCHW"
,
method
=
"BILINEAR"
))
ib
.
ret
(
func
)
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
ftype
=
func
.
checked_type
...
@@ -108,7 +108,7 @@ def test_upsampling_infer_type():
...
@@ -108,7 +108,7 @@ def test_upsampling_infer_type():
n
,
c
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
)
n
,
c
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
)
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
100
,
200
),
"float32"
))
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
100
,
200
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
relay
.
nn
.
upsampling
(
x
.
var
,
scale
=
2
,
layout
=
"NCHW"
,
method
=
"BILINEAR"
))
ib
.
ret
(
relay
.
nn
.
upsampling
(
x
,
scale
=
2
,
layout
=
"NCHW"
,
method
=
"BILINEAR"
))
ib
.
ret
(
func
)
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
ftype
=
func
.
checked_type
...
@@ -119,7 +119,7 @@ def _test_pool2d_infer_type(opfunc):
...
@@ -119,7 +119,7 @@ def _test_pool2d_infer_type(opfunc):
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
10
,
224
,
224
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
10
,
224
,
224
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
opfunc
(
x
.
var
,
pool_size
=
(
1
,
1
)))
ib
.
ret
(
opfunc
(
x
,
pool_size
=
(
1
,
1
)))
ib
.
ret
(
func
)
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
ftype
=
func
.
checked_type
...
@@ -132,7 +132,7 @@ def _test_pool2d_infer_type(opfunc):
...
@@ -132,7 +132,7 @@ def _test_pool2d_infer_type(opfunc):
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
10
,
224
,
224
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
10
,
224
,
224
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
opfunc
(
x
.
var
,
pool_size
=
(
ph
,
pw
),
strides
=
(
sh
,
sw
)))
ib
.
ret
(
opfunc
(
x
,
pool_size
=
(
ph
,
pw
),
strides
=
(
sh
,
sw
)))
ib
.
ret
(
func
)
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
ftype
=
func
.
checked_type
...
@@ -144,7 +144,7 @@ def _test_global_pool2d_infer_type(opfunc):
...
@@ -144,7 +144,7 @@ def _test_global_pool2d_infer_type(opfunc):
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
),
224
,
224
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
),
224
,
224
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
h
,
w
,
c
),
"float32"
))
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
h
,
w
,
c
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
opfunc
(
x
.
var
,
layout
=
"NHWC"
))
ib
.
ret
(
opfunc
(
x
,
layout
=
"NHWC"
))
ib
.
ret
(
func
)
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
ftype
=
func
.
checked_type
...
@@ -154,7 +154,7 @@ def _test_global_pool2d_infer_type(opfunc):
...
@@ -154,7 +154,7 @@ def _test_global_pool2d_infer_type(opfunc):
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
),
tvm
.
var
(
"h"
),
tvm
.
var
(
"w"
)
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
),
tvm
.
var
(
"h"
),
tvm
.
var
(
"w"
)
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
opfunc
(
x
.
var
))
ib
.
ret
(
opfunc
(
x
))
ib
.
ret
(
func
)
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
ftype
=
func
.
checked_type
...
@@ -172,7 +172,7 @@ def test_flatten_infer_type():
...
@@ -172,7 +172,7 @@ def test_flatten_infer_type():
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
d1
,
d2
,
d3
,
d4
),
"float32"
))
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
d1
,
d2
,
d3
,
d4
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
relay
.
nn
.
batch_flatten
(
x
.
var
))
ib
.
ret
(
relay
.
nn
.
batch_flatten
(
x
))
ib
.
ret
(
func
)
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
ftype
=
func
.
checked_type
...
@@ -181,7 +181,7 @@ def test_flatten_infer_type():
...
@@ -181,7 +181,7 @@ def test_flatten_infer_type():
ib
=
relay
.
ir_builder
.
IRBuilder
()
ib
=
relay
.
ir_builder
.
IRBuilder
()
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
3
,
2
,
4
,
3
),
"float32"
))
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
3
,
2
,
4
,
3
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
relay
.
nn
.
batch_flatten
(
x
.
var
))
ib
.
ret
(
relay
.
nn
.
batch_flatten
(
x
))
ib
.
ret
(
func
)
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
ftype
=
func
.
checked_type
...
@@ -190,7 +190,7 @@ def test_flatten_infer_type():
...
@@ -190,7 +190,7 @@ def test_flatten_infer_type():
ib
=
relay
.
ir_builder
.
IRBuilder
()
ib
=
relay
.
ir_builder
.
IRBuilder
()
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
d1
,
2
,
d3
,
3
),
"float32"
))
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
d1
,
2
,
d3
,
3
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
relay
.
nn
.
batch_flatten
(
x
.
var
))
ib
.
ret
(
relay
.
nn
.
batch_flatten
(
x
))
ib
.
ret
(
func
)
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
ftype
=
func
.
checked_type
...
@@ -202,7 +202,7 @@ def test_pad_infer_type():
...
@@ -202,7 +202,7 @@ def test_pad_infer_type():
n
,
c
,
h
,
w
=
1
,
2
,
3
,
4
n
,
c
,
h
,
w
=
1
,
2
,
3
,
4
t
=
ib
.
param
(
"t"
,
relay
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
t
=
ib
.
param
(
"t"
,
relay
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
with
ib
.
function
(
t
)
as
func
:
with
ib
.
function
(
t
)
as
func
:
ib
.
ret
(
relay
.
nn
.
pad
(
t
.
var
,
((
1
,
1
),
(
2
,
2
),
(
3
,
3
),
(
4
,
4
))))
ib
.
ret
(
relay
.
nn
.
pad
(
t
,
((
1
,
1
),
(
2
,
2
),
(
3
,
3
),
(
4
,
4
))))
ib
.
ret
(
func
)
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
ftype
=
func
.
checked_type
...
@@ -213,7 +213,7 @@ def test_pad_infer_type():
...
@@ -213,7 +213,7 @@ def test_pad_infer_type():
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
2
,
3
,
tvm
.
var
(
"w"
)
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
2
,
3
,
tvm
.
var
(
"w"
)
t
=
ib
.
param
(
"t"
,
relay
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
t
=
ib
.
param
(
"t"
,
relay
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
with
ib
.
function
(
t
)
as
func
:
with
ib
.
function
(
t
)
as
func
:
ib
.
ret
(
relay
.
nn
.
pad
(
t
.
var
,
((
1
,
1
),
(
2
,
2
),
(
3
,
3
),
(
4
,
4
))))
ib
.
ret
(
relay
.
nn
.
pad
(
t
,
((
1
,
1
),
(
2
,
2
),
(
3
,
3
),
(
4
,
4
))))
ib
.
ret
(
func
)
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
ftype
=
func
.
checked_type
...
@@ -227,4 +227,3 @@ if __name__ == "__main__":
...
@@ -227,4 +227,3 @@ if __name__ == "__main__":
test_flatten_infer_type
()
test_flatten_infer_type
()
test_pad_infer_type
()
test_pad_infer_type
()
test_conv2d_transpose_infer_type
()
test_conv2d_transpose_infer_type
()
tests/python/relay/test_op_level3.py
View file @
0b4cc050
...
@@ -17,12 +17,13 @@ def test_zeros_ones():
...
@@ -17,12 +17,13 @@ def test_zeros_ones():
ftype
=
func
.
checked_type
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
TensorType
((
124
,
50
),
"float64"
)
assert
ftype
.
ret_type
==
relay
.
TensorType
((
124
,
50
),
"float64"
)
def
test_unary_identity
():
def
test_unary_identity
():
for
op
in
[
relay
.
zeros_like
,
relay
.
ones_like
]:
for
op
in
[
relay
.
zeros_like
,
relay
.
ones_like
]:
ib
=
relay
.
ir_builder
.
IRBuilder
()
ib
=
relay
.
ir_builder
.
IRBuilder
()
x
=
ib
.
param
(
"x"
,
relay
.
TensorType
((
8
,
9
,
4
),
"int32"
))
x
=
ib
.
param
(
"x"
,
relay
.
TensorType
((
8
,
9
,
4
),
"int32"
))
with
ib
.
function
(
x
)
as
func
:
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
op
(
x
.
var
))
ib
.
ret
(
op
(
x
))
ib
.
ret
(
func
)
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
ftype
=
func
.
checked_type
...
@@ -33,7 +34,7 @@ def test_clip_type():
...
@@ -33,7 +34,7 @@ def test_clip_type():
ib
=
relay
.
ir_builder
.
IRBuilder
()
ib
=
relay
.
ir_builder
.
IRBuilder
()
a
=
ib
.
param
(
"a"
,
relay
.
TensorType
((
10
,
4
),
"float32"
))
a
=
ib
.
param
(
"a"
,
relay
.
TensorType
((
10
,
4
),
"float32"
))
with
ib
.
function
(
a
)
as
func
:
with
ib
.
function
(
a
)
as
func
:
ib
.
ret
(
relay
.
clip
(
a
.
var
,
1.
,
4.
))
ib
.
ret
(
relay
.
clip
(
a
,
1.
,
4.
))
ib
.
ret
(
func
)
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
ftype
=
func
.
checked_type
...
@@ -106,7 +107,7 @@ def test_take_infer_type():
...
@@ -106,7 +107,7 @@ def test_take_infer_type():
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
(
dshape
,
"float32"
))
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
(
dshape
,
"float32"
))
indices
=
ib
.
param
(
"indices"
,
relay
.
ty
.
TensorType
(
indices_shape
,
"int32"
))
indices
=
ib
.
param
(
"indices"
,
relay
.
ty
.
TensorType
(
indices_shape
,
"int32"
))
with
ib
.
function
(
x
,
indices
)
as
func
:
with
ib
.
function
(
x
,
indices
)
as
func
:
ib
.
ret
(
relay
.
take
(
x
.
var
,
indices
.
var
,
axis
=
axis
))
ib
.
ret
(
relay
.
take
(
x
,
indices
,
axis
=
axis
))
ib
.
ret
(
func
)
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
ftype
=
func
.
checked_type
...
@@ -127,7 +128,7 @@ def test_full():
...
@@ -127,7 +128,7 @@ def test_full():
ib
=
relay
.
ir_builder
.
IRBuilder
()
ib
=
relay
.
ir_builder
.
IRBuilder
()
x
=
ib
.
param
(
"x"
,
relay
.
TensorType
((),
"int8"
))
x
=
ib
.
param
(
"x"
,
relay
.
TensorType
((),
"int8"
))
with
ib
.
function
(
x
)
as
func
:
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
relay
.
full
(
x
.
var
,
()))
ib
.
ret
(
relay
.
full
(
x
,
()))
ib
.
ret
(
func
)
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
ftype
=
func
.
checked_type
...
@@ -137,7 +138,7 @@ def test_full():
...
@@ -137,7 +138,7 @@ def test_full():
ib
=
relay
.
ir_builder
.
IRBuilder
()
ib
=
relay
.
ir_builder
.
IRBuilder
()
x
=
ib
.
param
(
"x"
,
relay
.
TensorType
((),
"float32"
))
x
=
ib
.
param
(
"x"
,
relay
.
TensorType
((),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
relay
.
full
(
x
.
var
,
(
1
,
2
),
"int8"
))
ib
.
ret
(
relay
.
full
(
x
,
(
1
,
2
),
"int8"
))
ib
.
ret
(
func
)
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
ftype
=
func
.
checked_type
...
@@ -150,7 +151,7 @@ def test_full_like():
...
@@ -150,7 +151,7 @@ def test_full_like():
base
=
ib
.
param
(
"base"
,
relay
.
TensorType
((
1
,
2
,
3
),
"float32"
))
base
=
ib
.
param
(
"base"
,
relay
.
TensorType
((
1
,
2
,
3
),
"float32"
))
fill
=
ib
.
param
(
"fill"
,
relay
.
TensorType
((),
"float32"
))
fill
=
ib
.
param
(
"fill"
,
relay
.
TensorType
((),
"float32"
))
with
ib
.
function
(
base
,
fill
)
as
func
:
with
ib
.
function
(
base
,
fill
)
as
func
:
ib
.
ret
(
relay
.
full_like
(
base
.
var
,
fill
.
var
))
ib
.
ret
(
relay
.
full_like
(
base
,
fill
))
ib
.
ret
(
func
)
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
ftype
=
func
.
checked_type
...
@@ -162,7 +163,7 @@ def test_full_like():
...
@@ -162,7 +163,7 @@ def test_full_like():
base
=
ib
.
param
(
"base"
,
relay
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
base
=
ib
.
param
(
"base"
,
relay
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
fill
=
ib
.
param
(
"fill"
,
relay
.
TensorType
((),
"float32"
))
fill
=
ib
.
param
(
"fill"
,
relay
.
TensorType
((),
"float32"
))
with
ib
.
function
(
base
,
fill
)
as
func
:
with
ib
.
function
(
base
,
fill
)
as
func
:
ib
.
ret
(
relay
.
full_like
(
base
.
var
,
fill
.
var
))
ib
.
ret
(
relay
.
full_like
(
base
,
fill
))
ib
.
ret
(
func
)
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
ftype
=
func
.
checked_type
...
...
tests/python/relay/test_op_level4.py
View file @
0b4cc050
...
@@ -24,7 +24,7 @@ def test_cmp_type():
...
@@ -24,7 +24,7 @@ def test_cmp_type():
x
=
ib
.
param
(
"x"
,
relay
.
TensorType
((
10
,
4
),
"float32"
))
x
=
ib
.
param
(
"x"
,
relay
.
TensorType
((
10
,
4
),
"float32"
))
y
=
ib
.
param
(
"y"
,
relay
.
TensorType
((
5
,
10
,
1
),
"float32"
))
y
=
ib
.
param
(
"y"
,
relay
.
TensorType
((
5
,
10
,
1
),
"float32"
))
with
ib
.
function
(
x
,
y
)
as
func
:
with
ib
.
function
(
x
,
y
)
as
func
:
ib
.
ret
(
op
(
x
.
var
,
y
.
var
))
ib
.
ret
(
op
(
x
,
y
))
ib
.
ret
(
func
)
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
ftype
=
func
.
checked_type
...
@@ -39,7 +39,7 @@ def test_binary_broadcast():
...
@@ -39,7 +39,7 @@ def test_binary_broadcast():
x
=
ib
.
param
(
"x"
,
relay
.
TensorType
((
10
,
4
),
"int32"
))
x
=
ib
.
param
(
"x"
,
relay
.
TensorType
((
10
,
4
),
"int32"
))
y
=
ib
.
param
(
"y"
,
relay
.
TensorType
((
5
,
10
,
1
),
"int32"
))
y
=
ib
.
param
(
"y"
,
relay
.
TensorType
((
5
,
10
,
1
),
"int32"
))
with
ib
.
function
(
x
,
y
)
as
func
:
with
ib
.
function
(
x
,
y
)
as
func
:
ib
.
ret
(
op
(
x
.
var
,
y
.
var
))
ib
.
ret
(
op
(
x
,
y
))
ib
.
ret
(
func
)
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
ftype
=
func
.
checked_type
...
@@ -58,7 +58,7 @@ def test_binary_op():
...
@@ -58,7 +58,7 @@ def test_binary_op():
x
=
b
.
param
(
'x'
,
tensor_type
(
5
,
5
,
5
))
x
=
b
.
param
(
'x'
,
tensor_type
(
5
,
5
,
5
))
y
=
b
.
param
(
'y'
,
tensor_type
(
5
,
5
,
5
))
y
=
b
.
param
(
'y'
,
tensor_type
(
5
,
5
,
5
))
with
b
.
function
(
x
,
y
)
as
func
:
with
b
.
function
(
x
,
y
)
as
func
:
b
.
ret
(
opfunc
(
x
.
var
,
y
.
var
))
b
.
ret
(
opfunc
(
x
,
y
))
b
.
ret
(
func
)
b
.
ret
(
func
)
prog
,
env
=
b
.
get
()
prog
,
env
=
b
.
get
()
ttype
=
tensor_type
(
5
,
5
,
5
)
ttype
=
tensor_type
(
5
,
5
,
5
)
...
@@ -81,7 +81,7 @@ def test_binary_broadcast_op():
...
@@ -81,7 +81,7 @@ def test_binary_broadcast_op():
x
=
b
.
param
(
'x'
,
tensor_type
(
10
,
4
))
x
=
b
.
param
(
'x'
,
tensor_type
(
10
,
4
))
y
=
b
.
param
(
'y'
,
tensor_type
(
5
,
10
,
1
))
y
=
b
.
param
(
'y'
,
tensor_type
(
5
,
10
,
1
))
with
b
.
function
(
x
,
y
)
as
func
:
with
b
.
function
(
x
,
y
)
as
func
:
b
.
ret
(
opfunc
(
x
.
var
,
y
.
var
))
b
.
ret
(
opfunc
(
x
,
y
))
b
.
ret
(
func
)
b
.
ret
(
func
)
prog
,
env
=
b
.
get
()
prog
,
env
=
b
.
get
()
...
@@ -103,7 +103,7 @@ def test_cmp_type():
...
@@ -103,7 +103,7 @@ def test_cmp_type():
x
=
ib
.
param
(
"x"
,
relay
.
TensorType
((
10
,
4
),
"float32"
))
x
=
ib
.
param
(
"x"
,
relay
.
TensorType
((
10
,
4
),
"float32"
))
y
=
ib
.
param
(
"y"
,
relay
.
TensorType
((
5
,
10
,
1
),
"float32"
))
y
=
ib
.
param
(
"y"
,
relay
.
TensorType
((
5
,
10
,
1
),
"float32"
))
with
ib
.
function
(
x
,
y
)
as
func
:
with
ib
.
function
(
x
,
y
)
as
func
:
ib
.
ret
(
op
(
x
.
var
,
y
.
var
))
ib
.
ret
(
op
(
x
,
y
))
ib
.
ret
(
func
)
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
ftype
=
func
.
checked_type
...
@@ -118,7 +118,7 @@ def test_binary_broadcast():
...
@@ -118,7 +118,7 @@ def test_binary_broadcast():
x
=
ib
.
param
(
"x"
,
relay
.
TensorType
((
10
,
4
),
"int32"
))
x
=
ib
.
param
(
"x"
,
relay
.
TensorType
((
10
,
4
),
"int32"
))
y
=
ib
.
param
(
"y"
,
relay
.
TensorType
((
5
,
10
,
1
),
"int32"
))
y
=
ib
.
param
(
"y"
,
relay
.
TensorType
((
5
,
10
,
1
),
"int32"
))
with
ib
.
function
(
x
,
y
)
as
func
:
with
ib
.
function
(
x
,
y
)
as
func
:
ib
.
ret
(
op
(
x
.
var
,
y
.
var
))
ib
.
ret
(
op
(
x
,
y
))
ib
.
ret
(
func
)
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
ftype
=
func
.
checked_type
...
@@ -131,7 +131,7 @@ def test_where():
...
@@ -131,7 +131,7 @@ def test_where():
x
=
ib
.
param
(
"x"
,
relay
.
TensorType
((
3
,
4
),
"float32"
))
x
=
ib
.
param
(
"x"
,
relay
.
TensorType
((
3
,
4
),
"float32"
))
y
=
ib
.
param
(
"y"
,
relay
.
TensorType
((
3
,
4
),
"float32"
))
y
=
ib
.
param
(
"y"
,
relay
.
TensorType
((
3
,
4
),
"float32"
))
with
ib
.
function
(
cond
,
x
,
y
)
as
func
:
with
ib
.
function
(
cond
,
x
,
y
)
as
func
:
ib
.
ret
(
relay
.
where
(
cond
.
var
,
x
.
var
,
y
.
var
))
ib
.
ret
(
relay
.
where
(
cond
,
x
,
y
))
ib
.
ret
(
func
)
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
ftype
=
func
.
checked_type
...
...
tests/python/relay/test_op_level5.py
View file @
0b4cc050
...
@@ -10,7 +10,7 @@ def test_resize_infer_type():
...
@@ -10,7 +10,7 @@ def test_resize_infer_type():
th
,
tw
=
tvm
.
var
(
"th"
),
tvm
.
var
(
"tw"
)
th
,
tw
=
tvm
.
var
(
"th"
),
tvm
.
var
(
"tw"
)
with
ib
.
function
(
x
)
as
func
:
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
relay
.
image
.
resize
(
x
.
var
,
(
th
,
tw
)))
ib
.
ret
(
relay
.
image
.
resize
(
x
,
(
th
,
tw
)))
ib
.
ret
(
func
)
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
ftype
=
func
.
checked_type
...
@@ -19,7 +19,7 @@ def test_resize_infer_type():
...
@@ -19,7 +19,7 @@ def test_resize_infer_type():
ib
=
relay
.
ir_builder
.
IRBuilder
()
ib
=
relay
.
ir_builder
.
IRBuilder
()
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"int8"
))
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"int8"
))
with
ib
.
function
(
x
)
as
func
:
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
relay
.
image
.
resize
(
x
.
var
,
(
100
,
200
),
"NCHW"
,
"BILINEAR"
,
False
))
ib
.
ret
(
relay
.
image
.
resize
(
x
,
(
100
,
200
),
"NCHW"
,
"BILINEAR"
,
False
))
ib
.
ret
(
func
)
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
ftype
=
func
.
checked_type
...
...
tests/python/relay/test_pass_alpha_equal.py
View file @
0b4cc050
import
tvm
import
tvm
import
numpy
as
np
from
tvm
import
relay
from
tvm
import
relay
from
tvm.relay.ir_pass
import
alpha_equal
from
tvm.relay.ir_pass
import
alpha_equal
from
tvm.relay.ir_builder
import
convert
from
tvm.relay.ir_builder
import
convert
...
@@ -179,9 +180,9 @@ def test_var_alpha_equal():
...
@@ -179,9 +180,9 @@ def test_var_alpha_equal():
assert
not
alpha_equal
(
v1
,
v2
)
assert
not
alpha_equal
(
v1
,
v2
)
# let node allows for setting the eq_map
# let node allows for setting the eq_map
l1
=
relay
.
Let
(
v1
,
convert
(
1
),
v1
,
None
)
l1
=
relay
.
Let
(
v1
,
convert
(
1
),
v1
)
l2
=
relay
.
Let
(
v2
,
convert
(
1
),
v2
,
None
)
l2
=
relay
.
Let
(
v2
,
convert
(
1
),
v2
)
l3
=
relay
.
Let
(
v1
,
convert
(
1
),
v2
,
None
)
l3
=
relay
.
Let
(
v1
,
convert
(
1
),
v2
)
assert
alpha_equal
(
l1
,
l2
)
assert
alpha_equal
(
l1
,
l2
)
assert
not
alpha_equal
(
l1
,
l3
)
assert
not
alpha_equal
(
l1
,
l3
)
...
@@ -209,10 +210,10 @@ def test_tuple_alpha_equal():
...
@@ -209,10 +210,10 @@ def test_tuple_alpha_equal():
assert
alpha_equal
(
tup
,
same
)
assert
alpha_equal
(
tup
,
same
)
# use the eq_map
# use the eq_map
let_tup
=
relay
.
Let
(
v1
,
tup
,
v1
,
None
)
let_tup
=
relay
.
Let
(
v1
,
tup
,
v1
)
let_mapped
=
relay
.
Let
(
v2
,
relay
.
Tuple
([
v2
,
convert
(
2
),
convert
(
3
),
let_mapped
=
relay
.
Let
(
v2
,
relay
.
Tuple
([
v2
,
convert
(
2
),
convert
(
3
),
relay
.
Tuple
([
convert
(
4
)])]),
relay
.
Tuple
([
convert
(
4
)])]),
v2
,
None
)
v2
)
assert
alpha_equal
(
let_tup
,
let_mapped
)
assert
alpha_equal
(
let_tup
,
let_mapped
)
more_fields
=
relay
.
Tuple
([
v1
,
convert
(
2
),
convert
(
3
),
relay
.
Tuple
([
convert
(
4
)]),
v2
])
more_fields
=
relay
.
Tuple
([
v1
,
convert
(
2
),
convert
(
3
),
relay
.
Tuple
([
convert
(
4
)]),
v2
])
...
@@ -242,61 +243,44 @@ def test_tuple_get_item_alpha_equal():
...
@@ -242,61 +243,44 @@ def test_tuple_get_item_alpha_equal():
assert
alpha_equal
(
relay
.
TupleGetItem
(
x
,
1
),
relay
.
TupleGetItem
(
x
,
1
))
assert
alpha_equal
(
relay
.
TupleGetItem
(
x
,
1
),
relay
.
TupleGetItem
(
x
,
1
))
def
test_param_alpha_equal
():
# only checks equality of the types
v1
=
relay
.
Var
(
"v1"
)
v2
=
relay
.
Var
(
"v2"
)
p1
=
relay
.
Param
(
v1
,
relay
.
TensorType
((
1
,
2
,
3
),
"float32"
))
p2
=
relay
.
Param
(
v2
,
relay
.
TensorType
((
1
,
2
,
3
),
"float32"
))
assert
alpha_equal
(
p1
,
p2
)
p3
=
relay
.
Param
(
v1
,
relay
.
TensorType
((
4
,
5
,
6
),
"int8"
))
assert
not
alpha_equal
(
p1
,
p3
)
p4
=
relay
.
Param
(
v1
,
relay
.
TupleType
([
relay
.
TensorType
((
1
,
2
,
3
),
"float32"
)]))
assert
not
alpha_equal
(
p1
,
p4
)
def
test_function_alpha_equal
():
def
test_function_alpha_equal
():
v1
=
relay
.
Var
(
"v1"
)
v2
=
relay
.
Var
(
"v2"
)
v3
=
relay
.
Var
(
"v3"
)
v4
=
relay
.
Var
(
"v4"
)
tt1
=
relay
.
TensorType
((
1
,
2
,
3
),
"float32"
)
tt1
=
relay
.
TensorType
((
1
,
2
,
3
),
"float32"
)
tt2
=
relay
.
TensorType
((
4
,
5
,
6
),
"int8"
)
tt2
=
relay
.
TensorType
((
4
,
5
,
6
),
"int8"
)
tt3
=
relay
.
TupleType
([
tt1
,
tt2
])
tt3
=
relay
.
TupleType
([
tt1
,
tt2
])
v1
=
relay
.
Var
(
"v1"
,
tt1
)
v2
=
relay
.
Var
(
"v2"
,
tt2
)
v3
=
relay
.
Var
(
"v3"
,
tt3
)
v4
=
relay
.
Var
(
"v4"
,
tt2
)
vret
=
relay
.
Constant
(
tvm
.
nd
.
array
(
np
.
ones
(
1
)))
tp1
=
relay
.
TypeParam
(
"tp1"
,
relay
.
Kind
.
Type
)
tp1
=
relay
.
TypeParam
(
"tp1"
,
relay
.
Kind
.
Type
)
tp2
=
relay
.
TypeParam
(
"tp2"
,
relay
.
Kind
.
Type
)
tp2
=
relay
.
TypeParam
(
"tp2"
,
relay
.
Kind
.
Type
)
tp3
=
relay
.
TypeParam
(
"tp3"
,
relay
.
Kind
.
Shape
)
tp3
=
relay
.
TypeParam
(
"tp3"
,
relay
.
Kind
.
Shape
)
tp4
=
relay
.
TypeParam
(
"tp4"
,
relay
.
Kind
.
Shape
)
tp4
=
relay
.
TypeParam
(
"tp4"
,
relay
.
Kind
.
Shape
)
basic_args
=
[
relay
.
Param
(
v3
,
tt1
),
relay
.
Param
(
v4
,
tt2
)]
basic_args
=
[
relay
.
Var
(
"v3"
,
tt1
),
relay
.
Var
(
"v4"
,
tt2
)]
basic_tps
=
[
tp1
,
tp2
]
basic_tps
=
[
tp1
,
tp2
]
func
=
relay
.
Function
([
relay
.
Param
(
v1
,
tt1
),
relay
.
Param
(
v2
,
tt2
)
],
func
=
relay
.
Function
([
v1
,
v2
],
tt2
,
v
2
,
basic_tps
)
tt2
,
v
1
,
basic_tps
)
mapped
=
relay
.
Function
(
basic_args
,
tt2
,
v4
,
basic_tps
)
mapped
=
relay
.
Function
(
basic_args
,
tt2
,
basic_args
[
0
]
,
basic_tps
)
assert
alpha_equal
(
func
,
mapped
)
assert
alpha_equal
(
func
,
mapped
)
fewer_params
=
relay
.
Function
([
relay
.
Param
(
v4
,
tt2
)],
tt2
,
v4
,
basic_tps
)
fewer_params
=
relay
.
Function
([
relay
.
Var
(
"v4"
,
tt2
)],
tt2
,
v4
,
basic_tps
)
assert
not
alpha_equal
(
func
,
fewer_params
)
assert
not
alpha_equal
(
func
,
fewer_params
)
more_params
=
relay
.
Function
([
relay
.
Param
(
v3
,
tt1
),
relay
.
Param
(
v4
,
tt2
),
more_params
=
relay
.
Function
([
relay
.
Var
(
"v3"
,
tt1
),
relay
.
Param
(
v2
,
tt2
)],
tt2
,
v4
,
basic_tps
)
relay
.
Var
(
"v4"
,
tt2
),
relay
.
Var
(
"v2"
,
tt2
)],
tt2
,
v4
,
basic_tps
)
assert
not
alpha_equal
(
func
,
more_params
)
assert
not
alpha_equal
(
func
,
more_params
)
params_unordered
=
relay
.
Function
([
relay
.
Param
(
v3
,
tt2
),
params_unordered
=
relay
.
Function
([
v2
,
v1
],
relay
.
Param
(
v4
,
tt1
)],
tt2
,
v1
,
basic_tps
)
tt1
,
v3
,
basic_tps
)
assert
not
alpha_equal
(
func
,
params_unordered
)
assert
not
alpha_equal
(
func
,
params_unordered
)
params_mismatch
=
relay
.
Function
([
relay
.
Param
(
v3
,
tt3
),
params_mismatch
=
relay
.
Function
([
v1
,
v3
],
relay
.
Param
(
v4
,
tt2
)],
tt2
,
v1
,
basic_tps
)
tt2
,
v4
,
basic_tps
)
assert
not
alpha_equal
(
func
,
params_mismatch
)
assert
not
alpha_equal
(
func
,
params_mismatch
)
# also would not typecheck
# also would not typecheck
...
@@ -376,7 +360,10 @@ def test_call_alpha_equal():
...
@@ -376,7 +360,10 @@ def test_call_alpha_equal():
def
test_let_alpha_equal
():
def
test_let_alpha_equal
():
tt1
=
relay
.
TensorType
((),
"float32"
)
tt2
=
relay
.
TensorType
((),
"int8"
)
v1
=
relay
.
Var
(
"v1"
)
v1
=
relay
.
Var
(
"v1"
)
v1_wtype
=
relay
.
Var
(
"v1"
,
tt1
)
v2
=
relay
.
Var
(
"v2"
)
v2
=
relay
.
Var
(
"v2"
)
v3
=
relay
.
Var
(
"v3"
)
v3
=
relay
.
Var
(
"v3"
)
...
@@ -394,14 +381,13 @@ def test_let_alpha_equal():
...
@@ -394,14 +381,13 @@ def test_let_alpha_equal():
assert
not
alpha_equal
(
let
,
different_body
)
assert
not
alpha_equal
(
let
,
different_body
)
# specified types must match
# specified types must match
tt1
=
relay
.
TensorType
((),
"float32"
)
tt2
=
relay
.
TensorType
((),
"int8"
)
let_with_type
=
relay
.
Let
(
v1_wtype
,
convert
(
2
),
v1_wtype
)
let_with_type
=
relay
.
Let
(
v1
,
convert
(
2
),
v1
,
tt1
)
same_type
=
relay
.
Let
(
v1_wtype
,
convert
(
2
),
v1_wtype
)
same_type
=
relay
.
Let
(
v1
,
convert
(
2
),
v1
,
tt1
)
assert
alpha_equal
(
let_with_type
,
same_type
)
assert
alpha_equal
(
let_with_type
,
same_type
)
assert
not
alpha_equal
(
let
,
let_with_type
)
assert
not
alpha_equal
(
let
,
let_with_type
)
v2
=
relay
.
Var
(
"v1"
,
tt2
)
different_type
=
relay
.
Let
(
v
1
,
convert
(
2
),
v1
,
tt
2
)
different_type
=
relay
.
Let
(
v
2
,
convert
(
2
),
v
2
)
assert
not
alpha_equal
(
let_with_type
,
different_type
)
assert
not
alpha_equal
(
let_with_type
,
different_type
)
...
@@ -437,16 +423,13 @@ if __name__ == "__main__":
...
@@ -437,16 +423,13 @@ if __name__ == "__main__":
test_tensor_type_alpha_equal
()
test_tensor_type_alpha_equal
()
test_incomplete_type_alpha_equal
()
test_incomplete_type_alpha_equal
()
test_constant_alpha_equal
()
test_constant_alpha_equal
()
test_type_param_alpha_equal
()
test_func_type_alpha_equal
()
test_func_type_alpha_equal
()
test_tuple_type_alpha_equal
()
test_tuple_type_alpha_equal
()
test_type_relation_alpha_equal
()
test_type_relation_alpha_equal
()
test_constant_alpha_equal
()
test_constant_alpha_equal
()
test_var_alpha_equal
()
test_global_var_alpha_equal
()
test_global_var_alpha_equal
()
test_tuple_alpha_equal
()
test_tuple_alpha_equal
()
test_tuple_get_item_alpha_equal
()
test_tuple_get_item_alpha_equal
()
test_param_alpha_equal
()
test_function_alpha_equal
()
test_function_alpha_equal
()
test_call_alpha_equal
()
test_call_alpha_equal
()
test_let_alpha_equal
()
test_let_alpha_equal
()
...
...
tests/python/relay/test_pass_dead_code_elimination.py
View file @
0b4cc050
...
@@ -28,17 +28,17 @@ e = env()
...
@@ -28,17 +28,17 @@ e = env()
def
test_let
():
def
test_let
():
orig
=
relay
.
Let
(
e
.
x
,
e
.
y
,
e
.
z
,
e
.
tt
)
orig
=
relay
.
Let
(
e
.
x
,
e
.
y
,
e
.
z
)
assert
alpha_equal
(
dead_code_elimination
(
orig
),
e
.
z
)
assert
alpha_equal
(
dead_code_elimination
(
orig
),
e
.
z
)
def
test_used_let
():
def
test_used_let
():
orig
=
relay
.
Let
(
e
.
a
,
e
.
b
,
relay
.
Let
(
e
.
c
,
e
.
d
,
e
.
c
,
e
.
tt
),
e
.
tt
)
orig
=
relay
.
Let
(
e
.
a
,
e
.
b
,
relay
.
Let
(
e
.
c
,
e
.
d
,
e
.
c
)
)
assert
alpha_equal
(
dead_code_elimination
(
orig
),
relay
.
Let
(
e
.
c
,
e
.
d
,
e
.
c
,
e
.
tt
))
assert
alpha_equal
(
dead_code_elimination
(
orig
),
relay
.
Let
(
e
.
c
,
e
.
d
,
e
.
c
))
def
test_chain_unused_let
():
def
test_chain_unused_let
():
orig
=
relay
.
Let
(
e
.
a
,
e
.
b
,
relay
.
Let
(
e
.
c
,
e
.
d
,
e
.
e
,
e
.
tt
),
e
.
tt
)
orig
=
relay
.
Let
(
e
.
a
,
e
.
b
,
relay
.
Let
(
e
.
c
,
e
.
d
,
e
.
e
)
)
assert
alpha_equal
(
dead_code_elimination
(
orig
),
e
.
e
)
assert
alpha_equal
(
dead_code_elimination
(
orig
),
e
.
e
)
...
@@ -56,19 +56,17 @@ def test_recursion():
...
@@ -56,19 +56,17 @@ def test_recursion():
f(2, 10000);
f(2, 10000);
"""
"""
f
=
relay
.
Var
(
"f"
)
f
=
relay
.
Var
(
"f"
)
n
=
relay
.
Var
(
"n"
)
n
=
relay
.
Var
(
"n"
,
e
.
int32
)
np
=
relay
.
Param
(
n
,
e
.
int32
)
data
=
relay
.
Var
(
"data"
,
e
.
float32
)
data
=
relay
.
Var
(
"data"
)
datap
=
relay
.
Param
(
data
,
e
.
float32
)
funcbody
=
relay
.
If
(
equal
(
n
,
convert
(
0
)),
data
,
f
(
subtract
(
n
,
convert
(
1.0
)),
log
(
data
)))
funcbody
=
relay
.
If
(
equal
(
n
,
convert
(
0
)),
data
,
f
(
subtract
(
n
,
convert
(
1.0
)),
log
(
data
)))
value
=
relay
.
Function
([
n
p
,
datap
],
e
.
float32
,
funcbody
,
[])
value
=
relay
.
Function
([
n
,
data
],
e
.
float32
,
funcbody
,
[])
orig
=
relay
.
Let
(
f
,
funcbody
,
f
(
convert
(
2.0
),
convert
(
10000.0
))
,
e
.
float32
)
orig
=
relay
.
Let
(
f
,
funcbody
,
f
(
convert
(
2.0
),
convert
(
10000.0
)))
assert
alpha_equal
(
dead_code_elimination
(
orig
),
orig
)
assert
alpha_equal
(
dead_code_elimination
(
orig
),
orig
)
assert
alpha_equal
(
dead_code_elimination
(
relay
.
Let
(
f
,
funcbody
,
e
.
three
,
e
.
float32
)),
e
.
three
)
assert
alpha_equal
(
dead_code_elimination
(
relay
.
Let
(
f
,
funcbody
,
e
.
three
)),
e
.
three
)
def
test_op_let
():
def
test_op_let
():
assert
alpha_equal
(
dead_code_elimination
(
add
(
relay
.
Let
(
e
.
a
,
e
.
one
,
e
.
three
,
e
.
float32
),
e
.
two
)),
add
(
e
.
three
,
e
.
two
))
assert
alpha_equal
(
dead_code_elimination
(
add
(
relay
.
Let
(
e
.
a
,
e
.
one
,
e
.
three
),
e
.
two
)),
add
(
e
.
three
,
e
.
two
))
def
test_if
():
def
test_if
():
...
@@ -80,7 +78,7 @@ def test_tuple_get_item():
...
@@ -80,7 +78,7 @@ def test_tuple_get_item():
t
=
relay
.
Var
(
't'
)
t
=
relay
.
Var
(
't'
)
g
=
relay
.
TupleGetItem
(
t
,
0
)
g
=
relay
.
TupleGetItem
(
t
,
0
)
assert
alpha_equal
(
dead_code_elimination
(
g
),
g
)
assert
alpha_equal
(
dead_code_elimination
(
g
),
g
)
assert
alpha_equal
(
dead_code_elimination
(
relay
.
TupleGetItem
(
relay
.
Let
(
e
.
a
,
e
.
one
,
t
,
e
.
float32
),
0
)),
g
)
assert
alpha_equal
(
dead_code_elimination
(
relay
.
TupleGetItem
(
relay
.
Let
(
e
.
a
,
e
.
one
,
t
),
0
)),
g
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
tests/python/relay/test_pass_free_vars.py
View file @
0b4cc050
...
@@ -3,16 +3,17 @@ from tvm import relay
...
@@ -3,16 +3,17 @@ from tvm import relay
from
tvm.relay.ir_pass
import
free_vars
,
free_type_vars
from
tvm.relay.ir_pass
import
free_vars
,
free_type_vars
def
test_free_vars
():
def
test_free_vars
():
x
=
relay
.
Var
(
"x"
)
ty
=
relay
.
TensorType
([],
"int32"
)
x
=
relay
.
Var
(
"x"
,
ty
)
fvx
=
free_vars
(
x
)
fvx
=
free_vars
(
x
)
assert
len
(
fvx
)
==
1
assert
len
(
fvx
)
==
1
assert
fvx
[
0
]
==
x
assert
fvx
[
0
]
==
x
v
=
relay
.
Constant
(
tvm
.
nd
.
array
(
10
))
v
=
relay
.
Constant
(
tvm
.
nd
.
array
(
10
))
ty
=
relay
.
TensorType
([],
"int32"
)
let
=
relay
.
Let
(
x
,
v
,
x
,
ty
)
let
=
relay
.
Let
(
x
,
v
,
x
)
fvx
=
free_vars
(
let
)
fvx
=
free_vars
(
let
)
assert
len
(
free_vars
(
let
))
==
0
assert
len
(
free_vars
(
let
))
==
0
f
=
relay
.
Function
([
relay
.
Param
(
x
,
ty
)
],
ty
,
x
)
f
=
relay
.
Function
([
x
],
ty
,
x
)
assert
len
(
free_vars
(
f
))
==
0
assert
len
(
free_vars
(
f
))
==
0
...
@@ -29,9 +30,9 @@ def test_tuple():
...
@@ -29,9 +30,9 @@ def test_tuple():
def
test_free_type_vars
():
def
test_free_type_vars
():
tp
=
relay
.
TypeParam
(
""
)
tp
=
relay
.
TypeParam
(
""
)
ty
=
relay
.
TupleType
([
tp
,
relay
.
TensorType
([],
"int32"
)])
ty
=
relay
.
TupleType
([
tp
,
relay
.
TensorType
([],
"int32"
)])
x
=
relay
.
Var
(
"x"
)
x
=
relay
.
Var
(
"x"
,
ty
)
y
=
relay
.
Var
(
"y"
)
y
=
relay
.
Var
(
"y"
)
let
=
relay
.
Let
(
x
,
y
,
x
,
ty
)
let
=
relay
.
Let
(
x
,
y
,
x
)
fvl
=
free_vars
(
let
)
fvl
=
free_vars
(
let
)
assert
len
(
fvl
)
==
1
assert
len
(
fvl
)
==
1
assert
fvl
[
0
]
==
y
assert
fvl
[
0
]
==
y
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment