Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
d15477cd
Commit
d15477cd
authored
Dec 01, 2018
by
Wuwei Lin
Committed by
Tianqi Chen
Nov 30, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay][Pass] Fold constant tuple (#2201)
parent
e37dbd4e
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
54 additions
and
1 deletions
+54
-1
src/relay/pass/fold_constant.cc
+34
-1
tests/python/relay/test_pass_fold_constant.py
+20
-0
No files found.
src/relay/pass/fold_constant.cc
View file @
d15477cd
...
@@ -13,6 +13,36 @@ namespace relay {
...
@@ -13,6 +13,36 @@ namespace relay {
using
FInterpreter
=
runtime
::
TypedPackedFunc
<
Value
(
Expr
)
>
;
using
FInterpreter
=
runtime
::
TypedPackedFunc
<
Value
(
Expr
)
>
;
class
ConstantChecker
:
private
ExprVisitor
{
public
:
// Check whether an expression is constant. The results are memorized.
bool
Check
(
const
Expr
&
expr
)
{
if
(
expr
.
as
<
ConstantNode
>
())
{
return
true
;
}
const
auto
it
=
memo_
.
find
(
expr
);
if
(
it
!=
memo_
.
end
())
return
it
->
second
;
VisitExpr
(
expr
);
return
memo_
[
expr
];
// return memorized result or the default value false
}
private
:
std
::
unordered_map
<
Expr
,
bool
,
NodeHash
,
NodeEqual
>
memo_
;
void
VisitExpr_
(
const
TupleNode
*
n
)
final
{
bool
result
=
true
;
for
(
const
auto
&
field
:
n
->
fields
)
{
if
(
!
Check
(
field
))
{
result
=
false
;
break
;
}
}
memo_
[
GetRef
<
Tuple
>
(
n
)]
=
result
;
}
};
// TODO(tvm-team) consider combine dead-code with constant folder.
// TODO(tvm-team) consider combine dead-code with constant folder.
// or make a more powerful partial evaluator.
// or make a more powerful partial evaluator.
class
ConstantFolder
:
public
ExprMutator
{
class
ConstantFolder
:
public
ExprMutator
{
...
@@ -53,7 +83,7 @@ class ConstantFolder : public ExprMutator {
...
@@ -53,7 +83,7 @@ class ConstantFolder : public ExprMutator {
if
(
op_stateful
.
get
(
GetRef
<
Op
>
(
op
),
false
))
return
res
;
if
(
op_stateful
.
get
(
GetRef
<
Op
>
(
op
),
false
))
return
res
;
bool
all_const_args
=
true
;
bool
all_const_args
=
true
;
for
(
Expr
arg
:
call
->
args
)
{
for
(
Expr
arg
:
call
->
args
)
{
if
(
arg
.
as
<
ConstantNode
>
()
==
nullptr
)
{
if
(
!
checker_
.
Check
(
arg
)
)
{
all_const_args
=
false
;
all_const_args
=
false
;
}
}
}
}
...
@@ -77,6 +107,9 @@ class ConstantFolder : public ExprMutator {
...
@@ -77,6 +107,9 @@ class ConstantFolder : public ExprMutator {
private
:
private
:
// Internal interepreter.
// Internal interepreter.
FInterpreter
executor_
;
FInterpreter
executor_
;
// Internal constant checker
ConstantChecker
checker_
;
// Convert value to expression.
// Convert value to expression.
Expr
ValueToExpr
(
Value
value
)
{
Expr
ValueToExpr
(
Value
value
)
{
if
(
const
auto
*
val
=
value
.
as
<
TensorValueNode
>
())
{
if
(
const
auto
*
val
=
value
.
as
<
TensorValueNode
>
())
{
...
...
tests/python/relay/test_pass_fold_constant.py
View file @
d15477cd
...
@@ -76,7 +76,27 @@ def test_fold_tuple():
...
@@ -76,7 +76,27 @@ def test_fold_tuple():
assert
relay
.
ir_pass
.
graph_equal
(
zz
,
zexpected
)
assert
relay
.
ir_pass
.
graph_equal
(
zz
,
zexpected
)
def
test_fold_concat
():
c_data
=
np
.
array
([[
1
,
2
,
3
]])
.
astype
(
"float32"
)
def
before
():
a
=
relay
.
const
(
c_data
)
b
=
relay
.
const
(
c_data
)
y
=
relay
.
concatenate
((
a
,
b
),
axis
=
0
)
return
relay
.
Function
([],
y
)
def
expected
():
y_data
=
np
.
concatenate
((
c_data
,
c_data
),
axis
=
0
)
y
=
relay
.
const
(
y_data
)
return
relay
.
Function
([],
y
)
zz
=
relay
.
ir_pass
.
fold_constant
(
before
())
zexpected
=
expected
()
assert
relay
.
ir_pass
.
graph_equal
(
zz
,
zexpected
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_fold_const
()
test_fold_const
()
test_fold_let
()
test_fold_let
()
test_fold_tuple
()
test_fold_tuple
()
test_fold_concat
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment