Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
ec95675c
Unverified
Commit
ec95675c
authored
Mar 10, 2019
by
Tianqi Chen
Committed by
GitHub
Mar 10, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[ARITH] Analyzer RewriteSimplifier: add/sub/mul/div/mod (#2722)
parent
0bf64ee0
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
362 additions
and
4 deletions
+362
-4
include/tvm/arithmetic.h
+35
-0
python/tvm/arith.py
+16
-0
src/api/api_arith.cc
+4
-0
src/arithmetic/analyzer.cc
+9
-2
src/arithmetic/const_fold.h
+3
-1
src/arithmetic/pattern_match.h
+42
-1
src/arithmetic/rewrite_simplify.cc
+0
-0
tests/cpp/pattern_match_test.cc
+1
-0
tests/python/unittest/test_arith_rewrite_simplify.py
+252
-0
No files found.
include/tvm/arithmetic.h
View file @
ec95675c
...
...
@@ -193,6 +193,39 @@ class ModularSetAnalyzer {
};
/*!
* \brief Rewrite-rule based simplifier.
*/
class
RewriteSimplifier
{
public
:
/*!
* \brief analyze the expr
* \param expr The expression of interest.
* \return the result of the analysis.
*/
Expr
operator
()(
const
Expr
&
expr
);
/*!
* \brief Update binding of var to a new expression.
*
* \param var The variable of interest.
* \param new_expr
* \param override Whether do we allow override of existing information.
*/
void
Update
(
const
Var
&
var
,
const
Expr
&
new_expr
,
bool
override
=
false
);
private
:
friend
class
Analyzer
;
friend
class
ConstraintContext
;
explicit
RewriteSimplifier
(
Analyzer
*
parent
);
~
RewriteSimplifier
();
class
Impl
;
/*! \brief Internal impl */
Impl
*
impl_
;
};
/*!
* \brief A RAII constraint context.
*
* \code
...
...
@@ -242,6 +275,8 @@ class Analyzer {
ConstIntBoundAnalyzer
const_int_bound
;
/*! \brief sub-analyzer: modular set */
ModularSetAnalyzer
modular_set
;
/*! \brief sub-analyzer rewrite simplfy */
RewriteSimplifier
rewrite_simplify
;
/*! \brief constructor */
Analyzer
();
/*!
...
...
python/tvm/arith.py
View file @
ec95675c
...
...
@@ -96,6 +96,7 @@ class Analyzer:
self
.
_const_int_bound_update
=
_mod
(
"const_int_bound_update"
)
self
.
_bind
=
_mod
(
"bind"
)
self
.
_modular_set
=
_mod
(
"modular_set"
)
self
.
_rewrite_simplify
=
_mod
(
"rewrite_simplify"
)
self
.
_enter_constraint_context
=
_mod
(
"enter_constraint_context"
)
def
const_int_bound
(
self
,
expr
):
...
...
@@ -128,6 +129,21 @@ class Analyzer:
"""
return
self
.
_modular_set
(
expr
)
def
rewrite_simplify
(
self
,
expr
):
"""Simplify expression via rewriting rules.
Parameters
----------
expr : tvm.Expr
The expression.
Returns
-------
result : Expr
The result.
"""
return
self
.
_rewrite_simplify
(
expr
)
def
bind
(
self
,
var
,
expr
):
"""Bind a variable to the expression.
...
...
src/api/api_arith.cc
View file @
ec95675c
...
...
@@ -98,6 +98,10 @@ TVM_REGISTER_API("arith._CreateAnalyzer")
return
PackedFunc
([
self
](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
self
->
const_int_bound
.
Update
(
args
[
0
],
args
[
1
],
args
[
2
]);
});
}
else
if
(
name
==
"rewrite_simplify"
)
{
return
PackedFunc
([
self
](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
self
->
rewrite_simplify
(
args
[
0
]);
});
}
else
if
(
name
==
"bind"
)
{
return
PackedFunc
([
self
](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
auto
&
sptr
=
args
[
1
].
node_sptr
();
...
...
src/arithmetic/analyzer.cc
View file @
ec95675c
...
...
@@ -2,6 +2,7 @@
* Copyright (c) 2019 by Contributors
* \file tvm/arithmetic/analyzer.cc
*/
#include <tvm/ir.h>
#include <tvm/arithmetic.h>
namespace
tvm
{
...
...
@@ -9,19 +10,22 @@ namespace arith {
Analyzer
::
Analyzer
()
:
const_int_bound
(
this
),
modular_set
(
this
)
{
modular_set
(
this
),
rewrite_simplify
(
this
)
{
}
void
Analyzer
::
Bind
(
const
VarExpr
&
v
,
const
Expr
&
expr
)
{
Var
var
(
v
.
node_
);
this
->
const_int_bound
.
Update
(
var
,
this
->
const_int_bound
(
expr
));
this
->
modular_set
.
Update
(
var
,
this
->
modular_set
(
expr
));
this
->
rewrite_simplify
.
Update
(
var
,
this
->
rewrite_simplify
(
expr
));
}
void
Analyzer
::
Bind
(
const
VarExpr
&
v
,
const
Range
&
range
)
{
Var
var
(
v
.
node_
);
this
->
const_int_bound
.
Bind
(
var
,
range
);
// skip modular_set
// skip rewrite simplify
}
ConstraintContext
::
ConstraintContext
(
Analyzer
*
analyzer
,
const
Expr
&
constraint
)
{
...
...
@@ -36,7 +40,10 @@ ConstraintContext::ConstraintContext(Analyzer* analyzer, const Expr& constraint)
}
bool
Analyzer
::
CanProveGreaterEqual
(
const
Expr
&
expr
,
int64_t
lower_bound
)
{
auto
bd
=
this
->
const_int_bound
(
expr
);
if
(
const
auto
*
ptr
=
expr
.
as
<
ir
::
IntImm
>
())
{
return
ptr
->
value
>
lower_bound
;
}
auto
bd
=
this
->
const_int_bound
(
this
->
rewrite_simplify
(
expr
));
if
(
bd
->
min_value
>=
lower_bound
)
return
true
;
return
false
;
}
...
...
src/arithmetic/const_fold.h
View file @
ec95675c
...
...
@@ -23,7 +23,9 @@ namespace arith {
* \return nullptr if constant fold fails, otherwise return folded result.
*/
template
<
typename
Op
>
inline
Expr
TryConstFold
(
Expr
a
,
Expr
b
);
inline
Expr
TryConstFold
(
Expr
a
,
Expr
b
)
{
return
Expr
();
}
/*!
* \brief Try to run unary compute with constant folding.
...
...
src/arithmetic/pattern_match.h
View file @
ec95675c
...
...
@@ -49,6 +49,7 @@
#include <tvm/ir_pass.h>
#include <tuple>
#include "const_fold.h"
namespace
tvm
{
namespace
arith
{
...
...
@@ -242,7 +243,11 @@ class PBinaryExpr :
}
Expr
Eval
()
const
{
return
NodeType
::
make
(
a_
.
Eval
(),
b_
.
Eval
());
Expr
lhs
=
a_
.
Eval
();
Expr
rhs
=
b_
.
Eval
();
Expr
ret
=
TryConstFold
<
NodeType
>
(
lhs
,
rhs
);
if
(
ret
.
defined
())
return
ret
;
return
NodeType
::
make
(
lhs
,
rhs
);
}
private
:
...
...
@@ -250,12 +255,48 @@ class PBinaryExpr :
typename
TB
::
Nested
b_
;
};
template
<
typename
TA
>
class
PConstWithTypeLike
:
public
Pattern
<
PConstWithTypeLike
<
TA
>
>
{
public
:
PConstWithTypeLike
(
const
TA
&
ref
,
int64_t
value
)
:
ref_
(
ref
),
value_
(
value
)
{}
void
InitMatch_
()
const
{}
bool
Match_
(
const
NodeRef
&
node
)
const
{
if
(
const
ir
::
IntImm
*
ptr
=
node
.
as
<
ir
::
IntImm
>
())
{
return
ptr
->
value
==
value_
;
}
else
{
return
false
;
}
}
Expr
Eval
()
const
{
return
make_const
(
ref_
.
Eval
().
type
(),
value_
);
}
private
:
typename
TA
::
Nested
ref_
;
int64_t
value_
;
};
#define TVM_PATTERN_BINARY_OP(FuncName, NodeName) \
template<typename TA, typename TB> \
inline PBinaryExpr<NodeName, TA, TB> \
FuncName(const Pattern<TA>& a, const Pattern<TB>& b) { \
return PBinaryExpr<NodeName, TA, TB>(a.derived(), b.derived()); \
} \
template<typename TA> \
inline PBinaryExpr<NodeName, TA, PConstWithTypeLike<TA> > \
FuncName(const Pattern<TA>& a, int64_t b) { \
return FuncName(a, PConstWithTypeLike<TA>(a.derived(), b)); \
} \
template<typename TA> \
inline PBinaryExpr<NodeName, PConstWithTypeLike<TA>, TA> \
FuncName(int64_t b, const Pattern<TA>& a) { \
return FuncName(PConstWithTypeLike<TA>(a.derived(), b), a); \
}
// arithmetic expressions
...
...
src/arithmetic/rewrite_simplify.cc
0 → 100644
View file @
ec95675c
This diff is collapsed.
Click to expand it.
tests/cpp/pattern_match_test.cc
View file @
ec95675c
...
...
@@ -117,6 +117,7 @@ TEST(Pattern, Integer) {
// special case container of Expr
CHECK
((
v
*
c
).
Match
(
tx
*
3
));
CHECK_EQ
(
c
.
Eval
()
->
value
,
3
);
CHECK
((
v
*
3
).
Match
(
tx
*
3
));
}
// cannot match c to ty
CHECK
(
!
(
v
*
c
).
Match
(
tx
*
ty
));
...
...
tests/python/unittest/test_arith_rewrite_simplify.py
0 → 100644
View file @
ec95675c
import
tvm
class
RewriteChecker
:
def
__init__
(
self
):
self
.
analyzer
=
tvm
.
arith
.
Analyzer
()
def
verify
(
self
,
data
,
expected
):
res
=
self
.
analyzer
.
rewrite_simplify
(
data
)
assert
tvm
.
ir_pass
.
Equal
(
res
,
expected
),
"data={}, res={}, expected={}"
.
format
(
data
,
res
,
expected
)
def
test_vector_simplify
():
ck
=
RewriteChecker
()
x
,
y
,
z
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
),
tvm
.
var
(
"z"
)
# Add rules
ck
.
verify
(
tvm
.
expr
.
Ramp
(
x
,
1
,
4
)
+
tvm
.
expr
.
Ramp
(
y
,
2
,
4
),
tvm
.
expr
.
Ramp
(
x
+
y
,
3
,
4
))
ck
.
verify
(
tvm
.
expr
.
Ramp
(
x
,
1
,
2
)
+
y
,
tvm
.
expr
.
Ramp
(
x
+
y
,
1
,
2
))
ck
.
verify
(
y
+
tvm
.
expr
.
Ramp
(
x
,
1
,
2
)
,
tvm
.
expr
.
Ramp
(
y
+
x
,
1
,
2
))
ck
.
verify
(
y
.
astype
(
"int32x2"
)
+
x
.
astype
(
"int32x2"
),
(
y
+
x
)
.
astype
(
"int32x2"
))
# Sub rules
ck
.
verify
(
tvm
.
expr
.
Ramp
(
x
,
4
,
4
)
-
tvm
.
expr
.
Ramp
(
y
,
2
,
4
),
tvm
.
expr
.
Ramp
(
x
-
y
,
2
,
4
))
ck
.
verify
(
tvm
.
expr
.
Ramp
(
x
,
1
,
2
)
-
y
,
tvm
.
expr
.
Ramp
(
x
-
y
,
1
,
2
))
ck
.
verify
(
y
-
tvm
.
expr
.
Ramp
(
x
,
1
,
2
)
,
tvm
.
expr
.
Ramp
(
y
-
x
,
-
1
,
2
))
ck
.
verify
(
y
.
astype
(
"int32x2"
)
-
x
.
astype
(
"int32x2"
),
(
y
-
x
)
.
astype
(
"int32x2"
))
# Mul rules
ck
.
verify
(
y
.
astype
(
"int32x2"
)
*
x
.
astype
(
"int32x2"
),
(
y
*
x
)
.
astype
(
"int32x2"
))
ck
.
verify
(
tvm
.
expr
.
Ramp
(
x
,
4
,
4
)
*
2
,
tvm
.
expr
.
Ramp
(
x
*
2
,
8
,
4
))
ck
.
verify
(
2
*
tvm
.
expr
.
Ramp
(
x
,
4
,
4
),
tvm
.
expr
.
Ramp
(
x
*
2
,
8
,
4
))
## Div rules
ck
.
verify
(
y
.
astype
(
"int32x2"
)
/
x
.
astype
(
"int32x2"
),
(
y
/
x
)
.
astype
(
"int32x2"
))
ck
.
verify
(
tvm
.
expr
.
Ramp
(
x
,
4
,
4
)
/
2
,
tvm
.
expr
.
Ramp
(
x
/
2
,
2
,
4
))
ck
.
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
0
,
1000
),
override
=
True
)
ck
.
verify
(
tvm
.
expr
.
Ramp
(
x
*
8
+
1
,
1
,
4
)
/
8
,
(
x
)
.
astype
(
"int32x4"
))
ck
.
verify
(
tvm
.
expr
.
Ramp
(
x
*
8
+
15
,
1
,
4
)
/
8
,
tvm
.
expr
.
Ramp
(
x
*
8
+
15
,
1
,
4
)
/
8
)
## Mod rules
ck
.
verify
(
y
.
astype
(
"int32x2"
)
%
x
.
astype
(
"int32x2"
),
(
y
%
x
)
.
astype
(
"int32x2"
))
ck
.
verify
(
tvm
.
expr
.
Ramp
(
x
,
4
,
4
)
%
2
,
tvm
.
expr
.
Broadcast
(
x
%
2
,
4
))
ck
.
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
0
,
1000
),
override
=
True
)
ck
.
verify
(
tvm
.
expr
.
Ramp
(
x
*
8
+
1
,
1
,
4
)
%
8
,
tvm
.
expr
.
Ramp
(
1
,
1
,
4
))
ck
.
verify
(
tvm
.
expr
.
Ramp
(
x
*
8
+
1
,
15
,
4
)
%
8
,
tvm
.
expr
.
Ramp
(
1
,
15
,
4
)
%
8
)
def
test_select_simplify
():
ck
=
RewriteChecker
()
x
,
y
,
z
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
),
tvm
.
var
(
"z"
)
# Add rules
ck
.
verify
(
tvm
.
expr
.
Select
(
x
>
0
,
y
,
0
)
+
tvm
.
expr
.
Select
(
x
>
0
,
1
,
z
),
tvm
.
expr
.
Select
(
x
>
0
,
y
+
1
,
z
))
ck
.
verify
(
tvm
.
expr
.
Select
(
x
>
0
,
y
,
1
)
-
tvm
.
expr
.
Select
(
x
>
0
,
1
,
z
),
tvm
.
expr
.
Select
(
x
>
0
,
y
+
(
-
1
),
1
-
z
))
ck
.
verify
(
tvm
.
expr
.
Select
(
x
>
0
,
y
,
z
)
-
y
,
tvm
.
expr
.
Select
(
x
>
0
,
0
,
z
-
y
))
ck
.
verify
(
tvm
.
expr
.
Select
(
x
>
0
,
y
,
z
)
-
z
,
tvm
.
expr
.
Select
(
x
>
0
,
y
-
z
,
0
))
def
test_add_index_simplify
():
ck
=
RewriteChecker
()
x
,
y
,
z
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
),
tvm
.
var
(
"z"
)
ck
.
verify
(
x
+
(
y
-
x
),
y
)
ck
.
verify
(
x
-
(
y
+
1
)
+
(
y
+
1
),
x
)
ck
.
verify
((
x
-
10
)
+
(
10
-
z
),
x
-
z
)
ck
.
verify
((
x
-
y
)
+
(
z
-
x
),
z
-
y
)
ck
.
verify
(
tvm
.
min
(
x
,
y
-
z
)
+
z
,
tvm
.
min
(
x
+
z
,
y
))
ck
.
verify
(
tvm
.
min
(
x
-
z
,
y
)
+
z
,
tvm
.
min
(
x
,
y
+
z
))
ck
.
verify
(
tvm
.
max
(
x
,
y
-
10
)
+
10
,
tvm
.
max
(
x
+
10
,
y
))
ck
.
verify
(
tvm
.
max
(
x
-
11
,
y
)
+
11
,
tvm
.
max
(
x
,
y
+
11
))
ck
.
verify
(
tvm
.
max
(
x
,
y
*
2
)
+
tvm
.
min
(
x
,
y
*
2
),
x
+
y
*
2
);
ck
.
verify
(
tvm
.
min
(
x
,
y
*
2
)
+
tvm
.
max
(
x
,
y
*
2
),
x
+
y
*
2
);
ck
.
verify
(
tvm
.
max
(
x
,
y
+
2
)
+
(
-
2
),
tvm
.
max
(
x
+
(
-
2
),
y
));
ck
.
verify
(
tvm
.
min
(
x
,
y
+
2
)
+
(
-
2
),
tvm
.
min
(
x
+
(
-
2
),
y
));
ck
.
verify
(
tvm
.
min
(
x
+
2
,
y
+
3
)
+
(
-
2
),
tvm
.
min
(
x
,
y
+
1
));
ck
.
verify
(
x
*
y
+
x
*
10
,
x
*
(
y
+
10
))
ck
.
verify
(
y
*
x
+
x
*
10
,
x
*
(
y
+
10
))
ck
.
verify
(
y
*
x
+
10
*
x
,
x
*
(
y
+
10
))
ck
.
verify
(
x
*
y
+
10
*
x
,
x
*
(
y
+
10
))
ck
.
verify
(
y
*
(
x
%
8
)
+
10
*
(
x
%
8
),
(
x
%
8
)
*
(
y
+
10
))
ck
.
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
0
,
1000
),
override
=
True
)
ck
.
verify
((
x
/
8
)
*
8
+
x
%
8
,
x
)
# canonicalization
ck
.
verify
(
x
+
2
+
3
+
4
+
x
,
x
*
2
+
9
);
ck
.
verify
(
x
+
2
+
3
+
4
+
x
*
3
,
x
*
4
+
9
);
# conservative bound
try
:
ck
.
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
-
1
,
1000
),
override
=
True
)
ck
.
verify
((
x
/
8
)
*
8
+
x
%
8
,
x
)
raise
RuntimeError
(
"bad"
)
except
AssertionError
:
pass
def
test_sub_index_simplify
():
ck
=
RewriteChecker
()
x
,
y
,
z
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
),
tvm
.
var
(
"z"
)
ck
.
verify
(
x
+
y
-
y
,
x
)
ck
.
verify
(
x
+
y
-
x
,
y
)
ck
.
verify
(
x
-
(
y
+
x
),
0
-
y
)
ck
.
verify
(
x
-
(
x
+
y
),
0
-
y
)
ck
.
verify
(
tvm
.
min
(
x
,
y
)
-
x
,
tvm
.
min
(
0
,
y
-
x
))
ck
.
verify
(
tvm
.
min
(
x
,
y
)
-
y
,
tvm
.
min
(
x
-
y
,
0
))
ck
.
verify
(
tvm
.
max
(
x
,
y
)
-
x
,
tvm
.
max
(
0
,
y
-
x
))
ck
.
verify
(
tvm
.
max
(
x
,
y
)
-
y
,
tvm
.
max
(
x
-
y
,
0
))
ck
.
verify
(
x
-
tvm
.
min
(
x
,
y
),
tvm
.
max
(
0
,
x
-
y
))
ck
.
verify
(
y
-
tvm
.
min
(
x
,
y
),
tvm
.
max
(
y
-
x
,
0
))
ck
.
verify
(
x
-
tvm
.
max
(
x
,
y
),
tvm
.
min
(
0
,
x
-
y
))
ck
.
verify
(
y
-
tvm
.
max
(
x
,
y
),
tvm
.
min
(
y
-
x
,
0
))
# mul co-efficient foldng
ck
.
verify
(
x
-
x
,
0
)
ck
.
verify
(
x
*
y
-
x
,
x
*
(
y
+
(
-
1
)))
ck
.
verify
(
x
*
y
-
10
*
x
,
x
*
(
y
+
(
-
10
)))
ck
.
verify
(
y
*
x
-
x
*
z
,
x
*
(
y
-
z
))
ck
.
verify
(
y
*
x
-
z
*
x
,
x
*
(
y
-
z
))
ck
.
verify
(
x
+
10
-
20
,
x
+
(
-
10
))
# 4-operands pattern
ck
.
verify
((
x
+
y
)
-
(
x
+
z
),
y
-
z
)
ck
.
verify
((
y
+
x
)
-
(
x
+
z
),
y
-
z
)
ck
.
verify
((
x
+
y
)
-
(
z
+
x
),
y
-
z
)
ck
.
verify
((
y
+
x
)
-
(
z
+
x
),
y
-
z
)
ck
.
verify
(
tvm
.
min
(
x
+
y
,
z
)
-
x
,
tvm
.
min
(
y
,
z
-
x
))
ck
.
verify
(
tvm
.
min
(
y
+
x
,
z
)
-
x
,
tvm
.
min
(
y
,
z
-
x
))
ck
.
verify
(
tvm
.
min
(
z
,
x
+
y
)
-
x
,
tvm
.
min
(
z
-
x
,
y
))
ck
.
verify
(
tvm
.
min
(
z
,
y
+
x
)
-
x
,
tvm
.
min
(
z
-
x
,
y
))
ck
.
verify
(
x
-
tvm
.
min
(
x
+
y
,
z
),
tvm
.
max
(
0
-
y
,
x
-
z
))
ck
.
verify
(
x
-
tvm
.
min
(
y
+
x
,
z
),
tvm
.
max
(
0
-
y
,
x
-
z
))
ck
.
verify
(
x
-
tvm
.
min
(
z
,
x
+
y
),
tvm
.
max
(
x
-
z
,
0
-
y
))
ck
.
verify
(
x
-
tvm
.
min
(
z
,
y
+
x
),
tvm
.
max
(
x
-
z
,
0
-
y
))
ck
.
verify
(
tvm
.
min
(
x
,
y
)
-
tvm
.
min
(
y
,
x
),
0
)
ck
.
verify
(
tvm
.
max
(
x
,
y
)
-
tvm
.
max
(
y
,
x
),
0
)
ck
.
verify
(
tvm
.
min
(
x
,
y
)
-
tvm
.
min
(
x
+
10
,
y
+
10
),
-
10
)
ck
.
verify
(
tvm
.
min
(
x
+
10
,
y
+
1
)
-
tvm
.
min
(
x
,
y
-
9
),
10
)
# div pattern
ck
.
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
0
,
1000
),
override
=
True
)
ck
.
verify
(
x
-
(
x
/
3
)
*
3
,
x
%
3
)
ck
.
verify
((
x
+
5
)
/
3
-
x
/
3
,
(((
x
+
2
)
%
3
)
+
5
)
/
3
)
def
test_mul_index_simplify
():
ck
=
RewriteChecker
()
x
,
y
,
z
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
),
tvm
.
var
(
"z"
)
ck
.
verify
((
x
+
2
)
*
3
,
x
*
3
+
6
)
ck
.
verify
((
x
*
2
)
*
3
,
x
*
6
)
ck
.
verify
(
tvm
.
min
(
x
,
y
)
*
tvm
.
max
(
x
,
y
),
x
*
y
)
ck
.
verify
(
tvm
.
max
(
x
,
y
)
*
tvm
.
min
(
x
,
y
),
x
*
y
)
ck
.
verify
((
x
-
y
)
*
(
-
2
),
(
y
-
x
)
*
2
)
def
test_div_index_simplify
():
ck
=
RewriteChecker
()
x
,
y
,
z
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
),
tvm
.
var
(
"z"
)
ck
.
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
0
,
1000
),
override
=
True
)
ck
.
analyzer
.
update
(
y
,
tvm
.
arith
.
ConstIntBound
(
0
,
1000
),
override
=
True
)
ck
.
analyzer
.
update
(
z
,
tvm
.
arith
.
ConstIntBound
(
0
,
1000
),
override
=
True
)
ck
.
verify
(
x
/
2
/
3
,
x
/
6
)
ck
.
verify
((
x
/
2
+
1
)
/
3
,
(
x
+
2
)
/
6
)
ck
.
verify
(
x
*
2
/
4
,
x
/
2
)
ck
.
verify
(
x
*
4
/
2
,
x
*
2
)
ck
.
verify
((
x
*
4
+
y
)
/
2
,
x
*
2
+
y
/
2
)
ck
.
verify
(
tvm
.
min
(
x
*
6
,
y
)
/
2
,
tvm
.
min
(
x
*
3
,
y
/
2
))
ck
.
verify
(
tvm
.
max
(
x
*
6
,
y
)
/
2
,
tvm
.
max
(
x
*
3
,
y
/
2
))
ck
.
verify
((
y
+
x
*
4
)
/
2
,
y
/
2
+
x
*
2
)
ck
.
verify
(
tvm
.
min
(
y
,
x
*
6
)
/
2
,
tvm
.
min
(
y
/
2
,
x
*
3
))
ck
.
verify
(
tvm
.
max
(
y
,
x
*
6
)
/
2
,
tvm
.
max
(
y
/
2
,
x
*
3
))
# 3-operands
ck
.
verify
((
x
*
6
+
y
+
z
)
/
2
,
x
*
3
+
(
y
+
z
)
/
2
)
ck
.
verify
((
x
*
6
-
y
+
(
y
+
3
))
/
2
,
x
*
3
+
1
)
ck
.
verify
((
x
*
6
+
(
y
+
3
)
-
y
)
/
2
,
x
*
3
+
1
)
ck
.
verify
((
y
+
x
*
6
+
z
)
/
2
,
x
*
3
+
(
y
+
z
)
/
2
)
ck
.
verify
((
x
+
4
)
/
2
,
x
/
2
+
2
)
ck
.
verify
((
x
+
y
)
/
x
,
y
/
x
+
1
)
ck
.
verify
((
y
+
x
)
/
x
,
y
/
x
+
1
)
ck
.
verify
(((
x
+
y
)
+
z
)
/
x
,
(
y
+
z
)
/
x
+
1
)
ck
.
verify
(((
y
+
x
)
+
z
)
/
x
,
(
y
+
z
)
/
x
+
1
)
ck
.
verify
((
y
+
(
x
+
z
))
/
x
,
(
y
+
z
)
/
x
+
1
)
ck
.
verify
((
y
+
(
z
+
x
))
/
x
,
(
y
+
z
)
/
x
+
1
)
ck
.
verify
((
x
*
y
)
/
y
,
x
)
ck
.
verify
((
y
*
x
)
/
y
,
x
)
ck
.
verify
((
x
*
z
+
y
)
/
z
,
x
+
y
/
z
)
ck
.
verify
((
z
*
x
+
y
)
/
z
,
x
+
y
/
z
)
ck
.
verify
((
y
+
x
*
z
)
/
z
,
y
/
z
+
x
)
ck
.
verify
((
y
+
z
*
x
)
/
z
,
y
/
z
+
x
)
def
test_mod_index_simplify
():
ck
=
RewriteChecker
()
x
,
y
,
z
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
),
tvm
.
var
(
"z"
)
ck
.
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
0
,
1000
),
override
=
True
)
ck
.
analyzer
.
update
(
y
,
tvm
.
arith
.
ConstIntBound
(
0
,
1000
),
override
=
True
)
ck
.
verify
(
x
*
10
%
2
,
0
)
ck
.
verify
((
x
*
10
+
y
)
%
2
,
y
%
2
)
ck
.
verify
((
x
+
10
)
%
2
,
x
%
2
)
ck
.
verify
((
x
+
y
*
10
)
%
2
,
x
%
2
)
ck
.
verify
((
x
*
10
+
1
+
y
*
2
+
2
)
%
2
,
1
)
if
__name__
==
"__main__"
:
test_mod_index_simplify
()
test_vector_simplify
()
test_add_index_simplify
()
test_sub_index_simplify
()
test_mul_index_simplify
()
test_div_index_simplify
()
test_select_simplify
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment