Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
6616355d
Commit
6616355d
authored
Oct 10, 2018
by
雾雨魔理沙
Committed by
Tianqi Chen
Oct 10, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay] GetItem (#1861)
parent
4e309e67
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
177 additions
and
13 deletions
+177
-13
include/tvm/relay/expr.h
+22
-2
include/tvm/relay/expr_functor.h
+4
-0
python/tvm/relay/__init__.py
+1
-0
python/tvm/relay/expr.py
+8
-0
src/relay/ir/debug_printer.cc
+4
-2
src/relay/ir/expr.cc
+16
-0
src/relay/ir/expr_functor.cc
+13
-2
src/relay/pass/alpha_eq.cc
+9
-0
src/relay/pass/type_functor.h
+4
-5
src/relay/pass/type_infer.cc
+18
-0
tests/python/relay/test_ir_debug_printer.py
+6
-1
tests/python/relay/test_ir_nodes.py
+8
-0
tests/python/relay/test_ir_well_formed.py
+17
-1
tests/python/relay/test_pass_alpha_equal.py
+8
-0
tests/python/relay/test_pass_dead_code_elimination.py
+16
-0
tests/python/relay/test_pass_free_vars.py
+11
-0
tests/python/relay/test_type_infer.py
+12
-0
No files found.
include/tvm/relay/expr.h
View file @
6616355d
...
@@ -360,8 +360,6 @@ class IfNode : public ExprNode {
...
@@ -360,8 +360,6 @@ class IfNode : public ExprNode {
/*! \brief The expression evaluated when condition is false */
/*! \brief The expression evaluated when condition is false */
Expr
false_branch
;
Expr
false_branch
;
IfNode
()
{}
void
VisitAttrs
(
tvm
::
AttrVisitor
*
v
)
final
{
void
VisitAttrs
(
tvm
::
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"cond"
,
&
cond
);
v
->
Visit
(
"cond"
,
&
cond
);
v
->
Visit
(
"true_branch"
,
&
true_branch
);
v
->
Visit
(
"true_branch"
,
&
true_branch
);
...
@@ -378,6 +376,28 @@ class IfNode : public ExprNode {
...
@@ -378,6 +376,28 @@ class IfNode : public ExprNode {
RELAY_DEFINE_NODE_REF
(
If
,
IfNode
,
Expr
);
RELAY_DEFINE_NODE_REF
(
If
,
IfNode
,
Expr
);
/*! \brief Get a field out of a tuple. */
class
TupleGetItem
;
class
TupleGetItemNode
:
public
ExprNode
{
public
:
/*! \brief The tuple */
Expr
tuple
;
/*! \brief which value to get */
int
index
;
void
VisitAttrs
(
tvm
::
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"tuple"
,
&
tuple
);
v
->
Visit
(
"index"
,
&
index
);
}
TVM_DLL
static
TupleGetItem
make
(
Expr
tuple
,
int
index
);
static
constexpr
const
char
*
_type_key
=
"relay.GetItem"
;
TVM_DECLARE_NODE_TYPE_INFO
(
TupleGetItemNode
,
ExprNode
);
};
RELAY_DEFINE_NODE_REF
(
TupleGetItem
,
TupleGetItemNode
,
Expr
);
/*! \brief Print a debug representation of the expression to the stream.
/*! \brief Print a debug representation of the expression to the stream.
* \param env The environment.
* \param env The environment.
* \param e The expression
* \param e The expression
...
...
include/tvm/relay/expr_functor.h
View file @
6616355d
...
@@ -89,6 +89,7 @@ class ExprFunctor<R(const Expr& n, Args...)> {
...
@@ -89,6 +89,7 @@ class ExprFunctor<R(const Expr& n, Args...)> {
Args
...
args
)
EXPR_FUNCTOR_DEFAULT
;
Args
...
args
)
EXPR_FUNCTOR_DEFAULT
;
virtual
R
VisitExpr_
(
const
OpNode
*
op
,
virtual
R
VisitExpr_
(
const
OpNode
*
op
,
Args
...
args
)
EXPR_FUNCTOR_DEFAULT
;
Args
...
args
)
EXPR_FUNCTOR_DEFAULT
;
virtual
R
VisitExpr_
(
const
TupleGetItemNode
*
op
,
Args
...
args
)
EXPR_FUNCTOR_DEFAULT
;
virtual
R
VisitExprDefault_
(
const
Node
*
op
,
Args
...)
{
virtual
R
VisitExprDefault_
(
const
Node
*
op
,
Args
...)
{
throw
Error
(
std
::
string
(
"Do not have a default for "
)
+
op
->
type_key
());
throw
Error
(
std
::
string
(
"Do not have a default for "
)
+
op
->
type_key
());
}
}
...
@@ -108,6 +109,7 @@ class ExprFunctor<R(const Expr& n, Args...)> {
...
@@ -108,6 +109,7 @@ class ExprFunctor<R(const Expr& n, Args...)> {
RELAY_EXPR_FUNCTOR_DISPATCH
(
LetNode
);
RELAY_EXPR_FUNCTOR_DISPATCH
(
LetNode
);
RELAY_EXPR_FUNCTOR_DISPATCH
(
IfNode
);
RELAY_EXPR_FUNCTOR_DISPATCH
(
IfNode
);
RELAY_EXPR_FUNCTOR_DISPATCH
(
OpNode
);
RELAY_EXPR_FUNCTOR_DISPATCH
(
OpNode
);
RELAY_EXPR_FUNCTOR_DISPATCH
(
TupleGetItemNode
);
return
vtable
;
return
vtable
;
}
}
};
};
...
@@ -131,6 +133,7 @@ class ExprVisitor : public ::tvm::relay::ExprFunctor<void(const Expr& n)> {
...
@@ -131,6 +133,7 @@ class ExprVisitor : public ::tvm::relay::ExprFunctor<void(const Expr& n)> {
void
VisitExpr_
(
const
LetNode
*
op
)
override
;
void
VisitExpr_
(
const
LetNode
*
op
)
override
;
void
VisitExpr_
(
const
IfNode
*
op
)
override
;
void
VisitExpr_
(
const
IfNode
*
op
)
override
;
void
VisitExpr_
(
const
OpNode
*
op
)
override
;
void
VisitExpr_
(
const
OpNode
*
op
)
override
;
void
VisitExpr_
(
const
TupleGetItemNode
*
op
)
override
;
virtual
void
VisitType
(
const
Type
&
t
);
virtual
void
VisitType
(
const
Type
&
t
);
};
};
...
@@ -153,6 +156,7 @@ class ExprMutator
...
@@ -153,6 +156,7 @@ class ExprMutator
Expr
VisitExpr_
(
const
CallNode
*
call_node
)
override
;
Expr
VisitExpr_
(
const
CallNode
*
call_node
)
override
;
Expr
VisitExpr_
(
const
LetNode
*
op
)
override
;
Expr
VisitExpr_
(
const
LetNode
*
op
)
override
;
Expr
VisitExpr_
(
const
IfNode
*
op
)
override
;
Expr
VisitExpr_
(
const
IfNode
*
op
)
override
;
Expr
VisitExpr_
(
const
TupleGetItemNode
*
op
)
override
;
/*! \brief Used to visit the types inside of expressions.
/*! \brief Used to visit the types inside of expressions.
*
*
* Can be overloaded to transform the types in arbitrary
* Can be overloaded to transform the types in arbitrary
...
...
python/tvm/relay/__init__.py
View file @
6616355d
...
@@ -39,3 +39,4 @@ Function = expr.Function
...
@@ -39,3 +39,4 @@ Function = expr.Function
Call
=
expr
.
Call
Call
=
expr
.
Call
Let
=
expr
.
Let
Let
=
expr
.
Let
If
=
expr
.
If
If
=
expr
.
If
TupleGetItem
=
expr
.
TupleGetItem
python/tvm/relay/expr.py
View file @
6616355d
...
@@ -125,4 +125,12 @@ class If(Expr):
...
@@ -125,4 +125,12 @@ class If(Expr):
self
.
__init_handle_by_constructor__
(
self
.
__init_handle_by_constructor__
(
_make
.
If
,
cond
,
true_value
,
false_value
)
_make
.
If
,
cond
,
true_value
,
false_value
)
@register_relay_node
class
TupleGetItem
(
Expr
):
"""An expression that get field from tuple in Relay, see tvm/relay/expr.h for more details."""
def
__init__
(
self
,
tuple_
,
index
):
self
.
__init_handle_by_constructor__
(
_make
.
TupleGetItem
,
tuple_
,
index
)
debug_print
=
_expr
.
_debug_print
debug_print
=
_expr
.
_debug_print
src/relay/ir/debug_printer.cc
View file @
6616355d
...
@@ -223,7 +223,6 @@ class ExprDocifier : private ExprFunctor<Doc(const Expr& n)> {
...
@@ -223,7 +223,6 @@ class ExprDocifier : private ExprFunctor<Doc(const Expr& n)> {
}
}
Doc
VisitExpr_
(
const
CallNode
*
c
)
final
{
Doc
VisitExpr_
(
const
CallNode
*
c
)
final
{
auto
args
=
DocifyExprArray
(
c
->
args
);
return
Docify
(
c
->
op
)
+
Seq
(
"<"
,
DocifyExprArray
(
c
->
args
),
">"
);
return
Docify
(
c
->
op
)
+
Seq
(
"<"
,
DocifyExprArray
(
c
->
args
),
">"
);
}
}
...
@@ -244,6 +243,10 @@ class ExprDocifier : private ExprFunctor<Doc(const Expr& n)> {
...
@@ -244,6 +243,10 @@ class ExprDocifier : private ExprFunctor<Doc(const Expr& n)> {
return
DocOfStr
(
o
->
name
);
return
DocOfStr
(
o
->
name
);
}
}
Doc
VisitExpr_
(
const
TupleGetItemNode
*
g
)
final
{
return
Docify
(
g
->
tuple
)
+
DocOfStr
(
std
::
string
(
"."
)
+
std
::
to_string
(
g
->
index
));
}
public
:
public
:
ExprDocifier
(
const
Environment
&
env
)
:
env
(
env
),
td
(
env
)
{
}
ExprDocifier
(
const
Environment
&
env
)
:
env
(
env
),
td
(
env
)
{
}
...
@@ -291,7 +294,6 @@ std::string PrintType(const Environment& env, const Type& t) {
...
@@ -291,7 +294,6 @@ std::string PrintType(const Environment& env, const Type& t) {
TVM_REGISTER_API
(
"relay._expr._debug_print"
)
TVM_REGISTER_API
(
"relay._expr._debug_print"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
NodeRef
x
=
args
[
1
];
NodeRef
x
=
args
[
1
];
std
::
cout
<<
x
<<
std
::
endl
;
if
(
x
.
as
<
TypeNode
>
())
{
if
(
x
.
as
<
TypeNode
>
())
{
*
ret
=
PrintType
(
args
[
0
],
Downcast
<
Type
>
(
x
));
*
ret
=
PrintType
(
args
[
0
],
Downcast
<
Type
>
(
x
));
}
else
{
}
else
{
...
...
src/relay/ir/expr.cc
View file @
6616355d
...
@@ -193,5 +193,21 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
...
@@ -193,5 +193,21 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
<<
", "
<<
node
->
false_branch
<<
")"
;
<<
", "
<<
node
->
false_branch
<<
")"
;
});
});
TupleGetItem
TupleGetItemNode
::
make
(
Expr
tuple
,
int
index
)
{
NodePtr
<
TupleGetItemNode
>
n
=
make_node
<
TupleGetItemNode
>
();
n
->
tuple
=
std
::
move
(
tuple
);
n
->
index
=
index
;
return
TupleGetItem
(
n
);
}
TVM_REGISTER_API
(
"relay._make.TupleGetItem"
).
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
TupleGetItemNode
::
make
(
args
[
0
],
args
[
1
]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
.
set_dispatch
<
TupleGetItemNode
>
([](
const
TupleGetItemNode
*
node
,
tvm
::
IRPrinter
*
p
)
{
p
->
stream
<<
"TupleGetItemNode("
<<
node
->
tuple
<<
", "
<<
node
->
index
<<
")"
;
});
}
// namespace relay
}
// namespace relay
}
// namespace tvm
}
// namespace tvm
src/relay/ir/expr_functor.cc
View file @
6616355d
...
@@ -150,10 +150,17 @@ Expr ExprMutator::VisitExpr_(const IfNode* op) {
...
@@ -150,10 +150,17 @@ Expr ExprMutator::VisitExpr_(const IfNode* op) {
}
}
}
}
Type
ExprMutator
::
VisitType
(
const
Type
&
t
)
{
Expr
ExprMutator
::
VisitExpr_
(
const
TupleGetItemNode
*
g
)
{
return
t
;
auto
t
=
this
->
Mutate
(
g
->
tuple
);
if
(
g
->
tuple
==
t
)
{
return
GetRef
<
Expr
>
(
g
);
}
else
{
return
TupleGetItemNode
::
make
(
t
,
g
->
index
);
}
}
}
Type
ExprMutator
::
VisitType
(
const
Type
&
t
)
{
return
t
;
}
void
ExprVisitor
::
ExprVisitor
::
VisitExpr_
(
const
VarNode
*
op
)
{
void
ExprVisitor
::
ExprVisitor
::
VisitExpr_
(
const
VarNode
*
op
)
{
}
}
...
@@ -206,6 +213,10 @@ void ExprVisitor::VisitExpr_(const IfNode* op) {
...
@@ -206,6 +213,10 @@ void ExprVisitor::VisitExpr_(const IfNode* op) {
void
ExprVisitor
::
VisitExpr_
(
const
OpNode
*
op
)
{
return
;
}
void
ExprVisitor
::
VisitExpr_
(
const
OpNode
*
op
)
{
return
;
}
void
ExprVisitor
::
VisitExpr_
(
const
TupleGetItemNode
*
op
)
{
this
->
VisitExpr
(
op
->
tuple
);
}
void
ExprVisitor
::
VisitType
(
const
Type
&
t
)
{
return
;
}
void
ExprVisitor
::
VisitType
(
const
Type
&
t
)
{
return
;
}
}
// namespace relay
}
// namespace relay
...
...
src/relay/pass/alpha_eq.cc
View file @
6616355d
...
@@ -335,6 +335,15 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
...
@@ -335,6 +335,15 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
equal
=
false
;
equal
=
false
;
}
}
}
}
void
VisitExpr_
(
const
TupleGetItemNode
*
op
,
const
Expr
&
e2
)
final
{
if
(
const
TupleGetItemNode
*
proj
=
e2
.
as
<
TupleGetItemNode
>
())
{
this
->
VisitExpr
(
op
->
tuple
,
proj
->
tuple
);
equal
=
equal
&&
(
op
->
index
==
proj
->
index
);
}
else
{
equal
=
false
;
}
}
};
};
bool
AlphaEqual
(
const
Expr
&
e1
,
const
Expr
&
e2
)
{
bool
AlphaEqual
(
const
Expr
&
e1
,
const
Expr
&
e2
)
{
...
...
src/relay/pass/type_functor.h
View file @
6616355d
...
@@ -8,7 +8,6 @@
...
@@ -8,7 +8,6 @@
#include <tvm/node/ir_functor.h>
#include <tvm/node/ir_functor.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/error.h>
#include <string>
#include <string>
namespace
tvm
{
namespace
tvm
{
...
@@ -21,11 +20,11 @@ class TypeFunctor;
...
@@ -21,11 +20,11 @@ class TypeFunctor;
#define TYPE_FUNCTOR_DEFAULT \
#define TYPE_FUNCTOR_DEFAULT \
{ return VisitTypeDefault_(op, std::forward<Args>(args)...); }
{ return VisitTypeDefault_(op, std::forward<Args>(args)...); }
#define RELAY_TYPE_FUNCTOR_DISPATCH(OP) \
#define RELAY_TYPE_FUNCTOR_DISPATCH(OP)
\
vtable.template set_dispatch<OP>( \
vtable.template set_dispatch<OP>(
\
[](const NodeRef& n, TSelf* self, Args... args) { \
[](const NodeRef& n, TSelf* self, Args... args) {
\
return self->VisitType_(static_cast<const OP*>(n.node_.get()), \
return self->VisitType_(static_cast<const OP*>(n.node_.get()), \
std::forward<Args>(args)...); \
std::forward<Args>(args)...);
\
});
});
template
<
typename
R
,
typename
...
Args
>
template
<
typename
R
,
typename
...
Args
>
...
...
src/relay/pass/type_infer.cc
View file @
6616355d
...
@@ -119,6 +119,20 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
...
@@ -119,6 +119,20 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
return
TupleTypeNode
::
make
(
fields
);
return
TupleTypeNode
::
make
(
fields
);
}
}
Type
VisitExpr_
(
const
TupleGetItemNode
*
op
)
final
{
// TODO(M.K.)
// handle case where field type is not known
Type
tuple_type
=
GetType
(
op
->
tuple
);
auto
tuple_ty_node
=
tuple_type
.
as
<
TupleTypeNode
>
();
if
(
!
tuple_ty_node
)
{
LOG
(
FATAL
)
<<
"only expressions with tuple types is accepted"
<<
GetRef
<
TupleGetItem
>
(
op
);
}
if
(
static_cast
<
int
>
(
tuple_ty_node
->
fields
.
size
())
<=
op
->
index
)
{
LOG
(
FATAL
)
<<
"tuple not big enough"
<<
GetRef
<
TupleGetItem
>
(
op
);
}
return
tuple_ty_node
->
fields
[
op
->
index
];
}
Type
VisitExpr_
(
const
OpNode
*
op
)
final
{
Type
VisitExpr_
(
const
OpNode
*
op
)
final
{
return
op
->
op_type
;
return
op
->
op_type
;
}
}
...
@@ -293,6 +307,10 @@ class TypeInferencer::Resolver : public ExprMutator {
...
@@ -293,6 +307,10 @@ class TypeInferencer::Resolver : public ExprMutator {
return
AttachCheckedType
(
op
);
return
AttachCheckedType
(
op
);
}
}
Expr
VisitExpr_
(
const
TupleGetItemNode
*
op
)
final
{
return
AttachCheckedType
(
op
);
}
Expr
VisitExpr_
(
const
ParamNode
*
op
)
final
{
Expr
VisitExpr_
(
const
ParamNode
*
op
)
final
{
return
ExprMutator
::
VisitExpr_
(
op
);
return
ExprMutator
::
VisitExpr_
(
op
);
}
}
...
...
tests/python/relay/test_ir_debug_printer.py
View file @
6616355d
...
@@ -77,7 +77,7 @@ def test_call():
...
@@ -77,7 +77,7 @@ def test_call():
def
test_let
():
def
test_let
():
lv
=
relay
.
Var
(
'x'
)
lv
=
relay
.
Var
(
'x'
)
ty
=
relay
.
ty
.
TensorType
((
10
,
20
),
"float32"
)
ty
=
relay
.
ty
.
TensorType
((
10
,
20
),
'float32'
)
arr
=
tvm
.
nd
.
array
(
10
)
arr
=
tvm
.
nd
.
array
(
10
)
value
=
relay
.
Constant
(
arr
)
value
=
relay
.
Constant
(
arr
)
let
=
relay
.
Let
(
lv
,
value
,
lv
,
ty
)
let
=
relay
.
Let
(
lv
,
value
,
lv
,
ty
)
...
@@ -90,3 +90,8 @@ def test_if():
...
@@ -90,3 +90,8 @@ def test_if():
right
=
relay
.
Var
(
'right'
)
right
=
relay
.
Var
(
'right'
)
ife
=
relay
.
If
(
cond
,
left
,
right
)
ife
=
relay
.
If
(
cond
,
left
,
right
)
show
(
ife
)
show
(
ife
)
def
test_tuple_get_item
():
t
=
relay
.
Var
(
't'
)
g
=
relay
.
TupleGetItem
(
t
,
0
)
show
(
g
)
tests/python/relay/test_ir_nodes.py
View file @
6616355d
...
@@ -175,6 +175,13 @@ def test_if():
...
@@ -175,6 +175,13 @@ def test_if():
str
(
ife
)
str
(
ife
)
def
test_tuple_get_item
():
tup
=
relay
.
Var
(
"tuple"
)
get
=
relay
.
TupleGetItem
(
tup
,
1
)
assert
get
.
tuple
==
tup
assert
get
.
index
==
1
str
(
get
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_bad_constructor
()
test_bad_constructor
()
test_span
()
test_span
()
...
@@ -192,3 +199,4 @@ if __name__ == "__main__":
...
@@ -192,3 +199,4 @@ if __name__ == "__main__":
test_call
()
test_call
()
test_let
()
test_let
()
test_if
()
test_if
()
test_tuple_get_item
()
tests/python/relay/test_ir_well_formed.py
View file @
6616355d
...
@@ -3,7 +3,7 @@ from tvm import relay
...
@@ -3,7 +3,7 @@ from tvm import relay
from
tvm.relay.ir_pass
import
well_formed
from
tvm.relay.ir_pass
import
well_formed
def
test_well_formed
():
def
test_well_formed
():
x
=
relay
.
Var
(
"x"
)
x
=
relay
.
Var
(
'x'
)
assert
well_formed
(
x
)
assert
well_formed
(
x
)
v
=
relay
.
Constant
(
tvm
.
nd
.
array
(
10
))
v
=
relay
.
Constant
(
tvm
.
nd
.
array
(
10
))
ty
=
None
ty
=
None
...
@@ -16,3 +16,19 @@ def test_well_formed():
...
@@ -16,3 +16,19 @@ def test_well_formed():
# but we want all binder to be distinct from each other.
# but we want all binder to be distinct from each other.
assert
not
well_formed
(
relay
.
Let
(
relay
.
Var
(
"y"
),
f
,
assert
not
well_formed
(
relay
.
Let
(
relay
.
Var
(
"y"
),
f
,
relay
.
Let
(
relay
.
Var
(
"z"
),
f
,
v
,
ty
),
ty
))
relay
.
Let
(
relay
.
Var
(
"z"
),
f
,
v
,
ty
),
ty
))
def
test_tuple
():
x
=
relay
.
Var
(
'x'
)
assert
well_formed
(
x
)
v
=
relay
.
Constant
(
tvm
.
nd
.
array
(
10
))
ty
=
None
let
=
relay
.
Let
(
x
,
v
,
x
,
ty
)
assert
well_formed
(
let
)
assert
well_formed
(
relay
.
Tuple
([
v
,
v
]))
assert
not
well_formed
(
relay
.
Tuple
([
let
,
let
]))
def
test_tuple_get_item
():
t
=
relay
.
Var
(
't'
)
assert
well_formed
(
relay
.
TupleGetItem
(
t
,
2
))
tests/python/relay/test_pass_alpha_equal.py
View file @
6616355d
...
@@ -167,11 +167,19 @@ def test_type_relation_alpha_equal():
...
@@ -167,11 +167,19 @@ def test_type_relation_alpha_equal():
assert
bigger
!=
diff_num_inputs
assert
bigger
!=
diff_num_inputs
def
test_tuple_get_item_alpha_equal
():
x
=
relay
.
Var
(
'x'
)
y
=
relay
.
Var
(
'y'
)
assert
not
alpha_equal
(
relay
.
TupleGetItem
(
x
,
1
),
relay
.
TupleGetItem
(
y
,
1
))
assert
not
alpha_equal
(
relay
.
TupleGetItem
(
x
,
1
),
relay
.
TupleGetItem
(
x
,
2
))
assert
alpha_equal
(
relay
.
TupleGetItem
(
x
,
1
),
relay
.
TupleGetItem
(
x
,
1
))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_tensor_type_alpha_equal
()
test_tensor_type_alpha_equal
()
test_incomplete_type_alpha_equal
()
test_incomplete_type_alpha_equal
()
test_constant_alpha_equal
()
test_type_param_alpha_equal
()
test_type_param_alpha_equal
()
test_func_type_alpha_equal
()
test_func_type_alpha_equal
()
test_tuple_type_alpha_equal
()
test_tuple_type_alpha_equal
()
test_type_relation_alpha_equal
()
test_type_relation_alpha_equal
()
test_tuple_get_item_alpha_equal
()
tests/python/relay/test_pass_dead_code_elimination.py
View file @
6616355d
...
@@ -4,6 +4,7 @@ from tvm.relay.ir_pass import dead_code_elimination, alpha_equal
...
@@ -4,6 +4,7 @@ from tvm.relay.ir_pass import dead_code_elimination, alpha_equal
from
tvm.relay.ir_builder
import
convert
,
IRBuilder
from
tvm.relay.ir_builder
import
convert
,
IRBuilder
from
tvm.relay.op
import
log
,
add
,
equal
,
subtract
from
tvm.relay.op
import
log
,
add
,
equal
,
subtract
class
env
:
class
env
:
def
__init__
(
self
):
def
__init__
(
self
):
self
.
a
=
relay
.
Var
(
"a"
)
self
.
a
=
relay
.
Var
(
"a"
)
...
@@ -22,20 +23,25 @@ class env:
...
@@ -22,20 +23,25 @@ class env:
self
.
two
=
convert
(
2.0
)
self
.
two
=
convert
(
2.0
)
self
.
three
=
convert
(
3.0
)
self
.
three
=
convert
(
3.0
)
e
=
env
()
e
=
env
()
def
test_let
():
def
test_let
():
orig
=
relay
.
Let
(
e
.
x
,
e
.
y
,
e
.
z
,
e
.
tt
)
orig
=
relay
.
Let
(
e
.
x
,
e
.
y
,
e
.
z
,
e
.
tt
)
assert
alpha_equal
(
dead_code_elimination
(
orig
),
e
.
z
)
assert
alpha_equal
(
dead_code_elimination
(
orig
),
e
.
z
)
def
test_used_let
():
def
test_used_let
():
orig
=
relay
.
Let
(
e
.
a
,
e
.
b
,
relay
.
Let
(
e
.
c
,
e
.
d
,
e
.
c
,
e
.
tt
),
e
.
tt
)
orig
=
relay
.
Let
(
e
.
a
,
e
.
b
,
relay
.
Let
(
e
.
c
,
e
.
d
,
e
.
c
,
e
.
tt
),
e
.
tt
)
assert
alpha_equal
(
dead_code_elimination
(
orig
),
relay
.
Let
(
e
.
c
,
e
.
d
,
e
.
c
,
e
.
tt
))
assert
alpha_equal
(
dead_code_elimination
(
orig
),
relay
.
Let
(
e
.
c
,
e
.
d
,
e
.
c
,
e
.
tt
))
def
test_chain_unused_let
():
def
test_chain_unused_let
():
orig
=
relay
.
Let
(
e
.
a
,
e
.
b
,
relay
.
Let
(
e
.
c
,
e
.
d
,
e
.
e
,
e
.
tt
),
e
.
tt
)
orig
=
relay
.
Let
(
e
.
a
,
e
.
b
,
relay
.
Let
(
e
.
c
,
e
.
d
,
e
.
e
,
e
.
tt
),
e
.
tt
)
assert
alpha_equal
(
dead_code_elimination
(
orig
),
e
.
e
)
assert
alpha_equal
(
dead_code_elimination
(
orig
),
e
.
e
)
# make sure we dont infinite loop
# make sure we dont infinite loop
def
test_recursion
():
def
test_recursion
():
"""
"""
...
@@ -60,14 +66,23 @@ def test_recursion():
...
@@ -60,14 +66,23 @@ def test_recursion():
assert
alpha_equal
(
dead_code_elimination
(
orig
),
orig
)
assert
alpha_equal
(
dead_code_elimination
(
orig
),
orig
)
assert
alpha_equal
(
dead_code_elimination
(
relay
.
Let
(
f
,
funcbody
,
e
.
three
,
e
.
float32
)),
e
.
three
)
assert
alpha_equal
(
dead_code_elimination
(
relay
.
Let
(
f
,
funcbody
,
e
.
three
,
e
.
float32
)),
e
.
three
)
def
test_op_let
():
def
test_op_let
():
assert
alpha_equal
(
dead_code_elimination
(
add
(
relay
.
Let
(
e
.
a
,
e
.
one
,
e
.
three
,
e
.
float32
),
e
.
two
)),
add
(
e
.
three
,
e
.
two
))
assert
alpha_equal
(
dead_code_elimination
(
add
(
relay
.
Let
(
e
.
a
,
e
.
one
,
e
.
three
,
e
.
float32
),
e
.
two
)),
add
(
e
.
three
,
e
.
two
))
def
test_if
():
def
test_if
():
orig
=
relay
.
If
(
convert
(
True
),
e
.
a
,
e
.
b
)
orig
=
relay
.
If
(
convert
(
True
),
e
.
a
,
e
.
b
)
assert
alpha_equal
(
dead_code_elimination
(
orig
),
e
.
a
)
assert
alpha_equal
(
dead_code_elimination
(
orig
),
e
.
a
)
def
test_tuple_get_item
():
t
=
relay
.
Var
(
't'
)
g
=
relay
.
TupleGetItem
(
t
,
0
)
assert
alpha_equal
(
dead_code_elimination
(
g
),
g
)
assert
alpha_equal
(
dead_code_elimination
(
relay
.
TupleGetItem
(
relay
.
Let
(
e
.
a
,
e
.
one
,
t
,
e
.
float32
),
0
)),
g
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_let
()
test_let
()
test_used_let
()
test_used_let
()
...
@@ -75,3 +90,4 @@ if __name__ == "__main__":
...
@@ -75,3 +90,4 @@ if __name__ == "__main__":
test_recursion
()
test_recursion
()
test_op_let
()
test_op_let
()
test_if
()
test_if
()
test_tuple_get_item
()
tests/python/relay/test_pass_free_vars.py
View file @
6616355d
...
@@ -15,6 +15,17 @@ def test_free_vars():
...
@@ -15,6 +15,17 @@ def test_free_vars():
f
=
relay
.
Function
([
relay
.
Param
(
x
,
ty
)],
ty
,
x
)
f
=
relay
.
Function
([
relay
.
Param
(
x
,
ty
)],
ty
,
x
)
assert
len
(
free_vars
(
f
))
==
0
assert
len
(
free_vars
(
f
))
==
0
def
test_tuple
():
t
=
relay
.
Var
(
't'
)
fv
=
free_vars
(
relay
.
Tuple
([
t
,
t
]))
assert
len
(
fv
)
==
1
assert
fv
[
0
]
==
t
fv
=
free_vars
(
relay
.
TupleGetItem
(
t
,
123
))
assert
len
(
fv
)
==
1
assert
fv
[
0
]
==
t
def
test_free_type_vars
():
def
test_free_type_vars
():
tp
=
relay
.
TypeParam
(
""
)
tp
=
relay
.
TypeParam
(
""
)
ty
=
relay
.
TupleType
([
tp
,
relay
.
TensorType
([],
"int32"
)])
ty
=
relay
.
TupleType
([
tp
,
relay
.
TensorType
([],
"int32"
)])
...
...
tests/python/relay/test_type_infer.py
View file @
6616355d
...
@@ -9,6 +9,7 @@ from tvm.relay.ir_builder import scalar_type, convert, tensor_type
...
@@ -9,6 +9,7 @@ from tvm.relay.ir_builder import scalar_type, convert, tensor_type
from
tvm.relay.env
import
Environment
from
tvm.relay.env
import
Environment
from
tvm.relay.op
import
log
,
add
,
equal
,
subtract
,
concatenate
from
tvm.relay.op
import
log
,
add
,
equal
,
subtract
,
concatenate
from
tvm.relay.expr
import
Function
from
tvm.relay.expr
import
Function
from
tvm
import
relay
def
assert_has_type
(
expr
,
typ
,
env
=
Environment
({})):
def
assert_has_type
(
expr
,
typ
,
env
=
Environment
({})):
checked_expr
=
infer_type
(
env
,
expr
)
checked_expr
=
infer_type
(
env
,
expr
)
...
@@ -110,6 +111,16 @@ def test_concat():
...
@@ -110,6 +111,16 @@ def test_concat():
fn_ty
=
func_type
([
tensor_type
(
3
,
2
),
tensor_type
(
2
,
2
)],
tensor_type
(
5
,
2
))
fn_ty
=
func_type
([
tensor_type
(
3
,
2
),
tensor_type
(
2
,
2
)],
tensor_type
(
5
,
2
))
assert_decl_has_type
(
ib
.
env
,
try_concat2
,
fn_ty
)
assert_decl_has_type
(
ib
.
env
,
try_concat2
,
fn_ty
)
def
test_tuple
():
ib
=
IRBuilder
()
dup
=
ib
.
global_var
(
'dup'
)
x
=
ib
.
param
(
'x'
)
with
ib
.
decl
(
dup
,
x
):
ib
.
ret
(
relay
.
Tuple
([
x
,
x
]))
# todo: why is this not generalized?
fn_ty
=
func_type
([
tensor_type
()],
relay
.
TupleType
([
tensor_type
(),
tensor_type
()]))
assert_decl_has_type
(
ib
.
env
,
dup
,
fn_ty
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_dual_op
()
test_dual_op
()
test_recursion
()
test_recursion
()
...
@@ -117,3 +128,4 @@ if __name__ == "__main__":
...
@@ -117,3 +128,4 @@ if __name__ == "__main__":
test_decl
()
test_decl
()
test_recursion
()
test_recursion
()
test_concat
()
test_concat
()
test_tuple
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment