Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
515d4b6f
Commit
515d4b6f
authored
Apr 11, 2018
by
Tianqi Chen
Committed by
GitHub
Apr 11, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[PASS] More simplifier for mod and div (#1100)
* [PASS] More simplifier for mod and div * fix testcase
parent
1f0ca085
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
100 additions
and
13 deletions
+100
-13
src/arithmetic/canonical.cc
+76
-11
src/pass/lower_warp_memory.cc
+3
-1
tests/python/unittest/test_arith_simplify.py
+20
-0
tests/python/unittest/test_pass_simplify.py
+1
-1
No files found.
src/arithmetic/canonical.cc
View file @
515d4b6f
...
...
@@ -312,10 +312,24 @@ class Canonical::Internal : public IRMutator {
return
e
;
}
}
//
binary ops
//
Div operator
Expr
Mutate_
(
const
Div
*
op
,
const
Expr
&
e
)
final
{
if
(
!
EnableOpt
(
op
->
type
))
{
return
Binary
(
op
,
e
);
}
CacheEntry
a
=
Produce
(
op
->
a
);
CacheEntry
b
=
Produce
(
op
->
b
);
if
(
a
.
has_side_effect
||
b
.
has_side_effect
)
{
return
Binary_
(
op
,
e
,
a
.
value
,
b
.
value
);
}
if
(
is_const
(
a
.
value
)
&&
is_const
(
b
.
value
))
{
return
ComputeExpr
<
Div
>
(
a
.
value
,
b
.
value
);
}
else
if
(
is_const
(
b
.
value
))
{
return
SumDivConst
(
a
.
AsSum
(),
b
.
value
);
}
else
{
return
Binary
(
op
,
e
);
}
}
// Mod operator
Expr
Mutate_
(
const
Mod
*
op
,
const
Expr
&
e
)
final
{
if
(
!
EnableOpt
(
op
->
type
))
{
...
...
@@ -445,29 +459,80 @@ class Canonical::Internal : public IRMutator {
}
return
value
;
}
// Detect if a = x * coeff + y, where y \in [0, coeff), x >= 0
// return true if such detection is successful
// return false if it is not.
std
::
vector
<
ComExpr
>
TryLinearEquation
(
const
ComExpr
&
a
,
const
Expr
&
coeff
)
{
Type
type
=
coeff
.
type
();
int64_t
value
=
GetConstIntValue
(
coeff
);
if
(
value
<
0
)
return
{};
std
::
shared_ptr
<
ComExprNode
>
xnode
=
std
::
make_shared
<
ComExprNode
>
();
std
::
shared_ptr
<
ComExprNode
>
ynode
=
std
::
make_shared
<
ComExprNode
>
();
if
(
a
->
base
%
value
==
0
)
{
xnode
->
base
=
a
->
base
;
}
else
{
ynode
->
base
=
a
->
base
;
}
for
(
const
auto
&
e
:
a
->
elem
)
{
if
(
e
.
scale
%
value
==
0
)
{
xnode
->
elem
.
push_back
(
e
);
}
else
{
ynode
->
elem
.
push_back
(
e
);
}
}
Expr
yres
=
Sum2Expr
(
ComExpr
(
ynode
),
type
);
IntSet
yset
=
EvalSet
(
yres
,
var_range_
);
// This relies on the integer division rounds down
// Most cases it is good for integer division.
if
(
yset
.
min
().
type
()
==
type
&&
can_prove
(
yset
.
min
()
>=
make_zero
(
type
))
&&
yset
.
max
().
type
()
==
type
&&
can_prove
(
yset
.
max
()
<
coeff
))
{
xnode
->
base
/=
value
;
for
(
auto
&
e
:
xnode
->
elem
)
{
e
.
scale
/=
value
;
}
return
{
ComExpr
(
xnode
),
ComExpr
(
ynode
)};
}
else
{
return
{};
}
}
// subroutine to do produce a % v
Expr
SumModConst
(
ComExpr
a
,
Expr
v
)
{
std
::
vector
<
ComExpr
>
pair
=
TryLinearEquation
(
a
,
v
);
if
(
pair
.
size
()
==
0
)
{
int64_t
value
=
GetConstIntValue
(
v
);
std
::
shared_ptr
<
ComExprNode
>
n
=
std
::
make_shared
<
ComExprNode
>
();
int
mod_level
=
0
;
n
->
base
=
a
->
base
%
value
;
if
(
n
->
base
!=
0
)
mod_level
=
1
;
for
(
auto
e
:
a
->
elem
)
{
if
(
e
.
scale
%
value
==
0
)
continue
;
e
.
scale
=
e
.
scale
%
value
;
if
(
!
EvalSet
(
v
-
e
.
value
,
var_range_
).
can_prove_positive
())
{
mod_level
=
2
;
}
else
{
++
mod_level
;
}
n
->
elem
.
push_back
(
e
);
}
// cannot remove mode because there are more than two parts
if
(
mod_level
>=
2
)
{
Expr
ret
=
Sum2Expr
(
ComExpr
(
n
),
v
.
type
())
%
v
;
return
Binary
(
ret
.
as
<
Mod
>
(),
ret
);
}
ret_entry_
.
sum
=
ComExpr
(
n
);
ret_entry_
.
sum
=
pair
[
1
];
ret_entry_
.
max_level
=
stack_
.
back
().
max_level
;
ret_entry_
.
has_side_effect
=
stack_
.
back
().
has_side_effect
;
auto
it
=
cache_sum_
.
find
(
ret_entry_
.
sum
);
if
(
it
!=
cache_sum_
.
end
())
{
ret_entry_
=
it
->
second
;
}
else
{
ret_entry_
.
value
=
Sum2Expr
(
ret_entry_
.
sum
,
v
.
type
());
cache_sum_
[
ret_entry_
.
sum
]
=
ret_entry_
;
}
return
ret_entry_
.
value
;
}
// subroutine to do produce a % v
Expr
SumDivConst
(
ComExpr
a
,
Expr
v
)
{
std
::
vector
<
ComExpr
>
pair
=
TryLinearEquation
(
a
,
v
);
if
(
pair
.
size
()
==
0
)
{
Expr
ret
=
Sum2Expr
(
a
,
v
.
type
())
/
v
;
return
Binary
(
ret
.
as
<
Div
>
(),
ret
);
}
ret_entry_
.
sum
=
pair
[
0
];
ret_entry_
.
max_level
=
stack_
.
back
().
max_level
;
ret_entry_
.
has_side_effect
=
stack_
.
back
().
has_side_effect
;
auto
it
=
cache_sum_
.
find
(
ret_entry_
.
sum
);
...
...
src/pass/lower_warp_memory.cc
View file @
515d4b6f
...
...
@@ -279,7 +279,9 @@ class WarpMemoryRewriter : private IRMutator {
Stmt
Rewrite
(
Stmt
stmt
)
{
if
(
warp_size_
==
1
)
return
stmt
;
return
this
->
Mutate
(
stmt
);
stmt
=
this
->
Mutate
(
stmt
);
stmt
=
CanonicalSimplify
(
stmt
);
return
stmt
;
}
private
:
...
...
tests/python/unittest/test_arith_simplify.py
View file @
515d4b6f
...
...
@@ -37,6 +37,26 @@ def test_simplify_mod():
assert
index
==
j
def
test_modular
():
rx
=
tvm
.
var
(
"rx"
)
ry
=
tvm
.
var
(
"ry"
)
y
=
tvm
.
var
(
"y"
)
x
=
tvm
.
var
(
"x"
)
vmap
=
{
rx
:
tvm
.
Range
(
tvm
.
const
(
0
),
tvm
.
const
(
3
)),
ry
:
tvm
.
Range
(
tvm
.
const
(
0
),
tvm
.
const
(
3
)),
y
:
tvm
.
Range
(
tvm
.
const
(
0
),
tvm
.
const
(
2
)),
x
:
tvm
.
Range
(
tvm
.
const
(
0
),
tvm
.
const
(
14
))}
idx
=
ry
*
16
+
rx
+
y
*
16
+
x
z1
=
tvm
.
ir_pass
.
CanonicalSimplify
(
idx
//
16
,
vmap
)
z2
=
tvm
.
ir_pass
.
CanonicalSimplify
(
idx
%
16
,
vmap
)
assert
tvm
.
ir_pass
.
CanonicalSimplify
(
z1
-
(
ry
+
y
))
.
value
==
0
assert
tvm
.
ir_pass
.
CanonicalSimplify
(
z2
-
(
rx
+
x
))
.
value
==
0
if
__name__
==
"__main__"
:
test_simplify_mod
()
test_modular
()
test_simplify
()
tests/python/unittest/test_pass_simplify.py
View file @
515d4b6f
...
...
@@ -33,7 +33,6 @@ def test_bound():
ret
=
tvm
.
ir_pass
.
Simplify
(
m
%
10
,
vrange
)
assert
ret
==
m
def
test_canonical
():
x
=
tvm
.
var
(
"x"
)
z
=
tvm
.
const
(
3
)
...
...
@@ -54,6 +53,7 @@ def test_canonical():
assert
(
tvm
.
ir_pass
.
Equal
(
ret1
,
ret2
))
if
__name__
==
"__main__"
:
test_modular
()
test_bound
()
test_basic
()
test_simplify
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment