Unverified Commit 7013fc9a by wpan11nv Committed by GitHub

[TOPI][CUDA] Enable vectorization on fp16 type (#4867)

- This allows to better utilize the memory bandwidth

- Note that not all cases are vectorized for fp16 datatype. For
  instance, when the size is not a multiple of 1024, the inner loop
  may be an expression that cannot be vectorized. In this case, a
  small inner loop is still benefical for latency hidding.

Signed-off-by: Wei Pan <weip@nvidia.com>
parent b787ffa3
......@@ -40,13 +40,20 @@ def schedule_injective_from_existing(sch, out):
num_thread = tvm.target.Target.current(allow_none=False).max_num_threads
max_block = 256
# vectorize on fp16 data type. This allows to better utilize the memory
# bandwidth.
vector_width = 4 if out.dtype == "float16" else 1
try:
const_size = util.get_const_int(util.prod(out.shape))
max_block = 256
need_block_split = const_size > max_block * num_thread
need_block_split = const_size > max_block * num_thread * vector_width
except ValueError:
need_block_split = False
if vector_width > 1:
fused, v = sch[out].split(fused, vector_width)
sch[out].vectorize(v)
if need_block_split:
xo, xi = sch[out].split(fused, factor=num_thread * max_block)
bx, tx = sch[out].split(xi, factor=num_thread)
......
......@@ -20,11 +20,20 @@ import numpy as np
import tvm
import topi
from topi.util import get_const_tuple
from tvm.contrib.nvcc import parse_compute_version
from common import get_all_backend
def verify_relu(m, n):
A = tvm.placeholder((m, n), name='A')
def skip_test(dtype, device):
if dtype == "float16" and device == "cuda":
major, minor = parse_compute_version(tvm.gpu(0).compute_version)
# fp16 starts from 5.3
if major < 6 or (major == 5 and minor < 3):
print("skip because gpu does not support fp16")
return True
return False
def verify_relu(m, n, dtype="float32"):
A = tvm.placeholder((m, n), name='A', dtype=dtype)
B = topi.nn.relu(A)
a_np = np.random.uniform(low=-1.0, high=1.0, size=get_const_tuple(A.shape)).astype(A.dtype)
......@@ -35,6 +44,8 @@ def verify_relu(m, n):
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
if skip_test(dtype, device):
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_elemwise(B)
......@@ -87,12 +98,12 @@ def verify_prelu(x, w, axis, weight_reshape):
tvm.testing.assert_allclose(b.asnumpy(), out_np, rtol=1e-5)
def test_relu():
verify_relu(10, 128)
verify_relu(10, 128, "float32")
verify_relu(128, 64, "float16")
def test_schedule_big_array():
verify_relu(1024 * 100 , 512)
def test_leaky_relu():
verify_leaky_relu(100, 0.1)
......
......@@ -19,6 +19,16 @@ import numpy as np
import tvm
import topi
from tvm.contrib.pickle_memoize import memoize
from tvm.contrib.nvcc import parse_compute_version
def skip_test(dtype, device):
if dtype == "float16" and device == "cuda":
major, minor = parse_compute_version(tvm.gpu(0).compute_version)
# fp16 starts from 5.3
if major < 6 or (major == 5 and minor < 3):
print("skip because gpu does not support fp16")
return True
return False
def verify_elemwise_sum(num_args, dtype):
shape = (3,5,4)
......@@ -84,18 +94,43 @@ def verify_full(shape, dtype, fill_value):
for device in ["llvm"]:
check_device(device)
def verify_vectorization(n, m, dtype):
def check_device(device):
if not tvm.runtime.enabled(device):
print("Skip because %s is not enabled" % device)
return
if skip_test(dtype, device):
return
with tvm.target.create(device):
ctx = tvm.context(device, 0)
A = tvm.placeholder((n, m), name='A', dtype=dtype)
B = tvm.compute((n, m), lambda i, j:
A[i, j] + tvm.const(1, A.dtype), name='B')
S = topi.generic.schedule_elemwise(B)
fun = tvm.build(S, [A, B], device)
np_A = tvm.nd.empty((n, m), A.dtype, ctx).copyfrom(
np.random.uniform(size=(n, m)))
np_B = tvm.nd.empty((n, m), B.dtype, ctx)
fun(np_A, np_B)
tvm.testing.assert_allclose(np_B.asnumpy(), np_A.asnumpy() + 1, rtol=1e-5)
for device in ["cuda"]:
check_device(device)
def test_vectorization():
verify_vectorization(128, 64, "float16")
def test_elemwise_sum():
verify_elemwise_sum(1, "float32")
verify_elemwise_sum(5, "float32")
verify_elemwise_sum(4, "int32")
def test_full():
verify_full((3,4,5), "float32", 3.14)
verify_full((10,), "int32", 7)
if __name__ == "__main__":
test_elemwise_sum()
test_full()
test_vectorization()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment