Commit b374192b by Zhi Committed by Tianqi Chen

move fallback out of the build interface (#2456)

parent 985e7d72
...@@ -36,6 +36,7 @@ class BuildConfig(object): ...@@ -36,6 +36,7 @@ class BuildConfig(object):
defaults = { defaults = {
"opt_level": 2, "opt_level": 2,
"add_pass": None, "add_pass": None,
"fallback_device": None,
} }
def __init__(self, **kwargs): def __init__(self, **kwargs):
...@@ -96,6 +97,10 @@ def build_config(**kwargs): ...@@ -96,6 +97,10 @@ def build_config(**kwargs):
add_pass: set of str add_pass: set of str
Optimization pass to be added regardless of optimization level. Optimization pass to be added regardless of optimization level.
fallback_device : str or tvm.TVMContext
The fallback device. It is also used as the default device for
operators without specified device during heterogeneous execution.
Returns Returns
------- -------
config: BuildConfig config: BuildConfig
...@@ -192,8 +197,7 @@ def optimize(func, target, params=None): ...@@ -192,8 +197,7 @@ def optimize(func, target, params=None):
return func return func
def build(func, target=None, target_host=None, params=None, def build(func, target=None, target_host=None, params=None):
fallback_device=None):
"""Build a function to run on TVM graph runtime. """Build a function to run on TVM graph runtime.
Parameters Parameters
...@@ -219,10 +223,6 @@ def build(func, target=None, target_host=None, params=None, ...@@ -219,10 +223,6 @@ def build(func, target=None, target_host=None, params=None,
Input parameters to the graph that do not change Input parameters to the graph that do not change
during inference time. Used for constant folding. during inference time. Used for constant folding.
fallback_device : str or tvm.TVMContext, optional.
The fallback device. It is also used as the default device for
operators with no specified device.
Returns Returns
------- -------
graph_json : str graph_json : str
...@@ -239,8 +239,7 @@ def build(func, target=None, target_host=None, params=None, ...@@ -239,8 +239,7 @@ def build(func, target=None, target_host=None, params=None,
raise ValueError("Target is not set in env or passed as argument.") raise ValueError("Target is not set in env or passed as argument.")
if isinstance(target, dict): if isinstance(target, dict):
target, fallback_device = \ target, fallback_device = _update_heterogeneous_inputs(target)
_update_heterogeneous_inputs(target, fallback_device)
elif isinstance(target, (str, _target.Target)): elif isinstance(target, (str, _target.Target)):
target = _target.create(target) target = _target.create(target)
else: else:
...@@ -277,7 +276,7 @@ def build(func, target=None, target_host=None, params=None, ...@@ -277,7 +276,7 @@ def build(func, target=None, target_host=None, params=None,
return graph_json, mod, params return graph_json, mod, params
def _update_heterogeneous_inputs(target, fallback_device=None): def _update_heterogeneous_inputs(target):
"""Update the target and fallback device required for heterogeneous """Update the target and fallback device required for heterogeneous
compilation. CPU is used as the fallback device if it wasn't provided. compilation. CPU is used as the fallback device if it wasn't provided.
Meanwhile, a CPU device type and "llvm" pair will be added to the target Meanwhile, a CPU device type and "llvm" pair will be added to the target
...@@ -288,10 +287,6 @@ def _update_heterogeneous_inputs(target, fallback_device=None): ...@@ -288,10 +287,6 @@ def _update_heterogeneous_inputs(target, fallback_device=None):
target : dict of str(i.e. device/context name) to str/tvm.target.Target. target : dict of str(i.e. device/context name) to str/tvm.target.Target.
A dict contains context to target pairs. A dict contains context to target pairs.
fallback_device : str or tvm.TVMContext, optional.
The fallback device. It is also used as the default device for
operators with no specified device.
Returns Returns
------- -------
device_target : dict of int to tvm.target.Target. device_target : dict of int to tvm.target.Target.
...@@ -305,6 +300,7 @@ def _update_heterogeneous_inputs(target, fallback_device=None): ...@@ -305,6 +300,7 @@ def _update_heterogeneous_inputs(target, fallback_device=None):
"heterogeneous execution, but received %s." "heterogeneous execution, but received %s."
% type(target)) % type(target))
fallback_device = BuildConfig.current.fallback_device
if fallback_device is None: if fallback_device is None:
# cpu is used as the default fallback device when heterogeneous # cpu is used as the default fallback device when heterogeneous
# execution is needed, but no fallback device is provided. # execution is needed, but no fallback device is provided.
...@@ -315,7 +311,7 @@ def _update_heterogeneous_inputs(target, fallback_device=None): ...@@ -315,7 +311,7 @@ def _update_heterogeneous_inputs(target, fallback_device=None):
elif isinstance(fallback_device, TVMContext): elif isinstance(fallback_device, TVMContext):
fallback_device = fallback_device.device_type fallback_device = fallback_device.device_type
else: else:
raise ValueError("fallback_device expects the type of str or" + raise ValueError("fallback_device expects the type of str or " +
"TVMContext, but received %s." % type(fallback_device)) "TVMContext, but received %s." % type(fallback_device))
device_target = {} device_target = {}
......
...@@ -3,7 +3,6 @@ import numpy as np ...@@ -3,7 +3,6 @@ import numpy as np
import tvm import tvm
from tvm import relay from tvm import relay
from tvm.relay import testing
from tvm.contrib import graph_runtime from tvm.contrib import graph_runtime
...@@ -248,12 +247,14 @@ def test_fusible_network(): ...@@ -248,12 +247,14 @@ def test_fusible_network():
def test_runtime(target, device, func, fallback_device=None): def test_runtime(target, device, func, fallback_device=None):
params = {"x": x_data, "y": y_data} params = {"x": x_data, "y": y_data}
with relay.build_config(opt_level=1): config = {"opt_level": 1}
if fallback_device:
config["fallback_device"] = fallback_device
with relay.build_config(**config):
graph, lib, params = relay.build( graph, lib, params = relay.build(
func, func,
target, target,
params=params, params=params)
fallback_device=fallback_device)
contexts = [tvm.cpu(0), tvm.context(device)] contexts = [tvm.cpu(0), tvm.context(device)]
mod = graph_runtime.create(graph, lib, contexts) mod = graph_runtime.create(graph, lib, contexts)
mod.set_input(**params) mod.set_input(**params)
...@@ -367,13 +368,11 @@ def test_fusible_network(): ...@@ -367,13 +368,11 @@ def test_fusible_network():
test_runtime(target, device, annotated_func, fallback_device) test_runtime(target, device, annotated_func, fallback_device)
def test_fallback_all_operators(device, tgt): def test_fallback_all_operators(device, tgt):
target = {"cpu": "llvm", device: tgt} target = {device: tgt}
fallback_device = tvm.cpu(0)
annotated_func = get_func() annotated_func = get_func()
expected_func = get_func() expected_func = get_func()
check_annotated_graph(annotated_func, expected_func) check_annotated_graph(annotated_func, expected_func)
test_runtime(target, device, annotated_func, fallback_device) test_runtime(target, device, annotated_func)
for dev, tgt in [("opencl", "opencl"), ("cuda", "cuda"), for dev, tgt in [("opencl", "opencl"), ("cuda", "cuda"),
("opencl", str(tvm.target.intel_graphics()))]: ("opencl", str(tvm.target.intel_graphics()))]:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment