Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
b374192b
Commit
b374192b
authored
Jan 17, 2019
by
Zhi
Committed by
Tianqi Chen
Jan 17, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
move fallback out of the build interface (#2456)
parent
985e7d72
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
22 deletions
+17
-22
python/tvm/relay/build_module.py
+10
-14
tests/python/relay/test_pass_annotation.py
+7
-8
No files found.
python/tvm/relay/build_module.py
View file @
b374192b
...
...
@@ -36,6 +36,7 @@ class BuildConfig(object):
defaults
=
{
"opt_level"
:
2
,
"add_pass"
:
None
,
"fallback_device"
:
None
,
}
def
__init__
(
self
,
**
kwargs
):
...
...
@@ -96,6 +97,10 @@ def build_config(**kwargs):
add_pass: set of str
Optimization pass to be added regardless of optimization level.
fallback_device : str or tvm.TVMContext
The fallback device. It is also used as the default device for
operators without specified device during heterogeneous execution.
Returns
-------
config: BuildConfig
...
...
@@ -192,8 +197,7 @@ def optimize(func, target, params=None):
return
func
def
build
(
func
,
target
=
None
,
target_host
=
None
,
params
=
None
,
fallback_device
=
None
):
def
build
(
func
,
target
=
None
,
target_host
=
None
,
params
=
None
):
"""Build a function to run on TVM graph runtime.
Parameters
...
...
@@ -219,10 +223,6 @@ def build(func, target=None, target_host=None, params=None,
Input parameters to the graph that do not change
during inference time. Used for constant folding.
fallback_device : str or tvm.TVMContext, optional.
The fallback device. It is also used as the default device for
operators with no specified device.
Returns
-------
graph_json : str
...
...
@@ -239,8 +239,7 @@ def build(func, target=None, target_host=None, params=None,
raise
ValueError
(
"Target is not set in env or passed as argument."
)
if
isinstance
(
target
,
dict
):
target
,
fallback_device
=
\
_update_heterogeneous_inputs
(
target
,
fallback_device
)
target
,
fallback_device
=
_update_heterogeneous_inputs
(
target
)
elif
isinstance
(
target
,
(
str
,
_target
.
Target
)):
target
=
_target
.
create
(
target
)
else
:
...
...
@@ -277,7 +276,7 @@ def build(func, target=None, target_host=None, params=None,
return
graph_json
,
mod
,
params
def
_update_heterogeneous_inputs
(
target
,
fallback_device
=
None
):
def
_update_heterogeneous_inputs
(
target
):
"""Update the target and fallback device required for heterogeneous
compilation. CPU is used as the fallback device if it wasn't provided.
Meanwhile, a CPU device type and "llvm" pair will be added to the target
...
...
@@ -288,10 +287,6 @@ def _update_heterogeneous_inputs(target, fallback_device=None):
target : dict of str(i.e. device/context name) to str/tvm.target.Target.
A dict contains context to target pairs.
fallback_device : str or tvm.TVMContext, optional.
The fallback device. It is also used as the default device for
operators with no specified device.
Returns
-------
device_target : dict of int to tvm.target.Target.
...
...
@@ -305,6 +300,7 @@ def _update_heterogeneous_inputs(target, fallback_device=None):
"heterogeneous execution, but received
%
s."
%
type
(
target
))
fallback_device
=
BuildConfig
.
current
.
fallback_device
if
fallback_device
is
None
:
# cpu is used as the default fallback device when heterogeneous
# execution is needed, but no fallback device is provided.
...
...
@@ -315,7 +311,7 @@ def _update_heterogeneous_inputs(target, fallback_device=None):
elif
isinstance
(
fallback_device
,
TVMContext
):
fallback_device
=
fallback_device
.
device_type
else
:
raise
ValueError
(
"fallback_device expects the type of str or"
+
raise
ValueError
(
"fallback_device expects the type of str or
"
+
"TVMContext, but received
%
s."
%
type
(
fallback_device
))
device_target
=
{}
...
...
tests/python/relay/test_pass_annotation.py
View file @
b374192b
...
...
@@ -3,7 +3,6 @@ import numpy as np
import
tvm
from
tvm
import
relay
from
tvm.relay
import
testing
from
tvm.contrib
import
graph_runtime
...
...
@@ -248,12 +247,14 @@ def test_fusible_network():
def
test_runtime
(
target
,
device
,
func
,
fallback_device
=
None
):
params
=
{
"x"
:
x_data
,
"y"
:
y_data
}
with
relay
.
build_config
(
opt_level
=
1
):
config
=
{
"opt_level"
:
1
}
if
fallback_device
:
config
[
"fallback_device"
]
=
fallback_device
with
relay
.
build_config
(
**
config
):
graph
,
lib
,
params
=
relay
.
build
(
func
,
target
,
params
=
params
,
fallback_device
=
fallback_device
)
params
=
params
)
contexts
=
[
tvm
.
cpu
(
0
),
tvm
.
context
(
device
)]
mod
=
graph_runtime
.
create
(
graph
,
lib
,
contexts
)
mod
.
set_input
(
**
params
)
...
...
@@ -367,13 +368,11 @@ def test_fusible_network():
test_runtime
(
target
,
device
,
annotated_func
,
fallback_device
)
def
test_fallback_all_operators
(
device
,
tgt
):
target
=
{
"cpu"
:
"llvm"
,
device
:
tgt
}
fallback_device
=
tvm
.
cpu
(
0
)
target
=
{
device
:
tgt
}
annotated_func
=
get_func
()
expected_func
=
get_func
()
check_annotated_graph
(
annotated_func
,
expected_func
)
test_runtime
(
target
,
device
,
annotated_func
,
fallback_device
)
test_runtime
(
target
,
device
,
annotated_func
)
for
dev
,
tgt
in
[(
"opencl"
,
"opencl"
),
(
"cuda"
,
"cuda"
),
(
"opencl"
,
str
(
tvm
.
target
.
intel_graphics
()))]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment