Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
4273e461
Unverified
Commit
4273e461
authored
Jul 01, 2019
by
Tianqi Chen
Committed by
GitHub
Jul 01, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Migrate simplifier to new infra. (#3368)
parent
f2a6851a
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
175 additions
and
154 deletions
+175
-154
CMakeLists.txt
+5
-1
include/tvm/arithmetic.h
+6
-3
include/tvm/ir_pass.h
+0
-1
src/arithmetic/analyzer.cc
+1
-0
src/arithmetic/bound_deducer.cc
+84
-61
src/arithmetic/const_fold.h
+4
-3
src/arithmetic/rewrite_simplify.cc
+0
-1
src/arithmetic/stmt_simplify.cc
+8
-33
src/lang/buffer.cc
+3
-2
src/op/scan_op.cc
+3
-3
src/pass/loop_partition.cc
+13
-8
src/pass/narrow_channel_access.cc
+4
-3
src/pass/storage_rewrite.cc
+5
-3
src/pass/vectorize_loop.cc
+6
-3
src/schedule/message_passing.cc
+9
-7
src/schedule/schedule_dataflow_rewrite.cc
+1
-1
tests/cpp/ir_simplify_test.cc
+2
-8
tests/python/unittest/test_arith_deduce_bound.py
+21
-10
tests/python/unittest/test_pass_basic.py
+0
-3
No files found.
CMakeLists.txt
View file @
4273e461
...
...
@@ -154,7 +154,11 @@ file(GLOB_RECURSE NNVM_COMPILER_SRCS
file
(
GLOB TOPI_SRCS
topi/src/*.cc
)
file
(
GLOB_RECURSE HALIDEIR_SRCS 3rdparty/HalideIR/src/*.cpp
)
file
(
GLOB_RECURSE HALIDEIR_SRCS
3rdparty/HalideIR/src/base/*.cpp
3rdparty/HalideIR/src/ir/*.cpp
3rdparty/HalideIR/src/tvm/*.cpp
)
list
(
APPEND COMPILER_SRCS
${
HALIDEIR_SRCS
}
)
file
(
GLOB RUNTIME_SRCS
src/runtime/*.cc
...
...
include/tvm/arithmetic.h
View file @
4273e461
...
...
@@ -623,12 +623,15 @@ IntSet Intersect(const Array<IntSet>& sets);
* give the domain of each variables. Return undefined IntSet to
* represent failure.
*
* \note The returned set may be smaller than set that
* contains all possible values of v that satisfies the bound.
*
* \param v The target variable to be deduced.
* \param cond The conditional expression.
* \param hint_map The domain of variable, used to help deduce.
* \param relax_map The domain of each variable, used to relax the domain,
* The deduce bound mus
h
implies e for all value in relax_map
* \return An integer set that
can cover all the possible values
.
* The deduce bound mus
t
implies e for all value in relax_map
* \return An integer set that
always satisfies the condition
.
*/
IntSet
DeduceBound
(
Expr
v
,
Expr
cond
,
const
Map
<
Var
,
IntSet
>&
hint_map
,
...
...
@@ -641,7 +644,7 @@ IntSet DeduceBound(Expr v, Expr cond,
* \param hint_map The domain of variable, used to help deduce.
* \param relax_map The domain of each variable, used to relax the domain,
* The deduce bound mush implies e for all value in relax_map
* \return An integer set that
can cover all the possible values
.
* \return An integer set that
always satisfies the condition
.
*/
IntSet
DeduceBound
(
Expr
v
,
Expr
cond
,
const
std
::
unordered_map
<
const
Variable
*
,
IntSet
>&
hint_map
,
...
...
include/tvm/ir_pass.h
View file @
4273e461
...
...
@@ -27,7 +27,6 @@
#ifndef TVM_IR_PASS_H_
#define TVM_IR_PASS_H_
#include <arithmetic/Simplify.h>
#include <unordered_map>
#include <unordered_set>
#include <vector>
...
...
src/arithmetic/analyzer.cc
View file @
4273e461
...
...
@@ -106,6 +106,7 @@ bool Analyzer::CanProve(const Expr& expr) {
Expr
Analyzer
::
Simplify
(
const
Expr
&
expr
)
{
if
(
is_const
(
expr
))
return
expr
;
auto
res
=
this
->
rewrite_simplify
(
expr
);
if
(
is_const
(
res
))
return
res
;
res
=
this
->
canonical_simplify
(
res
);
return
res
;
}
...
...
src/arithmetic/bound_deducer.cc
View file @
4273e461
...
...
@@ -84,11 +84,11 @@ class BoundDeducer: public IRVisitor {
void
Deduce
();
void
Visit
(
const
NodeRef
&
e
)
final
{
if
(
!
success
)
return
;
if
(
!
success
_
)
return
;
if
(
e
.
get
()
==
path_
[
iter_
++
])
{
IRVisitor
::
Visit
(
e
);
}
else
{
success
=
false
;
success
_
=
false
;
return
;
}
}
...
...
@@ -111,18 +111,18 @@ class BoundDeducer: public IRVisitor {
void
Visit_
(
const
Add
*
op
)
final
{
bool
left
=
op
->
a
.
get
()
==
path_
[
iter_
];
result
-=
left
?
op
->
b
:
op
->
a
;
result
_
-=
left
?
op
->
b
:
op
->
a
;
Visit
(
left
?
op
->
a
:
op
->
b
);
}
void
Visit_
(
const
Sub
*
op
)
final
{
bool
left
=
op
->
a
.
get
()
==
path_
[
iter_
];
if
(
left
)
{
result
+=
op
->
b
;
result
_
+=
op
->
b
;
}
else
{
result
-=
op
->
a
;
result
=
-
result
;
is_greater
=
!
is_greater
;
result
_
-=
op
->
a
;
result
_
=
-
result_
;
is_greater
_
=
!
is_greater_
;
}
Visit
(
left
?
op
->
a
:
op
->
b
);
}
...
...
@@ -130,43 +130,65 @@ class BoundDeducer: public IRVisitor {
void
Visit_
(
const
Mul
*
op
)
final
{
bool
left
=
op
->
a
.
get
()
==
path_
[
iter_
];
Expr
operand
=
left
?
op
->
b
:
op
->
a
;
Expr
target_var
=
left
?
op
->
a
:
op
->
b
;
SignType
sign
;
SignType
sign
_operand
;
if
(
operand
.
type
().
is_uint
())
{
sign
=
kPositive
;
sign
_operand
=
kPositive
;
}
else
{
sign
=
expr_map_
[
operand
].
sign_type
();
sign
_operand
=
expr_map_
[
operand
].
sign_type
();
}
if
(
sign
==
SignType
::
kNegative
)
{
is_greater
=
!
is_greater
;
}
else
if
(
sign
==
SignType
::
kUnknown
)
{
if
(
sign
_operand
==
SignType
::
kNegative
)
{
is_greater
_
=
!
is_greater_
;
}
else
if
(
sign
_operand
==
SignType
::
kUnknown
)
{
// unable to get the sign of operand
success
=
false
;
success
_
=
false
;
return
;
}
// always use relax bound
bool
divided
=
can_prove
(
result
%
operand
==
0
);
result
=
result
/
operand
;
// since system will round down when not divided
// eg. 2/4 -> 0; -2/4 -> -1
// no need fix for !is_greater:
// eg. a <= 2/4 -> a <= 0
// eg. a <= 0/4 -> a <= 0
// so just fix for not divided and is_greater
// eg. a >= 2/4 -> a >= 0 + 1
// eg. a >= 0/4 -> a >= 0
if
(
is_greater
&&
!
divided
)
{
result
+=
1
;
bool
divided
=
analyzer_
.
CanProve
(
result_
%
operand
==
0
);
result_
=
result_
/
operand
;
if
(
!
divided
)
{
// Handle non-divisible case
// NOTE: this accounts for truc div behavior.
bool
target_is_non_neg
=
expr_map_
[
target_var
].
can_prove_non_negative
();
if
(
is_greater_
)
{
result_
+=
1
;
}
else
{
// NOTE: this is a bit sutble hack.
//
// condition:
// - x * operand <= result
// - operand > 0
// - x >= 0
//
// Then it is fine to deduce that x <= result / operand.
// - if result > 0, this division round down
// - if result < 0, (result / operand) rounds up and may violate the constraint
// however, given that x is always non-negative,
// it is fine to have this relaxed bound, given that the user of deduce bound
// will respect the bound of x
//
// TODO(tvm-team): think about a better API to incorporate constraint of x.
// e.g. specify an interval of x and return a bound
// that is in the interval and satisfies the condition.
if
(
target_is_non_neg
&&
sign_operand
==
kPositive
)
{
// do nothing
}
else
{
result_
-=
1
;
}
}
}
Visit
(
left
?
op
->
a
:
op
->
b
);
}
Expr
result
;
bool
is_greater
{
true
};
bool
success
{
true
};
Expr
result
_
;
bool
is_greater
_
{
true
};
bool
success
_
{
true
};
private
:
void
Init
();
...
...
@@ -180,6 +202,8 @@ class BoundDeducer: public IRVisitor {
ExprIntSetMap
expr_map_
;
std
::
vector
<
const
Node
*>
path_
;
size_t
iter_
{
0
};
// internal analzyer
Analyzer
analyzer_
;
};
class
BoundDeduceInputChecker
:
public
IRVisitor
{
...
...
@@ -202,7 +226,7 @@ class BoundDeduceInputChecker: public IRVisitor {
void
BoundDeducer
::
Init
()
{
BoundDeduceInputChecker
checker
;
if
(
!
checker
.
Check
(
this
))
success
=
false
;
if
(
!
checker
.
Check
(
this
))
success
_
=
false
;
Transform
();
}
...
...
@@ -211,66 +235,65 @@ void BoundDeducer::Transform() {
if
(
const
LT
*
op
=
expr_
.
as
<
LT
>
())
{
if
(
GetPath
(
target_
,
op
->
a
).
empty
())
{
// a < b -> b >= a + 1
is_greater
=
true
;
is_greater
_
=
true
;
expr_
=
op
->
b
;
result
=
op
->
a
+
1
;
result
_
=
op
->
a
+
1
;
}
else
{
// a < b -> a <= b - 1
is_greater
=
false
;
is_greater
_
=
false
;
expr_
=
op
->
a
;
result
=
op
->
b
-
1
;
result
_
=
op
->
b
-
1
;
}
}
else
if
(
const
LE
*
op
=
expr_
.
as
<
LE
>
())
{
if
(
GetPath
(
target_
,
op
->
a
).
empty
())
{
// a <= b -> b >= a
is_greater
=
true
;
is_greater
_
=
true
;
expr_
=
op
->
b
;
result
=
op
->
a
;
result
_
=
op
->
a
;
}
else
{
is_greater
=
false
;
is_greater
_
=
false
;
expr_
=
op
->
a
;
result
=
op
->
b
;
result
_
=
op
->
b
;
}
}
else
if
(
const
GT
*
op
=
expr_
.
as
<
GT
>
())
{
if
(
GetPath
(
target_
,
op
->
a
).
empty
())
{
// a > b -> b <= a - 1
is_greater
=
false
;
is_greater
_
=
false
;
expr_
=
op
->
b
;
result
=
op
->
a
-
1
;
result
_
=
op
->
a
-
1
;
}
else
{
// a > b -> a >= b + 1
is_greater
=
true
;
is_greater
_
=
true
;
expr_
=
op
->
a
;
result
=
op
->
b
+
1
;
result
_
=
op
->
b
+
1
;
}
}
else
if
(
const
GE
*
op
=
expr_
.
as
<
GE
>
())
{
if
(
GetPath
(
target_
,
op
->
a
).
empty
())
{
// a >= b -> b <= a
is_greater
=
false
;
is_greater
_
=
false
;
expr_
=
op
->
b
;
result
=
op
->
a
;
result
_
=
op
->
a
;
}
else
{
is_greater
=
true
;
is_greater
_
=
true
;
expr_
=
op
->
a
;
result
=
op
->
b
;
result
_
=
op
->
b
;
}
}
else
{
success
=
false
;
success
_
=
false
;
}
}
void
BoundDeducer
::
Deduce
()
{
Init
();
if
(
!
success
)
return
;
if
(
!
success
_
)
return
;
Relax
();
if
(
!
success
)
return
;
if
(
!
success
_
)
return
;
// get the path
path_
=
GetPath
(
target_
,
expr_
);
if
(
!
path_
.
size
())
{
success
=
false
;
success
_
=
false
;
return
;
}
expr_map_
=
EvalSetForEachSubExpr
(
expr_
,
hint_map_
);
Visit
(
expr_
);
...
...
@@ -278,13 +301,13 @@ void BoundDeducer::Deduce() {
void
BoundDeducer
::
Relax
()
{
IntSet
a
=
EvalSet
(
expr_
,
relax_map_
);
IntSet
b
=
EvalSet
(
result
,
relax_map_
);
IntSet
b
=
EvalSet
(
result
_
,
relax_map_
);
if
(
a
.
is_everything
()
||
b
.
is_everything
())
{
success
=
false
;
success
_
=
false
;
return
;
}
expr_
=
is_greater
?
a
.
min
()
:
a
.
max
();
result
=
is_greater
?
b
.
max
()
:
b
.
min
();
expr_
=
is_greater
_
?
a
.
min
()
:
a
.
max
();
result
_
=
is_greater_
?
b
.
max
()
:
b
.
min
();
}
IntSet
DeduceBound
(
Expr
v
,
Expr
e
,
...
...
@@ -292,12 +315,12 @@ IntSet DeduceBound(Expr v, Expr e,
const
std
::
unordered_map
<
const
Variable
*
,
IntSet
>&
relax_map
)
{
BoundDeducer
d
(
v
,
e
,
hint_map
,
relax_map
);
d
.
Deduce
();
if
(
!
d
.
success
)
return
IntSet
::
nothing
();
if
(
!
d
.
success
_
)
return
IntSet
::
nothing
();
Expr
min
=
neg_inf
(),
max
=
pos_inf
();
if
(
d
.
is_greater
)
{
min
=
d
.
result
;
if
(
d
.
is_greater
_
)
{
min
=
d
.
result
_
;
}
else
{
max
=
d
.
result
;
max
=
d
.
result
_
;
}
return
IntSet
::
interval
(
min
,
max
);
}
...
...
src/arithmetic/const_fold.h
View file @
4273e461
...
...
@@ -155,9 +155,10 @@ template<>
inline
Expr
TryConstFold
<
ir
::
Div
>
(
Expr
a
,
Expr
b
)
{
TVM_ARITH_CONST_PROPAGATION
({
const
Type
&
rtype
=
a
.
type
();
// due to division and mod can have different modes
// only constant fold positive number where rule is fixed.
if
(
pa
&&
pb
&&
pa
->
value
>=
0
&&
pb
->
value
>
0
)
{
if
(
pa
&&
pb
)
{
// due to division and mod can have different modes
// NOTE: this will assumes truc div.
CHECK_NE
(
pb
->
value
,
0
)
<<
"Divide by zero"
;
return
IntImm
::
make
(
rtype
,
pa
->
value
/
pb
->
value
);
}
if
(
pa
)
{
...
...
src/arithmetic/rewrite_simplify.cc
View file @
4273e461
...
...
@@ -155,7 +155,6 @@ Mutate_(const Add* op, const Expr& self) {
TVM_TRY_REWRITE
(
max
(
x
,
y
-
z
)
+
z
,
max
(
x
+
z
,
y
));
TVM_TRY_REWRITE
(
max
(
x
-
z
,
y
)
+
z
,
max
(
x
,
y
+
z
));
TVM_TRY_REWRITE_IF
(
min
(
x
,
y
+
z
*
c1
)
+
z
*
c2
,
min
(
x
+
z
*
c2
,
y
),
c1
.
Eval
()
->
value
==
-
c2
.
Eval
()
->
value
);
TVM_TRY_REWRITE_IF
(
max
(
x
,
y
+
z
*
c1
)
+
z
*
c2
,
max
(
x
+
z
*
c2
,
y
),
...
...
src/arithmetic/stmt_simplify.cc
View file @
4273e461
...
...
@@ -28,7 +28,6 @@
#include <tvm/ir_mutator.h>
#include <tvm/expr_operator.h>
#include <tvm/arithmetic.h>
#include "arithmetic/Simplify.h"
namespace
tvm
{
namespace
arith
{
...
...
@@ -158,42 +157,18 @@ Expr CanonicalSimplify(Expr expr, Map<Var, Range> vrange) {
return
analyzer
.
canonical_simplify
(
expr
);
}
template
<
typename
T
>
T
Simplify_
(
T
a
,
Map
<
Var
,
Range
>
vrange
)
{
using
namespace
HalideIR
::
Internal
;
Scope
<
Interval
>
rscope
;
Expr
Simplify
(
Expr
expr
,
Map
<
Var
,
Range
>
vrange
)
{
arith
::
Analyzer
analyzer
;
for
(
auto
kv
:
vrange
)
{
Range
r
=
kv
.
second
;
rscope
.
push
(
kv
.
first
.
get
(),
Interval
(
r
->
min
,
simplify
(
r
->
min
+
r
->
extent
-
make_const
(
r
->
min
.
type
(),
1
))));
}
return
HalideIR
::
Internal
::
simplify
(
a
,
true
,
rscope
);
}
Expr
Simplify
(
Expr
a
,
Map
<
Var
,
Range
>
vrange
)
{
// Simplify top level reduce.
if
(
const
Reduce
*
r
=
a
.
as
<
Reduce
>
())
{
Array
<
Expr
>
new_source
;
for
(
auto
&
e
:
r
->
source
)
{
new_source
.
push_back
(
Simplify_
(
e
,
vrange
));
}
Expr
new_condition
=
Simplify_
(
r
->
condition
,
vrange
);
if
(
r
->
source
.
same_as
(
new_source
)
&&
r
->
condition
.
same_as
(
new_condition
))
{
return
a
;
}
else
{
return
Reduce
::
make
(
r
->
combiner
,
new_source
,
r
->
axis
,
new_condition
,
r
->
value_index
);
}
analyzer
.
Bind
(
kv
.
first
,
kv
.
second
);
}
return
Simplify_
(
a
,
vrange
);
expr
=
analyzer
.
Simplify
(
expr
);
return
expr
;
}
Stmt
Simplify
(
Stmt
a
,
Map
<
Var
,
Range
>
vrange
)
{
return
Simplify_
(
a
,
vrange
);
Stmt
Simplify
(
Stmt
stmt
,
Map
<
Var
,
Range
>
vrange
)
{
return
arith
::
CanonicalStmtSimplifier
().
CanonicalSimplify
(
stmt
,
vrange
);
}
}
// namespace ir
}
// namespace tvm
src/lang/buffer.cc
View file @
4273e461
...
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
...
@@ -26,6 +26,7 @@
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <iterator>
#include <stack>
#include "../arithmetic/compute_expr.h"
namespace
tvm
{
...
...
src/op/scan_op.cc
View file @
4273e461
...
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
...
@@ -80,7 +80,7 @@ Operation ScanOpNode::make(std::string name,
for
(
size_t
i
=
0
;
i
<
init
.
size
();
++
i
)
{
CHECK_EQ
(
init
[
i
]
->
dtype
,
state_placeholder
[
i
]
->
dtype
);
CHECK_EQ
(
init
[
i
]
->
dtype
,
update
[
i
]
->
dtype
);
CHECK
(
can_prove
(
init
[
i
]
->
shape
[
0
]
==
axis
->
dom
->
min
))
CHECK
(
prove_equal
(
init
[
i
]
->
shape
[
0
],
axis
->
dom
->
min
))
<<
"init.shape[0] need to match scan_axis.dom.min"
;
CHECK
(
prove_equal
(
state_placeholder
[
i
]
->
shape
[
0
],
axis
->
dom
->
min
+
axis
->
dom
->
extent
))
...
...
src/pass/loop_partition.cc
View file @
4273e461
...
...
@@ -466,8 +466,13 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
Stmt
body
,
bool
partition_thread_scope
)
{
using
namespace
arith
;
// include hint of var.
hint_map_
.
insert
({
var
.
get
(),
IntSet
::
interval
(
min
,
max
)});
PartitionFinder
finder
(
var
,
hint_map_
,
relax_map_
);
finder
.
Visit
(
body
);
hint_map_
.
erase
(
var
.
get
());
if
(
finder
.
partitions
.
empty
())
return
Stmt
();
arith
::
IntervalSet
for_interval
(
min
,
max
);
...
...
@@ -504,9 +509,9 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
bool
pre_stmt_recurse
=
true
;
if
(
middle_interval_i
->
HasLowerBound
())
{
body_begin
=
ir
::
Simplify
(
middle_interval
.
min
());
if
(
!
can_p
rove
(
body_begin
==
min
))
{
if
(
!
analyzer_
.
CanP
rove
(
body_begin
==
min
))
{
Expr
cond
=
(
body_begin
-
min
>=
0
);
if
(
!
can_p
rove
(
cond
))
{
if
(
!
analyzer_
.
CanP
rove
(
cond
))
{
LOG
(
WARNING
)
<<
"Cannot prove: "
<<
cond
<<
", when generating the pre doubt loop"
;
body_begin
=
Max
::
make
(
body_begin
,
min
);
...
...
@@ -529,10 +534,10 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
bool
post_stmt_recurse
=
true
;
if
(
middle_interval_i
->
HasUpperBound
())
{
post_doubt_begin
=
ir
::
Simplify
(
middle_interval
.
max
()
+
1
);
if
(
!
can_p
rove
(
middle_interval
.
max
()
==
max
))
{
if
(
!
analyzer_
.
CanP
rove
(
middle_interval
.
max
()
==
max
))
{
// require the extent to be non-negative
Expr
cond
=
(
max
-
post_doubt_begin
+
1
>=
0
);
if
(
!
can_p
rove
(
cond
))
{
if
(
!
analyzer_
.
CanP
rove
(
cond
))
{
LOG
(
WARNING
)
<<
"Cannot prove: "
<<
cond
<<
", when generating the post doubt loop"
;
post_doubt_begin
=
Min
::
make
(
post_doubt_begin
,
max
);
...
...
@@ -554,7 +559,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
// Generating code for middle subrange
if
(
!
partition_thread_scope
)
{
Stmt
mid_stmt
;
if
(
!
can_p
rove
(
body_begin
>=
post_doubt_begin
))
{
if
(
!
analyzer_
.
CanP
rove
(
body_begin
>=
post_doubt_begin
))
{
// [body_begin, post_doubt_begin)
Stmt
simplified_body
=
ConditionEliminator
(
cond_set
,
cond_value
).
Mutate
(
body
);
Stmt
new_body
=
Substitute
(
simplified_body
,
{{
Var
{
var
},
var
+
body_begin
}});
...
...
@@ -576,8 +581,8 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
s
=
AppendStmts
(
s
,
post_stmt
);
}
else
{
Expr
cond
=
const_true
();
if
(
!
can_p
rove
(
body_begin
==
min
))
cond
=
cond
&&
(
var
>=
body_begin
);
if
(
!
can_p
rove
(
post_doubt_begin
==
(
max
+
1
)))
cond
=
cond
&&
(
var
<
post_doubt_begin
);
if
(
!
analyzer_
.
CanP
rove
(
body_begin
==
min
))
cond
=
cond
&&
(
var
>=
body_begin
);
if
(
!
analyzer_
.
CanP
rove
(
post_doubt_begin
==
(
max
+
1
)))
cond
=
cond
&&
(
var
<
post_doubt_begin
);
s
=
ThreadPartitionInserter
(
cond_set
,
cond
).
Mutate
(
stmt
);
}
s
=
ConvertSSA
(
s
);
...
...
@@ -587,7 +592,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
inline
Stmt
LoopPartitioner
::
MakeFor
(
const
Node
*
node
,
Expr
extent
,
Stmt
body
)
{
const
For
*
for_node
=
static_cast
<
const
For
*>
(
node
);
CHECK
(
for_node
);
if
(
can_p
rove
(
extent
==
make_const
(
Int
(
32
),
1
)))
{
if
(
analyzer_
.
CanP
rove
(
extent
==
make_const
(
Int
(
32
),
1
)))
{
// If the loop extent is 1, do not create the loop anymore
return
Substitute
(
body
,
{{
Var
{
for_node
->
loop_var
},
make_const
(
Int
(
32
),
0
)}});
}
else
{
...
...
src/pass/narrow_channel_access.cc
View file @
4273e461
...
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
...
@@ -200,7 +200,7 @@ class ChannelAccessRewriter : public IRMutator {
Expr
base
=
linear_eq
[
1
];
if
(
!
is_zero
(
base
))
return
body
;
Expr
left
=
ir
::
Simplify
(
adv_op
->
value
-
coeff
*
for_op
->
extent
);
if
(
!
can_p
rove
(
left
>=
0
))
return
body
;
if
(
!
analyzer_
.
CanP
rove
(
left
>=
0
))
return
body
;
// rewrite access index.
ChannelAccessIndexRewriter
rw
(
ch
->
handle_var
.
get
(),
var
*
coeff
,
read_access
);
...
...
@@ -233,6 +233,7 @@ class ChannelAccessRewriter : public IRMutator {
return
body
;
}
arith
::
Analyzer
analyzer_
;
std
::
vector
<
RewriteEntry
>
tasks_
;
};
...
...
src/pass/storage_rewrite.cc
View file @
4273e461
...
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
...
@@ -606,7 +606,7 @@ class StoragePlanRewriter : public IRMutator {
}
// transform to alloc bytes
auto
type_bits
=
alloc_type
.
bits
()
*
alloc_type
.
lanes
();
bool
divided
=
can_p
rove
(
combo_size
%
type_bits
==
0
);
bool
divided
=
analyzer_
.
CanP
rove
(
combo_size
%
type_bits
==
0
);
combo_size
=
combo_size
/
type_bits
;
// round up for can not divided
if
(
!
divided
)
{
...
...
@@ -920,6 +920,8 @@ class StoragePlanRewriter : public IRMutator {
std
::
unordered_map
<
const
Variable
*
,
StorageEntry
*>
alloc_map_
;
// The allocations
std
::
vector
<
std
::
unique_ptr
<
StorageEntry
>
>
alloc_vec_
;
// analyzer
arith
::
Analyzer
analyzer_
;
};
// Turn alloc into vector alloc
...
...
src/pass/vectorize_loop.cc
View file @
4273e461
...
...
@@ -25,6 +25,7 @@
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include <tvm/arithmetic.h>
#include <unordered_set>
#include <unordered_map>
#include <vector>
...
...
@@ -132,11 +133,11 @@ class Vectorizer : public IRMutator {
if
(
lanes
!=
1
)
{
const
Ramp
*
b_ramp
=
b
.
as
<
Ramp
>
();
const
Ramp
*
a_ramp
=
a
.
as
<
Ramp
>
();
if
(
a_ramp
&&
b
.
type
().
lanes
()
==
1
&&
can_p
rove
(
b
>
0
))
{
if
(
a_ramp
&&
b
.
type
().
lanes
()
==
1
&&
analyzer_
.
CanP
rove
(
b
>
0
))
{
return
Ramp
::
make
(
a_ramp
->
base
*
b
,
a_ramp
->
stride
*
b
,
a_ramp
->
lanes
);
}
if
(
b_ramp
&&
a
.
type
().
lanes
()
==
1
&&
can_p
rove
(
a
>
0
))
{
if
(
b_ramp
&&
a
.
type
().
lanes
()
==
1
&&
analyzer_
.
CanP
rove
(
a
>
0
))
{
return
Ramp
::
make
(
b_ramp
->
base
*
a
,
b_ramp
->
stride
*
a
,
b_ramp
->
lanes
);
}
...
...
@@ -186,7 +187,7 @@ class Vectorizer : public IRMutator {
Expr
stride
=
this
->
Mutate
(
op
->
stride
);
if
(
base
.
type
().
lanes
()
>
1
&&
stride
.
type
().
lanes
()
==
1
)
{
const
Ramp
*
base_ramp
=
base
.
as
<
Ramp
>
();
if
(
can_p
rove
(
base_ramp
->
stride
==
stride
*
make_const
(
stride
.
type
(),
op
->
lanes
)))
{
if
(
analyzer_
.
CanP
rove
(
base_ramp
->
stride
==
stride
*
make_const
(
stride
.
type
(),
op
->
lanes
)))
{
return
Ramp
::
make
(
base_ramp
->
base
,
stride
,
op
->
lanes
*
base_ramp
->
lanes
);
}
}
...
...
@@ -423,6 +424,8 @@ class Vectorizer : public IRMutator {
}
private
:
// analyzer
arith
::
Analyzer
analyzer_
;
// variable to be replaced
Var
var_
;
// the lanes.
...
...
src/schedule/message_passing.cc
View file @
4273e461
...
...
@@ -432,9 +432,9 @@ void PassDownBitMaskOr(const Stage& stage,
*/
void
PassUpBoundCheck
(
const
Stage
&
s
,
const
Map
<
IterVar
,
Range
>&
dom_map
,
std
::
unordered_map
<
IterVar
,
bool
>*
p_state
)
{
std
::
unordered_map
<
IterVar
,
bool
>*
p_state
,
arith
::
Analyzer
*
analyzer
)
{
auto
&
state
=
*
p_state
;
using
HalideIR
::
Internal
::
can_prove
;
for
(
size_t
i
=
s
->
relations
.
size
();
i
!=
0
;
--
i
)
{
IterVarRelation
rel
=
s
->
relations
[
i
-
1
];
if
(
const
SplitNode
*
s
=
rel
.
as
<
SplitNode
>
())
{
...
...
@@ -447,7 +447,7 @@ void PassUpBoundCheck(const Stage& s,
if
(
outer
||
inner
)
{
state
[
s
->
parent
]
=
true
;
}
else
{
if
(
can_p
rove
(
dom_map
.
at
(
s
->
parent
)
->
extent
==
factor
*
step
))
{
if
(
analyzer
->
CanP
rove
(
dom_map
.
at
(
s
->
parent
)
->
extent
==
factor
*
step
))
{
state
[
s
->
parent
]
=
false
;
}
else
{
state
[
s
->
parent
]
=
true
;
...
...
@@ -476,11 +476,13 @@ std::vector<Expr> MakeBoundCheck(
const
std
::
unordered_map
<
IterVar
,
Expr
>&
value_map
,
bool
skip_ivar_domain
,
const
std
::
unordered_set
<
IterVar
>&
skip_iter
)
{
Analyzer
analyzer
;
std
::
unordered_map
<
IterVar
,
bool
>
bound_state
;
for
(
IterVar
iv
:
stage
->
leaf_iter_vars
)
{
bound_state
[
iv
]
=
false
;
}
PassUpBoundCheck
(
stage
,
dom_map
,
&
bound_state
);
PassUpBoundCheck
(
stage
,
dom_map
,
&
bound_state
,
&
analyzer
);
std
::
vector
<
Expr
>
preds
;
std
::
unordered_map
<
const
Variable
*
,
IntSet
>
iset_dmap
;
...
...
@@ -496,7 +498,7 @@ std::vector<Expr> MakeBoundCheck(
Range
dom
=
dom_map
.
at
(
iv
);
Expr
value
=
ComputeExpr
<
Sub
>
(
value_map
.
at
(
iv
),
dom
->
min
);
Expr
vmax
=
EvalSet
(
value
,
iset_dmap
).
max
();
if
(
vmax
.
type
()
!=
value
.
type
()
||
!
can_p
rove
(
vmax
<
dom
->
extent
))
{
if
(
vmax
.
type
()
!=
value
.
type
()
||
!
analyzer
.
CanP
rove
(
vmax
<
dom
->
extent
))
{
preds
.
emplace_back
(
value
<
dom
->
extent
);
}
}
...
...
@@ -511,10 +513,10 @@ std::vector<Expr> MakeBoundCheck(
Expr
vmin
=
s
.
min
();
Expr
vmax
=
s
.
max
();
// The range of `value` resides in [vmin, vmax]
if
(
vmin
.
type
()
!=
value
.
type
()
||
!
can_p
rove
(
vmin
>=
0
))
{
if
(
vmin
.
type
()
!=
value
.
type
()
||
!
analyzer
.
CanP
rove
(
vmin
>=
0
))
{
preds
.
emplace_back
(
value
>=
0
);
}
if
(
vmax
.
type
()
!=
value
.
type
()
||
!
can_p
rove
(
vmax
<
iv
->
dom
->
extent
))
{
if
(
vmax
.
type
()
!=
value
.
type
()
||
!
analyzer
.
CanP
rove
(
vmax
<
iv
->
dom
->
extent
))
{
preds
.
emplace_back
(
value
<
iv
->
dom
->
extent
);
}
}
...
...
src/schedule/schedule_dataflow_rewrite.cc
View file @
4273e461
...
...
@@ -740,7 +740,7 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor,
const
Reduce
*
reduce
=
compute_op
->
body
[
idx
].
as
<
Reduce
>
();
CHECK
(
reduce
)
<<
"Can only rfactor non-inline reductions"
;
predicates
.
push_back
(
reduce
->
condition
);
Expr
predicate
=
likely
(
simplify
(
arith
::
ComputeReduce
<
ir
::
And
>
(
predicates
,
Expr
()
)));
Expr
predicate
=
likely
(
arith
::
ComputeReduce
<
ir
::
And
>
(
predicates
,
Expr
(
)));
std
::
unordered_map
<
const
Variable
*
,
Expr
>
vsub
;
...
...
tests/cpp/ir_simplify_test.cc
View file @
4273e461
...
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
...
@@ -21,12 +21,6 @@
#include <gtest/gtest.h>
#include <tvm/ir_pass.h>
#include <tvm/tvm.h>
#include <arithmetic/Simplify.h>
TEST
(
IRSIMPLIFY
,
Basic
)
{
using
namespace
HalideIR
::
Internal
;
simplify_test
();
}
TEST
(
IRSIMPLIFY
,
MinMax
)
{
auto
x
=
tvm
::
var
(
"x"
);
...
...
tests/python/unittest/test_arith_deduce_bound.py
View file @
4273e461
...
...
@@ -16,6 +16,14 @@
# under the License.
import
tvm
def
assert_expr_equal
(
a
,
b
):
res
=
tvm
.
ir_pass
.
Simplify
(
a
-
b
)
equal
=
isinstance
(
res
,
tvm
.
expr
.
IntImm
)
and
res
.
value
==
0
if
not
equal
:
raise
ValueError
(
"{} and {} are not equal"
.
format
(
a
,
b
))
def
test_deduce
():
a
=
tvm
.
var
(
'a'
)
b
=
tvm
.
var
(
'b'
)
...
...
@@ -29,31 +37,34 @@ def test_deduce():
e0
=
(
-
b
)
*
a
+
c
-
d
res0
=
tvm
.
arith
.
DeduceBound
(
a
,
e0
>=
0
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
ans0
=
((
d
-
c
)
/
(
b
*-
1
))
assert
str
(
tvm
.
ir_pass
.
Simplify
(
res0
.
max_value
))
==
str
(
ans0
)
ans0
=
((
d
-
c
)
/
(
b
*-
1
)
+
(
-
1
)
)
assert
_expr_equal
(
res0
.
max_value
,
ans0
)
# expression containing variable a is on rhs
res0
=
tvm
.
arith
.
DeduceBound
(
a
,
zero
<=
e0
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
assert
str
(
tvm
.
ir_pass
.
Simplify
(
res0
.
max_value
))
==
str
(
ans0
)
assert
_expr_equal
(
res0
.
max_value
,
ans0
)
e0
=
d
*
a
+
c
-
d
res0
=
tvm
.
arith
.
DeduceBound
(
a
,
e0
>=
0
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
ans0
=
((
0
-
c
)
/
d
+
1
)
assert
str
(
tvm
.
ir_pass
.
Simplify
(
res0
.
max_value
))
==
str
(
ans0
)
ans0
=
((
d
-
c
)
/
d
-
1
)
assert
_expr_equal
(
res0
.
max_value
,
ans0
)
# expression containing variable a is on rhs
res0
=
tvm
.
arith
.
DeduceBound
(
a
,
zero
<=
e0
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
assert
str
(
tvm
.
ir_pass
.
Simplify
(
res0
.
max_value
))
==
str
(
ans0
)
assert_expr_equal
(
res0
.
max_value
,
ans0
)
e1
=
(
a
*
4
+
b
<
c
)
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
e1
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
ans1
=
(((
c
-
b
)
+
-
1
)
/
4
)
assert
str
(
tvm
.
ir_pass
.
Simplify
(
res1
.
max_value
))
==
str
(
ans1
)
ans1
=
(((
c
-
b
)
+
-
1
)
/
4
-
1
)
assert_expr_equal
(
res1
.
max_value
,
ans1
)
# expression containing variable a is on rhs
e1
=
(
c
>
a
*
4
+
b
)
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
e1
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
assert
str
(
tvm
.
ir_pass
.
Simplify
(
res1
.
max_value
))
==
str
(
ans1
)
assert_expr_equal
(
res1
.
max_value
,
ans1
)
e2
=
(
tvm
.
max
(
5
,
a
*
4
)
<
0
)
res2
=
tvm
.
arith
.
DeduceBound
(
a
,
e2
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
...
...
@@ -66,7 +77,6 @@ def test_deduce():
assert
str
(
res2
.
max_value
)
==
"neg_inf"
assert
str
(
res2
.
min_value
)
==
"pos_inf"
e3
=
(
-
b
)
+
a
*
c
-
d
res3
=
tvm
.
arith
.
DeduceBound
(
a
,
e3
>=
0
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{
b
:
b_s
,
d
:
d_s
})
ans3
=
2
/
c
+
1
...
...
@@ -75,6 +85,7 @@ def test_deduce():
res3
=
tvm
.
arith
.
DeduceBound
(
a
,
zero
<=
e3
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{
b
:
b_s
,
d
:
d_s
})
assert
str
(
tvm
.
ir_pass
.
Simplify
(
res3
.
min_value
))
==
str
(
ans3
)
def
test_check
():
a
=
tvm
.
var
(
'a'
)
b
=
tvm
.
var
(
'b'
)
...
...
tests/python/unittest/test_pass_basic.py
View file @
4273e461
...
...
@@ -24,9 +24,6 @@ def test_simplify():
assert
(
tvm
.
ir_pass
.
Equal
(
e2
,
x
*
8
))
e3
=
tvm
.
ir_pass
.
Simplify
(
x
-
x
/
3
*
3
)
assert
(
tvm
.
ir_pass
.
Equal
(
e3
,
tvm
.
make
.
Mod
(
x
,
3
)))
let
=
tvm
.
make
.
Let
(
x
,
1
,
x
+
3
)
e4
=
tvm
.
ir_pass
.
Simplify
(
let
)
assert
(
tvm
.
ir_pass
.
Equal
(
e4
,
4
))
def
test_verify_ssa
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment