Commit bafc675c by Sergei Grechanik Committed by Tianqi Chen

[ARITH] Fix lowering of FloorMod (#4236)

parent a897d36d
......@@ -77,7 +77,7 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
if (op == nullptr) return ret;
int shift;
const DataType& dtype = op->type;
CHECK(dtype.is_int() || !dtype.is_uint());
CHECK(dtype.is_int() || dtype.is_uint());
if (support_bitwise_op_ &&
is_const_power_of_two_integer(op->b, &shift)) {
......@@ -124,7 +124,7 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
// Lower floordiv to native truncdiv.
int shift;
const DataType& dtype = op->type;
CHECK(dtype.is_int() || !dtype.is_uint());
CHECK(dtype.is_int() || dtype.is_uint());
if (support_bitwise_op_ &&
is_const_power_of_two_integer(op->b, &shift)) {
......@@ -136,8 +136,7 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
if (analyzer_->CanProveGreaterEqual(op->b, 0)) {
// Common pass, positive divisor
if (analyzer_->CanProveGreaterEqual(op->a, 0) ||
analyzer_->CanProveGreaterEqual(e, 0)) {
if (analyzer_->CanProveGreaterEqual(op->a, 0)) {
return truncmod(op->a, op->b);
} else {
DLOG(INFO) << "LowerFloorMod: Cannot decide the sign of divident";
......
......@@ -406,40 +406,103 @@ def test_alignment():
assert "align 32" in l
def test_llvm_div():
"""Check that the semantics of div and mod is the same as in C/C++"""
def check_div(start, end, divisor, dtype):
T = tvm.compute((end - start,),
lambda i: tvm.div(tvm.expr.Cast(dtype, (start + i)), tvm.const(divisor, dtype)))
s = tvm.create_schedule([T.op])
f = tvm.build(s, [T], "llvm")
a = tvm.nd.empty((end - start,), dtype)
f(a)
ref = [int(float(i)/divisor) for i in range(start, end)]
tvm.testing.assert_allclose(a.asnumpy(), ref)
def check_mod(start, end, divisor, dtype):
tmod = tvm.truncmod
T = tvm.compute((end - start,),
lambda i: tmod(tvm.expr.Cast(dtype, (start + i)), tvm.const(divisor, dtype)))
s = tvm.create_schedule([T.op])
f = tvm.build(s, [T], "llvm")
a = tvm.nd.empty((end - start,), dtype)
f(a)
ref = [int(math.fmod(i, divisor)) for i in range(start, end)]
tvm.testing.assert_allclose(a.asnumpy(), ref)
def check_llvm(start, end, divisor, dtype):
check_div(start, end, divisor, dtype)
check_mod(start, end, divisor, dtype)
for d in range(-5, 6):
if d != 0:
# Note that 11 (and not e.g. 10) is used to avoid issues with the simplifier
check_llvm(-11, 11, d, 'int32')
check_llvm(-11, 11, d, 'int8')
if d > 0:
check_llvm(123, 133, d, 'uint8')
check_llvm(0, 256, d, 'uint8')
"""Check that the semantics of div and mod is correct"""
def check(start, end, dstart, dend, dtype, floor_div=False):
div = tvm.floordiv if floor_div else tvm.truncdiv
mod = tvm.floormod if floor_div else tvm.truncmod
# A are dividends, B are divisors. Note that we add 1 to make include end in the range.
A = tvm.placeholder((end - start + 1,), name="A", dtype=dtype)
B = tvm.placeholder((dend - dstart + 1,), name="B", dtype=dtype)
# We clip values with min and max so that simplifiers know the ranges of values
clipa = lambda x: tvm.min(tvm.const(end, dtype), tvm.max(tvm.const(start, dtype), x))
clipb = lambda x: tvm.min(tvm.const(dend, dtype), tvm.max(tvm.const(dstart, dtype), x))
# If the range is just a single point, use the constant itself
if start == end:
clipa = lambda x: tvm.const(start, dtype)
if dstart == dend:
clipb = lambda x: tvm.const(dstart, dtype)
# D are division results and M are modulo results
[D, M] = tvm.compute((end - start + 1, dend - dstart + 1),
lambda i, j: (div(clipa(A[i]), clipb(B[j])),
mod(clipa(A[i]), clipb(B[j]))))
s = tvm.create_schedule([D.op, M.op])
f = tvm.build(s, [A, B, D, M], "llvm")
# Fill input arrays with values
A_arr = tvm.nd.empty((end - start + 1,), dtype)
B_arr = tvm.nd.empty((dend - dstart + 1,), dtype)
A_arr.copyfrom(np.arange(start, end + 1, dtype=dtype))
B_np = np.arange(dstart, dend + 1, dtype=dtype)
# If the range of the divisor contains 0, replace it with 1 to avoid division by zero
if dend >= 0 and dstart <= 0:
B_np[-dstart] = 1
B_arr.copyfrom(B_np)
D_arr = tvm.nd.empty((end - start + 1, dend - dstart + 1), dtype)
M_arr = tvm.nd.empty((end - start + 1, dend - dstart + 1), dtype)
# Run the function and convert the results to numpy
f(A_arr, B_arr, D_arr, M_arr)
D_arr = D_arr.asnumpy()
M_arr = M_arr.asnumpy()
# This helper just prints additional info on failure
def _show_info():
print("dtype: {}".format(dtype))
print("dividend range: [{}, {}]".format(start, end))
print("divisor range: [{}, {}]".format(dstart, dend))
lowered = tvm.lower(s, [A, B, D, M], simple_mode=True)
print("Lowered code:")
print(lowered)
# Check that the computed values are correct
for i in range(start, end + 1):
for j in range(dstart, dend + 1):
if j == 0:
continue
if floor_div:
dref = i // j
mref = i % j
else:
dref = int(float(i) / j)
mref = int(math.fmod(i, j))
if D_arr[i - start, j - dstart] != dref:
_show_info()
raise AssertionError("Incorrect division result: {}({}, {}) is {} "
"but should be {}".format(div.__name__, i, j,
D_arr[i - start, j - dstart],
dref))
if M_arr[i - start, j - dstart] != mref:
_show_info()
raise AssertionError("Incorrect modulo result: {}({}, {}) is {} "
"but should be {}".format(mod.__name__, i, j,
M_arr[i - start, j - dstart],
mref))
# Try different ranges to cover different cases
for start, end in [(-12, -12), (-11, -1), (-11, 0), (0, 0),
( 12, 12), ( 1, 11), ( 0, 11), (-11, 11)]:
for dstart, dend in [(-11, -1), (-11, 0), (-4, -4), (-2, -2),
( 1, 11), ( 0, 11), ( 4, 4), ( 2, 2), (-11, 11)]:
if end < start or dend < dstart or (dend == 0 and dstart == 0):
continue
check(start, end, dstart, dend, 'int32', floor_div=False)
check(start, end, dstart, dend, 'int32', floor_div=True)
check(start, end, dstart, dend, 'int8', floor_div=False)
check(start, end, dstart, dend, 'int8', floor_div=True)
if start >= 0 and dstart >= 0:
check(start, end, dstart, dend, 'uint32', floor_div=False)
check(start, end, dstart, dend, 'uint32', floor_div=True)
# Additional tests for uint8
for dstart, dend in [(0, 11), (1, 11), (2, 2), (4, 4)]:
check(123, 133, dstart, dend, 'uint8', floor_div=False)
check(123, 133, dstart, dend, 'uint8', floor_div=True)
check(0, 255, dstart, dend, 'uint8', floor_div=False)
check(0, 255, dstart, dend, 'uint8', floor_div=True)
def test_llvm_fp_math():
def check_llvm_reciprocal(n):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment