Commit a7e35fc3 by Andrew Tulloch Committed by Tianqi Chen

Fix vmlal.s16 code generation for int8 x int8 -> int32 (#2748)

parent 2239508b
...@@ -50,7 +50,23 @@ class IntrinInjecter : public IRMutator { ...@@ -50,7 +50,23 @@ class IntrinInjecter : public IRMutator {
// on ARM. // on ARM.
if (const Broadcast* bcast = e.as<Broadcast>()) { if (const Broadcast* bcast = e.as<Broadcast>()) {
if (const Cast* cast = bcast->value.as<Cast>()) { if (const Cast* cast = bcast->value.as<Cast>()) {
if (cast->type.bits() == cast->value.type().bits() * 2) { auto should_swap = [&]() {
// Maintain behaviour (int8 -> int16, fp16 -> fp32).
if (cast->type.bits() == cast->value.type().bits() * 2) {
return true;
}
// Check both operands are integer-like.
if (!cast->type.is_uint() && !cast->type.is_int()) {
return false;
}
if (!cast->value.type().is_uint() && !cast->value.type().is_int()) {
return false;
}
// If both are integer-like, swap if we have a widening cast.
return cast->type.bits() > cast->value.type().bits();
};
if (should_swap()) {
Expr new_bcast = Broadcast::make(cast->value, bcast->lanes); Expr new_bcast = Broadcast::make(cast->value, bcast->lanes);
return Cast::make(bcast->type, new_bcast); return Cast::make(bcast->type, new_bcast);
} }
......
...@@ -26,5 +26,49 @@ def test_popcount(): ...@@ -26,5 +26,49 @@ def test_popcount():
check_correct_assembly('uint32', 2, 2) check_correct_assembly('uint32', 2, 2)
check_correct_assembly('uint64', 2, 3) check_correct_assembly('uint64', 2, 3)
def test_vmlal_s16():
target = 'llvm -target=armv7l-none-linux-gnueabihf -mcpu=cortex-a53 -mattr=+neon'
def check_correct_assembly(N):
K = tvm.var("K")
A = tvm.placeholder((K, N), dtype="int8", name='A')
B = tvm.placeholder((K, N), dtype="int8", name='A')
k = tvm.reduce_axis((0, K))
C = tvm.compute((N, ), lambda n: tvm.sum(
A[k, n].astype("int32") * B[k, n].astype("int32"), axis=[k]), name='C')
s = tvm.create_schedule(C.op)
s[C].vectorize(s[C].op.axis[0])
f = tvm.build(s, [A, B, C], target)
# Verify we see the correct number of vmlal.s16 instructions
assembly = f.get_source('asm')
matches = re.findall("vmlal.s16", assembly)
assert (len(matches) == N // 4)
check_correct_assembly(4)
check_correct_assembly(8)
check_correct_assembly(16)
def check_broadcast_correct_assembly(N):
K = tvm.var("K")
A = tvm.placeholder((K, N), dtype="int8", name='A')
B = tvm.placeholder((K,), dtype="int8", name='A')
k = tvm.reduce_axis((0, K))
C = tvm.compute((N, ), lambda n: tvm.sum(
A[k, n].astype("int32") * B[k].astype("int32"),
axis=[k]), name='C')
s = tvm.create_schedule(C.op)
s[C].vectorize(s[C].op.axis[0])
f = tvm.build(s, [A, B, C], target)
# Verify we see the correct number of vmlal.s16 instructions
assembly = f.get_source('asm')
matches = re.findall("vmlal.s16", assembly)
assert len(matches) == N // 4
check_broadcast_correct_assembly(8)
check_broadcast_correct_assembly(16)
check_broadcast_correct_assembly(32)
check_broadcast_correct_assembly(64)
if __name__ == "__main__": if __name__ == "__main__":
test_popcount() test_popcount()
test_vmlal_s16()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment