Commit 81118023 by Jon Soifer Committed by Tianqi Chen

[Relay][TopHub] Add switch to disable TopHub download (#4015)

parent 7d911f46
...@@ -18,7 +18,8 @@ ...@@ -18,7 +18,8 @@
TopHub: Tensor Operator Hub TopHub: Tensor Operator Hub
To get the best performance, we typically need auto-tuning for the specific devices. To get the best performance, we typically need auto-tuning for the specific devices.
TVM releases pre-tuned parameters in TopHub for some common networks and hardware targets. TVM releases pre-tuned parameters in TopHub for some common networks and hardware targets.
TVM will download these parameters for you when you call nnvm.compiler.build_module . TVM will download these parameters for you when you call
nnvm.compiler.build_module or relay.build.
""" """
# pylint: disable=invalid-name # pylint: disable=invalid-name
...@@ -30,6 +31,16 @@ from .task import ApplyHistoryBest ...@@ -30,6 +31,16 @@ from .task import ApplyHistoryBest
from .. import target as _target from .. import target as _target
from ..contrib.download import download from ..contrib.download import download
from .record import load_from_file from .record import load_from_file
from .util import EmptyContext
# environment variable to read TopHub location
AUTOTVM_TOPHUB_LOC_VAR = "TOPHUB_LOCATION"
# default location of TopHub
AUTOTVM_TOPHUB_DEFAULT_LOC = "https://raw.githubusercontent.com/uwsampl/tvm-distro/master/tophub"
# value of AUTOTVM_TOPHUB_LOC_VAR to specify to not read from TopHub
AUTOTVM_TOPHUB_NONE_LOC = "NONE"
# root path to store TopHub files # root path to store TopHub files
AUTOTVM_TOPHUB_ROOT_PATH = os.path.join(os.path.expanduser('~'), ".tvm", "tophub") AUTOTVM_TOPHUB_ROOT_PATH = os.path.join(os.path.expanduser('~'), ".tvm", "tophub")
...@@ -61,6 +72,9 @@ def _alias(name): ...@@ -61,6 +72,9 @@ def _alias(name):
} }
return table.get(name, name) return table.get(name, name)
def _get_tophub_location():
location = os.getenv(AUTOTVM_TOPHUB_LOC_VAR, None)
return AUTOTVM_TOPHUB_DEFAULT_LOC if location is None else location
def context(target, extra_files=None): def context(target, extra_files=None):
"""Return the dispatch context with pre-tuned parameters. """Return the dispatch context with pre-tuned parameters.
...@@ -75,6 +89,10 @@ def context(target, extra_files=None): ...@@ -75,6 +89,10 @@ def context(target, extra_files=None):
extra_files: list of str, optional extra_files: list of str, optional
Extra log files to load Extra log files to load
""" """
tophub_location = _get_tophub_location()
if tophub_location == AUTOTVM_TOPHUB_NONE_LOC:
return EmptyContext()
best_context = ApplyHistoryBest([]) best_context = ApplyHistoryBest([])
targets = target if isinstance(target, (list, tuple)) else [target] targets = target if isinstance(target, (list, tuple)) else [target]
...@@ -94,7 +112,7 @@ def context(target, extra_files=None): ...@@ -94,7 +112,7 @@ def context(target, extra_files=None):
for name in possible_names: for name in possible_names:
name = _alias(name) name = _alias(name)
if name in all_packages: if name in all_packages:
if not check_backend(name): if not check_backend(tophub_location, name):
continue continue
filename = "%s_%s.log" % (name, PACKAGE_VERSION[name]) filename = "%s_%s.log" % (name, PACKAGE_VERSION[name])
...@@ -108,7 +126,7 @@ def context(target, extra_files=None): ...@@ -108,7 +126,7 @@ def context(target, extra_files=None):
return best_context return best_context
def check_backend(backend): def check_backend(tophub_location, backend):
"""Check whether have pre-tuned parameters of the certain target. """Check whether have pre-tuned parameters of the certain target.
If not, will download it. If not, will download it.
...@@ -135,18 +153,21 @@ def check_backend(backend): ...@@ -135,18 +153,21 @@ def check_backend(backend):
else: else:
import urllib2 import urllib2
try: try:
download_package(package_name) download_package(tophub_location, package_name)
return True return True
except urllib2.URLError as e: except urllib2.URLError as e:
logging.warning("Failed to download tophub package for %s: %s", backend, e) logging.warning("Failed to download tophub package for %s: %s", backend, e)
return False return False
def download_package(package_name): def download_package(tophub_location, package_name):
"""Download pre-tuned parameters of operators for a backend """Download pre-tuned parameters of operators for a backend
Parameters Parameters
---------- ----------
tophub_location: str
The location to download TopHub parameters from
package_name: str package_name: str
The name of package The name of package
""" """
...@@ -160,9 +181,9 @@ def download_package(package_name): ...@@ -160,9 +181,9 @@ def download_package(package_name):
if not os.path.isdir(path): if not os.path.isdir(path):
os.mkdir(path) os.mkdir(path)
logger.info("Download pre-tuned parameters package %s", package_name) download_url = "{0}/{1}".format(tophub_location, package_name)
download("https://raw.githubusercontent.com/uwsampl/tvm-distro/master/tophub/%s" logger.info("Download pre-tuned parameters package from %s", download_url)
% package_name, os.path.join(rootpath, package_name), True, verbose=0) download(download_url, os.path.join(rootpath, package_name), True, verbose=0)
# global cache for load_reference_log # global cache for load_reference_log
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment