Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
bafc675c
Commit
bafc675c
authored
Nov 01, 2019
by
Sergei Grechanik
Committed by
Tianqi Chen
Nov 01, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[ARITH] Fix lowering of FloorMod (#4236)
parent
a897d36d
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
100 additions
and
38 deletions
+100
-38
src/pass/lower_intrin.cc
+3
-4
tests/python/unittest/test_codegen_llvm.py
+97
-34
No files found.
src/pass/lower_intrin.cc
View file @
bafc675c
...
...
@@ -77,7 +77,7 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
if
(
op
==
nullptr
)
return
ret
;
int
shift
;
const
DataType
&
dtype
=
op
->
type
;
CHECK
(
dtype
.
is_int
()
||
!
dtype
.
is_uint
());
CHECK
(
dtype
.
is_int
()
||
dtype
.
is_uint
());
if
(
support_bitwise_op_
&&
is_const_power_of_two_integer
(
op
->
b
,
&
shift
))
{
...
...
@@ -124,7 +124,7 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
// Lower floordiv to native truncdiv.
int
shift
;
const
DataType
&
dtype
=
op
->
type
;
CHECK
(
dtype
.
is_int
()
||
!
dtype
.
is_uint
());
CHECK
(
dtype
.
is_int
()
||
dtype
.
is_uint
());
if
(
support_bitwise_op_
&&
is_const_power_of_two_integer
(
op
->
b
,
&
shift
))
{
...
...
@@ -136,8 +136,7 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
if
(
analyzer_
->
CanProveGreaterEqual
(
op
->
b
,
0
))
{
// Common pass, positive divisor
if
(
analyzer_
->
CanProveGreaterEqual
(
op
->
a
,
0
)
||
analyzer_
->
CanProveGreaterEqual
(
e
,
0
))
{
if
(
analyzer_
->
CanProveGreaterEqual
(
op
->
a
,
0
))
{
return
truncmod
(
op
->
a
,
op
->
b
);
}
else
{
DLOG
(
INFO
)
<<
"LowerFloorMod: Cannot decide the sign of divident"
;
...
...
tests/python/unittest/test_codegen_llvm.py
View file @
bafc675c
...
...
@@ -406,40 +406,103 @@ def test_alignment():
assert
"align 32"
in
l
def
test_llvm_div
():
"""Check that the semantics of div and mod is the same as in C/C++"""
def
check_div
(
start
,
end
,
divisor
,
dtype
):
T
=
tvm
.
compute
((
end
-
start
,),
lambda
i
:
tvm
.
div
(
tvm
.
expr
.
Cast
(
dtype
,
(
start
+
i
)),
tvm
.
const
(
divisor
,
dtype
)))
s
=
tvm
.
create_schedule
([
T
.
op
])
f
=
tvm
.
build
(
s
,
[
T
],
"llvm"
)
a
=
tvm
.
nd
.
empty
((
end
-
start
,),
dtype
)
f
(
a
)
ref
=
[
int
(
float
(
i
)
/
divisor
)
for
i
in
range
(
start
,
end
)]
tvm
.
testing
.
assert_allclose
(
a
.
asnumpy
(),
ref
)
def
check_mod
(
start
,
end
,
divisor
,
dtype
):
tmod
=
tvm
.
truncmod
T
=
tvm
.
compute
((
end
-
start
,),
lambda
i
:
tmod
(
tvm
.
expr
.
Cast
(
dtype
,
(
start
+
i
)),
tvm
.
const
(
divisor
,
dtype
)))
s
=
tvm
.
create_schedule
([
T
.
op
])
f
=
tvm
.
build
(
s
,
[
T
],
"llvm"
)
a
=
tvm
.
nd
.
empty
((
end
-
start
,),
dtype
)
f
(
a
)
ref
=
[
int
(
math
.
fmod
(
i
,
divisor
))
for
i
in
range
(
start
,
end
)]
tvm
.
testing
.
assert_allclose
(
a
.
asnumpy
(),
ref
)
def
check_llvm
(
start
,
end
,
divisor
,
dtype
):
check_div
(
start
,
end
,
divisor
,
dtype
)
check_mod
(
start
,
end
,
divisor
,
dtype
)
for
d
in
range
(
-
5
,
6
):
if
d
!=
0
:
# Note that 11 (and not e.g. 10) is used to avoid issues with the simplifier
check_llvm
(
-
11
,
11
,
d
,
'int32'
)
check_llvm
(
-
11
,
11
,
d
,
'int8'
)
if
d
>
0
:
check_llvm
(
123
,
133
,
d
,
'uint8'
)
check_llvm
(
0
,
256
,
d
,
'uint8'
)
"""Check that the semantics of div and mod is correct"""
def
check
(
start
,
end
,
dstart
,
dend
,
dtype
,
floor_div
=
False
):
div
=
tvm
.
floordiv
if
floor_div
else
tvm
.
truncdiv
mod
=
tvm
.
floormod
if
floor_div
else
tvm
.
truncmod
# A are dividends, B are divisors. Note that we add 1 to make include end in the range.
A
=
tvm
.
placeholder
((
end
-
start
+
1
,),
name
=
"A"
,
dtype
=
dtype
)
B
=
tvm
.
placeholder
((
dend
-
dstart
+
1
,),
name
=
"B"
,
dtype
=
dtype
)
# We clip values with min and max so that simplifiers know the ranges of values
clipa
=
lambda
x
:
tvm
.
min
(
tvm
.
const
(
end
,
dtype
),
tvm
.
max
(
tvm
.
const
(
start
,
dtype
),
x
))
clipb
=
lambda
x
:
tvm
.
min
(
tvm
.
const
(
dend
,
dtype
),
tvm
.
max
(
tvm
.
const
(
dstart
,
dtype
),
x
))
# If the range is just a single point, use the constant itself
if
start
==
end
:
clipa
=
lambda
x
:
tvm
.
const
(
start
,
dtype
)
if
dstart
==
dend
:
clipb
=
lambda
x
:
tvm
.
const
(
dstart
,
dtype
)
# D are division results and M are modulo results
[
D
,
M
]
=
tvm
.
compute
((
end
-
start
+
1
,
dend
-
dstart
+
1
),
lambda
i
,
j
:
(
div
(
clipa
(
A
[
i
]),
clipb
(
B
[
j
])),
mod
(
clipa
(
A
[
i
]),
clipb
(
B
[
j
]))))
s
=
tvm
.
create_schedule
([
D
.
op
,
M
.
op
])
f
=
tvm
.
build
(
s
,
[
A
,
B
,
D
,
M
],
"llvm"
)
# Fill input arrays with values
A_arr
=
tvm
.
nd
.
empty
((
end
-
start
+
1
,),
dtype
)
B_arr
=
tvm
.
nd
.
empty
((
dend
-
dstart
+
1
,),
dtype
)
A_arr
.
copyfrom
(
np
.
arange
(
start
,
end
+
1
,
dtype
=
dtype
))
B_np
=
np
.
arange
(
dstart
,
dend
+
1
,
dtype
=
dtype
)
# If the range of the divisor contains 0, replace it with 1 to avoid division by zero
if
dend
>=
0
and
dstart
<=
0
:
B_np
[
-
dstart
]
=
1
B_arr
.
copyfrom
(
B_np
)
D_arr
=
tvm
.
nd
.
empty
((
end
-
start
+
1
,
dend
-
dstart
+
1
),
dtype
)
M_arr
=
tvm
.
nd
.
empty
((
end
-
start
+
1
,
dend
-
dstart
+
1
),
dtype
)
# Run the function and convert the results to numpy
f
(
A_arr
,
B_arr
,
D_arr
,
M_arr
)
D_arr
=
D_arr
.
asnumpy
()
M_arr
=
M_arr
.
asnumpy
()
# This helper just prints additional info on failure
def
_show_info
():
print
(
"dtype: {}"
.
format
(
dtype
))
print
(
"dividend range: [{}, {}]"
.
format
(
start
,
end
))
print
(
"divisor range: [{}, {}]"
.
format
(
dstart
,
dend
))
lowered
=
tvm
.
lower
(
s
,
[
A
,
B
,
D
,
M
],
simple_mode
=
True
)
print
(
"Lowered code:"
)
print
(
lowered
)
# Check that the computed values are correct
for
i
in
range
(
start
,
end
+
1
):
for
j
in
range
(
dstart
,
dend
+
1
):
if
j
==
0
:
continue
if
floor_div
:
dref
=
i
//
j
mref
=
i
%
j
else
:
dref
=
int
(
float
(
i
)
/
j
)
mref
=
int
(
math
.
fmod
(
i
,
j
))
if
D_arr
[
i
-
start
,
j
-
dstart
]
!=
dref
:
_show_info
()
raise
AssertionError
(
"Incorrect division result: {}({}, {}) is {} "
"but should be {}"
.
format
(
div
.
__name__
,
i
,
j
,
D_arr
[
i
-
start
,
j
-
dstart
],
dref
))
if
M_arr
[
i
-
start
,
j
-
dstart
]
!=
mref
:
_show_info
()
raise
AssertionError
(
"Incorrect modulo result: {}({}, {}) is {} "
"but should be {}"
.
format
(
mod
.
__name__
,
i
,
j
,
M_arr
[
i
-
start
,
j
-
dstart
],
mref
))
# Try different ranges to cover different cases
for
start
,
end
in
[(
-
12
,
-
12
),
(
-
11
,
-
1
),
(
-
11
,
0
),
(
0
,
0
),
(
12
,
12
),
(
1
,
11
),
(
0
,
11
),
(
-
11
,
11
)]:
for
dstart
,
dend
in
[(
-
11
,
-
1
),
(
-
11
,
0
),
(
-
4
,
-
4
),
(
-
2
,
-
2
),
(
1
,
11
),
(
0
,
11
),
(
4
,
4
),
(
2
,
2
),
(
-
11
,
11
)]:
if
end
<
start
or
dend
<
dstart
or
(
dend
==
0
and
dstart
==
0
):
continue
check
(
start
,
end
,
dstart
,
dend
,
'int32'
,
floor_div
=
False
)
check
(
start
,
end
,
dstart
,
dend
,
'int32'
,
floor_div
=
True
)
check
(
start
,
end
,
dstart
,
dend
,
'int8'
,
floor_div
=
False
)
check
(
start
,
end
,
dstart
,
dend
,
'int8'
,
floor_div
=
True
)
if
start
>=
0
and
dstart
>=
0
:
check
(
start
,
end
,
dstart
,
dend
,
'uint32'
,
floor_div
=
False
)
check
(
start
,
end
,
dstart
,
dend
,
'uint32'
,
floor_div
=
True
)
# Additional tests for uint8
for
dstart
,
dend
in
[(
0
,
11
),
(
1
,
11
),
(
2
,
2
),
(
4
,
4
)]:
check
(
123
,
133
,
dstart
,
dend
,
'uint8'
,
floor_div
=
False
)
check
(
123
,
133
,
dstart
,
dend
,
'uint8'
,
floor_div
=
True
)
check
(
0
,
255
,
dstart
,
dend
,
'uint8'
,
floor_div
=
False
)
check
(
0
,
255
,
dstart
,
dend
,
'uint8'
,
floor_div
=
True
)
def
test_llvm_fp_math
():
def
check_llvm_reciprocal
(
n
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment