Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
1ca0393e
Commit
1ca0393e
authored
Mar 04, 2019
by
Wuwei Lin
Committed by
Tianqi Chen
Mar 03, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RELAY][PASS] Common subexpression elimination (#2639)
parent
6d460606
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
170 additions
and
0 deletions
+170
-0
python/tvm/relay/ir_pass.py
+20
-0
src/relay/pass/eliminate_common_subexpr.cc
+72
-0
src/relay/pass/pattern_util.h
+15
-0
tests/python/relay/test_pass_eliminate_common_subexpr.py
+63
-0
No files found.
python/tvm/relay/ir_pass.py
View file @
1ca0393e
...
@@ -564,3 +564,23 @@ def get_total_mac_number(expr):
...
@@ -564,3 +564,23 @@ def get_total_mac_number(expr):
The number of MACs (multiply-accumulate) of a model
The number of MACs (multiply-accumulate) of a model
"""
"""
return
_ir_pass
.
GetTotalMacNumber
(
expr
)
return
_ir_pass
.
GetTotalMacNumber
(
expr
)
def
eliminate_common_subexpr
(
expr
,
fskip
=
None
):
"""
Eliminate common subexpressions.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
fskip: function
The callback function that decides whether an expression should be skipped.
Returns
-------
expr : tvm.relay.Expr
The output expression.
"""
return
_ir_pass
.
eliminate_common_subexpr
(
expr
,
fskip
)
src/relay/pass/eliminate_common_subexpr.cc
0 → 100644
View file @
1ca0393e
/*!
* Copyright (c) 2019 by Contributors
*
* \file eliminate_common_subexpr.cc
* \brief Combine common subexpressions.
*
* This is an optimization pass that eliminates common subexpressions. During the pass, it tries
* to replace an expression with a previously appeared expression with the same input and
* attributes. The fskip callback argument allows us to skip specific expressions.
*/
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include <unordered_map>
#include "./pattern_util.h"
namespace
tvm
{
namespace
relay
{
class
CommonSubexprEliminator
:
public
ExprMutator
{
public
:
explicit
CommonSubexprEliminator
(
runtime
::
TypedPackedFunc
<
bool
(
Expr
)
>
fskip
)
:
fskip_
(
fskip
)
{}
Expr
VisitExpr_
(
const
CallNode
*
call
)
final
{
static
auto
op_stateful
=
Op
::
GetAttr
<
TOpIsStateful
>
(
"TOpIsStateful"
);
Expr
new_expr
=
ExprMutator
::
VisitExpr_
(
call
);
const
CallNode
*
new_call
=
new_expr
.
as
<
CallNode
>
();
CHECK
(
new_call
);
const
OpNode
*
op
=
new_call
->
op
.
as
<
OpNode
>
();
AttrsEqual
attrs_equal
;
if
(
new_call
->
args
.
size
()
==
0
||
op
==
nullptr
||
op_stateful
.
get
(
GetRef
<
Op
>
(
op
),
false
))
{
return
new_expr
;
}
if
(
fskip_
!=
nullptr
&&
fskip_
(
new_expr
))
{
return
new_expr
;
}
auto
it
=
expr_map_
.
find
(
new_call
->
op
);
if
(
it
!=
expr_map_
.
end
())
{
for
(
const
CallNode
*
candidate
:
it
->
second
)
{
bool
is_equivalent
=
true
;
if
(
!
attrs_equal
(
new_call
->
attrs
,
candidate
->
attrs
))
{
continue
;
}
for
(
size_t
i
=
0
;
i
<
new_call
->
args
.
size
();
i
++
)
{
if
(
!
new_call
->
args
[
i
].
same_as
(
candidate
->
args
[
i
])
&&
!
IsEqualScalar
(
new_call
->
args
[
i
],
candidate
->
args
[
i
]))
{
is_equivalent
=
false
;
break
;
}
}
if
(
!
is_equivalent
)
continue
;
return
GetRef
<
Call
>
(
candidate
);
}
}
expr_map_
[
new_call
->
op
].
push_back
(
new_call
);
return
new_expr
;
}
std
::
unordered_map
<
Expr
,
std
::
vector
<
const
CallNode
*>
,
NodeHash
,
NodeEqual
>
expr_map_
;
runtime
::
TypedPackedFunc
<
bool
(
Expr
)
>
fskip_
;
};
Expr
EliminateCommonSubexpr
(
const
Expr
&
expr
,
PackedFunc
callback
)
{
return
CommonSubexprEliminator
(
callback
)(
expr
);
}
TVM_REGISTER_API
(
"relay._ir_pass.eliminate_common_subexpr"
)
.
set_body_typed
<
Expr
(
Expr
,
PackedFunc
)
>
(
EliminateCommonSubexpr
);
}
// namespace relay
}
// namespace tvm
src/relay/pass/pattern_util.h
View file @
1ca0393e
...
@@ -191,6 +191,21 @@ inline Constant MakeConstantScalar(DataType dtype, T value) {
...
@@ -191,6 +191,21 @@ inline Constant MakeConstantScalar(DataType dtype, T value) {
return
ConstantNode
::
make
(
arr
);
return
ConstantNode
::
make
(
arr
);
}
}
/*!
* \brief Check if two expressions are equal scalars.
* \param a The expression to be checked.
* \param b The expression to be checked
* \return Whether two expressions are equal scalars.
*/
inline
bool
IsEqualScalar
(
const
Expr
&
a
,
const
Expr
&
b
)
{
const
auto
*
constant_a
=
a
.
as
<
ConstantNode
>
();
const
auto
*
constant_b
=
b
.
as
<
ConstantNode
>
();
if
(
!
constant_a
||
!
constant_b
||
!
constant_a
->
is_scalar
()
||
!
constant_b
->
is_scalar
())
{
return
false
;
}
return
AlphaEqual
(
a
,
b
);
}
inline
Expr
GetField
(
Expr
t
,
size_t
i
)
{
inline
Expr
GetField
(
Expr
t
,
size_t
i
)
{
return
TupleGetItemNode
::
make
(
t
,
i
);
return
TupleGetItemNode
::
make
(
t
,
i
);
}
}
...
...
tests/python/relay/test_pass_eliminate_common_subexpr.py
0 → 100644
View file @
1ca0393e
"""Test eliminate common subexpr pass"""
from
tvm
import
relay
from
tvm.relay.op
import
register_alter_op_layout
from
tvm.relay
import
ir_pass
def
test_simple
():
def
before
():
x
=
relay
.
var
(
"x"
,
shape
=
(
1
,
16
))
y1
=
relay
.
nn
.
relu
(
x
)
y2
=
relay
.
nn
.
relu
(
x
)
y1
=
relay
.
add
(
y1
,
relay
.
const
(
1.0
,
"float32"
))
y2
=
relay
.
add
(
y2
,
relay
.
const
(
1.0
,
"float32"
))
y
=
relay
.
add
(
y1
,
y2
)
f
=
relay
.
Function
([
x
],
y
)
return
f
def
expected
():
x
=
relay
.
var
(
"x"
,
shape
=
(
1
,
16
))
y
=
relay
.
nn
.
relu
(
x
)
y
=
relay
.
add
(
y
,
relay
.
const
(
1.0
,
"float32"
))
y
=
relay
.
add
(
y
,
y
)
f
=
relay
.
Function
([
x
],
y
)
return
f
z
=
before
()
z
=
ir_pass
.
eliminate_common_subexpr
(
z
)
assert
ir_pass
.
alpha_equal
(
z
,
expected
())
def
test_callback
():
def
before
():
x
=
relay
.
var
(
"x"
,
shape
=
(
1
,
16
))
y1
=
relay
.
nn
.
relu
(
x
)
y2
=
relay
.
nn
.
relu
(
x
)
y1
=
relay
.
add
(
y1
,
relay
.
const
(
1.0
,
"float32"
))
y2
=
relay
.
add
(
y2
,
relay
.
const
(
1.0
,
"float32"
))
y
=
relay
.
add
(
y1
,
y2
)
f
=
relay
.
Function
([
x
],
y
)
return
f
def
expected
():
x
=
relay
.
var
(
"x"
,
shape
=
(
1
,
16
))
y
=
relay
.
nn
.
relu
(
x
)
y1
=
relay
.
add
(
y
,
relay
.
const
(
1.0
,
"float32"
))
y2
=
relay
.
add
(
y
,
relay
.
const
(
1.0
,
"float32"
))
y
=
relay
.
add
(
y1
,
y2
)
f
=
relay
.
Function
([
x
],
y
)
return
f
def
fskip
(
expr
):
if
isinstance
(
expr
,
relay
.
expr
.
Call
)
and
expr
.
op
.
name
==
'add'
:
return
True
return
False
z
=
before
()
z
=
ir_pass
.
eliminate_common_subexpr
(
z
,
fskip
)
assert
ir_pass
.
alpha_equal
(
z
,
expected
())
if
__name__
==
"__main__"
:
test_simple
()
test_callback
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment