Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
ec95675c
Unverified
Commit
ec95675c
authored
Mar 10, 2019
by
Tianqi Chen
Committed by
GitHub
Mar 10, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[ARITH] Analyzer RewriteSimplifier: add/sub/mul/div/mod (#2722)
parent
0bf64ee0
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
1016 additions
and
8 deletions
+1016
-8
include/tvm/arithmetic.h
+35
-0
python/tvm/arith.py
+16
-0
src/api/api_arith.cc
+4
-0
src/arithmetic/analyzer.cc
+9
-2
src/arithmetic/const_fold.h
+3
-1
src/arithmetic/pattern_match.h
+46
-5
src/arithmetic/rewrite_simplify.cc
+650
-0
tests/cpp/pattern_match_test.cc
+1
-0
tests/python/unittest/test_arith_rewrite_simplify.py
+252
-0
No files found.
include/tvm/arithmetic.h
View file @
ec95675c
...
...
@@ -193,6 +193,39 @@ class ModularSetAnalyzer {
};
/*!
* \brief Rewrite-rule based simplifier.
*/
class
RewriteSimplifier
{
public
:
/*!
* \brief analyze the expr
* \param expr The expression of interest.
* \return the result of the analysis.
*/
Expr
operator
()(
const
Expr
&
expr
);
/*!
* \brief Update binding of var to a new expression.
*
* \param var The variable of interest.
* \param new_expr
* \param override Whether do we allow override of existing information.
*/
void
Update
(
const
Var
&
var
,
const
Expr
&
new_expr
,
bool
override
=
false
);
private
:
friend
class
Analyzer
;
friend
class
ConstraintContext
;
explicit
RewriteSimplifier
(
Analyzer
*
parent
);
~
RewriteSimplifier
();
class
Impl
;
/*! \brief Internal impl */
Impl
*
impl_
;
};
/*!
* \brief A RAII constraint context.
*
* \code
...
...
@@ -242,6 +275,8 @@ class Analyzer {
ConstIntBoundAnalyzer
const_int_bound
;
/*! \brief sub-analyzer: modular set */
ModularSetAnalyzer
modular_set
;
/*! \brief sub-analyzer rewrite simplfy */
RewriteSimplifier
rewrite_simplify
;
/*! \brief constructor */
Analyzer
();
/*!
...
...
python/tvm/arith.py
View file @
ec95675c
...
...
@@ -96,6 +96,7 @@ class Analyzer:
self
.
_const_int_bound_update
=
_mod
(
"const_int_bound_update"
)
self
.
_bind
=
_mod
(
"bind"
)
self
.
_modular_set
=
_mod
(
"modular_set"
)
self
.
_rewrite_simplify
=
_mod
(
"rewrite_simplify"
)
self
.
_enter_constraint_context
=
_mod
(
"enter_constraint_context"
)
def
const_int_bound
(
self
,
expr
):
...
...
@@ -128,6 +129,21 @@ class Analyzer:
"""
return
self
.
_modular_set
(
expr
)
def
rewrite_simplify
(
self
,
expr
):
"""Simplify expression via rewriting rules.
Parameters
----------
expr : tvm.Expr
The expression.
Returns
-------
result : Expr
The result.
"""
return
self
.
_rewrite_simplify
(
expr
)
def
bind
(
self
,
var
,
expr
):
"""Bind a variable to the expression.
...
...
src/api/api_arith.cc
View file @
ec95675c
...
...
@@ -98,6 +98,10 @@ TVM_REGISTER_API("arith._CreateAnalyzer")
return
PackedFunc
([
self
](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
self
->
const_int_bound
.
Update
(
args
[
0
],
args
[
1
],
args
[
2
]);
});
}
else
if
(
name
==
"rewrite_simplify"
)
{
return
PackedFunc
([
self
](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
self
->
rewrite_simplify
(
args
[
0
]);
});
}
else
if
(
name
==
"bind"
)
{
return
PackedFunc
([
self
](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
auto
&
sptr
=
args
[
1
].
node_sptr
();
...
...
src/arithmetic/analyzer.cc
View file @
ec95675c
...
...
@@ -2,6 +2,7 @@
* Copyright (c) 2019 by Contributors
* \file tvm/arithmetic/analyzer.cc
*/
#include <tvm/ir.h>
#include <tvm/arithmetic.h>
namespace
tvm
{
...
...
@@ -9,19 +10,22 @@ namespace arith {
Analyzer
::
Analyzer
()
:
const_int_bound
(
this
),
modular_set
(
this
)
{
modular_set
(
this
),
rewrite_simplify
(
this
)
{
}
void
Analyzer
::
Bind
(
const
VarExpr
&
v
,
const
Expr
&
expr
)
{
Var
var
(
v
.
node_
);
this
->
const_int_bound
.
Update
(
var
,
this
->
const_int_bound
(
expr
));
this
->
modular_set
.
Update
(
var
,
this
->
modular_set
(
expr
));
this
->
rewrite_simplify
.
Update
(
var
,
this
->
rewrite_simplify
(
expr
));
}
void
Analyzer
::
Bind
(
const
VarExpr
&
v
,
const
Range
&
range
)
{
Var
var
(
v
.
node_
);
this
->
const_int_bound
.
Bind
(
var
,
range
);
// skip modular_set
// skip rewrite simplify
}
ConstraintContext
::
ConstraintContext
(
Analyzer
*
analyzer
,
const
Expr
&
constraint
)
{
...
...
@@ -36,7 +40,10 @@ ConstraintContext::ConstraintContext(Analyzer* analyzer, const Expr& constraint)
}
bool
Analyzer
::
CanProveGreaterEqual
(
const
Expr
&
expr
,
int64_t
lower_bound
)
{
auto
bd
=
this
->
const_int_bound
(
expr
);
if
(
const
auto
*
ptr
=
expr
.
as
<
ir
::
IntImm
>
())
{
return
ptr
->
value
>
lower_bound
;
}
auto
bd
=
this
->
const_int_bound
(
this
->
rewrite_simplify
(
expr
));
if
(
bd
->
min_value
>=
lower_bound
)
return
true
;
return
false
;
}
...
...
src/arithmetic/const_fold.h
View file @
ec95675c
...
...
@@ -23,7 +23,9 @@ namespace arith {
* \return nullptr if constant fold fails, otherwise return folded result.
*/
template
<
typename
Op
>
inline
Expr
TryConstFold
(
Expr
a
,
Expr
b
);
inline
Expr
TryConstFold
(
Expr
a
,
Expr
b
)
{
return
Expr
();
}
/*!
* \brief Try to run unary compute with constant folding.
...
...
src/arithmetic/pattern_match.h
View file @
ec95675c
...
...
@@ -49,6 +49,7 @@
#include <tvm/ir_pass.h>
#include <tuple>
#include "const_fold.h"
namespace
tvm
{
namespace
arith
{
...
...
@@ -242,7 +243,11 @@ class PBinaryExpr :
}
Expr
Eval
()
const
{
return
NodeType
::
make
(
a_
.
Eval
(),
b_
.
Eval
());
Expr
lhs
=
a_
.
Eval
();
Expr
rhs
=
b_
.
Eval
();
Expr
ret
=
TryConstFold
<
NodeType
>
(
lhs
,
rhs
);
if
(
ret
.
defined
())
return
ret
;
return
NodeType
::
make
(
lhs
,
rhs
);
}
private
:
...
...
@@ -250,12 +255,48 @@ class PBinaryExpr :
typename
TB
::
Nested
b_
;
};
template
<
typename
TA
>
class
PConstWithTypeLike
:
public
Pattern
<
PConstWithTypeLike
<
TA
>
>
{
public
:
PConstWithTypeLike
(
const
TA
&
ref
,
int64_t
value
)
:
ref_
(
ref
),
value_
(
value
)
{}
void
InitMatch_
()
const
{}
bool
Match_
(
const
NodeRef
&
node
)
const
{
if
(
const
ir
::
IntImm
*
ptr
=
node
.
as
<
ir
::
IntImm
>
())
{
return
ptr
->
value
==
value_
;
}
else
{
return
false
;
}
}
Expr
Eval
()
const
{
return
make_const
(
ref_
.
Eval
().
type
(),
value_
);
}
private
:
typename
TA
::
Nested
ref_
;
int64_t
value_
;
};
#define TVM_PATTERN_BINARY_OP(FuncName, NodeName) \
template<typename TA, typename TB> \
inline PBinaryExpr<NodeName, TA, TB> \
FuncName(const Pattern<TA>& a, const Pattern<TB>& b) { \
#define TVM_PATTERN_BINARY_OP(FuncName, NodeName)
\
template<typename TA, typename TB>
\
inline PBinaryExpr<NodeName, TA, TB>
\
FuncName(const Pattern<TA>& a, const Pattern<TB>& b) {
\
return PBinaryExpr<NodeName, TA, TB>(a.derived(), b.derived()); \
} \
template<typename TA> \
inline PBinaryExpr<NodeName, TA, PConstWithTypeLike<TA> > \
FuncName(const Pattern<TA>& a, int64_t b) { \
return FuncName(a, PConstWithTypeLike<TA>(a.derived(), b)); \
} \
template<typename TA> \
inline PBinaryExpr<NodeName, PConstWithTypeLike<TA>, TA> \
FuncName(int64_t b, const Pattern<TA>& a) { \
return FuncName(PConstWithTypeLike<TA>(a.derived(), b), a); \
}
// arithmetic expressions
...
...
src/arithmetic/rewrite_simplify.cc
0 → 100644
View file @
ec95675c
/*!
* Copyright (c) 2019 by Contributors
* \file rewrite_simplify.cc
* \brief Rewrite-rule based simplification.
*/
// Acknowledgement: Most rewrite-rules are from Halide.
#include <tvm/arithmetic.h>
#include <tvm/expr_operator.h>
#include <tvm/ir_mutator.h>
#include "const_fold.h"
#include "pattern_match.h"
namespace
tvm
{
namespace
arith
{
using
namespace
ir
;
// macro for doing simple rewrite
#define TVM_TRY_REWRITE(SrcExpr, ResExpr) \
if ((SrcExpr).Match(ret)) { \
return (ResExpr).Eval(); \
}
// macro for rewrite + recursively rewrite ResExpr
#define TVM_TRY_RECURSIVE_REWRITE(SrcExpr, ResExpr) \
if ((SrcExpr).Match(ret)) { \
return RecursiveRewrite((ResExpr).Eval()); \
}
// macro rewrite only if CondExor is true after match.
#define TVM_TRY_REWRITE_IF(SrcExpr, ResExpr, CondExpr) \
if ((SrcExpr).Match(ret) && (CondExpr)) { \
return (ResExpr).Eval(); \
}
// macro rewrite + recursive_rewrite only if CondExor is true after match.
#define TVM_TRY_RECURSIVE_REWRITE_IF(SrcExpr, ResExpr, CondExpr) \
if ((SrcExpr).Match(ret) && (CondExpr)) { \
return RecursiveRewrite((ResExpr).Eval()); \
}
// NOTE for developers:
//
// We mainly focus on index expression simplification.
// Besides the RewriteSimplifier, some cases can be better
// handled by CanonicalSimplifier.
//
class
RewriteSimplifier
::
Impl
:
public
IRMutator
{
public
:
explicit
Impl
(
Analyzer
*
parent
)
:
parent_
(
parent
)
{}
void
Update
(
const
Var
&
var
,
const
Expr
&
info
,
bool
override
)
{
if
(
!
override
)
{
CHECK
(
!
var_map_
.
count
(
var
));
}
var_map_
[
var
]
=
info
;
}
// Run simplification in post order
Expr
PostOrderSimplify
(
Expr
expr
,
int
max_iter
=
2
)
{
for
(
int
i
=
0
;
i
<
max_iter
;
++
i
)
{
Expr
new_expr
=
this
->
Mutate
(
expr
);
if
(
new_expr
.
same_as
(
expr
))
return
expr
;
expr
=
new_expr
;
}
return
expr
;
}
Expr
Mutate_
(
const
Add
*
op
,
const
Expr
&
self
)
final
;
Expr
Mutate_
(
const
Sub
*
op
,
const
Expr
&
self
)
final
;
Expr
Mutate_
(
const
Mul
*
op
,
const
Expr
&
self
)
final
;
Expr
Mutate_
(
const
Div
*
op
,
const
Expr
&
self
)
final
;
Expr
Mutate_
(
const
Mod
*
op
,
const
Expr
&
self
)
final
;
private
:
// reference to the main analyzer
Analyzer
*
parent_
;
// counter to record recursive rewrite depth.
int
recur_depth_
{
0
};
// internal variable map
std
::
unordered_map
<
Var
,
Expr
,
ExprHash
,
ExprEqual
>
var_map_
;
// maximum number of recursion allowed during a single pass.
static
const
constexpr
int
kMaxRecurDepth
=
5
;
// Whether x >= val
bool
CanProveGreaterEqual
(
const
Expr
&
x
,
int64_t
val
)
{
return
parent_
->
CanProveGreaterEqual
(
x
,
val
);
}
// Whether x == val
bool
CanProveEqual
(
const
Expr
&
x
,
int64_t
val
)
{
// TODO(tqchen) refer back to super-analyzer.
Expr
res
=
Mutate
(
x
);
if
(
const
auto
*
ptr
=
res
.
as
<
ir
::
IntImm
>
())
{
return
ptr
->
value
==
val
;
}
return
false
;
}
// Recursive rewrite x
// we limit maximum depth of recursive rewrite allowed to
// avoid infinite loop
Expr
RecursiveRewrite
(
const
Expr
&
x
)
{
if
(
recur_depth_
>=
kMaxRecurDepth
)
return
x
;
++
recur_depth_
;
Expr
res
=
Mutate
(
x
);
--
recur_depth_
;
return
res
;
}
template
<
typename
TA
>
PConstWithTypeLike
<
TA
>
ZeroWithTypeLike
(
const
Pattern
<
TA
>&
pattern
)
{
return
PConstWithTypeLike
<
TA
>
(
pattern
.
derived
(),
0
);
}
};
Expr
RewriteSimplifier
::
Impl
::
Mutate_
(
const
Add
*
op
,
const
Expr
&
self
)
{
Expr
ret
=
IRMutator
::
Mutate_
(
op
,
self
);
op
=
ret
.
as
<
Add
>
();
Expr
const_res
=
TryConstFold
<
Add
>
(
op
->
a
,
op
->
b
);
if
(
const_res
.
defined
())
return
const_res
;
// Pattern var to match any expression
PVar
<
Expr
>
x
,
y
,
z
,
b1
,
b2
,
s1
,
s2
;
// Pattern var match IntImm
PVar
<
Integer
>
c1
,
c2
,
c3
;
// Pattern var for lanes in broadcast and ramp
PVar
<
int
>
lanes
;
// Vector rules
if
(
op
->
type
.
lanes
()
!=
1
)
{
TVM_TRY_REWRITE
(
ramp
(
b1
,
s1
,
lanes
)
+
ramp
(
b2
,
s2
,
lanes
),
ramp
(
b1
+
b2
,
s1
+
s2
,
lanes
));
TVM_TRY_REWRITE
(
ramp
(
b1
,
s1
,
lanes
)
+
broadcast
(
x
,
lanes
),
ramp
(
b1
+
x
,
s1
,
lanes
));
TVM_TRY_REWRITE
(
broadcast
(
x
,
lanes
)
+
ramp
(
b1
,
s1
,
lanes
),
ramp
(
x
+
b1
,
s1
,
lanes
));
TVM_TRY_REWRITE
(
broadcast
(
x
,
lanes
)
+
broadcast
(
y
,
lanes
),
broadcast
(
x
+
y
,
lanes
));
}
if
(
IsIndexType
(
op
->
type
))
{
// Index rules
// cancelation rules
TVM_TRY_REWRITE
((
x
-
y
)
+
y
,
x
);
TVM_TRY_REWRITE
(
x
+
(
y
-
x
),
y
);
TVM_TRY_REWRITE
((
x
-
y
)
+
(
y
-
z
),
x
-
z
);
TVM_TRY_REWRITE
((
x
-
y
)
+
(
z
-
x
),
z
-
y
);
TVM_TRY_REWRITE
(
min
(
x
,
y
-
z
)
+
z
,
min
(
x
+
z
,
y
));
TVM_TRY_REWRITE
(
min
(
x
-
z
,
y
)
+
z
,
min
(
x
,
y
+
z
));
TVM_TRY_REWRITE
(
max
(
x
,
y
-
z
)
+
z
,
max
(
x
+
z
,
y
));
TVM_TRY_REWRITE
(
max
(
x
-
z
,
y
)
+
z
,
max
(
x
,
y
+
z
));
TVM_TRY_REWRITE
(
max
(
x
,
y
)
+
min
(
x
,
y
),
x
+
y
);
TVM_TRY_REWRITE
(
min
(
x
,
y
)
+
max
(
x
,
y
),
x
+
y
);
TVM_TRY_REWRITE
(
max
(
x
,
y
)
+
min
(
y
,
x
),
x
+
y
);
TVM_TRY_REWRITE
(
min
(
x
,
y
)
+
max
(
y
,
x
),
x
+
y
);
TVM_TRY_REWRITE_IF
(
min
(
x
,
y
+
c1
)
+
c2
,
min
(
x
+
c2
,
y
),
c1
.
Eval
()
->
value
==
-
c2
.
Eval
()
->
value
);
TVM_TRY_REWRITE_IF
(
min
(
x
+
c1
,
y
)
+
c2
,
min
(
x
,
y
+
c2
),
c1
.
Eval
()
->
value
==
-
c2
.
Eval
()
->
value
);
TVM_TRY_REWRITE_IF
(
max
(
x
,
y
+
c1
)
+
c2
,
max
(
x
+
c2
,
y
),
c1
.
Eval
()
->
value
==
-
c2
.
Eval
()
->
value
);
TVM_TRY_REWRITE_IF
(
max
(
x
+
c1
,
y
)
+
c2
,
max
(
x
,
y
+
c2
),
c1
.
Eval
()
->
value
==
-
c2
.
Eval
()
->
value
);
// constant folding
// NOTE: canonicalization might better at this.
TVM_TRY_REWRITE
((
x
+
c1
)
+
c2
,
x
+
(
c1
+
c2
));
// mul co-efficient folding
TVM_TRY_REWRITE
(
x
+
x
,
x
*
2
);
TVM_TRY_REWRITE
(
x
*
y
+
x
,
x
*
(
y
+
1
));
TVM_TRY_REWRITE
(
y
*
x
+
x
,
x
*
(
y
+
1
));
TVM_TRY_REWRITE
(
x
+
y
*
x
,
x
*
(
1
+
y
));
TVM_TRY_REWRITE
(
x
+
x
*
y
,
x
*
(
1
+
y
));
TVM_TRY_REWRITE
(
x
*
y
+
x
*
z
,
x
*
(
y
+
z
));
TVM_TRY_REWRITE
(
y
*
x
+
x
*
z
,
x
*
(
y
+
z
));
TVM_TRY_REWRITE
(
x
*
y
+
z
*
x
,
x
*
(
y
+
z
));
TVM_TRY_REWRITE
(
y
*
x
+
z
*
x
,
x
*
(
y
+
z
));
// modular-div simplification
// Always pre-condition on positive integer domain
TVM_TRY_REWRITE_IF
(
(
x
/
c1
)
*
c1
+
x
%
c1
,
x
,
CanProveGreaterEqual
(
x
.
Eval
(),
0
)
&&
c1
.
Eval
()
->
value
>
0
);
// canonicalization rule
// will try rewrite again after canonicalization.
TVM_TRY_RECURSIVE_REWRITE
(
x
+
(
c1
-
y
),
(
x
-
y
)
+
c1
);
TVM_TRY_RECURSIVE_REWRITE
(
x
+
c1
+
y
,
(
x
+
y
)
+
c1
);
TVM_TRY_RECURSIVE_REWRITE
(
x
+
(
c1
+
y
),
(
x
+
y
)
+
c1
);
TVM_TRY_RECURSIVE_REWRITE
((
y
%
c1
)
+
x
*
c1
,
x
*
c1
+
(
y
%
c1
));
}
// condition rules.
TVM_TRY_REWRITE
(
select
(
x
,
b1
,
b2
)
+
select
(
x
,
s1
,
s2
),
select
(
x
,
b1
+
s1
,
b2
+
s2
));
// default value
return
ret
;
}
Expr
RewriteSimplifier
::
Impl
::
Mutate_
(
const
Sub
*
op
,
const
Expr
&
self
)
{
Expr
ret
=
IRMutator
::
Mutate_
(
op
,
self
);
op
=
ret
.
as
<
Sub
>
();
Expr
const_res
=
TryConstFold
<
Sub
>
(
op
->
a
,
op
->
b
);
if
(
const_res
.
defined
())
return
const_res
;
// Pattern var to match any expression
PVar
<
Expr
>
x
,
y
,
z
,
b1
,
b2
,
s1
,
s2
;
// Pattern var match IntImm
PVar
<
Integer
>
c1
,
c2
,
c3
;
// Pattern var for lanes in broadcast and ramp
PVar
<
int
>
lanes
;
// Vector rules
if
(
op
->
type
.
lanes
()
!=
1
)
{
TVM_TRY_REWRITE
(
ramp
(
b1
,
s1
,
lanes
)
-
ramp
(
b2
,
s2
,
lanes
),
ramp
(
b1
-
b2
,
s1
-
s2
,
lanes
));
TVM_TRY_REWRITE
(
ramp
(
b1
,
s1
,
lanes
)
-
broadcast
(
x
,
lanes
),
ramp
(
b1
-
x
,
s1
,
lanes
));
TVM_TRY_REWRITE
(
broadcast
(
x
,
lanes
)
-
ramp
(
b1
,
s1
,
lanes
),
ramp
(
x
-
b1
,
0
-
s1
,
lanes
));
TVM_TRY_REWRITE
(
broadcast
(
x
,
lanes
)
-
broadcast
(
y
,
lanes
),
broadcast
(
x
-
y
,
lanes
));
}
if
(
IsIndexType
(
op
->
type
))
{
// Index rules
// cancelation rules
TVM_TRY_REWRITE
((
x
+
y
)
-
y
,
x
);
TVM_TRY_REWRITE
((
x
+
y
)
-
x
,
y
);
TVM_TRY_REWRITE
(
x
-
(
y
+
x
),
0
-
y
);
TVM_TRY_REWRITE
(
x
-
(
x
+
y
),
0
-
y
);
TVM_TRY_REWRITE
(
min
(
x
,
y
)
-
x
,
min
(
0
,
y
-
x
));
TVM_TRY_REWRITE
(
min
(
x
,
y
)
-
y
,
min
(
x
-
y
,
0
));
TVM_TRY_REWRITE
(
max
(
x
,
y
)
-
x
,
max
(
0
,
y
-
x
));
TVM_TRY_REWRITE
(
max
(
x
,
y
)
-
y
,
max
(
x
-
y
,
0
));
TVM_TRY_REWRITE
(
x
-
max
(
x
,
y
),
min
(
0
,
x
-
y
));
TVM_TRY_REWRITE
(
y
-
max
(
x
,
y
),
min
(
y
-
x
,
0
));
TVM_TRY_REWRITE
(
x
-
min
(
x
,
y
),
max
(
0
,
x
-
y
));
TVM_TRY_REWRITE
(
y
-
min
(
x
,
y
),
max
(
y
-
x
,
0
));
// mul co-efficient folding
TVM_TRY_REWRITE
(
x
-
x
,
ZeroWithTypeLike
(
x
));
TVM_TRY_REWRITE
(
x
*
y
-
x
,
x
*
(
y
-
1
));
TVM_TRY_REWRITE
(
y
*
x
-
x
,
x
*
(
y
-
1
));
TVM_TRY_REWRITE
(
x
-
y
*
x
,
x
*
(
1
-
y
));
TVM_TRY_REWRITE
(
x
-
x
*
y
,
x
*
(
1
-
y
));
TVM_TRY_REWRITE
(
x
*
y
-
x
*
z
,
x
*
(
y
-
z
));
TVM_TRY_REWRITE
(
y
*
x
-
x
*
z
,
x
*
(
y
-
z
));
TVM_TRY_REWRITE
(
x
*
y
-
z
*
x
,
x
*
(
y
-
z
));
TVM_TRY_REWRITE
(
y
*
x
-
z
*
x
,
x
*
(
y
-
z
));
// constant cancelation
TVM_TRY_REWRITE
((
x
+
c1
)
-
c2
,
x
+
(
c1
-
c2
));
TVM_TRY_REWRITE
((
c1
-
x
)
-
(
c2
-
y
),
(
y
-
x
)
+
(
c1
-
c2
));
// cancelization rule involving 4 operands
TVM_TRY_REWRITE
((
x
+
y
)
-
(
x
+
z
),
y
-
z
);
TVM_TRY_REWRITE
((
x
+
y
)
-
(
z
+
x
),
y
-
z
);
TVM_TRY_REWRITE
((
y
+
x
)
-
(
z
+
x
),
y
-
z
);
TVM_TRY_REWRITE
((
y
+
x
)
-
(
x
+
z
),
y
-
z
);
TVM_TRY_REWRITE
(
min
(
x
+
y
,
z
)
-
x
,
min
(
y
,
z
-
x
));
TVM_TRY_REWRITE
(
min
(
y
+
x
,
z
)
-
x
,
min
(
y
,
z
-
x
));
TVM_TRY_REWRITE
(
min
(
z
,
x
+
y
)
-
x
,
min
(
z
-
x
,
y
));
TVM_TRY_REWRITE
(
min
(
z
,
y
+
x
)
-
x
,
min
(
z
-
x
,
y
));
TVM_TRY_REWRITE
(
x
-
min
(
x
+
y
,
z
),
max
(
0
-
y
,
x
-
z
));
TVM_TRY_REWRITE
(
x
-
min
(
y
+
x
,
z
),
max
(
0
-
y
,
x
-
z
));
TVM_TRY_REWRITE
(
x
-
min
(
z
,
x
+
y
),
max
(
x
-
z
,
0
-
y
));
TVM_TRY_REWRITE
(
x
-
min
(
z
,
y
+
x
),
max
(
x
-
z
,
0
-
y
));
TVM_TRY_REWRITE
(
min
(
x
,
y
)
-
min
(
y
,
x
),
ZeroWithTypeLike
(
x
));
TVM_TRY_REWRITE
(
max
(
x
,
y
)
-
max
(
y
,
x
),
ZeroWithTypeLike
(
x
));
TVM_TRY_REWRITE_IF
(
min
(
b1
,
b2
)
-
min
(
s1
,
s2
),
b1
-
s1
,
CanProveEqual
(((
b1
-
s1
)
-
(
b2
-
s2
)).
Eval
(),
0
));
TVM_TRY_REWRITE_IF
(
min
(
b1
,
b2
)
-
min
(
s1
,
s2
),
b1
-
s2
,
CanProveEqual
(((
b1
-
s2
)
-
(
b2
-
s1
)).
Eval
(),
0
));
TVM_TRY_REWRITE_IF
(
max
(
b1
,
b2
)
-
max
(
s1
,
s2
),
b1
-
s1
,
CanProveEqual
(((
b1
-
s1
)
-
(
b2
-
s2
)).
Eval
(),
0
));
TVM_TRY_REWRITE_IF
(
max
(
b1
,
b2
)
-
max
(
s1
,
s2
),
b1
-
s2
,
CanProveEqual
(((
b1
-
s2
)
-
(
b2
-
s1
)).
Eval
(),
0
));
// modular-div simplification
// Always pre-condition on positive integer domain
TVM_TRY_REWRITE_IF
(
x
-
(
x
/
c1
)
*
c1
,
x
%
c1
,
CanProveGreaterEqual
(
x
.
Eval
(),
0
)
&&
c1
.
Eval
()
->
value
>
0
);
TVM_TRY_REWRITE_IF
((
x
/
c1
)
*
c1
-
x
,
0
-
(
x
%
c1
),
CanProveGreaterEqual
(
x
.
Eval
(),
0
)
&&
c1
.
Eval
()
->
value
>
0
);
TVM_TRY_REWRITE_IF
((
x
+
c1
)
/
c3
-
(
x
+
c2
)
/
c3
,
((
x
+
(
c1
%
c3
))
%
c3
+
(
c1
-
c2
))
/
c3
,
CanProveGreaterEqual
(
x
.
Eval
(),
-
c2
.
Eval
()
->
value
)
&&
c1
.
Eval
()
->
value
>=
c2
.
Eval
()
->
value
&&
c3
.
Eval
()
->
value
>
0
);
TVM_TRY_REWRITE_IF
((
x
+
c1
)
/
c3
-
x
/
c3
,
((
x
+
(
c1
%
c3
))
%
c3
+
c1
)
/
c3
,
CanProveGreaterEqual
(
x
.
Eval
(),
0
)
&&
c1
.
Eval
()
->
value
>=
0
&&
c3
.
Eval
()
->
value
>
0
);
// canonicalization rule
// will try rewrite again after canonicalization.
TVM_TRY_REWRITE
(
x
-
c1
,
x
+
(
0
-
c1
));
TVM_TRY_RECURSIVE_REWRITE
((
x
+
c1
)
-
y
,
(
x
-
y
)
+
c1
);
TVM_TRY_RECURSIVE_REWRITE
(
x
-
(
y
-
z
),
(
x
+
z
)
-
y
);
TVM_TRY_RECURSIVE_REWRITE
(
x
-
y
*
c1
,
x
+
y
*
(
0
-
c1
));
}
// condition rules.
TVM_TRY_REWRITE
(
select
(
x
,
b1
,
b2
)
-
select
(
x
,
s1
,
s2
),
select
(
x
,
b1
-
s1
,
b2
-
s2
));
TVM_TRY_REWRITE
(
select
(
x
,
y
,
z
)
-
z
,
select
(
x
,
y
-
z
,
ZeroWithTypeLike
(
z
)));
TVM_TRY_REWRITE
(
select
(
x
,
y
,
z
)
-
y
,
select
(
x
,
ZeroWithTypeLike
(
y
),
z
-
y
));
return
ret
;
}
Expr
RewriteSimplifier
::
Impl
::
Mutate_
(
const
Mul
*
op
,
const
Expr
&
self
)
{
Expr
ret
=
IRMutator
::
Mutate_
(
op
,
self
);
op
=
ret
.
as
<
Mul
>
();
Expr
const_res
=
TryConstFold
<
Mul
>
(
op
->
a
,
op
->
b
);
if
(
const_res
.
defined
())
return
const_res
;
// Pattern var to match any expression
PVar
<
Expr
>
x
,
y
,
z
,
b1
,
b2
,
s1
,
s2
;
// Pattern var match IntImm
PVar
<
Integer
>
c1
,
c2
;
// Pattern var for lanes in broadcast and ramp
PVar
<
int
>
lanes
;
// Vector rules
if
(
op
->
type
.
lanes
()
!=
1
)
{
TVM_TRY_REWRITE
(
broadcast
(
x
,
lanes
)
*
broadcast
(
y
,
lanes
),
broadcast
(
x
*
y
,
lanes
));
TVM_TRY_REWRITE
(
ramp
(
b1
,
s1
,
lanes
)
*
broadcast
(
x
,
lanes
),
ramp
(
b1
*
x
,
s1
*
x
,
lanes
));
TVM_TRY_REWRITE
(
broadcast
(
x
,
lanes
)
*
ramp
(
b1
,
s1
,
lanes
),
ramp
(
b1
*
x
,
s1
*
x
,
lanes
));
}
if
(
IsIndexType
(
op
->
type
))
{
// constant simplification rule
TVM_TRY_REWRITE
((
x
+
c1
)
*
c2
,
x
*
c2
+
c1
*
c2
);
TVM_TRY_REWRITE
((
x
*
c1
)
*
c2
,
x
*
(
c1
*
c2
));
TVM_TRY_REWRITE
(
min
(
x
,
y
)
*
max
(
x
,
y
),
x
*
y
);
TVM_TRY_REWRITE
(
max
(
x
,
y
)
*
min
(
x
,
y
),
x
*
y
);
// canonicalization
TVM_TRY_RECURSIVE_REWRITE
(
x
*
(
c1
*
y
),
(
x
*
y
)
*
c1
);
TVM_TRY_RECURSIVE_REWRITE_IF
(
(
x
-
y
)
*
c1
,
(
y
-
x
)
*
(
0
-
c1
),
c1
.
Eval
()
->
value
<
0
);
}
return
ret
;
}
Expr
RewriteSimplifier
::
Impl
::
Mutate_
(
const
Div
*
op
,
const
Expr
&
self
)
{
Expr
ret
=
IRMutator
::
Mutate_
(
op
,
self
);
op
=
ret
.
as
<
Div
>
();
Expr
const_res
=
TryConstFold
<
Div
>
(
op
->
a
,
op
->
b
);
if
(
const_res
.
defined
())
return
const_res
;
// Pattern var to match any expression
PVar
<
Expr
>
x
,
y
,
z
,
b1
;
// Pattern var match IntImm
PVar
<
Integer
>
c1
,
c2
,
c3
;
// Pattern var for lanes in broadcast and ramp
PVar
<
int
>
lanes
;
// Vector rules
if
(
op
->
type
.
lanes
()
!=
1
)
{
TVM_TRY_REWRITE
(
broadcast
(
x
,
lanes
)
/
broadcast
(
y
,
lanes
),
broadcast
(
x
/
y
,
lanes
));
// ramp / bcast
if
((
ramp
(
b1
,
c1
,
lanes
)
/
broadcast
(
c2
,
lanes
)).
Match
(
ret
))
{
int64_t
c1val
=
c1
.
Eval
()
->
value
;
int64_t
c2val
=
c2
.
Eval
()
->
value
;
if
(
c1val
%
c2val
==
0
)
{
return
ramp
(
b1
/
c2
,
c1
/
c2
,
lanes
).
Eval
();
}
// If all possible indices in ramp are the same.
if
(
CanProveGreaterEqual
(
b1
.
Eval
(),
0
))
{
ModularSet
bmod
=
parent_
->
modular_set
(
b1
.
Eval
());
int64_t
ramp_min
=
bmod
->
base
/
c2val
;
int64_t
ramp_max
=
(
bmod
->
base
+
(
lanes
.
Eval
()
-
1
)
*
c1val
)
/
c2val
;
if
(
bmod
->
coeff
%
c2val
==
0
&&
ramp_min
==
ramp_max
)
{
return
broadcast
(
b1
/
c2
,
lanes
).
Eval
();
}
}
}
}
if
(
IsIndexType
(
op
->
type
))
{
// Be-aware of the division rules:
// We adopt the default C division uses truncation instead of floordiv.
// This means most rules need to check non-negativeness of the operands.
// while it is always true for trunc div
// restrict to common case(positive div)
TVM_TRY_REWRITE_IF
((
x
/
c1
)
/
c2
,
x
/
(
c1
*
c2
),
c1
.
Eval
()
->
value
>
0
&&
c2
.
Eval
()
->
value
>
0
);
TVM_TRY_REWRITE_IF
((
x
/
c1
+
c2
)
/
c3
,
(
x
+
c1
*
c2
)
/
(
c1
*
c3
),
c1
.
Eval
()
->
value
>
0
&&
c2
.
Eval
()
->
value
>=
0
&&
c3
.
Eval
()
->
value
>
0
&&
CanProveGreaterEqual
(
x
.
Eval
(),
0
));
if
(((
x
*
c1
)
/
c2
).
Match
(
ret
))
{
int64_t
c1val
=
c1
.
Eval
()
->
value
;
int64_t
c2val
=
c2
.
Eval
()
->
value
;
if
(
c1val
>
0
&&
c2val
>
0
)
{
if
(
c1val
%
c2val
==
0
)
return
(
x
*
(
c1
/
c2
)).
Eval
();
if
(
c2val
%
c1val
==
0
)
return
(
x
/
(
c2
/
c1
)).
Eval
();
}
}
// Rules involving 2-operands.
TVM_TRY_REWRITE_IF
((
x
*
c1
+
y
)
/
c2
,
x
*
(
c1
/
c2
)
+
y
/
c2
,
c1
.
Eval
()
->
value
>=
0
&&
c2
.
Eval
()
->
value
>
0
&&
c1
.
Eval
()
->
value
%
c2
.
Eval
()
->
value
==
0
&&
CanProveGreaterEqual
(
x
.
Eval
(),
0
)
&&
CanProveGreaterEqual
(
y
.
Eval
(),
0
));
TVM_TRY_REWRITE_IF
(
min
(
x
*
c1
,
y
)
/
c2
,
min
(
x
*
(
c1
/
c2
),
y
/
c2
),
c1
.
Eval
()
->
value
>=
0
&&
c2
.
Eval
()
->
value
>
0
&&
c1
.
Eval
()
->
value
%
c2
.
Eval
()
->
value
==
0
&&
CanProveGreaterEqual
(
x
.
Eval
(),
0
)
&&
CanProveGreaterEqual
(
y
.
Eval
(),
0
));
TVM_TRY_REWRITE_IF
(
max
(
x
*
c1
,
y
)
/
c2
,
max
(
x
*
(
c1
/
c2
),
y
/
c2
),
c1
.
Eval
()
->
value
>=
0
&&
c2
.
Eval
()
->
value
>
0
&&
c1
.
Eval
()
->
value
%
c2
.
Eval
()
->
value
==
0
&&
CanProveGreaterEqual
(
x
.
Eval
(),
0
)
&&
CanProveGreaterEqual
(
y
.
Eval
(),
0
));
TVM_TRY_REWRITE_IF
((
y
+
x
*
c1
)
/
c2
,
y
/
c2
+
x
*
(
c1
/
c2
),
c1
.
Eval
()
->
value
>=
0
&&
c2
.
Eval
()
->
value
>
0
&&
c1
.
Eval
()
->
value
%
c2
.
Eval
()
->
value
==
0
&&
CanProveGreaterEqual
(
x
.
Eval
(),
0
)
&&
CanProveGreaterEqual
(
y
.
Eval
(),
0
));
TVM_TRY_REWRITE_IF
(
min
(
y
,
x
*
c1
)
/
c2
,
min
(
y
/
c2
,
x
*
(
c1
/
c2
)),
c1
.
Eval
()
->
value
>=
0
&&
c2
.
Eval
()
->
value
>
0
&&
c1
.
Eval
()
->
value
%
c2
.
Eval
()
->
value
==
0
&&
CanProveGreaterEqual
(
x
.
Eval
(),
0
)
&&
CanProveGreaterEqual
(
y
.
Eval
(),
0
));
TVM_TRY_REWRITE_IF
(
max
(
y
,
x
*
c1
)
/
c2
,
max
(
y
/
c2
,
x
*
(
c1
/
c2
)),
c1
.
Eval
()
->
value
>=
0
&&
c2
.
Eval
()
->
value
>
0
&&
c1
.
Eval
()
->
value
%
c2
.
Eval
()
->
value
==
0
&&
CanProveGreaterEqual
(
x
.
Eval
(),
0
)
&&
CanProveGreaterEqual
(
y
.
Eval
(),
0
));
// Rules involving 3-operands.
TVM_TRY_REWRITE_IF
((
x
*
c1
+
y
+
z
)
/
c2
,
x
*
(
c1
/
c2
)
+
(
y
+
z
)
/
c2
,
c1
.
Eval
()
->
value
>=
0
&&
c2
.
Eval
()
->
value
>
0
&&
c1
.
Eval
()
->
value
%
c2
.
Eval
()
->
value
==
0
&&
CanProveGreaterEqual
(
x
.
Eval
(),
0
)
&&
CanProveGreaterEqual
((
y
+
z
).
Eval
(),
0
));
TVM_TRY_REWRITE_IF
((
x
*
c1
-
y
+
z
)
/
c2
,
x
*
(
c1
/
c2
)
+
(
z
-
y
)
/
c2
,
c1
.
Eval
()
->
value
>=
0
&&
c2
.
Eval
()
->
value
>
0
&&
c1
.
Eval
()
->
value
%
c2
.
Eval
()
->
value
==
0
&&
CanProveGreaterEqual
(
x
.
Eval
(),
0
)
&&
CanProveGreaterEqual
((
z
-
y
).
Eval
(),
0
));
TVM_TRY_REWRITE_IF
((
x
*
c1
+
y
-
z
)
/
c2
,
x
*
(
c1
/
c2
)
+
(
y
-
z
)
/
c2
,
c1
.
Eval
()
->
value
>=
0
&&
c2
.
Eval
()
->
value
>
0
&&
c1
.
Eval
()
->
value
%
c2
.
Eval
()
->
value
==
0
&&
CanProveGreaterEqual
(
x
.
Eval
(),
0
)
&&
CanProveGreaterEqual
((
y
-
z
).
Eval
(),
0
));
TVM_TRY_REWRITE_IF
((
y
+
x
*
c1
+
z
)
/
c2
,
x
*
(
c1
/
c2
)
+
(
y
+
z
)
/
c2
,
c1
.
Eval
()
->
value
>
0
&&
c2
.
Eval
()
->
value
>
0
&&
c1
.
Eval
()
->
value
%
c2
.
Eval
()
->
value
==
0
&&
CanProveGreaterEqual
(
x
.
Eval
(),
0
)
&&
CanProveGreaterEqual
((
y
+
z
).
Eval
(),
0
));
TVM_TRY_REWRITE_IF
((
x
+
c1
)
/
c2
,
x
/
c2
+
c1
/
c2
,
c1
.
Eval
()
->
value
>
0
&&
c2
.
Eval
()
->
value
>
0
&&
c1
.
Eval
()
->
value
%
c2
.
Eval
()
->
value
==
0
&&
CanProveGreaterEqual
(
x
.
Eval
(),
0
));
TVM_TRY_REWRITE_IF
((
x
+
y
)
/
x
,
y
/
x
+
1
,
CanProveGreaterEqual
(
x
.
Eval
(),
0
)
&&
CanProveGreaterEqual
(
y
.
Eval
(),
0
));
TVM_TRY_REWRITE_IF
((
y
+
x
)
/
x
,
y
/
x
+
1
,
CanProveGreaterEqual
(
x
.
Eval
(),
0
)
&&
CanProveGreaterEqual
(
y
.
Eval
(),
0
));
TVM_TRY_REWRITE_IF
(((
x
+
y
)
+
z
)
/
x
,
(
y
+
z
)
/
x
+
1
,
CanProveGreaterEqual
(
x
.
Eval
(),
0
)
&&
CanProveGreaterEqual
((
y
+
z
).
Eval
(),
0
));
TVM_TRY_REWRITE_IF
(((
y
+
x
)
+
z
)
/
x
,
(
y
+
z
)
/
x
+
1
,
CanProveGreaterEqual
(
x
.
Eval
(),
0
)
&&
CanProveGreaterEqual
((
y
+
z
).
Eval
(),
0
));
TVM_TRY_REWRITE_IF
((
y
+
(
z
+
x
))
/
x
,
(
y
+
z
)
/
x
+
1
,
CanProveGreaterEqual
(
x
.
Eval
(),
0
)
&&
CanProveGreaterEqual
((
y
+
z
).
Eval
(),
0
));
TVM_TRY_REWRITE_IF
((
y
+
(
x
+
z
))
/
x
,
(
y
+
z
)
/
x
+
1
,
CanProveGreaterEqual
(
x
.
Eval
(),
0
)
&&
CanProveGreaterEqual
((
y
+
z
).
Eval
(),
0
));
TVM_TRY_REWRITE_IF
((
x
*
y
)
/
y
,
x
,
CanProveGreaterEqual
(
x
.
Eval
(),
0
)
&&
CanProveGreaterEqual
(
y
.
Eval
(),
0
));
TVM_TRY_REWRITE_IF
((
y
*
x
)
/
y
,
x
,
CanProveGreaterEqual
(
x
.
Eval
(),
0
)
&&
CanProveGreaterEqual
(
y
.
Eval
(),
0
));
TVM_TRY_REWRITE_IF
((
x
*
z
+
y
)
/
z
,
x
+
y
/
z
,
CanProveGreaterEqual
(
x
.
Eval
(),
0
)
&&
CanProveGreaterEqual
(
y
.
Eval
(),
0
)
&&
CanProveGreaterEqual
(
z
.
Eval
(),
0
));
TVM_TRY_REWRITE_IF
((
z
*
x
+
y
)
/
z
,
x
+
y
/
z
,
CanProveGreaterEqual
(
x
.
Eval
(),
0
)
&&
CanProveGreaterEqual
(
y
.
Eval
(),
0
)
&&
CanProveGreaterEqual
(
z
.
Eval
(),
0
));
TVM_TRY_REWRITE_IF
((
y
+
x
*
z
)
/
z
,
y
/
z
+
x
,
CanProveGreaterEqual
(
x
.
Eval
(),
0
)
&&
CanProveGreaterEqual
(
y
.
Eval
(),
0
)
&&
CanProveGreaterEqual
(
z
.
Eval
(),
0
));
TVM_TRY_REWRITE_IF
((
y
+
z
*
x
)
/
z
,
y
/
z
+
x
,
CanProveGreaterEqual
(
x
.
Eval
(),
0
)
&&
CanProveGreaterEqual
(
y
.
Eval
(),
0
)
&&
CanProveGreaterEqual
(
z
.
Eval
(),
0
));
}
return
ret
;
}
Expr
RewriteSimplifier
::
Impl
::
Mutate_
(
const
Mod
*
op
,
const
Expr
&
self
)
{
Expr
ret
=
IRMutator
::
Mutate_
(
op
,
self
);
op
=
ret
.
as
<
Mod
>
();
Expr
const_res
=
TryConstFold
<
Mod
>
(
op
->
a
,
op
->
b
);
if
(
const_res
.
defined
())
return
const_res
;
// Pattern var to match any expression
PVar
<
Expr
>
x
,
y
,
z
,
b1
;
// Pattern var match IntImm
PVar
<
Integer
>
c1
,
c2
,
c3
;
// Pattern var for lanes in broadcast and ramp
PVar
<
int
>
lanes
;
// Vector rules
if
(
op
->
type
.
lanes
()
!=
1
)
{
TVM_TRY_REWRITE
(
broadcast
(
x
,
lanes
)
%
broadcast
(
y
,
lanes
),
broadcast
(
x
%
y
,
lanes
));
// ramp % bcast
if
((
ramp
(
b1
,
c1
,
lanes
)
%
broadcast
(
c2
,
lanes
)).
Match
(
ret
))
{
int64_t
c1val
=
c1
.
Eval
()
->
value
;
int64_t
c2val
=
c2
.
Eval
()
->
value
;
if
(
c1val
%
c2val
==
0
)
{
return
broadcast
(
b1
%
c2
,
lanes
).
Eval
();
}
// If all possible indices in ramp are the same.
if
(
CanProveGreaterEqual
(
b1
.
Eval
(),
0
))
{
ModularSet
bmod
=
parent_
->
modular_set
(
b1
.
Eval
());
int64_t
ramp_min
=
bmod
->
base
/
c2val
;
int64_t
ramp_max
=
(
bmod
->
base
+
(
lanes
.
Eval
()
-
1
)
*
c1val
)
/
c2val
;
if
(
bmod
->
coeff
%
c2val
==
0
)
{
if
(
ramp_min
==
ramp_max
)
{
return
ramp
(
bmod
->
base
%
c2
,
c1
,
lanes
).
Eval
();
}
else
{
return
(
ramp
(
bmod
->
base
%
c2
,
c1
,
lanes
)
%
broadcast
(
c2
,
lanes
)).
Eval
();
}
}
}
}
}
if
(
IsIndexType
(
op
->
type
))
{
// Be-aware of the division rules:
// We adopt the default C division uses truncation instead of floordiv.
// This means most rules need to check non-negativeness of the operands.
TVM_TRY_REWRITE_IF
((
x
*
c1
)
%
c2
,
ZeroWithTypeLike
(
x
),
c2
.
Eval
()
->
value
!=
0
&&
c1
.
Eval
()
->
value
%
c2
.
Eval
()
->
value
==
0
);
TVM_TRY_REWRITE_IF
((
x
*
c1
+
y
)
%
c2
,
y
%
c2
,
c2
.
Eval
()
->
value
>
0
&&
c1
.
Eval
()
->
value
%
c2
.
Eval
()
->
value
==
0
&&
CanProveGreaterEqual
(
y
.
Eval
(),
0
));
TVM_TRY_REWRITE_IF
((
x
+
c1
)
%
c2
,
x
%
c2
,
c2
.
Eval
()
->
value
>
0
&&
c1
.
Eval
()
->
value
%
c2
.
Eval
()
->
value
==
0
&&
CanProveGreaterEqual
(
x
.
Eval
(),
0
));
TVM_TRY_REWRITE_IF
((
x
+
y
*
c1
)
%
c2
,
x
%
c2
,
c2
.
Eval
()
->
value
>
0
&&
c1
.
Eval
()
->
value
%
c2
.
Eval
()
->
value
==
0
&&
CanProveGreaterEqual
(
x
.
Eval
(),
0
)
&&
CanProveGreaterEqual
(
y
.
Eval
(),
0
));
// try modular analysis
if
((
x
%
c1
).
Match
(
ret
))
{
ModularSet
mod
=
parent_
->
modular_set
(
x
.
Eval
());
int64_t
c1val
=
c1
.
Eval
()
->
value
;
if
(
mod
->
coeff
%
c1val
==
0
&&
CanProveGreaterEqual
(
x
.
Eval
(),
0
))
{
return
(
mod
->
base
%
c1
).
Eval
();
}
}
}
return
ret
;
}
Expr
RewriteSimplifier
::
operator
()(
const
Expr
&
expr
)
{
return
impl_
->
PostOrderSimplify
(
expr
);
}
void
RewriteSimplifier
::
Update
(
const
Var
&
var
,
const
Expr
&
info
,
bool
override
)
{
impl_
->
Update
(
var
,
info
,
override
);
}
RewriteSimplifier
::
RewriteSimplifier
(
Analyzer
*
parent
)
:
impl_
(
new
Impl
(
parent
))
{
}
RewriteSimplifier
::~
RewriteSimplifier
()
{
delete
impl_
;
}
}
// namespace arith
}
// namespace tvm
tests/cpp/pattern_match_test.cc
View file @
ec95675c
...
...
@@ -117,6 +117,7 @@ TEST(Pattern, Integer) {
// special case container of Expr
CHECK
((
v
*
c
).
Match
(
tx
*
3
));
CHECK_EQ
(
c
.
Eval
()
->
value
,
3
);
CHECK
((
v
*
3
).
Match
(
tx
*
3
));
}
// cannot match c to ty
CHECK
(
!
(
v
*
c
).
Match
(
tx
*
ty
));
...
...
tests/python/unittest/test_arith_rewrite_simplify.py
0 → 100644
View file @
ec95675c
import
tvm
class
RewriteChecker
:
def
__init__
(
self
):
self
.
analyzer
=
tvm
.
arith
.
Analyzer
()
def
verify
(
self
,
data
,
expected
):
res
=
self
.
analyzer
.
rewrite_simplify
(
data
)
assert
tvm
.
ir_pass
.
Equal
(
res
,
expected
),
"data={}, res={}, expected={}"
.
format
(
data
,
res
,
expected
)
def
test_vector_simplify
():
ck
=
RewriteChecker
()
x
,
y
,
z
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
),
tvm
.
var
(
"z"
)
# Add rules
ck
.
verify
(
tvm
.
expr
.
Ramp
(
x
,
1
,
4
)
+
tvm
.
expr
.
Ramp
(
y
,
2
,
4
),
tvm
.
expr
.
Ramp
(
x
+
y
,
3
,
4
))
ck
.
verify
(
tvm
.
expr
.
Ramp
(
x
,
1
,
2
)
+
y
,
tvm
.
expr
.
Ramp
(
x
+
y
,
1
,
2
))
ck
.
verify
(
y
+
tvm
.
expr
.
Ramp
(
x
,
1
,
2
)
,
tvm
.
expr
.
Ramp
(
y
+
x
,
1
,
2
))
ck
.
verify
(
y
.
astype
(
"int32x2"
)
+
x
.
astype
(
"int32x2"
),
(
y
+
x
)
.
astype
(
"int32x2"
))
# Sub rules
ck
.
verify
(
tvm
.
expr
.
Ramp
(
x
,
4
,
4
)
-
tvm
.
expr
.
Ramp
(
y
,
2
,
4
),
tvm
.
expr
.
Ramp
(
x
-
y
,
2
,
4
))
ck
.
verify
(
tvm
.
expr
.
Ramp
(
x
,
1
,
2
)
-
y
,
tvm
.
expr
.
Ramp
(
x
-
y
,
1
,
2
))
ck
.
verify
(
y
-
tvm
.
expr
.
Ramp
(
x
,
1
,
2
)
,
tvm
.
expr
.
Ramp
(
y
-
x
,
-
1
,
2
))
ck
.
verify
(
y
.
astype
(
"int32x2"
)
-
x
.
astype
(
"int32x2"
),
(
y
-
x
)
.
astype
(
"int32x2"
))
# Mul rules
ck
.
verify
(
y
.
astype
(
"int32x2"
)
*
x
.
astype
(
"int32x2"
),
(
y
*
x
)
.
astype
(
"int32x2"
))
ck
.
verify
(
tvm
.
expr
.
Ramp
(
x
,
4
,
4
)
*
2
,
tvm
.
expr
.
Ramp
(
x
*
2
,
8
,
4
))
ck
.
verify
(
2
*
tvm
.
expr
.
Ramp
(
x
,
4
,
4
),
tvm
.
expr
.
Ramp
(
x
*
2
,
8
,
4
))
## Div rules
ck
.
verify
(
y
.
astype
(
"int32x2"
)
/
x
.
astype
(
"int32x2"
),
(
y
/
x
)
.
astype
(
"int32x2"
))
ck
.
verify
(
tvm
.
expr
.
Ramp
(
x
,
4
,
4
)
/
2
,
tvm
.
expr
.
Ramp
(
x
/
2
,
2
,
4
))
ck
.
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
0
,
1000
),
override
=
True
)
ck
.
verify
(
tvm
.
expr
.
Ramp
(
x
*
8
+
1
,
1
,
4
)
/
8
,
(
x
)
.
astype
(
"int32x4"
))
ck
.
verify
(
tvm
.
expr
.
Ramp
(
x
*
8
+
15
,
1
,
4
)
/
8
,
tvm
.
expr
.
Ramp
(
x
*
8
+
15
,
1
,
4
)
/
8
)
## Mod rules
ck
.
verify
(
y
.
astype
(
"int32x2"
)
%
x
.
astype
(
"int32x2"
),
(
y
%
x
)
.
astype
(
"int32x2"
))
ck
.
verify
(
tvm
.
expr
.
Ramp
(
x
,
4
,
4
)
%
2
,
tvm
.
expr
.
Broadcast
(
x
%
2
,
4
))
ck
.
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
0
,
1000
),
override
=
True
)
ck
.
verify
(
tvm
.
expr
.
Ramp
(
x
*
8
+
1
,
1
,
4
)
%
8
,
tvm
.
expr
.
Ramp
(
1
,
1
,
4
))
ck
.
verify
(
tvm
.
expr
.
Ramp
(
x
*
8
+
1
,
15
,
4
)
%
8
,
tvm
.
expr
.
Ramp
(
1
,
15
,
4
)
%
8
)
def
test_select_simplify
():
ck
=
RewriteChecker
()
x
,
y
,
z
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
),
tvm
.
var
(
"z"
)
# Add rules
ck
.
verify
(
tvm
.
expr
.
Select
(
x
>
0
,
y
,
0
)
+
tvm
.
expr
.
Select
(
x
>
0
,
1
,
z
),
tvm
.
expr
.
Select
(
x
>
0
,
y
+
1
,
z
))
ck
.
verify
(
tvm
.
expr
.
Select
(
x
>
0
,
y
,
1
)
-
tvm
.
expr
.
Select
(
x
>
0
,
1
,
z
),
tvm
.
expr
.
Select
(
x
>
0
,
y
+
(
-
1
),
1
-
z
))
ck
.
verify
(
tvm
.
expr
.
Select
(
x
>
0
,
y
,
z
)
-
y
,
tvm
.
expr
.
Select
(
x
>
0
,
0
,
z
-
y
))
ck
.
verify
(
tvm
.
expr
.
Select
(
x
>
0
,
y
,
z
)
-
z
,
tvm
.
expr
.
Select
(
x
>
0
,
y
-
z
,
0
))
def
test_add_index_simplify
():
ck
=
RewriteChecker
()
x
,
y
,
z
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
),
tvm
.
var
(
"z"
)
ck
.
verify
(
x
+
(
y
-
x
),
y
)
ck
.
verify
(
x
-
(
y
+
1
)
+
(
y
+
1
),
x
)
ck
.
verify
((
x
-
10
)
+
(
10
-
z
),
x
-
z
)
ck
.
verify
((
x
-
y
)
+
(
z
-
x
),
z
-
y
)
ck
.
verify
(
tvm
.
min
(
x
,
y
-
z
)
+
z
,
tvm
.
min
(
x
+
z
,
y
))
ck
.
verify
(
tvm
.
min
(
x
-
z
,
y
)
+
z
,
tvm
.
min
(
x
,
y
+
z
))
ck
.
verify
(
tvm
.
max
(
x
,
y
-
10
)
+
10
,
tvm
.
max
(
x
+
10
,
y
))
ck
.
verify
(
tvm
.
max
(
x
-
11
,
y
)
+
11
,
tvm
.
max
(
x
,
y
+
11
))
ck
.
verify
(
tvm
.
max
(
x
,
y
*
2
)
+
tvm
.
min
(
x
,
y
*
2
),
x
+
y
*
2
);
ck
.
verify
(
tvm
.
min
(
x
,
y
*
2
)
+
tvm
.
max
(
x
,
y
*
2
),
x
+
y
*
2
);
ck
.
verify
(
tvm
.
max
(
x
,
y
+
2
)
+
(
-
2
),
tvm
.
max
(
x
+
(
-
2
),
y
));
ck
.
verify
(
tvm
.
min
(
x
,
y
+
2
)
+
(
-
2
),
tvm
.
min
(
x
+
(
-
2
),
y
));
ck
.
verify
(
tvm
.
min
(
x
+
2
,
y
+
3
)
+
(
-
2
),
tvm
.
min
(
x
,
y
+
1
));
ck
.
verify
(
x
*
y
+
x
*
10
,
x
*
(
y
+
10
))
ck
.
verify
(
y
*
x
+
x
*
10
,
x
*
(
y
+
10
))
ck
.
verify
(
y
*
x
+
10
*
x
,
x
*
(
y
+
10
))
ck
.
verify
(
x
*
y
+
10
*
x
,
x
*
(
y
+
10
))
ck
.
verify
(
y
*
(
x
%
8
)
+
10
*
(
x
%
8
),
(
x
%
8
)
*
(
y
+
10
))
ck
.
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
0
,
1000
),
override
=
True
)
ck
.
verify
((
x
/
8
)
*
8
+
x
%
8
,
x
)
# canonicalization
ck
.
verify
(
x
+
2
+
3
+
4
+
x
,
x
*
2
+
9
);
ck
.
verify
(
x
+
2
+
3
+
4
+
x
*
3
,
x
*
4
+
9
);
# conservative bound
try
:
ck
.
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
-
1
,
1000
),
override
=
True
)
ck
.
verify
((
x
/
8
)
*
8
+
x
%
8
,
x
)
raise
RuntimeError
(
"bad"
)
except
AssertionError
:
pass
def
test_sub_index_simplify
():
ck
=
RewriteChecker
()
x
,
y
,
z
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
),
tvm
.
var
(
"z"
)
ck
.
verify
(
x
+
y
-
y
,
x
)
ck
.
verify
(
x
+
y
-
x
,
y
)
ck
.
verify
(
x
-
(
y
+
x
),
0
-
y
)
ck
.
verify
(
x
-
(
x
+
y
),
0
-
y
)
ck
.
verify
(
tvm
.
min
(
x
,
y
)
-
x
,
tvm
.
min
(
0
,
y
-
x
))
ck
.
verify
(
tvm
.
min
(
x
,
y
)
-
y
,
tvm
.
min
(
x
-
y
,
0
))
ck
.
verify
(
tvm
.
max
(
x
,
y
)
-
x
,
tvm
.
max
(
0
,
y
-
x
))
ck
.
verify
(
tvm
.
max
(
x
,
y
)
-
y
,
tvm
.
max
(
x
-
y
,
0
))
ck
.
verify
(
x
-
tvm
.
min
(
x
,
y
),
tvm
.
max
(
0
,
x
-
y
))
ck
.
verify
(
y
-
tvm
.
min
(
x
,
y
),
tvm
.
max
(
y
-
x
,
0
))
ck
.
verify
(
x
-
tvm
.
max
(
x
,
y
),
tvm
.
min
(
0
,
x
-
y
))
ck
.
verify
(
y
-
tvm
.
max
(
x
,
y
),
tvm
.
min
(
y
-
x
,
0
))
# mul co-efficient foldng
ck
.
verify
(
x
-
x
,
0
)
ck
.
verify
(
x
*
y
-
x
,
x
*
(
y
+
(
-
1
)))
ck
.
verify
(
x
*
y
-
10
*
x
,
x
*
(
y
+
(
-
10
)))
ck
.
verify
(
y
*
x
-
x
*
z
,
x
*
(
y
-
z
))
ck
.
verify
(
y
*
x
-
z
*
x
,
x
*
(
y
-
z
))
ck
.
verify
(
x
+
10
-
20
,
x
+
(
-
10
))
# 4-operands pattern
ck
.
verify
((
x
+
y
)
-
(
x
+
z
),
y
-
z
)
ck
.
verify
((
y
+
x
)
-
(
x
+
z
),
y
-
z
)
ck
.
verify
((
x
+
y
)
-
(
z
+
x
),
y
-
z
)
ck
.
verify
((
y
+
x
)
-
(
z
+
x
),
y
-
z
)
ck
.
verify
(
tvm
.
min
(
x
+
y
,
z
)
-
x
,
tvm
.
min
(
y
,
z
-
x
))
ck
.
verify
(
tvm
.
min
(
y
+
x
,
z
)
-
x
,
tvm
.
min
(
y
,
z
-
x
))
ck
.
verify
(
tvm
.
min
(
z
,
x
+
y
)
-
x
,
tvm
.
min
(
z
-
x
,
y
))
ck
.
verify
(
tvm
.
min
(
z
,
y
+
x
)
-
x
,
tvm
.
min
(
z
-
x
,
y
))
ck
.
verify
(
x
-
tvm
.
min
(
x
+
y
,
z
),
tvm
.
max
(
0
-
y
,
x
-
z
))
ck
.
verify
(
x
-
tvm
.
min
(
y
+
x
,
z
),
tvm
.
max
(
0
-
y
,
x
-
z
))
ck
.
verify
(
x
-
tvm
.
min
(
z
,
x
+
y
),
tvm
.
max
(
x
-
z
,
0
-
y
))
ck
.
verify
(
x
-
tvm
.
min
(
z
,
y
+
x
),
tvm
.
max
(
x
-
z
,
0
-
y
))
ck
.
verify
(
tvm
.
min
(
x
,
y
)
-
tvm
.
min
(
y
,
x
),
0
)
ck
.
verify
(
tvm
.
max
(
x
,
y
)
-
tvm
.
max
(
y
,
x
),
0
)
ck
.
verify
(
tvm
.
min
(
x
,
y
)
-
tvm
.
min
(
x
+
10
,
y
+
10
),
-
10
)
ck
.
verify
(
tvm
.
min
(
x
+
10
,
y
+
1
)
-
tvm
.
min
(
x
,
y
-
9
),
10
)
# div pattern
ck
.
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
0
,
1000
),
override
=
True
)
ck
.
verify
(
x
-
(
x
/
3
)
*
3
,
x
%
3
)
ck
.
verify
((
x
+
5
)
/
3
-
x
/
3
,
(((
x
+
2
)
%
3
)
+
5
)
/
3
)
def
test_mul_index_simplify
():
ck
=
RewriteChecker
()
x
,
y
,
z
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
),
tvm
.
var
(
"z"
)
ck
.
verify
((
x
+
2
)
*
3
,
x
*
3
+
6
)
ck
.
verify
((
x
*
2
)
*
3
,
x
*
6
)
ck
.
verify
(
tvm
.
min
(
x
,
y
)
*
tvm
.
max
(
x
,
y
),
x
*
y
)
ck
.
verify
(
tvm
.
max
(
x
,
y
)
*
tvm
.
min
(
x
,
y
),
x
*
y
)
ck
.
verify
((
x
-
y
)
*
(
-
2
),
(
y
-
x
)
*
2
)
def
test_div_index_simplify
():
ck
=
RewriteChecker
()
x
,
y
,
z
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
),
tvm
.
var
(
"z"
)
ck
.
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
0
,
1000
),
override
=
True
)
ck
.
analyzer
.
update
(
y
,
tvm
.
arith
.
ConstIntBound
(
0
,
1000
),
override
=
True
)
ck
.
analyzer
.
update
(
z
,
tvm
.
arith
.
ConstIntBound
(
0
,
1000
),
override
=
True
)
ck
.
verify
(
x
/
2
/
3
,
x
/
6
)
ck
.
verify
((
x
/
2
+
1
)
/
3
,
(
x
+
2
)
/
6
)
ck
.
verify
(
x
*
2
/
4
,
x
/
2
)
ck
.
verify
(
x
*
4
/
2
,
x
*
2
)
ck
.
verify
((
x
*
4
+
y
)
/
2
,
x
*
2
+
y
/
2
)
ck
.
verify
(
tvm
.
min
(
x
*
6
,
y
)
/
2
,
tvm
.
min
(
x
*
3
,
y
/
2
))
ck
.
verify
(
tvm
.
max
(
x
*
6
,
y
)
/
2
,
tvm
.
max
(
x
*
3
,
y
/
2
))
ck
.
verify
((
y
+
x
*
4
)
/
2
,
y
/
2
+
x
*
2
)
ck
.
verify
(
tvm
.
min
(
y
,
x
*
6
)
/
2
,
tvm
.
min
(
y
/
2
,
x
*
3
))
ck
.
verify
(
tvm
.
max
(
y
,
x
*
6
)
/
2
,
tvm
.
max
(
y
/
2
,
x
*
3
))
# 3-operands
ck
.
verify
((
x
*
6
+
y
+
z
)
/
2
,
x
*
3
+
(
y
+
z
)
/
2
)
ck
.
verify
((
x
*
6
-
y
+
(
y
+
3
))
/
2
,
x
*
3
+
1
)
ck
.
verify
((
x
*
6
+
(
y
+
3
)
-
y
)
/
2
,
x
*
3
+
1
)
ck
.
verify
((
y
+
x
*
6
+
z
)
/
2
,
x
*
3
+
(
y
+
z
)
/
2
)
ck
.
verify
((
x
+
4
)
/
2
,
x
/
2
+
2
)
ck
.
verify
((
x
+
y
)
/
x
,
y
/
x
+
1
)
ck
.
verify
((
y
+
x
)
/
x
,
y
/
x
+
1
)
ck
.
verify
(((
x
+
y
)
+
z
)
/
x
,
(
y
+
z
)
/
x
+
1
)
ck
.
verify
(((
y
+
x
)
+
z
)
/
x
,
(
y
+
z
)
/
x
+
1
)
ck
.
verify
((
y
+
(
x
+
z
))
/
x
,
(
y
+
z
)
/
x
+
1
)
ck
.
verify
((
y
+
(
z
+
x
))
/
x
,
(
y
+
z
)
/
x
+
1
)
ck
.
verify
((
x
*
y
)
/
y
,
x
)
ck
.
verify
((
y
*
x
)
/
y
,
x
)
ck
.
verify
((
x
*
z
+
y
)
/
z
,
x
+
y
/
z
)
ck
.
verify
((
z
*
x
+
y
)
/
z
,
x
+
y
/
z
)
ck
.
verify
((
y
+
x
*
z
)
/
z
,
y
/
z
+
x
)
ck
.
verify
((
y
+
z
*
x
)
/
z
,
y
/
z
+
x
)
def
test_mod_index_simplify
():
ck
=
RewriteChecker
()
x
,
y
,
z
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
),
tvm
.
var
(
"z"
)
ck
.
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
0
,
1000
),
override
=
True
)
ck
.
analyzer
.
update
(
y
,
tvm
.
arith
.
ConstIntBound
(
0
,
1000
),
override
=
True
)
ck
.
verify
(
x
*
10
%
2
,
0
)
ck
.
verify
((
x
*
10
+
y
)
%
2
,
y
%
2
)
ck
.
verify
((
x
+
10
)
%
2
,
x
%
2
)
ck
.
verify
((
x
+
y
*
10
)
%
2
,
x
%
2
)
ck
.
verify
((
x
*
10
+
1
+
y
*
2
+
2
)
%
2
,
1
)
if
__name__
==
"__main__"
:
test_mod_index_simplify
()
test_vector_simplify
()
test_add_index_simplify
()
test_sub_index_simplify
()
test_mul_index_simplify
()
test_div_index_simplify
()
test_select_simplify
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment