Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
79e071c9
Unverified
Commit
79e071c9
authored
Jun 30, 2019
by
Tianqi Chen
Committed by
GitHub
Jun 30, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[ARITH][SCHEDULE] Update schedule to use the new analyzer (#3466)
parent
dfc4f972
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
142 additions
and
30 deletions
+142
-30
include/tvm/arithmetic.h
+18
-0
include/tvm/expr.h
+1
-0
src/arithmetic/analyzer.cc
+28
-0
src/arithmetic/canonical_simplify.cc
+1
-0
src/arithmetic/const_int_bound.cc
+20
-5
src/arithmetic/rewrite_simplify.cc
+18
-1
src/arithmetic/rewrite_simplify.h
+5
-0
src/schedule/bound.cc
+15
-3
src/schedule/message_passing.cc
+18
-19
src/schedule/message_passing.h
+2
-0
src/schedule/schedule_dataflow_rewrite.cc
+7
-2
src/schedule/schedule_ops.cc
+0
-0
tests/python/unittest/test_arith_canonical_simplify.py
+6
-0
tests/python/unittest/test_arith_rewrite_simplify.py
+3
-0
No files found.
include/tvm/arithmetic.h
View file @
79e071c9
...
@@ -516,6 +516,24 @@ class Analyzer {
...
@@ -516,6 +516,24 @@ class Analyzer {
* \note Analyzer will call into sub-analyzers to get the result.
* \note Analyzer will call into sub-analyzers to get the result.
*/
*/
bool
CanProveGreaterEqual
(
const
Expr
&
expr
,
int64_t
lower_bound
);
bool
CanProveGreaterEqual
(
const
Expr
&
expr
,
int64_t
lower_bound
);
/*!
* \brief Whether can we prove condition.
*
* \param cond The expression to be proved.
* \return The result.
*
* \note Analyzer will call into sub-analyzers to get the result.
*/
bool
CanProve
(
const
Expr
&
cond
);
/*!
* \brief Simplify expr.
*
* \param expr The expression to be simplified.
* \return The result.
*
* \note Analyzer will call into sub-analyzers to get the result.
*/
Expr
Simplify
(
const
Expr
&
expr
);
};
};
//-----------------------------------------------
//-----------------------------------------------
...
...
include/tvm/expr.h
View file @
79e071c9
/*
/*
* Licensed to the Apache Software Foundation (ASF) under one
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* or more contributor license agreements. See the NOTICE file
...
...
src/arithmetic/analyzer.cc
View file @
79e071c9
...
@@ -23,6 +23,7 @@
...
@@ -23,6 +23,7 @@
*/
*/
#include <tvm/ir.h>
#include <tvm/ir.h>
#include <tvm/arithmetic.h>
#include <tvm/arithmetic.h>
#include <tvm/expr_operator.h>
namespace
tvm
{
namespace
tvm
{
namespace
arith
{
namespace
arith
{
...
@@ -49,8 +50,13 @@ void Analyzer::Bind(const VarExpr& v, const Expr& expr) {
...
@@ -49,8 +50,13 @@ void Analyzer::Bind(const VarExpr& v, const Expr& expr) {
}
}
void
Analyzer
::
Bind
(
const
VarExpr
&
v
,
const
Range
&
range
)
{
void
Analyzer
::
Bind
(
const
VarExpr
&
v
,
const
Range
&
range
)
{
CHECK
(
range
.
defined
());
Var
var
(
v
.
node_
);
Var
var
(
v
.
node_
);
this
->
const_int_bound
.
Bind
(
var
,
range
);
this
->
const_int_bound
.
Bind
(
var
,
range
);
if
(
is_one
(
range
->
extent
))
{
this
->
rewrite_simplify
.
Update
(
var
,
range
->
min
);
this
->
canonical_simplify
.
Update
(
var
,
range
->
min
);
}
// skip modular_set
// skip modular_set
// skip rewrite simplify
// skip rewrite simplify
}
}
...
@@ -82,5 +88,27 @@ bool Analyzer::CanProveGreaterEqual(const Expr& expr, int64_t lower_bound) {
...
@@ -82,5 +88,27 @@ bool Analyzer::CanProveGreaterEqual(const Expr& expr, int64_t lower_bound) {
return
false
;
return
false
;
}
}
bool
Analyzer
::
CanProve
(
const
Expr
&
expr
)
{
if
(
const
auto
*
ptr
=
expr
.
as
<
ir
::
UIntImm
>
())
{
return
ptr
->
value
!=
0
;
}
auto
res
=
this
->
rewrite_simplify
(
expr
);
if
(
const
auto
*
ptr
=
res
.
as
<
ir
::
UIntImm
>
())
{
return
ptr
->
value
!=
0
;
}
res
=
this
->
canonical_simplify
(
expr
);
if
(
const
auto
*
ptr
=
res
.
as
<
ir
::
UIntImm
>
())
{
return
ptr
->
value
!=
0
;
}
return
false
;
}
Expr
Analyzer
::
Simplify
(
const
Expr
&
expr
)
{
if
(
is_const
(
expr
))
return
expr
;
auto
res
=
this
->
rewrite_simplify
(
expr
);
res
=
this
->
canonical_simplify
(
res
);
return
res
;
}
}
// namespace arith
}
// namespace arith
}
// namespace tvm
}
// namespace tvm
src/arithmetic/canonical_simplify.cc
View file @
79e071c9
...
@@ -262,6 +262,7 @@ class SumExprNode : public CanonicalExprNode {
...
@@ -262,6 +262,7 @@ class SumExprNode : public CanonicalExprNode {
rhs
.
CopyOnWrite
()
->
scale
+=
lhs
->
scale
;
rhs
.
CopyOnWrite
()
->
scale
+=
lhs
->
scale
;
lhs
.
CopyOnWrite
()
->
scale
=
0
;
lhs
.
CopyOnWrite
()
->
scale
=
0
;
}
else
if
(
lhs
->
lower_factor
==
rhs
->
upper_factor
&&
}
else
if
(
lhs
->
lower_factor
==
rhs
->
upper_factor
&&
rhs
->
scale
!=
0
&&
lhs
->
scale
%
rhs
->
scale
==
0
&&
lhs
->
scale
%
rhs
->
scale
==
0
&&
lhs
->
lower_factor
==
(
lhs
->
scale
/
rhs
->
scale
)
*
rhs
->
lower_factor
)
{
lhs
->
lower_factor
==
(
lhs
->
scale
/
rhs
->
scale
)
*
rhs
->
lower_factor
)
{
// Rules used in the proof:
// Rules used in the proof:
...
...
src/arithmetic/const_int_bound.cc
View file @
79e071c9
...
@@ -42,11 +42,23 @@ ConstIntBound::ConstIntBound(
...
@@ -42,11 +42,23 @@ ConstIntBound::ConstIntBound(
node_
=
std
::
move
(
node
);
node_
=
std
::
move
(
node
);
}
}
inline
void
PrintBoundValue
(
std
::
ostream
&
os
,
int64_t
val
)
{
if
(
val
==
ConstIntBound
::
kPosInf
)
{
os
<<
"pos_inf"
;
}
else
if
(
val
==
ConstIntBound
::
kNegInf
)
{
os
<<
"neg_inf"
;
}
else
{
os
<<
val
;
}
}
TVM_STATIC_IR_FUNCTOR
(
IRPrinter
,
vtable
)
TVM_STATIC_IR_FUNCTOR
(
IRPrinter
,
vtable
)
.
set_dispatch
<
ConstIntBoundNode
>
([](
const
ConstIntBoundNode
*
op
,
IRPrinter
*
p
)
{
.
set_dispatch
<
ConstIntBoundNode
>
([](
const
ConstIntBoundNode
*
op
,
IRPrinter
*
p
)
{
p
->
stream
<<
"ConstIntBound"
p
->
stream
<<
"ConstIntBound["
;
<<
"["
<<
op
->
min_value
<<
", "
PrintBoundValue
(
p
->
stream
,
op
->
min_value
);
<<
op
->
max_value
<<
']'
;
p
->
stream
<<
','
;
PrintBoundValue
(
p
->
stream
,
op
->
max_value
);
p
->
stream
<<
']'
;
});
});
// internal entry for const int bound
// internal entry for const int bound
...
@@ -95,7 +107,10 @@ class ConstIntBoundAnalyzer::Impl :
...
@@ -95,7 +107,10 @@ class ConstIntBoundAnalyzer::Impl :
auto
it
=
var_map_
.
find
(
var
);
auto
it
=
var_map_
.
find
(
var
);
if
(
it
!=
var_map_
.
end
())
{
if
(
it
!=
var_map_
.
end
())
{
CHECK
(
it
->
second
==
info
)
CHECK
(
it
->
second
==
info
)
<<
"var
\'
"
<<
var
<<
"
\'
already updated."
;
<<
"Trying to update var
\'
"
<<
var
<<
"
\'
"
<<
" with a different const bound: "
<<
"original="
<<
ConstIntBound
(
it
->
second
.
min_value
,
it
->
second
.
max_value
)
<<
", new="
<<
ConstIntBound
(
info
.
min_value
,
info
.
max_value
);
}
}
}
}
var_map_
[
var
]
=
info
;
var_map_
[
var
]
=
info
;
...
...
src/arithmetic/rewrite_simplify.cc
View file @
79e071c9
...
@@ -105,7 +105,14 @@ TryCompare(const Expr& x, int64_t val) {
...
@@ -105,7 +105,14 @@ TryCompare(const Expr& x, int64_t val) {
void
RewriteSimplifier
::
Impl
::
void
RewriteSimplifier
::
Impl
::
Update
(
const
Var
&
var
,
const
Expr
&
info
,
bool
override
)
{
Update
(
const
Var
&
var
,
const
Expr
&
info
,
bool
override
)
{
if
(
!
override
)
{
if
(
!
override
)
{
CHECK
(
!
var_map_
.
count
(
var
));
auto
it
=
var_map_
.
find
(
var
);
if
(
it
!=
var_map_
.
end
())
{
CHECK
(
Equal
(
it
->
second
,
info
))
<<
"Trying to update var
\'
"
<<
var
<<
"
\'
"
<<
" with a different value: "
<<
"original="
<<
it
->
second
<<
", new="
<<
info
;
}
}
}
var_map_
[
var
]
=
info
;
var_map_
[
var
]
=
info
;
}
}
...
@@ -199,6 +206,9 @@ Mutate_(const Add* op, const Expr& self) {
...
@@ -199,6 +206,9 @@ Mutate_(const Add* op, const Expr& self) {
TVM_TRY_RECURSIVE_REWRITE
(
x
+
c1
+
y
,
(
x
+
y
)
+
c1
);
TVM_TRY_RECURSIVE_REWRITE
(
x
+
c1
+
y
,
(
x
+
y
)
+
c1
);
TVM_TRY_RECURSIVE_REWRITE
(
x
+
(
c1
+
y
),
(
x
+
y
)
+
c1
);
TVM_TRY_RECURSIVE_REWRITE
(
x
+
(
c1
+
y
),
(
x
+
y
)
+
c1
);
TVM_TRY_RECURSIVE_REWRITE
((
y
%
c1
)
+
x
*
c1
,
x
*
c1
+
(
y
%
c1
));
TVM_TRY_RECURSIVE_REWRITE
((
y
%
c1
)
+
x
*
c1
,
x
*
c1
+
(
y
%
c1
));
TVM_TRY_RECURSIVE_REWRITE
(
x
+
max
(
y
,
z
),
max
(
y
,
z
)
+
x
);
TVM_TRY_RECURSIVE_REWRITE
(
x
+
min
(
y
,
z
),
min
(
y
,
z
)
+
x
);
}
}
// condition rules.
// condition rules.
...
@@ -477,6 +487,10 @@ Mutate_(const Div* op, const Expr& self) {
...
@@ -477,6 +487,10 @@ Mutate_(const Div* op, const Expr& self) {
}
}
}
}
TVM_TRY_REWRITE
(
x
/
x
,
OneWithTypeLike
(
x
));
TVM_TRY_REWRITE
(
x
*
c1
/
x
,
c1
);
TVM_TRY_REWRITE
(
c1
*
x
/
x
,
c1
);
// Rules involving 2-operands.
// Rules involving 2-operands.
TVM_TRY_REWRITE_IF
((
x
*
c1
+
y
)
/
c2
,
x
*
(
c1
/
c2
)
+
y
/
c2
,
TVM_TRY_REWRITE_IF
((
x
*
c1
+
y
)
/
c2
,
x
*
(
c1
/
c2
)
+
y
/
c2
,
c1
.
Eval
()
->
value
>=
0
&&
c1
.
Eval
()
->
value
>=
0
&&
...
@@ -684,6 +698,9 @@ Mutate_(const Mod* op, const Expr& self) {
...
@@ -684,6 +698,9 @@ Mutate_(const Mod* op, const Expr& self) {
if
(
mod
->
coeff
%
c1val
==
0
&&
if
(
mod
->
coeff
%
c1val
==
0
&&
CanProveGreaterEqual
(
x
.
Eval
(),
0
))
{
CanProveGreaterEqual
(
x
.
Eval
(),
0
))
{
return
(
mod
->
base
%
c1
).
Eval
();
return
(
mod
->
base
%
c1
).
Eval
();
}
else
if
(
mod
->
coeff
%
c1val
==
0
&&
mod
->
base
%
c1val
==
0
)
{
return
make_zero
(
ret
.
type
());
}
}
}
}
}
}
...
...
src/arithmetic/rewrite_simplify.h
View file @
79e071c9
...
@@ -121,6 +121,11 @@ class RewriteSimplifier::Impl : public IRMutator {
...
@@ -121,6 +121,11 @@ class RewriteSimplifier::Impl : public IRMutator {
PConstWithTypeLike
<
TA
>
ZeroWithTypeLike
(
const
Pattern
<
TA
>&
pattern
)
{
PConstWithTypeLike
<
TA
>
ZeroWithTypeLike
(
const
Pattern
<
TA
>&
pattern
)
{
return
PConstWithTypeLike
<
TA
>
(
pattern
.
derived
(),
0
);
return
PConstWithTypeLike
<
TA
>
(
pattern
.
derived
(),
0
);
}
}
template
<
typename
TA
>
PConstWithTypeLike
<
TA
>
OneWithTypeLike
(
const
Pattern
<
TA
>&
pattern
)
{
return
PConstWithTypeLike
<
TA
>
(
pattern
.
derived
(),
1
);
}
};
};
...
...
src/schedule/bound.cc
View file @
79e071c9
...
@@ -213,6 +213,8 @@ Map<IterVar, Range> InferBound(const Schedule& sch) {
...
@@ -213,6 +213,8 @@ Map<IterVar, Range> InferBound(const Schedule& sch) {
// Prepare context
// Prepare context
GraphContext
ctx
;
GraphContext
ctx
;
Array
<
Operation
>
roots
;
Array
<
Operation
>
roots
;
arith
::
Analyzer
analyzer
;
for
(
Operation
op
:
sch
->
outputs
)
{
for
(
Operation
op
:
sch
->
outputs
)
{
roots
.
push_back
(
sch
->
stage_map
[
op
]
->
op
);
roots
.
push_back
(
sch
->
stage_map
[
op
]
->
op
);
}
}
...
@@ -233,16 +235,26 @@ Map<IterVar, Range> InferBound(const Schedule& sch) {
...
@@ -233,16 +235,26 @@ Map<IterVar, Range> InferBound(const Schedule& sch) {
for
(
size_t
i
=
sch
->
stages
.
size
();
i
!=
0
;
--
i
)
{
for
(
size_t
i
=
sch
->
stages
.
size
();
i
!=
0
;
--
i
)
{
const
Stage
&
stage
=
sch
->
stages
[
i
-
1
];
const
Stage
&
stage
=
sch
->
stages
[
i
-
1
];
InferRootBound
(
stage
,
ctx
,
&
ret
);
InferRootBound
(
stage
,
ctx
,
&
ret
);
// bind bound of root iter vars.
for
(
auto
iv
:
stage
->
op
->
root_iter_vars
())
{
auto
it
=
ret
.
find
(
iv
);
if
(
it
!=
ret
.
end
())
{
analyzer
.
Bind
(
iv
->
var
,
it
->
second
);
}
}
// pass down to get bound of all iter vars.
// pass down to get bound of all iter vars.
PassDownDomain
(
stage
,
&
ret
);
PassDownDomain
(
stage
,
&
ret
,
&
analyzer
);
for
(
IterVar
iv
:
stage
->
env_threads
)
{
for
(
IterVar
iv
:
stage
->
env_threads
)
{
CHECK
(
iv
->
dom
.
defined
());
CHECK
(
iv
->
dom
.
defined
());
ret
[
iv
]
=
iv
->
dom
;
ret
[
iv
]
=
iv
->
dom
;
}
}
}
}
for
(
auto
&
p
:
ret
)
{
for
(
auto
&
p
:
ret
)
{
ret
[
p
.
first
]
=
Range
::
make_by_min_extent
(
ir
::
Simplify
(
p
.
second
->
min
),
ret
[
p
.
first
]
=
Range
::
make_by_min_extent
(
ir
::
Simplify
(
p
.
second
->
extent
));
analyzer
.
Simplify
(
p
.
second
->
min
),
analyzer
.
Simplify
(
p
.
second
->
extent
));
}
}
return
Map
<
IterVar
,
Range
>
(
ret
.
begin
(),
ret
.
end
());
return
Map
<
IterVar
,
Range
>
(
ret
.
begin
(),
ret
.
end
());
}
}
...
...
src/schedule/message_passing.cc
View file @
79e071c9
...
@@ -34,24 +34,17 @@ namespace schedule {
...
@@ -34,24 +34,17 @@ namespace schedule {
using
namespace
ir
;
using
namespace
ir
;
using
namespace
arith
;
using
namespace
arith
;
// result = ceil((a / b)), both a and b are positive integer
inline
Expr
DivCeil
(
Expr
a
,
Expr
b
)
{
return
ir
::
Simplify
((
a
+
b
-
1
)
/
b
);
}
inline
bool
prove_equal
(
Expr
lhs
,
Expr
rhs
)
{
return
is_zero
(
ir
::
Simplify
(
lhs
-
rhs
));
}
void
Update
(
std
::
unordered_map
<
IterVar
,
Range
>*
p_state
,
void
Update
(
std
::
unordered_map
<
IterVar
,
Range
>*
p_state
,
const
IterVar
&
iv
,
const
IterVar
&
iv
,
Range
r
)
{
Range
r
,
Analyzer
*
analyzer
)
{
auto
it
=
p_state
->
find
(
iv
);
auto
it
=
p_state
->
find
(
iv
);
if
(
it
==
p_state
->
end
())
{
if
(
it
==
p_state
->
end
())
{
(
*
p_state
)[
iv
]
=
r
;
(
*
p_state
)[
iv
]
=
r
;
analyzer
->
Bind
(
iv
->
var
,
r
);
}
else
{
}
else
{
bool
match
=
is_zero
(
it
->
second
->
min
)
;
bool
match
=
is_zero
(
it
->
second
->
min
)
&&
if
(
!
prove_equal
(
r
->
extent
,
it
->
second
->
extent
))
match
=
false
;
analyzer
->
CanProve
(
r
->
extent
-
it
->
second
->
extent
==
0
)
;
CHECK
(
match
)
CHECK
(
match
)
<<
iv
<<
iv
<<
" domain already inferred,"
<<
" domain already inferred,"
...
@@ -62,7 +55,12 @@ void Update(std::unordered_map<IterVar, Range>* p_state,
...
@@ -62,7 +55,12 @@ void Update(std::unordered_map<IterVar, Range>* p_state,
void
PassDownDomain
(
const
Stage
&
stage
,
void
PassDownDomain
(
const
Stage
&
stage
,
std
::
unordered_map
<
IterVar
,
Range
>*
p_state
,
std
::
unordered_map
<
IterVar
,
Range
>*
p_state
,
arith
::
Analyzer
*
actx
,
bool
allow_missing
)
{
bool
allow_missing
)
{
auto
ceil_div
=
[
actx
](
Expr
a
,
Expr
b
)
{
return
actx
->
Simplify
((
a
+
(
b
-
1
))
/
b
);
};
auto
&
state
=
*
p_state
;
auto
&
state
=
*
p_state
;
// forwar iteration on relations
// forwar iteration on relations
for
(
IterVarRelation
rel
:
stage
->
relations
)
{
for
(
IterVarRelation
rel
:
stage
->
relations
)
{
...
@@ -74,15 +72,16 @@ void PassDownDomain(const Stage& stage,
...
@@ -74,15 +72,16 @@ void PassDownDomain(const Stage& stage,
CHECK
(
!
state
.
count
(
r
->
inner
));
CHECK
(
!
state
.
count
(
r
->
inner
));
const
Range
&
range_parent
=
state
.
at
(
r
->
parent
);
const
Range
&
range_parent
=
state
.
at
(
r
->
parent
);
if
(
r
->
factor
.
defined
())
{
if
(
r
->
factor
.
defined
())
{
Update
(
p_state
,
r
->
inner
,
Range
::
make_by_min_extent
(
0
,
r
->
factor
));
Update
(
p_state
,
r
->
inner
,
Range
::
make_by_min_extent
(
0
,
r
->
factor
),
actx
);
Update
(
p_state
,
r
->
outer
,
Update
(
p_state
,
r
->
outer
,
Range
::
make_by_min_extent
(
Range
::
make_by_min_extent
(
0
,
DivCeil
(
range_parent
->
extent
,
r
->
factor
))
);
0
,
ceil_div
(
range_parent
->
extent
,
r
->
factor
)),
actx
);
}
else
{
}
else
{
Update
(
p_state
,
r
->
outer
,
Range
::
make_by_min_extent
(
0
,
r
->
nparts
));
Update
(
p_state
,
r
->
outer
,
Range
::
make_by_min_extent
(
0
,
r
->
nparts
)
,
actx
);
Update
(
p_state
,
r
->
inner
,
Update
(
p_state
,
r
->
inner
,
Range
::
make_by_min_extent
(
Range
::
make_by_min_extent
(
0
,
DivCeil
(
range_parent
->
extent
,
r
->
nparts
))
);
0
,
ceil_div
(
range_parent
->
extent
,
r
->
nparts
)),
actx
);
}
}
}
else
if
(
const
FuseNode
*
r
=
rel
.
as
<
FuseNode
>
())
{
}
else
if
(
const
FuseNode
*
r
=
rel
.
as
<
FuseNode
>
())
{
if
(
!
state
.
count
(
r
->
outer
)
||
!
state
.
count
(
r
->
inner
))
{
if
(
!
state
.
count
(
r
->
outer
)
||
!
state
.
count
(
r
->
inner
))
{
...
@@ -100,9 +99,9 @@ void PassDownDomain(const Stage& stage,
...
@@ -100,9 +99,9 @@ void PassDownDomain(const Stage& stage,
}
}
Update
(
p_state
,
r
->
rebased
,
Update
(
p_state
,
r
->
rebased
,
Range
::
make_by_min_extent
(
Range
::
make_by_min_extent
(
0
,
state
.
at
(
r
->
parent
)
->
extent
));
0
,
state
.
at
(
r
->
parent
)
->
extent
)
,
actx
);
}
else
if
(
const
SingletonNode
*
s
=
rel
.
as
<
SingletonNode
>
())
{
}
else
if
(
const
SingletonNode
*
s
=
rel
.
as
<
SingletonNode
>
())
{
Update
(
p_state
,
s
->
iter
,
Range
::
make_by_min_extent
(
0
,
1
));
Update
(
p_state
,
s
->
iter
,
Range
::
make_by_min_extent
(
0
,
1
)
,
actx
);
}
else
{
}
else
{
LOG
(
FATAL
)
<<
"unknown relation type"
;
LOG
(
FATAL
)
<<
"unknown relation type"
;
}
}
...
@@ -111,7 +110,7 @@ void PassDownDomain(const Stage& stage,
...
@@ -111,7 +110,7 @@ void PassDownDomain(const Stage& stage,
for
(
auto
kv
:
stage
->
iter_var_attrs
)
{
for
(
auto
kv
:
stage
->
iter_var_attrs
)
{
if
(
kv
.
second
->
bind_thread
.
defined
())
{
if
(
kv
.
second
->
bind_thread
.
defined
())
{
CHECK
(
state
.
count
(
kv
.
first
));
CHECK
(
state
.
count
(
kv
.
first
));
Update
(
p_state
,
kv
.
second
->
bind_thread
,
state
.
at
(
kv
.
first
));
Update
(
p_state
,
kv
.
second
->
bind_thread
,
state
.
at
(
kv
.
first
)
,
actx
);
}
}
}
}
}
}
...
...
src/schedule/message_passing.h
View file @
79e071c9
...
@@ -43,11 +43,13 @@ namespace schedule {
...
@@ -43,11 +43,13 @@ namespace schedule {
*
*
* \param stage The stage to operate on.
* \param stage The stage to operate on.
* \param p_state The state of the message passing.
* \param p_state The state of the message passing.
* \param analyzer Analyzer context, storing information about bounds in p_state.
* \param allow_missing Whether allow missing value.
* \param allow_missing Whether allow missing value.
*/
*/
void
PassDownDomain
(
void
PassDownDomain
(
const
Stage
&
stage
,
const
Stage
&
stage
,
std
::
unordered_map
<
IterVar
,
Range
>*
p_state
,
std
::
unordered_map
<
IterVar
,
Range
>*
p_state
,
arith
::
Analyzer
*
analyzer
,
bool
allow_missing
=
false
);
bool
allow_missing
=
false
);
/*!
/*!
...
...
src/schedule/schedule_dataflow_rewrite.cc
View file @
79e071c9
...
@@ -203,14 +203,16 @@ void PrepareAxisMapping(Stage orig_stage,
...
@@ -203,14 +203,16 @@ void PrepareAxisMapping(Stage orig_stage,
auto
&
vsub
=
*
p_vsub
;
auto
&
vsub
=
*
p_vsub
;
auto
&
vsub2newvar
=
*
p_vsub2newvar
;
auto
&
vsub2newvar
=
*
p_vsub2newvar
;
auto
&
predicates
=
*
p_predicates
;
auto
&
predicates
=
*
p_predicates
;
arith
::
Analyzer
analyzer
;
for
(
IterVar
iv
:
op
->
reduce_axis
)
{
for
(
IterVar
iv
:
op
->
reduce_axis
)
{
red_axis
.
insert
(
iv
);
red_axis
.
insert
(
iv
);
}
}
for
(
IterVar
iv
:
op
->
axis
)
{
for
(
IterVar
iv
:
op
->
axis
)
{
dom_map
[
iv
]
=
iv
->
dom
;
dom_map
[
iv
]
=
iv
->
dom
;
analyzer
.
Bind
(
iv
->
var
,
iv
->
dom
);
}
}
schedule
::
PassDownDomain
(
orig_stage
,
&
dom_map
,
true
);
schedule
::
PassDownDomain
(
orig_stage
,
&
dom_map
,
&
analyzer
,
true
);
{
{
// The source->cache
// The source->cache
std
::
unordered_map
<
IterVar
,
Expr
>
value_map
;
std
::
unordered_map
<
IterVar
,
Expr
>
value_map
;
...
@@ -679,6 +681,8 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor,
...
@@ -679,6 +681,8 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor,
<<
"Factor axis touches normal axis."
;
<<
"Factor axis touches normal axis."
;
skip_bound_check
.
insert
(
iv
);
skip_bound_check
.
insert
(
iv
);
}
}
// get analyzer.
arith
::
Analyzer
analyzer
;
// Get the replace index
// Get the replace index
std
::
unordered_map
<
IterVar
,
Range
>
dom_map
;
std
::
unordered_map
<
IterVar
,
Range
>
dom_map
;
std
::
unordered_map
<
IterVar
,
Expr
>
value_map
;
std
::
unordered_map
<
IterVar
,
Expr
>
value_map
;
...
@@ -688,8 +692,9 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor,
...
@@ -688,8 +692,9 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor,
}
else
{
}
else
{
skip_bound_check
.
insert
(
iv
);
skip_bound_check
.
insert
(
iv
);
}
}
analyzer
.
Bind
(
iv
->
var
,
iv
->
dom
);
}
}
schedule
::
PassDownDomain
(
reduce_stage
,
&
dom_map
,
true
);
schedule
::
PassDownDomain
(
reduce_stage
,
&
dom_map
,
&
analyzer
,
true
);
for
(
IterVar
iv
:
reduce_stage
->
leaf_iter_vars
)
{
for
(
IterVar
iv
:
reduce_stage
->
leaf_iter_vars
)
{
if
(
touch_map
.
count
(
iv
))
{
if
(
touch_map
.
count
(
iv
))
{
Range
dom
=
dom_map
.
at
(
iv
);
Range
dom
=
dom_map
.
at
(
iv
);
...
...
src/schedule/schedule_ops.cc
View file @
79e071c9
tests/python/unittest/test_arith_canonical_simplify.py
View file @
79e071c9
...
@@ -198,6 +198,12 @@ def test_complex_cases():
...
@@ -198,6 +198,12 @@ def test_complex_cases():
ck
.
analyzer
.
update
(
y
,
tvm
.
arith
.
ConstIntBound
(
0
,
127
))
ck
.
analyzer
.
update
(
y
,
tvm
.
arith
.
ConstIntBound
(
0
,
127
))
ck
.
verify
(
res2
,
1
)
ck
.
verify
(
res2
,
1
)
ck
.
analyzer
.
update
(
y
,
tvm
.
arith
.
ConstIntBound
(
0
,
1024
),
True
)
res3
=
((((((((((
x
*
1024
)
+
y
)
/
65536
)
+
((((
x
*
1024
)
+
y
)
%
65536
)
/
256
))
+
((((
x
*
1024
)
+
y
)
%
256
)
/
16
))
+
(((
x
*
1024
)
+
y
)
%
16
))
-
(
y
/
256
))
-
((
y
%
256
)
/
16
))
-
(
y
%
16
))
-
(
x
*
4
))
ck
.
verify
(
res3
,
((((
x
*
1024
)
+
y
)
/
256
)
-
(
y
/
256
))
-
(
x
*
4
))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_simplify_if_then_else
()
test_simplify_if_then_else
()
...
...
tests/python/unittest/test_arith_rewrite_simplify.py
View file @
79e071c9
...
@@ -271,6 +271,8 @@ def test_mul_index_simplify():
...
@@ -271,6 +271,8 @@ def test_mul_index_simplify():
def
test_div_index_simplify
():
def
test_div_index_simplify
():
ck
=
RewriteChecker
()
ck
=
RewriteChecker
()
x
,
y
,
z
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
),
tvm
.
var
(
"z"
)
x
,
y
,
z
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
),
tvm
.
var
(
"z"
)
ck
.
verify
(
x
/
x
,
1
)
ck
.
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
0
,
1000
),
override
=
True
)
ck
.
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
0
,
1000
),
override
=
True
)
ck
.
analyzer
.
update
(
y
,
tvm
.
arith
.
ConstIntBound
(
0
,
1000
),
override
=
True
)
ck
.
analyzer
.
update
(
y
,
tvm
.
arith
.
ConstIntBound
(
0
,
1000
),
override
=
True
)
ck
.
analyzer
.
update
(
z
,
tvm
.
arith
.
ConstIntBound
(
0
,
1000
),
override
=
True
)
ck
.
analyzer
.
update
(
z
,
tvm
.
arith
.
ConstIntBound
(
0
,
1000
),
override
=
True
)
...
@@ -311,6 +313,7 @@ def test_div_index_simplify():
...
@@ -311,6 +313,7 @@ def test_div_index_simplify():
ck
.
verify
((
y
+
z
*
x
)
/
z
,
y
/
z
+
x
)
ck
.
verify
((
y
+
z
*
x
)
/
z
,
y
/
z
+
x
)
def
test_mod_index_simplify
():
def
test_mod_index_simplify
():
ck
=
RewriteChecker
()
ck
=
RewriteChecker
()
x
,
y
,
nx
,
ny
,
z
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
),
tvm
.
var
(
"nx"
),
tvm
.
var
(
"ny"
),
tvm
.
var
(
"z"
)
x
,
y
,
nx
,
ny
,
z
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
),
tvm
.
var
(
"nx"
),
tvm
.
var
(
"ny"
),
tvm
.
var
(
"z"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment