Commit 155e955f by Wuwei Lin Committed by Tianqi Chen

Fix int8x4 broadcast value codegen in cuda (#1959)

parent 3bfa5fc0
......@@ -273,6 +273,16 @@ void CodeGenCUDA::VisitExpr_(const Ramp* op, std::ostream& os) {
}
void CodeGenCUDA::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*)
if (op->type.is_int() && op->type.bits() == 8 && op->lanes == 4) {
// make_int8x4
const int64_t *p = as_const_int(op->value);
CHECK(p);
int64_t v = *p & 0xFF;
v = (v << 24) | (v << 16) | (v << 8) | v;
os << "(int)" << v;
return;
}
std::string v = PrintExpr(op->value);
os << "make_";
PrintType(op->type, os);
......
......@@ -87,7 +87,30 @@ def test_cuda_vectorize_load():
check_cuda("int8", 64, 8)
check_cuda("int8", 64, 16)
def test_cuda_make_int8x4():
def check_cuda(n, value):
if not tvm.gpu(0).exist or not tvm.module.enabled("cuda"):
print("skip because cuda is not enabled..")
return
lanes = 4
dtype = 'int8'
ctx = tvm.gpu(0)
A = tvm.compute((n, lanes), lambda i,j: tvm.const(value, dtype=dtype))
s = tvm.create_schedule(A.op)
y, x = s[A].op.axis
s[A].vectorize(x)
s[A].bind(y, tvm.thread_axis("blockIdx.x"))
fun = tvm.build(s, [A], "cuda", name="make_int8x4")
np_a = np.full((n, lanes), value, dtype=dtype)
a = tvm.nd.empty(np_a.shape, dtype, ctx)
fun(a)
np.testing.assert_equal(a.asnumpy(), np_a)
check_cuda(64, 0xAB)
check_cuda(64, 0)
check_cuda(64, -3)
if __name__ == "__main__":
test_cuda_vectorize_add()
test_cuda_multiply_add()
test_cuda_vectorize_load()
test_cuda_make_int8x4()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment