Commit ad28f5ca by Lianmin Zheng Committed by Tianqi Chen

[AUTOTVM] Misc bug fix (#1467)

parent 9026f3fc
......@@ -6,7 +6,6 @@ from collections import namedtuple
import numpy as np
from ... import build, nd, target as _target
from ...contrib.util import tempdir
from ...rpc.tracker import Tracker
from ...rpc.server import Server
......@@ -209,14 +208,12 @@ def create_measure_batch(task, options):
kwargs['rpc_device_key'] = rpc_device_key
kwargs['rpc_tracker_addr'] = (tracker.host, tracker.port)
kwargs['rpc_timeout'] = timeout
kwargs['tmp_dir'] = tempdir()
elif mode == 'rpc':
fmeasure = measure_methods.measure_rpc
kwargs['rpc_device_key'] = rpc_device_key
kwargs['rpc_priority'] = rpc_priority
kwargs['rpc_timeout'] = rpc_timeout
kwargs['use_ndk'] = use_ndk
kwargs['tmp_dir'] = tempdir()
assert rpc_device_key, "In rpc mode, a rpc_device_key must be provided"
elif mode == "custom":
assert callable(custom_measure_batch), "In custom mode, custom_measure_func " \
......@@ -243,7 +240,7 @@ def create_measure_batch(task, options):
tvm_buf = [nd.array(x) for x in ref_input]
func(*tvm_buf)
ref_output = [x.asnumpy() for x in tvm_buf]
kwargs['ref_input'], kwargs['ref_outpu'] = ref_input, ref_output
kwargs['ref_input'], kwargs['ref_output'] = ref_input, ref_output
def measure_batch(measure_inputs):
"""measure the time cost for a batch of configs in real machines"""
......
......@@ -12,7 +12,7 @@ from random import getrandbits
import numpy as np
from ...contrib import ndk, nvcc
from ...contrib import ndk, nvcc, util
from ... import rpc, ir_pass, build, build_config, nd, context, TVMError, register_func
from ..util import get_const_tuple
......@@ -113,8 +113,8 @@ def _measure_generic(fbuild, input_pack, ref_input, ref_output):
if ref_input:
args = [nd.array(x, ctx) for x in ref_input]
else:
args = [nd.array(np.random.uniform(size=get_const_tuple(x.shape)).astype(x.dtype),
ctx) for x in arg_bufs]
args = [nd.empty(get_const_tuple(x.shape), dtype=x.dtype,
ctx=ctx) for x in arg_bufs]
costs = time_f(*args).results
if len(costs) > 2: # remove largest and smallest value to reduce variance
costs = list(costs)
......@@ -173,7 +173,6 @@ def measure_rpc(input_pack,
rpc_tracker_addr=None,
rpc_priority=1,
rpc_timeout=60,
tmp_dir=None,
**kwargs):
"""Measure the time cost on a device by rpc
......@@ -198,9 +197,6 @@ def measure_rpc(input_pack,
rpc_timeout: int, optional
timeout of the rpc session
tmp_dir: tvm.contrib.util.TempDirectory, optional
directory to store temp file
kwargs: dict, optional
Additional key word arguments
......@@ -213,6 +209,7 @@ def measure_rpc(input_pack,
""" Local build function."""
func, args = _build_func(inp, build_option, kwargs)
tmp_dir = util.tempdir()
if not kwargs.get('use_ndk', False):
file_name = "tmp_func_%0x.tar" % getrandbits(64)
path = tmp_dir.relpath(file_name)
......
......@@ -9,11 +9,12 @@ import multiprocessing
import pickle
import json
import time
import os
from collections import OrderedDict
import numpy as np
from .. import target, build, lower
from .. import build, lower, target as _target
from . import task
from .task import DispatchContext, ConfigEntity
......@@ -26,6 +27,11 @@ try: # convert unicode to str for python2
except NameError:
_unicode = ()
try:
_long = long
except NameError:
_long = int
def measure_str_key(inp, include_config=True):
""" get unique str key for MeasureInput
......@@ -111,7 +117,7 @@ def decode(row, protocol='json'):
if protocol == 'json':
row = json.loads(row)
tgt, task_name, task_args, task_kwargs, workload, config = row['i']
tgt = target.create(str(tgt))
tgt = _target.create(str(tgt))
def clean_json_to_python(x):
"""1. convert all list in x to tuple (hashable)
......@@ -121,6 +127,8 @@ def decode(row, protocol='json'):
return tuple([clean_json_to_python(a) for a in x])
if isinstance(x, _unicode):
return str(x)
if isinstance(x, (_long, int)):
return int(x)
return x
tsk = task.Task(clean_json_to_python(task_name), clean_json_to_python(task_args))
......@@ -132,7 +140,7 @@ def decode(row, protocol='json'):
return inp, result
elif protocol == 'pickle':
items = row.split("\t")
tgt = target.create(items[0])
tgt = _target.create(items[0])
task_tuple = pickle.loads(base64.b64decode(items[1].encode()))
config = pickle.loads(base64.b64decode(items[2].encode()))
result = pickle.loads(base64.b64decode(items[3].encode()))
......@@ -168,36 +176,70 @@ class ApplyHistoryBest(DispatchContext):
----------
records : str or iterator of (MeasureInput, MeasureResult)
Collection of tuning records.
if is str, then it should be the filename of a records log file.
If is str, then it should be the filename of a records log file.
Each row of this file is an encoded record pair.
otherwise, it is an iterator
Otherwise, it is an iterator.
default: ConfigEntity, optional
default config to return when no history records
The default config to return when no history records
"""
def __init__(self, records, default=None):
super(ApplyHistoryBest, self).__init__()
self.best_by_targetkey = {}
self.best_by_model = {}
self._default = default
self.load(records)
def load(self, records):
"""Load records to this dispatch context
Parameters
----------
records : str or iterator of (MeasureInput, MeasureResult)
Collection of tuning records.
If is str, then it should be the filename of a records log file.
Each row of this file is an encoded record pair.
Otherwise, it is an iterator.
"""
if isinstance(records, str):
records = load_from_file(records)
if not records:
return
best_by_targetkey = self.best_by_targetkey
best_by_model = self.best_by_model
counter = 0
best_map = {}
for inp, res in records:
counter += 1
if res.error_no != 0:
continue
# use target keys in tvm target system as key to build best map
for k in inp.target.keys:
key = (k, inp.task.workload)
if key not in best_map:
best_map[key] = (inp, res)
if key not in best_by_targetkey:
best_by_targetkey[key] = (inp, res)
else:
_, other_res = best_map[key]
_, other_res = best_by_targetkey[key]
if np.mean(other_res.costs) > np.mean(res.costs):
best_map[key] = (inp, res)
logging.info(
"Finish load %d records, %d entries selected", counter, len(best_map))
self._best_map = best_map
self._default = default
best_by_targetkey[key] = (inp, res)
# use model as key to build best map
for opt in inp.target.options:
if opt.startswith("-model"):
model = opt[7:]
key = (model, inp.task.workload)
if key not in best_by_model:
best_by_model[key] = (inp, res)
else:
_, other_res = best_by_model[key]
if np.mean(other_res.costs) > np.mean(res.costs):
best_by_model[key] = (inp, res)
break
logging.info("Finish loading %d records", counter)
def query(self, target, workload):
if target is None:
......@@ -205,29 +247,25 @@ class ApplyHistoryBest(DispatchContext):
"Hint: If your target is llvm, use `with tvm.target.create('llvm'):`"
" above the dispatcher call. So does other target. ")
# first try matching by model
for opt in target.options:
if opt.startswith("-model"):
model = opt[7:]
key = (model, workload)
if key in self.best_by_model:
return self.best_by_model[key][0].config
# then try matching by target key
for k in target.keys:
key = (k, workload)
if key in self._best_map:
return self._best_map[key][0].config
if key in self.best_by_targetkey:
return self.best_by_targetkey[key][0].config
if self._default:
return self._default
raise RuntimeError(
"Cannot find config for target=%s, workload=%s" % (target, workload))
def dump_best(self, out_file):
"""Dump the best records for each workload to a file
Parameters
----------
out_file: str
filename
"""
fout = open(out_file, 'a')
for val in self._best_map.values():
inp, res = val
fout.write(encode(inp, res) + '\n')
def split_workload(in_file, clean=True):
"""Split a log file into separate files, each of which contains only a single workload
......@@ -243,7 +281,7 @@ def split_workload(in_file, clean=True):
tic = time.time()
lines = list(open(in_file).readlines())
logging.info("start convert...")
logging.info("start converting...")
pool = multiprocessing.Pool()
lines = pool.map(decode, lines)
logging.info("map done %.2f", time.time() - tic)
......@@ -279,23 +317,69 @@ def split_workload(in_file, clean=True):
for inp, res in v:
fout.write(encode(inp, res) + '\n')
def pick_best(in_file, out_file):
"""
Pick best entries from a file and store it to another file.
This distill the useful log entries from a large log file.
Parameters
----------
in_file: str
The filename of input
out_file:
The filename of output
"""
best_context = ApplyHistoryBest(load_from_file(in_file))
best_set = set()
for v in best_context.best_by_model.values():
best_set.add(measure_str_key(v[0]))
for v in best_context.best_by_targetkey.values():
best_set.add(measure_str_key(v[0]))
logging.info("Extract %d best records from the log file", len(best_set))
fout = open(out_file, 'w')
for inp, res in load_from_file(in_file):
if measure_str_key(inp) in best_set:
fout.write(encode(inp, res) + "\n")
def load_op_param(rootpath=os.path.join(os.path.expanduser('~'), ".tvm", "op_params")):
"""Load pre-tuned parameters of operators.
This function will load all "*.log" file under root path and select best configs.
Parameters
----------
rootpath: str
The root path of stored parameters
"""
best_context = ApplyHistoryBest([])
for dirpath, _, filenames in os.walk(rootpath):
for filename in filenames:
if os.path.splitext(filename)[1] == '.log':
best_context.load(os.path.join(dirpath, filename))
assert not DispatchContext.current, "Cannot load pre-tuned parameters inside a dispatch context"
DispatchContext.current = best_context
"""
Usage:
This record executable module has three modes.
* Print log file in readable format
e.g. python -m autotvm.record --mode read --i collect_conv.tsv --begin 0 --end 5 --ir --code
e.g. python -m autotvm.record --mode read --i collect_conv.log --begin 0 --end 5 --ir --code
* Extract history best from a large log file
e.g. python -m autotvm.record --mode best --i collect.tsv
e.g. python -m autotvm.record --mode pick --i collect.log
* Split a log file into separate files, each of which contains only a single wkl
e.g. python -m autotvm.record --mode split --i collect.tsv
e.g. python -m autotvm.record --mode split --i collect.log
"""
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--mode", choices=['read', 'best', 'split'], default='read')
parser.add_argument("--mode", choices=['read', 'pick', 'split'], default='read')
parser.add_argument("--i", type=str, help="input file")
parser.add_argument("--o", type=str, default=None, help='output file')
parser.add_argument("--begin", type=int, default=0)
......@@ -306,10 +390,9 @@ if __name__ == '__main__':
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
if args.mode == 'best':
args.o = args.o or args.i + ".best"
hist_best = ApplyHistoryBest(load_from_file(args.i))
hist_best.dump_best(args.o)
if args.mode == 'pick':
args.o = args.o or args.i + ".best.log"
pick_best(args.i, args.o)
elif args.mode == 'read':
for i, (inp, result) in enumerate(load_from_file(args.i)):
if args.begin <= i < args.end:
......
......@@ -6,7 +6,7 @@ This module defines the task data structure, as well as a collection(zoo)
of typical tasks of interest.
"""
from .task import Task, create, register, template, get_config
from .task import Task, create, register, template, get_config, args_to_workload
from .space import ConfigSpace, ConfigEntity
from .code_hash import attach_code_hash, attach_code_hash_to_arg
from .dispatcher import DispatchContext, ApplyConfig, dispatcher
......@@ -68,6 +68,33 @@ class Task(object):
self.flop = config.flop
return sch, arg_bufs
def __getstate__(self):
# custom pickle implementation is required for
# some unpickable local task functions.
# So we only pickle the name of the function
# and restore the function by name when unpickling it.
return {
"name": self.name,
"args": self.args,
"kwargs": self.kwargs,
"config_space": self.config_space,
"workload": self.workload,
"flop": self.flop,
"target": self.target,
"target_host": self.target_host
}
def __setstate__(self, state):
self.name = state["name"]
self.args = state["args"]
self.kwargs = state["kwargs"]
self.config_space = state["config_space"]
self.func = TASK_TABLE.get(state["name"], _raise_error)
self.workload = state["workload"]
self.flop = state["flop"]
self.target = state["target"]
self.target_host = state["target_host"]
def __repr__(self):
return "Task(func_name=%s, args=%s, kwargs=%s, workload=%s)" % (
self.name, self.args, self.kwargs, self.workload
......
......@@ -264,12 +264,23 @@ class ModelBasedTuner(Tuner):
self.train_ct += 1
def load_history(self, data_set):
# filter data, only pick the data with a same task
data = []
for inp, res in data_set:
if inp.task.name == self.task.name and \
inp.config.template_key == self.task.config_space.template_key:
data.append((inp, res))
if not data:
return
# fit base model
base_model = self.cost_model.clone_new()
base_model.fit_log(data_set, self.plan_size)
base_model.fit_log(data, self.plan_size)
# use base model to select initial points
if not self.trials:
# no plan yet, use base model to select initial trials
maximums = self.model_optimizer.find_maximums(base_model, self.visited)
maximums = self.model_optimizer.find_maximums(base_model, self.plan_size, self.visited)
self.trials = maximums
self.trial_pt = 0
......
......@@ -30,7 +30,7 @@ class SimulatedAnnealingOptimizer(ModelOptimizer):
Print log every `verbose` iterations
"""
def __init__(self, task, n_iter=500, temp=(1, 0), persistent=True, parallel_size=128,
early_stop=30, verbose=50):
early_stop=50, verbose=50):
super(SimulatedAnnealingOptimizer, self).__init__()
self.task = task
......@@ -39,8 +39,8 @@ class SimulatedAnnealingOptimizer(ModelOptimizer):
self.n_iter = n_iter
self.temp = temp
self.persistent = persistent
self.parallel_size = parallel_size
self.early_stop = early_stop
self.parallel_size = min(parallel_size, len(self.task.config_space))
self.early_stop = early_stop or 1e9
self.verbose = verbose
self.points = None
......
......@@ -27,6 +27,7 @@ class Tuner(object):
self.best_config = None
self.best_flops = 0
self.best_measure_pair = None
self.best_iter = 0
def has_next(self):
"""Whether has next untried config in the space
......@@ -63,7 +64,7 @@ class Tuner(object):
"""
pass
def tune(self, n_trial, measure_option, verbose=1, callbacks=()):
def tune(self, n_trial, measure_option, early_stop=None, verbose=1, callbacks=()):
"""Begin tuning
Parameters
......@@ -73,6 +74,8 @@ class Tuner(object):
measure_option: dict
The options for how to measure generated code.
You should use the return value ot autotvm.measure_option for this argument.
early_stop: int
Early stop the tuning when not finding better configs in this number of trials
verbose: int
0: silent mode, no output
1: print every measurement result
......@@ -84,6 +87,7 @@ class Tuner(object):
"""
measure_batch = create_measure_batch(self.task, measure_option)
parallel_num = getattr(measure_batch, 'parallel_num', 1)
early_stop = early_stop or 1e9
i = 0
while i < n_trial:
......@@ -107,6 +111,7 @@ class Tuner(object):
self.best_flops = flops
self.best_config = config
self.best_measure_pair = (inp, res)
self.best_iter = i + k
logging.info("No: %d\tGFLOPS: %.2f/%.2f\tresult: %s\t%s",
i + k + 1, flops / 1e9, self.best_flops / 1e9,
......@@ -119,6 +124,10 @@ class Tuner(object):
for callback in callbacks:
callback(self, inputs, results)
if i > self.best_iter + early_stop:
logging.info("Early stopped. Best iter: %d.", self.best_iter)
break
del measure_batch
def reset(self):
......
......@@ -111,6 +111,9 @@ class XGBoostCostModel(CostModel):
self.feature_extra_ct = 0
self.pool = None
self.base_model = None
self.upper_model = None
self._sample_size = 0
self._reset_pool()
......@@ -127,20 +130,25 @@ class XGBoostCostModel(CostModel):
_extract_task = self.task
self.pool = multiprocessing.Pool(self.num_threads)
def _base_model_discount(self):
return 1.0 / (2 ** (self._sample_size / 50.0))
def fit(self, xs, ys, plan_size):
tic = time.time()
self._reset_pool()
x_train = self._get_feature(xs)
y_train = np.array(ys)
y_train /= np.max(y_train)
y_train = y_train / np.max(y_train)
valid_index = y_train > 1e-6
index = np.random.permutation(len(x_train))
dtrain = xgb.DMatrix(x_train[index], y_train[index])
self._sample_size = len(x_train)
if self.base_model:
dtrain.set_base_margin(self.base_model.predict(xs, output_margin=True))
dtrain.set_base_margin(self._base_model_discount() *
self.base_model.predict(xs, output_margin=True))
self.bst = xgb.train(self.xgb_params, dtrain,
num_boost_round=8000,
......@@ -164,6 +172,7 @@ class XGBoostCostModel(CostModel):
self._reset_pool()
args = list(records)
logging.info("Load %d entries from history log file", len(args))
if self.fea_type == 'itervar':
feature_extract_func = _extract_itervar_feature_log
elif self.fea_type == 'knob':
......@@ -185,7 +194,7 @@ class XGBoostCostModel(CostModel):
plan_size *= 2
self.bst = xgb.train(self.xgb_params, dtrain,
num_boost_round=200,
num_boost_round=400,
callbacks=[custom_callback(
stopping_rounds=100,
metric='tr-a-recall@%d' % plan_size,
......@@ -203,12 +212,23 @@ class XGBoostCostModel(CostModel):
dtest = xgb.DMatrix(feas)
if self.base_model:
dtest.set_base_margin(self.base_model.predict(xs, output_margin=True))
dtest.set_base_margin(self._base_model_discount() *
self.base_model.predict(xs, output_margin=True))
return self.bst.predict(dtest, output_margin=output_margin)
def load_basemodel(self, base_model):
self.base_model = base_model
if isinstance(base_model, XGBoostCostModel):
# share feature cache
base_model.feature_cache = self.feature_cache
# close thread pool
if base_model.pool:
base_model.pool.terminate()
base_model.pool.join()
del base_model.pool
self.base_model.upper_model = self
def clone_new(self):
return XGBoostCostModel(self.task, self.fea_type, self.loss_type,
......@@ -226,7 +246,8 @@ class XGBoostCostModel(CostModel):
need_extract = [x for x in indexes if x not in fea_cache]
if need_extract:
feas = self.pool.map(self.feature_extract_func, need_extract)
pool = self.pool if self.upper_model is None else self.upper_model.pool
feas = pool.map(self.feature_extract_func, need_extract)
for i, fea in zip(need_extract, feas):
fea_cache[i] = fea
......
......@@ -346,6 +346,7 @@ def generic_func(fdefault):
return func(*args, **kwargs)
fdecorate = decorate(fdefault, dispatch_func)
fdecorate.register = register
fdecorate.fdefault = fdefault
return fdecorate
......
import time
import numpy as np
import tvm
from tvm import autotvm
from tvm.autotvm import MeasureInput, MeasureResult
from tvm.autotvm.tuner.xgboost_cost_model import XGBoostCostModel
from test_autotvm_common import get_sample_task, get_sample_records
def test_fit():
task, target = get_sample_task()
records = get_sample_records(n=100)
base_model = XGBoostCostModel(task, feature_type='itervar', loss_type='rank')
base_model.fit_log(records, plan_size=32)
upper_model = XGBoostCostModel(task, feature_type='itervar', loss_type='rank')
upper_model.load_basemodel(base_model)
xs = np.arange(100)
ys = np.arange(100)
upper_model.fit(xs, ys, plan_size=32)
def test_tuner():
task, target = get_sample_task()
records = get_sample_records(n=100)
tuner = autotvm.tuner.XGBTuner(task)
tuner.load_history(records)
if __name__ == "__main__":
test_fit()
test_tuner()
"""
How to get high performance convolution kernel on NVIDIA GPU by auto-tuning
Tuning High Performance Convolution on NVIDIA GPUs
=========================================================================
**Author**: `Lianmin Zheng <https://https://github.com/merrymercy>`_
......@@ -10,9 +10,11 @@ vendor provided library CuDNN in many cases.
import logging
import sys
import numpy as np
import tvm
import topi
from topi.testing import conv2d_nchw_python
from tvm import autotvm
......@@ -133,9 +135,10 @@ def conv2d_no_batching(N, H, W, CI, CO, KH, KW, stride, padding):
# logging config (for printing tuning log to screen)
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
# the last layer in resnet
# the last layer in resnet
N, H, W, CO, CI, KH, KW, strides, padding = 1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1)
task = autotvm.task.create(conv2d_no_batching,
args=(1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1)),
args=(N, H, W, CO, CI, KH, KW, strides, padding),
target='cuda')
print(task.config_space)
......@@ -146,15 +149,43 @@ measure_option = autotvm.measure_option(mode='local',
parallel_num=8,
timeout=20)
# begin tuning, log records to file `cache.tsv`
# begin tuning, log records to file `conv2d.tsv`
tuner = autotvm.tuner.XGBTuner(task)
tuner.tune(n_trial=20,
measure_option=measure_option,
callbacks=[autotvm.callback.log_to_file('cache.tsv')])
callbacks=[autotvm.callback.log_to_file('conv2d.log')])
# get best config from cache file
dispatch_context = autotvm.apply_history_best("cache.tsv")
#########################################################################
# Finally we can inspect the best config from log file, check correctness,
# and measure running time.
# inspect the best config
dispatch_context = autotvm.apply_history_best("conv2d.log")
best_config = dispatch_context.query(task.target, task.workload)
print("\nBest config:")
print(best_config)
# apply history best from log file
with autotvm.apply_history_best('conv2d.log'):
with tvm.target.create("cuda"):
s, arg_bufs = conv2d_no_batching(N, H, W, CO, CI, KH, KW, strides, padding)
func = tvm.build(s, arg_bufs)
# check correctness
a_np = np.random.uniform(size=(N, CI, H, W)).astype(np.float32)
w_np = np.random.uniform(size=(CO, CI, KH, KW)).astype(np.float32)
c_np = conv2d_nchw_python(a_np, w_np, strides, padding)
ctx = tvm.gpu()
a_tvm = tvm.nd.array(a_np, ctx=ctx)
w_tvm = tvm.nd.array(w_np, ctx=ctx)
c_tvm = tvm.nd.empty(c_np.shape, ctx=ctx)
func(a_tvm, w_tvm, c_tvm)
np.testing.assert_allclose(c_np, c_tvm.asnumpy(), rtol=1e-2)
# Evaluate running time. Here we choose a large repeat number (200) to reduce the noise
# and the overhead of kernel launch. You can also use nvprof to validate the result.
evaluator = func.time_evaluator(func.entry_name, ctx, number=200)
print('Time cost of this operator: %f' % evaluator(a_tvm, w_tvm, c_tvm).mean)
......@@ -243,7 +243,7 @@ print(task.config_space)
#
# We only make 10 trials in this tutorial for demonstration. In practice,
# you can do more trials according to your time budget.
# We will log the tuning results into a cache file. This file can be
# We will log the tuning results into a log file. This file can be
# used to get the best config later.
# logging config (for printing tuning log to screen)
......@@ -253,11 +253,11 @@ logging.basicConfig(level=logging.INFO, stream=sys.stdout)
measure_option = autotvm.measure_option(mode='local',
number=5)
# begin tuning, log records to file `cache.tsv`
# begin tuning, log records to file `matmul.log`
tuner = autotvm.tuner.RandomTuner(task)
tuner.tune(n_trial=10,
measure_option=measure_option,
callbacks=[autotvm.callback.log_to_file('cache.tsv')])
callbacks=[autotvm.callback.log_to_file('matmul.log')])
#########################################################################
# Finally we apply history best from the cache file and check its correctness.
......@@ -267,7 +267,7 @@ tuner.tune(n_trial=10,
# with the same argument.
# apply history best from log file
with autotvm.apply_history_best('cache.tsv'):
with autotvm.apply_history_best('matmul.log'):
with tvm.target.create("llvm"):
s, arg_bufs = matmul(N, L, M, 'float32')
func = tvm.build(s, arg_bufs)
......@@ -281,4 +281,3 @@ c_tvm = tvm.nd.empty(c_np.shape)
func(tvm.nd.array(a_np), tvm.nd.array(b_np), c_tvm)
np.testing.assert_allclose(c_np, c_tvm.asnumpy(), rtol=1e-2)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment