Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
6616355d
Commit
6616355d
authored
Oct 10, 2018
by
雾雨魔理沙
Committed by
Tianqi Chen
Oct 10, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay] GetItem (#1861)
parent
4e309e67
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
177 additions
and
13 deletions
+177
-13
include/tvm/relay/expr.h
+22
-2
include/tvm/relay/expr_functor.h
+4
-0
python/tvm/relay/__init__.py
+1
-0
python/tvm/relay/expr.py
+8
-0
src/relay/ir/debug_printer.cc
+4
-2
src/relay/ir/expr.cc
+16
-0
src/relay/ir/expr_functor.cc
+13
-2
src/relay/pass/alpha_eq.cc
+9
-0
src/relay/pass/type_functor.h
+4
-5
src/relay/pass/type_infer.cc
+18
-0
tests/python/relay/test_ir_debug_printer.py
+6
-1
tests/python/relay/test_ir_nodes.py
+8
-0
tests/python/relay/test_ir_well_formed.py
+17
-1
tests/python/relay/test_pass_alpha_equal.py
+8
-0
tests/python/relay/test_pass_dead_code_elimination.py
+16
-0
tests/python/relay/test_pass_free_vars.py
+11
-0
tests/python/relay/test_type_infer.py
+12
-0
No files found.
include/tvm/relay/expr.h
View file @
6616355d
...
...
@@ -360,8 +360,6 @@ class IfNode : public ExprNode {
/*! \brief The expression evaluated when condition is false */
Expr
false_branch
;
IfNode
()
{}
void
VisitAttrs
(
tvm
::
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"cond"
,
&
cond
);
v
->
Visit
(
"true_branch"
,
&
true_branch
);
...
...
@@ -378,6 +376,28 @@ class IfNode : public ExprNode {
RELAY_DEFINE_NODE_REF
(
If
,
IfNode
,
Expr
);
/*! \brief Get a field out of a tuple. */
class
TupleGetItem
;
class
TupleGetItemNode
:
public
ExprNode
{
public
:
/*! \brief The tuple */
Expr
tuple
;
/*! \brief which value to get */
int
index
;
void
VisitAttrs
(
tvm
::
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"tuple"
,
&
tuple
);
v
->
Visit
(
"index"
,
&
index
);
}
TVM_DLL
static
TupleGetItem
make
(
Expr
tuple
,
int
index
);
static
constexpr
const
char
*
_type_key
=
"relay.GetItem"
;
TVM_DECLARE_NODE_TYPE_INFO
(
TupleGetItemNode
,
ExprNode
);
};
RELAY_DEFINE_NODE_REF
(
TupleGetItem
,
TupleGetItemNode
,
Expr
);
/*! \brief Print a debug representation of the expression to the stream.
* \param env The environment.
* \param e The expression
...
...
include/tvm/relay/expr_functor.h
View file @
6616355d
...
...
@@ -89,6 +89,7 @@ class ExprFunctor<R(const Expr& n, Args...)> {
Args
...
args
)
EXPR_FUNCTOR_DEFAULT
;
virtual
R
VisitExpr_
(
const
OpNode
*
op
,
Args
...
args
)
EXPR_FUNCTOR_DEFAULT
;
virtual
R
VisitExpr_
(
const
TupleGetItemNode
*
op
,
Args
...
args
)
EXPR_FUNCTOR_DEFAULT
;
virtual
R
VisitExprDefault_
(
const
Node
*
op
,
Args
...)
{
throw
Error
(
std
::
string
(
"Do not have a default for "
)
+
op
->
type_key
());
}
...
...
@@ -108,6 +109,7 @@ class ExprFunctor<R(const Expr& n, Args...)> {
RELAY_EXPR_FUNCTOR_DISPATCH
(
LetNode
);
RELAY_EXPR_FUNCTOR_DISPATCH
(
IfNode
);
RELAY_EXPR_FUNCTOR_DISPATCH
(
OpNode
);
RELAY_EXPR_FUNCTOR_DISPATCH
(
TupleGetItemNode
);
return
vtable
;
}
};
...
...
@@ -131,6 +133,7 @@ class ExprVisitor : public ::tvm::relay::ExprFunctor<void(const Expr& n)> {
void
VisitExpr_
(
const
LetNode
*
op
)
override
;
void
VisitExpr_
(
const
IfNode
*
op
)
override
;
void
VisitExpr_
(
const
OpNode
*
op
)
override
;
void
VisitExpr_
(
const
TupleGetItemNode
*
op
)
override
;
virtual
void
VisitType
(
const
Type
&
t
);
};
...
...
@@ -153,6 +156,7 @@ class ExprMutator
Expr
VisitExpr_
(
const
CallNode
*
call_node
)
override
;
Expr
VisitExpr_
(
const
LetNode
*
op
)
override
;
Expr
VisitExpr_
(
const
IfNode
*
op
)
override
;
Expr
VisitExpr_
(
const
TupleGetItemNode
*
op
)
override
;
/*! \brief Used to visit the types inside of expressions.
*
* Can be overloaded to transform the types in arbitrary
...
...
python/tvm/relay/__init__.py
View file @
6616355d
...
...
@@ -39,3 +39,4 @@ Function = expr.Function
Call
=
expr
.
Call
Let
=
expr
.
Let
If
=
expr
.
If
TupleGetItem
=
expr
.
TupleGetItem
python/tvm/relay/expr.py
View file @
6616355d
...
...
@@ -125,4 +125,12 @@ class If(Expr):
self
.
__init_handle_by_constructor__
(
_make
.
If
,
cond
,
true_value
,
false_value
)
@register_relay_node
class
TupleGetItem
(
Expr
):
"""An expression that get field from tuple in Relay, see tvm/relay/expr.h for more details."""
def
__init__
(
self
,
tuple_
,
index
):
self
.
__init_handle_by_constructor__
(
_make
.
TupleGetItem
,
tuple_
,
index
)
debug_print
=
_expr
.
_debug_print
src/relay/ir/debug_printer.cc
View file @
6616355d
...
...
@@ -223,7 +223,6 @@ class ExprDocifier : private ExprFunctor<Doc(const Expr& n)> {
}
Doc
VisitExpr_
(
const
CallNode
*
c
)
final
{
auto
args
=
DocifyExprArray
(
c
->
args
);
return
Docify
(
c
->
op
)
+
Seq
(
"<"
,
DocifyExprArray
(
c
->
args
),
">"
);
}
...
...
@@ -244,6 +243,10 @@ class ExprDocifier : private ExprFunctor<Doc(const Expr& n)> {
return
DocOfStr
(
o
->
name
);
}
Doc
VisitExpr_
(
const
TupleGetItemNode
*
g
)
final
{
return
Docify
(
g
->
tuple
)
+
DocOfStr
(
std
::
string
(
"."
)
+
std
::
to_string
(
g
->
index
));
}
public
:
ExprDocifier
(
const
Environment
&
env
)
:
env
(
env
),
td
(
env
)
{
}
...
...
@@ -291,7 +294,6 @@ std::string PrintType(const Environment& env, const Type& t) {
TVM_REGISTER_API
(
"relay._expr._debug_print"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
NodeRef
x
=
args
[
1
];
std
::
cout
<<
x
<<
std
::
endl
;
if
(
x
.
as
<
TypeNode
>
())
{
*
ret
=
PrintType
(
args
[
0
],
Downcast
<
Type
>
(
x
));
}
else
{
...
...
src/relay/ir/expr.cc
View file @
6616355d
...
...
@@ -193,5 +193,21 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
<<
", "
<<
node
->
false_branch
<<
")"
;
});
TupleGetItem
TupleGetItemNode
::
make
(
Expr
tuple
,
int
index
)
{
NodePtr
<
TupleGetItemNode
>
n
=
make_node
<
TupleGetItemNode
>
();
n
->
tuple
=
std
::
move
(
tuple
);
n
->
index
=
index
;
return
TupleGetItem
(
n
);
}
TVM_REGISTER_API
(
"relay._make.TupleGetItem"
).
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
TupleGetItemNode
::
make
(
args
[
0
],
args
[
1
]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
.
set_dispatch
<
TupleGetItemNode
>
([](
const
TupleGetItemNode
*
node
,
tvm
::
IRPrinter
*
p
)
{
p
->
stream
<<
"TupleGetItemNode("
<<
node
->
tuple
<<
", "
<<
node
->
index
<<
")"
;
});
}
// namespace relay
}
// namespace tvm
src/relay/ir/expr_functor.cc
View file @
6616355d
...
...
@@ -150,10 +150,17 @@ Expr ExprMutator::VisitExpr_(const IfNode* op) {
}
}
Type
ExprMutator
::
VisitType
(
const
Type
&
t
)
{
return
t
;
Expr
ExprMutator
::
VisitExpr_
(
const
TupleGetItemNode
*
g
)
{
auto
t
=
this
->
Mutate
(
g
->
tuple
);
if
(
g
->
tuple
==
t
)
{
return
GetRef
<
Expr
>
(
g
);
}
else
{
return
TupleGetItemNode
::
make
(
t
,
g
->
index
);
}
}
Type
ExprMutator
::
VisitType
(
const
Type
&
t
)
{
return
t
;
}
void
ExprVisitor
::
ExprVisitor
::
VisitExpr_
(
const
VarNode
*
op
)
{
}
...
...
@@ -206,6 +213,10 @@ void ExprVisitor::VisitExpr_(const IfNode* op) {
void
ExprVisitor
::
VisitExpr_
(
const
OpNode
*
op
)
{
return
;
}
void
ExprVisitor
::
VisitExpr_
(
const
TupleGetItemNode
*
op
)
{
this
->
VisitExpr
(
op
->
tuple
);
}
void
ExprVisitor
::
VisitType
(
const
Type
&
t
)
{
return
;
}
}
// namespace relay
...
...
src/relay/pass/alpha_eq.cc
View file @
6616355d
...
...
@@ -335,6 +335,15 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
equal
=
false
;
}
}
void
VisitExpr_
(
const
TupleGetItemNode
*
op
,
const
Expr
&
e2
)
final
{
if
(
const
TupleGetItemNode
*
proj
=
e2
.
as
<
TupleGetItemNode
>
())
{
this
->
VisitExpr
(
op
->
tuple
,
proj
->
tuple
);
equal
=
equal
&&
(
op
->
index
==
proj
->
index
);
}
else
{
equal
=
false
;
}
}
};
bool
AlphaEqual
(
const
Expr
&
e1
,
const
Expr
&
e2
)
{
...
...
src/relay/pass/type_functor.h
View file @
6616355d
...
...
@@ -8,7 +8,6 @@
#include <tvm/node/ir_functor.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/error.h>
#include <string>
namespace
tvm
{
...
...
@@ -21,11 +20,11 @@ class TypeFunctor;
#define TYPE_FUNCTOR_DEFAULT \
{ return VisitTypeDefault_(op, std::forward<Args>(args)...); }
#define RELAY_TYPE_FUNCTOR_DISPATCH(OP) \
vtable.template set_dispatch<OP>( \
[](const NodeRef& n, TSelf* self, Args... args) { \
#define RELAY_TYPE_FUNCTOR_DISPATCH(OP)
\
vtable.template set_dispatch<OP>(
\
[](const NodeRef& n, TSelf* self, Args... args) {
\
return self->VisitType_(static_cast<const OP*>(n.node_.get()), \
std::forward<Args>(args)...); \
std::forward<Args>(args)...);
\
});
template
<
typename
R
,
typename
...
Args
>
...
...
src/relay/pass/type_infer.cc
View file @
6616355d
...
...
@@ -119,6 +119,20 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
return
TupleTypeNode
::
make
(
fields
);
}
Type
VisitExpr_
(
const
TupleGetItemNode
*
op
)
final
{
// TODO(M.K.)
// handle case where field type is not known
Type
tuple_type
=
GetType
(
op
->
tuple
);
auto
tuple_ty_node
=
tuple_type
.
as
<
TupleTypeNode
>
();
if
(
!
tuple_ty_node
)
{
LOG
(
FATAL
)
<<
"only expressions with tuple types is accepted"
<<
GetRef
<
TupleGetItem
>
(
op
);
}
if
(
static_cast
<
int
>
(
tuple_ty_node
->
fields
.
size
())
<=
op
->
index
)
{
LOG
(
FATAL
)
<<
"tuple not big enough"
<<
GetRef
<
TupleGetItem
>
(
op
);
}
return
tuple_ty_node
->
fields
[
op
->
index
];
}
Type
VisitExpr_
(
const
OpNode
*
op
)
final
{
return
op
->
op_type
;
}
...
...
@@ -293,6 +307,10 @@ class TypeInferencer::Resolver : public ExprMutator {
return
AttachCheckedType
(
op
);
}
Expr
VisitExpr_
(
const
TupleGetItemNode
*
op
)
final
{
return
AttachCheckedType
(
op
);
}
Expr
VisitExpr_
(
const
ParamNode
*
op
)
final
{
return
ExprMutator
::
VisitExpr_
(
op
);
}
...
...
tests/python/relay/test_ir_debug_printer.py
View file @
6616355d
...
...
@@ -77,7 +77,7 @@ def test_call():
def
test_let
():
lv
=
relay
.
Var
(
'x'
)
ty
=
relay
.
ty
.
TensorType
((
10
,
20
),
"float32"
)
ty
=
relay
.
ty
.
TensorType
((
10
,
20
),
'float32'
)
arr
=
tvm
.
nd
.
array
(
10
)
value
=
relay
.
Constant
(
arr
)
let
=
relay
.
Let
(
lv
,
value
,
lv
,
ty
)
...
...
@@ -90,3 +90,8 @@ def test_if():
right
=
relay
.
Var
(
'right'
)
ife
=
relay
.
If
(
cond
,
left
,
right
)
show
(
ife
)
def
test_tuple_get_item
():
t
=
relay
.
Var
(
't'
)
g
=
relay
.
TupleGetItem
(
t
,
0
)
show
(
g
)
tests/python/relay/test_ir_nodes.py
View file @
6616355d
...
...
@@ -175,6 +175,13 @@ def test_if():
str
(
ife
)
def
test_tuple_get_item
():
tup
=
relay
.
Var
(
"tuple"
)
get
=
relay
.
TupleGetItem
(
tup
,
1
)
assert
get
.
tuple
==
tup
assert
get
.
index
==
1
str
(
get
)
if
__name__
==
"__main__"
:
test_bad_constructor
()
test_span
()
...
...
@@ -192,3 +199,4 @@ if __name__ == "__main__":
test_call
()
test_let
()
test_if
()
test_tuple_get_item
()
tests/python/relay/test_ir_well_formed.py
View file @
6616355d
...
...
@@ -3,7 +3,7 @@ from tvm import relay
from
tvm.relay.ir_pass
import
well_formed
def
test_well_formed
():
x
=
relay
.
Var
(
"x"
)
x
=
relay
.
Var
(
'x'
)
assert
well_formed
(
x
)
v
=
relay
.
Constant
(
tvm
.
nd
.
array
(
10
))
ty
=
None
...
...
@@ -16,3 +16,19 @@ def test_well_formed():
# but we want all binder to be distinct from each other.
assert
not
well_formed
(
relay
.
Let
(
relay
.
Var
(
"y"
),
f
,
relay
.
Let
(
relay
.
Var
(
"z"
),
f
,
v
,
ty
),
ty
))
def
test_tuple
():
x
=
relay
.
Var
(
'x'
)
assert
well_formed
(
x
)
v
=
relay
.
Constant
(
tvm
.
nd
.
array
(
10
))
ty
=
None
let
=
relay
.
Let
(
x
,
v
,
x
,
ty
)
assert
well_formed
(
let
)
assert
well_formed
(
relay
.
Tuple
([
v
,
v
]))
assert
not
well_formed
(
relay
.
Tuple
([
let
,
let
]))
def
test_tuple_get_item
():
t
=
relay
.
Var
(
't'
)
assert
well_formed
(
relay
.
TupleGetItem
(
t
,
2
))
tests/python/relay/test_pass_alpha_equal.py
View file @
6616355d
...
...
@@ -167,11 +167,19 @@ def test_type_relation_alpha_equal():
assert
bigger
!=
diff_num_inputs
def
test_tuple_get_item_alpha_equal
():
x
=
relay
.
Var
(
'x'
)
y
=
relay
.
Var
(
'y'
)
assert
not
alpha_equal
(
relay
.
TupleGetItem
(
x
,
1
),
relay
.
TupleGetItem
(
y
,
1
))
assert
not
alpha_equal
(
relay
.
TupleGetItem
(
x
,
1
),
relay
.
TupleGetItem
(
x
,
2
))
assert
alpha_equal
(
relay
.
TupleGetItem
(
x
,
1
),
relay
.
TupleGetItem
(
x
,
1
))
if
__name__
==
"__main__"
:
test_tensor_type_alpha_equal
()
test_incomplete_type_alpha_equal
()
test_constant_alpha_equal
()
test_type_param_alpha_equal
()
test_func_type_alpha_equal
()
test_tuple_type_alpha_equal
()
test_type_relation_alpha_equal
()
test_tuple_get_item_alpha_equal
()
tests/python/relay/test_pass_dead_code_elimination.py
View file @
6616355d
...
...
@@ -4,6 +4,7 @@ from tvm.relay.ir_pass import dead_code_elimination, alpha_equal
from
tvm.relay.ir_builder
import
convert
,
IRBuilder
from
tvm.relay.op
import
log
,
add
,
equal
,
subtract
class
env
:
def
__init__
(
self
):
self
.
a
=
relay
.
Var
(
"a"
)
...
...
@@ -22,20 +23,25 @@ class env:
self
.
two
=
convert
(
2.0
)
self
.
three
=
convert
(
3.0
)
e
=
env
()
def
test_let
():
orig
=
relay
.
Let
(
e
.
x
,
e
.
y
,
e
.
z
,
e
.
tt
)
assert
alpha_equal
(
dead_code_elimination
(
orig
),
e
.
z
)
def
test_used_let
():
orig
=
relay
.
Let
(
e
.
a
,
e
.
b
,
relay
.
Let
(
e
.
c
,
e
.
d
,
e
.
c
,
e
.
tt
),
e
.
tt
)
assert
alpha_equal
(
dead_code_elimination
(
orig
),
relay
.
Let
(
e
.
c
,
e
.
d
,
e
.
c
,
e
.
tt
))
def
test_chain_unused_let
():
orig
=
relay
.
Let
(
e
.
a
,
e
.
b
,
relay
.
Let
(
e
.
c
,
e
.
d
,
e
.
e
,
e
.
tt
),
e
.
tt
)
assert
alpha_equal
(
dead_code_elimination
(
orig
),
e
.
e
)
# make sure we dont infinite loop
def
test_recursion
():
"""
...
...
@@ -60,14 +66,23 @@ def test_recursion():
assert
alpha_equal
(
dead_code_elimination
(
orig
),
orig
)
assert
alpha_equal
(
dead_code_elimination
(
relay
.
Let
(
f
,
funcbody
,
e
.
three
,
e
.
float32
)),
e
.
three
)
def
test_op_let
():
assert
alpha_equal
(
dead_code_elimination
(
add
(
relay
.
Let
(
e
.
a
,
e
.
one
,
e
.
three
,
e
.
float32
),
e
.
two
)),
add
(
e
.
three
,
e
.
two
))
def
test_if
():
orig
=
relay
.
If
(
convert
(
True
),
e
.
a
,
e
.
b
)
assert
alpha_equal
(
dead_code_elimination
(
orig
),
e
.
a
)
def
test_tuple_get_item
():
t
=
relay
.
Var
(
't'
)
g
=
relay
.
TupleGetItem
(
t
,
0
)
assert
alpha_equal
(
dead_code_elimination
(
g
),
g
)
assert
alpha_equal
(
dead_code_elimination
(
relay
.
TupleGetItem
(
relay
.
Let
(
e
.
a
,
e
.
one
,
t
,
e
.
float32
),
0
)),
g
)
if
__name__
==
"__main__"
:
test_let
()
test_used_let
()
...
...
@@ -75,3 +90,4 @@ if __name__ == "__main__":
test_recursion
()
test_op_let
()
test_if
()
test_tuple_get_item
()
tests/python/relay/test_pass_free_vars.py
View file @
6616355d
...
...
@@ -15,6 +15,17 @@ def test_free_vars():
f
=
relay
.
Function
([
relay
.
Param
(
x
,
ty
)],
ty
,
x
)
assert
len
(
free_vars
(
f
))
==
0
def
test_tuple
():
t
=
relay
.
Var
(
't'
)
fv
=
free_vars
(
relay
.
Tuple
([
t
,
t
]))
assert
len
(
fv
)
==
1
assert
fv
[
0
]
==
t
fv
=
free_vars
(
relay
.
TupleGetItem
(
t
,
123
))
assert
len
(
fv
)
==
1
assert
fv
[
0
]
==
t
def
test_free_type_vars
():
tp
=
relay
.
TypeParam
(
""
)
ty
=
relay
.
TupleType
([
tp
,
relay
.
TensorType
([],
"int32"
)])
...
...
tests/python/relay/test_type_infer.py
View file @
6616355d
...
...
@@ -9,6 +9,7 @@ from tvm.relay.ir_builder import scalar_type, convert, tensor_type
from
tvm.relay.env
import
Environment
from
tvm.relay.op
import
log
,
add
,
equal
,
subtract
,
concatenate
from
tvm.relay.expr
import
Function
from
tvm
import
relay
def
assert_has_type
(
expr
,
typ
,
env
=
Environment
({})):
checked_expr
=
infer_type
(
env
,
expr
)
...
...
@@ -110,6 +111,16 @@ def test_concat():
fn_ty
=
func_type
([
tensor_type
(
3
,
2
),
tensor_type
(
2
,
2
)],
tensor_type
(
5
,
2
))
assert_decl_has_type
(
ib
.
env
,
try_concat2
,
fn_ty
)
def
test_tuple
():
ib
=
IRBuilder
()
dup
=
ib
.
global_var
(
'dup'
)
x
=
ib
.
param
(
'x'
)
with
ib
.
decl
(
dup
,
x
):
ib
.
ret
(
relay
.
Tuple
([
x
,
x
]))
# todo: why is this not generalized?
fn_ty
=
func_type
([
tensor_type
()],
relay
.
TupleType
([
tensor_type
(),
tensor_type
()]))
assert_decl_has_type
(
ib
.
env
,
dup
,
fn_ty
)
if
__name__
==
"__main__"
:
test_dual_op
()
test_recursion
()
...
...
@@ -117,3 +128,4 @@ if __name__ == "__main__":
test_decl
()
test_recursion
()
test_concat
()
test_tuple
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment