Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
48c92376
Unverified
Commit
48c92376
authored
May 03, 2019
by
Tianqi Chen
Committed by
GitHub
May 03, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[ARITH] Constraint-aware ConstIntBound, Enhance CanonicalSimplify (#3132)
parent
8fb7f820
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
180 additions
and
17 deletions
+180
-17
src/arithmetic/canonical_simplify.cc
+15
-2
src/arithmetic/const_int_bound.cc
+74
-3
src/arithmetic/modular_set.cc
+2
-2
src/arithmetic/rewrite_simplify.cc
+57
-9
tests/python/unittest/test_arith_canonical_simplify.py
+32
-1
No files found.
src/arithmetic/canonical_simplify.cc
View file @
48c92376
...
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
...
@@ -453,6 +453,9 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
if
(
const
auto
*
op
=
expr
.
as
<
SplitExprNode
>
())
{
return
GetRef
<
SplitExpr
>
(
op
);
}
if
(
const
auto
*
op
=
expr
.
as
<
SumExprNode
>
())
{
if
(
op
->
base
==
0
&&
op
->
args
.
size
()
==
1
)
return
op
->
args
[
0
];
}
if
(
const
auto
*
op
=
expr
.
as_derived
<
CanonicalExprNode
>
())
{
expr
=
op
->
Normalize
();
}
...
...
@@ -764,6 +767,16 @@ Mutate_(const Mod* op, const Expr& self) {
}
}
}
// Simplify the offset constant if necessary.
// (x - 5) % 3 => (x - 2) % 3 if x - 5 >= 0
auto
cbound
=
parent_
->
const_int_bound
(
Normalize
(
a
));
int64_t
new_base
=
psum
->
base
%
cval
;
if
(
cbound
->
min_value
>=
0
&&
cbound
->
min_value
-
psum
->
base
+
new_base
>=
0
)
{
SumExpr
sum_expr
(
std
::
move
(
a
.
node_
));
sum_expr
.
CopyOnWrite
()
->
base
=
new_base
;
return
SplitModConst
(
ToSplitExpr
(
std
::
move
(
sum_expr
)),
cval
);
}
}
else
{
// if a >= 0 && a < cval, then result == 0
auto
cbound
=
parent_
->
const_int_bound
(
Normalize
(
a
));
...
...
src/arithmetic/const_int_bound.cc
View file @
48c92376
...
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
...
@@ -25,6 +25,7 @@
#include <tvm/ir_functor_ext.h>
#include <algorithm>
#include "int_op_overflow.h"
#include "pattern_match.h"
namespace
tvm
{
namespace
arith
{
...
...
@@ -65,6 +66,19 @@ struct ConstIntBoundAnalyzer::Entry {
class
ConstIntBoundAnalyzer
::
Impl
:
public
ExprFunctor
<
ConstIntBoundAnalyzer
::
Entry
(
const
Expr
&
)
>
{
public
:
/*! \brief additional bound info about expr \in bound */
struct
BoundInfo
{
/*! \brief The expr */
Expr
expr
;
/*! \brief The additional bound */
Entry
bound
;
BoundInfo
()
{}
BoundInfo
(
Expr
expr
,
Entry
bound
)
:
expr
(
expr
),
bound
(
bound
)
{
}
};
void
Bind
(
const
Var
&
var
,
const
Range
&
range
)
{
Entry
a
=
VisitExpr
(
range
->
min
);
Entry
b
=
VisitExpr
(
range
->
extent
);
...
...
@@ -99,6 +113,18 @@ class ConstIntBoundAnalyzer::Impl :
static_cast
<
const
ir
::
BaseExprNode
*>
(
op
)
->
type
);
}
Entry
VisitExpr
(
const
Expr
&
expr
)
final
{
Entry
res
=
ExprFunctor
::
VisitExpr
(
expr
);
// a linear search over additional info
// assume we won't have a lot of conditions
for
(
const
BoundInfo
&
info
:
additional_info_
)
{
if
(
ir
::
Equal
(
expr
,
info
.
expr
))
{
res
=
Intersect
(
res
,
info
.
bound
);
}
}
return
res
;
}
Entry
VisitExpr_
(
const
Cast
*
op
)
final
{
Entry
a
=
VisitExpr
(
op
->
value
);
Entry
b
=
Everything
(
op
->
type
);
...
...
@@ -243,9 +269,24 @@ class ConstIntBoundAnalyzer::Impl :
}
}
std
::
function
<
void
()
>
EnterConstraint
(
const
Expr
&
constraint
)
{
std
::
vector
<
BoundInfo
>
info
=
DetectBoundInfo
(
constraint
);
if
(
info
.
size
()
==
0
)
return
nullptr
;
size_t
old_size
=
additional_info_
.
size
();
additional_info_
.
insert
(
additional_info_
.
end
(),
info
.
begin
(),
info
.
end
());
size_t
new_size
=
old_size
+
info
.
size
();
auto
frecover
=
[
old_size
,
new_size
,
this
]()
{
CHECK_EQ
(
additional_info_
.
size
(),
new_size
);
additional_info_
.
resize
(
old_size
);
};
return
frecover
;
}
private
:
// internal variable map
std
::
unordered_map
<
Var
,
Entry
,
ExprHash
,
ExprEqual
>
var_map_
;
// additional bound info
std
::
vector
<
BoundInfo
>
additional_info_
;
// constants: the limit value means umlimited
// NOTE: kNegInf/kPosInf are used to represent infinity.
static
const
constexpr
int64_t
kNegInf
=
ConstIntBoundNode
::
kNegInf
;
...
...
@@ -387,6 +428,36 @@ class ConstIntBoundAnalyzer::Impl :
}
return
ret
;
}
/*!
* \brief Detect additional constant bound from cond, if any
* \param cond The constraint condition.
* \return List of detected bounds.
*/
static
std
::
vector
<
BoundInfo
>
DetectBoundInfo
(
const
Expr
&
cond
)
{
PVar
<
Expr
>
x
,
y
;
PVar
<
Integer
>
c
;
// NOTE: canonical form always use <= or <
if
((
c
<=
x
).
Match
(
cond
))
{
return
{
BoundInfo
(
x
.
Eval
(),
MakeBound
(
c
.
Eval
()
->
value
,
kPosInf
))};
}
if
((
c
<
x
).
Match
(
cond
))
{
return
{
BoundInfo
(
x
.
Eval
(),
MakeBound
(
c
.
Eval
()
->
value
+
1
,
kPosInf
))};
}
if
((
x
<=
c
).
Match
(
cond
))
{
return
{
BoundInfo
(
x
.
Eval
(),
MakeBound
(
kNegInf
,
c
.
Eval
()
->
value
))};
}
if
((
x
<
c
).
Match
(
cond
))
{
return
{
BoundInfo
(
x
.
Eval
(),
MakeBound
(
kNegInf
,
c
.
Eval
()
->
value
-
1
))};
}
if
((
x
&&
y
).
Match
(
cond
))
{
auto
ret1
=
DetectBoundInfo
(
x
.
Eval
());
auto
ret2
=
DetectBoundInfo
(
y
.
Eval
());
ret1
.
insert
(
ret1
.
end
(),
ret2
.
begin
(),
ret2
.
end
());
return
ret1
;
}
return
{};
}
};
ConstIntBound
ConstIntBoundAnalyzer
::
operator
()(
const
Expr
&
expr
)
{
...
...
@@ -405,7 +476,7 @@ void ConstIntBoundAnalyzer::Bind(const Var& var, const Range& range) {
}
std
::
function
<
void
()
>
ConstIntBoundAnalyzer
::
EnterConstraint
(
const
Expr
&
constraint
)
{
return
nullptr
;
return
impl_
->
EnterConstraint
(
constraint
)
;
}
ConstIntBoundAnalyzer
::
ConstIntBoundAnalyzer
(
Analyzer
*
parent
)
...
...
src/arithmetic/modular_set.cc
View file @
48c92376
...
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
...
src/arithmetic/rewrite_simplify.cc
View file @
48c92376
...
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
...
@@ -1197,14 +1197,32 @@ Mutate_(const Or* op, const Expr& self) {
Expr
RewriteSimplifier
::
Impl
::
Mutate_
(
const
Select
*
op
,
const
Expr
&
self
)
{
Expr
ret
=
IRMutator
::
Mutate_
(
op
,
self
);
op
=
ret
.
as
<
Select
>
();
if
(
is_zero
(
op
->
condition
))
{
return
op
->
false_value
;
Expr
cond
=
Mutate
(
op
->
condition
);
Expr
true_value
,
false_value
;
{
ConstraintContext
constraint
(
parent_
,
cond
);
true_value
=
Mutate
(
op
->
true_value
);
}
{
ConstraintContext
constraint
(
parent_
,
Mutate
(
Not
::
make
(
cond
)));
false_value
=
Mutate
(
op
->
false_value
);
}
if
(
is_zero
(
cond
))
{
return
false_value
;
}
if
(
is_one
(
op
->
condition
))
{
return
op
->
true_value
;
if
(
is_one
(
cond
))
{
return
true_value
;
}
// normal path
Expr
ret
;
if
(
cond
.
same_as
(
op
->
condition
)
&&
true_value
.
same_as
(
op
->
true_value
)
&&
false_value
.
same_as
(
op
->
false_value
))
{
ret
=
self
;
}
else
{
ret
=
Select
::
make
(
cond
,
true_value
,
false_value
);
}
op
=
ret
.
as
<
Select
>
();
// Pattern var to match any expression
PVar
<
Expr
>
x
,
y
;
TVM_TRY_REWRITE
(
select
(
x
,
y
,
y
),
y
);
...
...
@@ -1213,7 +1231,37 @@ Mutate_(const Select* op, const Expr& self) {
Expr
RewriteSimplifier
::
Impl
::
Mutate_
(
const
Call
*
op
,
const
Expr
&
self
)
{
Expr
ret
=
IRMutator
::
Mutate_
(
op
,
self
);
// add condition context to if_then_else
Expr
ret
;
if
(
op
->
is_intrinsic
(
ir
::
intrinsic
::
tvm_if_then_else
))
{
Expr
cond
=
Mutate
(
op
->
args
[
0
]);
Expr
true_value
,
false_value
;
{
ConstraintContext
constraint
(
parent_
,
cond
);
true_value
=
Mutate
(
op
->
args
[
1
]);
}
{
ConstraintContext
constraint
(
parent_
,
Mutate
(
Not
::
make
(
cond
)));
false_value
=
Mutate
(
op
->
args
[
2
]);
}
if
(
is_zero
(
cond
))
{
return
false_value
;
}
if
(
is_one
(
cond
))
{
return
true_value
;
}
if
(
cond
.
same_as
(
op
->
args
[
0
])
&&
true_value
.
same_as
(
op
->
args
[
1
])
&&
false_value
.
same_as
(
op
->
args
[
2
]))
{
ret
=
self
;
}
else
{
ret
=
Call
::
make
(
op
->
type
,
op
->
name
,
{
cond
,
true_value
,
false_value
},
op
->
call_type
);
}
}
else
{
ret
=
IRMutator
::
Mutate_
(
op
,
self
);
}
op
=
ret
.
as
<
Call
>
();
if
(
op
->
is_intrinsic
(
Call
::
likely
)
&&
is_const
(
op
->
args
[
0
]))
{
return
op
->
args
[
0
];
...
...
tests/python/unittest/test_arith_canonical_simplify.py
View file @
48c92376
...
...
@@ -22,7 +22,7 @@ class CanonicalChecker:
def
verify
(
self
,
data
,
expected
):
res
=
self
.
analyzer
.
canonical_simplify
(
data
)
assert
tvm
.
ir_pass
.
Equal
(
res
,
expected
),
"
data={}, res={},
expected={}"
.
format
(
data
,
res
,
expected
)
assert
tvm
.
ir_pass
.
Equal
(
res
,
expected
),
"
\n
data={}
\n
res={}
\n
expected={}"
.
format
(
data
,
res
,
expected
)
def
test_mul_sum_simplify
():
...
...
@@ -157,7 +157,38 @@ def test_reduce_simplify():
ck
.
verify
(
tvm
.
sum
(
k
/
10
,
k
),
tvm
.
sum
(
tvm
.
const
(
0
,
"int32"
),
k
))
def
test_simplify_if_then_else
():
ck
=
CanonicalChecker
()
x
=
tvm
.
var
(
"x"
)
y
=
tvm
.
var
(
"y"
)
# simplification that takes condition into account.
res
=
tvm
.
if_then_else
((
x
*
4
+
y
)
>=
466036
,
tvm
.
if_then_else
(
24512
<=
((((
x
*
4
)
+
y
)
-
466036
)
%
24528
),
(((((
x
*
4
)
+
y
)
-
466036
)
%
24528
)
-
24512
)
%
16
,
x
),
y
)
expected
=
tvm
.
if_then_else
(
tvm
.
expr
.
LE
(
466036
,
(
x
*
4
+
y
)),
tvm
.
if_then_else
(
tvm
.
expr
.
LE
(
24512
,
((((
x
*
4
)
+
y
)
-
4
)
%
24528
)),
(((
x
*
4
)
+
y
)
-
4
)
%
16
,
x
),
y
)
ck
.
verify
(
res
,
expected
)
# can only simplify if condition
res
=
tvm
.
expr
.
Select
(
tvm
.
all
(
x
>=
-
1
,
y
>=
0
),
(
x
+
y
+
100
)
%
3
,
(
x
+
100
)
%
3
)
expected
=
tvm
.
expr
.
Select
(
tvm
.
all
(
x
>=
-
1
,
y
>=
0
),
(
x
+
y
+
1
)
%
3
,
(
x
+
100
)
%
3
)
ck
.
verify
(
res
,
ck
.
analyzer
.
canonical_simplify
(
expected
))
res
=
tvm
.
expr
.
Select
(
x
>=
10
,
tvm
.
if_then_else
(
x
/
3
>
2
,
x
,
0
),
0
)
expected
=
tvm
.
expr
.
Select
(
x
>=
10
,
x
,
0
)
ck
.
verify
(
res
,
ck
.
analyzer
.
canonical_simplify
(
expected
))
res
=
tvm
.
expr
.
Select
(
x
>=
10
,
tvm
.
if_then_else
(
x
/
3
<
2
,
x
,
0
),
0
)
ck
.
verify
(
res
,
0
)
if
__name__
==
"__main__"
:
test_simplify_if_then_else
()
test_div_simplify
()
test_reduce_simplify
()
test_reduce_combiner_simplify
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment