Commit ad28f5ca by Lianmin Zheng Committed by Tianqi Chen

[AUTOTVM] Misc bug fix (#1467)

parent 9026f3fc
...@@ -6,7 +6,6 @@ from collections import namedtuple ...@@ -6,7 +6,6 @@ from collections import namedtuple
import numpy as np import numpy as np
from ... import build, nd, target as _target from ... import build, nd, target as _target
from ...contrib.util import tempdir
from ...rpc.tracker import Tracker from ...rpc.tracker import Tracker
from ...rpc.server import Server from ...rpc.server import Server
...@@ -209,14 +208,12 @@ def create_measure_batch(task, options): ...@@ -209,14 +208,12 @@ def create_measure_batch(task, options):
kwargs['rpc_device_key'] = rpc_device_key kwargs['rpc_device_key'] = rpc_device_key
kwargs['rpc_tracker_addr'] = (tracker.host, tracker.port) kwargs['rpc_tracker_addr'] = (tracker.host, tracker.port)
kwargs['rpc_timeout'] = timeout kwargs['rpc_timeout'] = timeout
kwargs['tmp_dir'] = tempdir()
elif mode == 'rpc': elif mode == 'rpc':
fmeasure = measure_methods.measure_rpc fmeasure = measure_methods.measure_rpc
kwargs['rpc_device_key'] = rpc_device_key kwargs['rpc_device_key'] = rpc_device_key
kwargs['rpc_priority'] = rpc_priority kwargs['rpc_priority'] = rpc_priority
kwargs['rpc_timeout'] = rpc_timeout kwargs['rpc_timeout'] = rpc_timeout
kwargs['use_ndk'] = use_ndk kwargs['use_ndk'] = use_ndk
kwargs['tmp_dir'] = tempdir()
assert rpc_device_key, "In rpc mode, a rpc_device_key must be provided" assert rpc_device_key, "In rpc mode, a rpc_device_key must be provided"
elif mode == "custom": elif mode == "custom":
assert callable(custom_measure_batch), "In custom mode, custom_measure_func " \ assert callable(custom_measure_batch), "In custom mode, custom_measure_func " \
...@@ -243,7 +240,7 @@ def create_measure_batch(task, options): ...@@ -243,7 +240,7 @@ def create_measure_batch(task, options):
tvm_buf = [nd.array(x) for x in ref_input] tvm_buf = [nd.array(x) for x in ref_input]
func(*tvm_buf) func(*tvm_buf)
ref_output = [x.asnumpy() for x in tvm_buf] ref_output = [x.asnumpy() for x in tvm_buf]
kwargs['ref_input'], kwargs['ref_outpu'] = ref_input, ref_output kwargs['ref_input'], kwargs['ref_output'] = ref_input, ref_output
def measure_batch(measure_inputs): def measure_batch(measure_inputs):
"""measure the time cost for a batch of configs in real machines""" """measure the time cost for a batch of configs in real machines"""
......
...@@ -12,7 +12,7 @@ from random import getrandbits ...@@ -12,7 +12,7 @@ from random import getrandbits
import numpy as np import numpy as np
from ...contrib import ndk, nvcc from ...contrib import ndk, nvcc, util
from ... import rpc, ir_pass, build, build_config, nd, context, TVMError, register_func from ... import rpc, ir_pass, build, build_config, nd, context, TVMError, register_func
from ..util import get_const_tuple from ..util import get_const_tuple
...@@ -113,8 +113,8 @@ def _measure_generic(fbuild, input_pack, ref_input, ref_output): ...@@ -113,8 +113,8 @@ def _measure_generic(fbuild, input_pack, ref_input, ref_output):
if ref_input: if ref_input:
args = [nd.array(x, ctx) for x in ref_input] args = [nd.array(x, ctx) for x in ref_input]
else: else:
args = [nd.array(np.random.uniform(size=get_const_tuple(x.shape)).astype(x.dtype), args = [nd.empty(get_const_tuple(x.shape), dtype=x.dtype,
ctx) for x in arg_bufs] ctx=ctx) for x in arg_bufs]
costs = time_f(*args).results costs = time_f(*args).results
if len(costs) > 2: # remove largest and smallest value to reduce variance if len(costs) > 2: # remove largest and smallest value to reduce variance
costs = list(costs) costs = list(costs)
...@@ -173,7 +173,6 @@ def measure_rpc(input_pack, ...@@ -173,7 +173,6 @@ def measure_rpc(input_pack,
rpc_tracker_addr=None, rpc_tracker_addr=None,
rpc_priority=1, rpc_priority=1,
rpc_timeout=60, rpc_timeout=60,
tmp_dir=None,
**kwargs): **kwargs):
"""Measure the time cost on a device by rpc """Measure the time cost on a device by rpc
...@@ -198,9 +197,6 @@ def measure_rpc(input_pack, ...@@ -198,9 +197,6 @@ def measure_rpc(input_pack,
rpc_timeout: int, optional rpc_timeout: int, optional
timeout of the rpc session timeout of the rpc session
tmp_dir: tvm.contrib.util.TempDirectory, optional
directory to store temp file
kwargs: dict, optional kwargs: dict, optional
Additional key word arguments Additional key word arguments
...@@ -213,6 +209,7 @@ def measure_rpc(input_pack, ...@@ -213,6 +209,7 @@ def measure_rpc(input_pack,
""" Local build function.""" """ Local build function."""
func, args = _build_func(inp, build_option, kwargs) func, args = _build_func(inp, build_option, kwargs)
tmp_dir = util.tempdir()
if not kwargs.get('use_ndk', False): if not kwargs.get('use_ndk', False):
file_name = "tmp_func_%0x.tar" % getrandbits(64) file_name = "tmp_func_%0x.tar" % getrandbits(64)
path = tmp_dir.relpath(file_name) path = tmp_dir.relpath(file_name)
......
...@@ -9,11 +9,12 @@ import multiprocessing ...@@ -9,11 +9,12 @@ import multiprocessing
import pickle import pickle
import json import json
import time import time
import os
from collections import OrderedDict from collections import OrderedDict
import numpy as np import numpy as np
from .. import target, build, lower from .. import build, lower, target as _target
from . import task from . import task
from .task import DispatchContext, ConfigEntity from .task import DispatchContext, ConfigEntity
...@@ -26,6 +27,11 @@ try: # convert unicode to str for python2 ...@@ -26,6 +27,11 @@ try: # convert unicode to str for python2
except NameError: except NameError:
_unicode = () _unicode = ()
try:
_long = long
except NameError:
_long = int
def measure_str_key(inp, include_config=True): def measure_str_key(inp, include_config=True):
""" get unique str key for MeasureInput """ get unique str key for MeasureInput
...@@ -111,7 +117,7 @@ def decode(row, protocol='json'): ...@@ -111,7 +117,7 @@ def decode(row, protocol='json'):
if protocol == 'json': if protocol == 'json':
row = json.loads(row) row = json.loads(row)
tgt, task_name, task_args, task_kwargs, workload, config = row['i'] tgt, task_name, task_args, task_kwargs, workload, config = row['i']
tgt = target.create(str(tgt)) tgt = _target.create(str(tgt))
def clean_json_to_python(x): def clean_json_to_python(x):
"""1. convert all list in x to tuple (hashable) """1. convert all list in x to tuple (hashable)
...@@ -121,6 +127,8 @@ def decode(row, protocol='json'): ...@@ -121,6 +127,8 @@ def decode(row, protocol='json'):
return tuple([clean_json_to_python(a) for a in x]) return tuple([clean_json_to_python(a) for a in x])
if isinstance(x, _unicode): if isinstance(x, _unicode):
return str(x) return str(x)
if isinstance(x, (_long, int)):
return int(x)
return x return x
tsk = task.Task(clean_json_to_python(task_name), clean_json_to_python(task_args)) tsk = task.Task(clean_json_to_python(task_name), clean_json_to_python(task_args))
...@@ -132,7 +140,7 @@ def decode(row, protocol='json'): ...@@ -132,7 +140,7 @@ def decode(row, protocol='json'):
return inp, result return inp, result
elif protocol == 'pickle': elif protocol == 'pickle':
items = row.split("\t") items = row.split("\t")
tgt = target.create(items[0]) tgt = _target.create(items[0])
task_tuple = pickle.loads(base64.b64decode(items[1].encode())) task_tuple = pickle.loads(base64.b64decode(items[1].encode()))
config = pickle.loads(base64.b64decode(items[2].encode())) config = pickle.loads(base64.b64decode(items[2].encode()))
result = pickle.loads(base64.b64decode(items[3].encode())) result = pickle.loads(base64.b64decode(items[3].encode()))
...@@ -168,36 +176,70 @@ class ApplyHistoryBest(DispatchContext): ...@@ -168,36 +176,70 @@ class ApplyHistoryBest(DispatchContext):
---------- ----------
records : str or iterator of (MeasureInput, MeasureResult) records : str or iterator of (MeasureInput, MeasureResult)
Collection of tuning records. Collection of tuning records.
if is str, then it should be the filename of a records log file. If is str, then it should be the filename of a records log file.
Each row of this file is an encoded record pair. Each row of this file is an encoded record pair.
otherwise, it is an iterator Otherwise, it is an iterator.
default: ConfigEntity, optional default: ConfigEntity, optional
default config to return when no history records The default config to return when no history records
""" """
def __init__(self, records, default=None): def __init__(self, records, default=None):
super(ApplyHistoryBest, self).__init__() super(ApplyHistoryBest, self).__init__()
self.best_by_targetkey = {}
self.best_by_model = {}
self._default = default
self.load(records)
def load(self, records):
"""Load records to this dispatch context
Parameters
----------
records : str or iterator of (MeasureInput, MeasureResult)
Collection of tuning records.
If is str, then it should be the filename of a records log file.
Each row of this file is an encoded record pair.
Otherwise, it is an iterator.
"""
if isinstance(records, str): if isinstance(records, str):
records = load_from_file(records) records = load_from_file(records)
if not records:
return
best_by_targetkey = self.best_by_targetkey
best_by_model = self.best_by_model
counter = 0 counter = 0
best_map = {}
for inp, res in records: for inp, res in records:
counter += 1 counter += 1
if res.error_no != 0: if res.error_no != 0:
continue continue
# use target keys in tvm target system as key to build best map
for k in inp.target.keys: for k in inp.target.keys:
key = (k, inp.task.workload) key = (k, inp.task.workload)
if key not in best_map: if key not in best_by_targetkey:
best_map[key] = (inp, res) best_by_targetkey[key] = (inp, res)
else: else:
_, other_res = best_map[key] _, other_res = best_by_targetkey[key]
if np.mean(other_res.costs) > np.mean(res.costs): if np.mean(other_res.costs) > np.mean(res.costs):
best_map[key] = (inp, res) best_by_targetkey[key] = (inp, res)
logging.info(
"Finish load %d records, %d entries selected", counter, len(best_map)) # use model as key to build best map
self._best_map = best_map for opt in inp.target.options:
self._default = default if opt.startswith("-model"):
model = opt[7:]
key = (model, inp.task.workload)
if key not in best_by_model:
best_by_model[key] = (inp, res)
else:
_, other_res = best_by_model[key]
if np.mean(other_res.costs) > np.mean(res.costs):
best_by_model[key] = (inp, res)
break
logging.info("Finish loading %d records", counter)
def query(self, target, workload): def query(self, target, workload):
if target is None: if target is None:
...@@ -205,29 +247,25 @@ class ApplyHistoryBest(DispatchContext): ...@@ -205,29 +247,25 @@ class ApplyHistoryBest(DispatchContext):
"Hint: If your target is llvm, use `with tvm.target.create('llvm'):`" "Hint: If your target is llvm, use `with tvm.target.create('llvm'):`"
" above the dispatcher call. So does other target. ") " above the dispatcher call. So does other target. ")
# first try matching by model
for opt in target.options:
if opt.startswith("-model"):
model = opt[7:]
key = (model, workload)
if key in self.best_by_model:
return self.best_by_model[key][0].config
# then try matching by target key
for k in target.keys: for k in target.keys:
key = (k, workload) key = (k, workload)
if key in self._best_map: if key in self.best_by_targetkey:
return self._best_map[key][0].config return self.best_by_targetkey[key][0].config
if self._default: if self._default:
return self._default return self._default
raise RuntimeError( raise RuntimeError(
"Cannot find config for target=%s, workload=%s" % (target, workload)) "Cannot find config for target=%s, workload=%s" % (target, workload))
def dump_best(self, out_file):
"""Dump the best records for each workload to a file
Parameters
----------
out_file: str
filename
"""
fout = open(out_file, 'a')
for val in self._best_map.values():
inp, res = val
fout.write(encode(inp, res) + '\n')
def split_workload(in_file, clean=True): def split_workload(in_file, clean=True):
"""Split a log file into separate files, each of which contains only a single workload """Split a log file into separate files, each of which contains only a single workload
...@@ -243,7 +281,7 @@ def split_workload(in_file, clean=True): ...@@ -243,7 +281,7 @@ def split_workload(in_file, clean=True):
tic = time.time() tic = time.time()
lines = list(open(in_file).readlines()) lines = list(open(in_file).readlines())
logging.info("start convert...") logging.info("start converting...")
pool = multiprocessing.Pool() pool = multiprocessing.Pool()
lines = pool.map(decode, lines) lines = pool.map(decode, lines)
logging.info("map done %.2f", time.time() - tic) logging.info("map done %.2f", time.time() - tic)
...@@ -279,23 +317,69 @@ def split_workload(in_file, clean=True): ...@@ -279,23 +317,69 @@ def split_workload(in_file, clean=True):
for inp, res in v: for inp, res in v:
fout.write(encode(inp, res) + '\n') fout.write(encode(inp, res) + '\n')
def pick_best(in_file, out_file):
"""
Pick best entries from a file and store it to another file.
This distill the useful log entries from a large log file.
Parameters
----------
in_file: str
The filename of input
out_file:
The filename of output
"""
best_context = ApplyHistoryBest(load_from_file(in_file))
best_set = set()
for v in best_context.best_by_model.values():
best_set.add(measure_str_key(v[0]))
for v in best_context.best_by_targetkey.values():
best_set.add(measure_str_key(v[0]))
logging.info("Extract %d best records from the log file", len(best_set))
fout = open(out_file, 'w')
for inp, res in load_from_file(in_file):
if measure_str_key(inp) in best_set:
fout.write(encode(inp, res) + "\n")
def load_op_param(rootpath=os.path.join(os.path.expanduser('~'), ".tvm", "op_params")):
"""Load pre-tuned parameters of operators.
This function will load all "*.log" file under root path and select best configs.
Parameters
----------
rootpath: str
The root path of stored parameters
"""
best_context = ApplyHistoryBest([])
for dirpath, _, filenames in os.walk(rootpath):
for filename in filenames:
if os.path.splitext(filename)[1] == '.log':
best_context.load(os.path.join(dirpath, filename))
assert not DispatchContext.current, "Cannot load pre-tuned parameters inside a dispatch context"
DispatchContext.current = best_context
""" """
Usage: Usage:
This record executable module has three modes. This record executable module has three modes.
* Print log file in readable format * Print log file in readable format
e.g. python -m autotvm.record --mode read --i collect_conv.tsv --begin 0 --end 5 --ir --code e.g. python -m autotvm.record --mode read --i collect_conv.log --begin 0 --end 5 --ir --code
* Extract history best from a large log file * Extract history best from a large log file
e.g. python -m autotvm.record --mode best --i collect.tsv e.g. python -m autotvm.record --mode pick --i collect.log
* Split a log file into separate files, each of which contains only a single wkl * Split a log file into separate files, each of which contains only a single wkl
e.g. python -m autotvm.record --mode split --i collect.tsv e.g. python -m autotvm.record --mode split --i collect.log
""" """
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--mode", choices=['read', 'best', 'split'], default='read') parser.add_argument("--mode", choices=['read', 'pick', 'split'], default='read')
parser.add_argument("--i", type=str, help="input file") parser.add_argument("--i", type=str, help="input file")
parser.add_argument("--o", type=str, default=None, help='output file') parser.add_argument("--o", type=str, default=None, help='output file')
parser.add_argument("--begin", type=int, default=0) parser.add_argument("--begin", type=int, default=0)
...@@ -306,10 +390,9 @@ if __name__ == '__main__': ...@@ -306,10 +390,9 @@ if __name__ == '__main__':
args = parser.parse_args() args = parser.parse_args()
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
if args.mode == 'best': if args.mode == 'pick':
args.o = args.o or args.i + ".best" args.o = args.o or args.i + ".best.log"
hist_best = ApplyHistoryBest(load_from_file(args.i)) pick_best(args.i, args.o)
hist_best.dump_best(args.o)
elif args.mode == 'read': elif args.mode == 'read':
for i, (inp, result) in enumerate(load_from_file(args.i)): for i, (inp, result) in enumerate(load_from_file(args.i)):
if args.begin <= i < args.end: if args.begin <= i < args.end:
......
...@@ -6,7 +6,7 @@ This module defines the task data structure, as well as a collection(zoo) ...@@ -6,7 +6,7 @@ This module defines the task data structure, as well as a collection(zoo)
of typical tasks of interest. of typical tasks of interest.
""" """
from .task import Task, create, register, template, get_config from .task import Task, create, register, template, get_config, args_to_workload
from .space import ConfigSpace, ConfigEntity from .space import ConfigSpace, ConfigEntity
from .code_hash import attach_code_hash, attach_code_hash_to_arg from .code_hash import attach_code_hash, attach_code_hash_to_arg
from .dispatcher import DispatchContext, ApplyConfig, dispatcher from .dispatcher import DispatchContext, ApplyConfig, dispatcher
...@@ -68,6 +68,33 @@ class Task(object): ...@@ -68,6 +68,33 @@ class Task(object):
self.flop = config.flop self.flop = config.flop
return sch, arg_bufs return sch, arg_bufs
def __getstate__(self):
# custom pickle implementation is required for
# some unpickable local task functions.
# So we only pickle the name of the function
# and restore the function by name when unpickling it.
return {
"name": self.name,
"args": self.args,
"kwargs": self.kwargs,
"config_space": self.config_space,
"workload": self.workload,
"flop": self.flop,
"target": self.target,
"target_host": self.target_host
}
def __setstate__(self, state):
self.name = state["name"]
self.args = state["args"]
self.kwargs = state["kwargs"]
self.config_space = state["config_space"]
self.func = TASK_TABLE.get(state["name"], _raise_error)
self.workload = state["workload"]
self.flop = state["flop"]
self.target = state["target"]
self.target_host = state["target_host"]
def __repr__(self): def __repr__(self):
return "Task(func_name=%s, args=%s, kwargs=%s, workload=%s)" % ( return "Task(func_name=%s, args=%s, kwargs=%s, workload=%s)" % (
self.name, self.args, self.kwargs, self.workload self.name, self.args, self.kwargs, self.workload
......
...@@ -264,12 +264,23 @@ class ModelBasedTuner(Tuner): ...@@ -264,12 +264,23 @@ class ModelBasedTuner(Tuner):
self.train_ct += 1 self.train_ct += 1
def load_history(self, data_set): def load_history(self, data_set):
# filter data, only pick the data with a same task
data = []
for inp, res in data_set:
if inp.task.name == self.task.name and \
inp.config.template_key == self.task.config_space.template_key:
data.append((inp, res))
if not data:
return
# fit base model
base_model = self.cost_model.clone_new() base_model = self.cost_model.clone_new()
base_model.fit_log(data_set, self.plan_size) base_model.fit_log(data, self.plan_size)
# use base model to select initial points
if not self.trials: if not self.trials:
# no plan yet, use base model to select initial trials # no plan yet, use base model to select initial trials
maximums = self.model_optimizer.find_maximums(base_model, self.visited) maximums = self.model_optimizer.find_maximums(base_model, self.plan_size, self.visited)
self.trials = maximums self.trials = maximums
self.trial_pt = 0 self.trial_pt = 0
......
...@@ -30,7 +30,7 @@ class SimulatedAnnealingOptimizer(ModelOptimizer): ...@@ -30,7 +30,7 @@ class SimulatedAnnealingOptimizer(ModelOptimizer):
Print log every `verbose` iterations Print log every `verbose` iterations
""" """
def __init__(self, task, n_iter=500, temp=(1, 0), persistent=True, parallel_size=128, def __init__(self, task, n_iter=500, temp=(1, 0), persistent=True, parallel_size=128,
early_stop=30, verbose=50): early_stop=50, verbose=50):
super(SimulatedAnnealingOptimizer, self).__init__() super(SimulatedAnnealingOptimizer, self).__init__()
self.task = task self.task = task
...@@ -39,8 +39,8 @@ class SimulatedAnnealingOptimizer(ModelOptimizer): ...@@ -39,8 +39,8 @@ class SimulatedAnnealingOptimizer(ModelOptimizer):
self.n_iter = n_iter self.n_iter = n_iter
self.temp = temp self.temp = temp
self.persistent = persistent self.persistent = persistent
self.parallel_size = parallel_size self.parallel_size = min(parallel_size, len(self.task.config_space))
self.early_stop = early_stop self.early_stop = early_stop or 1e9
self.verbose = verbose self.verbose = verbose
self.points = None self.points = None
......
...@@ -27,6 +27,7 @@ class Tuner(object): ...@@ -27,6 +27,7 @@ class Tuner(object):
self.best_config = None self.best_config = None
self.best_flops = 0 self.best_flops = 0
self.best_measure_pair = None self.best_measure_pair = None
self.best_iter = 0
def has_next(self): def has_next(self):
"""Whether has next untried config in the space """Whether has next untried config in the space
...@@ -63,7 +64,7 @@ class Tuner(object): ...@@ -63,7 +64,7 @@ class Tuner(object):
""" """
pass pass
def tune(self, n_trial, measure_option, verbose=1, callbacks=()): def tune(self, n_trial, measure_option, early_stop=None, verbose=1, callbacks=()):
"""Begin tuning """Begin tuning
Parameters Parameters
...@@ -73,6 +74,8 @@ class Tuner(object): ...@@ -73,6 +74,8 @@ class Tuner(object):
measure_option: dict measure_option: dict
The options for how to measure generated code. The options for how to measure generated code.
You should use the return value ot autotvm.measure_option for this argument. You should use the return value ot autotvm.measure_option for this argument.
early_stop: int
Early stop the tuning when not finding better configs in this number of trials
verbose: int verbose: int
0: silent mode, no output 0: silent mode, no output
1: print every measurement result 1: print every measurement result
...@@ -84,6 +87,7 @@ class Tuner(object): ...@@ -84,6 +87,7 @@ class Tuner(object):
""" """
measure_batch = create_measure_batch(self.task, measure_option) measure_batch = create_measure_batch(self.task, measure_option)
parallel_num = getattr(measure_batch, 'parallel_num', 1) parallel_num = getattr(measure_batch, 'parallel_num', 1)
early_stop = early_stop or 1e9
i = 0 i = 0
while i < n_trial: while i < n_trial:
...@@ -107,6 +111,7 @@ class Tuner(object): ...@@ -107,6 +111,7 @@ class Tuner(object):
self.best_flops = flops self.best_flops = flops
self.best_config = config self.best_config = config
self.best_measure_pair = (inp, res) self.best_measure_pair = (inp, res)
self.best_iter = i + k
logging.info("No: %d\tGFLOPS: %.2f/%.2f\tresult: %s\t%s", logging.info("No: %d\tGFLOPS: %.2f/%.2f\tresult: %s\t%s",
i + k + 1, flops / 1e9, self.best_flops / 1e9, i + k + 1, flops / 1e9, self.best_flops / 1e9,
...@@ -119,6 +124,10 @@ class Tuner(object): ...@@ -119,6 +124,10 @@ class Tuner(object):
for callback in callbacks: for callback in callbacks:
callback(self, inputs, results) callback(self, inputs, results)
if i > self.best_iter + early_stop:
logging.info("Early stopped. Best iter: %d.", self.best_iter)
break
del measure_batch del measure_batch
def reset(self): def reset(self):
......
...@@ -111,6 +111,9 @@ class XGBoostCostModel(CostModel): ...@@ -111,6 +111,9 @@ class XGBoostCostModel(CostModel):
self.feature_extra_ct = 0 self.feature_extra_ct = 0
self.pool = None self.pool = None
self.base_model = None self.base_model = None
self.upper_model = None
self._sample_size = 0
self._reset_pool() self._reset_pool()
...@@ -127,20 +130,25 @@ class XGBoostCostModel(CostModel): ...@@ -127,20 +130,25 @@ class XGBoostCostModel(CostModel):
_extract_task = self.task _extract_task = self.task
self.pool = multiprocessing.Pool(self.num_threads) self.pool = multiprocessing.Pool(self.num_threads)
def _base_model_discount(self):
return 1.0 / (2 ** (self._sample_size / 50.0))
def fit(self, xs, ys, plan_size): def fit(self, xs, ys, plan_size):
tic = time.time() tic = time.time()
self._reset_pool() self._reset_pool()
x_train = self._get_feature(xs) x_train = self._get_feature(xs)
y_train = np.array(ys) y_train = np.array(ys)
y_train /= np.max(y_train) y_train = y_train / np.max(y_train)
valid_index = y_train > 1e-6 valid_index = y_train > 1e-6
index = np.random.permutation(len(x_train)) index = np.random.permutation(len(x_train))
dtrain = xgb.DMatrix(x_train[index], y_train[index]) dtrain = xgb.DMatrix(x_train[index], y_train[index])
self._sample_size = len(x_train)
if self.base_model: if self.base_model:
dtrain.set_base_margin(self.base_model.predict(xs, output_margin=True)) dtrain.set_base_margin(self._base_model_discount() *
self.base_model.predict(xs, output_margin=True))
self.bst = xgb.train(self.xgb_params, dtrain, self.bst = xgb.train(self.xgb_params, dtrain,
num_boost_round=8000, num_boost_round=8000,
...@@ -164,6 +172,7 @@ class XGBoostCostModel(CostModel): ...@@ -164,6 +172,7 @@ class XGBoostCostModel(CostModel):
self._reset_pool() self._reset_pool()
args = list(records) args = list(records)
logging.info("Load %d entries from history log file", len(args))
if self.fea_type == 'itervar': if self.fea_type == 'itervar':
feature_extract_func = _extract_itervar_feature_log feature_extract_func = _extract_itervar_feature_log
elif self.fea_type == 'knob': elif self.fea_type == 'knob':
...@@ -185,7 +194,7 @@ class XGBoostCostModel(CostModel): ...@@ -185,7 +194,7 @@ class XGBoostCostModel(CostModel):
plan_size *= 2 plan_size *= 2
self.bst = xgb.train(self.xgb_params, dtrain, self.bst = xgb.train(self.xgb_params, dtrain,
num_boost_round=200, num_boost_round=400,
callbacks=[custom_callback( callbacks=[custom_callback(
stopping_rounds=100, stopping_rounds=100,
metric='tr-a-recall@%d' % plan_size, metric='tr-a-recall@%d' % plan_size,
...@@ -203,12 +212,23 @@ class XGBoostCostModel(CostModel): ...@@ -203,12 +212,23 @@ class XGBoostCostModel(CostModel):
dtest = xgb.DMatrix(feas) dtest = xgb.DMatrix(feas)
if self.base_model: if self.base_model:
dtest.set_base_margin(self.base_model.predict(xs, output_margin=True)) dtest.set_base_margin(self._base_model_discount() *
self.base_model.predict(xs, output_margin=True))
return self.bst.predict(dtest, output_margin=output_margin) return self.bst.predict(dtest, output_margin=output_margin)
def load_basemodel(self, base_model): def load_basemodel(self, base_model):
self.base_model = base_model self.base_model = base_model
if isinstance(base_model, XGBoostCostModel):
# share feature cache
base_model.feature_cache = self.feature_cache
# close thread pool
if base_model.pool:
base_model.pool.terminate()
base_model.pool.join()
del base_model.pool
self.base_model.upper_model = self
def clone_new(self): def clone_new(self):
return XGBoostCostModel(self.task, self.fea_type, self.loss_type, return XGBoostCostModel(self.task, self.fea_type, self.loss_type,
...@@ -226,7 +246,8 @@ class XGBoostCostModel(CostModel): ...@@ -226,7 +246,8 @@ class XGBoostCostModel(CostModel):
need_extract = [x for x in indexes if x not in fea_cache] need_extract = [x for x in indexes if x not in fea_cache]
if need_extract: if need_extract:
feas = self.pool.map(self.feature_extract_func, need_extract) pool = self.pool if self.upper_model is None else self.upper_model.pool
feas = pool.map(self.feature_extract_func, need_extract)
for i, fea in zip(need_extract, feas): for i, fea in zip(need_extract, feas):
fea_cache[i] = fea fea_cache[i] = fea
......
...@@ -346,6 +346,7 @@ def generic_func(fdefault): ...@@ -346,6 +346,7 @@ def generic_func(fdefault):
return func(*args, **kwargs) return func(*args, **kwargs)
fdecorate = decorate(fdefault, dispatch_func) fdecorate = decorate(fdefault, dispatch_func)
fdecorate.register = register fdecorate.register = register
fdecorate.fdefault = fdefault
return fdecorate return fdecorate
......
import time
import numpy as np
import tvm
from tvm import autotvm
from tvm.autotvm import MeasureInput, MeasureResult
from tvm.autotvm.tuner.xgboost_cost_model import XGBoostCostModel
from test_autotvm_common import get_sample_task, get_sample_records
def test_fit():
task, target = get_sample_task()
records = get_sample_records(n=100)
base_model = XGBoostCostModel(task, feature_type='itervar', loss_type='rank')
base_model.fit_log(records, plan_size=32)
upper_model = XGBoostCostModel(task, feature_type='itervar', loss_type='rank')
upper_model.load_basemodel(base_model)
xs = np.arange(100)
ys = np.arange(100)
upper_model.fit(xs, ys, plan_size=32)
def test_tuner():
task, target = get_sample_task()
records = get_sample_records(n=100)
tuner = autotvm.tuner.XGBTuner(task)
tuner.load_history(records)
if __name__ == "__main__":
test_fit()
test_tuner()
""" """
How to get high performance convolution kernel on NVIDIA GPU by auto-tuning Tuning High Performance Convolution on NVIDIA GPUs
========================================================================= =========================================================================
**Author**: `Lianmin Zheng <https://https://github.com/merrymercy>`_ **Author**: `Lianmin Zheng <https://https://github.com/merrymercy>`_
...@@ -10,9 +10,11 @@ vendor provided library CuDNN in many cases. ...@@ -10,9 +10,11 @@ vendor provided library CuDNN in many cases.
import logging import logging
import sys import sys
import numpy as np
import tvm import tvm
import topi import topi
from topi.testing import conv2d_nchw_python
from tvm import autotvm from tvm import autotvm
...@@ -134,8 +136,9 @@ def conv2d_no_batching(N, H, W, CI, CO, KH, KW, stride, padding): ...@@ -134,8 +136,9 @@ def conv2d_no_batching(N, H, W, CI, CO, KH, KW, stride, padding):
logging.basicConfig(level=logging.INFO, stream=sys.stdout) logging.basicConfig(level=logging.INFO, stream=sys.stdout)
# the last layer in resnet # the last layer in resnet
N, H, W, CO, CI, KH, KW, strides, padding = 1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1)
task = autotvm.task.create(conv2d_no_batching, task = autotvm.task.create(conv2d_no_batching,
args=(1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1)), args=(N, H, W, CO, CI, KH, KW, strides, padding),
target='cuda') target='cuda')
print(task.config_space) print(task.config_space)
...@@ -146,15 +149,43 @@ measure_option = autotvm.measure_option(mode='local', ...@@ -146,15 +149,43 @@ measure_option = autotvm.measure_option(mode='local',
parallel_num=8, parallel_num=8,
timeout=20) timeout=20)
# begin tuning, log records to file `cache.tsv` # begin tuning, log records to file `conv2d.tsv`
tuner = autotvm.tuner.XGBTuner(task) tuner = autotvm.tuner.XGBTuner(task)
tuner.tune(n_trial=20, tuner.tune(n_trial=20,
measure_option=measure_option, measure_option=measure_option,
callbacks=[autotvm.callback.log_to_file('cache.tsv')]) callbacks=[autotvm.callback.log_to_file('conv2d.log')])
# get best config from cache file #########################################################################
dispatch_context = autotvm.apply_history_best("cache.tsv") # Finally we can inspect the best config from log file, check correctness,
# and measure running time.
# inspect the best config
dispatch_context = autotvm.apply_history_best("conv2d.log")
best_config = dispatch_context.query(task.target, task.workload) best_config = dispatch_context.query(task.target, task.workload)
print("\nBest config:") print("\nBest config:")
print(best_config) print(best_config)
# apply history best from log file
with autotvm.apply_history_best('conv2d.log'):
with tvm.target.create("cuda"):
s, arg_bufs = conv2d_no_batching(N, H, W, CO, CI, KH, KW, strides, padding)
func = tvm.build(s, arg_bufs)
# check correctness
a_np = np.random.uniform(size=(N, CI, H, W)).astype(np.float32)
w_np = np.random.uniform(size=(CO, CI, KH, KW)).astype(np.float32)
c_np = conv2d_nchw_python(a_np, w_np, strides, padding)
ctx = tvm.gpu()
a_tvm = tvm.nd.array(a_np, ctx=ctx)
w_tvm = tvm.nd.array(w_np, ctx=ctx)
c_tvm = tvm.nd.empty(c_np.shape, ctx=ctx)
func(a_tvm, w_tvm, c_tvm)
np.testing.assert_allclose(c_np, c_tvm.asnumpy(), rtol=1e-2)
# Evaluate running time. Here we choose a large repeat number (200) to reduce the noise
# and the overhead of kernel launch. You can also use nvprof to validate the result.
evaluator = func.time_evaluator(func.entry_name, ctx, number=200)
print('Time cost of this operator: %f' % evaluator(a_tvm, w_tvm, c_tvm).mean)
...@@ -243,7 +243,7 @@ print(task.config_space) ...@@ -243,7 +243,7 @@ print(task.config_space)
# #
# We only make 10 trials in this tutorial for demonstration. In practice, # We only make 10 trials in this tutorial for demonstration. In practice,
# you can do more trials according to your time budget. # you can do more trials according to your time budget.
# We will log the tuning results into a cache file. This file can be # We will log the tuning results into a log file. This file can be
# used to get the best config later. # used to get the best config later.
# logging config (for printing tuning log to screen) # logging config (for printing tuning log to screen)
...@@ -253,11 +253,11 @@ logging.basicConfig(level=logging.INFO, stream=sys.stdout) ...@@ -253,11 +253,11 @@ logging.basicConfig(level=logging.INFO, stream=sys.stdout)
measure_option = autotvm.measure_option(mode='local', measure_option = autotvm.measure_option(mode='local',
number=5) number=5)
# begin tuning, log records to file `cache.tsv` # begin tuning, log records to file `matmul.log`
tuner = autotvm.tuner.RandomTuner(task) tuner = autotvm.tuner.RandomTuner(task)
tuner.tune(n_trial=10, tuner.tune(n_trial=10,
measure_option=measure_option, measure_option=measure_option,
callbacks=[autotvm.callback.log_to_file('cache.tsv')]) callbacks=[autotvm.callback.log_to_file('matmul.log')])
######################################################################### #########################################################################
# Finally we apply history best from the cache file and check its correctness. # Finally we apply history best from the cache file and check its correctness.
...@@ -267,7 +267,7 @@ tuner.tune(n_trial=10, ...@@ -267,7 +267,7 @@ tuner.tune(n_trial=10,
# with the same argument. # with the same argument.
# apply history best from log file # apply history best from log file
with autotvm.apply_history_best('cache.tsv'): with autotvm.apply_history_best('matmul.log'):
with tvm.target.create("llvm"): with tvm.target.create("llvm"):
s, arg_bufs = matmul(N, L, M, 'float32') s, arg_bufs = matmul(N, L, M, 'float32')
func = tvm.build(s, arg_bufs) func = tvm.build(s, arg_bufs)
...@@ -281,4 +281,3 @@ c_tvm = tvm.nd.empty(c_np.shape) ...@@ -281,4 +281,3 @@ c_tvm = tvm.nd.empty(c_np.shape)
func(tvm.nd.array(a_np), tvm.nd.array(b_np), c_tvm) func(tvm.nd.array(a_np), tvm.nd.array(b_np), c_tvm)
np.testing.assert_allclose(c_np, c_tvm.asnumpy(), rtol=1e-2) np.testing.assert_allclose(c_np, c_tvm.asnumpy(), rtol=1e-2)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment