Commit df3c996b by Tianqi Chen Committed by GitHub

[TEST] Add memoize to save test data (#424)

* [TEST] Add memoize to save test data

* Update comment

* mark py version
parent 071b138f
......@@ -127,6 +127,7 @@ jvm/*/*/target/
*.perspectivev3
!default.perspectivev3
xcuserdata/
.pkl_memoize_*
.emscripten*
.m2
......
......@@ -253,7 +253,7 @@ def register_extension(cls):
.. code-block:: python
@tvm.register_dltensor
@tvm.register_extension
class MyTensor(object):
def __init__(self):
self.handle = _LIB.NewDLTensor()
......
"""Memoize result of function via pickle, used for cache testcases."""
# pylint: disable=broad-except,superfluous-parens
import os
import sys
import atexit
from decorator import decorate
from .._ffi.base import string_types
try:
import cPickle as pickle
except ImportError:
import pickle
class Cache(object):
"""A cache object for result cache.
Parameters
----------
key: str
The file key to the function
"""
cache_by_key = {}
def __init__(self, key):
cache_dir = ".pkl_memoize_py{0}".format(sys.version_info[0])
if not os.path.exists(cache_dir):
os.mkdir(cache_dir)
self.path = os.path.join(cache_dir, key)
if os.path.exists(self.path):
try:
self.cache = pickle.load(open(self.path, "rb"))
except Exception:
self.cache = {}
else:
self.cache = {}
self.dirty = False
def save(self):
if self.dirty:
print("Save memoize result to %s" % self.path)
with open(self.path, "wb") as out_file:
pickle.dump(self.cache, out_file, pickle.HIGHEST_PROTOCOL)
@atexit.register
def _atexit():
"""Save handler."""
for value in Cache.cache_by_key.values():
value.save()
def memoize(key):
"""Memoize the result of function and reuse multiple times.
Parameters
----------
key: str
The unique key to the file
Returns
-------
fmemoize : function
The decorator function to perform memoization.
"""
def _register(f):
"""Registration function"""
allow_types = (string_types, int, float)
fkey = key + "." + f.__name__ + ".pkl"
if fkey not in Cache.cache_by_key:
Cache.cache_by_key[fkey] = Cache(fkey)
cache = Cache.cache_by_key[fkey]
cargs = tuple(x.cell_contents for x in f.__closure__)
cargs = (len(cargs),) + cargs
def _memoized_f(func, *args, **kwargs):
assert not kwargs, "Only allow positional call"
key = cargs + args
for arg in key:
if isinstance(arg, tuple):
for x in arg:
assert isinstance(x, allow_types)
else:
assert isinstance(arg, allow_types)
if key in cache.cache:
print("Use memoize {0}{1}".format(fkey, key))
return cache.cache[key]
res = func(*args)
cache.cache[key] = res
cache.dirty = True
return res
return decorate(f, _memoized_f)
return _register
......@@ -3,10 +3,11 @@ import os
import numpy as np
import tvm
import topi
from tvm.contrib.pickle_memoize import memoize
from topi.util import get_const_tuple
def verify_conv2d_hwcn_map(batch, in_channel, in_size, num_filter, kernel, stride, padding):
def verify_conv2d_hwcn(batch, in_channel, in_size, num_filter, kernel, stride, padding):
in_height = in_width = in_size
A = tvm.placeholder((in_height, in_width, in_channel, batch), name='A')
......@@ -16,10 +17,18 @@ def verify_conv2d_hwcn_map(batch, in_channel, in_size, num_filter, kernel, strid
s1 = topi.cuda.schedule_conv2d_hwcn([B])
s2 = topi.cuda.schedule_conv2d_hwcn([C])
a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
w_np = np.random.uniform(size=get_const_tuple(W.shape)).astype(W.dtype)
b_np = topi.testing.conv2d_hwcn_python(a_np, w_np, stride, padding)
c_np = np.maximum(b_np, 0)
a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape)
dtype = A.dtype
@memoize("topi.tests.test_topi_conv2d_hwcn.verify_hwcn")
def get_ref_data():
a_np = np.random.uniform(size=a_shape).astype(dtype)
w_np = np.random.uniform(size=w_shape).astype(dtype)
b_np = topi.testing.conv2d_hwcn_python(a_np, w_np, stride, padding)
c_np = np.maximum(b_np, 0)
return a_np, w_np, b_np, c_np
a_np, w_np, b_np, c_np = get_ref_data()
def check_device(device):
if not tvm.module.enabled(device):
......@@ -44,16 +53,16 @@ def verify_conv2d_hwcn_map(batch, in_channel, in_size, num_filter, kernel, strid
check_device(device)
def test_conv2d_hwcn_map():
verify_conv2d_hwcn_map(1, 256, 32, 256, 3, 1, "SAME")
verify_conv2d_hwcn_map(1, 256, 32, 256, 3, 1, "SAME")
verify_conv2d_hwcn_map(4, 128, 16, 128, 5, 2, "SAME")
verify_conv2d_hwcn_map(4, 128, 16, 256, 5, 2, "SAME")
verify_conv2d_hwcn_map(1, 256, 32, 256, 3, 1, "VALID")
verify_conv2d_hwcn_map(1, 256, 32, 256, 3, 1, "VALID")
verify_conv2d_hwcn_map(4, 128, 16, 128, 5, 2, "VALID")
verify_conv2d_hwcn_map(4, 128, 16, 256, 5, 2, "VALID")
def test_conv2d_hwcn():
verify_conv2d_hwcn(1, 256, 32, 256, 3, 1, "SAME")
verify_conv2d_hwcn(1, 256, 32, 256, 3, 1, "SAME")
verify_conv2d_hwcn(4, 128, 16, 128, 5, 2, "SAME")
verify_conv2d_hwcn(4, 128, 16, 256, 5, 2, "SAME")
verify_conv2d_hwcn(1, 256, 32, 256, 3, 1, "VALID")
verify_conv2d_hwcn(1, 256, 32, 256, 3, 1, "VALID")
verify_conv2d_hwcn(4, 128, 16, 128, 5, 2, "VALID")
verify_conv2d_hwcn(4, 128, 16, 256, 5, 2, "VALID")
if __name__ == "__main__":
test_conv2d_hwcn_map()
test_conv2d_hwcn()
......@@ -3,6 +3,7 @@ import os
import numpy as np
import tvm
import topi
from tvm.contrib.pickle_memoize import memoize
from topi.util import get_const_tuple
......@@ -16,10 +17,19 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
s1 = topi.cuda.schedule_conv2d_nchw([B])
s2 = topi.cuda.schedule_conv2d_nchw([C])
a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
w_np = np.random.uniform(size=get_const_tuple(W.shape)).astype(W.dtype)
b_np = topi.testing.conv2d_nchw_python(a_np, w_np, stride, padding)
c_np = np.maximum(b_np, 0)
a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape)
dtype = A.dtype
@memoize("topi.tests.test_topi_conv2d.verify_con2d_nchw")
def get_ref_data():
a_np = np.random.uniform(size=a_shape).astype(dtype)
w_np = np.random.uniform(size=w_shape).astype(dtype)
b_np = topi.testing.conv2d_nchw_python(a_np, w_np, stride, padding)
c_np = np.maximum(b_np, 0)
return a_np, w_np, b_np, c_np
a_np, w_np, b_np, c_np = get_ref_data()
def check_device(device):
if not tvm.module.enabled(device):
......
......@@ -3,6 +3,7 @@ import os
import numpy as np
import tvm
import topi
from tvm.contrib.pickle_memoize import memoize
from topi.util import get_const_tuple
......@@ -13,11 +14,21 @@ def verify_convolution(batch, in_size, in_channel, num_filter, kernel, stride, p
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W')
B = topi.nn.convolution(A, W, stride, padding)
s = topi.rasp.schedule_convolution([B])
a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
w_np = np.random.uniform(size=get_const_tuple(W.shape)).astype(W.dtype)
b_np = topi.testing.conv2d_nchw_python(a_np, w_np, stride, padding)
a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape)
dtype = A.dtype
@memoize("topi.tests.test_topi_convolution.verify_convolution")
def get_ref_data():
a_np = np.random.uniform(size=a_shape).astype(dtype)
w_np = np.random.uniform(size=w_shape).astype(dtype)
b_np = topi.testing.conv2d_nchw_python(a_np, w_np, stride, padding)
return a_np, w_np, b_np
a_np, w_np, b_np = get_ref_data()
ctx = tvm.cpu(0)
a = tvm.nd.array(a_np, ctx)
......
......@@ -3,8 +3,10 @@ import topi
import numpy as np
from scipy import signal
from topi.util import get_const_tuple
from tvm.contrib.pickle_memoize import memoize
from topi.cuda.depthwise_conv2d import schedule_depthwise_conv2d_nchw, schedule_depthwise_conv2d_nhwc
def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding):
in_width = in_height
filter_channel = in_channel
......@@ -25,11 +27,6 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu
s2 = schedule_depthwise_conv2d_nchw(ScaleShift)
s3 = schedule_depthwise_conv2d_nchw(Relu)
input_np = np.random.uniform(size=get_const_tuple(Input.shape)).astype(Input.dtype)
filter_np = np.random.uniform(size=get_const_tuple(Filter.shape)).astype(Filter.dtype)
scale_np = np.random.uniform(size=get_const_tuple(Scale.shape)).astype(Scale.dtype)
shift_np = np.random.uniform(size=get_const_tuple(Shift.shape)).astype(Shift.dtype)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
......@@ -39,7 +36,35 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu
f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], device)
f2 = tvm.build(s2, [Input, Filter, Scale, Shift, ScaleShift], device)
f3 = tvm.build(s3, [Input, Filter, Scale, Shift, Relu], device)
# prepare data
# Prepare pod type for test data closure
dtype = Input.dtype
input_shape = get_const_tuple(Input.shape)
filter_shape = get_const_tuple(Filter.shape)
scale_shape = get_const_tuple(Scale.shape)
shift_shape = get_const_tuple(Shift.shape)
scale_shift_shape = get_const_tuple(ScaleShift.shape)
# Use memoize, pickle the test data for next time use.
@memoize("topi.tests.test_topi_depthwise_conv2d.nchw")
def get_ref_data():
input_np = np.random.uniform(size=input_shape).astype(dtype)
filter_np = np.random.uniform(size=filter_shape).astype(dtype)
scale_np = np.random.uniform(size=scale_shape).astype(dtype)
shift_np = np.random.uniform(size=shift_shape).astype(dtype)
# correctness with scipy
depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python_nchw(
input_np, filter_np, stride=[stride_h, stride_w], padding=padding)
scale_shift_scipy = np.zeros(shape=scale_shift_shape)
for c in range(in_channel * channel_multiplier):
scale_shift_scipy[:,c,:,:] = depthwise_conv2d_scipy[:,c,:,:] * scale_np[c] + shift_np[c]
relu_scipy = np.maximum(scale_shift_scipy, 0)
return (input_np, filter_np, scale_np, shift_np,
depthwise_conv2d_scipy, scale_shift_scipy, relu_scipy)
# Get the test data
(input_np, filter_np, scale_np, shift_np,
depthwise_conv2d_scipy, scale_shift_scipy, relu_scipy) = get_ref_data()
input_tvm = tvm.nd.array(input_np, ctx)
filter_tvm = tvm.nd.array(filter_np, ctx)
scale_tvm = tvm.nd.array(scale_np, ctx)
......@@ -56,12 +81,6 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu
# launch kernel 3 (depthwise_conv2d + scale_shift + relu)
timer_3 = f3.time_evaluator(f3.entry_name, ctx, number=1)
tcost_3 = timer_3(input_tvm, filter_tvm, scale_tvm, shift_tvm, relu_tvm).mean
# correctness with scipy
depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python_nchw(input_np, filter_np, stride=[stride_h, stride_w], padding=padding)
scale_shift_scipy = np.zeros(shape=get_const_tuple(ScaleShift.shape))
for c in range(in_channel * channel_multiplier):
scale_shift_scipy[:,c,:,:] = depthwise_conv2d_scipy[:,c,:,:] * scale_np[c] + shift_np[c]
relu_scipy = np.maximum(scale_shift_scipy, 0)
np.testing.assert_allclose(depthwise_conv2d_tvm.asnumpy(), depthwise_conv2d_scipy, rtol=1e-5)
np.testing.assert_allclose(scale_shift_tvm.asnumpy(), scale_shift_scipy, rtol=1e-5)
np.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5)
......@@ -90,11 +109,6 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu
s2 = schedule_depthwise_conv2d_nhwc(ScaleShift)
s3 = schedule_depthwise_conv2d_nhwc(Relu)
input_np = np.random.uniform(size=get_const_tuple(Input.shape)).astype(Input.dtype)
filter_np = np.random.uniform(size=get_const_tuple(Filter.shape)).astype(Filter.dtype)
scale_np = np.random.uniform(size=get_const_tuple(Scale.shape)).astype(Scale.dtype)
shift_np = np.random.uniform(size=get_const_tuple(Shift.shape)).astype(Shift.dtype)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
......@@ -104,6 +118,35 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu
f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], device)
f2 = tvm.build(s2, [Input, Filter, Scale, Shift, ScaleShift], device)
f3 = tvm.build(s3, [Input, Filter, Scale, Shift, Relu], device)
# Prepare pod type for test data closure
dtype = Input.dtype
input_shape = get_const_tuple(Input.shape)
filter_shape = get_const_tuple(Filter.shape)
scale_shape = get_const_tuple(Scale.shape)
shift_shape = get_const_tuple(Shift.shape)
scale_shift_shape = get_const_tuple(ScaleShift.shape)
# Use memoize, pickle the test data for next time use.
@memoize("topi.tests.test_topi_depthwise_conv2d.nhwc")
def get_ref_data():
input_np = np.random.uniform(size=input_shape).astype(dtype)
filter_np = np.random.uniform(size=filter_shape).astype(dtype)
scale_np = np.random.uniform(size=scale_shape).astype(dtype)
shift_np = np.random.uniform(size=shift_shape).astype(dtype)
# correctness with scipy
depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python_nhwc(
input_np, filter_np, stride=[stride_h, stride_w], padding=padding)
scale_shift_scipy = np.zeros(shape=scale_shift_shape)
for c in range(in_channel * channel_multiplier):
scale_shift_scipy[:,:,:,c] = depthwise_conv2d_scipy[:,:,:,c] * scale_np[c] + shift_np[c]
relu_scipy = np.maximum(scale_shift_scipy, 0)
return (input_np, filter_np, scale_np, shift_np,
depthwise_conv2d_scipy, scale_shift_scipy, relu_scipy)
# Get the test data
(input_np, filter_np, scale_np, shift_np,
depthwise_conv2d_scipy, scale_shift_scipy, relu_scipy) = get_ref_data()
# prepare data
input_tvm = tvm.nd.array(input_np, ctx)
filter_tvm = tvm.nd.array(filter_np, ctx)
......@@ -121,11 +164,6 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu
# launch kernel 3 (depthwise_conv2d + scale_shift + relu)
timer_3 = f3.time_evaluator(f3.entry_name, ctx, number=1)
tcost_3 = timer_3(input_tvm, filter_tvm, scale_tvm, shift_tvm, relu_tvm).mean
# correctness with scipy
depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python_nhwc(input_np, filter_np, stride=[stride_h, stride_w], padding=padding)
scale_shift_scipy = np.zeros(shape=get_const_tuple(ScaleShift.shape))
for c in range(in_channel * channel_multiplier):
scale_shift_scipy[:,:,:,c] = depthwise_conv2d_scipy[:,:,:,c] * scale_np[c] + shift_np[c]
relu_scipy = np.maximum(scale_shift_scipy, 0)
np.testing.assert_allclose(depthwise_conv2d_tvm.asnumpy(), depthwise_conv2d_scipy, rtol=1e-5)
np.testing.assert_allclose(scale_shift_tvm.asnumpy(), scale_shift_scipy, rtol=1e-5)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment