Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
1e270aa4
Commit
1e270aa4
authored
Feb 19, 2019
by
Zhi
Committed by
Tianqi Chen
Feb 19, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay]fix heterogenous annotation bug (#2622)
parent
f23a7a54
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
62 additions
and
14 deletions
+62
-14
src/relay/pass/device_annotation.cc
+33
-8
tests/python/relay/test_pass_annotation.py
+29
-6
No files found.
src/relay/pass/device_annotation.cc
View file @
1e270aa4
...
@@ -337,12 +337,17 @@ class DeviceInfo {
...
@@ -337,12 +337,17 @@ class DeviceInfo {
private
:
private
:
class
PostDfsOrderVisitor
:
private
ExprVisitor
{
class
PostDfsOrderVisitor
:
private
ExprVisitor
{
public
:
public
:
void
Visit
(
const
Expr
&
expr
)
{
this
->
VisitExpr
(
expr
);
}
void
Visit
(
const
Expr
&
expr
)
{
if
(
const
auto
*
fn
=
expr
.
as
<
FunctionNode
>
())
{
this
->
VisitExpr
(
fn
->
body
);
}
else
{
this
->
VisitExpr
(
expr
);
}
}
private
:
private
:
// Post order traversal.
// Post order traversal.
void
VisitExpr_
(
const
FunctionNode
*
fn
)
final
{
void
VisitExpr_
(
const
FunctionNode
*
fn
)
final
{
ExprVisitor
::
VisitExpr_
(
fn
);
// TODO(zhiics) Skip annotation of function node for now.
// TODO(zhiics) Skip annotation of function node for now.
}
}
...
@@ -356,7 +361,7 @@ class DeviceInfo {
...
@@ -356,7 +361,7 @@ class DeviceInfo {
ExprVisitor
::
VisitExpr_
(
call
);
ExprVisitor
::
VisitExpr_
(
call
);
post_dfs_order_
.
push_back
(
call
);
post_dfs_order_
.
push_back
(
call
);
if
(
Is
DeviceCopyNode
(
call
))
{
if
(
Get
DeviceCopyNode
(
call
))
{
num_device_copy_ops_
++
;
num_device_copy_ops_
++
;
}
}
}
}
...
@@ -389,6 +394,26 @@ class DeviceInfo {
...
@@ -389,6 +394,26 @@ class DeviceInfo {
friend
DeviceInfo
;
friend
DeviceInfo
;
};
};
/*
* \brief Returns a device copy node based on the current expr node. It
* returns a device copy node either the current expr node is a device copy
* node or the current expr node is a function node whose body is a device
* copy node (i.e. the fused function of a device copy call node).
*/
static
const
ExprNode
*
GetDeviceCopyNode
(
const
ExprNode
*
node
)
{
if
(
IsDeviceCopyNode
(
node
))
{
return
node
;
}
else
if
(
const
auto
*
call_node
=
dynamic_cast
<
const
CallNode
*>
(
node
))
{
if
(
const
auto
*
fn
=
call_node
->
op
.
as
<
FunctionNode
>
())
{
const
ExprNode
*
body
=
fn
->
body
.
operator
->
();
if
(
IsDeviceCopyNode
(
body
))
{
return
body
;
}
}
}
return
nullptr
;
}
void
PropagateDeviceId
()
{
void
PropagateDeviceId
()
{
// Bottom-up propagation.
// Bottom-up propagation.
BottomUpPropagation
();
BottomUpPropagation
();
...
@@ -401,11 +426,11 @@ class DeviceInfo {
...
@@ -401,11 +426,11 @@ class DeviceInfo {
int
cur_dev_type
=
-
1
;
int
cur_dev_type
=
-
1
;
for
(
auto
it
=
post_visitor_
.
post_dfs_order_
.
crbegin
();
for
(
auto
it
=
post_visitor_
.
post_dfs_order_
.
crbegin
();
it
!=
post_visitor_
.
post_dfs_order_
.
crend
();
++
it
)
{
it
!=
post_visitor_
.
post_dfs_order_
.
crend
();
++
it
)
{
if
(
Is
DeviceCopyNode
(
*
it
))
{
if
(
const
auto
*
node
=
Get
DeviceCopyNode
(
*
it
))
{
last_copy_node
=
dynamic_cast
<
const
CallNode
*>
(
*
it
);
last_copy_node
=
dynamic_cast
<
const
CallNode
*>
(
node
);
const
auto
*
attrs
=
last_copy_node
->
attrs
.
as
<
DeviceCopyAttrs
>
();
const
auto
*
attrs
=
last_copy_node
->
attrs
.
as
<
DeviceCopyAttrs
>
();
cur_dev_type
=
attrs
->
src_dev_type
;
cur_dev_type
=
attrs
->
src_dev_type
;
device_map_
.
Set
(
GetRef
<
Expr
>
(
last_copy_node
),
attrs
->
dst_dev_type
);
device_map_
.
Set
(
GetRef
<
Expr
>
(
*
it
),
attrs
->
dst_dev_type
);
}
else
if
(
last_copy_node
)
{
}
else
if
(
last_copy_node
)
{
Expr
expr
=
GetRef
<
Expr
>
(
*
it
);
Expr
expr
=
GetRef
<
Expr
>
(
*
it
);
CHECK_EQ
(
device_map_
.
count
(
expr
),
0U
);
CHECK_EQ
(
device_map_
.
count
(
expr
),
0U
);
...
@@ -418,8 +443,8 @@ class DeviceInfo {
...
@@ -418,8 +443,8 @@ class DeviceInfo {
const
CallNode
*
last_copy_node
=
nullptr
;
const
CallNode
*
last_copy_node
=
nullptr
;
int
cur_dev_type
=
-
1
;
int
cur_dev_type
=
-
1
;
for
(
const
auto
&
it
:
post_visitor_
.
post_dfs_order_
)
{
for
(
const
auto
&
it
:
post_visitor_
.
post_dfs_order_
)
{
if
(
Is
DeviceCopyNode
(
it
))
{
if
(
const
auto
*
node
=
Get
DeviceCopyNode
(
it
))
{
last_copy_node
=
dynamic_cast
<
const
CallNode
*>
(
it
);
last_copy_node
=
dynamic_cast
<
const
CallNode
*>
(
node
);
const
auto
*
attrs
=
last_copy_node
->
attrs
.
as
<
DeviceCopyAttrs
>
();
const
auto
*
attrs
=
last_copy_node
->
attrs
.
as
<
DeviceCopyAttrs
>
();
cur_dev_type
=
attrs
->
dst_dev_type
;
cur_dev_type
=
attrs
->
dst_dev_type
;
}
else
if
(
last_copy_node
)
{
}
else
if
(
last_copy_node
)
{
...
...
tests/python/relay/test_pass_annotation.py
View file @
1e270aa4
"""Unit tests for heterogeneous compilation and execution."""
"""Unit tests for heterogeneous compilation and execution."""
import
json
import
numpy
as
np
import
numpy
as
np
import
tvm
import
tvm
...
@@ -72,6 +73,7 @@ def test_annotate_all():
...
@@ -72,6 +73,7 @@ def test_annotate_all():
annotated_func
=
relay
.
ir_pass
.
infer_type
(
annotated
())
annotated_func
=
relay
.
ir_pass
.
infer_type
(
annotated
())
expected_func
=
relay
.
ir_pass
.
infer_type
(
expected
())
expected_func
=
relay
.
ir_pass
.
infer_type
(
expected
())
assert
relay
.
ir_pass
.
alpha_equal
(
annotated_func
,
expected_func
)
def
test_annotate_none
():
def
test_annotate_none
():
ctx1
=
tvm
.
context
(
1
)
ctx1
=
tvm
.
context
(
1
)
...
@@ -203,7 +205,7 @@ def test_conv_network():
...
@@ -203,7 +205,7 @@ def test_conv_network():
for
did
in
storage_dev_type
[
1
]:
for
did
in
storage_dev_type
[
1
]:
device_types
.
append
(
did
.
value
)
device_types
.
append
(
did
.
value
)
assert
len
(
storage_ids
)
==
10
assert
len
(
storage_ids
)
==
10
assert
len
(
set
(
storage_ids
))
==
7
assert
len
(
set
(
storage_ids
))
==
8
assert
len
(
set
(
device_types
))
==
2
assert
len
(
set
(
device_types
))
==
2
assert
set
(
device_types
)
==
{
1
,
2
}
assert
set
(
device_types
)
==
{
1
,
2
}
...
@@ -245,7 +247,8 @@ def test_fusible_network():
...
@@ -245,7 +247,8 @@ def test_fusible_network():
func
=
relay
.
Function
([
x
,
y
],
exp
)
func
=
relay
.
Function
([
x
,
y
],
exp
)
return
func
return
func
def
test_runtime
(
target
,
device
,
func
,
fallback_device
=
None
):
def
test_runtime
(
target
,
device
,
func
,
fallback_device
=
None
,
expected_index
=
None
):
params
=
{
"x"
:
x_data
,
"y"
:
y_data
}
params
=
{
"x"
:
x_data
,
"y"
:
y_data
}
config
=
{
"opt_level"
:
1
}
config
=
{
"opt_level"
:
1
}
if
fallback_device
:
if
fallback_device
:
...
@@ -256,6 +259,10 @@ def test_fusible_network():
...
@@ -256,6 +259,10 @@ def test_fusible_network():
target
,
target
,
params
=
params
)
params
=
params
)
contexts
=
[
tvm
.
cpu
(
0
),
tvm
.
context
(
device
)]
contexts
=
[
tvm
.
cpu
(
0
),
tvm
.
context
(
device
)]
graph_json
=
json
.
loads
(
graph
)
if
"device_index"
in
graph_json
[
"attrs"
]:
device_index
=
graph_json
[
"attrs"
][
"device_index"
][
1
]
assert
device_index
==
expected_index
mod
=
graph_runtime
.
create
(
graph
,
lib
,
contexts
)
mod
=
graph_runtime
.
create
(
graph
,
lib
,
contexts
)
mod
.
set_input
(
**
params
)
mod
.
set_input
(
**
params
)
mod
.
run
()
mod
.
run
()
...
@@ -302,8 +309,10 @@ def test_fusible_network():
...
@@ -302,8 +309,10 @@ def test_fusible_network():
annotated_func
=
annotated
()
annotated_func
=
annotated
()
expected_func
=
expected
()
expected_func
=
expected
()
expected_index
=
[
1
,
1
,
1
,
2
,
2
,
1
,
1
,
2
,
2
]
check_annotated_graph
(
annotated_func
,
expected_func
)
check_annotated_graph
(
annotated_func
,
expected_func
)
test_runtime
(
target
,
device
,
annotated_func
,
fallback_device
)
test_runtime
(
target
,
device
,
annotated_func
,
fallback_device
,
expected_index
)
def
test_fuse_all
(
device
,
tgt
):
def
test_fuse_all
(
device
,
tgt
):
"""Fuse all operators."""
"""Fuse all operators."""
...
@@ -344,6 +353,7 @@ def test_fusible_network():
...
@@ -344,6 +353,7 @@ def test_fusible_network():
fallback_device
=
tvm
.
context
(
"cpu"
)
fallback_device
=
tvm
.
context
(
"cpu"
)
target
=
{
"cpu"
:
"llvm"
,
device
:
tgt
}
target
=
{
"cpu"
:
"llvm"
,
device
:
tgt
}
cpu_ctx
=
fallback_device
cpu_ctx
=
fallback_device
dev_ctx
=
tvm
.
context
(
device
)
def
annotated
():
def
annotated
():
add
=
relay
.
add
(
x
,
y
)
add
=
relay
.
add
(
x
,
y
)
...
@@ -357,15 +367,28 @@ def test_fusible_network():
...
@@ -357,15 +367,28 @@ def test_fusible_network():
relay
.
Tuple
(
tvm
.
convert
([
_exp
,
exp
])))
relay
.
Tuple
(
tvm
.
convert
([
_exp
,
exp
])))
func
=
relay
.
ir_pass
.
infer_type
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
func
)
func
=
relay
.
ir_pass
.
rewrite_annotated_ops
(
func
,
func
=
relay
.
ir_pass
.
rewrite_annotated_ops
(
func
,
cpu
_ctx
.
device_type
)
dev
_ctx
.
device_type
)
func
=
relay
.
ir_pass
.
infer_type
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
func
)
return
relay
.
Function
(
relay
.
ir_pass
.
free_vars
(
func
.
body
[
1
]),
return
relay
.
Function
(
relay
.
ir_pass
.
free_vars
(
func
.
body
[
1
]),
func
.
body
[
1
])
func
.
body
[
1
])
def
expected
():
add
=
relay
.
add
(
x
,
y
)
sqrt
=
relay
.
sqrt
(
add
)
log
=
relay
.
log
(
add
)
subtract
=
relay
.
subtract
(
sqrt
,
log
)
copy_sub_exp
=
relay
.
device_copy
(
subtract
,
dev_ctx
,
cpu_ctx
)
exp
=
relay
.
exp
(
copy_sub_exp
)
func
=
relay
.
Function
([
x
,
y
],
exp
)
return
func
annotated_func
=
annotated
()
annotated_func
=
annotated
()
expected_func
=
get_func
()
expected_func
=
expected
()
expected_index
=
[
2
,
2
,
2
,
1
,
1
]
check_annotated_graph
(
annotated_func
,
expected_func
)
check_annotated_graph
(
annotated_func
,
expected_func
)
test_runtime
(
target
,
device
,
annotated_func
,
fallback_device
)
test_runtime
(
target
,
device
,
annotated_func
,
fallback_device
,
expected_index
)
def
test_fallback_all_operators
(
device
,
tgt
):
def
test_fallback_all_operators
(
device
,
tgt
):
target
=
{
device
:
tgt
}
target
=
{
device
:
tgt
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment