Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
f23a7a54
Commit
f23a7a54
authored
6 years ago
by
Wuwei Lin
Committed by
Tianqi Chen
6 years ago
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RELAY] Stop_fusion annotation (#2624)
parent
8b1d07ff
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
80 additions
and
0 deletions
+80
-0
python/tvm/relay/op/annotation/annotation.py
+16
-0
src/relay/op/annotation/annotation.cc
+26
-0
src/relay/pass/fuse_ops.cc
+4
-0
src/relay/pass/pattern_util.h
+2
-0
tests/python/relay/test_pass_fuse_ops.py
+32
-0
No files found.
python/tvm/relay/op/annotation/annotation.py
View file @
f23a7a54
...
...
@@ -29,3 +29,19 @@ def on_device(data, device):
raise
ValueError
(
"device is expected to be the type of TVMContext or "
"str, but received
%
s"
%
(
type
(
device
)))
return
_make
.
on_device
(
data
,
device
)
def
stop_fusion
(
data
):
"""Annotate an expression to prevent it being fused with previous expressions.
Parameters
----------
data : tvm.relay.Expr
The expression to be annotated.
Returns
-------
result : tvm.relay.Expr
The annotated expression.
"""
return
_make
.
stop_fusion
(
data
)
This diff is collapsed.
Click to expand it.
src/relay/op/annotation/annotation.cc
View file @
f23a7a54
...
...
@@ -9,6 +9,7 @@
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <topi/elemwise.h>
#include "../type_relations.h"
#include "../../pass/alter_op_layout.h"
...
...
@@ -37,6 +38,31 @@ RELAY_REGISTER_OP("on_device")
.
set_attr
<
FInferCorrectLayout
>
(
"FInferCorrectLayout"
,
ElemwiseArbitraryLayout
);
Expr
StopFusion
(
Expr
data
)
{
static
const
Op
&
op
=
Op
::
Get
(
"annotation.stop_fusion"
);
return
CallNode
::
make
(
op
,
{
data
},
Attrs
{},
{});
}
TVM_REGISTER_API
(
"relay.op.annotation._make.stop_fusion"
)
.
set_body_typed
<
Expr
(
Expr
)
>
([](
Expr
data
)
{
return
StopFusion
(
data
);
});
RELAY_REGISTER_OP
(
"annotation.stop_fusion"
)
.
describe
(
R"code(Annotate an expression to prevent it being fused with previous expressions.)code"
TVM_ADD_FILELINE
)
.
set_num_inputs
(
1
)
.
add_argument
(
"data"
,
"Tensor"
,
"The input data."
)
.
add_type_rel
(
"Identity"
,
IdentityRel
)
.
set_support_level
(
10
)
.
set_attr
<
TOpPattern
>
(
"TOpPattern"
,
kOpaque
)
.
set_attr
<
TOpIsStateful
>
(
"TOpIsStateful"
,
false
)
.
set_attr
<
FInferCorrectLayout
>
(
"FInferCorrectLayout"
,
ElemwiseArbitraryLayout
)
.
set_attr
<
FTVMCompute
>
(
"FTVMCompute"
,
[](
const
Attrs
&
attrs
,
const
Array
<
Tensor
>&
inputs
,
const
Type
&
out_dtype
,
const
Target
&
target
)
->
Array
<
Tensor
>
{
return
{
topi
::
identity
(
inputs
[
0
])};
});
}
// namespace relay
}
// namespace tvm
This diff is collapsed.
Click to expand it.
src/relay/pass/fuse_ops.cc
View file @
f23a7a54
...
...
@@ -741,10 +741,14 @@ class FuseMutator : private ExprMutator {
}
// Transform calls.
Expr
VisitExpr_
(
const
CallNode
*
call
)
{
static
const
Op
&
stop_fusion
=
Op
::
Get
(
"annotation.stop_fusion"
);
if
(
call
->
op
.
as
<
OpNode
>
())
{
// If it is a primitive op call
// then we must have a group assignment for it already.
CHECK
(
gmap_
.
count
(
call
));
if
(
call
->
op
.
same_as
(
stop_fusion
))
{
return
ExprMutator
::
VisitExpr
(
call
->
args
[
0
]);
}
auto
*
ret_group
=
gmap_
.
at
(
call
)
->
FindRoot
();
Array
<
Expr
>
new_args
=
GetNewArguments
(
call
->
args
,
ret_group
);
...
...
This diff is collapsed.
Click to expand it.
src/relay/pass/pattern_util.h
View file @
f23a7a54
...
...
@@ -329,6 +329,8 @@ Expr MakeConcatenate(Expr data, int axis);
Expr
MakeStridedSlice
(
Expr
data
,
Array
<
Integer
>
begin
,
Array
<
Integer
>
end
,
Array
<
Integer
>
strides
);
Expr
StopFusion
(
Expr
data
);
}
// namespace relay
}
// namespace tvm
#endif // TVM_RELAY_PASS_PATTERN_UTIL_H_
This diff is collapsed.
Click to expand it.
tests/python/relay/test_pass_fuse_ops.py
View file @
f23a7a54
...
...
@@ -220,9 +220,41 @@ def test_tuple_strided_slice():
print
(
zz
.
astext
())
def
test_stop_fusion
():
def
before
(
dshape
):
x
=
relay
.
var
(
"x"
,
shape
=
dshape
)
y
=
relay
.
add
(
x
,
relay
.
const
(
1
,
"float32"
))
y
=
relay
.
annotation
.
stop_fusion
(
y
)
z
=
relay
.
exp
(
y
)
return
relay
.
Function
([
x
],
z
)
def
expected
(
dshape
):
x
=
relay
.
var
(
"p0"
,
shape
=
dshape
)
y
=
relay
.
add
(
x
,
relay
.
const
(
1
,
"float32"
))
f1
=
relay
.
Function
([
x
],
y
)
x
=
relay
.
var
(
"p01"
,
shape
=
dshape
)
y
=
relay
.
exp
(
x
)
f2
=
relay
.
Function
([
x
],
y
)
x
=
relay
.
var
(
"x"
,
shape
=
dshape
)
y
=
relay
.
Call
(
f1
,
[
x
])
z
=
relay
.
Call
(
f2
,
[
y
])
return
relay
.
Function
([
x
],
z
)
dshape
=
(
10
,
20
)
z
=
before
(
dshape
)
z
=
relay
.
ir_pass
.
infer_type
(
z
)
z
=
relay
.
ir_pass
.
fuse_ops
(
z
)
z
=
relay
.
ir_pass
.
infer_type
(
z
)
after
=
relay
.
ir_pass
.
infer_type
(
expected
(
dshape
))
assert
relay
.
ir_pass
.
alpha_equal
(
z
,
after
)
if
__name__
==
"__main__"
:
test_fuse_simple
()
test_conv2d_fuse
()
test_concatenate
()
test_tuple_root
()
test_tuple_strided_slice
()
test_stop_fusion
()
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment