Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
046e4ff0
Unverified
Commit
046e4ff0
authored
Mar 14, 2019
by
Tianqi Chen
Committed by
GitHub
Mar 14, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[ARITH] RewriteSimplifier: min/max, logical, select (#2768)
parent
6c60b8d3
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
919 additions
and
16 deletions
+919
-16
src/arithmetic/rewrite_simplify.cc
+652
-5
tests/python/unittest/test_arith_rewrite_simplify.py
+267
-11
No files found.
src/arithmetic/rewrite_simplify.cc
View file @
046e4ff0
...
...
@@ -75,8 +75,29 @@ class RewriteSimplifier::Impl : public IRMutator {
Expr
Mutate_
(
const
Mul
*
op
,
const
Expr
&
self
)
final
;
Expr
Mutate_
(
const
Div
*
op
,
const
Expr
&
self
)
final
;
Expr
Mutate_
(
const
Mod
*
op
,
const
Expr
&
self
)
final
;
Expr
Mutate_
(
const
Min
*
op
,
const
Expr
&
self
)
final
;
Expr
Mutate_
(
const
Max
*
op
,
const
Expr
&
self
)
final
;
Expr
Mutate_
(
const
EQ
*
op
,
const
Expr
&
self
)
final
;
Expr
Mutate_
(
const
NE
*
op
,
const
Expr
&
self
)
final
;
Expr
Mutate_
(
const
LT
*
op
,
const
Expr
&
self
)
final
;
Expr
Mutate_
(
const
LE
*
op
,
const
Expr
&
self
)
final
;
Expr
Mutate_
(
const
GT
*
op
,
const
Expr
&
self
)
final
;
Expr
Mutate_
(
const
GE
*
op
,
const
Expr
&
self
)
final
;
Expr
Mutate_
(
const
And
*
op
,
const
Expr
&
self
)
final
;
Expr
Mutate_
(
const
Or
*
op
,
const
Expr
&
self
)
final
;
Expr
Mutate_
(
const
Not
*
op
,
const
Expr
&
self
)
final
;
Expr
Mutate_
(
const
Select
*
op
,
const
Expr
&
self
)
final
;
Expr
Mutate_
(
const
Ramp
*
op
,
const
Expr
&
self
)
final
;
private
:
/*! \brief internal structure for comparison. */
enum
CompareResult
{
kUnknown
,
kEQ
,
kGT
,
kLT
,
kNE
};
// reference to the main analyzer
Analyzer
*
parent_
;
// counter to record recursive rewrite depth.
...
...
@@ -92,12 +113,36 @@ class RewriteSimplifier::Impl : public IRMutator {
// Whether x == val
bool
CanProveEqual
(
const
Expr
&
x
,
int64_t
val
)
{
// TODO(tqchen) refer back to super-analyzer.
Expr
res
=
Mutate
(
x
);
if
(
const
auto
*
ptr
=
res
.
as
<
ir
::
IntImm
>
())
{
return
ptr
->
value
==
val
;
return
TryCompare
(
x
,
val
)
==
kEQ
;
}
// try to prove x equals val
CompareResult
TryCompare
(
const
Expr
&
x
,
int64_t
val
)
{
Expr
diff
=
Mutate
(
x
);
if
(
const
auto
*
ptr
=
diff
.
as
<
IntImm
>
())
{
if
(
ptr
->
value
==
val
)
{
return
kEQ
;
}
else
if
(
ptr
->
value
>
val
)
{
return
kGT
;
}
else
if
(
ptr
->
value
<
val
)
{
return
kLT
;
}
}
if
(
val
==
0
)
{
ModularSet
dmod
=
parent_
->
modular_set
(
diff
);
if
(
dmod
->
base
!=
0
)
{
return
kNE
;
}
}
ConstIntBound
dbound
=
parent_
->
const_int_bound
(
diff
);
if
(
dbound
->
min_value
>
val
)
{
return
kGT
;
}
if
(
dbound
->
max_value
<
val
)
{
return
kLT
;
}
return
false
;
return
kUnknown
;
}
// Recursive rewrite x
// we limit maximum depth of recursive rewrite allowed to
// avoid infinite loop
...
...
@@ -557,7 +602,7 @@ Mutate_(const Mod* op, const Expr& self) {
// Pattern var to match any expression
PVar
<
Expr
>
x
,
y
,
z
,
b1
;
// Pattern var match IntImm
PVar
<
Integer
>
c1
,
c2
,
c3
;
PVar
<
Integer
>
c1
,
c2
;
// Pattern var for lanes in broadcast and ramp
PVar
<
int
>
lanes
;
...
...
@@ -626,6 +671,608 @@ Mutate_(const Mod* op, const Expr& self) {
return
ret
;
}
Expr
RewriteSimplifier
::
Impl
::
Mutate_
(
const
Min
*
op
,
const
Expr
&
self
)
{
Expr
ret
=
IRMutator
::
Mutate_
(
op
,
self
);
op
=
ret
.
as
<
Min
>
();
Expr
const_res
=
TryConstFold
<
Min
>
(
op
->
a
,
op
->
b
);
if
(
const_res
.
defined
())
return
const_res
;
// Pattern var to match any expression
PVar
<
Expr
>
x
,
y
,
z
,
s1
,
s2
;
// Pattern var match IntImm
PVar
<
Integer
>
c1
,
c2
;
PVar
<
int
>
lanes
;
// vector rule
if
(
op
->
type
.
lanes
()
!=
1
)
{
TVM_TRY_REWRITE
(
min
(
broadcast
(
x
,
lanes
),
broadcast
(
y
,
lanes
)),
broadcast
(
min
(
x
,
y
),
lanes
));
TVM_TRY_REWRITE
(
min
(
min
(
x
,
broadcast
(
y
,
lanes
)),
broadcast
(
z
,
lanes
)),
min
(
x
,
broadcast
(
min
(
y
,
z
),
lanes
)));
}
if
(
IsIndexType
(
op
->
type
))
{
TVM_TRY_REWRITE
(
min
(
x
,
x
),
x
);
// constant int bound
ConstIntBound
a_bound
=
parent_
->
const_int_bound
(
op
->
a
);
ConstIntBound
b_bound
=
parent_
->
const_int_bound
(
op
->
b
);
if
(
a_bound
->
max_value
<=
b_bound
->
min_value
)
{
return
op
->
a
;
}
if
(
b_bound
->
max_value
<=
a_bound
->
min_value
)
{
return
op
->
b
;
}
// constant comparison
if
(
min
(
x
+
c1
,
x
+
c2
).
Match
(
ret
))
{
if
(
c1
.
Eval
()
->
value
<
c2
.
Eval
()
->
value
)
{
return
(
x
+
c1
).
Eval
();
}
else
{
return
(
x
+
c2
).
Eval
();
}
}
if
(
min
(
x
+
c1
,
x
).
Match
(
ret
)
||
min
(
x
,
x
+
c1
).
Match
(
ret
))
{
if
(
c1
.
Eval
()
->
value
<
0
)
{
return
(
x
+
c1
).
Eval
();
}
else
{
return
x
.
Eval
();
}
}
if
(
min
(
c1
-
x
,
c2
-
x
).
Match
(
ret
))
{
if
(
c1
.
Eval
()
->
value
<
c2
.
Eval
()
->
value
)
{
return
(
c1
-
x
).
Eval
();
}
else
{
return
(
c2
-
x
).
Eval
();
}
}
// Divide up rounding
TVM_TRY_REWRITE_IF
(
min
(((
x
+
c1
)
/
c2
)
*
c2
,
x
),
x
,
c2
.
Eval
()
->
value
>
0
&&
c1
.
Eval
()
->
value
+
1
==
c2
.
Eval
()
->
value
);
TVM_TRY_REWRITE_IF
(
min
(((
x
+
c1
)
/
c2
)
*
c2
,
max
(
x
,
c2
)),
max
(
x
,
c2
),
c2
.
Eval
()
->
value
>
0
&&
c1
.
Eval
()
->
value
+
1
==
c2
.
Eval
()
->
value
&&
CanProveGreaterEqual
(
x
.
Eval
(),
0
));
TVM_TRY_REWRITE_IF
(
min
(
x
,
((
x
+
c1
)
/
c2
)
*
c2
),
x
,
c2
.
Eval
()
->
value
>
0
&&
c1
.
Eval
()
->
value
+
1
==
c2
.
Eval
()
->
value
);
TVM_TRY_REWRITE_IF
(
min
(
max
(
x
,
c2
),
((
x
+
c1
)
/
c2
)
*
c2
),
max
(
x
,
c2
),
c2
.
Eval
()
->
value
>
0
&&
c1
.
Eval
()
->
value
+
1
==
c2
.
Eval
()
->
value
&&
CanProveGreaterEqual
(
x
.
Eval
(),
0
));
TVM_TRY_REWRITE
(
min
(
max
(
x
,
y
),
min
(
x
,
y
)),
min
(
x
,
y
));
TVM_TRY_REWRITE
(
min
(
max
(
x
,
y
),
min
(
y
,
x
)),
min
(
x
,
y
));
TVM_TRY_REWRITE
(
min
(
min
(
x
,
y
),
max
(
x
,
y
)),
min
(
x
,
y
));
TVM_TRY_REWRITE
(
min
(
min
(
x
,
y
),
max
(
y
,
x
)),
min
(
x
,
y
));
TVM_TRY_REWRITE
(
min
(
max
(
x
,
y
),
x
),
x
);
TVM_TRY_REWRITE
(
min
(
max
(
x
,
y
),
y
),
y
);
TVM_TRY_REWRITE
(
min
(
min
(
x
,
y
),
x
),
min
(
x
,
y
));
TVM_TRY_REWRITE
(
min
(
min
(
x
,
y
),
y
),
min
(
x
,
y
));
TVM_TRY_REWRITE
(
min
(
x
,
max
(
x
,
y
)),
x
);
TVM_TRY_REWRITE
(
min
(
y
,
max
(
x
,
y
)),
y
);
TVM_TRY_REWRITE
(
min
(
x
,
min
(
x
,
y
)),
min
(
x
,
y
));
TVM_TRY_REWRITE
(
min
(
y
,
min
(
x
,
y
)),
min
(
x
,
y
));
TVM_TRY_REWRITE
(
min
(
min
(
min
(
x
,
y
),
z
),
y
),
min
(
min
(
x
,
y
),
z
));
TVM_TRY_REWRITE
(
min
(
min
(
min
(
min
(
x
,
y
),
z
),
s1
),
y
),
min
(
min
(
min
(
x
,
y
),
z
),
s1
));
TVM_TRY_REWRITE
(
min
(
min
(
min
(
min
(
min
(
x
,
y
),
z
),
s1
),
s2
),
y
),
min
(
min
(
min
(
min
(
x
,
y
),
z
),
s1
),
s2
));
TVM_TRY_REWRITE
(
min
(
max
(
x
,
y
),
max
(
x
,
z
)),
max
(
min
(
y
,
z
),
x
));
TVM_TRY_REWRITE
(
min
(
max
(
x
,
y
),
max
(
z
,
x
)),
max
(
min
(
y
,
z
),
x
));
TVM_TRY_REWRITE
(
min
(
max
(
y
,
x
),
max
(
x
,
z
)),
max
(
min
(
y
,
z
),
x
));
TVM_TRY_REWRITE
(
min
(
max
(
y
,
x
),
max
(
z
,
x
)),
max
(
min
(
y
,
z
),
x
));
TVM_TRY_REWRITE
(
min
(
min
(
x
,
y
),
min
(
x
,
z
)),
min
(
min
(
y
,
z
),
x
));
TVM_TRY_REWRITE
(
min
(
min
(
x
,
y
),
min
(
z
,
x
)),
min
(
min
(
y
,
z
),
x
));
TVM_TRY_REWRITE
(
min
(
min
(
y
,
x
),
min
(
x
,
z
)),
min
(
min
(
y
,
z
),
x
));
TVM_TRY_REWRITE
(
min
(
min
(
y
,
x
),
min
(
z
,
x
)),
min
(
min
(
y
,
z
),
x
));
TVM_TRY_REWRITE
(
min
(
y
+
x
,
z
+
x
),
min
(
y
,
z
)
+
x
);
TVM_TRY_REWRITE
(
min
(
y
+
x
,
x
+
z
),
min
(
y
,
z
)
+
x
);
TVM_TRY_REWRITE
(
min
(
x
+
y
,
x
+
z
),
min
(
y
,
z
)
+
x
);
TVM_TRY_REWRITE
(
min
(
x
+
y
,
z
+
x
),
min
(
y
,
z
)
+
x
);
// sub distribution
TVM_TRY_REWRITE
(
min
(
y
-
x
,
z
-
x
),
min
(
y
,
z
)
-
x
);
TVM_TRY_REWRITE
(
min
(
x
-
y
,
x
-
z
),
x
-
max
(
y
,
z
));
// constant folding rule.
TVM_TRY_REWRITE
(
min
(
min
(
x
,
c1
),
c2
),
min
(
x
,
min
(
c1
,
c2
)));
// scaling rule
if
(
min
(
x
/
c1
,
y
/
c1
).
Match
(
ret
))
{
if
(
c1
.
Eval
()
->
value
>
0
)
{
return
(
min
(
x
,
y
)
/
c1
).
Eval
();
}
else
{
return
(
max
(
x
,
y
)
/
c1
).
Eval
();
}
}
if
(
min
(
x
*
c1
,
y
*
c1
).
Match
(
ret
))
{
if
(
c1
.
Eval
()
->
value
>
0
)
{
return
(
min
(
x
,
y
)
*
c1
).
Eval
();
}
else
{
return
(
max
(
x
,
y
)
*
c1
).
Eval
();
}
}
if
(
min
(
x
*
c1
,
c2
).
Match
(
ret
))
{
int64_t
c1val
=
c1
.
Eval
()
->
value
;
int64_t
c2val
=
c2
.
Eval
()
->
value
;
if
(
c2val
%
c1val
==
0
)
{
if
(
c2val
/
c1val
>=
0
)
{
return
(
min
(
x
,
c2val
/
c1val
)
*
c1val
).
Eval
();
}
else
{
return
(
max
(
x
,
c2val
/
c1val
)
*
c1val
).
Eval
();
}
}
}
// canonicalization
TVM_TRY_RECURSIVE_REWRITE
(
min
(
min
(
x
,
c1
),
y
),
min
(
min
(
x
,
y
),
c1
));
TVM_TRY_RECURSIVE_REWRITE
(
min
(
c1
-
x
,
c2
),
c1
-
max
(
x
,
c2
-
c1
));
}
// condition rules.
TVM_TRY_REWRITE
(
min
(
select
(
x
,
y
,
z
),
select
(
x
,
s1
,
s2
)),
select
(
x
,
min
(
y
,
s1
),
min
(
z
,
s2
)));
return
ret
;
}
Expr
RewriteSimplifier
::
Impl
::
Mutate_
(
const
Max
*
op
,
const
Expr
&
self
)
{
Expr
ret
=
IRMutator
::
Mutate_
(
op
,
self
);
op
=
ret
.
as
<
Max
>
();
Expr
const_res
=
TryConstFold
<
Max
>
(
op
->
a
,
op
->
b
);
if
(
const_res
.
defined
())
return
const_res
;
// Pattern var to match any expression
PVar
<
Expr
>
x
,
y
,
z
,
s1
,
s2
;
// Pattern var match IntImm
PVar
<
Integer
>
c1
,
c2
;
PVar
<
int
>
lanes
;
// vector rule
if
(
op
->
type
.
lanes
()
!=
1
)
{
TVM_TRY_REWRITE
(
max
(
broadcast
(
x
,
lanes
),
broadcast
(
y
,
lanes
)),
broadcast
(
max
(
x
,
y
),
lanes
));
TVM_TRY_REWRITE
(
max
(
max
(
x
,
broadcast
(
y
,
lanes
)),
broadcast
(
z
,
lanes
)),
max
(
x
,
broadcast
(
max
(
y
,
z
),
lanes
)));
}
if
(
IsIndexType
(
op
->
type
))
{
TVM_TRY_REWRITE
(
max
(
x
,
x
),
x
);
// constant int bound
ConstIntBound
a_bound
=
parent_
->
const_int_bound
(
op
->
a
);
ConstIntBound
b_bound
=
parent_
->
const_int_bound
(
op
->
b
);
if
(
a_bound
->
min_value
>=
b_bound
->
max_value
)
{
return
op
->
a
;
}
if
(
b_bound
->
min_value
>=
a_bound
->
max_value
)
{
return
op
->
b
;
}
// constant comparison
if
(
max
(
x
+
c1
,
x
+
c2
).
Match
(
ret
))
{
if
(
c1
.
Eval
()
->
value
>
c2
.
Eval
()
->
value
)
{
return
(
x
+
c1
).
Eval
();
}
else
{
return
(
x
+
c2
).
Eval
();
}
}
if
(
max
(
x
+
c1
,
x
).
Match
(
ret
)
||
max
(
x
,
x
+
c1
).
Match
(
ret
))
{
if
(
c1
.
Eval
()
->
value
>
0
)
{
return
(
x
+
c1
).
Eval
();
}
else
{
return
x
.
Eval
();
}
}
if
(
max
(
c1
-
x
,
c2
-
x
).
Match
(
ret
))
{
if
(
c1
.
Eval
()
->
value
>
c2
.
Eval
()
->
value
)
{
return
(
c1
-
x
).
Eval
();
}
else
{
return
(
c2
-
x
).
Eval
();
}
}
// Divide up rounding
TVM_TRY_REWRITE_IF
(
max
(((
x
+
c1
)
/
c2
)
*
c2
,
x
),
((
x
+
c1
)
/
c2
)
*
c2
,
c2
.
Eval
()
->
value
>
0
&&
c1
.
Eval
()
->
value
+
1
==
c2
.
Eval
()
->
value
);
TVM_TRY_REWRITE_IF
(
max
(
x
,
((
x
+
c1
)
/
c2
)
*
c2
),
((
x
+
c1
)
/
c2
)
*
c2
,
c2
.
Eval
()
->
value
>
0
&&
c1
.
Eval
()
->
value
+
1
==
c2
.
Eval
()
->
value
);
TVM_TRY_REWRITE
(
max
(
min
(
x
,
y
),
max
(
x
,
y
)),
max
(
x
,
y
));
TVM_TRY_REWRITE
(
max
(
min
(
x
,
y
),
max
(
y
,
x
)),
max
(
x
,
y
));
TVM_TRY_REWRITE
(
max
(
max
(
x
,
y
),
min
(
x
,
y
)),
max
(
x
,
y
));
TVM_TRY_REWRITE
(
max
(
max
(
x
,
y
),
min
(
y
,
x
)),
max
(
x
,
y
));
TVM_TRY_REWRITE
(
max
(
min
(
x
,
y
),
x
),
x
);
TVM_TRY_REWRITE
(
max
(
min
(
x
,
y
),
y
),
y
);
TVM_TRY_REWRITE
(
max
(
max
(
x
,
y
),
x
),
max
(
x
,
y
));
TVM_TRY_REWRITE
(
max
(
max
(
x
,
y
),
y
),
max
(
x
,
y
));
TVM_TRY_REWRITE
(
max
(
x
,
min
(
x
,
y
)),
x
);
TVM_TRY_REWRITE
(
max
(
y
,
min
(
x
,
y
)),
y
);
TVM_TRY_REWRITE
(
max
(
x
,
max
(
x
,
y
)),
max
(
x
,
y
));
TVM_TRY_REWRITE
(
max
(
y
,
max
(
x
,
y
)),
max
(
x
,
y
));
TVM_TRY_REWRITE
(
max
(
max
(
max
(
x
,
y
),
z
),
y
),
max
(
max
(
x
,
y
),
z
));
TVM_TRY_REWRITE
(
max
(
max
(
max
(
max
(
x
,
y
),
z
),
s1
),
y
),
max
(
max
(
max
(
x
,
y
),
z
),
s1
));
TVM_TRY_REWRITE
(
max
(
max
(
max
(
max
(
max
(
x
,
y
),
z
),
s1
),
s2
),
y
),
max
(
max
(
max
(
max
(
x
,
y
),
z
),
s1
),
s2
));
// max/max cancelation
TVM_TRY_REWRITE
(
max
(
max
(
x
,
y
),
max
(
x
,
z
)),
max
(
max
(
y
,
z
),
x
));
TVM_TRY_REWRITE
(
max
(
max
(
x
,
y
),
max
(
z
,
x
)),
max
(
max
(
y
,
z
),
x
));
TVM_TRY_REWRITE
(
max
(
max
(
y
,
x
),
max
(
x
,
z
)),
max
(
max
(
y
,
z
),
x
));
TVM_TRY_REWRITE
(
max
(
max
(
y
,
x
),
max
(
z
,
x
)),
max
(
max
(
y
,
z
),
x
));
// max/min distribution
TVM_TRY_REWRITE
(
max
(
min
(
x
,
y
),
min
(
x
,
z
)),
min
(
max
(
y
,
z
),
x
));
TVM_TRY_REWRITE
(
max
(
min
(
x
,
y
),
min
(
z
,
x
)),
min
(
max
(
y
,
z
),
x
));
TVM_TRY_REWRITE
(
max
(
min
(
y
,
x
),
min
(
x
,
z
)),
min
(
max
(
y
,
z
),
x
));
TVM_TRY_REWRITE
(
max
(
min
(
y
,
x
),
min
(
z
,
x
)),
min
(
max
(
y
,
z
),
x
));
// add distribution
TVM_TRY_REWRITE
(
max
(
y
+
x
,
z
+
x
),
max
(
y
,
z
)
+
x
);
TVM_TRY_REWRITE
(
max
(
y
+
x
,
x
+
z
),
max
(
y
,
z
)
+
x
);
TVM_TRY_REWRITE
(
max
(
x
+
y
,
x
+
z
),
max
(
y
,
z
)
+
x
);
TVM_TRY_REWRITE
(
max
(
x
+
y
,
z
+
x
),
max
(
y
,
z
)
+
x
);
// sub distribution
TVM_TRY_REWRITE
(
max
(
y
-
x
,
z
-
x
),
max
(
y
,
z
)
-
x
);
TVM_TRY_REWRITE
(
max
(
x
-
y
,
x
-
z
),
x
-
min
(
y
,
z
));
// constant folding rule.
TVM_TRY_REWRITE
(
max
(
max
(
x
,
c1
),
c2
),
max
(
x
,
max
(
c1
,
c2
)));
// scaling rule
if
(
max
(
x
/
c1
,
y
/
c1
).
Match
(
ret
))
{
if
(
c1
.
Eval
()
->
value
>
0
)
{
return
(
max
(
x
,
y
)
/
c1
).
Eval
();
}
else
{
return
(
min
(
x
,
y
)
/
c1
).
Eval
();
}
}
if
(
max
(
x
*
c1
,
y
*
c1
).
Match
(
ret
))
{
if
(
c1
.
Eval
()
->
value
>
0
)
{
return
(
max
(
x
,
y
)
*
c1
).
Eval
();
}
else
{
return
(
min
(
x
,
y
)
*
c1
).
Eval
();
}
}
if
(
max
(
x
*
c1
,
c2
).
Match
(
ret
))
{
int64_t
c1val
=
c1
.
Eval
()
->
value
;
int64_t
c2val
=
c2
.
Eval
()
->
value
;
if
(
c2val
%
c1val
==
0
)
{
if
(
c2val
/
c1val
>=
0
)
{
return
(
max
(
x
,
c2val
/
c1val
)
*
c1val
).
Eval
();
}
else
{
return
(
min
(
x
,
c2val
/
c1val
)
*
c1val
).
Eval
();
}
}
}
// canonicalization
TVM_TRY_RECURSIVE_REWRITE
(
max
(
max
(
x
,
c1
),
y
),
max
(
max
(
x
,
y
),
c1
));
TVM_TRY_RECURSIVE_REWRITE
(
max
(
c1
-
x
,
c2
),
c1
-
min
(
x
,
c2
-
c1
));
}
// condition rules.
TVM_TRY_REWRITE
(
max
(
select
(
x
,
y
,
z
),
select
(
x
,
s1
,
s2
)),
select
(
x
,
max
(
y
,
s1
),
max
(
z
,
s2
)));
return
ret
;
}
Expr
RewriteSimplifier
::
Impl
::
Mutate_
(
const
EQ
*
op
,
const
Expr
&
self
)
{
Expr
ret
=
IRMutator
::
Mutate_
(
op
,
self
);
op
=
ret
.
as
<
EQ
>
();
Expr
const_res
=
TryConstFold
<
EQ
>
(
op
->
a
,
op
->
b
);
if
(
const_res
.
defined
())
return
const_res
;
// Pattern var to match any expression
PVar
<
Expr
>
x
,
y
;
// Pattern var match IntImm
PVar
<
Integer
>
c1
;
PVar
<
int
>
lanes
;
// vector rule
if
(
op
->
type
.
lanes
()
!=
1
)
{
TVM_TRY_REWRITE
(
broadcast
(
x
,
lanes
)
==
broadcast
(
y
,
lanes
),
broadcast
(
x
==
y
,
lanes
));
}
if
(
IsIndexType
(
op
->
a
.
type
()))
{
CompareResult
result
=
TryCompare
(
op
->
a
-
op
->
b
,
0
);
if
(
result
!=
kUnknown
)
{
if
(
result
==
kEQ
)
{
return
make_const
(
op
->
type
,
true
);
}
else
{
return
make_const
(
op
->
type
,
false
);
}
}
TVM_TRY_REWRITE
(
x
-
c1
==
0
,
x
==
c1
);
TVM_TRY_REWRITE
(
c1
-
x
==
0
,
x
==
c1
);
TVM_TRY_REWRITE
(
x
+
c1
==
0
,
x
==
0
-
c1
);
TVM_TRY_REWRITE
(
x
*
y
==
0
,
x
==
0
||
y
==
0
);
}
return
ret
;
}
Expr
RewriteSimplifier
::
Impl
::
Mutate_
(
const
NE
*
op
,
const
Expr
&
self
)
{
return
Mutate
(
Not
::
make
(
op
->
a
==
op
->
b
));
}
Expr
RewriteSimplifier
::
Impl
::
Mutate_
(
const
LE
*
op
,
const
Expr
&
self
)
{
return
Mutate
(
Not
::
make
(
op
->
b
<
op
->
a
));
}
Expr
RewriteSimplifier
::
Impl
::
Mutate_
(
const
GT
*
op
,
const
Expr
&
self
)
{
return
Mutate
(
op
->
b
<
op
->
a
);
}
Expr
RewriteSimplifier
::
Impl
::
Mutate_
(
const
GE
*
op
,
const
Expr
&
self
)
{
return
Mutate
(
Not
::
make
(
op
->
a
<
op
->
b
));
}
Expr
RewriteSimplifier
::
Impl
::
Mutate_
(
const
LT
*
op
,
const
Expr
&
self
)
{
Expr
ret
=
IRMutator
::
Mutate_
(
op
,
self
);
op
=
ret
.
as
<
LT
>
();
Expr
const_res
=
TryConstFold
<
LT
>
(
op
->
a
,
op
->
b
);
if
(
const_res
.
defined
())
return
const_res
;
// Pattern var to match any expression
PVar
<
Expr
>
x
,
y
,
z
,
s1
,
s2
;
// Pattern var match IntImm
PVar
<
Integer
>
c1
,
c2
;
PVar
<
int
>
lanes
;
// vector rule
if
(
op
->
type
.
lanes
()
!=
1
)
{
TVM_TRY_REWRITE
(
broadcast
(
x
,
lanes
)
<
broadcast
(
y
,
lanes
),
broadcast
(
x
<
y
,
lanes
));
TVM_TRY_REWRITE
(
ramp
(
x
,
s1
,
lanes
)
<
ramp
(
y
,
s1
,
lanes
),
broadcast
(
x
<
y
,
lanes
));
}
if
(
IsIndexType
(
op
->
a
.
type
()))
{
CompareResult
result
=
TryCompare
(
op
->
a
-
op
->
b
,
0
);
if
(
result
==
kLT
)
{
return
make_const
(
op
->
type
,
true
);
}
if
(
result
==
kEQ
||
result
==
kGT
)
{
return
make_const
(
op
->
type
,
false
);
}
TVM_TRY_REWRITE
(
x
+
y
<
x
+
z
,
y
<
z
);
TVM_TRY_REWRITE
(
x
+
y
<
z
+
x
,
y
<
z
);
TVM_TRY_REWRITE
(
y
+
x
<
x
+
z
,
y
<
z
);
TVM_TRY_REWRITE
(
y
+
x
<
z
+
x
,
y
<
z
);
TVM_TRY_REWRITE
(
y
-
x
<
z
-
x
,
y
<
z
);
TVM_TRY_REWRITE
(
x
-
y
<
x
-
z
,
z
<
y
);
TVM_TRY_REWRITE
(
x
<
x
+
z
,
0
<
z
);
TVM_TRY_REWRITE
(
x
<
z
+
x
,
0
<
z
);
TVM_TRY_REWRITE
(
x
<
x
-
z
,
z
<
0
);
TVM_TRY_REWRITE
(
c1
<
x
+
c2
,
c1
-
c2
<
x
);
TVM_TRY_REWRITE
(
c1
<
c2
-
x
,
x
<
c2
-
c1
);
TVM_TRY_REWRITE_IF
(
x
*
c1
<
y
*
c1
,
x
<
y
,
c1
.
Eval
()
->
value
>
0
);
TVM_TRY_REWRITE_IF
(
x
*
c1
<
y
*
c1
,
y
<
x
,
c1
.
Eval
()
->
value
<
0
);
// require c1 > 0 to work for any div mode
TVM_TRY_REWRITE_IF
(
x
*
c2
<
c1
,
x
<
(
c1
-
1
)
/
c2
+
1
,
c1
.
Eval
()
->
value
>
0
&&
c2
.
Eval
()
->
value
>
0
);
TVM_TRY_REWRITE_IF
(
x
/
c1
<
c2
,
x
<
c1
*
c2
,
c1
.
Eval
()
->
value
>
0
&&
c2
.
Eval
()
->
value
>
0
);
TVM_TRY_REWRITE_IF
(
c1
<
x
*
c2
,
c1
/
c2
<
x
,
c1
.
Eval
()
->
value
>=
0
&&
c2
.
Eval
()
->
value
>
0
);
TVM_TRY_REWRITE_IF
(
c1
<
x
/
c2
,
(
c1
+
1
)
*
c2
-
1
<
x
,
c1
.
Eval
()
->
value
>=
0
&&
c2
.
Eval
()
->
value
>
0
);
// division related simplificationx
// invariance for any div mod: x - (x / c1) * c1 == x % c1
TVM_TRY_REWRITE_IF
((
x
/
c1
)
*
c1
<
x
,
0
<
x
%
c1
,
c1
.
Eval
()
->
value
>
0
);
TVM_TRY_REWRITE_IF
((
x
/
c1
)
*
c1
<
x
+
y
,
0
<
x
%
c1
+
y
,
c1
.
Eval
()
->
value
>
0
);
TVM_TRY_REWRITE_IF
((
x
/
c1
)
*
c1
<
x
-
y
,
y
<
x
%
c1
,
c1
.
Eval
()
->
value
>
0
);
TVM_TRY_REWRITE_IF
(((
x
+
c2
)
/
c1
)
*
c1
<
x
,
c2
<
(
x
+
c2
)
%
c1
,
c1
.
Eval
()
->
value
>
0
);
TVM_TRY_REWRITE_IF
(((
x
+
c2
)
/
c1
)
*
c1
<
x
+
y
,
c2
<
(
x
+
c2
)
%
c1
+
y
,
c1
.
Eval
()
->
value
>
0
);
TVM_TRY_REWRITE_IF
(((
x
+
c2
)
/
c1
)
*
c1
<
x
-
y
,
y
<
(
x
+
c2
)
%
c1
+
(
0
-
c2
),
c1
.
Eval
()
->
value
>
0
);
// canonicalization rule
TVM_TRY_RECURSIVE_REWRITE
(
min
(
x
,
y
)
<
z
,
x
<
z
||
y
<
z
);
TVM_TRY_RECURSIVE_REWRITE
(
max
(
x
,
y
)
<
z
,
x
<
z
&&
y
<
z
);
TVM_TRY_RECURSIVE_REWRITE
(
z
<
min
(
x
,
y
),
z
<
x
&&
z
<
y
);
TVM_TRY_RECURSIVE_REWRITE
(
z
<
max
(
x
,
y
),
z
<
x
||
z
<
y
);
TVM_TRY_REWRITE
(
x
-
c1
<
0
,
x
<
c1
);
TVM_TRY_REWRITE
(
x
+
c1
<
c2
,
x
<
c2
-
c1
);
}
return
ret
;
}
Expr
RewriteSimplifier
::
Impl
::
Mutate_
(
const
Not
*
op
,
const
Expr
&
self
)
{
Expr
ret
=
IRMutator
::
Mutate_
(
op
,
self
);
op
=
ret
.
as
<
Not
>
();
Expr
const_res
=
TryConstFold
<
Not
>
(
op
->
a
);
if
(
const_res
.
defined
())
return
const_res
;
// Pattern var to match any expression
PVar
<
Expr
>
x
,
y
;
PVar
<
int
>
lanes
;
if
(
op
->
type
.
lanes
()
!=
1
)
{
TVM_TRY_REWRITE
(
!
broadcast
(
x
,
lanes
),
broadcast
(
!
x
,
lanes
));
}
TVM_TRY_REWRITE
(
!
(
!
x
),
x
);
TVM_TRY_REWRITE
(
!
(
x
<=
y
),
y
<
x
);
TVM_TRY_REWRITE
(
!
(
x
>=
y
),
x
<
y
);
TVM_TRY_REWRITE
(
!
(
x
<
y
),
y
<=
x
);
TVM_TRY_REWRITE
(
!
(
x
>
y
),
x
<=
y
);
TVM_TRY_REWRITE
(
!
(
x
==
y
),
x
!=
y
);
TVM_TRY_REWRITE
(
!
(
x
!=
y
),
x
==
y
);
TVM_TRY_RECURSIVE_REWRITE
(
!
(
x
||
y
),
(
!
x
)
&&
(
!
y
));
TVM_TRY_RECURSIVE_REWRITE
(
!
(
x
&&
y
),
(
!
x
)
||
(
!
y
));
return
ret
;
}
Expr
RewriteSimplifier
::
Impl
::
Mutate_
(
const
And
*
op
,
const
Expr
&
self
)
{
Expr
ret
=
IRMutator
::
Mutate_
(
op
,
self
);
op
=
ret
.
as
<
And
>
();
Expr
const_res
=
TryConstFold
<
And
>
(
op
->
a
,
op
->
b
);
if
(
const_res
.
defined
())
return
const_res
;
// Pattern var to match any expression
PVar
<
Expr
>
x
,
y
;
// Pattern var match IntImm
PVar
<
Integer
>
c1
,
c2
;
PVar
<
int
>
lanes
;
if
(
op
->
type
.
lanes
()
!=
1
)
{
TVM_TRY_REWRITE
(
broadcast
(
x
,
lanes
)
&&
broadcast
(
y
,
lanes
),
broadcast
(
x
&&
y
,
lanes
));
}
auto
cfalse
=
PConst
<
Expr
>
(
make_const
(
op
->
type
,
false
));
TVM_TRY_REWRITE
(
x
==
y
&&
x
!=
y
,
cfalse
);
TVM_TRY_REWRITE
(
x
!=
y
&&
x
==
y
,
cfalse
);
TVM_TRY_REWRITE
(
x
&&
!
x
,
cfalse
);
TVM_TRY_REWRITE
(
x
<=
y
&&
y
<
x
,
cfalse
);
TVM_TRY_REWRITE
(
y
<
x
&&
y
<=
x
,
cfalse
);
TVM_TRY_REWRITE_IF
(
x
<
c1
&&
c2
<
x
,
cfalse
,
c2
.
Eval
()
->
value
+
1
>=
c1
.
Eval
()
->
value
);
TVM_TRY_REWRITE_IF
(
c2
<
x
&&
x
<
c1
,
cfalse
,
c2
.
Eval
()
->
value
+
1
>=
c1
.
Eval
()
->
value
);
TVM_TRY_REWRITE_IF
(
x
<
c1
&&
c2
<=
x
,
cfalse
,
c2
.
Eval
()
->
value
>=
c1
.
Eval
()
->
value
);
TVM_TRY_REWRITE_IF
(
c2
<=
x
&&
x
<
c1
,
cfalse
,
c2
.
Eval
()
->
value
>=
c1
.
Eval
()
->
value
);
TVM_TRY_REWRITE_IF
(
x
<=
c1
&&
c2
<
x
,
cfalse
,
c2
.
Eval
()
->
value
>=
c1
.
Eval
()
->
value
);
TVM_TRY_REWRITE_IF
(
c2
<
x
&&
x
<=
c1
,
cfalse
,
c2
.
Eval
()
->
value
>=
c1
.
Eval
()
->
value
);
TVM_TRY_REWRITE_IF
(
x
<=
c1
&&
c2
<=
x
,
cfalse
,
c2
.
Eval
()
->
value
>
c1
.
Eval
()
->
value
);
TVM_TRY_REWRITE_IF
(
c2
<=
x
&&
x
<=
c1
,
cfalse
,
c2
.
Eval
()
->
value
>
c1
.
Eval
()
->
value
);
TVM_TRY_REWRITE
(
x
==
c1
&&
x
!=
c2
,
x
==
c1
&&
c1
!=
c2
);
TVM_TRY_REWRITE
(
x
!=
c2
&&
x
==
c1
,
x
==
c1
&&
c1
!=
c2
);
return
ret
;
}
Expr
RewriteSimplifier
::
Impl
::
Mutate_
(
const
Or
*
op
,
const
Expr
&
self
)
{
Expr
ret
=
IRMutator
::
Mutate_
(
op
,
self
);
op
=
ret
.
as
<
Or
>
();
Expr
const_res
=
TryConstFold
<
Or
>
(
op
->
a
,
op
->
b
);
if
(
const_res
.
defined
())
return
const_res
;
// Pattern var to match any expression
PVar
<
Expr
>
x
,
y
;
// Pattern var match IntImm
PVar
<
Integer
>
c1
,
c2
;
PVar
<
int
>
lanes
;
if
(
op
->
type
.
lanes
()
!=
1
)
{
TVM_TRY_REWRITE
(
broadcast
(
x
,
lanes
)
||
broadcast
(
y
,
lanes
),
broadcast
(
x
||
y
,
lanes
));
}
auto
ctrue
=
PConst
<
Expr
>
(
make_const
(
op
->
type
,
true
));
TVM_TRY_REWRITE
(
x
==
y
||
x
!=
y
,
ctrue
);
TVM_TRY_REWRITE
(
x
!=
y
||
x
==
y
,
ctrue
);
TVM_TRY_REWRITE
(
x
||
!
x
,
ctrue
);
TVM_TRY_REWRITE
(
x
<=
y
||
y
<
x
,
ctrue
);
TVM_TRY_REWRITE
(
y
<
x
||
y
<=
x
,
ctrue
);
TVM_TRY_REWRITE_IF
(
x
<
c1
||
c2
<
x
,
ctrue
,
c2
.
Eval
()
->
value
<
c1
.
Eval
()
->
value
);
TVM_TRY_REWRITE_IF
(
c2
<
x
||
x
<
c1
,
ctrue
,
c2
.
Eval
()
->
value
<
c1
.
Eval
()
->
value
);
TVM_TRY_REWRITE_IF
(
x
<=
c1
||
c2
<
x
,
ctrue
,
c2
.
Eval
()
->
value
<=
c1
.
Eval
()
->
value
);
TVM_TRY_REWRITE_IF
(
c2
<
x
||
x
<=
c1
,
ctrue
,
c2
.
Eval
()
->
value
<=
c1
.
Eval
()
->
value
);
TVM_TRY_REWRITE_IF
(
x
<
c1
||
c2
<=
x
,
ctrue
,
c2
.
Eval
()
->
value
<=
c1
.
Eval
()
->
value
);
TVM_TRY_REWRITE_IF
(
c2
<=
x
||
x
<
c1
,
ctrue
,
c2
.
Eval
()
->
value
<=
c1
.
Eval
()
->
value
);
TVM_TRY_REWRITE_IF
(
x
<=
c1
||
c2
<=
x
,
ctrue
,
c2
.
Eval
()
->
value
<=
c1
.
Eval
()
->
value
+
1
);
TVM_TRY_REWRITE_IF
(
c2
<=
x
||
x
<=
c1
,
ctrue
,
c2
.
Eval
()
->
value
<=
c1
.
Eval
()
->
value
+
1
);
TVM_TRY_REWRITE
(
x
!=
c1
||
x
==
c2
,
x
!=
c1
||
c1
==
c2
);
TVM_TRY_REWRITE
(
x
==
c2
||
x
!=
c1
,
x
!=
c1
||
c1
==
c2
);
return
ret
;
}
Expr
RewriteSimplifier
::
Impl
::
Mutate_
(
const
Ramp
*
op
,
const
Expr
&
self
)
{
Expr
ret
=
IRMutator
::
Mutate_
(
op
,
self
);
op
=
ret
.
as
<
Ramp
>
();
if
(
is_zero
(
op
->
stride
))
{
return
Broadcast
::
make
(
op
->
base
,
op
->
lanes
);
}
return
ret
;
}
Expr
RewriteSimplifier
::
Impl
::
Mutate_
(
const
Select
*
op
,
const
Expr
&
self
)
{
Expr
ret
=
IRMutator
::
Mutate_
(
op
,
self
);
op
=
ret
.
as
<
Select
>
();
if
(
is_zero
(
op
->
condition
))
{
return
op
->
false_value
;
}
if
(
is_one
(
op
->
condition
))
{
return
op
->
true_value
;
}
// Pattern var to match any expression
PVar
<
Expr
>
x
,
y
;
TVM_TRY_REWRITE
(
select
(
x
,
y
,
y
),
y
);
return
ret
;
}
Expr
RewriteSimplifier
::
operator
()(
const
Expr
&
expr
)
{
return
impl_
->
PostOrderSimplify
(
expr
);
...
...
tests/python/unittest/test_arith_rewrite_simplify.py
View file @
046e4ff0
...
...
@@ -6,8 +6,7 @@ class RewriteChecker:
def
verify
(
self
,
data
,
expected
):
res
=
self
.
analyzer
.
rewrite_simplify
(
data
)
assert
tvm
.
ir_pass
.
Equal
(
res
,
expected
),
"data={}, res={}, expected={}"
.
format
(
data
,
res
,
expected
)
assert
tvm
.
ir_pass
.
Equal
(
res
,
expected
),
"data={}, res={}, expected={}"
.
format
(
data
,
res
,
expected
)
def
test_vector_simplify
():
...
...
@@ -62,20 +61,57 @@ def test_vector_simplify():
ck
.
verify
(
tvm
.
expr
.
Ramp
(
x
*
8
+
1
,
15
,
4
)
%
8
,
tvm
.
expr
.
Ramp
(
1
,
15
,
4
)
%
8
)
# Min/Max rules
vx
=
tvm
.
var
(
"vx"
,
dtype
=
"int32x2"
)
vc
=
tvm
.
var
(
"vc"
,
dtype
=
"uint1"
)
ck
.
verify
(
tvm
.
min
(
y
.
astype
(
"int32x2"
),
x
.
astype
(
"int32x2"
)),
tvm
.
min
(
y
,
x
)
.
astype
(
"int32x2"
))
ck
.
verify
(
tvm
.
min
(
tvm
.
min
(
vx
,
y
.
astype
(
"int32x2"
)),
x
.
astype
(
"int32x2"
)),
tvm
.
min
(
vx
,
tvm
.
min
(
y
,
x
)
.
astype
(
"int32x2"
)))
ck
.
verify
(
tvm
.
max
(
y
.
astype
(
"int32x2"
),
x
.
astype
(
"int32x2"
)),
tvm
.
max
(
y
,
x
)
.
astype
(
"int32x2"
))
ck
.
verify
(
tvm
.
max
(
tvm
.
max
(
vx
,
y
.
astype
(
"int32x2"
)),
x
.
astype
(
"int32x2"
)),
tvm
.
max
(
vx
,
tvm
.
max
(
y
,
x
)
.
astype
(
"int32x2"
)))
## Logical rules
ck
.
verify
(
y
.
astype
(
"int32x2"
)
.
equal
(
x
.
astype
(
"int32x2"
)),
(
y
.
equal
(
x
))
.
astype
(
"uint1x2"
))
ck
.
verify
(
tvm
.
expr
.
NE
(
y
.
astype
(
"int32x2"
),
(
x
.
astype
(
"int32x2"
))),
(
tvm
.
expr
.
NE
(
y
,
x
))
.
astype
(
"uint1x2"
))
ck
.
verify
(
y
.
astype
(
"int32x2"
)
>
x
.
astype
(
"int32x2"
),
(
x
<
y
)
.
astype
(
"uint1x2"
))
ck
.
verify
(
y
.
astype
(
"int32x2"
)
>=
x
.
astype
(
"int32x2"
),
(
x
<=
y
)
.
astype
(
"uint1x2"
))
ck
.
verify
(
y
.
astype
(
"int32x2"
)
<
x
.
astype
(
"int32x2"
),
(
y
<
x
)
.
astype
(
"uint1x2"
))
ck
.
verify
(
y
.
astype
(
"int32x2"
)
<=
x
.
astype
(
"int32x2"
),
(
y
<=
x
)
.
astype
(
"uint1x2"
))
ck
.
verify
(
tvm
.
expr
.
And
(
y
.
astype
(
"int32x2"
)
<=
x
.
astype
(
"int32x2"
),
vc
.
astype
(
"uint1x2"
)),
(
tvm
.
expr
.
And
(
y
<=
x
,
vc
))
.
astype
(
"uint1x2"
))
ck
.
verify
(
tvm
.
expr
.
Or
(
y
.
astype
(
"int32x2"
)
<=
x
.
astype
(
"int32x2"
),
vc
.
astype
(
"uint1x2"
)),
(
tvm
.
expr
.
Or
(
y
<=
x
,
vc
))
.
astype
(
"uint1x2"
))
def
test_select_simplify
():
ck
=
RewriteChecker
()
x
,
y
,
z
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
),
tvm
.
var
(
"z"
)
# Add rules
ck
.
verify
(
tvm
.
expr
.
Select
(
x
>
0
,
y
,
0
)
+
tvm
.
expr
.
Select
(
x
>
0
,
1
,
z
),
tvm
.
expr
.
Select
(
x
>
0
,
y
+
1
,
z
))
ck
.
verify
(
tvm
.
expr
.
Select
(
x
>
0
,
y
,
1
)
-
tvm
.
expr
.
Select
(
x
>
0
,
1
,
z
),
tvm
.
expr
.
Select
(
x
>
0
,
y
+
(
-
1
),
1
-
z
))
ck
.
verify
(
tvm
.
expr
.
Select
(
x
>
0
,
y
,
z
)
-
y
,
tvm
.
expr
.
Select
(
x
>
0
,
0
,
z
-
y
))
ck
.
verify
(
tvm
.
expr
.
Select
(
x
>
0
,
y
,
z
)
-
z
,
tvm
.
expr
.
Select
(
x
>
0
,
y
-
z
,
0
))
ck
.
verify
(
tvm
.
expr
.
Select
(
x
<
0
,
y
,
0
)
+
tvm
.
expr
.
Select
(
x
<
0
,
1
,
z
),
tvm
.
expr
.
Select
(
x
<
0
,
y
+
1
,
z
))
ck
.
verify
(
tvm
.
expr
.
Select
(
x
<
0
,
y
,
1
)
-
tvm
.
expr
.
Select
(
x
<
0
,
1
,
z
),
tvm
.
expr
.
Select
(
x
<
0
,
y
+
(
-
1
),
1
-
z
))
ck
.
verify
(
tvm
.
expr
.
Select
(
x
<
0
,
y
,
z
)
-
y
,
tvm
.
expr
.
Select
(
x
<
0
,
0
,
z
-
y
))
ck
.
verify
(
tvm
.
expr
.
Select
(
x
<
0
,
y
,
z
)
-
z
,
tvm
.
expr
.
Select
(
x
<
0
,
y
-
z
,
0
))
ck
.
verify
(
tvm
.
min
(
tvm
.
expr
.
Select
(
x
<
0
,
y
,
0
),
tvm
.
expr
.
Select
(
x
<
0
,
1
,
z
)),
tvm
.
expr
.
Select
(
x
<
0
,
tvm
.
min
(
y
,
1
),
tvm
.
min
(
0
,
z
)))
ck
.
verify
(
tvm
.
max
(
tvm
.
expr
.
Select
(
x
<
0
,
y
,
0
),
tvm
.
expr
.
Select
(
x
<
0
,
1
,
z
)),
tvm
.
expr
.
Select
(
x
<
0
,
tvm
.
max
(
y
,
1
),
tvm
.
max
(
0
,
z
)))
ck
.
verify
(
tvm
.
expr
.
Select
(
x
*
3
+
1
!=
0
,
y
,
z
),
y
)
ck
.
verify
(
tvm
.
expr
.
Select
(
x
*
3
+
1
==
0
,
y
,
z
),
z
)
ck
.
verify
(
tvm
.
expr
.
Select
(
x
>
0
,
y
+
1
,
y
+
1
),
y
+
1
)
def
test_add_index_simplify
():
...
...
@@ -242,11 +278,231 @@ def test_mod_index_simplify():
ck
.
verify
((
x
*
10
+
1
+
y
*
2
+
2
)
%
2
,
1
)
def
test_min_index_simplify
():
ck
=
RewriteChecker
()
x
,
y
,
z
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
),
tvm
.
var
(
"z"
)
# const int bound
ck
.
verify
(
tvm
.
min
(
x
%
2
,
y
%
2
+
10
),
x
%
2
)
ck
.
verify
(
tvm
.
min
(
x
+
1
,
x
+
10
),
x
+
1
)
ck
.
verify
(
tvm
.
min
(
x
+
111
,
x
+
10
),
x
+
10
)
ck
.
verify
(
tvm
.
min
(
x
+
1
,
x
),
x
)
ck
.
verify
(
tvm
.
min
(
x
,
x
+
2
),
x
)
ck
.
verify
(
tvm
.
min
(
1
-
x
,
2
-
x
),
1
-
x
)
ck
.
verify
(
tvm
.
min
(
3
-
x
,
2
-
x
),
2
-
x
)
ck
.
verify
(
tvm
.
min
((
x
+
3
)
/
4
*
4
,
x
),
x
)
ck
.
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
0
,
1000
))
ck
.
verify
(
tvm
.
min
((
x
+
3
)
/
4
*
4
,
tvm
.
max
(
x
,
4
)),
tvm
.
max
(
x
,
4
))
ck
.
verify
(
tvm
.
min
(
x
,
(
x
+
3
)
/
4
*
4
),
x
)
ck
.
verify
(
tvm
.
min
(
tvm
.
max
(
x
,
4
),
(
x
+
3
)
/
4
*
4
),
tvm
.
max
(
x
,
4
))
ck
.
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
-
1000
,
1000
),
True
)
ck
.
verify
(
tvm
.
min
(
tvm
.
max
(
x
,
y
),
tvm
.
min
(
x
,
y
)),
tvm
.
min
(
x
,
y
))
ck
.
verify
(
tvm
.
min
(
tvm
.
max
(
x
,
y
),
tvm
.
min
(
y
,
x
)),
tvm
.
min
(
x
,
y
))
ck
.
verify
(
tvm
.
min
(
tvm
.
max
(
x
,
y
),
x
),
x
)
ck
.
verify
(
tvm
.
min
(
tvm
.
max
(
y
,
x
),
x
),
x
)
ck
.
verify
(
tvm
.
min
(
tvm
.
min
(
x
,
y
),
x
),
tvm
.
min
(
x
,
y
))
ck
.
verify
(
tvm
.
min
(
tvm
.
min
(
x
,
y
),
y
),
tvm
.
min
(
x
,
y
))
ck
.
verify
(
tvm
.
min
(
x
,
tvm
.
max
(
x
,
y
)),
x
)
ck
.
verify
(
tvm
.
min
(
x
,
tvm
.
max
(
y
,
x
)),
x
)
ck
.
verify
(
tvm
.
min
(
x
,
tvm
.
min
(
x
,
y
)),
tvm
.
min
(
x
,
y
))
ck
.
verify
(
tvm
.
min
(
y
,
tvm
.
min
(
x
,
y
)),
tvm
.
min
(
x
,
y
))
ck
.
verify
(
tvm
.
min
(
tvm
.
min
(
tvm
.
min
(
x
,
y
),
z
),
y
),
tvm
.
min
(
tvm
.
min
(
x
,
y
),
z
))
ck
.
verify
(
tvm
.
min
(
tvm
.
min
(
tvm
.
min
(
tvm
.
min
(
x
,
y
),
z
),
x
*
2
),
y
),
tvm
.
min
(
tvm
.
min
(
tvm
.
min
(
x
,
y
),
z
),
x
*
2
))
ck
.
verify
(
tvm
.
min
(
tvm
.
min
(
tvm
.
min
(
tvm
.
min
(
tvm
.
min
(
x
,
y
),
z
),
x
*
2
),
z
*
2
),
y
),
tvm
.
min
(
tvm
.
min
(
tvm
.
min
(
tvm
.
min
(
x
,
y
),
z
),
x
*
2
),
z
*
2
))
ck
.
verify
(
tvm
.
min
(
tvm
.
max
(
x
,
y
),
tvm
.
max
(
x
,
z
)),
tvm
.
max
(
tvm
.
min
(
y
,
z
),
x
))
ck
.
verify
(
tvm
.
min
(
tvm
.
max
(
x
,
y
),
tvm
.
max
(
z
,
x
)),
tvm
.
max
(
tvm
.
min
(
y
,
z
),
x
))
ck
.
verify
(
tvm
.
min
(
tvm
.
max
(
y
,
x
),
tvm
.
max
(
x
,
z
)),
tvm
.
max
(
tvm
.
min
(
y
,
z
),
x
))
ck
.
verify
(
tvm
.
min
(
tvm
.
max
(
y
,
x
),
tvm
.
max
(
z
,
x
)),
tvm
.
max
(
tvm
.
min
(
y
,
z
),
x
))
ck
.
verify
(
tvm
.
min
(
y
+
x
,
z
+
x
),
tvm
.
min
(
y
,
z
)
+
x
)
ck
.
verify
(
tvm
.
min
(
y
+
x
,
x
+
z
),
tvm
.
min
(
y
,
z
)
+
x
)
ck
.
verify
(
tvm
.
min
(
x
+
y
,
z
+
x
),
tvm
.
min
(
y
,
z
)
+
x
)
ck
.
verify
(
tvm
.
min
(
x
+
y
,
x
+
z
),
tvm
.
min
(
y
,
z
)
+
x
)
ck
.
verify
(
tvm
.
min
(
x
-
y
,
x
-
z
),
x
-
tvm
.
max
(
y
,
z
))
ck
.
verify
(
tvm
.
min
(
y
-
x
,
z
-
x
),
tvm
.
min
(
y
,
z
)
-
x
)
ck
.
verify
(
tvm
.
min
(
tvm
.
min
(
x
,
1
),
10
),
tvm
.
min
(
x
,
1
))
ck
.
verify
(
tvm
.
min
(
tvm
.
min
(
x
,
11
),
10
),
tvm
.
min
(
x
,
10
))
ck
.
verify
(
tvm
.
min
(
x
/
10
,
y
/
10
),
tvm
.
min
(
x
,
y
)
/
10
)
ck
.
verify
(
tvm
.
min
(
x
/
(
-
10
),
y
/
(
-
10
)),
tvm
.
max
(
x
,
y
)
/
(
-
10
))
ck
.
verify
(
tvm
.
min
(
x
*
3
,
9
),
tvm
.
min
(
x
,
3
)
*
3
)
def
test_max_index_simplify
():
ck
=
RewriteChecker
()
x
,
y
,
z
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
),
tvm
.
var
(
"z"
)
# const int bound
ck
.
verify
(
tvm
.
max
(
x
%
2
,
y
%
2
+
10
),
y
%
2
+
10
)
ck
.
verify
(
tvm
.
max
(
x
+
1
,
x
+
10
),
x
+
10
)
ck
.
verify
(
tvm
.
max
(
x
+
111
,
x
+
10
),
x
+
111
)
ck
.
verify
(
tvm
.
max
(
x
+
1
,
x
),
x
+
1
)
ck
.
verify
(
tvm
.
max
(
x
,
x
+
2
),
x
+
2
)
ck
.
verify
(
tvm
.
max
(
1
-
x
,
2
-
x
),
2
-
x
)
ck
.
verify
(
tvm
.
max
(
3
-
x
,
2
-
x
),
3
-
x
)
ck
.
verify
(
tvm
.
max
((
x
+
3
)
/
4
*
4
,
x
),
(
x
+
3
)
/
4
*
4
)
ck
.
verify
(
tvm
.
max
(
tvm
.
min
(
x
,
y
),
tvm
.
max
(
x
,
y
)),
tvm
.
max
(
x
,
y
))
ck
.
verify
(
tvm
.
max
(
tvm
.
min
(
x
,
y
),
tvm
.
max
(
y
,
x
)),
tvm
.
max
(
x
,
y
))
ck
.
verify
(
tvm
.
max
(
tvm
.
min
(
x
,
y
),
x
),
x
)
ck
.
verify
(
tvm
.
max
(
tvm
.
min
(
y
,
x
),
x
),
x
)
ck
.
verify
(
tvm
.
max
(
tvm
.
max
(
x
,
y
),
x
),
tvm
.
max
(
x
,
y
))
ck
.
verify
(
tvm
.
max
(
tvm
.
max
(
x
,
y
),
y
),
tvm
.
max
(
x
,
y
))
ck
.
verify
(
tvm
.
max
(
x
,
tvm
.
min
(
x
,
y
)),
x
)
ck
.
verify
(
tvm
.
max
(
x
,
tvm
.
min
(
y
,
x
)),
x
)
ck
.
verify
(
tvm
.
max
(
x
,
tvm
.
max
(
x
,
y
)),
tvm
.
max
(
x
,
y
))
ck
.
verify
(
tvm
.
max
(
y
,
tvm
.
max
(
x
,
y
)),
tvm
.
max
(
x
,
y
))
ck
.
verify
(
tvm
.
max
(
tvm
.
max
(
tvm
.
max
(
x
,
y
),
z
),
y
),
tvm
.
max
(
tvm
.
max
(
x
,
y
),
z
))
ck
.
verify
(
tvm
.
max
(
tvm
.
max
(
tvm
.
max
(
tvm
.
max
(
x
,
y
),
z
),
x
*
2
),
y
),
tvm
.
max
(
tvm
.
max
(
tvm
.
max
(
x
,
y
),
z
),
x
*
2
))
ck
.
verify
(
tvm
.
max
(
tvm
.
max
(
tvm
.
max
(
tvm
.
max
(
tvm
.
max
(
x
,
y
),
z
),
x
*
2
),
z
*
2
),
y
),
tvm
.
max
(
tvm
.
max
(
tvm
.
max
(
tvm
.
max
(
x
,
y
),
z
),
x
*
2
),
z
*
2
))
ck
.
verify
(
tvm
.
max
(
tvm
.
min
(
x
,
y
),
tvm
.
min
(
x
,
z
)),
tvm
.
min
(
tvm
.
max
(
y
,
z
),
x
))
ck
.
verify
(
tvm
.
max
(
tvm
.
min
(
x
,
y
),
tvm
.
min
(
z
,
x
)),
tvm
.
min
(
tvm
.
max
(
y
,
z
),
x
))
ck
.
verify
(
tvm
.
max
(
tvm
.
min
(
y
,
x
),
tvm
.
min
(
x
,
z
)),
tvm
.
min
(
tvm
.
max
(
y
,
z
),
x
))
ck
.
verify
(
tvm
.
max
(
tvm
.
min
(
y
,
x
),
tvm
.
min
(
z
,
x
)),
tvm
.
min
(
tvm
.
max
(
y
,
z
),
x
))
ck
.
verify
(
tvm
.
max
(
y
+
x
,
z
+
x
),
tvm
.
max
(
y
,
z
)
+
x
)
ck
.
verify
(
tvm
.
max
(
y
+
x
,
x
+
z
),
tvm
.
max
(
y
,
z
)
+
x
)
ck
.
verify
(
tvm
.
max
(
x
+
y
,
z
+
x
),
tvm
.
max
(
y
,
z
)
+
x
)
ck
.
verify
(
tvm
.
max
(
x
+
y
,
x
+
z
),
tvm
.
max
(
y
,
z
)
+
x
)
ck
.
verify
(
tvm
.
max
(
x
-
y
,
x
-
z
),
x
-
tvm
.
min
(
y
,
z
))
ck
.
verify
(
tvm
.
max
(
y
-
x
,
z
-
x
),
tvm
.
max
(
y
,
z
)
-
x
)
ck
.
verify
(
tvm
.
max
(
tvm
.
max
(
x
,
1
),
10
),
tvm
.
max
(
x
,
10
))
ck
.
verify
(
tvm
.
max
(
tvm
.
max
(
x
,
11
),
10
),
tvm
.
max
(
x
,
11
))
ck
.
verify
(
tvm
.
max
(
x
/
10
,
y
/
10
),
tvm
.
max
(
x
,
y
)
/
10
)
ck
.
verify
(
tvm
.
max
(
x
/
(
-
10
),
y
/
(
-
10
)),
tvm
.
min
(
x
,
y
)
/
(
-
10
))
ck
.
verify
(
tvm
.
max
(
x
*
3
,
9
),
tvm
.
max
(
x
,
3
)
*
3
)
def
test_cmp_simplify
():
ck
=
RewriteChecker
()
x
,
y
,
z
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
),
tvm
.
var
(
"z"
)
# const int bound
ck
.
verify
((
x
%
2
+
10
)
.
equal
(
0
),
tvm
.
const
(
0
,
"bool"
))
ck
.
verify
(
tvm
.
expr
.
NE
(
x
%
2
+
10
,
0
),
tvm
.
const
(
1
,
"bool"
))
ck
.
verify
(
x
%
2
+
10
>
1
,
tvm
.
const
(
1
,
"bool"
))
ck
.
verify
(
x
%
2
+
10
<=
1
,
tvm
.
const
(
0
,
"bool"
))
ck
.
verify
(
x
*
3
+
10
==
0
,
tvm
.
const
(
0
,
"bool"
))
ck
.
verify
(
x
*
3
+
10
!=
0
,
tvm
.
const
(
1
,
"bool"
))
# canonicalization
ck
.
verify
((
x
-
10
)
.
equal
(
0
),
x
.
equal
(
10
))
ck
.
verify
((
10
-
x
)
.
equal
(
0
),
x
.
equal
(
10
))
ck
.
verify
((
x
*
y
)
.
equal
(
0
),
tvm
.
expr
.
Or
(
x
.
equal
(
0
),
y
.
equal
(
0
)))
# cmp bound
ck
.
verify
(
x
+
y
<
x
+
z
,
y
<
z
)
ck
.
verify
(
x
+
y
<
z
+
x
,
y
<
z
)
ck
.
verify
(
y
+
x
<
x
+
z
,
y
<
z
)
ck
.
verify
(
y
+
x
<
z
+
x
,
y
<
z
)
ck
.
verify
(
y
-
x
<
z
-
x
,
y
<
z
)
ck
.
verify
(
x
-
y
<
x
-
z
,
z
<
y
)
ck
.
verify
(
x
<
z
+
x
,
tvm
.
expr
.
LT
(
0
,
z
))
ck
.
verify
(
x
<
x
+
z
,
tvm
.
expr
.
LT
(
0
,
z
))
ck
.
verify
(
100
<
x
+
1
,
tvm
.
expr
.
LT
(
99
,
x
))
ck
.
verify
(
1
<
100
-
x
,
tvm
.
expr
.
LT
(
x
,
99
))
ck
.
verify
(
x
*
3
<
y
*
3
,
x
<
y
)
ck
.
verify
(
x
*
(
-
3
)
<
y
*
(
-
3
),
y
<
x
)
ck
.
verify
(
x
*
3
>=
y
*
3
,
y
<=
x
)
ck
.
verify
(
x
*
4
>=
2
,
tvm
.
expr
.
LE
(
1
,
x
))
ck
.
verify
(
x
*
2
>=
50
,
tvm
.
expr
.
LE
(
25
,
x
))
ck
.
verify
(
x
/
2
<
3
,
x
<
6
)
ck
.
verify
(
x
*
4
<=
2
,
x
<=
0
)
ck
.
verify
(
3
<
x
/
2
,
tvm
.
expr
.
LT
(
7
,
x
))
ck
.
verify
(
x
/
4
*
4
<
x
,
tvm
.
expr
.
LT
(
0
,
x
%
4
))
ck
.
verify
(
x
/
4
*
4
>=
x
,
tvm
.
expr
.
LE
(
x
%
4
,
0
))
ck
.
verify
(
x
/
4
*
4
<
x
+
y
,
tvm
.
expr
.
LT
(
0
,
x
%
4
+
y
))
ck
.
verify
(
x
/
4
*
4
<
x
-
y
,
tvm
.
expr
.
LT
(
y
,
x
%
4
))
ck
.
verify
((
x
+
2
)
/
4
*
4
>=
x
,
tvm
.
expr
.
LE
((
x
+
2
)
%
4
,
2
))
ck
.
verify
((
x
+
2
)
/
4
*
4
>=
x
+
y
,
tvm
.
expr
.
LE
((
x
+
2
)
%
4
+
y
,
2
))
ck
.
verify
((
x
+
2
)
/
4
*
4
>=
x
-
y
,
tvm
.
expr
.
LE
((
x
+
2
)
%
4
+
(
-
2
),
y
))
ck
.
verify
(
tvm
.
min
(
x
,
11
)
<
10
,
x
<
10
)
ck
.
verify
(
tvm
.
min
(
x
,
8
)
<
10
,
tvm
.
const
(
1
,
"bool"
))
ck
.
verify
(
tvm
.
max
(
8
,
x
)
>
10
,
tvm
.
expr
.
LT
(
10
,
x
))
ck
.
verify
(
x
+
1
<
tvm
.
max
(
8
,
x
),
x
<
7
)
def
test_logical_simplify
():
ck
=
RewriteChecker
()
x
,
y
,
z
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
),
tvm
.
var
(
"z"
)
ck
.
verify
(
tvm
.
expr
.
And
(
tvm
.
expr
.
EQ
(
x
,
y
),
tvm
.
expr
.
NE
(
x
,
y
)),
tvm
.
const
(
False
,
"bool"
))
ck
.
verify
(
tvm
.
expr
.
And
(
tvm
.
expr
.
NE
(
x
,
y
),
tvm
.
expr
.
EQ
(
x
,
y
)),
tvm
.
const
(
False
,
"bool"
))
ck
.
verify
(
tvm
.
expr
.
And
(
x
>
1
,
tvm
.
expr
.
Not
(
x
>
1
)),
tvm
.
const
(
False
,
"bool"
))
ck
.
verify
(
tvm
.
expr
.
And
(
x
<=
y
,
y
<
x
),
tvm
.
const
(
False
,
"bool"
))
ck
.
verify
(
tvm
.
expr
.
And
(
y
<
x
,
y
<=
x
),
tvm
.
const
(
False
,
"bool"
))
ck
.
verify
(
tvm
.
expr
.
And
(
x
<
1
,
0
<
x
),
tvm
.
const
(
False
,
"bool"
))
ck
.
verify
(
tvm
.
expr
.
And
(
x
<
0
,
1
<
x
),
tvm
.
const
(
False
,
"bool"
))
ck
.
verify
(
tvm
.
expr
.
And
(
x
<
1
,
1
<=
x
),
tvm
.
const
(
False
,
"bool"
))
ck
.
verify
(
tvm
.
expr
.
And
(
x
<=
1
,
1
<
x
),
tvm
.
const
(
False
,
"bool"
))
ck
.
verify
(
tvm
.
expr
.
And
(
1
<=
x
,
x
<
1
),
tvm
.
const
(
False
,
"bool"
))
ck
.
verify
(
tvm
.
expr
.
And
(
1
<
x
,
x
<=
1
),
tvm
.
const
(
False
,
"bool"
))
ck
.
verify
(
tvm
.
expr
.
And
(
x
<=
1
,
2
<=
x
),
tvm
.
const
(
False
,
"bool"
))
ck
.
verify
(
tvm
.
expr
.
And
(
2
<=
x
,
x
<=
1
),
tvm
.
const
(
False
,
"bool"
))
ck
.
verify
(
tvm
.
expr
.
And
(
x
==
1
,
x
!=
2
),
x
==
1
)
ck
.
verify
(
tvm
.
expr
.
Or
(
tvm
.
expr
.
EQ
(
x
,
y
),
tvm
.
expr
.
NE
(
x
,
y
)),
tvm
.
const
(
True
,
"bool"
))
ck
.
verify
(
tvm
.
expr
.
Or
(
tvm
.
expr
.
NE
(
x
,
y
),
tvm
.
expr
.
EQ
(
x
,
y
)),
tvm
.
const
(
True
,
"bool"
))
ck
.
verify
(
tvm
.
expr
.
Or
(
x
>
y
,
tvm
.
expr
.
Not
(
x
<
y
)),
tvm
.
const
(
True
,
"bool"
))
ck
.
verify
(
tvm
.
expr
.
Or
(
x
<=
y
,
y
<
x
),
tvm
.
const
(
True
,
"bool"
))
ck
.
verify
(
tvm
.
expr
.
Or
(
y
<
x
,
y
<=
x
),
tvm
.
const
(
True
,
"bool"
))
ck
.
verify
(
tvm
.
expr
.
Or
(
x
<
1
,
0
<
x
),
tvm
.
const
(
True
,
"bool"
))
ck
.
verify
(
tvm
.
expr
.
Or
(
0
<
x
,
x
<
1
),
tvm
.
const
(
True
,
"bool"
))
ck
.
verify
(
tvm
.
expr
.
Or
(
x
<
1
,
1
<=
x
),
tvm
.
const
(
True
,
"bool"
))
ck
.
verify
(
tvm
.
expr
.
Or
(
x
<=
1
,
1
<
x
),
tvm
.
const
(
True
,
"bool"
))
ck
.
verify
(
tvm
.
expr
.
Or
(
1
<=
x
,
x
<
1
),
tvm
.
const
(
True
,
"bool"
))
ck
.
verify
(
tvm
.
expr
.
Or
(
1
<
x
,
x
<=
1
),
tvm
.
const
(
True
,
"bool"
))
ck
.
verify
(
tvm
.
expr
.
Or
(
x
<=
1
,
2
<=
x
),
tvm
.
const
(
True
,
"bool"
))
ck
.
verify
(
tvm
.
expr
.
Or
(
2
<=
x
,
x
<=
1
),
tvm
.
const
(
True
,
"bool"
))
ck
.
verify
(
tvm
.
expr
.
Or
(
x
!=
1
,
x
==
2
),
x
!=
1
)
if
__name__
==
"__main__"
:
test_
mod_index
_simplify
()
test_
cmp
_simplify
()
test_vector_simplify
()
test_add_index_simplify
()
test_sub_index_simplify
()
test_mul_index_simplify
()
test_div_index_simplify
()
test_max_index_simplify
()
test_min_index_simplify
()
test_mod_index_simplify
()
test_select_simplify
()
test_logical_simplify
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment