Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
48c92376
Unverified
Commit
48c92376
authored
May 03, 2019
by
Tianqi Chen
Committed by
GitHub
May 03, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[ARITH] Constraint-aware ConstIntBound, Enhance CanonicalSimplify (#3132)
parent
8fb7f820
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
173 additions
and
10 deletions
+173
-10
src/arithmetic/canonical_simplify.cc
+13
-0
src/arithmetic/const_int_bound.cc
+72
-1
src/arithmetic/modular_set.cc
+0
-0
src/arithmetic/rewrite_simplify.cc
+56
-8
tests/python/unittest/test_arith_canonical_simplify.py
+32
-1
No files found.
src/arithmetic/canonical_simplify.cc
View file @
48c92376
...
@@ -453,6 +453,9 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
...
@@ -453,6 +453,9 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
if
(
const
auto
*
op
=
expr
.
as
<
SplitExprNode
>
())
{
if
(
const
auto
*
op
=
expr
.
as
<
SplitExprNode
>
())
{
return
GetRef
<
SplitExpr
>
(
op
);
return
GetRef
<
SplitExpr
>
(
op
);
}
}
if
(
const
auto
*
op
=
expr
.
as
<
SumExprNode
>
())
{
if
(
op
->
base
==
0
&&
op
->
args
.
size
()
==
1
)
return
op
->
args
[
0
];
}
if
(
const
auto
*
op
=
expr
.
as_derived
<
CanonicalExprNode
>
())
{
if
(
const
auto
*
op
=
expr
.
as_derived
<
CanonicalExprNode
>
())
{
expr
=
op
->
Normalize
();
expr
=
op
->
Normalize
();
}
}
...
@@ -764,6 +767,16 @@ Mutate_(const Mod* op, const Expr& self) {
...
@@ -764,6 +767,16 @@ Mutate_(const Mod* op, const Expr& self) {
}
}
}
}
}
}
// Simplify the offset constant if necessary.
// (x - 5) % 3 => (x - 2) % 3 if x - 5 >= 0
auto
cbound
=
parent_
->
const_int_bound
(
Normalize
(
a
));
int64_t
new_base
=
psum
->
base
%
cval
;
if
(
cbound
->
min_value
>=
0
&&
cbound
->
min_value
-
psum
->
base
+
new_base
>=
0
)
{
SumExpr
sum_expr
(
std
::
move
(
a
.
node_
));
sum_expr
.
CopyOnWrite
()
->
base
=
new_base
;
return
SplitModConst
(
ToSplitExpr
(
std
::
move
(
sum_expr
)),
cval
);
}
}
else
{
}
else
{
// if a >= 0 && a < cval, then result == 0
// if a >= 0 && a < cval, then result == 0
auto
cbound
=
parent_
->
const_int_bound
(
Normalize
(
a
));
auto
cbound
=
parent_
->
const_int_bound
(
Normalize
(
a
));
...
...
src/arithmetic/const_int_bound.cc
View file @
48c92376
...
@@ -25,6 +25,7 @@
...
@@ -25,6 +25,7 @@
#include <tvm/ir_functor_ext.h>
#include <tvm/ir_functor_ext.h>
#include <algorithm>
#include <algorithm>
#include "int_op_overflow.h"
#include "int_op_overflow.h"
#include "pattern_match.h"
namespace
tvm
{
namespace
tvm
{
namespace
arith
{
namespace
arith
{
...
@@ -65,6 +66,19 @@ struct ConstIntBoundAnalyzer::Entry {
...
@@ -65,6 +66,19 @@ struct ConstIntBoundAnalyzer::Entry {
class
ConstIntBoundAnalyzer
::
Impl
:
class
ConstIntBoundAnalyzer
::
Impl
:
public
ExprFunctor
<
ConstIntBoundAnalyzer
::
Entry
(
const
Expr
&
)
>
{
public
ExprFunctor
<
ConstIntBoundAnalyzer
::
Entry
(
const
Expr
&
)
>
{
public
:
public
:
/*! \brief additional bound info about expr \in bound */
struct
BoundInfo
{
/*! \brief The expr */
Expr
expr
;
/*! \brief The additional bound */
Entry
bound
;
BoundInfo
()
{}
BoundInfo
(
Expr
expr
,
Entry
bound
)
:
expr
(
expr
),
bound
(
bound
)
{
}
};
void
Bind
(
const
Var
&
var
,
const
Range
&
range
)
{
void
Bind
(
const
Var
&
var
,
const
Range
&
range
)
{
Entry
a
=
VisitExpr
(
range
->
min
);
Entry
a
=
VisitExpr
(
range
->
min
);
Entry
b
=
VisitExpr
(
range
->
extent
);
Entry
b
=
VisitExpr
(
range
->
extent
);
...
@@ -99,6 +113,18 @@ class ConstIntBoundAnalyzer::Impl :
...
@@ -99,6 +113,18 @@ class ConstIntBoundAnalyzer::Impl :
static_cast
<
const
ir
::
BaseExprNode
*>
(
op
)
->
type
);
static_cast
<
const
ir
::
BaseExprNode
*>
(
op
)
->
type
);
}
}
Entry
VisitExpr
(
const
Expr
&
expr
)
final
{
Entry
res
=
ExprFunctor
::
VisitExpr
(
expr
);
// a linear search over additional info
// assume we won't have a lot of conditions
for
(
const
BoundInfo
&
info
:
additional_info_
)
{
if
(
ir
::
Equal
(
expr
,
info
.
expr
))
{
res
=
Intersect
(
res
,
info
.
bound
);
}
}
return
res
;
}
Entry
VisitExpr_
(
const
Cast
*
op
)
final
{
Entry
VisitExpr_
(
const
Cast
*
op
)
final
{
Entry
a
=
VisitExpr
(
op
->
value
);
Entry
a
=
VisitExpr
(
op
->
value
);
Entry
b
=
Everything
(
op
->
type
);
Entry
b
=
Everything
(
op
->
type
);
...
@@ -243,9 +269,24 @@ class ConstIntBoundAnalyzer::Impl :
...
@@ -243,9 +269,24 @@ class ConstIntBoundAnalyzer::Impl :
}
}
}
}
std
::
function
<
void
()
>
EnterConstraint
(
const
Expr
&
constraint
)
{
std
::
vector
<
BoundInfo
>
info
=
DetectBoundInfo
(
constraint
);
if
(
info
.
size
()
==
0
)
return
nullptr
;
size_t
old_size
=
additional_info_
.
size
();
additional_info_
.
insert
(
additional_info_
.
end
(),
info
.
begin
(),
info
.
end
());
size_t
new_size
=
old_size
+
info
.
size
();
auto
frecover
=
[
old_size
,
new_size
,
this
]()
{
CHECK_EQ
(
additional_info_
.
size
(),
new_size
);
additional_info_
.
resize
(
old_size
);
};
return
frecover
;
}
private
:
private
:
// internal variable map
// internal variable map
std
::
unordered_map
<
Var
,
Entry
,
ExprHash
,
ExprEqual
>
var_map_
;
std
::
unordered_map
<
Var
,
Entry
,
ExprHash
,
ExprEqual
>
var_map_
;
// additional bound info
std
::
vector
<
BoundInfo
>
additional_info_
;
// constants: the limit value means umlimited
// constants: the limit value means umlimited
// NOTE: kNegInf/kPosInf are used to represent infinity.
// NOTE: kNegInf/kPosInf are used to represent infinity.
static
const
constexpr
int64_t
kNegInf
=
ConstIntBoundNode
::
kNegInf
;
static
const
constexpr
int64_t
kNegInf
=
ConstIntBoundNode
::
kNegInf
;
...
@@ -387,6 +428,36 @@ class ConstIntBoundAnalyzer::Impl :
...
@@ -387,6 +428,36 @@ class ConstIntBoundAnalyzer::Impl :
}
}
return
ret
;
return
ret
;
}
}
/*!
* \brief Detect additional constant bound from cond, if any
* \param cond The constraint condition.
* \return List of detected bounds.
*/
static
std
::
vector
<
BoundInfo
>
DetectBoundInfo
(
const
Expr
&
cond
)
{
PVar
<
Expr
>
x
,
y
;
PVar
<
Integer
>
c
;
// NOTE: canonical form always use <= or <
if
((
c
<=
x
).
Match
(
cond
))
{
return
{
BoundInfo
(
x
.
Eval
(),
MakeBound
(
c
.
Eval
()
->
value
,
kPosInf
))};
}
if
((
c
<
x
).
Match
(
cond
))
{
return
{
BoundInfo
(
x
.
Eval
(),
MakeBound
(
c
.
Eval
()
->
value
+
1
,
kPosInf
))};
}
if
((
x
<=
c
).
Match
(
cond
))
{
return
{
BoundInfo
(
x
.
Eval
(),
MakeBound
(
kNegInf
,
c
.
Eval
()
->
value
))};
}
if
((
x
<
c
).
Match
(
cond
))
{
return
{
BoundInfo
(
x
.
Eval
(),
MakeBound
(
kNegInf
,
c
.
Eval
()
->
value
-
1
))};
}
if
((
x
&&
y
).
Match
(
cond
))
{
auto
ret1
=
DetectBoundInfo
(
x
.
Eval
());
auto
ret2
=
DetectBoundInfo
(
y
.
Eval
());
ret1
.
insert
(
ret1
.
end
(),
ret2
.
begin
(),
ret2
.
end
());
return
ret1
;
}
return
{};
}
};
};
ConstIntBound
ConstIntBoundAnalyzer
::
operator
()(
const
Expr
&
expr
)
{
ConstIntBound
ConstIntBoundAnalyzer
::
operator
()(
const
Expr
&
expr
)
{
...
@@ -405,7 +476,7 @@ void ConstIntBoundAnalyzer::Bind(const Var& var, const Range& range) {
...
@@ -405,7 +476,7 @@ void ConstIntBoundAnalyzer::Bind(const Var& var, const Range& range) {
}
}
std
::
function
<
void
()
>
ConstIntBoundAnalyzer
::
EnterConstraint
(
const
Expr
&
constraint
)
{
std
::
function
<
void
()
>
ConstIntBoundAnalyzer
::
EnterConstraint
(
const
Expr
&
constraint
)
{
return
nullptr
;
return
impl_
->
EnterConstraint
(
constraint
)
;
}
}
ConstIntBoundAnalyzer
::
ConstIntBoundAnalyzer
(
Analyzer
*
parent
)
ConstIntBoundAnalyzer
::
ConstIntBoundAnalyzer
(
Analyzer
*
parent
)
...
...
src/arithmetic/modular_set.cc
View file @
48c92376
src/arithmetic/rewrite_simplify.cc
View file @
48c92376
...
@@ -1197,14 +1197,32 @@ Mutate_(const Or* op, const Expr& self) {
...
@@ -1197,14 +1197,32 @@ Mutate_(const Or* op, const Expr& self) {
Expr
RewriteSimplifier
::
Impl
::
Expr
RewriteSimplifier
::
Impl
::
Mutate_
(
const
Select
*
op
,
const
Expr
&
self
)
{
Mutate_
(
const
Select
*
op
,
const
Expr
&
self
)
{
Expr
ret
=
IRMutator
::
Mutate_
(
op
,
self
);
Expr
cond
=
Mutate
(
op
->
condition
);
op
=
ret
.
as
<
Select
>
();
Expr
true_value
,
false_value
;
if
(
is_zero
(
op
->
condition
))
{
{
return
op
->
false_value
;
ConstraintContext
constraint
(
parent_
,
cond
);
}
true_value
=
Mutate
(
op
->
true_value
);
if
(
is_one
(
op
->
condition
))
{
}
return
op
->
true_value
;
{
ConstraintContext
constraint
(
parent_
,
Mutate
(
Not
::
make
(
cond
)));
false_value
=
Mutate
(
op
->
false_value
);
}
if
(
is_zero
(
cond
))
{
return
false_value
;
}
if
(
is_one
(
cond
))
{
return
true_value
;
}
// normal path
Expr
ret
;
if
(
cond
.
same_as
(
op
->
condition
)
&&
true_value
.
same_as
(
op
->
true_value
)
&&
false_value
.
same_as
(
op
->
false_value
))
{
ret
=
self
;
}
else
{
ret
=
Select
::
make
(
cond
,
true_value
,
false_value
);
}
}
op
=
ret
.
as
<
Select
>
();
// Pattern var to match any expression
// Pattern var to match any expression
PVar
<
Expr
>
x
,
y
;
PVar
<
Expr
>
x
,
y
;
TVM_TRY_REWRITE
(
select
(
x
,
y
,
y
),
y
);
TVM_TRY_REWRITE
(
select
(
x
,
y
,
y
),
y
);
...
@@ -1213,7 +1231,37 @@ Mutate_(const Select* op, const Expr& self) {
...
@@ -1213,7 +1231,37 @@ Mutate_(const Select* op, const Expr& self) {
Expr
RewriteSimplifier
::
Impl
::
Expr
RewriteSimplifier
::
Impl
::
Mutate_
(
const
Call
*
op
,
const
Expr
&
self
)
{
Mutate_
(
const
Call
*
op
,
const
Expr
&
self
)
{
Expr
ret
=
IRMutator
::
Mutate_
(
op
,
self
);
// add condition context to if_then_else
Expr
ret
;
if
(
op
->
is_intrinsic
(
ir
::
intrinsic
::
tvm_if_then_else
))
{
Expr
cond
=
Mutate
(
op
->
args
[
0
]);
Expr
true_value
,
false_value
;
{
ConstraintContext
constraint
(
parent_
,
cond
);
true_value
=
Mutate
(
op
->
args
[
1
]);
}
{
ConstraintContext
constraint
(
parent_
,
Mutate
(
Not
::
make
(
cond
)));
false_value
=
Mutate
(
op
->
args
[
2
]);
}
if
(
is_zero
(
cond
))
{
return
false_value
;
}
if
(
is_one
(
cond
))
{
return
true_value
;
}
if
(
cond
.
same_as
(
op
->
args
[
0
])
&&
true_value
.
same_as
(
op
->
args
[
1
])
&&
false_value
.
same_as
(
op
->
args
[
2
]))
{
ret
=
self
;
}
else
{
ret
=
Call
::
make
(
op
->
type
,
op
->
name
,
{
cond
,
true_value
,
false_value
},
op
->
call_type
);
}
}
else
{
ret
=
IRMutator
::
Mutate_
(
op
,
self
);
}
op
=
ret
.
as
<
Call
>
();
op
=
ret
.
as
<
Call
>
();
if
(
op
->
is_intrinsic
(
Call
::
likely
)
&&
is_const
(
op
->
args
[
0
]))
{
if
(
op
->
is_intrinsic
(
Call
::
likely
)
&&
is_const
(
op
->
args
[
0
]))
{
return
op
->
args
[
0
];
return
op
->
args
[
0
];
...
...
tests/python/unittest/test_arith_canonical_simplify.py
View file @
48c92376
...
@@ -22,7 +22,7 @@ class CanonicalChecker:
...
@@ -22,7 +22,7 @@ class CanonicalChecker:
def
verify
(
self
,
data
,
expected
):
def
verify
(
self
,
data
,
expected
):
res
=
self
.
analyzer
.
canonical_simplify
(
data
)
res
=
self
.
analyzer
.
canonical_simplify
(
data
)
assert
tvm
.
ir_pass
.
Equal
(
res
,
expected
),
"
data={}, res={},
expected={}"
.
format
(
data
,
res
,
expected
)
assert
tvm
.
ir_pass
.
Equal
(
res
,
expected
),
"
\n
data={}
\n
res={}
\n
expected={}"
.
format
(
data
,
res
,
expected
)
def
test_mul_sum_simplify
():
def
test_mul_sum_simplify
():
...
@@ -157,7 +157,38 @@ def test_reduce_simplify():
...
@@ -157,7 +157,38 @@ def test_reduce_simplify():
ck
.
verify
(
tvm
.
sum
(
k
/
10
,
k
),
tvm
.
sum
(
tvm
.
const
(
0
,
"int32"
),
k
))
ck
.
verify
(
tvm
.
sum
(
k
/
10
,
k
),
tvm
.
sum
(
tvm
.
const
(
0
,
"int32"
),
k
))
def
test_simplify_if_then_else
():
ck
=
CanonicalChecker
()
x
=
tvm
.
var
(
"x"
)
y
=
tvm
.
var
(
"y"
)
# simplification that takes condition into account.
res
=
tvm
.
if_then_else
((
x
*
4
+
y
)
>=
466036
,
tvm
.
if_then_else
(
24512
<=
((((
x
*
4
)
+
y
)
-
466036
)
%
24528
),
(((((
x
*
4
)
+
y
)
-
466036
)
%
24528
)
-
24512
)
%
16
,
x
),
y
)
expected
=
tvm
.
if_then_else
(
tvm
.
expr
.
LE
(
466036
,
(
x
*
4
+
y
)),
tvm
.
if_then_else
(
tvm
.
expr
.
LE
(
24512
,
((((
x
*
4
)
+
y
)
-
4
)
%
24528
)),
(((
x
*
4
)
+
y
)
-
4
)
%
16
,
x
),
y
)
ck
.
verify
(
res
,
expected
)
# can only simplify if condition
res
=
tvm
.
expr
.
Select
(
tvm
.
all
(
x
>=
-
1
,
y
>=
0
),
(
x
+
y
+
100
)
%
3
,
(
x
+
100
)
%
3
)
expected
=
tvm
.
expr
.
Select
(
tvm
.
all
(
x
>=
-
1
,
y
>=
0
),
(
x
+
y
+
1
)
%
3
,
(
x
+
100
)
%
3
)
ck
.
verify
(
res
,
ck
.
analyzer
.
canonical_simplify
(
expected
))
res
=
tvm
.
expr
.
Select
(
x
>=
10
,
tvm
.
if_then_else
(
x
/
3
>
2
,
x
,
0
),
0
)
expected
=
tvm
.
expr
.
Select
(
x
>=
10
,
x
,
0
)
ck
.
verify
(
res
,
ck
.
analyzer
.
canonical_simplify
(
expected
))
res
=
tvm
.
expr
.
Select
(
x
>=
10
,
tvm
.
if_then_else
(
x
/
3
<
2
,
x
,
0
),
0
)
ck
.
verify
(
res
,
0
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_simplify_if_then_else
()
test_div_simplify
()
test_div_simplify
()
test_reduce_simplify
()
test_reduce_simplify
()
test_reduce_combiner_simplify
()
test_reduce_combiner_simplify
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment