Commit ad28f5ca by Lianmin Zheng Committed by Tianqi Chen

[AUTOTVM] Misc bug fix (#1467)

parent 9026f3fc
......@@ -6,7 +6,6 @@ from collections import namedtuple
import numpy as np
from ... import build, nd, target as _target
from ...contrib.util import tempdir
from ...rpc.tracker import Tracker
from ...rpc.server import Server
......@@ -209,14 +208,12 @@ def create_measure_batch(task, options):
kwargs['rpc_device_key'] = rpc_device_key
kwargs['rpc_tracker_addr'] = (tracker.host, tracker.port)
kwargs['rpc_timeout'] = timeout
kwargs['tmp_dir'] = tempdir()
elif mode == 'rpc':
fmeasure = measure_methods.measure_rpc
kwargs['rpc_device_key'] = rpc_device_key
kwargs['rpc_priority'] = rpc_priority
kwargs['rpc_timeout'] = rpc_timeout
kwargs['use_ndk'] = use_ndk
kwargs['tmp_dir'] = tempdir()
assert rpc_device_key, "In rpc mode, a rpc_device_key must be provided"
elif mode == "custom":
assert callable(custom_measure_batch), "In custom mode, custom_measure_func " \
......@@ -243,7 +240,7 @@ def create_measure_batch(task, options):
tvm_buf = [nd.array(x) for x in ref_input]
func(*tvm_buf)
ref_output = [x.asnumpy() for x in tvm_buf]
kwargs['ref_input'], kwargs['ref_outpu'] = ref_input, ref_output
kwargs['ref_input'], kwargs['ref_output'] = ref_input, ref_output
def measure_batch(measure_inputs):
"""measure the time cost for a batch of configs in real machines"""
......
......@@ -12,7 +12,7 @@ from random import getrandbits
import numpy as np
from ...contrib import ndk, nvcc
from ...contrib import ndk, nvcc, util
from ... import rpc, ir_pass, build, build_config, nd, context, TVMError, register_func
from ..util import get_const_tuple
......@@ -113,8 +113,8 @@ def _measure_generic(fbuild, input_pack, ref_input, ref_output):
if ref_input:
args = [nd.array(x, ctx) for x in ref_input]
else:
args = [nd.array(np.random.uniform(size=get_const_tuple(x.shape)).astype(x.dtype),
ctx) for x in arg_bufs]
args = [nd.empty(get_const_tuple(x.shape), dtype=x.dtype,
ctx=ctx) for x in arg_bufs]
costs = time_f(*args).results
if len(costs) > 2: # remove largest and smallest value to reduce variance
costs = list(costs)
......@@ -173,7 +173,6 @@ def measure_rpc(input_pack,
rpc_tracker_addr=None,
rpc_priority=1,
rpc_timeout=60,
tmp_dir=None,
**kwargs):
"""Measure the time cost on a device by rpc
......@@ -198,9 +197,6 @@ def measure_rpc(input_pack,
rpc_timeout: int, optional
timeout of the rpc session
tmp_dir: tvm.contrib.util.TempDirectory, optional
directory to store temp file
kwargs: dict, optional
Additional key word arguments
......@@ -213,6 +209,7 @@ def measure_rpc(input_pack,
""" Local build function."""
func, args = _build_func(inp, build_option, kwargs)
tmp_dir = util.tempdir()
if not kwargs.get('use_ndk', False):
file_name = "tmp_func_%0x.tar" % getrandbits(64)
path = tmp_dir.relpath(file_name)
......
......@@ -6,7 +6,7 @@ This module defines the task data structure, as well as a collection(zoo)
of typical tasks of interest.
"""
from .task import Task, create, register, template, get_config
from .task import Task, create, register, template, get_config, args_to_workload
from .space import ConfigSpace, ConfigEntity
from .code_hash import attach_code_hash, attach_code_hash_to_arg
from .dispatcher import DispatchContext, ApplyConfig, dispatcher
......@@ -68,6 +68,33 @@ class Task(object):
self.flop = config.flop
return sch, arg_bufs
def __getstate__(self):
# custom pickle implementation is required for
# some unpickable local task functions.
# So we only pickle the name of the function
# and restore the function by name when unpickling it.
return {
"name": self.name,
"args": self.args,
"kwargs": self.kwargs,
"config_space": self.config_space,
"workload": self.workload,
"flop": self.flop,
"target": self.target,
"target_host": self.target_host
}
def __setstate__(self, state):
self.name = state["name"]
self.args = state["args"]
self.kwargs = state["kwargs"]
self.config_space = state["config_space"]
self.func = TASK_TABLE.get(state["name"], _raise_error)
self.workload = state["workload"]
self.flop = state["flop"]
self.target = state["target"]
self.target_host = state["target_host"]
def __repr__(self):
return "Task(func_name=%s, args=%s, kwargs=%s, workload=%s)" % (
self.name, self.args, self.kwargs, self.workload
......
......@@ -264,12 +264,23 @@ class ModelBasedTuner(Tuner):
self.train_ct += 1
def load_history(self, data_set):
# filter data, only pick the data with a same task
data = []
for inp, res in data_set:
if inp.task.name == self.task.name and \
inp.config.template_key == self.task.config_space.template_key:
data.append((inp, res))
if not data:
return
# fit base model
base_model = self.cost_model.clone_new()
base_model.fit_log(data_set, self.plan_size)
base_model.fit_log(data, self.plan_size)
# use base model to select initial points
if not self.trials:
# no plan yet, use base model to select initial trials
maximums = self.model_optimizer.find_maximums(base_model, self.visited)
maximums = self.model_optimizer.find_maximums(base_model, self.plan_size, self.visited)
self.trials = maximums
self.trial_pt = 0
......
......@@ -30,7 +30,7 @@ class SimulatedAnnealingOptimizer(ModelOptimizer):
Print log every `verbose` iterations
"""
def __init__(self, task, n_iter=500, temp=(1, 0), persistent=True, parallel_size=128,
early_stop=30, verbose=50):
early_stop=50, verbose=50):
super(SimulatedAnnealingOptimizer, self).__init__()
self.task = task
......@@ -39,8 +39,8 @@ class SimulatedAnnealingOptimizer(ModelOptimizer):
self.n_iter = n_iter
self.temp = temp
self.persistent = persistent
self.parallel_size = parallel_size
self.early_stop = early_stop
self.parallel_size = min(parallel_size, len(self.task.config_space))
self.early_stop = early_stop or 1e9
self.verbose = verbose
self.points = None
......
......@@ -27,6 +27,7 @@ class Tuner(object):
self.best_config = None
self.best_flops = 0
self.best_measure_pair = None
self.best_iter = 0
def has_next(self):
"""Whether has next untried config in the space
......@@ -63,7 +64,7 @@ class Tuner(object):
"""
pass
def tune(self, n_trial, measure_option, verbose=1, callbacks=()):
def tune(self, n_trial, measure_option, early_stop=None, verbose=1, callbacks=()):
"""Begin tuning
Parameters
......@@ -73,6 +74,8 @@ class Tuner(object):
measure_option: dict
The options for how to measure generated code.
You should use the return value ot autotvm.measure_option for this argument.
early_stop: int
Early stop the tuning when not finding better configs in this number of trials
verbose: int
0: silent mode, no output
1: print every measurement result
......@@ -84,6 +87,7 @@ class Tuner(object):
"""
measure_batch = create_measure_batch(self.task, measure_option)
parallel_num = getattr(measure_batch, 'parallel_num', 1)
early_stop = early_stop or 1e9
i = 0
while i < n_trial:
......@@ -107,6 +111,7 @@ class Tuner(object):
self.best_flops = flops
self.best_config = config
self.best_measure_pair = (inp, res)
self.best_iter = i + k
logging.info("No: %d\tGFLOPS: %.2f/%.2f\tresult: %s\t%s",
i + k + 1, flops / 1e9, self.best_flops / 1e9,
......@@ -119,6 +124,10 @@ class Tuner(object):
for callback in callbacks:
callback(self, inputs, results)
if i > self.best_iter + early_stop:
logging.info("Early stopped. Best iter: %d.", self.best_iter)
break
del measure_batch
def reset(self):
......
......@@ -111,6 +111,9 @@ class XGBoostCostModel(CostModel):
self.feature_extra_ct = 0
self.pool = None
self.base_model = None
self.upper_model = None
self._sample_size = 0
self._reset_pool()
......@@ -127,20 +130,25 @@ class XGBoostCostModel(CostModel):
_extract_task = self.task
self.pool = multiprocessing.Pool(self.num_threads)
def _base_model_discount(self):
return 1.0 / (2 ** (self._sample_size / 50.0))
def fit(self, xs, ys, plan_size):
tic = time.time()
self._reset_pool()
x_train = self._get_feature(xs)
y_train = np.array(ys)
y_train /= np.max(y_train)
y_train = y_train / np.max(y_train)
valid_index = y_train > 1e-6
index = np.random.permutation(len(x_train))
dtrain = xgb.DMatrix(x_train[index], y_train[index])
self._sample_size = len(x_train)
if self.base_model:
dtrain.set_base_margin(self.base_model.predict(xs, output_margin=True))
dtrain.set_base_margin(self._base_model_discount() *
self.base_model.predict(xs, output_margin=True))
self.bst = xgb.train(self.xgb_params, dtrain,
num_boost_round=8000,
......@@ -164,6 +172,7 @@ class XGBoostCostModel(CostModel):
self._reset_pool()
args = list(records)
logging.info("Load %d entries from history log file", len(args))
if self.fea_type == 'itervar':
feature_extract_func = _extract_itervar_feature_log
elif self.fea_type == 'knob':
......@@ -185,7 +194,7 @@ class XGBoostCostModel(CostModel):
plan_size *= 2
self.bst = xgb.train(self.xgb_params, dtrain,
num_boost_round=200,
num_boost_round=400,
callbacks=[custom_callback(
stopping_rounds=100,
metric='tr-a-recall@%d' % plan_size,
......@@ -203,12 +212,23 @@ class XGBoostCostModel(CostModel):
dtest = xgb.DMatrix(feas)
if self.base_model:
dtest.set_base_margin(self.base_model.predict(xs, output_margin=True))
dtest.set_base_margin(self._base_model_discount() *
self.base_model.predict(xs, output_margin=True))
return self.bst.predict(dtest, output_margin=output_margin)
def load_basemodel(self, base_model):
self.base_model = base_model
if isinstance(base_model, XGBoostCostModel):
# share feature cache
base_model.feature_cache = self.feature_cache
# close thread pool
if base_model.pool:
base_model.pool.terminate()
base_model.pool.join()
del base_model.pool
self.base_model.upper_model = self
def clone_new(self):
return XGBoostCostModel(self.task, self.fea_type, self.loss_type,
......@@ -226,7 +246,8 @@ class XGBoostCostModel(CostModel):
need_extract = [x for x in indexes if x not in fea_cache]
if need_extract:
feas = self.pool.map(self.feature_extract_func, need_extract)
pool = self.pool if self.upper_model is None else self.upper_model.pool
feas = pool.map(self.feature_extract_func, need_extract)
for i, fea in zip(need_extract, feas):
fea_cache[i] = fea
......
......@@ -346,6 +346,7 @@ def generic_func(fdefault):
return func(*args, **kwargs)
fdecorate = decorate(fdefault, dispatch_func)
fdecorate.register = register
fdecorate.fdefault = fdefault
return fdecorate
......
import time
import numpy as np
import tvm
from tvm import autotvm
from tvm.autotvm import MeasureInput, MeasureResult
from tvm.autotvm.tuner.xgboost_cost_model import XGBoostCostModel
from test_autotvm_common import get_sample_task, get_sample_records
def test_fit():
task, target = get_sample_task()
records = get_sample_records(n=100)
base_model = XGBoostCostModel(task, feature_type='itervar', loss_type='rank')
base_model.fit_log(records, plan_size=32)
upper_model = XGBoostCostModel(task, feature_type='itervar', loss_type='rank')
upper_model.load_basemodel(base_model)
xs = np.arange(100)
ys = np.arange(100)
upper_model.fit(xs, ys, plan_size=32)
def test_tuner():
task, target = get_sample_task()
records = get_sample_records(n=100)
tuner = autotvm.tuner.XGBTuner(task)
tuner.load_history(records)
if __name__ == "__main__":
test_fit()
test_tuner()
"""
How to get high performance convolution kernel on NVIDIA GPU by auto-tuning
Tuning High Performance Convolution on NVIDIA GPUs
=========================================================================
**Author**: `Lianmin Zheng <https://https://github.com/merrymercy>`_
......@@ -10,9 +10,11 @@ vendor provided library CuDNN in many cases.
import logging
import sys
import numpy as np
import tvm
import topi
from topi.testing import conv2d_nchw_python
from tvm import autotvm
......@@ -133,9 +135,10 @@ def conv2d_no_batching(N, H, W, CI, CO, KH, KW, stride, padding):
# logging config (for printing tuning log to screen)
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
# the last layer in resnet
# the last layer in resnet
N, H, W, CO, CI, KH, KW, strides, padding = 1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1)
task = autotvm.task.create(conv2d_no_batching,
args=(1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1)),
args=(N, H, W, CO, CI, KH, KW, strides, padding),
target='cuda')
print(task.config_space)
......@@ -146,15 +149,43 @@ measure_option = autotvm.measure_option(mode='local',
parallel_num=8,
timeout=20)
# begin tuning, log records to file `cache.tsv`
# begin tuning, log records to file `conv2d.tsv`
tuner = autotvm.tuner.XGBTuner(task)
tuner.tune(n_trial=20,
measure_option=measure_option,
callbacks=[autotvm.callback.log_to_file('cache.tsv')])
callbacks=[autotvm.callback.log_to_file('conv2d.log')])
# get best config from cache file
dispatch_context = autotvm.apply_history_best("cache.tsv")
#########################################################################
# Finally we can inspect the best config from log file, check correctness,
# and measure running time.
# inspect the best config
dispatch_context = autotvm.apply_history_best("conv2d.log")
best_config = dispatch_context.query(task.target, task.workload)
print("\nBest config:")
print(best_config)
# apply history best from log file
with autotvm.apply_history_best('conv2d.log'):
with tvm.target.create("cuda"):
s, arg_bufs = conv2d_no_batching(N, H, W, CO, CI, KH, KW, strides, padding)
func = tvm.build(s, arg_bufs)
# check correctness
a_np = np.random.uniform(size=(N, CI, H, W)).astype(np.float32)
w_np = np.random.uniform(size=(CO, CI, KH, KW)).astype(np.float32)
c_np = conv2d_nchw_python(a_np, w_np, strides, padding)
ctx = tvm.gpu()
a_tvm = tvm.nd.array(a_np, ctx=ctx)
w_tvm = tvm.nd.array(w_np, ctx=ctx)
c_tvm = tvm.nd.empty(c_np.shape, ctx=ctx)
func(a_tvm, w_tvm, c_tvm)
np.testing.assert_allclose(c_np, c_tvm.asnumpy(), rtol=1e-2)
# Evaluate running time. Here we choose a large repeat number (200) to reduce the noise
# and the overhead of kernel launch. You can also use nvprof to validate the result.
evaluator = func.time_evaluator(func.entry_name, ctx, number=200)
print('Time cost of this operator: %f' % evaluator(a_tvm, w_tvm, c_tvm).mean)
......@@ -243,7 +243,7 @@ print(task.config_space)
#
# We only make 10 trials in this tutorial for demonstration. In practice,
# you can do more trials according to your time budget.
# We will log the tuning results into a cache file. This file can be
# We will log the tuning results into a log file. This file can be
# used to get the best config later.
# logging config (for printing tuning log to screen)
......@@ -253,11 +253,11 @@ logging.basicConfig(level=logging.INFO, stream=sys.stdout)
measure_option = autotvm.measure_option(mode='local',
number=5)
# begin tuning, log records to file `cache.tsv`
# begin tuning, log records to file `matmul.log`
tuner = autotvm.tuner.RandomTuner(task)
tuner.tune(n_trial=10,
measure_option=measure_option,
callbacks=[autotvm.callback.log_to_file('cache.tsv')])
callbacks=[autotvm.callback.log_to_file('matmul.log')])
#########################################################################
# Finally we apply history best from the cache file and check its correctness.
......@@ -267,7 +267,7 @@ tuner.tune(n_trial=10,
# with the same argument.
# apply history best from log file
with autotvm.apply_history_best('cache.tsv'):
with autotvm.apply_history_best('matmul.log'):
with tvm.target.create("llvm"):
s, arg_bufs = matmul(N, L, M, 'float32')
func = tvm.build(s, arg_bufs)
......@@ -281,4 +281,3 @@ c_tvm = tvm.nd.empty(c_np.shape)
func(tvm.nd.array(a_np), tvm.nd.array(b_np), c_tvm)
np.testing.assert_allclose(c_np, c_tvm.asnumpy(), rtol=1e-2)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment