Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
887255a8
Commit
887255a8
authored
Jun 01, 2019
by
Zhi
Committed by
Jared Roesch
Jun 01, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[relay][heterogeneous] annotate using visitor (#3261)
* annotate using visitor * retrigger CI
parent
f6acf2e5
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
94 additions
and
43 deletions
+94
-43
src/relay/pass/device_annotation.cc
+9
-2
tests/python/relay/test_pass_annotation.py
+85
-41
No files found.
src/relay/pass/device_annotation.cc
View file @
887255a8
...
@@ -176,7 +176,11 @@ class RewriteAnnotation : public ExprMutator {
...
@@ -176,7 +176,11 @@ class RewriteAnnotation : public ExprMutator {
}
}
Expr
VisitExpr_
(
const
CallNode
*
call_node
)
final
{
Expr
VisitExpr_
(
const
CallNode
*
call_node
)
final
{
if
(
IsOnDeviceNode
(
call_node
)
||
IsDeviceCopyNode
(
call_node
))
{
if
(
IsOnDeviceNode
(
call_node
))
{
return
this
->
VisitExpr
(
call_node
->
args
[
0
]);
}
if
(
IsDeviceCopyNode
(
call_node
))
{
return
ExprMutator
::
VisitExpr_
(
call_node
);
return
ExprMutator
::
VisitExpr_
(
call_node
);
}
}
...
@@ -358,6 +362,9 @@ class DeviceInfo {
...
@@ -358,6 +362,9 @@ class DeviceInfo {
public
:
public
:
void
Visit
(
const
Expr
&
expr
)
{
void
Visit
(
const
Expr
&
expr
)
{
if
(
const
auto
*
fn
=
expr
.
as
<
FunctionNode
>
())
{
if
(
const
auto
*
fn
=
expr
.
as
<
FunctionNode
>
())
{
for
(
const
auto
&
param
:
fn
->
params
)
{
this
->
VisitExpr
(
param
);
}
this
->
VisitExpr
(
fn
->
body
);
this
->
VisitExpr
(
fn
->
body
);
}
else
{
}
else
{
this
->
VisitExpr
(
expr
);
this
->
VisitExpr
(
expr
);
...
@@ -402,7 +409,7 @@ class DeviceInfo {
...
@@ -402,7 +409,7 @@ class DeviceInfo {
}
}
void
VisitExpr_
(
const
VarNode
*
vn
)
final
{
void
VisitExpr_
(
const
VarNode
*
vn
)
final
{
post_dfs_order_
.
push_back
(
std
::
make_pair
(
vn
,
has_copy_
));
post_dfs_order_
.
push_back
(
std
::
make_pair
(
vn
,
has_copy_
));
}
}
void
VisitExpr_
(
const
LetNode
*
ln
)
final
{
void
VisitExpr_
(
const
LetNode
*
ln
)
final
{
...
...
tests/python/relay/test_pass_annotation.py
View file @
887255a8
...
@@ -21,6 +21,7 @@ import numpy as np
...
@@ -21,6 +21,7 @@ import numpy as np
import
tvm
import
tvm
from
tvm
import
relay
from
tvm
import
relay
from
tvm.contrib
import
graph_runtime
from
tvm.contrib
import
graph_runtime
from
tvm.relay.expr_functor
import
ExprMutator
def
test_redundant_annotation
():
def
test_redundant_annotation
():
...
@@ -34,11 +35,10 @@ def test_redundant_annotation():
...
@@ -34,11 +35,10 @@ def test_redundant_annotation():
add
=
relay
.
add
(
x
,
y
)
add
=
relay
.
add
(
x
,
y
)
_add1
=
relay
.
annotation
.
on_device
(
add
,
ctx2
)
_add1
=
relay
.
annotation
.
on_device
(
add
,
ctx2
)
_add2
=
relay
.
annotation
.
on_device
(
add
,
ctx2
)
_add2
=
relay
.
annotation
.
on_device
(
add
,
ctx2
)
sub
=
relay
.
subtract
(
add
,
z
)
sub1
=
relay
.
subtract
(
_add1
,
z
)
sub2
=
relay
.
subtract
(
_add2
,
z
)
func
=
relay
.
Function
([
x
,
y
,
z
],
func
=
relay
.
Function
([
x
,
y
,
z
],
relay
.
Tuple
([
sub1
,
sub2
]))
relay
.
Tuple
(
tvm
.
convert
([
_add1
,
_add2
,
sub
])))
func
=
relay
.
ir_pass
.
infer_type
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
func
)
func
=
relay
.
ir_pass
.
rewrite_annotated_ops
(
func
,
func
=
relay
.
ir_pass
.
rewrite_annotated_ops
(
func
,
ctx1
.
device_type
)
ctx1
.
device_type
)
...
@@ -46,9 +46,11 @@ def test_redundant_annotation():
...
@@ -46,9 +46,11 @@ def test_redundant_annotation():
def
expected
():
def
expected
():
add
=
relay
.
add
(
x
,
y
)
add
=
relay
.
add
(
x
,
y
)
copy_add_sub
=
relay
.
device_copy
(
add
,
ctx2
,
ctx1
)
copy_add_sub1
=
relay
.
device_copy
(
add
,
ctx2
,
ctx1
)
sub
=
relay
.
subtract
(
copy_add_sub
,
z
)
sub1
=
relay
.
subtract
(
copy_add_sub1
,
z
)
func
=
relay
.
Function
([
x
,
y
,
z
],
sub
)
copy_add_sub2
=
relay
.
device_copy
(
add
,
ctx2
,
ctx1
)
sub2
=
relay
.
subtract
(
copy_add_sub2
,
z
)
func
=
relay
.
Function
([
x
,
y
,
z
],
relay
.
Tuple
([
sub1
,
sub2
]))
return
func
return
func
annotated_func
=
relay
.
ir_pass
.
infer_type
(
annotated
())
annotated_func
=
relay
.
ir_pass
.
infer_type
(
annotated
())
...
@@ -66,10 +68,9 @@ def test_annotate_expr():
...
@@ -66,10 +68,9 @@ def test_annotate_expr():
def
annotated
():
def
annotated
():
add
=
relay
.
add
(
x
,
y
)
add
=
relay
.
add
(
x
,
y
)
_add
=
relay
.
annotation
.
on_device
(
add
,
ctx1
)
_add
=
relay
.
annotation
.
on_device
(
add
,
ctx1
)
sub
=
relay
.
subtract
(
add
,
z
)
sub
=
relay
.
subtract
(
_
add
,
z
)
_sub
=
relay
.
annotation
.
on_device
(
sub
,
ctx2
)
_sub
=
relay
.
annotation
.
on_device
(
sub
,
ctx2
)
expr
=
relay
.
Tuple
([
sub
,
_add
,
_sub
])
expr
=
relay
.
ir_pass
.
infer_type
(
_sub
)
expr
=
relay
.
ir_pass
.
infer_type
(
expr
)
expr
=
relay
.
ir_pass
.
rewrite_annotated_ops
(
expr
,
expr
=
relay
.
ir_pass
.
rewrite_annotated_ops
(
expr
,
ctx1
.
device_type
)
ctx1
.
device_type
)
return
expr
return
expr
...
@@ -95,12 +96,10 @@ def test_annotate_all():
...
@@ -95,12 +96,10 @@ def test_annotate_all():
def
annotated
():
def
annotated
():
add
=
relay
.
add
(
x
,
y
)
add
=
relay
.
add
(
x
,
y
)
_add
=
relay
.
annotation
.
on_device
(
add
,
ctx2
)
_add
=
relay
.
annotation
.
on_device
(
add
,
ctx2
)
sub
=
relay
.
subtract
(
add
,
z
)
sub
=
relay
.
subtract
(
_
add
,
z
)
_sub
=
relay
.
annotation
.
on_device
(
sub
,
ctx2
)
_sub
=
relay
.
annotation
.
on_device
(
sub
,
ctx2
)
func
=
relay
.
Function
([
x
,
y
,
z
],
func
=
relay
.
Function
([
x
,
y
,
z
],
_sub
)
relay
.
Tuple
(
tvm
.
convert
([
_add
,
_sub
,
sub
])))
func
=
relay
.
ir_pass
.
infer_type
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
func
)
func
=
relay
.
ir_pass
.
rewrite_annotated_ops
(
func
,
func
=
relay
.
ir_pass
.
rewrite_annotated_ops
(
func
,
ctx1
.
device_type
)
ctx1
.
device_type
)
...
@@ -168,6 +167,34 @@ def test_conv_network():
...
@@ -168,6 +167,34 @@ def test_conv_network():
dev1
=
tvm
.
context
(
1
)
dev1
=
tvm
.
context
(
1
)
dev2
=
tvm
.
context
(
2
)
dev2
=
tvm
.
context
(
2
)
def
original
():
conv2d_1
=
relay
.
nn
.
conv2d
(
data1
,
weight
,
channels
=
64
,
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
))
conv2d_2
=
relay
.
nn
.
conv2d
(
data2
,
weight
,
channels
=
64
,
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
))
add
=
relay
.
add
(
conv2d_1
,
conv2d_2
)
conv2d_3
=
relay
.
nn
.
conv2d
(
add
,
weight
,
channels
=
64
,
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
))
func
=
relay
.
Function
([
data1
,
data2
,
weight
],
conv2d_3
)
func
=
relay
.
ir_pass
.
infer_type
(
func
)
func
=
relay
.
ir_pass
.
rewrite_annotated_ops
(
func
,
tvm
.
context
(
3
)
.
device_type
)
return
func
def
annotated
():
def
annotated
():
conv2d_1
=
relay
.
nn
.
conv2d
(
conv2d_1
=
relay
.
nn
.
conv2d
(
data1
,
data1
,
...
@@ -183,25 +210,40 @@ def test_conv_network():
...
@@ -183,25 +210,40 @@ def test_conv_network():
kernel_size
=
(
3
,
3
),
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
))
padding
=
(
1
,
1
))
_conv2d_2
=
relay
.
annotation
.
on_device
(
conv2d_2
,
dev2
)
_conv2d_2
=
relay
.
annotation
.
on_device
(
conv2d_2
,
dev2
)
add
=
relay
.
add
(
conv2d_1
,
conv2d_2
)
add
=
relay
.
add
(
_conv2d_1
,
_
conv2d_2
)
_add
=
relay
.
annotation
.
on_device
(
add
,
dev1
)
_add
=
relay
.
annotation
.
on_device
(
add
,
dev1
)
conv2d_3
=
relay
.
nn
.
conv2d
(
conv2d_3
=
relay
.
nn
.
conv2d
(
add
,
_
add
,
weight
,
weight
,
channels
=
64
,
channels
=
64
,
kernel_size
=
(
3
,
3
),
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
))
padding
=
(
1
,
1
))
_conv2d_3
=
relay
.
annotation
.
on_device
(
conv2d_3
,
dev2
)
_conv2d_3
=
relay
.
annotation
.
on_device
(
conv2d_3
,
dev2
)
func
=
relay
.
Function
([
data1
,
data2
,
weight
],
func
=
relay
.
Function
([
data1
,
data2
,
weight
],
_conv2d_3
)
relay
.
Tuple
(
tvm
.
convert
([
_conv2d_1
,
_conv2d_2
,
_conv2d_3
,
_add
,
conv2d_3
])))
func
=
relay
.
ir_pass
.
infer_type
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
func
)
func
=
relay
.
ir_pass
.
rewrite_annotated_ops
(
func
,
func
=
relay
.
ir_pass
.
rewrite_annotated_ops
(
func
,
tvm
.
context
(
3
)
.
device_type
)
tvm
.
context
(
3
)
.
device_type
)
return
func
return
func
class
ScheduleConv2d
(
ExprMutator
):
def
__init__
(
self
,
device
):
self
.
device
=
device
super
()
.
__init__
()
def
visit_call
(
self
,
expr
):
visit
=
super
()
.
visit_call
(
expr
)
if
expr
.
op
==
tvm
.
relay
.
op
.
get
(
"nn.conv2d"
):
return
relay
.
annotation
.
on_device
(
visit
,
self
.
device
)
else
:
return
visit
def
annotate_with_visitor
(
func
):
sched
=
ScheduleConv2d
(
dev2
)
func
=
sched
.
visit
(
func
)
func
=
relay
.
ir_pass
.
rewrite_annotated_ops
(
func
,
dev1
.
device_type
)
return
func
def
expected
():
def
expected
():
conv2d_1
=
relay
.
nn
.
conv2d
(
conv2d_1
=
relay
.
nn
.
conv2d
(
data1
,
data1
,
...
@@ -249,10 +291,19 @@ def test_conv_network():
...
@@ -249,10 +291,19 @@ def test_conv_network():
assert
len
(
set
(
device_types
))
==
2
assert
len
(
set
(
device_types
))
==
2
assert
set
(
device_types
)
==
{
1
,
2
}
assert
set
(
device_types
)
==
{
1
,
2
}
annotated_func
=
annotated
()
def
test_manual_annotation
():
expected_func
=
expected
()
annotated_func
=
annotated
()
check_annotated_graph
(
annotated_func
,
expected_func
)
expected_func
=
expected
()
check_storage_and_device_types
()
check_annotated_graph
(
annotated_func
,
expected_func
)
check_storage_and_device_types
()
def
test_visitor_annotation
():
annotated_func
=
annotate_with_visitor
(
original
())
expected_func
=
expected
()
check_annotated_graph
(
annotated_func
,
expected_func
)
test_manual_annotation
()
test_visitor_annotation
()
def
run_fusible_network
(
dev
,
tgt
):
def
run_fusible_network
(
dev
,
tgt
):
...
@@ -321,12 +372,11 @@ def run_fusible_network(dev, tgt):
...
@@ -321,12 +372,11 @@ def run_fusible_network(dev, tgt):
sqrt
=
relay
.
sqrt
(
add
)
sqrt
=
relay
.
sqrt
(
add
)
_sqrt
=
relay
.
annotation
.
on_device
(
sqrt
,
dev_ctx
)
_sqrt
=
relay
.
annotation
.
on_device
(
sqrt
,
dev_ctx
)
log
=
relay
.
log
(
add
)
log
=
relay
.
log
(
add
)
subtract
=
relay
.
subtract
(
sqrt
,
log
)
subtract
=
relay
.
subtract
(
_
sqrt
,
log
)
exp
=
relay
.
exp
(
subtract
)
exp
=
relay
.
exp
(
subtract
)
_exp
=
relay
.
annotation
.
on_device
(
exp
,
dev_ctx
)
_exp
=
relay
.
annotation
.
on_device
(
exp
,
dev_ctx
)
func
=
relay
.
Function
([
x
,
y
],
func
=
relay
.
Function
([
x
,
y
],
_exp
)
relay
.
Tuple
(
tvm
.
convert
([
_sqrt
,
_exp
,
exp
])))
func
=
relay
.
ir_pass
.
infer_type
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
func
)
func
=
relay
.
ir_pass
.
rewrite_annotated_ops
(
func
,
func
=
relay
.
ir_pass
.
rewrite_annotated_ops
(
func
,
cpu_ctx
.
device_type
)
cpu_ctx
.
device_type
)
...
@@ -364,19 +414,16 @@ def run_fusible_network(dev, tgt):
...
@@ -364,19 +414,16 @@ def run_fusible_network(dev, tgt):
def
annotated
():
def
annotated
():
add
=
relay
.
add
(
x
,
y
)
add
=
relay
.
add
(
x
,
y
)
_add
=
relay
.
annotation
.
on_device
(
add
,
dev_ctx
)
_add
=
relay
.
annotation
.
on_device
(
add
,
dev_ctx
)
sqrt
=
relay
.
sqrt
(
add
)
sqrt
=
relay
.
sqrt
(
_
add
)
_sqrt
=
relay
.
annotation
.
on_device
(
sqrt
,
dev_ctx
)
_sqrt
=
relay
.
annotation
.
on_device
(
sqrt
,
dev_ctx
)
log
=
relay
.
log
(
add
)
log
=
relay
.
log
(
_
add
)
_log
=
relay
.
annotation
.
on_device
(
log
,
dev_ctx
)
_log
=
relay
.
annotation
.
on_device
(
log
,
dev_ctx
)
subtract
=
relay
.
subtract
(
sqrt
,
log
)
subtract
=
relay
.
subtract
(
_sqrt
,
_
log
)
_subtract
=
relay
.
annotation
.
on_device
(
subtract
,
dev_ctx
)
_subtract
=
relay
.
annotation
.
on_device
(
subtract
,
dev_ctx
)
exp
=
relay
.
exp
(
subtract
)
exp
=
relay
.
exp
(
_
subtract
)
_exp
=
relay
.
annotation
.
on_device
(
exp
,
dev_ctx
)
_exp
=
relay
.
annotation
.
on_device
(
exp
,
dev_ctx
)
func
=
relay
.
Function
([
x
,
y
],
func
=
relay
.
Function
([
x
,
y
],
_exp
)
relay
.
Tuple
(
tvm
.
convert
([
_add
,
_sqrt
,
_log
,
_subtract
,
_exp
,
exp
])))
func
=
relay
.
ir_pass
.
infer_type
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
func
)
func
=
relay
.
ir_pass
.
rewrite_annotated_ops
(
func
,
func
=
relay
.
ir_pass
.
rewrite_annotated_ops
(
func
,
cpu_ctx
.
device_type
)
cpu_ctx
.
device_type
)
...
@@ -401,8 +448,7 @@ def run_fusible_network(dev, tgt):
...
@@ -401,8 +448,7 @@ def run_fusible_network(dev, tgt):
exp
=
relay
.
exp
(
subtract
)
exp
=
relay
.
exp
(
subtract
)
_exp
=
relay
.
annotation
.
on_device
(
exp
,
cpu_ctx
)
_exp
=
relay
.
annotation
.
on_device
(
exp
,
cpu_ctx
)
func
=
relay
.
Function
([
x
,
y
],
func
=
relay
.
Function
([
x
,
y
],
_exp
)
relay
.
Tuple
(
tvm
.
convert
([
_exp
,
exp
])))
func
=
relay
.
ir_pass
.
infer_type
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
func
)
func
=
relay
.
ir_pass
.
rewrite_annotated_ops
(
func
,
func
=
relay
.
ir_pass
.
rewrite_annotated_ops
(
func
,
dev_ctx
.
device_type
)
dev_ctx
.
device_type
)
...
@@ -472,11 +518,9 @@ def run_unpropagatable_graph(dev, tgt):
...
@@ -472,11 +518,9 @@ def run_unpropagatable_graph(dev, tgt):
_add
=
relay
.
annotation
.
on_device
(
add
,
dev_ctx
)
_add
=
relay
.
annotation
.
on_device
(
add
,
dev_ctx
)
mul
=
relay
.
multiply
(
c
,
d
)
mul
=
relay
.
multiply
(
c
,
d
)
_mul
=
relay
.
annotation
.
on_device
(
mul
,
cpu_ctx
)
_mul
=
relay
.
annotation
.
on_device
(
mul
,
cpu_ctx
)
sub
=
relay
.
subtract
(
add
,
mul
)
sub
=
relay
.
subtract
(
_add
,
_
mul
)
_sub
=
relay
.
annotation
.
on_device
(
sub
,
dev_ctx
)
_sub
=
relay
.
annotation
.
on_device
(
sub
,
dev_ctx
)
func
=
relay
.
Function
([
a
,
b
,
c
,
d
],
func
=
relay
.
Function
([
a
,
b
,
c
,
d
],
_sub
)
relay
.
Tuple
(
tvm
.
convert
([
_add
,
_mul
,
_sub
,
sub
])))
func
=
relay
.
ir_pass
.
infer_type
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
func
)
func
=
relay
.
ir_pass
.
rewrite_annotated_ops
(
func
,
func
=
relay
.
ir_pass
.
rewrite_annotated_ops
(
func
,
dev_ctx
.
device_type
)
dev_ctx
.
device_type
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment