Commit b374192b by Zhi Committed by Tianqi Chen

move fallback out of the build interface (#2456)

parent 985e7d72
......@@ -36,6 +36,7 @@ class BuildConfig(object):
defaults = {
"opt_level": 2,
"add_pass": None,
"fallback_device": None,
}
def __init__(self, **kwargs):
......@@ -96,6 +97,10 @@ def build_config(**kwargs):
add_pass: set of str
Optimization pass to be added regardless of optimization level.
fallback_device : str or tvm.TVMContext
The fallback device. It is also used as the default device for
operators without specified device during heterogeneous execution.
Returns
-------
config: BuildConfig
......@@ -192,8 +197,7 @@ def optimize(func, target, params=None):
return func
def build(func, target=None, target_host=None, params=None,
fallback_device=None):
def build(func, target=None, target_host=None, params=None):
"""Build a function to run on TVM graph runtime.
Parameters
......@@ -219,10 +223,6 @@ def build(func, target=None, target_host=None, params=None,
Input parameters to the graph that do not change
during inference time. Used for constant folding.
fallback_device : str or tvm.TVMContext, optional.
The fallback device. It is also used as the default device for
operators with no specified device.
Returns
-------
graph_json : str
......@@ -239,8 +239,7 @@ def build(func, target=None, target_host=None, params=None,
raise ValueError("Target is not set in env or passed as argument.")
if isinstance(target, dict):
target, fallback_device = \
_update_heterogeneous_inputs(target, fallback_device)
target, fallback_device = _update_heterogeneous_inputs(target)
elif isinstance(target, (str, _target.Target)):
target = _target.create(target)
else:
......@@ -277,7 +276,7 @@ def build(func, target=None, target_host=None, params=None,
return graph_json, mod, params
def _update_heterogeneous_inputs(target, fallback_device=None):
def _update_heterogeneous_inputs(target):
"""Update the target and fallback device required for heterogeneous
compilation. CPU is used as the fallback device if it wasn't provided.
Meanwhile, a CPU device type and "llvm" pair will be added to the target
......@@ -288,10 +287,6 @@ def _update_heterogeneous_inputs(target, fallback_device=None):
target : dict of str(i.e. device/context name) to str/tvm.target.Target.
A dict contains context to target pairs.
fallback_device : str or tvm.TVMContext, optional.
The fallback device. It is also used as the default device for
operators with no specified device.
Returns
-------
device_target : dict of int to tvm.target.Target.
......@@ -305,6 +300,7 @@ def _update_heterogeneous_inputs(target, fallback_device=None):
"heterogeneous execution, but received %s."
% type(target))
fallback_device = BuildConfig.current.fallback_device
if fallback_device is None:
# cpu is used as the default fallback device when heterogeneous
# execution is needed, but no fallback device is provided.
......@@ -315,7 +311,7 @@ def _update_heterogeneous_inputs(target, fallback_device=None):
elif isinstance(fallback_device, TVMContext):
fallback_device = fallback_device.device_type
else:
raise ValueError("fallback_device expects the type of str or" +
raise ValueError("fallback_device expects the type of str or " +
"TVMContext, but received %s." % type(fallback_device))
device_target = {}
......
......@@ -3,7 +3,6 @@ import numpy as np
import tvm
from tvm import relay
from tvm.relay import testing
from tvm.contrib import graph_runtime
......@@ -248,12 +247,14 @@ def test_fusible_network():
def test_runtime(target, device, func, fallback_device=None):
params = {"x": x_data, "y": y_data}
with relay.build_config(opt_level=1):
config = {"opt_level": 1}
if fallback_device:
config["fallback_device"] = fallback_device
with relay.build_config(**config):
graph, lib, params = relay.build(
func,
target,
params=params,
fallback_device=fallback_device)
params=params)
contexts = [tvm.cpu(0), tvm.context(device)]
mod = graph_runtime.create(graph, lib, contexts)
mod.set_input(**params)
......@@ -367,13 +368,11 @@ def test_fusible_network():
test_runtime(target, device, annotated_func, fallback_device)
def test_fallback_all_operators(device, tgt):
target = {"cpu": "llvm", device: tgt}
fallback_device = tvm.cpu(0)
target = {device: tgt}
annotated_func = get_func()
expected_func = get_func()
check_annotated_graph(annotated_func, expected_func)
test_runtime(target, device, annotated_func, fallback_device)
test_runtime(target, device, annotated_func)
for dev, tgt in [("opencl", "opencl"), ("cuda", "cuda"),
("opencl", str(tvm.target.intel_graphics()))]:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment