Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
173b4fc4
Unverified
Commit
173b4fc4
authored
Mar 12, 2020
by
pankratz
Committed by
GitHub
Mar 12, 2020
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Fixed div by zero core dump. Fixed rounding intrinsics on int crash (#5026)
parent
ec86d7f1
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
45 additions
and
2 deletions
+45
-2
src/arith/const_fold.h
+2
-0
src/tir/ir/op.cc
+15
-0
tests/python/unittest/test_lang_basic.py
+17
-2
tests/python/unittest/test_tvm_intrin.py
+11
-0
No files found.
src/arith/const_fold.h
View file @
173b4fc4
...
...
@@ -181,6 +181,7 @@ inline PrimExpr TryConstFold<tir::ModNode>(PrimExpr a, PrimExpr b) {
TVM_INDEX_CONST_PROPAGATION
({
const
DataType
&
rtype
=
a
.
dtype
();
if
(
pa
&&
pb
)
{
CHECK_NE
(
pb
->
value
,
0
)
<<
"Divide by zero"
;
return
IntImm
(
rtype
,
pa
->
value
%
pb
->
value
);
}
if
(
pa
)
{
...
...
@@ -226,6 +227,7 @@ inline PrimExpr TryConstFold<tir::FloorModNode>(PrimExpr a, PrimExpr b) {
TVM_INDEX_CONST_PROPAGATION
({
const
DataType
&
rtype
=
a
.
dtype
();
if
(
pa
&&
pb
)
{
CHECK_NE
(
pb
->
value
,
0
)
<<
"Divide by zero"
;
return
IntImm
(
rtype
,
floormod
(
pa
->
value
,
pb
->
value
));
}
if
(
pa
)
{
...
...
src/tir/ir/op.cc
View file @
173b4fc4
...
...
@@ -606,6 +606,9 @@ PrimExpr fmod(PrimExpr x, PrimExpr y) {
}
PrimExpr
floor
(
PrimExpr
x
)
{
if
(
x
.
dtype
().
is_int
()
||
x
.
dtype
().
is_uint
())
{
return
x
;
}
using
tir
::
FloatImmNode
;
const
FloatImmNode
*
fx
=
x
.
as
<
FloatImmNode
>
();
if
(
fx
)
return
FloatImm
(
x
.
dtype
(),
std
::
floor
(
fx
->
value
));
...
...
@@ -613,6 +616,9 @@ PrimExpr floor(PrimExpr x) {
}
PrimExpr
ceil
(
PrimExpr
x
)
{
if
(
x
.
dtype
().
is_int
()
||
x
.
dtype
().
is_uint
())
{
return
x
;
}
using
tir
::
FloatImmNode
;
const
FloatImmNode
*
fx
=
x
.
as
<
FloatImmNode
>
();
if
(
fx
)
return
FloatImm
(
x
.
dtype
(),
std
::
ceil
(
fx
->
value
));
...
...
@@ -620,6 +626,9 @@ PrimExpr ceil(PrimExpr x) {
}
PrimExpr
round
(
PrimExpr
x
)
{
if
(
x
.
dtype
().
is_int
()
||
x
.
dtype
().
is_uint
())
{
return
x
;
}
using
tir
::
FloatImmNode
;
const
FloatImmNode
*
fx
=
x
.
as
<
FloatImmNode
>
();
if
(
fx
)
return
FloatImm
(
x
.
dtype
(),
std
::
nearbyint
(
fx
->
value
));
...
...
@@ -627,6 +636,9 @@ PrimExpr round(PrimExpr x) {
}
PrimExpr
nearbyint
(
PrimExpr
x
)
{
if
(
x
.
dtype
().
is_int
()
||
x
.
dtype
().
is_uint
())
{
return
x
;
}
using
tir
::
FloatImmNode
;
const
FloatImmNode
*
fx
=
x
.
as
<
FloatImmNode
>
();
if
(
fx
)
return
FloatImm
(
x
.
dtype
(),
std
::
nearbyint
(
fx
->
value
));
...
...
@@ -634,6 +646,9 @@ PrimExpr nearbyint(PrimExpr x) {
}
PrimExpr
trunc
(
PrimExpr
x
)
{
if
(
x
.
dtype
().
is_int
()
||
x
.
dtype
().
is_uint
())
{
return
x
;
}
using
tir
::
FloatImmNode
;
const
FloatImmNode
*
fx
=
x
.
as
<
FloatImmNode
>
();
if
(
fx
)
{
...
...
tests/python/unittest/test_lang_basic.py
View file @
173b4fc4
...
...
@@ -187,14 +187,14 @@ def test_bitwise():
assert
(
x
>>
tvm
.
tir
.
const
(
1
,
"int32x2"
))
.
dtype
==
"int32x2"
assert
(
te
.
var
(
"z"
,
"int8x2"
)
<<
tvm
.
tir
.
const
(
1
,
"int8x2"
))
.
dtype
==
"int8x2"
def
test_float_bitwise
():
t
=
tvm
.
tir
.
const
(
1.5
,
dtype
=
'float32'
)
for
test
in
[
lambda
lhs
,
rhs
:
lhs
<<
rhs
,
lambda
lhs
,
rhs
:
lhs
>>
rhs
,
lambda
lhs
,
rhs
:
lhs
|
rhs
,
lambda
lhs
,
rhs
:
lhs
^
rhs
,
lambda
lhs
,
rhs
:
lhs
&
rhs
]:
lambda
lhs
,
rhs
:
lhs
&
rhs
]:
try
:
test
(
t
,
10.0
)
assert
False
...
...
@@ -206,6 +206,20 @@ def test_float_bitwise():
except
RuntimeError
:
pass
def
test_divide_by_zero
():
for
test
in
[
lambda
lhs
,
rhs
:
tvm
.
tir
.
floormod
(
lhs
,
rhs
),
lambda
lhs
,
rhs
:
tvm
.
tir
.
floordiv
(
lhs
,
rhs
),
lambda
lhs
,
rhs
:
tvm
.
tir
.
truncmod
(
lhs
,
rhs
),
lambda
lhs
,
rhs
:
tvm
.
tir
.
truncdiv
(
lhs
,
rhs
),
lambda
lhs
,
rhs
:
tvm
.
tir
.
div
(
lhs
,
rhs
)]:
try
:
test
(
tvm
.
tir
.
const
(
5
,
'int32'
),
tvm
.
tir
.
const
(
0
,
'int32'
))
assert
False
except
tvm
.
TVMError
:
pass
def
test_isnan
():
x
=
te
.
var
(
'x'
,
'float32'
)
assert
str
(
tvm
.
tir
.
isnan
(
x
))
==
'isnan(x)'
...
...
@@ -250,6 +264,7 @@ if __name__ == "__main__":
test_all
()
test_bitwise
()
test_float_bitwise
()
test_divide_by_zero
()
test_isnan
()
test_equality
()
test_equality_string_imm
()
tests/python/unittest/test_tvm_intrin.py
View file @
173b4fc4
...
...
@@ -44,6 +44,16 @@ def test_nearbyint():
tvm
.
testing
.
assert_allclose
(
a_rounded
.
asnumpy
(),
np
.
rint
(
a
.
asnumpy
()))
def
test_round_intrinsics_on_int
():
i
=
tvm
.
te
.
var
(
"i"
,
'int32'
)
for
op
in
[
tvm
.
tir
.
round
,
tvm
.
tir
.
trunc
,
tvm
.
tir
.
ceil
,
tvm
.
tir
.
floor
,
tvm
.
tir
.
nearbyint
]:
assert
op
(
tvm
.
tir
.
const
(
10
,
'int32'
))
.
value
==
10
assert
op
(
tvm
.
tir
.
const
(
True
,
'bool'
))
.
value
==
True
assert
op
(
i
)
.
same_as
(
i
)
assert
tvm
.
tir
.
isnan
(
tvm
.
tir
.
const
(
10
,
'int32'
))
.
value
==
False
def
test_unary_intrin
():
test_funcs
=
[
...
...
@@ -75,3 +85,4 @@ def test_unary_intrin():
if
__name__
==
"__main__"
:
test_nearbyint
()
test_unary_intrin
()
test_round_intrinsics_on_int
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment