Commit e37dbd4e by Sergei Grechanik Committed by Tianqi Chen

[TVM] Fix llvm codegen (div by power of 2) (#2204)

parent ed942616
......@@ -788,11 +788,7 @@ DEFINE_CODEGEN_CMP_OP(GE);
llvm::Value* CodeGenLLVM::VisitExpr_(const Div* op) {
llvm::Value* a = MakeValue(op->a);
llvm::Value* b = MakeValue(op->b);
int shift;
if ((op->type.is_int() || op->type.is_uint()) &&
is_const_power_of_two_integer(op->b, &shift)) {
return builder_->CreateAShr(a, shift);
} else if (op->type.is_int()) {
if (op->type.is_int()) {
return builder_->CreateSDiv(a, b);
} else if (op->type.is_uint()) {
return builder_->CreateUDiv(a, b);
......
......@@ -2,6 +2,7 @@ import tvm
from tvm.contrib import util, clang
import numpy as np
import ctypes
import math
def test_llvm_intrin():
ib = tvm.ir_builder.create()
......@@ -386,6 +387,40 @@ def test_alignment():
if "align" in l and "4 x float" in l:
assert "align 32" in l
def test_llvm_div():
"""Check that the semantics of div and mod is the same as in C/C++"""
def check_div(start, end, divisor, dtype):
T = tvm.compute((end - start,),
lambda i: tvm.expr.Cast(dtype, (start + i)) / tvm.const(divisor, dtype))
s = tvm.create_schedule([T.op])
f = tvm.build(s, [T], "llvm")
a = tvm.nd.empty((end - start,), dtype)
f(a)
ref = [int(float(i)/divisor) for i in range(start, end)]
tvm.testing.assert_allclose(a.asnumpy(), ref)
def check_mod(start, end, divisor, dtype):
T = tvm.compute((end - start,),
lambda i: tvm.expr.Cast(dtype, (start + i)) % tvm.const(divisor, dtype))
s = tvm.create_schedule([T.op])
f = tvm.build(s, [T], "llvm")
a = tvm.nd.empty((end - start,), dtype)
f(a)
ref = [int(math.fmod(i, divisor)) for i in range(start, end)]
tvm.testing.assert_allclose(a.asnumpy(), ref)
def check_llvm(start, end, divisor, dtype):
check_div(start, end, divisor, dtype)
check_mod(start, end, divisor, dtype)
for d in range(-5, 6):
if d != 0:
# Note that 11 (and not e.g. 10) is used to avoid issues with the simplifier
check_llvm(-11, 11, d, 'int32')
check_llvm(-11, 11, d, 'int8')
if d > 0:
check_llvm(123, 133, d, 'uint8')
check_llvm(0, 256, d, 'uint8')
if __name__ == "__main__":
test_llvm_import()
......@@ -403,3 +438,4 @@ if __name__ == "__main__":
test_llvm_madd_pipeline()
test_llvm_temp_space()
test_llvm_lookup_intrin()
test_llvm_div()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment