Commit 81118023 by Jon Soifer Committed by Tianqi Chen

[Relay][TopHub] Add switch to disable TopHub download (#4015)

parent 7d911f46
......@@ -18,7 +18,8 @@
TopHub: Tensor Operator Hub
To get the best performance, we typically need auto-tuning for the specific devices.
TVM releases pre-tuned parameters in TopHub for some common networks and hardware targets.
TVM will download these parameters for you when you call nnvm.compiler.build_module .
TVM will download these parameters for you when you call
nnvm.compiler.build_module or relay.build.
"""
# pylint: disable=invalid-name
......@@ -30,6 +31,16 @@ from .task import ApplyHistoryBest
from .. import target as _target
from ..contrib.download import download
from .record import load_from_file
from .util import EmptyContext
# environment variable to read TopHub location
AUTOTVM_TOPHUB_LOC_VAR = "TOPHUB_LOCATION"
# default location of TopHub
AUTOTVM_TOPHUB_DEFAULT_LOC = "https://raw.githubusercontent.com/uwsampl/tvm-distro/master/tophub"
# value of AUTOTVM_TOPHUB_LOC_VAR to specify to not read from TopHub
AUTOTVM_TOPHUB_NONE_LOC = "NONE"
# root path to store TopHub files
AUTOTVM_TOPHUB_ROOT_PATH = os.path.join(os.path.expanduser('~'), ".tvm", "tophub")
......@@ -61,6 +72,9 @@ def _alias(name):
}
return table.get(name, name)
def _get_tophub_location():
location = os.getenv(AUTOTVM_TOPHUB_LOC_VAR, None)
return AUTOTVM_TOPHUB_DEFAULT_LOC if location is None else location
def context(target, extra_files=None):
"""Return the dispatch context with pre-tuned parameters.
......@@ -75,6 +89,10 @@ def context(target, extra_files=None):
extra_files: list of str, optional
Extra log files to load
"""
tophub_location = _get_tophub_location()
if tophub_location == AUTOTVM_TOPHUB_NONE_LOC:
return EmptyContext()
best_context = ApplyHistoryBest([])
targets = target if isinstance(target, (list, tuple)) else [target]
......@@ -94,7 +112,7 @@ def context(target, extra_files=None):
for name in possible_names:
name = _alias(name)
if name in all_packages:
if not check_backend(name):
if not check_backend(tophub_location, name):
continue
filename = "%s_%s.log" % (name, PACKAGE_VERSION[name])
......@@ -108,7 +126,7 @@ def context(target, extra_files=None):
return best_context
def check_backend(backend):
def check_backend(tophub_location, backend):
"""Check whether have pre-tuned parameters of the certain target.
If not, will download it.
......@@ -135,18 +153,21 @@ def check_backend(backend):
else:
import urllib2
try:
download_package(package_name)
download_package(tophub_location, package_name)
return True
except urllib2.URLError as e:
logging.warning("Failed to download tophub package for %s: %s", backend, e)
return False
def download_package(package_name):
def download_package(tophub_location, package_name):
"""Download pre-tuned parameters of operators for a backend
Parameters
----------
tophub_location: str
The location to download TopHub parameters from
package_name: str
The name of package
"""
......@@ -160,9 +181,9 @@ def download_package(package_name):
if not os.path.isdir(path):
os.mkdir(path)
logger.info("Download pre-tuned parameters package %s", package_name)
download("https://raw.githubusercontent.com/uwsampl/tvm-distro/master/tophub/%s"
% package_name, os.path.join(rootpath, package_name), True, verbose=0)
download_url = "{0}/{1}".format(tophub_location, package_name)
logger.info("Download pre-tuned parameters package from %s", download_url)
download(download_url, os.path.join(rootpath, package_name), True, verbose=0)
# global cache for load_reference_log
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment