Commit ce72e9b5 by Xingyu Zhou Committed by Wuwei Lin

[codegen] Add multiple operands and function support when using fp16 compilation (#4056)

* overload half operators for cuda codegen

* add float16 te test_op_level1

* fix test_op_level1.py

* fix lint

* disable fp16 test if gpu does not support

* disable fp16 test if gpu does not support

* bypass float16 test if gpu does not support float16
parent d08ec106
...@@ -50,6 +50,20 @@ void CodeGenCUDA::AddFunction(LoweredFunc f) { ...@@ -50,6 +50,20 @@ void CodeGenCUDA::AddFunction(LoweredFunc f) {
std::string CodeGenCUDA::Finish() { std::string CodeGenCUDA::Finish() {
if (enable_fp16_) { if (enable_fp16_) {
decl_stream << "#include <cuda_fp16.h>\n"; decl_stream << "#include <cuda_fp16.h>\n";
decl_stream << "__device__ half max" \
"(const half a, const half b)\n"
"{\n return __hgt(__half(a), __half(b)) ? a : b;\n}\n";
decl_stream << "__device__ half min(const half a, const half b)\n"
"{\n return __hlt(__half(a), __half(b)) ? a : b;\n}\n";
decl_stream << "__device__ half operator+" \
"(const volatile __half &a, const volatile __half &b)\n"
"{\n return __hadd(a, b);\n}\n";
decl_stream << "__device__ half operator<=" \
"(const volatile __half &a, const volatile __half &b)\n"
"{\n return __hlt(a, b);\n}\n";
decl_stream << "__device__ half operator*" \
"(const volatile __half &a, const volatile __half &b)\n"
"{\n return __hmul(a, b);\n}\n";
} }
if (enable_int8_) { if (enable_int8_) {
......
...@@ -21,6 +21,7 @@ from tvm import relay ...@@ -21,6 +21,7 @@ from tvm import relay
from tvm.relay import transform from tvm.relay import transform
from tvm.relay.testing import ctx_list from tvm.relay.testing import ctx_list
import topi.testing import topi.testing
from tvm.contrib.nvcc import have_fp16
def run_infer_type(expr): def run_infer_type(expr):
mod = relay.Module.from_expr(expr) mod = relay.Module.from_expr(expr)
...@@ -42,11 +43,11 @@ def rsqrt(x): ...@@ -42,11 +43,11 @@ def rsqrt(x):
return one / np.sqrt(x) return one / np.sqrt(x)
def test_unary_op(): def test_unary_op():
def check_single_op(opfunc, ref): def check_single_op(opfunc, ref, dtype):
shape = (10, 4) shape = (10, 4)
dtype = 'float32' dtype = dtype
tp = relay.TensorType(shape, dtype) tp = relay.TensorType(shape)
x = relay.var("x", tp) x = relay.var("x", tp, dtype=dtype)
y = opfunc(x) y = opfunc(x)
# test printer # test printer
assert ("{}(%x)".format(y.op.name)) in y.astext() assert ("{}(%x)".format(y.op.name)) in y.astext()
...@@ -61,6 +62,8 @@ def test_unary_op(): ...@@ -61,6 +62,8 @@ def test_unary_op():
for target, ctx in ctx_list(): for target, ctx in ctx_list():
# use graph by execuor default for testing, as we need # use graph by execuor default for testing, as we need
# create function explicitly to avoid constant-folding. # create function explicitly to avoid constant-folding.
if dtype == 'float16' and target == 'cuda' and not have_fp16(tvm.gpu(0).compute_version):
continue
intrp = relay.create_executor("graph", ctx=ctx, target=target) intrp = relay.create_executor("graph", ctx=ctx, target=target)
op_res = intrp.evaluate(func)(data) op_res = intrp.evaluate(func)(data)
np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01) np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
...@@ -77,22 +80,23 @@ def test_unary_op(): ...@@ -77,22 +80,23 @@ def test_unary_op():
(tvm.relay.cos, np.cos), (tvm.relay.cos, np.cos),
(tvm.relay.sin, np.sin), (tvm.relay.sin, np.sin),
(tvm.relay.atan, np.arctan)]: (tvm.relay.atan, np.arctan)]:
check_single_op(opfunc, ref) for dtype in ['float16', 'float32']:
check_single_op(opfunc, ref, dtype)
def test_binary_op(): def test_binary_op():
def inst(vars, sh): def inst(vars, sh):
return [vars.get(s, s) for s in sh] return [vars.get(s, s) for s in sh]
def check_binary_op(opfunc, ref): def check_binary_op(opfunc, ref, dtype):
# TODO(@jroesch): this piece of code improperly uses type variables. # TODO(@jroesch): this piece of code improperly uses type variables.
n = tvm.var("n") n = tvm.var("n")
s1 = (5, n, 5) s1 = (5, n, 5)
s2 = (n, 1) s2 = (n, 1)
t1 = relay.TensorType(s1) t1 = relay.TensorType(s1)
t2 = relay.TensorType(s2) t2 = relay.TensorType(s2)
x = relay.var("x", t1) x = relay.var("x", t1, dtype=dtype)
y = relay.var("y", t2) y = relay.var("y", t2, dtype=dtype)
z = opfunc(x, y) z = opfunc(x, y)
# test printer # test printer
assert ("{}(%x, %y)".format(z.op.name)) in z.astext() assert ("{}(%x, %y)".format(z.op.name)) in z.astext()
...@@ -102,17 +106,19 @@ def test_binary_op(): ...@@ -102,17 +106,19 @@ def test_binary_op():
if ref is not None: if ref is not None:
t1 = relay.TensorType((5, 10, 5)) t1 = relay.TensorType((5, 10, 5))
t2 = relay.TensorType((5, 10, 5)) t2 = relay.TensorType((5, 10, 5))
x = relay.var("x", t1) x = relay.var("x", t1, dtype=dtype)
y = relay.var("y", t2) y = relay.var("y", t2, dtype=dtype)
z = opfunc(x, y) z = opfunc(x, y)
x_data = np.random.rand(5, 10, 5).astype(t1.dtype) x_data = np.random.rand(5, 10, 5).astype(dtype)
y_data = np.random.rand(5, 10, 5).astype(t2.dtype) y_data = np.random.rand(5, 10, 5).astype(dtype)
ref_res = ref(x_data, y_data) ref_res = ref(x_data, y_data)
func = relay.Function([x, y], z) func = relay.Function([x, y], z)
for target, ctx in ctx_list(): for target, ctx in ctx_list():
# use graph by execuor default for testing, as we need # use graph by execuor default for testing, as we need
# create function explicitly to avoid constant-folding. # create function explicitly to avoid constant-folding.
if dtype == 'float16' and target == 'cuda' and not have_fp16(tvm.gpu(0).compute_version):
continue
intrp = relay.create_executor("graph", ctx=ctx, target=target) intrp = relay.create_executor("graph", ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data, y_data) op_res = intrp.evaluate(func)(x_data, y_data)
np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01) np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
...@@ -121,7 +127,8 @@ def test_binary_op(): ...@@ -121,7 +127,8 @@ def test_binary_op():
(relay.subtract, np.subtract), (relay.subtract, np.subtract),
(relay.multiply, np.multiply), (relay.multiply, np.multiply),
(relay.divide, np.divide)]: (relay.divide, np.divide)]:
check_binary_op(opfunc, ref) for dtype in ['float16', 'float32']:
check_binary_op(opfunc, ref, dtype)
def test_expand_dims(): def test_expand_dims():
...@@ -130,55 +137,65 @@ def test_expand_dims(): ...@@ -130,55 +137,65 @@ def test_expand_dims():
x = relay.Var("x", relay.TensorType(dshape, dtype)) x = relay.Var("x", relay.TensorType(dshape, dtype))
func = relay.Function([x], relay.expand_dims(x, axis, num_newaxis)) func = relay.Function([x], relay.expand_dims(x, axis, num_newaxis))
for target, ctx in ctx_list(): for target, ctx in ctx_list():
if dtype == 'float16' and target == 'cuda' and not have_fp16(tvm.gpu(0).compute_version):
continue
data = np.random.uniform(size=dshape).astype(dtype) data = np.random.uniform(size=dshape).astype(dtype)
ref_res = data.reshape(oshape) ref_res = data.reshape(oshape)
intrp = relay.create_executor("graph", ctx=ctx, target=target) intrp = relay.create_executor("graph", ctx=ctx, target=target)
op_res = intrp.evaluate(func)(data) op_res = intrp.evaluate(func)(data)
np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01) np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
for dtype in ['float16', 'float32']:
verify_expand_dims((3, 10), 'float32', (3, 10, 1, 1), 2, 2) verify_expand_dims((3, 10), dtype, (3, 10, 1, 1), 2, 2)
verify_expand_dims((3, 10), 'float32', (1, 3, 10), -3, 1) verify_expand_dims((3, 10), dtype, (1, 3, 10), -3, 1)
def test_bias_add(): def test_bias_add():
for dtype in ['float16', 'float32']:
xshape=(10, 2, 3, 4) xshape=(10, 2, 3, 4)
bshape=(2,) bshape=(2,)
dtype="float32" rtol = 1e-2 if dtype is 'float16' else 1e-5
x = relay.var("x", shape=xshape) x = relay.var("x", shape=xshape, dtype=dtype)
bias = relay.var("bias") bias = relay.var("bias", dtype=dtype)
z = relay.nn.bias_add(x, bias) z = relay.nn.bias_add(x, bias)
zz = run_infer_type(z) zz = run_infer_type(z)
assert "axis=" not in zz.astext() assert "axis=" not in zz.astext()
assert zz.args[1].checked_type == relay.TensorType(bshape) assert zz.args[1].checked_type == relay.TensorType(bshape, dtype)
func = relay.Function([x, bias], z) func = relay.Function([x, bias], z)
x_data = np.random.uniform(size=xshape).astype(dtype) x_data = np.random.uniform(size=xshape).astype(dtype)
y_data = np.random.uniform(size=bshape).astype(dtype) y_data = np.random.uniform(size=bshape).astype(dtype)
ref_res = x_data + y_data.reshape((2, 1, 1)) ref_res = x_data + y_data.reshape((2, 1, 1))
for target, ctx in ctx_list(): for target, ctx in ctx_list():
if dtype == 'float16' and target == 'cuda' and not have_fp16(tvm.gpu(0).compute_version):
continue
intrp = relay.create_executor("graph", ctx=ctx, target=target) intrp = relay.create_executor("graph", ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data, y_data) op_res = intrp.evaluate(func)(x_data, y_data)
np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=rtol)
def test_expand_dims_infer_type(): def test_expand_dims_infer_type():
for dtype in ['float16', 'float32']:
n, t, d = tvm.var("n"), tvm.var("t"), 100 n, t, d = tvm.var("n"), tvm.var("t"), 100
x = relay.var("x", shape=(n, t, d)) x = relay.var("x", shape=(n, t, d), dtype=dtype)
y = relay.expand_dims(x, axis=2) y = relay.expand_dims(x, axis=2)
assert "axis=2" in y.astext() assert "axis=2" in y.astext()
yy = run_infer_type(y) yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, t, 1, 100)) assert yy.checked_type == relay.TensorType((n, t, 1, 100), dtype)
def test_softmax(): def test_softmax():
for dtype in ['float16', 'float32']:
# Softmax accuracy for float16 is poor
if dtype == 'float16':
return
shape = (10, 4) shape = (10, 4)
x = relay.var("x", shape=shape) x = relay.var("x", shape=shape, dtype=dtype)
y = relay.nn.softmax(x, axis=1) y = relay.nn.softmax(x, axis=1)
assert "nn.softmax" in y.astext() assert "nn.softmax" in y.astext()
yy = run_infer_type(y) yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType(shape) assert yy.checked_type == relay.TensorType(shape, dtype)
func = relay.Function([x], y) func = relay.Function([x], y)
x_data = np.random.uniform(size=shape).astype("float32") x_data = np.random.uniform(size=shape).astype(dtype)
ref_res = topi.testing.softmax_python(x_data) ref_res = topi.testing.softmax_python(x_data)
for target, ctx in ctx_list(): for target, ctx in ctx_list():
intrp = relay.create_executor("graph", ctx=ctx, target=target) intrp = relay.create_executor("graph", ctx=ctx, target=target)
...@@ -187,14 +204,18 @@ def test_softmax(): ...@@ -187,14 +204,18 @@ def test_softmax():
def test_log_softmax(): def test_log_softmax():
for dtype in ['float16', 'float32']:
# Softmax accuracy for float16 is poor
if dtype == 'float16':
return
shape = (10, 4) shape = (10, 4)
x = relay.var("x", shape=shape) x = relay.var("x", shape=shape, dtype=dtype)
y = relay.nn.log_softmax(x, axis=1) y = relay.nn.log_softmax(x, axis=1)
assert "nn.log_softmax" in y.astext() assert "nn.log_softmax" in y.astext()
yy = run_infer_type(y) yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType(shape) assert yy.checked_type == relay.TensorType(shape, dtype)
func = relay.Function([x], y) func = relay.Function([x], y)
x_data = np.random.uniform(size=shape).astype("float32") x_data = np.random.uniform(size=shape).astype(dtype)
ref_res = topi.testing.log_softmax_python(x_data) ref_res = topi.testing.log_softmax_python(x_data)
for target, ctx in ctx_list(): for target, ctx in ctx_list():
intrp = relay.create_executor("graph", ctx=ctx, target=target) intrp = relay.create_executor("graph", ctx=ctx, target=target)
...@@ -203,6 +224,7 @@ def test_log_softmax(): ...@@ -203,6 +224,7 @@ def test_log_softmax():
def test_concatenate(): def test_concatenate():
for dtype in ['float16', 'float32']:
n, t, d = tvm.var("n"), tvm.var("t"), 100 n, t, d = tvm.var("n"), tvm.var("t"), 100
x = relay.var("x", shape=(n, t, d)) x = relay.var("x", shape=(n, t, d))
y = relay.var("y", shape=(n, t, d)) y = relay.var("y", shape=(n, t, d))
...@@ -232,19 +254,21 @@ def test_concatenate(): ...@@ -232,19 +254,21 @@ def test_concatenate():
else: else:
assert False assert False
x = relay.var("x", shape=(10, 5)) x = relay.var("x", shape=(10, 5), dtype=dtype)
y = relay.var("y", shape=(10, 5)) y = relay.var("y", shape=(10, 5), dtype=dtype)
t = relay.var("z", shape=()) t = relay.var("z", shape=(), dtype=dtype)
z = relay.concatenate((x, y), axis=1) z = relay.concatenate((x, y), axis=1)
z = relay.add(z, t) z = relay.add(z, t)
# Check result. # Check result.
func = relay.Function([x, y, t], z) func = relay.Function([x, y, t], z)
x_data = np.random.rand(10, 5).astype('float32') x_data = np.random.rand(10, 5).astype(dtype)
y_data = np.random.rand(10, 5).astype('float32') y_data = np.random.rand(10, 5).astype(dtype)
t_data = np.random.uniform(size=()).astype('float32') t_data = np.random.uniform(size=()).astype(dtype)
ref_res = np.concatenate((x_data, y_data), axis=1) + t_data ref_res = np.concatenate((x_data, y_data), axis=1) + t_data
for target, ctx in ctx_list(): for target, ctx in ctx_list():
if dtype == 'float16' and target == 'cuda' and not have_fp16(tvm.gpu(0).compute_version):
continue
intrp1 = relay.create_executor("graph", ctx=ctx, target=target) intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
intrp2 = relay.create_executor("debug", ctx=ctx, target=target) intrp2 = relay.create_executor("debug", ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(x_data, y_data, t_data) op_res1 = intrp1.evaluate(func)(x_data, y_data, t_data)
...@@ -253,8 +277,9 @@ def test_concatenate(): ...@@ -253,8 +277,9 @@ def test_concatenate():
tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=0.01) tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=0.01)
def test_dropout(): def test_dropout():
for dtype in ['float16', 'float32']:
n, t, d = tvm.var("n"), tvm.var("t"), tvm.var("d") n, t, d = tvm.var("n"), tvm.var("t"), tvm.var("d")
input_ty = relay.TensorType((n, t, d), "float32") input_ty = relay.TensorType((n, t, d), dtype)
x = relay.var("x", input_ty) x = relay.var("x", input_ty)
y = relay.nn.dropout(x, rate=0.75) y = relay.nn.dropout(x, rate=0.75)
assert "rate=" in y.astext() assert "rate=" in y.astext()
...@@ -263,84 +288,89 @@ def test_dropout(): ...@@ -263,84 +288,89 @@ def test_dropout():
def test_batch_norm(): def test_batch_norm():
for dtype in ['float16', 'float32']:
# beta and gamma ignored # beta and gamma ignored
data = relay.var("data", relay.TensorType((3, 2, 1))) data = relay.var("data", relay.TensorType((3, 2, 1), dtype))
beta = relay.var("beta", relay.TensorType((2,))) beta = relay.var("beta", relay.TensorType((2,), dtype))
gamma = relay.var("gamma", relay.TensorType((2,))) gamma = relay.var("gamma", relay.TensorType((2,), dtype))
moving_mean = relay.var("moving_mean", relay.TensorType((2,))) moving_mean = relay.var("moving_mean", relay.TensorType((2,), dtype))
moving_var = relay.var("moving_var", relay.TensorType((2,))) moving_var = relay.var("moving_var", relay.TensorType((2,), dtype))
y = relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var, y = relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var,
center=False, scale=False) center=False, scale=False)
yy = run_infer_type(y.astuple()) yy = run_infer_type(y.astuple())
assert "center=" in yy.astext() assert "center=" in yy.astext()
assert yy.checked_type == relay.ty.TupleType(tvm.convert([ assert yy.checked_type == relay.ty.TupleType(tvm.convert([
relay.TensorType((3, 2, 1), "float32"), relay.TensorType((3, 2, 1), dtype),
relay.TensorType((2,), "float32"), relay.TensorType((2,), dtype),
relay.TensorType((2,), "float32") relay.TensorType((2,), dtype)
])) ]))
beta = relay.var("beta", relay.TensorType((3,))) beta = relay.var("beta", relay.TensorType((3,), dtype))
gamma = relay.var("gamma", relay.TensorType((3,))) gamma = relay.var("gamma", relay.TensorType((3,), dtype))
moving_mean = relay.var("moving_mean", relay.TensorType((3,))) moving_mean = relay.var("moving_mean", relay.TensorType((3,), dtype))
moving_var = relay.var("moving_var", relay.TensorType((3,))) moving_var = relay.var("moving_var", relay.TensorType((3,), dtype))
y = relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var, y = relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var,
axis=0, center=False, scale=False) axis=0, center=False, scale=False)
yy = run_infer_type(y.astuple()) yy = run_infer_type(y.astuple())
assert yy.checked_type == relay.ty.TupleType(tvm.convert([ assert yy.checked_type == relay.ty.TupleType(tvm.convert([
relay.ty.TensorType((3, 2, 1), "float32"), relay.ty.TensorType((3, 2, 1), dtype),
relay.ty.TensorType((3,), "float32"), relay.ty.TensorType((3,), dtype),
relay.ty.TensorType((3,), "float32") relay.ty.TensorType((3,), dtype)
])) ]))
# axis=-1 # axis=-1
data = relay.var("data", relay.TensorType((1, 2, 3))) data = relay.var("data", relay.TensorType((1, 2, 3), dtype))
beta = relay.var("beta", relay.TensorType((3,))) beta = relay.var("beta", relay.TensorType((3,), dtype))
gamma = relay.var("gamma", relay.TensorType((3,))) gamma = relay.var("gamma", relay.TensorType((3,), dtype))
moving_mean = relay.var("moving_mean", relay.TensorType((3,))) moving_mean = relay.var("moving_mean", relay.TensorType((3,), dtype))
moving_var = relay.var("moving_var", relay.TensorType((3,))) moving_var = relay.var("moving_var", relay.TensorType((3,), dtype))
y = relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var, y = relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var,
axis=-1, center=False, scale=False) axis=-1, center=False, scale=False)
yy = run_infer_type(y.astuple()) yy = run_infer_type(y.astuple())
assert yy.checked_type == relay.ty.TupleType(tvm.convert([ assert yy.checked_type == relay.ty.TupleType(tvm.convert([
relay.ty.TensorType((1, 2, 3), "float32"), relay.ty.TensorType((1, 2, 3), dtype),
relay.ty.TensorType((3,), "float32"), relay.ty.TensorType((3,), dtype),
relay.ty.TensorType((3,), "float32") relay.ty.TensorType((3,), dtype)
])) ]))
def test_dense(): def test_dense():
for dtype in ['float16', 'float32']:
# Dense accuracy for float16 is poor
if dtype == 'float16':
return
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) x = relay.var("x", relay.TensorType((n, c, h, w), dtype))
w = relay.var("w", relay.TensorType((2, w), "float32")) w = relay.var("w", relay.TensorType((2, w), dtype))
y = relay.nn.dense(x, w, units=2) y = relay.nn.dense(x, w, units=2)
assert "units=2" in y.astext() assert "units=2" in y.astext()
yy = run_infer_type(y) yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, c, h, 2), "float32") assert yy.checked_type == relay.TensorType((n, c, h, 2), dtype)
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), 2 n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), 2
x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) x = relay.var("x", relay.TensorType((n, c, h, w), dtype))
wh, ww = tvm.var("wh"), tvm.var("ww") wh, ww = tvm.var("wh"), tvm.var("ww")
w = relay.var("w", relay.TensorType((ww, wh), "float32")) w = relay.var("w", relay.TensorType((ww, wh), dtype))
y = relay.nn.dense(x, w) y = relay.nn.dense(x, w)
yy = run_infer_type(y) yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, c, h, ww), "float32") assert yy.checked_type == relay.TensorType((n, c, h, ww), dtype)
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), 2 n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), 2
x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) x = relay.var("x", relay.TensorType((n, c, h, w), dtype))
w = relay.var("w", relay.IncompleteType()) w = relay.var("w", relay.IncompleteType())
y = relay.nn.dense(x, w, units=2) y = relay.nn.dense(x, w, units=2)
yy = run_infer_type(y) yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, c, h, 2), "float32") assert yy.checked_type == relay.TensorType((n, c, h, 2), dtype)
x = relay.var("x", shape=(10, 5)) x = relay.var("x", shape=(10, 5), dtype=dtype)
w = relay.var("w", shape=(2, 5)) w = relay.var("w", shape=(2, 5), dtype=dtype)
z = relay.nn.dense(x, w) z = relay.nn.dense(x, w)
# Check result. # Check result.
func = relay.Function([x, w], z) func = relay.Function([x, w], z)
x_data = np.random.rand(10, 5).astype('float32') x_data = np.random.rand(10, 5).astype(dtype)
w_data = np.random.rand(2, 5).astype('float32') w_data = np.random.rand(2, 5).astype(dtype)
ref_res = np.dot(x_data, w_data.T) ref_res = np.dot(x_data, w_data.T)
for target, ctx in ctx_list(): for target, ctx in ctx_list():
......
...@@ -19,6 +19,7 @@ import numpy as np ...@@ -19,6 +19,7 @@ import numpy as np
import tvm import tvm
import topi import topi
import topi.testing import topi.testing
from tvm.contrib.nvcc import have_fp16
from common import get_all_backend from common import get_all_backend
...@@ -53,6 +54,9 @@ def verify_reinterpret(in_shape, in_dtype, out_dtype, generator): ...@@ -53,6 +54,9 @@ def verify_reinterpret(in_shape, in_dtype, out_dtype, generator):
if not ctx.exist: if not ctx.exist:
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
if in_dtype == "float16" and device == 'cuda' and not have_fp16(ctx.compute_version):
print("Skip because %s does not have fp16 support" % device)
return
print("Running on target: %s" % device) print("Running on target: %s" % device)
with tvm.target.create(device): with tvm.target.create(device):
s = topi.generic.schedule_elemwise(B) s = topi.generic.schedule_elemwise(B)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment