Commit c3cac464 by masahi Committed by Tianqi Chen

enable rocm target for topi/recipes. add timing util to gemm test. (#554)

parent 592a1f65
...@@ -69,7 +69,7 @@ def test_depthwise_conv2d_nchw(): ...@@ -69,7 +69,7 @@ def test_depthwise_conv2d_nchw():
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0) ctx = tvm.context(device, 0)
# Build the kernel # Build the kernel
f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], device) f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], device)
f2 = tvm.build(s2, [Input, Filter, Scale, Shift, ScaleShift], device) f2 = tvm.build(s2, [Input, Filter, Scale, Shift, ScaleShift], device)
...@@ -111,12 +111,13 @@ def test_depthwise_conv2d_nchw(): ...@@ -111,12 +111,13 @@ def test_depthwise_conv2d_nchw():
np.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5) np.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5)
print("success") print("success")
with tvm.build_config(auto_unroll_max_step=32, for device in ['cuda', 'opencl', 'rocm']:
auto_unroll_min_depth=0, with tvm.build_config(auto_unroll_max_step=32,
unroll_explicit=False, auto_unroll_min_depth=0,
detect_global_barrier=False, unroll_explicit=device == 'rocm',
restricted_func=True): detect_global_barrier=False,
check_device("cuda") restricted_func=True):
check_device(device)
def test_depthwise_conv2d_nhwc(): def test_depthwise_conv2d_nhwc():
"""You may test different settings.""" """You may test different settings."""
...@@ -159,7 +160,7 @@ def test_depthwise_conv2d_nhwc(): ...@@ -159,7 +160,7 @@ def test_depthwise_conv2d_nhwc():
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0) ctx = tvm.context(device, 0)
# Build the kernel # Build the kernel
f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], device) f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], device)
f2 = tvm.build(s2, [Input, Filter, Scale, Shift, ScaleShift], device) f2 = tvm.build(s2, [Input, Filter, Scale, Shift, ScaleShift], device)
...@@ -200,12 +201,13 @@ def test_depthwise_conv2d_nhwc(): ...@@ -200,12 +201,13 @@ def test_depthwise_conv2d_nhwc():
np.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5) np.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5)
print("success") print("success")
with tvm.build_config(auto_unroll_max_step=32, for device in ['cuda', 'opencl', 'rocm']:
auto_unroll_min_depth=0, with tvm.build_config(auto_unroll_max_step=32,
unroll_explicit=False, auto_unroll_min_depth=0,
detect_global_barrier=False, unroll_explicit=device == 'rocm',
restricted_func=True): detect_global_barrier=False,
check_device("cuda") restricted_func=True):
check_device(device)
if __name__ == "__main__": if __name__ == "__main__":
test_depthwise_conv2d_nchw() test_depthwise_conv2d_nchw()
......
...@@ -5,7 +5,7 @@ import scipy.signal ...@@ -5,7 +5,7 @@ import scipy.signal
import tvm import tvm
from tvm.contrib import nvcc from tvm.contrib import nvcc
import topi import topi
from topi.nn.util import get_const_tuple from topi.util import get_const_tuple
TASK = "conv2d_hwcn_map" TASK = "conv2d_hwcn_map"
USE_MANUAL_CODE = False USE_MANUAL_CODE = False
...@@ -55,14 +55,14 @@ def test_conv2d_hwcn_map(): ...@@ -55,14 +55,14 @@ def test_conv2d_hwcn_map():
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0) ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx) a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx) w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
with tvm.build_config(auto_unroll_max_step=32, with tvm.build_config(auto_unroll_max_step=32,
auto_unroll_min_depth=0, auto_unroll_min_depth=0,
unroll_explicit=False): unroll_explicit=device == 'rocm'):
func1 = tvm.build(s1, [A, W, B], device) func1 = tvm.build(s1, [A, W, B], device)
func1(a, w, b) func1(a, w, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
...@@ -70,7 +70,7 @@ def test_conv2d_hwcn_map(): ...@@ -70,7 +70,7 @@ def test_conv2d_hwcn_map():
func2(a, w, c) func2(a, w, c)
np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5) np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
for device in ['cuda', 'opencl']: for device in ['cuda', 'opencl', 'rocm']:
check_device(device) check_device(device)
......
...@@ -100,11 +100,12 @@ def test_gemm(): ...@@ -100,11 +100,12 @@ def test_gemm():
s[BB].double_buffer() s[BB].double_buffer()
# correctness # correctness
def check_device(device): def check_device(device):
print("Device %s" % device)
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
f = tvm.build(s, [A, B, C], device) f = tvm.build(s, [A, B, C], device)
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0) ctx = tvm.context(device, 0)
# launch the kernel. # launch the kernel.
n, m, l = nn, nn, nn n, m, l = nn, nn, nn
a_np = np.random.uniform(size=(n, l)).astype(A.dtype) a_np = np.random.uniform(size=(n, l)).astype(A.dtype)
...@@ -117,10 +118,18 @@ def test_gemm(): ...@@ -117,10 +118,18 @@ def test_gemm():
np.testing.assert_allclose( np.testing.assert_allclose(
c.asnumpy(), np.dot(b_np.T, a_np), rtol=1e-5) c.asnumpy(), np.dot(b_np.T, a_np), rtol=1e-5)
with tvm.build_config(auto_unroll_max_step=32, num_flops = 2 * nn * nn * nn
auto_unroll_min_depth=0, num_runs = 10
unroll_explicit=False): timer_f = f.time_evaluator(f.entry_name, ctx, number=num_runs)
check_device("cuda") t = timer_f(a, b, c).mean
GFLOPS = num_flops / (t * 1e3) / 1e6
print("average time cost of %d runs = %g ms, %g GFLOPS." % (num_runs, t * 1e3, GFLOPS))
for device in ['cuda', 'opencl', 'rocm']:
with tvm.build_config(auto_unroll_max_step=32,
auto_unroll_min_depth=0,
unroll_explicit=device == 'rocm'):
check_device(device)
if __name__ == "__main__": if __name__ == "__main__":
test_gemm() test_gemm()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment