Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
ec95675c
Unverified
Commit
ec95675c
authored
Mar 10, 2019
by
Tianqi Chen
Committed by
GitHub
Mar 10, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[ARITH] Analyzer RewriteSimplifier: add/sub/mul/div/mod (#2722)
parent
0bf64ee0
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
366 additions
and
8 deletions
+366
-8
include/tvm/arithmetic.h
+35
-0
python/tvm/arith.py
+16
-0
src/api/api_arith.cc
+4
-0
src/arithmetic/analyzer.cc
+9
-2
src/arithmetic/const_fold.h
+3
-1
src/arithmetic/pattern_match.h
+46
-5
src/arithmetic/rewrite_simplify.cc
+0
-0
tests/cpp/pattern_match_test.cc
+1
-0
tests/python/unittest/test_arith_rewrite_simplify.py
+252
-0
No files found.
include/tvm/arithmetic.h
View file @
ec95675c
...
@@ -193,6 +193,39 @@ class ModularSetAnalyzer {
...
@@ -193,6 +193,39 @@ class ModularSetAnalyzer {
};
};
/*!
/*!
* \brief Rewrite-rule based simplifier.
*/
class
RewriteSimplifier
{
public
:
/*!
* \brief analyze the expr
* \param expr The expression of interest.
* \return the result of the analysis.
*/
Expr
operator
()(
const
Expr
&
expr
);
/*!
* \brief Update binding of var to a new expression.
*
* \param var The variable of interest.
* \param new_expr
* \param override Whether do we allow override of existing information.
*/
void
Update
(
const
Var
&
var
,
const
Expr
&
new_expr
,
bool
override
=
false
);
private
:
friend
class
Analyzer
;
friend
class
ConstraintContext
;
explicit
RewriteSimplifier
(
Analyzer
*
parent
);
~
RewriteSimplifier
();
class
Impl
;
/*! \brief Internal impl */
Impl
*
impl_
;
};
/*!
* \brief A RAII constraint context.
* \brief A RAII constraint context.
*
*
* \code
* \code
...
@@ -242,6 +275,8 @@ class Analyzer {
...
@@ -242,6 +275,8 @@ class Analyzer {
ConstIntBoundAnalyzer
const_int_bound
;
ConstIntBoundAnalyzer
const_int_bound
;
/*! \brief sub-analyzer: modular set */
/*! \brief sub-analyzer: modular set */
ModularSetAnalyzer
modular_set
;
ModularSetAnalyzer
modular_set
;
/*! \brief sub-analyzer rewrite simplfy */
RewriteSimplifier
rewrite_simplify
;
/*! \brief constructor */
/*! \brief constructor */
Analyzer
();
Analyzer
();
/*!
/*!
...
...
python/tvm/arith.py
View file @
ec95675c
...
@@ -96,6 +96,7 @@ class Analyzer:
...
@@ -96,6 +96,7 @@ class Analyzer:
self
.
_const_int_bound_update
=
_mod
(
"const_int_bound_update"
)
self
.
_const_int_bound_update
=
_mod
(
"const_int_bound_update"
)
self
.
_bind
=
_mod
(
"bind"
)
self
.
_bind
=
_mod
(
"bind"
)
self
.
_modular_set
=
_mod
(
"modular_set"
)
self
.
_modular_set
=
_mod
(
"modular_set"
)
self
.
_rewrite_simplify
=
_mod
(
"rewrite_simplify"
)
self
.
_enter_constraint_context
=
_mod
(
"enter_constraint_context"
)
self
.
_enter_constraint_context
=
_mod
(
"enter_constraint_context"
)
def
const_int_bound
(
self
,
expr
):
def
const_int_bound
(
self
,
expr
):
...
@@ -128,6 +129,21 @@ class Analyzer:
...
@@ -128,6 +129,21 @@ class Analyzer:
"""
"""
return
self
.
_modular_set
(
expr
)
return
self
.
_modular_set
(
expr
)
def
rewrite_simplify
(
self
,
expr
):
"""Simplify expression via rewriting rules.
Parameters
----------
expr : tvm.Expr
The expression.
Returns
-------
result : Expr
The result.
"""
return
self
.
_rewrite_simplify
(
expr
)
def
bind
(
self
,
var
,
expr
):
def
bind
(
self
,
var
,
expr
):
"""Bind a variable to the expression.
"""Bind a variable to the expression.
...
...
src/api/api_arith.cc
View file @
ec95675c
...
@@ -98,6 +98,10 @@ TVM_REGISTER_API("arith._CreateAnalyzer")
...
@@ -98,6 +98,10 @@ TVM_REGISTER_API("arith._CreateAnalyzer")
return
PackedFunc
([
self
](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
return
PackedFunc
([
self
](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
self
->
const_int_bound
.
Update
(
args
[
0
],
args
[
1
],
args
[
2
]);
self
->
const_int_bound
.
Update
(
args
[
0
],
args
[
1
],
args
[
2
]);
});
});
}
else
if
(
name
==
"rewrite_simplify"
)
{
return
PackedFunc
([
self
](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
self
->
rewrite_simplify
(
args
[
0
]);
});
}
else
if
(
name
==
"bind"
)
{
}
else
if
(
name
==
"bind"
)
{
return
PackedFunc
([
self
](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
return
PackedFunc
([
self
](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
auto
&
sptr
=
args
[
1
].
node_sptr
();
auto
&
sptr
=
args
[
1
].
node_sptr
();
...
...
src/arithmetic/analyzer.cc
View file @
ec95675c
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
* Copyright (c) 2019 by Contributors
* Copyright (c) 2019 by Contributors
* \file tvm/arithmetic/analyzer.cc
* \file tvm/arithmetic/analyzer.cc
*/
*/
#include <tvm/ir.h>
#include <tvm/arithmetic.h>
#include <tvm/arithmetic.h>
namespace
tvm
{
namespace
tvm
{
...
@@ -9,19 +10,22 @@ namespace arith {
...
@@ -9,19 +10,22 @@ namespace arith {
Analyzer
::
Analyzer
()
Analyzer
::
Analyzer
()
:
const_int_bound
(
this
),
:
const_int_bound
(
this
),
modular_set
(
this
)
{
modular_set
(
this
),
rewrite_simplify
(
this
)
{
}
}
void
Analyzer
::
Bind
(
const
VarExpr
&
v
,
const
Expr
&
expr
)
{
void
Analyzer
::
Bind
(
const
VarExpr
&
v
,
const
Expr
&
expr
)
{
Var
var
(
v
.
node_
);
Var
var
(
v
.
node_
);
this
->
const_int_bound
.
Update
(
var
,
this
->
const_int_bound
(
expr
));
this
->
const_int_bound
.
Update
(
var
,
this
->
const_int_bound
(
expr
));
this
->
modular_set
.
Update
(
var
,
this
->
modular_set
(
expr
));
this
->
modular_set
.
Update
(
var
,
this
->
modular_set
(
expr
));
this
->
rewrite_simplify
.
Update
(
var
,
this
->
rewrite_simplify
(
expr
));
}
}
void
Analyzer
::
Bind
(
const
VarExpr
&
v
,
const
Range
&
range
)
{
void
Analyzer
::
Bind
(
const
VarExpr
&
v
,
const
Range
&
range
)
{
Var
var
(
v
.
node_
);
Var
var
(
v
.
node_
);
this
->
const_int_bound
.
Bind
(
var
,
range
);
this
->
const_int_bound
.
Bind
(
var
,
range
);
// skip modular_set
// skip modular_set
// skip rewrite simplify
}
}
ConstraintContext
::
ConstraintContext
(
Analyzer
*
analyzer
,
const
Expr
&
constraint
)
{
ConstraintContext
::
ConstraintContext
(
Analyzer
*
analyzer
,
const
Expr
&
constraint
)
{
...
@@ -36,7 +40,10 @@ ConstraintContext::ConstraintContext(Analyzer* analyzer, const Expr& constraint)
...
@@ -36,7 +40,10 @@ ConstraintContext::ConstraintContext(Analyzer* analyzer, const Expr& constraint)
}
}
bool
Analyzer
::
CanProveGreaterEqual
(
const
Expr
&
expr
,
int64_t
lower_bound
)
{
bool
Analyzer
::
CanProveGreaterEqual
(
const
Expr
&
expr
,
int64_t
lower_bound
)
{
auto
bd
=
this
->
const_int_bound
(
expr
);
if
(
const
auto
*
ptr
=
expr
.
as
<
ir
::
IntImm
>
())
{
return
ptr
->
value
>
lower_bound
;
}
auto
bd
=
this
->
const_int_bound
(
this
->
rewrite_simplify
(
expr
));
if
(
bd
->
min_value
>=
lower_bound
)
return
true
;
if
(
bd
->
min_value
>=
lower_bound
)
return
true
;
return
false
;
return
false
;
}
}
...
...
src/arithmetic/const_fold.h
View file @
ec95675c
...
@@ -23,7 +23,9 @@ namespace arith {
...
@@ -23,7 +23,9 @@ namespace arith {
* \return nullptr if constant fold fails, otherwise return folded result.
* \return nullptr if constant fold fails, otherwise return folded result.
*/
*/
template
<
typename
Op
>
template
<
typename
Op
>
inline
Expr
TryConstFold
(
Expr
a
,
Expr
b
);
inline
Expr
TryConstFold
(
Expr
a
,
Expr
b
)
{
return
Expr
();
}
/*!
/*!
* \brief Try to run unary compute with constant folding.
* \brief Try to run unary compute with constant folding.
...
...
src/arithmetic/pattern_match.h
View file @
ec95675c
...
@@ -49,6 +49,7 @@
...
@@ -49,6 +49,7 @@
#include <tvm/ir_pass.h>
#include <tvm/ir_pass.h>
#include <tuple>
#include <tuple>
#include "const_fold.h"
namespace
tvm
{
namespace
tvm
{
namespace
arith
{
namespace
arith
{
...
@@ -242,7 +243,11 @@ class PBinaryExpr :
...
@@ -242,7 +243,11 @@ class PBinaryExpr :
}
}
Expr
Eval
()
const
{
Expr
Eval
()
const
{
return
NodeType
::
make
(
a_
.
Eval
(),
b_
.
Eval
());
Expr
lhs
=
a_
.
Eval
();
Expr
rhs
=
b_
.
Eval
();
Expr
ret
=
TryConstFold
<
NodeType
>
(
lhs
,
rhs
);
if
(
ret
.
defined
())
return
ret
;
return
NodeType
::
make
(
lhs
,
rhs
);
}
}
private
:
private
:
...
@@ -250,12 +255,48 @@ class PBinaryExpr :
...
@@ -250,12 +255,48 @@ class PBinaryExpr :
typename
TB
::
Nested
b_
;
typename
TB
::
Nested
b_
;
};
};
template
<
typename
TA
>
class
PConstWithTypeLike
:
public
Pattern
<
PConstWithTypeLike
<
TA
>
>
{
public
:
PConstWithTypeLike
(
const
TA
&
ref
,
int64_t
value
)
:
ref_
(
ref
),
value_
(
value
)
{}
void
InitMatch_
()
const
{}
bool
Match_
(
const
NodeRef
&
node
)
const
{
if
(
const
ir
::
IntImm
*
ptr
=
node
.
as
<
ir
::
IntImm
>
())
{
return
ptr
->
value
==
value_
;
}
else
{
return
false
;
}
}
Expr
Eval
()
const
{
return
make_const
(
ref_
.
Eval
().
type
(),
value_
);
}
private
:
typename
TA
::
Nested
ref_
;
int64_t
value_
;
};
#define TVM_PATTERN_BINARY_OP(FuncName, NodeName) \
#define TVM_PATTERN_BINARY_OP(FuncName, NodeName)
\
template<typename TA, typename TB> \
template<typename TA, typename TB>
\
inline PBinaryExpr<NodeName, TA, TB> \
inline PBinaryExpr<NodeName, TA, TB>
\
FuncName(const Pattern<TA>& a, const Pattern<TB>& b) { \
FuncName(const Pattern<TA>& a, const Pattern<TB>& b) {
\
return PBinaryExpr<NodeName, TA, TB>(a.derived(), b.derived()); \
return PBinaryExpr<NodeName, TA, TB>(a.derived(), b.derived()); \
} \
template<typename TA> \
inline PBinaryExpr<NodeName, TA, PConstWithTypeLike<TA> > \
FuncName(const Pattern<TA>& a, int64_t b) { \
return FuncName(a, PConstWithTypeLike<TA>(a.derived(), b)); \
} \
template<typename TA> \
inline PBinaryExpr<NodeName, PConstWithTypeLike<TA>, TA> \
FuncName(int64_t b, const Pattern<TA>& a) { \
return FuncName(PConstWithTypeLike<TA>(a.derived(), b), a); \
}
}
// arithmetic expressions
// arithmetic expressions
...
...
src/arithmetic/rewrite_simplify.cc
0 → 100644
View file @
ec95675c
This diff is collapsed.
Click to expand it.
tests/cpp/pattern_match_test.cc
View file @
ec95675c
...
@@ -117,6 +117,7 @@ TEST(Pattern, Integer) {
...
@@ -117,6 +117,7 @@ TEST(Pattern, Integer) {
// special case container of Expr
// special case container of Expr
CHECK
((
v
*
c
).
Match
(
tx
*
3
));
CHECK
((
v
*
c
).
Match
(
tx
*
3
));
CHECK_EQ
(
c
.
Eval
()
->
value
,
3
);
CHECK_EQ
(
c
.
Eval
()
->
value
,
3
);
CHECK
((
v
*
3
).
Match
(
tx
*
3
));
}
}
// cannot match c to ty
// cannot match c to ty
CHECK
(
!
(
v
*
c
).
Match
(
tx
*
ty
));
CHECK
(
!
(
v
*
c
).
Match
(
tx
*
ty
));
...
...
tests/python/unittest/test_arith_rewrite_simplify.py
0 → 100644
View file @
ec95675c
import
tvm
class
RewriteChecker
:
def
__init__
(
self
):
self
.
analyzer
=
tvm
.
arith
.
Analyzer
()
def
verify
(
self
,
data
,
expected
):
res
=
self
.
analyzer
.
rewrite_simplify
(
data
)
assert
tvm
.
ir_pass
.
Equal
(
res
,
expected
),
"data={}, res={}, expected={}"
.
format
(
data
,
res
,
expected
)
def
test_vector_simplify
():
ck
=
RewriteChecker
()
x
,
y
,
z
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
),
tvm
.
var
(
"z"
)
# Add rules
ck
.
verify
(
tvm
.
expr
.
Ramp
(
x
,
1
,
4
)
+
tvm
.
expr
.
Ramp
(
y
,
2
,
4
),
tvm
.
expr
.
Ramp
(
x
+
y
,
3
,
4
))
ck
.
verify
(
tvm
.
expr
.
Ramp
(
x
,
1
,
2
)
+
y
,
tvm
.
expr
.
Ramp
(
x
+
y
,
1
,
2
))
ck
.
verify
(
y
+
tvm
.
expr
.
Ramp
(
x
,
1
,
2
)
,
tvm
.
expr
.
Ramp
(
y
+
x
,
1
,
2
))
ck
.
verify
(
y
.
astype
(
"int32x2"
)
+
x
.
astype
(
"int32x2"
),
(
y
+
x
)
.
astype
(
"int32x2"
))
# Sub rules
ck
.
verify
(
tvm
.
expr
.
Ramp
(
x
,
4
,
4
)
-
tvm
.
expr
.
Ramp
(
y
,
2
,
4
),
tvm
.
expr
.
Ramp
(
x
-
y
,
2
,
4
))
ck
.
verify
(
tvm
.
expr
.
Ramp
(
x
,
1
,
2
)
-
y
,
tvm
.
expr
.
Ramp
(
x
-
y
,
1
,
2
))
ck
.
verify
(
y
-
tvm
.
expr
.
Ramp
(
x
,
1
,
2
)
,
tvm
.
expr
.
Ramp
(
y
-
x
,
-
1
,
2
))
ck
.
verify
(
y
.
astype
(
"int32x2"
)
-
x
.
astype
(
"int32x2"
),
(
y
-
x
)
.
astype
(
"int32x2"
))
# Mul rules
ck
.
verify
(
y
.
astype
(
"int32x2"
)
*
x
.
astype
(
"int32x2"
),
(
y
*
x
)
.
astype
(
"int32x2"
))
ck
.
verify
(
tvm
.
expr
.
Ramp
(
x
,
4
,
4
)
*
2
,
tvm
.
expr
.
Ramp
(
x
*
2
,
8
,
4
))
ck
.
verify
(
2
*
tvm
.
expr
.
Ramp
(
x
,
4
,
4
),
tvm
.
expr
.
Ramp
(
x
*
2
,
8
,
4
))
## Div rules
ck
.
verify
(
y
.
astype
(
"int32x2"
)
/
x
.
astype
(
"int32x2"
),
(
y
/
x
)
.
astype
(
"int32x2"
))
ck
.
verify
(
tvm
.
expr
.
Ramp
(
x
,
4
,
4
)
/
2
,
tvm
.
expr
.
Ramp
(
x
/
2
,
2
,
4
))
ck
.
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
0
,
1000
),
override
=
True
)
ck
.
verify
(
tvm
.
expr
.
Ramp
(
x
*
8
+
1
,
1
,
4
)
/
8
,
(
x
)
.
astype
(
"int32x4"
))
ck
.
verify
(
tvm
.
expr
.
Ramp
(
x
*
8
+
15
,
1
,
4
)
/
8
,
tvm
.
expr
.
Ramp
(
x
*
8
+
15
,
1
,
4
)
/
8
)
## Mod rules
ck
.
verify
(
y
.
astype
(
"int32x2"
)
%
x
.
astype
(
"int32x2"
),
(
y
%
x
)
.
astype
(
"int32x2"
))
ck
.
verify
(
tvm
.
expr
.
Ramp
(
x
,
4
,
4
)
%
2
,
tvm
.
expr
.
Broadcast
(
x
%
2
,
4
))
ck
.
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
0
,
1000
),
override
=
True
)
ck
.
verify
(
tvm
.
expr
.
Ramp
(
x
*
8
+
1
,
1
,
4
)
%
8
,
tvm
.
expr
.
Ramp
(
1
,
1
,
4
))
ck
.
verify
(
tvm
.
expr
.
Ramp
(
x
*
8
+
1
,
15
,
4
)
%
8
,
tvm
.
expr
.
Ramp
(
1
,
15
,
4
)
%
8
)
def
test_select_simplify
():
ck
=
RewriteChecker
()
x
,
y
,
z
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
),
tvm
.
var
(
"z"
)
# Add rules
ck
.
verify
(
tvm
.
expr
.
Select
(
x
>
0
,
y
,
0
)
+
tvm
.
expr
.
Select
(
x
>
0
,
1
,
z
),
tvm
.
expr
.
Select
(
x
>
0
,
y
+
1
,
z
))
ck
.
verify
(
tvm
.
expr
.
Select
(
x
>
0
,
y
,
1
)
-
tvm
.
expr
.
Select
(
x
>
0
,
1
,
z
),
tvm
.
expr
.
Select
(
x
>
0
,
y
+
(
-
1
),
1
-
z
))
ck
.
verify
(
tvm
.
expr
.
Select
(
x
>
0
,
y
,
z
)
-
y
,
tvm
.
expr
.
Select
(
x
>
0
,
0
,
z
-
y
))
ck
.
verify
(
tvm
.
expr
.
Select
(
x
>
0
,
y
,
z
)
-
z
,
tvm
.
expr
.
Select
(
x
>
0
,
y
-
z
,
0
))
def
test_add_index_simplify
():
ck
=
RewriteChecker
()
x
,
y
,
z
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
),
tvm
.
var
(
"z"
)
ck
.
verify
(
x
+
(
y
-
x
),
y
)
ck
.
verify
(
x
-
(
y
+
1
)
+
(
y
+
1
),
x
)
ck
.
verify
((
x
-
10
)
+
(
10
-
z
),
x
-
z
)
ck
.
verify
((
x
-
y
)
+
(
z
-
x
),
z
-
y
)
ck
.
verify
(
tvm
.
min
(
x
,
y
-
z
)
+
z
,
tvm
.
min
(
x
+
z
,
y
))
ck
.
verify
(
tvm
.
min
(
x
-
z
,
y
)
+
z
,
tvm
.
min
(
x
,
y
+
z
))
ck
.
verify
(
tvm
.
max
(
x
,
y
-
10
)
+
10
,
tvm
.
max
(
x
+
10
,
y
))
ck
.
verify
(
tvm
.
max
(
x
-
11
,
y
)
+
11
,
tvm
.
max
(
x
,
y
+
11
))
ck
.
verify
(
tvm
.
max
(
x
,
y
*
2
)
+
tvm
.
min
(
x
,
y
*
2
),
x
+
y
*
2
);
ck
.
verify
(
tvm
.
min
(
x
,
y
*
2
)
+
tvm
.
max
(
x
,
y
*
2
),
x
+
y
*
2
);
ck
.
verify
(
tvm
.
max
(
x
,
y
+
2
)
+
(
-
2
),
tvm
.
max
(
x
+
(
-
2
),
y
));
ck
.
verify
(
tvm
.
min
(
x
,
y
+
2
)
+
(
-
2
),
tvm
.
min
(
x
+
(
-
2
),
y
));
ck
.
verify
(
tvm
.
min
(
x
+
2
,
y
+
3
)
+
(
-
2
),
tvm
.
min
(
x
,
y
+
1
));
ck
.
verify
(
x
*
y
+
x
*
10
,
x
*
(
y
+
10
))
ck
.
verify
(
y
*
x
+
x
*
10
,
x
*
(
y
+
10
))
ck
.
verify
(
y
*
x
+
10
*
x
,
x
*
(
y
+
10
))
ck
.
verify
(
x
*
y
+
10
*
x
,
x
*
(
y
+
10
))
ck
.
verify
(
y
*
(
x
%
8
)
+
10
*
(
x
%
8
),
(
x
%
8
)
*
(
y
+
10
))
ck
.
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
0
,
1000
),
override
=
True
)
ck
.
verify
((
x
/
8
)
*
8
+
x
%
8
,
x
)
# canonicalization
ck
.
verify
(
x
+
2
+
3
+
4
+
x
,
x
*
2
+
9
);
ck
.
verify
(
x
+
2
+
3
+
4
+
x
*
3
,
x
*
4
+
9
);
# conservative bound
try
:
ck
.
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
-
1
,
1000
),
override
=
True
)
ck
.
verify
((
x
/
8
)
*
8
+
x
%
8
,
x
)
raise
RuntimeError
(
"bad"
)
except
AssertionError
:
pass
def
test_sub_index_simplify
():
ck
=
RewriteChecker
()
x
,
y
,
z
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
),
tvm
.
var
(
"z"
)
ck
.
verify
(
x
+
y
-
y
,
x
)
ck
.
verify
(
x
+
y
-
x
,
y
)
ck
.
verify
(
x
-
(
y
+
x
),
0
-
y
)
ck
.
verify
(
x
-
(
x
+
y
),
0
-
y
)
ck
.
verify
(
tvm
.
min
(
x
,
y
)
-
x
,
tvm
.
min
(
0
,
y
-
x
))
ck
.
verify
(
tvm
.
min
(
x
,
y
)
-
y
,
tvm
.
min
(
x
-
y
,
0
))
ck
.
verify
(
tvm
.
max
(
x
,
y
)
-
x
,
tvm
.
max
(
0
,
y
-
x
))
ck
.
verify
(
tvm
.
max
(
x
,
y
)
-
y
,
tvm
.
max
(
x
-
y
,
0
))
ck
.
verify
(
x
-
tvm
.
min
(
x
,
y
),
tvm
.
max
(
0
,
x
-
y
))
ck
.
verify
(
y
-
tvm
.
min
(
x
,
y
),
tvm
.
max
(
y
-
x
,
0
))
ck
.
verify
(
x
-
tvm
.
max
(
x
,
y
),
tvm
.
min
(
0
,
x
-
y
))
ck
.
verify
(
y
-
tvm
.
max
(
x
,
y
),
tvm
.
min
(
y
-
x
,
0
))
# mul co-efficient foldng
ck
.
verify
(
x
-
x
,
0
)
ck
.
verify
(
x
*
y
-
x
,
x
*
(
y
+
(
-
1
)))
ck
.
verify
(
x
*
y
-
10
*
x
,
x
*
(
y
+
(
-
10
)))
ck
.
verify
(
y
*
x
-
x
*
z
,
x
*
(
y
-
z
))
ck
.
verify
(
y
*
x
-
z
*
x
,
x
*
(
y
-
z
))
ck
.
verify
(
x
+
10
-
20
,
x
+
(
-
10
))
# 4-operands pattern
ck
.
verify
((
x
+
y
)
-
(
x
+
z
),
y
-
z
)
ck
.
verify
((
y
+
x
)
-
(
x
+
z
),
y
-
z
)
ck
.
verify
((
x
+
y
)
-
(
z
+
x
),
y
-
z
)
ck
.
verify
((
y
+
x
)
-
(
z
+
x
),
y
-
z
)
ck
.
verify
(
tvm
.
min
(
x
+
y
,
z
)
-
x
,
tvm
.
min
(
y
,
z
-
x
))
ck
.
verify
(
tvm
.
min
(
y
+
x
,
z
)
-
x
,
tvm
.
min
(
y
,
z
-
x
))
ck
.
verify
(
tvm
.
min
(
z
,
x
+
y
)
-
x
,
tvm
.
min
(
z
-
x
,
y
))
ck
.
verify
(
tvm
.
min
(
z
,
y
+
x
)
-
x
,
tvm
.
min
(
z
-
x
,
y
))
ck
.
verify
(
x
-
tvm
.
min
(
x
+
y
,
z
),
tvm
.
max
(
0
-
y
,
x
-
z
))
ck
.
verify
(
x
-
tvm
.
min
(
y
+
x
,
z
),
tvm
.
max
(
0
-
y
,
x
-
z
))
ck
.
verify
(
x
-
tvm
.
min
(
z
,
x
+
y
),
tvm
.
max
(
x
-
z
,
0
-
y
))
ck
.
verify
(
x
-
tvm
.
min
(
z
,
y
+
x
),
tvm
.
max
(
x
-
z
,
0
-
y
))
ck
.
verify
(
tvm
.
min
(
x
,
y
)
-
tvm
.
min
(
y
,
x
),
0
)
ck
.
verify
(
tvm
.
max
(
x
,
y
)
-
tvm
.
max
(
y
,
x
),
0
)
ck
.
verify
(
tvm
.
min
(
x
,
y
)
-
tvm
.
min
(
x
+
10
,
y
+
10
),
-
10
)
ck
.
verify
(
tvm
.
min
(
x
+
10
,
y
+
1
)
-
tvm
.
min
(
x
,
y
-
9
),
10
)
# div pattern
ck
.
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
0
,
1000
),
override
=
True
)
ck
.
verify
(
x
-
(
x
/
3
)
*
3
,
x
%
3
)
ck
.
verify
((
x
+
5
)
/
3
-
x
/
3
,
(((
x
+
2
)
%
3
)
+
5
)
/
3
)
def
test_mul_index_simplify
():
ck
=
RewriteChecker
()
x
,
y
,
z
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
),
tvm
.
var
(
"z"
)
ck
.
verify
((
x
+
2
)
*
3
,
x
*
3
+
6
)
ck
.
verify
((
x
*
2
)
*
3
,
x
*
6
)
ck
.
verify
(
tvm
.
min
(
x
,
y
)
*
tvm
.
max
(
x
,
y
),
x
*
y
)
ck
.
verify
(
tvm
.
max
(
x
,
y
)
*
tvm
.
min
(
x
,
y
),
x
*
y
)
ck
.
verify
((
x
-
y
)
*
(
-
2
),
(
y
-
x
)
*
2
)
def
test_div_index_simplify
():
ck
=
RewriteChecker
()
x
,
y
,
z
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
),
tvm
.
var
(
"z"
)
ck
.
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
0
,
1000
),
override
=
True
)
ck
.
analyzer
.
update
(
y
,
tvm
.
arith
.
ConstIntBound
(
0
,
1000
),
override
=
True
)
ck
.
analyzer
.
update
(
z
,
tvm
.
arith
.
ConstIntBound
(
0
,
1000
),
override
=
True
)
ck
.
verify
(
x
/
2
/
3
,
x
/
6
)
ck
.
verify
((
x
/
2
+
1
)
/
3
,
(
x
+
2
)
/
6
)
ck
.
verify
(
x
*
2
/
4
,
x
/
2
)
ck
.
verify
(
x
*
4
/
2
,
x
*
2
)
ck
.
verify
((
x
*
4
+
y
)
/
2
,
x
*
2
+
y
/
2
)
ck
.
verify
(
tvm
.
min
(
x
*
6
,
y
)
/
2
,
tvm
.
min
(
x
*
3
,
y
/
2
))
ck
.
verify
(
tvm
.
max
(
x
*
6
,
y
)
/
2
,
tvm
.
max
(
x
*
3
,
y
/
2
))
ck
.
verify
((
y
+
x
*
4
)
/
2
,
y
/
2
+
x
*
2
)
ck
.
verify
(
tvm
.
min
(
y
,
x
*
6
)
/
2
,
tvm
.
min
(
y
/
2
,
x
*
3
))
ck
.
verify
(
tvm
.
max
(
y
,
x
*
6
)
/
2
,
tvm
.
max
(
y
/
2
,
x
*
3
))
# 3-operands
ck
.
verify
((
x
*
6
+
y
+
z
)
/
2
,
x
*
3
+
(
y
+
z
)
/
2
)
ck
.
verify
((
x
*
6
-
y
+
(
y
+
3
))
/
2
,
x
*
3
+
1
)
ck
.
verify
((
x
*
6
+
(
y
+
3
)
-
y
)
/
2
,
x
*
3
+
1
)
ck
.
verify
((
y
+
x
*
6
+
z
)
/
2
,
x
*
3
+
(
y
+
z
)
/
2
)
ck
.
verify
((
x
+
4
)
/
2
,
x
/
2
+
2
)
ck
.
verify
((
x
+
y
)
/
x
,
y
/
x
+
1
)
ck
.
verify
((
y
+
x
)
/
x
,
y
/
x
+
1
)
ck
.
verify
(((
x
+
y
)
+
z
)
/
x
,
(
y
+
z
)
/
x
+
1
)
ck
.
verify
(((
y
+
x
)
+
z
)
/
x
,
(
y
+
z
)
/
x
+
1
)
ck
.
verify
((
y
+
(
x
+
z
))
/
x
,
(
y
+
z
)
/
x
+
1
)
ck
.
verify
((
y
+
(
z
+
x
))
/
x
,
(
y
+
z
)
/
x
+
1
)
ck
.
verify
((
x
*
y
)
/
y
,
x
)
ck
.
verify
((
y
*
x
)
/
y
,
x
)
ck
.
verify
((
x
*
z
+
y
)
/
z
,
x
+
y
/
z
)
ck
.
verify
((
z
*
x
+
y
)
/
z
,
x
+
y
/
z
)
ck
.
verify
((
y
+
x
*
z
)
/
z
,
y
/
z
+
x
)
ck
.
verify
((
y
+
z
*
x
)
/
z
,
y
/
z
+
x
)
def
test_mod_index_simplify
():
ck
=
RewriteChecker
()
x
,
y
,
z
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
),
tvm
.
var
(
"z"
)
ck
.
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
0
,
1000
),
override
=
True
)
ck
.
analyzer
.
update
(
y
,
tvm
.
arith
.
ConstIntBound
(
0
,
1000
),
override
=
True
)
ck
.
verify
(
x
*
10
%
2
,
0
)
ck
.
verify
((
x
*
10
+
y
)
%
2
,
y
%
2
)
ck
.
verify
((
x
+
10
)
%
2
,
x
%
2
)
ck
.
verify
((
x
+
y
*
10
)
%
2
,
x
%
2
)
ck
.
verify
((
x
*
10
+
1
+
y
*
2
+
2
)
%
2
,
1
)
if
__name__
==
"__main__"
:
test_mod_index_simplify
()
test_vector_simplify
()
test_add_index_simplify
()
test_sub_index_simplify
()
test_mul_index_simplify
()
test_div_index_simplify
()
test_select_simplify
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment