Commit ad28f5ca by Lianmin Zheng Committed by Tianqi Chen

[AUTOTVM] Misc bug fix (#1467)

parent 9026f3fc
...@@ -6,7 +6,6 @@ from collections import namedtuple ...@@ -6,7 +6,6 @@ from collections import namedtuple
import numpy as np import numpy as np
from ... import build, nd, target as _target from ... import build, nd, target as _target
from ...contrib.util import tempdir
from ...rpc.tracker import Tracker from ...rpc.tracker import Tracker
from ...rpc.server import Server from ...rpc.server import Server
...@@ -209,14 +208,12 @@ def create_measure_batch(task, options): ...@@ -209,14 +208,12 @@ def create_measure_batch(task, options):
kwargs['rpc_device_key'] = rpc_device_key kwargs['rpc_device_key'] = rpc_device_key
kwargs['rpc_tracker_addr'] = (tracker.host, tracker.port) kwargs['rpc_tracker_addr'] = (tracker.host, tracker.port)
kwargs['rpc_timeout'] = timeout kwargs['rpc_timeout'] = timeout
kwargs['tmp_dir'] = tempdir()
elif mode == 'rpc': elif mode == 'rpc':
fmeasure = measure_methods.measure_rpc fmeasure = measure_methods.measure_rpc
kwargs['rpc_device_key'] = rpc_device_key kwargs['rpc_device_key'] = rpc_device_key
kwargs['rpc_priority'] = rpc_priority kwargs['rpc_priority'] = rpc_priority
kwargs['rpc_timeout'] = rpc_timeout kwargs['rpc_timeout'] = rpc_timeout
kwargs['use_ndk'] = use_ndk kwargs['use_ndk'] = use_ndk
kwargs['tmp_dir'] = tempdir()
assert rpc_device_key, "In rpc mode, a rpc_device_key must be provided" assert rpc_device_key, "In rpc mode, a rpc_device_key must be provided"
elif mode == "custom": elif mode == "custom":
assert callable(custom_measure_batch), "In custom mode, custom_measure_func " \ assert callable(custom_measure_batch), "In custom mode, custom_measure_func " \
...@@ -243,7 +240,7 @@ def create_measure_batch(task, options): ...@@ -243,7 +240,7 @@ def create_measure_batch(task, options):
tvm_buf = [nd.array(x) for x in ref_input] tvm_buf = [nd.array(x) for x in ref_input]
func(*tvm_buf) func(*tvm_buf)
ref_output = [x.asnumpy() for x in tvm_buf] ref_output = [x.asnumpy() for x in tvm_buf]
kwargs['ref_input'], kwargs['ref_outpu'] = ref_input, ref_output kwargs['ref_input'], kwargs['ref_output'] = ref_input, ref_output
def measure_batch(measure_inputs): def measure_batch(measure_inputs):
"""measure the time cost for a batch of configs in real machines""" """measure the time cost for a batch of configs in real machines"""
......
...@@ -12,7 +12,7 @@ from random import getrandbits ...@@ -12,7 +12,7 @@ from random import getrandbits
import numpy as np import numpy as np
from ...contrib import ndk, nvcc from ...contrib import ndk, nvcc, util
from ... import rpc, ir_pass, build, build_config, nd, context, TVMError, register_func from ... import rpc, ir_pass, build, build_config, nd, context, TVMError, register_func
from ..util import get_const_tuple from ..util import get_const_tuple
...@@ -113,8 +113,8 @@ def _measure_generic(fbuild, input_pack, ref_input, ref_output): ...@@ -113,8 +113,8 @@ def _measure_generic(fbuild, input_pack, ref_input, ref_output):
if ref_input: if ref_input:
args = [nd.array(x, ctx) for x in ref_input] args = [nd.array(x, ctx) for x in ref_input]
else: else:
args = [nd.array(np.random.uniform(size=get_const_tuple(x.shape)).astype(x.dtype), args = [nd.empty(get_const_tuple(x.shape), dtype=x.dtype,
ctx) for x in arg_bufs] ctx=ctx) for x in arg_bufs]
costs = time_f(*args).results costs = time_f(*args).results
if len(costs) > 2: # remove largest and smallest value to reduce variance if len(costs) > 2: # remove largest and smallest value to reduce variance
costs = list(costs) costs = list(costs)
...@@ -173,7 +173,6 @@ def measure_rpc(input_pack, ...@@ -173,7 +173,6 @@ def measure_rpc(input_pack,
rpc_tracker_addr=None, rpc_tracker_addr=None,
rpc_priority=1, rpc_priority=1,
rpc_timeout=60, rpc_timeout=60,
tmp_dir=None,
**kwargs): **kwargs):
"""Measure the time cost on a device by rpc """Measure the time cost on a device by rpc
...@@ -198,9 +197,6 @@ def measure_rpc(input_pack, ...@@ -198,9 +197,6 @@ def measure_rpc(input_pack,
rpc_timeout: int, optional rpc_timeout: int, optional
timeout of the rpc session timeout of the rpc session
tmp_dir: tvm.contrib.util.TempDirectory, optional
directory to store temp file
kwargs: dict, optional kwargs: dict, optional
Additional key word arguments Additional key word arguments
...@@ -213,6 +209,7 @@ def measure_rpc(input_pack, ...@@ -213,6 +209,7 @@ def measure_rpc(input_pack,
""" Local build function.""" """ Local build function."""
func, args = _build_func(inp, build_option, kwargs) func, args = _build_func(inp, build_option, kwargs)
tmp_dir = util.tempdir()
if not kwargs.get('use_ndk', False): if not kwargs.get('use_ndk', False):
file_name = "tmp_func_%0x.tar" % getrandbits(64) file_name = "tmp_func_%0x.tar" % getrandbits(64)
path = tmp_dir.relpath(file_name) path = tmp_dir.relpath(file_name)
......
...@@ -6,7 +6,7 @@ This module defines the task data structure, as well as a collection(zoo) ...@@ -6,7 +6,7 @@ This module defines the task data structure, as well as a collection(zoo)
of typical tasks of interest. of typical tasks of interest.
""" """
from .task import Task, create, register, template, get_config from .task import Task, create, register, template, get_config, args_to_workload
from .space import ConfigSpace, ConfigEntity from .space import ConfigSpace, ConfigEntity
from .code_hash import attach_code_hash, attach_code_hash_to_arg from .code_hash import attach_code_hash, attach_code_hash_to_arg
from .dispatcher import DispatchContext, ApplyConfig, dispatcher from .dispatcher import DispatchContext, ApplyConfig, dispatcher
...@@ -68,6 +68,33 @@ class Task(object): ...@@ -68,6 +68,33 @@ class Task(object):
self.flop = config.flop self.flop = config.flop
return sch, arg_bufs return sch, arg_bufs
def __getstate__(self):
# custom pickle implementation is required for
# some unpickable local task functions.
# So we only pickle the name of the function
# and restore the function by name when unpickling it.
return {
"name": self.name,
"args": self.args,
"kwargs": self.kwargs,
"config_space": self.config_space,
"workload": self.workload,
"flop": self.flop,
"target": self.target,
"target_host": self.target_host
}
def __setstate__(self, state):
self.name = state["name"]
self.args = state["args"]
self.kwargs = state["kwargs"]
self.config_space = state["config_space"]
self.func = TASK_TABLE.get(state["name"], _raise_error)
self.workload = state["workload"]
self.flop = state["flop"]
self.target = state["target"]
self.target_host = state["target_host"]
def __repr__(self): def __repr__(self):
return "Task(func_name=%s, args=%s, kwargs=%s, workload=%s)" % ( return "Task(func_name=%s, args=%s, kwargs=%s, workload=%s)" % (
self.name, self.args, self.kwargs, self.workload self.name, self.args, self.kwargs, self.workload
......
...@@ -264,12 +264,23 @@ class ModelBasedTuner(Tuner): ...@@ -264,12 +264,23 @@ class ModelBasedTuner(Tuner):
self.train_ct += 1 self.train_ct += 1
def load_history(self, data_set): def load_history(self, data_set):
# filter data, only pick the data with a same task
data = []
for inp, res in data_set:
if inp.task.name == self.task.name and \
inp.config.template_key == self.task.config_space.template_key:
data.append((inp, res))
if not data:
return
# fit base model
base_model = self.cost_model.clone_new() base_model = self.cost_model.clone_new()
base_model.fit_log(data_set, self.plan_size) base_model.fit_log(data, self.plan_size)
# use base model to select initial points
if not self.trials: if not self.trials:
# no plan yet, use base model to select initial trials # no plan yet, use base model to select initial trials
maximums = self.model_optimizer.find_maximums(base_model, self.visited) maximums = self.model_optimizer.find_maximums(base_model, self.plan_size, self.visited)
self.trials = maximums self.trials = maximums
self.trial_pt = 0 self.trial_pt = 0
......
...@@ -30,7 +30,7 @@ class SimulatedAnnealingOptimizer(ModelOptimizer): ...@@ -30,7 +30,7 @@ class SimulatedAnnealingOptimizer(ModelOptimizer):
Print log every `verbose` iterations Print log every `verbose` iterations
""" """
def __init__(self, task, n_iter=500, temp=(1, 0), persistent=True, parallel_size=128, def __init__(self, task, n_iter=500, temp=(1, 0), persistent=True, parallel_size=128,
early_stop=30, verbose=50): early_stop=50, verbose=50):
super(SimulatedAnnealingOptimizer, self).__init__() super(SimulatedAnnealingOptimizer, self).__init__()
self.task = task self.task = task
...@@ -39,8 +39,8 @@ class SimulatedAnnealingOptimizer(ModelOptimizer): ...@@ -39,8 +39,8 @@ class SimulatedAnnealingOptimizer(ModelOptimizer):
self.n_iter = n_iter self.n_iter = n_iter
self.temp = temp self.temp = temp
self.persistent = persistent self.persistent = persistent
self.parallel_size = parallel_size self.parallel_size = min(parallel_size, len(self.task.config_space))
self.early_stop = early_stop self.early_stop = early_stop or 1e9
self.verbose = verbose self.verbose = verbose
self.points = None self.points = None
......
...@@ -27,6 +27,7 @@ class Tuner(object): ...@@ -27,6 +27,7 @@ class Tuner(object):
self.best_config = None self.best_config = None
self.best_flops = 0 self.best_flops = 0
self.best_measure_pair = None self.best_measure_pair = None
self.best_iter = 0
def has_next(self): def has_next(self):
"""Whether has next untried config in the space """Whether has next untried config in the space
...@@ -63,7 +64,7 @@ class Tuner(object): ...@@ -63,7 +64,7 @@ class Tuner(object):
""" """
pass pass
def tune(self, n_trial, measure_option, verbose=1, callbacks=()): def tune(self, n_trial, measure_option, early_stop=None, verbose=1, callbacks=()):
"""Begin tuning """Begin tuning
Parameters Parameters
...@@ -73,6 +74,8 @@ class Tuner(object): ...@@ -73,6 +74,8 @@ class Tuner(object):
measure_option: dict measure_option: dict
The options for how to measure generated code. The options for how to measure generated code.
You should use the return value ot autotvm.measure_option for this argument. You should use the return value ot autotvm.measure_option for this argument.
early_stop: int
Early stop the tuning when not finding better configs in this number of trials
verbose: int verbose: int
0: silent mode, no output 0: silent mode, no output
1: print every measurement result 1: print every measurement result
...@@ -84,6 +87,7 @@ class Tuner(object): ...@@ -84,6 +87,7 @@ class Tuner(object):
""" """
measure_batch = create_measure_batch(self.task, measure_option) measure_batch = create_measure_batch(self.task, measure_option)
parallel_num = getattr(measure_batch, 'parallel_num', 1) parallel_num = getattr(measure_batch, 'parallel_num', 1)
early_stop = early_stop or 1e9
i = 0 i = 0
while i < n_trial: while i < n_trial:
...@@ -107,6 +111,7 @@ class Tuner(object): ...@@ -107,6 +111,7 @@ class Tuner(object):
self.best_flops = flops self.best_flops = flops
self.best_config = config self.best_config = config
self.best_measure_pair = (inp, res) self.best_measure_pair = (inp, res)
self.best_iter = i + k
logging.info("No: %d\tGFLOPS: %.2f/%.2f\tresult: %s\t%s", logging.info("No: %d\tGFLOPS: %.2f/%.2f\tresult: %s\t%s",
i + k + 1, flops / 1e9, self.best_flops / 1e9, i + k + 1, flops / 1e9, self.best_flops / 1e9,
...@@ -119,6 +124,10 @@ class Tuner(object): ...@@ -119,6 +124,10 @@ class Tuner(object):
for callback in callbacks: for callback in callbacks:
callback(self, inputs, results) callback(self, inputs, results)
if i > self.best_iter + early_stop:
logging.info("Early stopped. Best iter: %d.", self.best_iter)
break
del measure_batch del measure_batch
def reset(self): def reset(self):
......
...@@ -111,6 +111,9 @@ class XGBoostCostModel(CostModel): ...@@ -111,6 +111,9 @@ class XGBoostCostModel(CostModel):
self.feature_extra_ct = 0 self.feature_extra_ct = 0
self.pool = None self.pool = None
self.base_model = None self.base_model = None
self.upper_model = None
self._sample_size = 0
self._reset_pool() self._reset_pool()
...@@ -127,20 +130,25 @@ class XGBoostCostModel(CostModel): ...@@ -127,20 +130,25 @@ class XGBoostCostModel(CostModel):
_extract_task = self.task _extract_task = self.task
self.pool = multiprocessing.Pool(self.num_threads) self.pool = multiprocessing.Pool(self.num_threads)
def _base_model_discount(self):
return 1.0 / (2 ** (self._sample_size / 50.0))
def fit(self, xs, ys, plan_size): def fit(self, xs, ys, plan_size):
tic = time.time() tic = time.time()
self._reset_pool() self._reset_pool()
x_train = self._get_feature(xs) x_train = self._get_feature(xs)
y_train = np.array(ys) y_train = np.array(ys)
y_train /= np.max(y_train) y_train = y_train / np.max(y_train)
valid_index = y_train > 1e-6 valid_index = y_train > 1e-6
index = np.random.permutation(len(x_train)) index = np.random.permutation(len(x_train))
dtrain = xgb.DMatrix(x_train[index], y_train[index]) dtrain = xgb.DMatrix(x_train[index], y_train[index])
self._sample_size = len(x_train)
if self.base_model: if self.base_model:
dtrain.set_base_margin(self.base_model.predict(xs, output_margin=True)) dtrain.set_base_margin(self._base_model_discount() *
self.base_model.predict(xs, output_margin=True))
self.bst = xgb.train(self.xgb_params, dtrain, self.bst = xgb.train(self.xgb_params, dtrain,
num_boost_round=8000, num_boost_round=8000,
...@@ -164,6 +172,7 @@ class XGBoostCostModel(CostModel): ...@@ -164,6 +172,7 @@ class XGBoostCostModel(CostModel):
self._reset_pool() self._reset_pool()
args = list(records) args = list(records)
logging.info("Load %d entries from history log file", len(args))
if self.fea_type == 'itervar': if self.fea_type == 'itervar':
feature_extract_func = _extract_itervar_feature_log feature_extract_func = _extract_itervar_feature_log
elif self.fea_type == 'knob': elif self.fea_type == 'knob':
...@@ -185,7 +194,7 @@ class XGBoostCostModel(CostModel): ...@@ -185,7 +194,7 @@ class XGBoostCostModel(CostModel):
plan_size *= 2 plan_size *= 2
self.bst = xgb.train(self.xgb_params, dtrain, self.bst = xgb.train(self.xgb_params, dtrain,
num_boost_round=200, num_boost_round=400,
callbacks=[custom_callback( callbacks=[custom_callback(
stopping_rounds=100, stopping_rounds=100,
metric='tr-a-recall@%d' % plan_size, metric='tr-a-recall@%d' % plan_size,
...@@ -203,12 +212,23 @@ class XGBoostCostModel(CostModel): ...@@ -203,12 +212,23 @@ class XGBoostCostModel(CostModel):
dtest = xgb.DMatrix(feas) dtest = xgb.DMatrix(feas)
if self.base_model: if self.base_model:
dtest.set_base_margin(self.base_model.predict(xs, output_margin=True)) dtest.set_base_margin(self._base_model_discount() *
self.base_model.predict(xs, output_margin=True))
return self.bst.predict(dtest, output_margin=output_margin) return self.bst.predict(dtest, output_margin=output_margin)
def load_basemodel(self, base_model): def load_basemodel(self, base_model):
self.base_model = base_model self.base_model = base_model
if isinstance(base_model, XGBoostCostModel):
# share feature cache
base_model.feature_cache = self.feature_cache
# close thread pool
if base_model.pool:
base_model.pool.terminate()
base_model.pool.join()
del base_model.pool
self.base_model.upper_model = self
def clone_new(self): def clone_new(self):
return XGBoostCostModel(self.task, self.fea_type, self.loss_type, return XGBoostCostModel(self.task, self.fea_type, self.loss_type,
...@@ -226,7 +246,8 @@ class XGBoostCostModel(CostModel): ...@@ -226,7 +246,8 @@ class XGBoostCostModel(CostModel):
need_extract = [x for x in indexes if x not in fea_cache] need_extract = [x for x in indexes if x not in fea_cache]
if need_extract: if need_extract:
feas = self.pool.map(self.feature_extract_func, need_extract) pool = self.pool if self.upper_model is None else self.upper_model.pool
feas = pool.map(self.feature_extract_func, need_extract)
for i, fea in zip(need_extract, feas): for i, fea in zip(need_extract, feas):
fea_cache[i] = fea fea_cache[i] = fea
......
...@@ -346,6 +346,7 @@ def generic_func(fdefault): ...@@ -346,6 +346,7 @@ def generic_func(fdefault):
return func(*args, **kwargs) return func(*args, **kwargs)
fdecorate = decorate(fdefault, dispatch_func) fdecorate = decorate(fdefault, dispatch_func)
fdecorate.register = register fdecorate.register = register
fdecorate.fdefault = fdefault
return fdecorate return fdecorate
......
import time
import numpy as np
import tvm
from tvm import autotvm
from tvm.autotvm import MeasureInput, MeasureResult
from tvm.autotvm.tuner.xgboost_cost_model import XGBoostCostModel
from test_autotvm_common import get_sample_task, get_sample_records
def test_fit():
task, target = get_sample_task()
records = get_sample_records(n=100)
base_model = XGBoostCostModel(task, feature_type='itervar', loss_type='rank')
base_model.fit_log(records, plan_size=32)
upper_model = XGBoostCostModel(task, feature_type='itervar', loss_type='rank')
upper_model.load_basemodel(base_model)
xs = np.arange(100)
ys = np.arange(100)
upper_model.fit(xs, ys, plan_size=32)
def test_tuner():
task, target = get_sample_task()
records = get_sample_records(n=100)
tuner = autotvm.tuner.XGBTuner(task)
tuner.load_history(records)
if __name__ == "__main__":
test_fit()
test_tuner()
""" """
How to get high performance convolution kernel on NVIDIA GPU by auto-tuning Tuning High Performance Convolution on NVIDIA GPUs
========================================================================= =========================================================================
**Author**: `Lianmin Zheng <https://https://github.com/merrymercy>`_ **Author**: `Lianmin Zheng <https://https://github.com/merrymercy>`_
...@@ -10,9 +10,11 @@ vendor provided library CuDNN in many cases. ...@@ -10,9 +10,11 @@ vendor provided library CuDNN in many cases.
import logging import logging
import sys import sys
import numpy as np
import tvm import tvm
import topi import topi
from topi.testing import conv2d_nchw_python
from tvm import autotvm from tvm import autotvm
...@@ -133,9 +135,10 @@ def conv2d_no_batching(N, H, W, CI, CO, KH, KW, stride, padding): ...@@ -133,9 +135,10 @@ def conv2d_no_batching(N, H, W, CI, CO, KH, KW, stride, padding):
# logging config (for printing tuning log to screen) # logging config (for printing tuning log to screen)
logging.basicConfig(level=logging.INFO, stream=sys.stdout) logging.basicConfig(level=logging.INFO, stream=sys.stdout)
# the last layer in resnet # the last layer in resnet
N, H, W, CO, CI, KH, KW, strides, padding = 1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1)
task = autotvm.task.create(conv2d_no_batching, task = autotvm.task.create(conv2d_no_batching,
args=(1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1)), args=(N, H, W, CO, CI, KH, KW, strides, padding),
target='cuda') target='cuda')
print(task.config_space) print(task.config_space)
...@@ -146,15 +149,43 @@ measure_option = autotvm.measure_option(mode='local', ...@@ -146,15 +149,43 @@ measure_option = autotvm.measure_option(mode='local',
parallel_num=8, parallel_num=8,
timeout=20) timeout=20)
# begin tuning, log records to file `cache.tsv` # begin tuning, log records to file `conv2d.tsv`
tuner = autotvm.tuner.XGBTuner(task) tuner = autotvm.tuner.XGBTuner(task)
tuner.tune(n_trial=20, tuner.tune(n_trial=20,
measure_option=measure_option, measure_option=measure_option,
callbacks=[autotvm.callback.log_to_file('cache.tsv')]) callbacks=[autotvm.callback.log_to_file('conv2d.log')])
# get best config from cache file #########################################################################
dispatch_context = autotvm.apply_history_best("cache.tsv") # Finally we can inspect the best config from log file, check correctness,
# and measure running time.
# inspect the best config
dispatch_context = autotvm.apply_history_best("conv2d.log")
best_config = dispatch_context.query(task.target, task.workload) best_config = dispatch_context.query(task.target, task.workload)
print("\nBest config:") print("\nBest config:")
print(best_config) print(best_config)
# apply history best from log file
with autotvm.apply_history_best('conv2d.log'):
with tvm.target.create("cuda"):
s, arg_bufs = conv2d_no_batching(N, H, W, CO, CI, KH, KW, strides, padding)
func = tvm.build(s, arg_bufs)
# check correctness
a_np = np.random.uniform(size=(N, CI, H, W)).astype(np.float32)
w_np = np.random.uniform(size=(CO, CI, KH, KW)).astype(np.float32)
c_np = conv2d_nchw_python(a_np, w_np, strides, padding)
ctx = tvm.gpu()
a_tvm = tvm.nd.array(a_np, ctx=ctx)
w_tvm = tvm.nd.array(w_np, ctx=ctx)
c_tvm = tvm.nd.empty(c_np.shape, ctx=ctx)
func(a_tvm, w_tvm, c_tvm)
np.testing.assert_allclose(c_np, c_tvm.asnumpy(), rtol=1e-2)
# Evaluate running time. Here we choose a large repeat number (200) to reduce the noise
# and the overhead of kernel launch. You can also use nvprof to validate the result.
evaluator = func.time_evaluator(func.entry_name, ctx, number=200)
print('Time cost of this operator: %f' % evaluator(a_tvm, w_tvm, c_tvm).mean)
...@@ -243,7 +243,7 @@ print(task.config_space) ...@@ -243,7 +243,7 @@ print(task.config_space)
# #
# We only make 10 trials in this tutorial for demonstration. In practice, # We only make 10 trials in this tutorial for demonstration. In practice,
# you can do more trials according to your time budget. # you can do more trials according to your time budget.
# We will log the tuning results into a cache file. This file can be # We will log the tuning results into a log file. This file can be
# used to get the best config later. # used to get the best config later.
# logging config (for printing tuning log to screen) # logging config (for printing tuning log to screen)
...@@ -253,11 +253,11 @@ logging.basicConfig(level=logging.INFO, stream=sys.stdout) ...@@ -253,11 +253,11 @@ logging.basicConfig(level=logging.INFO, stream=sys.stdout)
measure_option = autotvm.measure_option(mode='local', measure_option = autotvm.measure_option(mode='local',
number=5) number=5)
# begin tuning, log records to file `cache.tsv` # begin tuning, log records to file `matmul.log`
tuner = autotvm.tuner.RandomTuner(task) tuner = autotvm.tuner.RandomTuner(task)
tuner.tune(n_trial=10, tuner.tune(n_trial=10,
measure_option=measure_option, measure_option=measure_option,
callbacks=[autotvm.callback.log_to_file('cache.tsv')]) callbacks=[autotvm.callback.log_to_file('matmul.log')])
######################################################################### #########################################################################
# Finally we apply history best from the cache file and check its correctness. # Finally we apply history best from the cache file and check its correctness.
...@@ -267,7 +267,7 @@ tuner.tune(n_trial=10, ...@@ -267,7 +267,7 @@ tuner.tune(n_trial=10,
# with the same argument. # with the same argument.
# apply history best from log file # apply history best from log file
with autotvm.apply_history_best('cache.tsv'): with autotvm.apply_history_best('matmul.log'):
with tvm.target.create("llvm"): with tvm.target.create("llvm"):
s, arg_bufs = matmul(N, L, M, 'float32') s, arg_bufs = matmul(N, L, M, 'float32')
func = tvm.build(s, arg_bufs) func = tvm.build(s, arg_bufs)
...@@ -281,4 +281,3 @@ c_tvm = tvm.nd.empty(c_np.shape) ...@@ -281,4 +281,3 @@ c_tvm = tvm.nd.empty(c_np.shape)
func(tvm.nd.array(a_np), tvm.nd.array(b_np), c_tvm) func(tvm.nd.array(a_np), tvm.nd.array(b_np), c_tvm)
np.testing.assert_allclose(c_np, c_tvm.asnumpy(), rtol=1e-2) np.testing.assert_allclose(c_np, c_tvm.asnumpy(), rtol=1e-2)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment