Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
4273e461
Unverified
Commit
4273e461
authored
Jul 01, 2019
by
Tianqi Chen
Committed by
GitHub
Jul 01, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Migrate simplifier to new infra. (#3368)
parent
f2a6851a
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
175 additions
and
154 deletions
+175
-154
CMakeLists.txt
+5
-1
include/tvm/arithmetic.h
+6
-3
include/tvm/ir_pass.h
+0
-1
src/arithmetic/analyzer.cc
+1
-0
src/arithmetic/bound_deducer.cc
+84
-61
src/arithmetic/const_fold.h
+4
-3
src/arithmetic/rewrite_simplify.cc
+0
-1
src/arithmetic/stmt_simplify.cc
+8
-33
src/lang/buffer.cc
+3
-2
src/op/scan_op.cc
+3
-3
src/pass/loop_partition.cc
+13
-8
src/pass/narrow_channel_access.cc
+4
-3
src/pass/storage_rewrite.cc
+5
-3
src/pass/vectorize_loop.cc
+6
-3
src/schedule/message_passing.cc
+9
-7
src/schedule/schedule_dataflow_rewrite.cc
+1
-1
tests/cpp/ir_simplify_test.cc
+2
-8
tests/python/unittest/test_arith_deduce_bound.py
+21
-10
tests/python/unittest/test_pass_basic.py
+0
-3
No files found.
CMakeLists.txt
View file @
4273e461
...
@@ -154,7 +154,11 @@ file(GLOB_RECURSE NNVM_COMPILER_SRCS
...
@@ -154,7 +154,11 @@ file(GLOB_RECURSE NNVM_COMPILER_SRCS
file
(
GLOB TOPI_SRCS
file
(
GLOB TOPI_SRCS
topi/src/*.cc
topi/src/*.cc
)
)
file
(
GLOB_RECURSE HALIDEIR_SRCS 3rdparty/HalideIR/src/*.cpp
)
file
(
GLOB_RECURSE HALIDEIR_SRCS
3rdparty/HalideIR/src/base/*.cpp
3rdparty/HalideIR/src/ir/*.cpp
3rdparty/HalideIR/src/tvm/*.cpp
)
list
(
APPEND COMPILER_SRCS
${
HALIDEIR_SRCS
}
)
list
(
APPEND COMPILER_SRCS
${
HALIDEIR_SRCS
}
)
file
(
GLOB RUNTIME_SRCS
file
(
GLOB RUNTIME_SRCS
src/runtime/*.cc
src/runtime/*.cc
...
...
include/tvm/arithmetic.h
View file @
4273e461
...
@@ -623,12 +623,15 @@ IntSet Intersect(const Array<IntSet>& sets);
...
@@ -623,12 +623,15 @@ IntSet Intersect(const Array<IntSet>& sets);
* give the domain of each variables. Return undefined IntSet to
* give the domain of each variables. Return undefined IntSet to
* represent failure.
* represent failure.
*
*
* \note The returned set may be smaller than set that
* contains all possible values of v that satisfies the bound.
*
* \param v The target variable to be deduced.
* \param v The target variable to be deduced.
* \param cond The conditional expression.
* \param cond The conditional expression.
* \param hint_map The domain of variable, used to help deduce.
* \param hint_map The domain of variable, used to help deduce.
* \param relax_map The domain of each variable, used to relax the domain,
* \param relax_map The domain of each variable, used to relax the domain,
* The deduce bound mus
h
implies e for all value in relax_map
* The deduce bound mus
t
implies e for all value in relax_map
* \return An integer set that
can cover all the possible values
.
* \return An integer set that
always satisfies the condition
.
*/
*/
IntSet
DeduceBound
(
Expr
v
,
Expr
cond
,
IntSet
DeduceBound
(
Expr
v
,
Expr
cond
,
const
Map
<
Var
,
IntSet
>&
hint_map
,
const
Map
<
Var
,
IntSet
>&
hint_map
,
...
@@ -641,7 +644,7 @@ IntSet DeduceBound(Expr v, Expr cond,
...
@@ -641,7 +644,7 @@ IntSet DeduceBound(Expr v, Expr cond,
* \param hint_map The domain of variable, used to help deduce.
* \param hint_map The domain of variable, used to help deduce.
* \param relax_map The domain of each variable, used to relax the domain,
* \param relax_map The domain of each variable, used to relax the domain,
* The deduce bound mush implies e for all value in relax_map
* The deduce bound mush implies e for all value in relax_map
* \return An integer set that
can cover all the possible values
.
* \return An integer set that
always satisfies the condition
.
*/
*/
IntSet
DeduceBound
(
Expr
v
,
Expr
cond
,
IntSet
DeduceBound
(
Expr
v
,
Expr
cond
,
const
std
::
unordered_map
<
const
Variable
*
,
IntSet
>&
hint_map
,
const
std
::
unordered_map
<
const
Variable
*
,
IntSet
>&
hint_map
,
...
...
include/tvm/ir_pass.h
View file @
4273e461
...
@@ -27,7 +27,6 @@
...
@@ -27,7 +27,6 @@
#ifndef TVM_IR_PASS_H_
#ifndef TVM_IR_PASS_H_
#define TVM_IR_PASS_H_
#define TVM_IR_PASS_H_
#include <arithmetic/Simplify.h>
#include <unordered_map>
#include <unordered_map>
#include <unordered_set>
#include <unordered_set>
#include <vector>
#include <vector>
...
...
src/arithmetic/analyzer.cc
View file @
4273e461
...
@@ -106,6 +106,7 @@ bool Analyzer::CanProve(const Expr& expr) {
...
@@ -106,6 +106,7 @@ bool Analyzer::CanProve(const Expr& expr) {
Expr
Analyzer
::
Simplify
(
const
Expr
&
expr
)
{
Expr
Analyzer
::
Simplify
(
const
Expr
&
expr
)
{
if
(
is_const
(
expr
))
return
expr
;
if
(
is_const
(
expr
))
return
expr
;
auto
res
=
this
->
rewrite_simplify
(
expr
);
auto
res
=
this
->
rewrite_simplify
(
expr
);
if
(
is_const
(
res
))
return
res
;
res
=
this
->
canonical_simplify
(
res
);
res
=
this
->
canonical_simplify
(
res
);
return
res
;
return
res
;
}
}
...
...
src/arithmetic/bound_deducer.cc
View file @
4273e461
...
@@ -84,11 +84,11 @@ class BoundDeducer: public IRVisitor {
...
@@ -84,11 +84,11 @@ class BoundDeducer: public IRVisitor {
void
Deduce
();
void
Deduce
();
void
Visit
(
const
NodeRef
&
e
)
final
{
void
Visit
(
const
NodeRef
&
e
)
final
{
if
(
!
success
)
return
;
if
(
!
success
_
)
return
;
if
(
e
.
get
()
==
path_
[
iter_
++
])
{
if
(
e
.
get
()
==
path_
[
iter_
++
])
{
IRVisitor
::
Visit
(
e
);
IRVisitor
::
Visit
(
e
);
}
else
{
}
else
{
success
=
false
;
success
_
=
false
;
return
;
return
;
}
}
}
}
...
@@ -111,18 +111,18 @@ class BoundDeducer: public IRVisitor {
...
@@ -111,18 +111,18 @@ class BoundDeducer: public IRVisitor {
void
Visit_
(
const
Add
*
op
)
final
{
void
Visit_
(
const
Add
*
op
)
final
{
bool
left
=
op
->
a
.
get
()
==
path_
[
iter_
];
bool
left
=
op
->
a
.
get
()
==
path_
[
iter_
];
result
-=
left
?
op
->
b
:
op
->
a
;
result
_
-=
left
?
op
->
b
:
op
->
a
;
Visit
(
left
?
op
->
a
:
op
->
b
);
Visit
(
left
?
op
->
a
:
op
->
b
);
}
}
void
Visit_
(
const
Sub
*
op
)
final
{
void
Visit_
(
const
Sub
*
op
)
final
{
bool
left
=
op
->
a
.
get
()
==
path_
[
iter_
];
bool
left
=
op
->
a
.
get
()
==
path_
[
iter_
];
if
(
left
)
{
if
(
left
)
{
result
+=
op
->
b
;
result
_
+=
op
->
b
;
}
else
{
}
else
{
result
-=
op
->
a
;
result
_
-=
op
->
a
;
result
=
-
result
;
result
_
=
-
result_
;
is_greater
=
!
is_greater
;
is_greater
_
=
!
is_greater_
;
}
}
Visit
(
left
?
op
->
a
:
op
->
b
);
Visit
(
left
?
op
->
a
:
op
->
b
);
}
}
...
@@ -130,43 +130,65 @@ class BoundDeducer: public IRVisitor {
...
@@ -130,43 +130,65 @@ class BoundDeducer: public IRVisitor {
void
Visit_
(
const
Mul
*
op
)
final
{
void
Visit_
(
const
Mul
*
op
)
final
{
bool
left
=
op
->
a
.
get
()
==
path_
[
iter_
];
bool
left
=
op
->
a
.
get
()
==
path_
[
iter_
];
Expr
operand
=
left
?
op
->
b
:
op
->
a
;
Expr
operand
=
left
?
op
->
b
:
op
->
a
;
Expr
target_var
=
left
?
op
->
a
:
op
->
b
;
SignType
sign
;
SignType
sign
_operand
;
if
(
operand
.
type
().
is_uint
())
{
if
(
operand
.
type
().
is_uint
())
{
sign
=
kPositive
;
sign
_operand
=
kPositive
;
}
else
{
}
else
{
sign
=
expr_map_
[
operand
].
sign_type
();
sign
_operand
=
expr_map_
[
operand
].
sign_type
();
}
}
if
(
sign
==
SignType
::
kNegative
)
{
if
(
sign
_operand
==
SignType
::
kNegative
)
{
is_greater
=
!
is_greater
;
is_greater
_
=
!
is_greater_
;
}
else
if
(
sign
==
SignType
::
kUnknown
)
{
}
else
if
(
sign
_operand
==
SignType
::
kUnknown
)
{
// unable to get the sign of operand
// unable to get the sign of operand
success
=
false
;
success
_
=
false
;
return
;
return
;
}
}
// always use relax bound
// always use relax bound
bool
divided
=
can_prove
(
result
%
operand
==
0
);
bool
divided
=
analyzer_
.
CanProve
(
result_
%
operand
==
0
);
result
=
result
/
operand
;
// since system will round down when not divided
result_
=
result_
/
operand
;
// eg. 2/4 -> 0; -2/4 -> -1
// no need fix for !is_greater:
if
(
!
divided
)
{
// eg. a <= 2/4 -> a <= 0
// Handle non-divisible case
// eg. a <= 0/4 -> a <= 0
// NOTE: this accounts for truc div behavior.
// so just fix for not divided and is_greater
bool
target_is_non_neg
=
expr_map_
[
target_var
].
can_prove_non_negative
();
// eg. a >= 2/4 -> a >= 0 + 1
// eg. a >= 0/4 -> a >= 0
if
(
is_greater_
)
{
if
(
is_greater
&&
!
divided
)
{
result_
+=
1
;
result
+=
1
;
}
else
{
// NOTE: this is a bit sutble hack.
//
// condition:
// - x * operand <= result
// - operand > 0
// - x >= 0
//
// Then it is fine to deduce that x <= result / operand.
// - if result > 0, this division round down
// - if result < 0, (result / operand) rounds up and may violate the constraint
// however, given that x is always non-negative,
// it is fine to have this relaxed bound, given that the user of deduce bound
// will respect the bound of x
//
// TODO(tvm-team): think about a better API to incorporate constraint of x.
// e.g. specify an interval of x and return a bound
// that is in the interval and satisfies the condition.
if
(
target_is_non_neg
&&
sign_operand
==
kPositive
)
{
// do nothing
}
else
{
result_
-=
1
;
}
}
}
}
Visit
(
left
?
op
->
a
:
op
->
b
);
Visit
(
left
?
op
->
a
:
op
->
b
);
}
}
Expr
result
;
Expr
result
_
;
bool
is_greater
{
true
};
bool
is_greater
_
{
true
};
bool
success
{
true
};
bool
success
_
{
true
};
private
:
private
:
void
Init
();
void
Init
();
...
@@ -180,6 +202,8 @@ class BoundDeducer: public IRVisitor {
...
@@ -180,6 +202,8 @@ class BoundDeducer: public IRVisitor {
ExprIntSetMap
expr_map_
;
ExprIntSetMap
expr_map_
;
std
::
vector
<
const
Node
*>
path_
;
std
::
vector
<
const
Node
*>
path_
;
size_t
iter_
{
0
};
size_t
iter_
{
0
};
// internal analzyer
Analyzer
analyzer_
;
};
};
class
BoundDeduceInputChecker
:
public
IRVisitor
{
class
BoundDeduceInputChecker
:
public
IRVisitor
{
...
@@ -202,7 +226,7 @@ class BoundDeduceInputChecker: public IRVisitor {
...
@@ -202,7 +226,7 @@ class BoundDeduceInputChecker: public IRVisitor {
void
BoundDeducer
::
Init
()
{
void
BoundDeducer
::
Init
()
{
BoundDeduceInputChecker
checker
;
BoundDeduceInputChecker
checker
;
if
(
!
checker
.
Check
(
this
))
success
=
false
;
if
(
!
checker
.
Check
(
this
))
success
_
=
false
;
Transform
();
Transform
();
}
}
...
@@ -211,66 +235,65 @@ void BoundDeducer::Transform() {
...
@@ -211,66 +235,65 @@ void BoundDeducer::Transform() {
if
(
const
LT
*
op
=
expr_
.
as
<
LT
>
())
{
if
(
const
LT
*
op
=
expr_
.
as
<
LT
>
())
{
if
(
GetPath
(
target_
,
op
->
a
).
empty
())
{
if
(
GetPath
(
target_
,
op
->
a
).
empty
())
{
// a < b -> b >= a + 1
// a < b -> b >= a + 1
is_greater
=
true
;
is_greater
_
=
true
;
expr_
=
op
->
b
;
expr_
=
op
->
b
;
result
=
op
->
a
+
1
;
result
_
=
op
->
a
+
1
;
}
else
{
}
else
{
// a < b -> a <= b - 1
// a < b -> a <= b - 1
is_greater
=
false
;
is_greater
_
=
false
;
expr_
=
op
->
a
;
expr_
=
op
->
a
;
result
=
op
->
b
-
1
;
result
_
=
op
->
b
-
1
;
}
}
}
else
if
(
const
LE
*
op
=
expr_
.
as
<
LE
>
())
{
}
else
if
(
const
LE
*
op
=
expr_
.
as
<
LE
>
())
{
if
(
GetPath
(
target_
,
op
->
a
).
empty
())
{
if
(
GetPath
(
target_
,
op
->
a
).
empty
())
{
// a <= b -> b >= a
// a <= b -> b >= a
is_greater
=
true
;
is_greater
_
=
true
;
expr_
=
op
->
b
;
expr_
=
op
->
b
;
result
=
op
->
a
;
result
_
=
op
->
a
;
}
else
{
}
else
{
is_greater
=
false
;
is_greater
_
=
false
;
expr_
=
op
->
a
;
expr_
=
op
->
a
;
result
=
op
->
b
;
result
_
=
op
->
b
;
}
}
}
else
if
(
const
GT
*
op
=
expr_
.
as
<
GT
>
())
{
}
else
if
(
const
GT
*
op
=
expr_
.
as
<
GT
>
())
{
if
(
GetPath
(
target_
,
op
->
a
).
empty
())
{
if
(
GetPath
(
target_
,
op
->
a
).
empty
())
{
// a > b -> b <= a - 1
// a > b -> b <= a - 1
is_greater
=
false
;
is_greater
_
=
false
;
expr_
=
op
->
b
;
expr_
=
op
->
b
;
result
=
op
->
a
-
1
;
result
_
=
op
->
a
-
1
;
}
else
{
}
else
{
// a > b -> a >= b + 1
// a > b -> a >= b + 1
is_greater
=
true
;
is_greater
_
=
true
;
expr_
=
op
->
a
;
expr_
=
op
->
a
;
result
=
op
->
b
+
1
;
result
_
=
op
->
b
+
1
;
}
}
}
else
if
(
const
GE
*
op
=
expr_
.
as
<
GE
>
())
{
}
else
if
(
const
GE
*
op
=
expr_
.
as
<
GE
>
())
{
if
(
GetPath
(
target_
,
op
->
a
).
empty
())
{
if
(
GetPath
(
target_
,
op
->
a
).
empty
())
{
// a >= b -> b <= a
// a >= b -> b <= a
is_greater
=
false
;
is_greater
_
=
false
;
expr_
=
op
->
b
;
expr_
=
op
->
b
;
result
=
op
->
a
;
result
_
=
op
->
a
;
}
else
{
}
else
{
is_greater
=
true
;
is_greater
_
=
true
;
expr_
=
op
->
a
;
expr_
=
op
->
a
;
result
=
op
->
b
;
result
_
=
op
->
b
;
}
}
}
else
{
}
else
{
success
=
false
;
success
_
=
false
;
}
}
}
}
void
BoundDeducer
::
Deduce
()
{
void
BoundDeducer
::
Deduce
()
{
Init
();
Init
();
if
(
!
success
)
return
;
if
(
!
success
_
)
return
;
Relax
();
Relax
();
if
(
!
success
)
return
;
if
(
!
success
_
)
return
;
// get the path
// get the path
path_
=
GetPath
(
target_
,
expr_
);
path_
=
GetPath
(
target_
,
expr_
);
if
(
!
path_
.
size
())
{
if
(
!
path_
.
size
())
{
success
=
false
;
success
_
=
false
;
return
;
return
;
}
}
expr_map_
=
EvalSetForEachSubExpr
(
expr_
,
hint_map_
);
expr_map_
=
EvalSetForEachSubExpr
(
expr_
,
hint_map_
);
Visit
(
expr_
);
Visit
(
expr_
);
...
@@ -278,13 +301,13 @@ void BoundDeducer::Deduce() {
...
@@ -278,13 +301,13 @@ void BoundDeducer::Deduce() {
void
BoundDeducer
::
Relax
()
{
void
BoundDeducer
::
Relax
()
{
IntSet
a
=
EvalSet
(
expr_
,
relax_map_
);
IntSet
a
=
EvalSet
(
expr_
,
relax_map_
);
IntSet
b
=
EvalSet
(
result
,
relax_map_
);
IntSet
b
=
EvalSet
(
result
_
,
relax_map_
);
if
(
a
.
is_everything
()
||
b
.
is_everything
())
{
if
(
a
.
is_everything
()
||
b
.
is_everything
())
{
success
=
false
;
success
_
=
false
;
return
;
return
;
}
}
expr_
=
is_greater
?
a
.
min
()
:
a
.
max
();
expr_
=
is_greater
_
?
a
.
min
()
:
a
.
max
();
result
=
is_greater
?
b
.
max
()
:
b
.
min
();
result
_
=
is_greater_
?
b
.
max
()
:
b
.
min
();
}
}
IntSet
DeduceBound
(
Expr
v
,
Expr
e
,
IntSet
DeduceBound
(
Expr
v
,
Expr
e
,
...
@@ -292,12 +315,12 @@ IntSet DeduceBound(Expr v, Expr e,
...
@@ -292,12 +315,12 @@ IntSet DeduceBound(Expr v, Expr e,
const
std
::
unordered_map
<
const
Variable
*
,
IntSet
>&
relax_map
)
{
const
std
::
unordered_map
<
const
Variable
*
,
IntSet
>&
relax_map
)
{
BoundDeducer
d
(
v
,
e
,
hint_map
,
relax_map
);
BoundDeducer
d
(
v
,
e
,
hint_map
,
relax_map
);
d
.
Deduce
();
d
.
Deduce
();
if
(
!
d
.
success
)
return
IntSet
::
nothing
();
if
(
!
d
.
success
_
)
return
IntSet
::
nothing
();
Expr
min
=
neg_inf
(),
max
=
pos_inf
();
Expr
min
=
neg_inf
(),
max
=
pos_inf
();
if
(
d
.
is_greater
)
{
if
(
d
.
is_greater
_
)
{
min
=
d
.
result
;
min
=
d
.
result
_
;
}
else
{
}
else
{
max
=
d
.
result
;
max
=
d
.
result
_
;
}
}
return
IntSet
::
interval
(
min
,
max
);
return
IntSet
::
interval
(
min
,
max
);
}
}
...
...
src/arithmetic/const_fold.h
View file @
4273e461
...
@@ -155,9 +155,10 @@ template<>
...
@@ -155,9 +155,10 @@ template<>
inline
Expr
TryConstFold
<
ir
::
Div
>
(
Expr
a
,
Expr
b
)
{
inline
Expr
TryConstFold
<
ir
::
Div
>
(
Expr
a
,
Expr
b
)
{
TVM_ARITH_CONST_PROPAGATION
({
TVM_ARITH_CONST_PROPAGATION
({
const
Type
&
rtype
=
a
.
type
();
const
Type
&
rtype
=
a
.
type
();
// due to division and mod can have different modes
if
(
pa
&&
pb
)
{
// only constant fold positive number where rule is fixed.
// due to division and mod can have different modes
if
(
pa
&&
pb
&&
pa
->
value
>=
0
&&
pb
->
value
>
0
)
{
// NOTE: this will assumes truc div.
CHECK_NE
(
pb
->
value
,
0
)
<<
"Divide by zero"
;
return
IntImm
::
make
(
rtype
,
pa
->
value
/
pb
->
value
);
return
IntImm
::
make
(
rtype
,
pa
->
value
/
pb
->
value
);
}
}
if
(
pa
)
{
if
(
pa
)
{
...
...
src/arithmetic/rewrite_simplify.cc
View file @
4273e461
...
@@ -155,7 +155,6 @@ Mutate_(const Add* op, const Expr& self) {
...
@@ -155,7 +155,6 @@ Mutate_(const Add* op, const Expr& self) {
TVM_TRY_REWRITE
(
max
(
x
,
y
-
z
)
+
z
,
max
(
x
+
z
,
y
));
TVM_TRY_REWRITE
(
max
(
x
,
y
-
z
)
+
z
,
max
(
x
+
z
,
y
));
TVM_TRY_REWRITE
(
max
(
x
-
z
,
y
)
+
z
,
max
(
x
,
y
+
z
));
TVM_TRY_REWRITE
(
max
(
x
-
z
,
y
)
+
z
,
max
(
x
,
y
+
z
));
TVM_TRY_REWRITE_IF
(
min
(
x
,
y
+
z
*
c1
)
+
z
*
c2
,
min
(
x
+
z
*
c2
,
y
),
TVM_TRY_REWRITE_IF
(
min
(
x
,
y
+
z
*
c1
)
+
z
*
c2
,
min
(
x
+
z
*
c2
,
y
),
c1
.
Eval
()
->
value
==
-
c2
.
Eval
()
->
value
);
c1
.
Eval
()
->
value
==
-
c2
.
Eval
()
->
value
);
TVM_TRY_REWRITE_IF
(
max
(
x
,
y
+
z
*
c1
)
+
z
*
c2
,
max
(
x
+
z
*
c2
,
y
),
TVM_TRY_REWRITE_IF
(
max
(
x
,
y
+
z
*
c1
)
+
z
*
c2
,
max
(
x
+
z
*
c2
,
y
),
...
...
src/arithmetic/stmt_simplify.cc
View file @
4273e461
...
@@ -28,7 +28,6 @@
...
@@ -28,7 +28,6 @@
#include <tvm/ir_mutator.h>
#include <tvm/ir_mutator.h>
#include <tvm/expr_operator.h>
#include <tvm/expr_operator.h>
#include <tvm/arithmetic.h>
#include <tvm/arithmetic.h>
#include "arithmetic/Simplify.h"
namespace
tvm
{
namespace
tvm
{
namespace
arith
{
namespace
arith
{
...
@@ -158,42 +157,18 @@ Expr CanonicalSimplify(Expr expr, Map<Var, Range> vrange) {
...
@@ -158,42 +157,18 @@ Expr CanonicalSimplify(Expr expr, Map<Var, Range> vrange) {
return
analyzer
.
canonical_simplify
(
expr
);
return
analyzer
.
canonical_simplify
(
expr
);
}
}
template
<
typename
T
>
Expr
Simplify
(
Expr
expr
,
Map
<
Var
,
Range
>
vrange
)
{
T
Simplify_
(
T
a
,
Map
<
Var
,
Range
>
vrange
)
{
arith
::
Analyzer
analyzer
;
using
namespace
HalideIR
::
Internal
;
Scope
<
Interval
>
rscope
;
for
(
auto
kv
:
vrange
)
{
for
(
auto
kv
:
vrange
)
{
Range
r
=
kv
.
second
;
analyzer
.
Bind
(
kv
.
first
,
kv
.
second
);
rscope
.
push
(
kv
.
first
.
get
(),
Interval
(
r
->
min
,
simplify
(
r
->
min
+
r
->
extent
-
make_const
(
r
->
min
.
type
(),
1
))));
}
return
HalideIR
::
Internal
::
simplify
(
a
,
true
,
rscope
);
}
Expr
Simplify
(
Expr
a
,
Map
<
Var
,
Range
>
vrange
)
{
// Simplify top level reduce.
if
(
const
Reduce
*
r
=
a
.
as
<
Reduce
>
())
{
Array
<
Expr
>
new_source
;
for
(
auto
&
e
:
r
->
source
)
{
new_source
.
push_back
(
Simplify_
(
e
,
vrange
));
}
Expr
new_condition
=
Simplify_
(
r
->
condition
,
vrange
);
if
(
r
->
source
.
same_as
(
new_source
)
&&
r
->
condition
.
same_as
(
new_condition
))
{
return
a
;
}
else
{
return
Reduce
::
make
(
r
->
combiner
,
new_source
,
r
->
axis
,
new_condition
,
r
->
value_index
);
}
}
}
return
Simplify_
(
a
,
vrange
);
expr
=
analyzer
.
Simplify
(
expr
);
return
expr
;
}
}
Stmt
Simplify
(
Stmt
a
,
Map
<
Var
,
Range
>
vrange
)
{
Stmt
Simplify
(
Stmt
stmt
,
Map
<
Var
,
Range
>
vrange
)
{
return
Simplify_
(
a
,
vrange
);
return
arith
::
CanonicalStmtSimplifier
().
CanonicalSimplify
(
stmt
,
vrange
);
}
}
}
// namespace ir
}
// namespace ir
}
// namespace tvm
}
// namespace tvm
src/lang/buffer.cc
View file @
4273e461
...
@@ -6,9 +6,9 @@
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
@@ -26,6 +26,7 @@
...
@@ -26,6 +26,7 @@
#include <tvm/ir.h>
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_pass.h>
#include <iterator>
#include <iterator>
#include <stack>
#include "../arithmetic/compute_expr.h"
#include "../arithmetic/compute_expr.h"
namespace
tvm
{
namespace
tvm
{
...
...
src/op/scan_op.cc
View file @
4273e461
...
@@ -6,9 +6,9 @@
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
@@ -80,7 +80,7 @@ Operation ScanOpNode::make(std::string name,
...
@@ -80,7 +80,7 @@ Operation ScanOpNode::make(std::string name,
for
(
size_t
i
=
0
;
i
<
init
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
init
.
size
();
++
i
)
{
CHECK_EQ
(
init
[
i
]
->
dtype
,
state_placeholder
[
i
]
->
dtype
);
CHECK_EQ
(
init
[
i
]
->
dtype
,
state_placeholder
[
i
]
->
dtype
);
CHECK_EQ
(
init
[
i
]
->
dtype
,
update
[
i
]
->
dtype
);
CHECK_EQ
(
init
[
i
]
->
dtype
,
update
[
i
]
->
dtype
);
CHECK
(
can_prove
(
init
[
i
]
->
shape
[
0
]
==
axis
->
dom
->
min
))
CHECK
(
prove_equal
(
init
[
i
]
->
shape
[
0
],
axis
->
dom
->
min
))
<<
"init.shape[0] need to match scan_axis.dom.min"
;
<<
"init.shape[0] need to match scan_axis.dom.min"
;
CHECK
(
prove_equal
(
CHECK
(
prove_equal
(
state_placeholder
[
i
]
->
shape
[
0
],
axis
->
dom
->
min
+
axis
->
dom
->
extent
))
state_placeholder
[
i
]
->
shape
[
0
],
axis
->
dom
->
min
+
axis
->
dom
->
extent
))
...
...
src/pass/loop_partition.cc
View file @
4273e461
...
@@ -466,8 +466,13 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
...
@@ -466,8 +466,13 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
Stmt
body
,
Stmt
body
,
bool
partition_thread_scope
)
{
bool
partition_thread_scope
)
{
using
namespace
arith
;
using
namespace
arith
;
// include hint of var.
hint_map_
.
insert
({
var
.
get
(),
IntSet
::
interval
(
min
,
max
)});
PartitionFinder
finder
(
var
,
hint_map_
,
relax_map_
);
PartitionFinder
finder
(
var
,
hint_map_
,
relax_map_
);
finder
.
Visit
(
body
);
finder
.
Visit
(
body
);
hint_map_
.
erase
(
var
.
get
());
if
(
finder
.
partitions
.
empty
())
return
Stmt
();
if
(
finder
.
partitions
.
empty
())
return
Stmt
();
arith
::
IntervalSet
for_interval
(
min
,
max
);
arith
::
IntervalSet
for_interval
(
min
,
max
);
...
@@ -504,9 +509,9 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
...
@@ -504,9 +509,9 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
bool
pre_stmt_recurse
=
true
;
bool
pre_stmt_recurse
=
true
;
if
(
middle_interval_i
->
HasLowerBound
())
{
if
(
middle_interval_i
->
HasLowerBound
())
{
body_begin
=
ir
::
Simplify
(
middle_interval
.
min
());
body_begin
=
ir
::
Simplify
(
middle_interval
.
min
());
if
(
!
can_p
rove
(
body_begin
==
min
))
{
if
(
!
analyzer_
.
CanP
rove
(
body_begin
==
min
))
{
Expr
cond
=
(
body_begin
-
min
>=
0
);
Expr
cond
=
(
body_begin
-
min
>=
0
);
if
(
!
can_p
rove
(
cond
))
{
if
(
!
analyzer_
.
CanP
rove
(
cond
))
{
LOG
(
WARNING
)
<<
"Cannot prove: "
<<
cond
LOG
(
WARNING
)
<<
"Cannot prove: "
<<
cond
<<
", when generating the pre doubt loop"
;
<<
", when generating the pre doubt loop"
;
body_begin
=
Max
::
make
(
body_begin
,
min
);
body_begin
=
Max
::
make
(
body_begin
,
min
);
...
@@ -529,10 +534,10 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
...
@@ -529,10 +534,10 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
bool
post_stmt_recurse
=
true
;
bool
post_stmt_recurse
=
true
;
if
(
middle_interval_i
->
HasUpperBound
())
{
if
(
middle_interval_i
->
HasUpperBound
())
{
post_doubt_begin
=
ir
::
Simplify
(
middle_interval
.
max
()
+
1
);
post_doubt_begin
=
ir
::
Simplify
(
middle_interval
.
max
()
+
1
);
if
(
!
can_p
rove
(
middle_interval
.
max
()
==
max
))
{
if
(
!
analyzer_
.
CanP
rove
(
middle_interval
.
max
()
==
max
))
{
// require the extent to be non-negative
// require the extent to be non-negative
Expr
cond
=
(
max
-
post_doubt_begin
+
1
>=
0
);
Expr
cond
=
(
max
-
post_doubt_begin
+
1
>=
0
);
if
(
!
can_p
rove
(
cond
))
{
if
(
!
analyzer_
.
CanP
rove
(
cond
))
{
LOG
(
WARNING
)
<<
"Cannot prove: "
<<
cond
LOG
(
WARNING
)
<<
"Cannot prove: "
<<
cond
<<
", when generating the post doubt loop"
;
<<
", when generating the post doubt loop"
;
post_doubt_begin
=
Min
::
make
(
post_doubt_begin
,
max
);
post_doubt_begin
=
Min
::
make
(
post_doubt_begin
,
max
);
...
@@ -554,7 +559,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
...
@@ -554,7 +559,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
// Generating code for middle subrange
// Generating code for middle subrange
if
(
!
partition_thread_scope
)
{
if
(
!
partition_thread_scope
)
{
Stmt
mid_stmt
;
Stmt
mid_stmt
;
if
(
!
can_p
rove
(
body_begin
>=
post_doubt_begin
))
{
if
(
!
analyzer_
.
CanP
rove
(
body_begin
>=
post_doubt_begin
))
{
// [body_begin, post_doubt_begin)
// [body_begin, post_doubt_begin)
Stmt
simplified_body
=
ConditionEliminator
(
cond_set
,
cond_value
).
Mutate
(
body
);
Stmt
simplified_body
=
ConditionEliminator
(
cond_set
,
cond_value
).
Mutate
(
body
);
Stmt
new_body
=
Substitute
(
simplified_body
,
{{
Var
{
var
},
var
+
body_begin
}});
Stmt
new_body
=
Substitute
(
simplified_body
,
{{
Var
{
var
},
var
+
body_begin
}});
...
@@ -576,8 +581,8 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
...
@@ -576,8 +581,8 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
s
=
AppendStmts
(
s
,
post_stmt
);
s
=
AppendStmts
(
s
,
post_stmt
);
}
else
{
}
else
{
Expr
cond
=
const_true
();
Expr
cond
=
const_true
();
if
(
!
can_p
rove
(
body_begin
==
min
))
cond
=
cond
&&
(
var
>=
body_begin
);
if
(
!
analyzer_
.
CanP
rove
(
body_begin
==
min
))
cond
=
cond
&&
(
var
>=
body_begin
);
if
(
!
can_p
rove
(
post_doubt_begin
==
(
max
+
1
)))
cond
=
cond
&&
(
var
<
post_doubt_begin
);
if
(
!
analyzer_
.
CanP
rove
(
post_doubt_begin
==
(
max
+
1
)))
cond
=
cond
&&
(
var
<
post_doubt_begin
);
s
=
ThreadPartitionInserter
(
cond_set
,
cond
).
Mutate
(
stmt
);
s
=
ThreadPartitionInserter
(
cond_set
,
cond
).
Mutate
(
stmt
);
}
}
s
=
ConvertSSA
(
s
);
s
=
ConvertSSA
(
s
);
...
@@ -587,7 +592,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
...
@@ -587,7 +592,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
inline
Stmt
LoopPartitioner
::
MakeFor
(
const
Node
*
node
,
Expr
extent
,
Stmt
body
)
{
inline
Stmt
LoopPartitioner
::
MakeFor
(
const
Node
*
node
,
Expr
extent
,
Stmt
body
)
{
const
For
*
for_node
=
static_cast
<
const
For
*>
(
node
);
const
For
*
for_node
=
static_cast
<
const
For
*>
(
node
);
CHECK
(
for_node
);
CHECK
(
for_node
);
if
(
can_p
rove
(
extent
==
make_const
(
Int
(
32
),
1
)))
{
if
(
analyzer_
.
CanP
rove
(
extent
==
make_const
(
Int
(
32
),
1
)))
{
// If the loop extent is 1, do not create the loop anymore
// If the loop extent is 1, do not create the loop anymore
return
Substitute
(
body
,
{{
Var
{
for_node
->
loop_var
},
make_const
(
Int
(
32
),
0
)}});
return
Substitute
(
body
,
{{
Var
{
for_node
->
loop_var
},
make_const
(
Int
(
32
),
0
)}});
}
else
{
}
else
{
...
...
src/pass/narrow_channel_access.cc
View file @
4273e461
...
@@ -6,9 +6,9 @@
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
@@ -200,7 +200,7 @@ class ChannelAccessRewriter : public IRMutator {
...
@@ -200,7 +200,7 @@ class ChannelAccessRewriter : public IRMutator {
Expr
base
=
linear_eq
[
1
];
Expr
base
=
linear_eq
[
1
];
if
(
!
is_zero
(
base
))
return
body
;
if
(
!
is_zero
(
base
))
return
body
;
Expr
left
=
ir
::
Simplify
(
adv_op
->
value
-
coeff
*
for_op
->
extent
);
Expr
left
=
ir
::
Simplify
(
adv_op
->
value
-
coeff
*
for_op
->
extent
);
if
(
!
can_p
rove
(
left
>=
0
))
return
body
;
if
(
!
analyzer_
.
CanP
rove
(
left
>=
0
))
return
body
;
// rewrite access index.
// rewrite access index.
ChannelAccessIndexRewriter
rw
(
ChannelAccessIndexRewriter
rw
(
ch
->
handle_var
.
get
(),
var
*
coeff
,
read_access
);
ch
->
handle_var
.
get
(),
var
*
coeff
,
read_access
);
...
@@ -233,6 +233,7 @@ class ChannelAccessRewriter : public IRMutator {
...
@@ -233,6 +233,7 @@ class ChannelAccessRewriter : public IRMutator {
return
body
;
return
body
;
}
}
arith
::
Analyzer
analyzer_
;
std
::
vector
<
RewriteEntry
>
tasks_
;
std
::
vector
<
RewriteEntry
>
tasks_
;
};
};
...
...
src/pass/storage_rewrite.cc
View file @
4273e461
...
@@ -6,9 +6,9 @@
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
@@ -606,7 +606,7 @@ class StoragePlanRewriter : public IRMutator {
...
@@ -606,7 +606,7 @@ class StoragePlanRewriter : public IRMutator {
}
}
// transform to alloc bytes
// transform to alloc bytes
auto
type_bits
=
alloc_type
.
bits
()
*
alloc_type
.
lanes
();
auto
type_bits
=
alloc_type
.
bits
()
*
alloc_type
.
lanes
();
bool
divided
=
can_p
rove
(
combo_size
%
type_bits
==
0
);
bool
divided
=
analyzer_
.
CanP
rove
(
combo_size
%
type_bits
==
0
);
combo_size
=
combo_size
/
type_bits
;
combo_size
=
combo_size
/
type_bits
;
// round up for can not divided
// round up for can not divided
if
(
!
divided
)
{
if
(
!
divided
)
{
...
@@ -920,6 +920,8 @@ class StoragePlanRewriter : public IRMutator {
...
@@ -920,6 +920,8 @@ class StoragePlanRewriter : public IRMutator {
std
::
unordered_map
<
const
Variable
*
,
StorageEntry
*>
alloc_map_
;
std
::
unordered_map
<
const
Variable
*
,
StorageEntry
*>
alloc_map_
;
// The allocations
// The allocations
std
::
vector
<
std
::
unique_ptr
<
StorageEntry
>
>
alloc_vec_
;
std
::
vector
<
std
::
unique_ptr
<
StorageEntry
>
>
alloc_vec_
;
// analyzer
arith
::
Analyzer
analyzer_
;
};
};
// Turn alloc into vector alloc
// Turn alloc into vector alloc
...
...
src/pass/vectorize_loop.cc
View file @
4273e461
...
@@ -25,6 +25,7 @@
...
@@ -25,6 +25,7 @@
#include <tvm/ir.h>
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_mutator.h>
#include <tvm/arithmetic.h>
#include <unordered_set>
#include <unordered_set>
#include <unordered_map>
#include <unordered_map>
#include <vector>
#include <vector>
...
@@ -132,11 +133,11 @@ class Vectorizer : public IRMutator {
...
@@ -132,11 +133,11 @@ class Vectorizer : public IRMutator {
if
(
lanes
!=
1
)
{
if
(
lanes
!=
1
)
{
const
Ramp
*
b_ramp
=
b
.
as
<
Ramp
>
();
const
Ramp
*
b_ramp
=
b
.
as
<
Ramp
>
();
const
Ramp
*
a_ramp
=
a
.
as
<
Ramp
>
();
const
Ramp
*
a_ramp
=
a
.
as
<
Ramp
>
();
if
(
a_ramp
&&
b
.
type
().
lanes
()
==
1
&&
can_p
rove
(
b
>
0
))
{
if
(
a_ramp
&&
b
.
type
().
lanes
()
==
1
&&
analyzer_
.
CanP
rove
(
b
>
0
))
{
return
Ramp
::
make
(
return
Ramp
::
make
(
a_ramp
->
base
*
b
,
a_ramp
->
stride
*
b
,
a_ramp
->
lanes
);
a_ramp
->
base
*
b
,
a_ramp
->
stride
*
b
,
a_ramp
->
lanes
);
}
}
if
(
b_ramp
&&
a
.
type
().
lanes
()
==
1
&&
can_p
rove
(
a
>
0
))
{
if
(
b_ramp
&&
a
.
type
().
lanes
()
==
1
&&
analyzer_
.
CanP
rove
(
a
>
0
))
{
return
Ramp
::
make
(
return
Ramp
::
make
(
b_ramp
->
base
*
a
,
b_ramp
->
stride
*
a
,
b_ramp
->
lanes
);
b_ramp
->
base
*
a
,
b_ramp
->
stride
*
a
,
b_ramp
->
lanes
);
}
}
...
@@ -186,7 +187,7 @@ class Vectorizer : public IRMutator {
...
@@ -186,7 +187,7 @@ class Vectorizer : public IRMutator {
Expr
stride
=
this
->
Mutate
(
op
->
stride
);
Expr
stride
=
this
->
Mutate
(
op
->
stride
);
if
(
base
.
type
().
lanes
()
>
1
&&
stride
.
type
().
lanes
()
==
1
)
{
if
(
base
.
type
().
lanes
()
>
1
&&
stride
.
type
().
lanes
()
==
1
)
{
const
Ramp
*
base_ramp
=
base
.
as
<
Ramp
>
();
const
Ramp
*
base_ramp
=
base
.
as
<
Ramp
>
();
if
(
can_p
rove
(
base_ramp
->
stride
==
stride
*
make_const
(
stride
.
type
(),
op
->
lanes
)))
{
if
(
analyzer_
.
CanP
rove
(
base_ramp
->
stride
==
stride
*
make_const
(
stride
.
type
(),
op
->
lanes
)))
{
return
Ramp
::
make
(
base_ramp
->
base
,
stride
,
op
->
lanes
*
base_ramp
->
lanes
);
return
Ramp
::
make
(
base_ramp
->
base
,
stride
,
op
->
lanes
*
base_ramp
->
lanes
);
}
}
}
}
...
@@ -423,6 +424,8 @@ class Vectorizer : public IRMutator {
...
@@ -423,6 +424,8 @@ class Vectorizer : public IRMutator {
}
}
private
:
private
:
// analyzer
arith
::
Analyzer
analyzer_
;
// variable to be replaced
// variable to be replaced
Var
var_
;
Var
var_
;
// the lanes.
// the lanes.
...
...
src/schedule/message_passing.cc
View file @
4273e461
...
@@ -432,9 +432,9 @@ void PassDownBitMaskOr(const Stage& stage,
...
@@ -432,9 +432,9 @@ void PassDownBitMaskOr(const Stage& stage,
*/
*/
void
PassUpBoundCheck
(
const
Stage
&
s
,
void
PassUpBoundCheck
(
const
Stage
&
s
,
const
Map
<
IterVar
,
Range
>&
dom_map
,
const
Map
<
IterVar
,
Range
>&
dom_map
,
std
::
unordered_map
<
IterVar
,
bool
>*
p_state
)
{
std
::
unordered_map
<
IterVar
,
bool
>*
p_state
,
arith
::
Analyzer
*
analyzer
)
{
auto
&
state
=
*
p_state
;
auto
&
state
=
*
p_state
;
using
HalideIR
::
Internal
::
can_prove
;
for
(
size_t
i
=
s
->
relations
.
size
();
i
!=
0
;
--
i
)
{
for
(
size_t
i
=
s
->
relations
.
size
();
i
!=
0
;
--
i
)
{
IterVarRelation
rel
=
s
->
relations
[
i
-
1
];
IterVarRelation
rel
=
s
->
relations
[
i
-
1
];
if
(
const
SplitNode
*
s
=
rel
.
as
<
SplitNode
>
())
{
if
(
const
SplitNode
*
s
=
rel
.
as
<
SplitNode
>
())
{
...
@@ -447,7 +447,7 @@ void PassUpBoundCheck(const Stage& s,
...
@@ -447,7 +447,7 @@ void PassUpBoundCheck(const Stage& s,
if
(
outer
||
inner
)
{
if
(
outer
||
inner
)
{
state
[
s
->
parent
]
=
true
;
state
[
s
->
parent
]
=
true
;
}
else
{
}
else
{
if
(
can_p
rove
(
dom_map
.
at
(
s
->
parent
)
->
extent
==
factor
*
step
))
{
if
(
analyzer
->
CanP
rove
(
dom_map
.
at
(
s
->
parent
)
->
extent
==
factor
*
step
))
{
state
[
s
->
parent
]
=
false
;
state
[
s
->
parent
]
=
false
;
}
else
{
}
else
{
state
[
s
->
parent
]
=
true
;
state
[
s
->
parent
]
=
true
;
...
@@ -476,11 +476,13 @@ std::vector<Expr> MakeBoundCheck(
...
@@ -476,11 +476,13 @@ std::vector<Expr> MakeBoundCheck(
const
std
::
unordered_map
<
IterVar
,
Expr
>&
value_map
,
const
std
::
unordered_map
<
IterVar
,
Expr
>&
value_map
,
bool
skip_ivar_domain
,
bool
skip_ivar_domain
,
const
std
::
unordered_set
<
IterVar
>&
skip_iter
)
{
const
std
::
unordered_set
<
IterVar
>&
skip_iter
)
{
Analyzer
analyzer
;
std
::
unordered_map
<
IterVar
,
bool
>
bound_state
;
std
::
unordered_map
<
IterVar
,
bool
>
bound_state
;
for
(
IterVar
iv
:
stage
->
leaf_iter_vars
)
{
for
(
IterVar
iv
:
stage
->
leaf_iter_vars
)
{
bound_state
[
iv
]
=
false
;
bound_state
[
iv
]
=
false
;
}
}
PassUpBoundCheck
(
stage
,
dom_map
,
&
bound_state
);
PassUpBoundCheck
(
stage
,
dom_map
,
&
bound_state
,
&
analyzer
);
std
::
vector
<
Expr
>
preds
;
std
::
vector
<
Expr
>
preds
;
std
::
unordered_map
<
const
Variable
*
,
IntSet
>
iset_dmap
;
std
::
unordered_map
<
const
Variable
*
,
IntSet
>
iset_dmap
;
...
@@ -496,7 +498,7 @@ std::vector<Expr> MakeBoundCheck(
...
@@ -496,7 +498,7 @@ std::vector<Expr> MakeBoundCheck(
Range
dom
=
dom_map
.
at
(
iv
);
Range
dom
=
dom_map
.
at
(
iv
);
Expr
value
=
ComputeExpr
<
Sub
>
(
value_map
.
at
(
iv
),
dom
->
min
);
Expr
value
=
ComputeExpr
<
Sub
>
(
value_map
.
at
(
iv
),
dom
->
min
);
Expr
vmax
=
EvalSet
(
value
,
iset_dmap
).
max
();
Expr
vmax
=
EvalSet
(
value
,
iset_dmap
).
max
();
if
(
vmax
.
type
()
!=
value
.
type
()
||
!
can_p
rove
(
vmax
<
dom
->
extent
))
{
if
(
vmax
.
type
()
!=
value
.
type
()
||
!
analyzer
.
CanP
rove
(
vmax
<
dom
->
extent
))
{
preds
.
emplace_back
(
value
<
dom
->
extent
);
preds
.
emplace_back
(
value
<
dom
->
extent
);
}
}
}
}
...
@@ -511,10 +513,10 @@ std::vector<Expr> MakeBoundCheck(
...
@@ -511,10 +513,10 @@ std::vector<Expr> MakeBoundCheck(
Expr
vmin
=
s
.
min
();
Expr
vmin
=
s
.
min
();
Expr
vmax
=
s
.
max
();
Expr
vmax
=
s
.
max
();
// The range of `value` resides in [vmin, vmax]
// The range of `value` resides in [vmin, vmax]
if
(
vmin
.
type
()
!=
value
.
type
()
||
!
can_p
rove
(
vmin
>=
0
))
{
if
(
vmin
.
type
()
!=
value
.
type
()
||
!
analyzer
.
CanP
rove
(
vmin
>=
0
))
{
preds
.
emplace_back
(
value
>=
0
);
preds
.
emplace_back
(
value
>=
0
);
}
}
if
(
vmax
.
type
()
!=
value
.
type
()
||
!
can_p
rove
(
vmax
<
iv
->
dom
->
extent
))
{
if
(
vmax
.
type
()
!=
value
.
type
()
||
!
analyzer
.
CanP
rove
(
vmax
<
iv
->
dom
->
extent
))
{
preds
.
emplace_back
(
value
<
iv
->
dom
->
extent
);
preds
.
emplace_back
(
value
<
iv
->
dom
->
extent
);
}
}
}
}
...
...
src/schedule/schedule_dataflow_rewrite.cc
View file @
4273e461
...
@@ -740,7 +740,7 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor,
...
@@ -740,7 +740,7 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor,
const
Reduce
*
reduce
=
compute_op
->
body
[
idx
].
as
<
Reduce
>
();
const
Reduce
*
reduce
=
compute_op
->
body
[
idx
].
as
<
Reduce
>
();
CHECK
(
reduce
)
<<
"Can only rfactor non-inline reductions"
;
CHECK
(
reduce
)
<<
"Can only rfactor non-inline reductions"
;
predicates
.
push_back
(
reduce
->
condition
);
predicates
.
push_back
(
reduce
->
condition
);
Expr
predicate
=
likely
(
simplify
(
arith
::
ComputeReduce
<
ir
::
And
>
(
predicates
,
Expr
()
)));
Expr
predicate
=
likely
(
arith
::
ComputeReduce
<
ir
::
And
>
(
predicates
,
Expr
(
)));
std
::
unordered_map
<
const
Variable
*
,
Expr
>
vsub
;
std
::
unordered_map
<
const
Variable
*
,
Expr
>
vsub
;
...
...
tests/cpp/ir_simplify_test.cc
View file @
4273e461
...
@@ -6,9 +6,9 @@
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
@@ -21,12 +21,6 @@
...
@@ -21,12 +21,6 @@
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_pass.h>
#include <tvm/tvm.h>
#include <tvm/tvm.h>
#include <arithmetic/Simplify.h>
TEST
(
IRSIMPLIFY
,
Basic
)
{
using
namespace
HalideIR
::
Internal
;
simplify_test
();
}
TEST
(
IRSIMPLIFY
,
MinMax
)
{
TEST
(
IRSIMPLIFY
,
MinMax
)
{
auto
x
=
tvm
::
var
(
"x"
);
auto
x
=
tvm
::
var
(
"x"
);
...
...
tests/python/unittest/test_arith_deduce_bound.py
View file @
4273e461
...
@@ -16,6 +16,14 @@
...
@@ -16,6 +16,14 @@
# under the License.
# under the License.
import
tvm
import
tvm
def
assert_expr_equal
(
a
,
b
):
res
=
tvm
.
ir_pass
.
Simplify
(
a
-
b
)
equal
=
isinstance
(
res
,
tvm
.
expr
.
IntImm
)
and
res
.
value
==
0
if
not
equal
:
raise
ValueError
(
"{} and {} are not equal"
.
format
(
a
,
b
))
def
test_deduce
():
def
test_deduce
():
a
=
tvm
.
var
(
'a'
)
a
=
tvm
.
var
(
'a'
)
b
=
tvm
.
var
(
'b'
)
b
=
tvm
.
var
(
'b'
)
...
@@ -29,31 +37,34 @@ def test_deduce():
...
@@ -29,31 +37,34 @@ def test_deduce():
e0
=
(
-
b
)
*
a
+
c
-
d
e0
=
(
-
b
)
*
a
+
c
-
d
res0
=
tvm
.
arith
.
DeduceBound
(
a
,
e0
>=
0
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
res0
=
tvm
.
arith
.
DeduceBound
(
a
,
e0
>=
0
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
ans0
=
((
d
-
c
)
/
(
b
*-
1
))
ans0
=
((
d
-
c
)
/
(
b
*-
1
)
+
(
-
1
)
)
assert
str
(
tvm
.
ir_pass
.
Simplify
(
res0
.
max_value
))
==
str
(
ans0
)
assert
_expr_equal
(
res0
.
max_value
,
ans0
)
# expression containing variable a is on rhs
# expression containing variable a is on rhs
res0
=
tvm
.
arith
.
DeduceBound
(
a
,
zero
<=
e0
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
res0
=
tvm
.
arith
.
DeduceBound
(
a
,
zero
<=
e0
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
assert
str
(
tvm
.
ir_pass
.
Simplify
(
res0
.
max_value
))
==
str
(
ans0
)
assert
_expr_equal
(
res0
.
max_value
,
ans0
)
e0
=
d
*
a
+
c
-
d
e0
=
d
*
a
+
c
-
d
res0
=
tvm
.
arith
.
DeduceBound
(
a
,
e0
>=
0
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
res0
=
tvm
.
arith
.
DeduceBound
(
a
,
e0
>=
0
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
ans0
=
((
0
-
c
)
/
d
+
1
)
ans0
=
((
d
-
c
)
/
d
-
1
)
assert
str
(
tvm
.
ir_pass
.
Simplify
(
res0
.
max_value
))
==
str
(
ans0
)
assert
_expr_equal
(
res0
.
max_value
,
ans0
)
# expression containing variable a is on rhs
# expression containing variable a is on rhs
res0
=
tvm
.
arith
.
DeduceBound
(
a
,
zero
<=
e0
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
res0
=
tvm
.
arith
.
DeduceBound
(
a
,
zero
<=
e0
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
assert
str
(
tvm
.
ir_pass
.
Simplify
(
res0
.
max_value
))
==
str
(
ans0
)
assert_expr_equal
(
res0
.
max_value
,
ans0
)
e1
=
(
a
*
4
+
b
<
c
)
e1
=
(
a
*
4
+
b
<
c
)
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
e1
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
e1
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
ans1
=
(((
c
-
b
)
+
-
1
)
/
4
)
ans1
=
(((
c
-
b
)
+
-
1
)
/
4
-
1
)
assert
str
(
tvm
.
ir_pass
.
Simplify
(
res1
.
max_value
))
==
str
(
ans1
)
assert_expr_equal
(
res1
.
max_value
,
ans1
)
# expression containing variable a is on rhs
# expression containing variable a is on rhs
e1
=
(
c
>
a
*
4
+
b
)
e1
=
(
c
>
a
*
4
+
b
)
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
e1
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
e1
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
assert
str
(
tvm
.
ir_pass
.
Simplify
(
res1
.
max_value
))
==
str
(
ans1
)
assert_expr_equal
(
res1
.
max_value
,
ans1
)
e2
=
(
tvm
.
max
(
5
,
a
*
4
)
<
0
)
e2
=
(
tvm
.
max
(
5
,
a
*
4
)
<
0
)
res2
=
tvm
.
arith
.
DeduceBound
(
a
,
e2
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
res2
=
tvm
.
arith
.
DeduceBound
(
a
,
e2
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
...
@@ -66,7 +77,6 @@ def test_deduce():
...
@@ -66,7 +77,6 @@ def test_deduce():
assert
str
(
res2
.
max_value
)
==
"neg_inf"
assert
str
(
res2
.
max_value
)
==
"neg_inf"
assert
str
(
res2
.
min_value
)
==
"pos_inf"
assert
str
(
res2
.
min_value
)
==
"pos_inf"
e3
=
(
-
b
)
+
a
*
c
-
d
e3
=
(
-
b
)
+
a
*
c
-
d
res3
=
tvm
.
arith
.
DeduceBound
(
a
,
e3
>=
0
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{
b
:
b_s
,
d
:
d_s
})
res3
=
tvm
.
arith
.
DeduceBound
(
a
,
e3
>=
0
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{
b
:
b_s
,
d
:
d_s
})
ans3
=
2
/
c
+
1
ans3
=
2
/
c
+
1
...
@@ -75,6 +85,7 @@ def test_deduce():
...
@@ -75,6 +85,7 @@ def test_deduce():
res3
=
tvm
.
arith
.
DeduceBound
(
a
,
zero
<=
e3
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{
b
:
b_s
,
d
:
d_s
})
res3
=
tvm
.
arith
.
DeduceBound
(
a
,
zero
<=
e3
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{
b
:
b_s
,
d
:
d_s
})
assert
str
(
tvm
.
ir_pass
.
Simplify
(
res3
.
min_value
))
==
str
(
ans3
)
assert
str
(
tvm
.
ir_pass
.
Simplify
(
res3
.
min_value
))
==
str
(
ans3
)
def
test_check
():
def
test_check
():
a
=
tvm
.
var
(
'a'
)
a
=
tvm
.
var
(
'a'
)
b
=
tvm
.
var
(
'b'
)
b
=
tvm
.
var
(
'b'
)
...
...
tests/python/unittest/test_pass_basic.py
View file @
4273e461
...
@@ -24,9 +24,6 @@ def test_simplify():
...
@@ -24,9 +24,6 @@ def test_simplify():
assert
(
tvm
.
ir_pass
.
Equal
(
e2
,
x
*
8
))
assert
(
tvm
.
ir_pass
.
Equal
(
e2
,
x
*
8
))
e3
=
tvm
.
ir_pass
.
Simplify
(
x
-
x
/
3
*
3
)
e3
=
tvm
.
ir_pass
.
Simplify
(
x
-
x
/
3
*
3
)
assert
(
tvm
.
ir_pass
.
Equal
(
e3
,
tvm
.
make
.
Mod
(
x
,
3
)))
assert
(
tvm
.
ir_pass
.
Equal
(
e3
,
tvm
.
make
.
Mod
(
x
,
3
)))
let
=
tvm
.
make
.
Let
(
x
,
1
,
x
+
3
)
e4
=
tvm
.
ir_pass
.
Simplify
(
let
)
assert
(
tvm
.
ir_pass
.
Equal
(
e4
,
4
))
def
test_verify_ssa
():
def
test_verify_ssa
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment