Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
37e57548
Unverified
Commit
37e57548
authored
Apr 26, 2020
by
yongfeng-nv
Committed by
GitHub
Apr 26, 2020
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Improve IntervalSet's floormod (#5367)
parent
4a3fece7
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
168 additions
and
41 deletions
+168
-41
include/tvm/arith/analyzer.h
+21
-4
include/tvm/arith/int_set.h
+8
-4
src/arith/analyzer.cc
+19
-10
src/arith/const_int_bound.cc
+10
-8
src/arith/int_set.cc
+10
-0
src/te/operation/compute_op.cc
+8
-6
src/te/schedule/bound.cc
+10
-5
src/te/schedule/message_passing.cc
+8
-4
tests/python/unittest/test_arith_intset.py
+14
-0
tests/python/unittest/test_te_schedule_bound_inference_tiling.py
+60
-0
No files found.
include/tvm/arith/analyzer.h
View file @
37e57548
...
...
@@ -138,8 +138,9 @@ class ConstIntBoundAnalyzer {
*
* \param var The variable.
* \param range The range we bind to.
* \param override Whether we allow overriding an existing var's range.
*/
TVM_DLL
void
Bind
(
const
Var
&
var
,
const
Range
&
range
);
TVM_DLL
void
Bind
(
const
Var
&
var
,
const
Range
&
range
,
bool
override
=
false
);
private
:
friend
class
Analyzer
;
...
...
@@ -411,8 +412,9 @@ class TVM_DLL Analyzer {
*
* \param var The variable.
* \param expr The expression we bind to.
* \param override Whether we allow overriding an existing var's expression.
*/
void
Bind
(
const
Var
&
var
,
const
PrimExpr
&
expr
);
void
Bind
(
const
Var
&
var
,
const
PrimExpr
&
expr
,
bool
override
=
false
);
/*!
* \brief Notify all the sub-analyzers that var
* is created and binded to a range.
...
...
@@ -421,14 +423,16 @@ class TVM_DLL Analyzer {
*
* \param var The variable.
* \param range The range we bind to.
* \param override Whether we allow overriding an existing var's expression.
*/
void
Bind
(
const
Var
&
var
,
const
Range
&
range
);
void
Bind
(
const
Var
&
var
,
const
Range
&
range
,
bool
override
=
false
);
/*!
* \brief Bind all the vars in the Map
*
* \param variables The {variable -> range} map.
* \param override Whether we allow overriding an existing var's expression.
*/
void
Bind
(
const
Map
<
Var
,
Range
>&
variables
);
void
Bind
(
const
Map
<
Var
,
Range
>&
variables
,
bool
override
=
false
);
/*!
* \brief Whether can we prove expr >= val.
...
...
@@ -443,6 +447,19 @@ class TVM_DLL Analyzer {
*/
bool
CanProveGreaterEqual
(
const
PrimExpr
&
expr
,
int64_t
lower_bound
);
/*!
* \brief Whether can we prove expr < val.
* Non-negative proof is very useful in integer analysis
* to lower divisions and mods given difference in trunc and ceil mode.
*
* \param expr The expression.
* \param upper_bound The upper bound.
* \return Whether we can prove it.
*
* \note Analyzer will call into sub-analyzers to get the result.
*/
bool
CanProveLess
(
const
PrimExpr
&
expr
,
int64_t
upper_bound
);
/*!
* \brief Whether can we prove condition.
*
* \param cond The expression to be proved.
...
...
include/tvm/arith/int_set.h
View file @
37e57548
...
...
@@ -153,6 +153,13 @@ class IntSet : public ObjectRef {
// Integer set legacy API.
//------------------------------------------------
/*!
* \brief Convert std::unordered_map<const VarNode*, IntSet> to Map<Var, IntSet>
*
* \param dom_map The domain map to convert.
* \return The converted map.
*/
Map
<
Var
,
IntSet
>
ConvertDomMap
(
const
std
::
unordered_map
<
const
VarNode
*
,
IntSet
>&
dom_map
);
/*!
* \brief Find an symbolic integer set that contains all possible values of
* e given the domain of each iteration variables.
*
...
...
@@ -160,8 +167,7 @@ class IntSet : public ObjectRef {
* \param dom_map The domain of each variable.
* \return An integer set that can cover all the possible values of e.
*/
IntSet
EvalSet
(
PrimExpr
e
,
const
Map
<
IterVar
,
IntSet
>&
dom_map
);
IntSet
EvalSet
(
PrimExpr
e
,
const
Map
<
IterVar
,
IntSet
>&
dom_map
);
/*!
* \brief Same as EvalSet, but takes unordered_map
*
...
...
@@ -171,7 +177,6 @@ IntSet EvalSet(PrimExpr e,
*/
IntSet
EvalSet
(
PrimExpr
e
,
const
std
::
unordered_map
<
const
tir
::
VarNode
*
,
IntSet
>&
dom_map
);
/*!
* \brief Find an symbolic integer set that contains is union over
* all the possible conditional values in dom_map.
...
...
@@ -202,7 +207,6 @@ IntSet EvalSet(IntSet s,
*/
IntSet
EvalSet
(
Range
r
,
const
std
::
unordered_map
<
const
VarNode
*
,
IntSet
>&
dom_map
);
/*! \brief Map from Expr to IntSet */
using
ExprIntSetMap
=
std
::
unordered_map
<
PrimExpr
,
IntSet
,
ObjectHash
,
ObjectEqual
>
;
/*!
...
...
src/arith/analyzer.cc
View file @
37e57548
...
...
@@ -36,31 +36,31 @@ Analyzer::Analyzer()
int_set
(
this
)
{
}
void
Analyzer
::
Bind
(
const
Var
&
var
,
const
PrimExpr
&
expr
)
{
void
Analyzer
::
Bind
(
const
Var
&
var
,
const
PrimExpr
&
expr
,
bool
override
)
{
PrimExpr
new_expr
=
expr
;
new_expr
=
this
->
canonical_simplify
(
new_expr
);
new_expr
=
this
->
rewrite_simplify
(
new_expr
);
this
->
const_int_bound
.
Update
(
var
,
this
->
const_int_bound
(
new_expr
));
this
->
modular_set
.
Update
(
var
,
this
->
modular_set
(
new_expr
));
this
->
rewrite_simplify
.
Update
(
var
,
new_expr
);
this
->
canonical_simplify
.
Update
(
var
,
new_expr
);
this
->
const_int_bound
.
Update
(
var
,
this
->
const_int_bound
(
new_expr
)
,
override
);
this
->
modular_set
.
Update
(
var
,
this
->
modular_set
(
new_expr
)
,
override
);
this
->
rewrite_simplify
.
Update
(
var
,
new_expr
,
override
);
this
->
canonical_simplify
.
Update
(
var
,
new_expr
,
override
);
}
void
Analyzer
::
Bind
(
const
Var
&
var
,
const
Range
&
range
)
{
void
Analyzer
::
Bind
(
const
Var
&
var
,
const
Range
&
range
,
bool
override
)
{
CHECK
(
range
.
defined
());
if
(
tir
::
is_one
(
range
->
extent
))
{
this
->
Bind
(
var
,
range
->
min
);
this
->
Bind
(
var
,
range
->
min
,
override
);
}
else
{
this
->
const_int_bound
.
Bind
(
var
,
range
);
this
->
const_int_bound
.
Bind
(
var
,
range
,
override
);
}
// skip modular_set
// skip rewrite simplify
}
void
Analyzer
::
Bind
(
const
Map
<
Var
,
Range
>&
variables
)
{
void
Analyzer
::
Bind
(
const
Map
<
Var
,
Range
>&
variables
,
bool
override
)
{
for
(
const
auto
&
iter
:
variables
)
{
this
->
Bind
(
iter
.
first
,
iter
.
second
);
this
->
Bind
(
iter
.
first
,
iter
.
second
,
override
);
}
}
...
...
@@ -92,6 +92,15 @@ bool Analyzer::CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound) {
return
false
;
}
bool
Analyzer
::
CanProveLess
(
const
PrimExpr
&
expr
,
int64_t
upper_bound
)
{
if
(
const
auto
*
ptr
=
expr
.
as
<
tir
::
IntImmNode
>
())
{
return
ptr
->
value
<
upper_bound
;
}
auto
bd
=
this
->
const_int_bound
(
this
->
rewrite_simplify
(
expr
));
if
(
bd
->
max_value
<
upper_bound
)
return
true
;
return
false
;
}
bool
Analyzer
::
CanProve
(
const
PrimExpr
&
expr
)
{
if
(
const
auto
*
ptr
=
expr
.
as
<
IntImmNode
>
())
{
return
ptr
->
value
!=
0
;
...
...
src/arith/const_int_bound.cc
View file @
37e57548
...
...
@@ -99,13 +99,13 @@ class ConstIntBoundAnalyzer::Impl :
}
};
void
Bind
(
const
Var
&
var
,
const
Range
&
range
)
{
void
Bind
(
const
Var
&
var
,
const
Range
&
range
,
bool
override
)
{
Entry
a
=
VisitExpr
(
range
->
min
);
Entry
b
=
VisitExpr
(
range
->
extent
);
Entry
ret
;
ret
.
min_value
=
a
.
min_value
;
ret
.
max_value
=
InfAwareAdd
(
a
.
max_value
,
InfAwareAdd
(
b
.
max_value
,
-
1
));
Update
(
var
,
ret
,
fals
e
);
Update
(
var
,
ret
,
overrid
e
);
}
void
Update
(
const
Var
&
var
,
...
...
@@ -150,10 +150,12 @@ class ConstIntBoundAnalyzer::Impl :
const
PrimExprNode
*
op
=
expr
.
as
<
PrimExprNode
>
();
auto
val
=
bound_
->
find
(
op
);
if
(
val
!=
bound_
->
end
())
{
CHECK
(
val
->
second
->
min_value
==
res
.
min_value
&&
val
->
second
->
max_value
==
res
.
max_value
)
<<
"Detected bound for "
<<
expr
<<
"conflicts with memorization"
;
auto
everything
=
Everything
(
op
->
dtype
);
CHECK
(
(
val
->
second
->
min_value
==
res
.
min_value
&&
val
->
second
->
max_value
==
res
.
max_value
)
||
(
val
->
second
->
min_value
==
everything
.
min_value
&&
val
->
second
->
max_value
==
everything
.
max_value
))
<<
"Detected bound for "
<<
expr
<<
"conflicts with memorization"
;
}
(
*
bound_
)[
op
]
=
ConstIntBound
(
res
.
min_value
,
res
.
max_value
);
}
...
...
@@ -574,8 +576,8 @@ void ConstIntBoundAnalyzer::Update(const Var& var,
impl_
->
Update
(
var
,
info
,
override
);
}
void
ConstIntBoundAnalyzer
::
Bind
(
const
Var
&
var
,
const
Range
&
range
)
{
impl_
->
Bind
(
var
,
range
);
void
ConstIntBoundAnalyzer
::
Bind
(
const
Var
&
var
,
const
Range
&
range
,
bool
override
)
{
impl_
->
Bind
(
var
,
range
,
override
);
}
std
::
function
<
void
()
>
ConstIntBoundAnalyzer
::
EnterConstraint
(
const
PrimExpr
&
constraint
)
{
...
...
src/arith/int_set.cc
View file @
37e57548
...
...
@@ -311,6 +311,16 @@ inline IntervalSet Combine<tir::FloorModNode>(Analyzer* analyzer,
LOG
(
FATAL
)
<<
"Modular by zero in CombineInterval Mod"
;
}
if
(
analyzer
->
CanProveGreaterEqual
(
divisor
,
0
))
{
if
(
divisor
.
as
<
tir
::
IntImmNode
>
())
{
// a mod b = a - (a / b) * b if a_max / b == a_min / b
auto
qmax
=
floordiv
(
a
->
max_value
,
divisor
);
auto
qmin
=
floordiv
(
a
->
min_value
,
divisor
);
if
(
analyzer
->
CanProve
(
qmax
==
qmin
))
{
auto
tmax
=
a
->
max_value
-
divisor
*
qmin
;
auto
tmin
=
a
->
min_value
-
divisor
*
qmin
;
return
IntervalSet
(
tmin
,
tmax
);
}
}
return
IntervalSet
(
make_zero
(
divisor
.
dtype
()),
divisor
-
1
);
}
else
{
PrimExpr
bound
=
abs
(
divisor
)
-
1
;
...
...
src/te/operation/compute_op.cc
View file @
37e57548
...
...
@@ -231,7 +231,7 @@ void ComputeOpNode::PropBoundToInputs(
// undefined behaviour), so we can intersect the estimated set of the argument with the
// range expected by the tensor. However, intersection may result in overly complex
// expressions, so we perform a more relaxed form of intersection.
IntSet
arg_intset
=
EvalSet
(
call
->
args
[
i
],
dom_map
);
IntSet
arg_intset
=
analyzer
->
int_set
(
call
->
args
[
i
],
ConvertDomMap
(
dom_map
)
);
const
arith
::
IntervalSetNode
*
arg_interval
=
arg_intset
.
as
<
arith
::
IntervalSetNode
>
();
if
(
arg_interval
)
{
PrimExpr
shape_i_min_value
=
make_zero
(
t
->
shape
[
i
].
dtype
());
...
...
@@ -239,12 +239,14 @@ void ComputeOpNode::PropBoundToInputs(
PrimExpr
min_value
=
arg_interval
->
min_value
;
PrimExpr
max_value
=
arg_interval
->
max_value
;
// Prefer the shape bounds only when we can prove they are tighter.
if
(
arith
::
is_neg_inf
(
min_value
)
||
analyzer
->
CanProve
(
shape_i_min_value
>=
min_value
))
{
// We must update bound's ends in pairs. Here is an counter example: shape_i is
// [0, 0] and arg_interval is [threadIdx.y, threadIdx.y], where threadIdx.y's range is
// [0, 7]. If we allowed updating one end, the bound would become [threadIdx.y, 0],
// awkward for further analysis.
if
((
arith
::
is_pos_inf
(
max_value
)
&&
arith
::
is_neg_inf
(
min_value
))
||
(
analyzer
->
CanProve
(
shape_i_min_value
>=
min_value
)
&&
analyzer
->
CanProve
(
shape_i_max_value
<=
max_value
)))
{
min_value
=
shape_i_min_value
;
}
if
(
arith
::
is_pos_inf
(
max_value
)
||
analyzer
->
CanProve
(
shape_i_max_value
<=
max_value
))
{
max_value
=
shape_i_max_value
;
}
dom
.
data
[
i
].
push_back
(
IntSet
::
interval
(
min_value
,
max_value
));
...
...
src/te/schedule/bound.cc
View file @
37e57548
...
...
@@ -137,7 +137,7 @@ void InferRootBound(const Stage& stage,
Array
<
IterVar
>
stage_attach
=
ctx
.
attach_path
.
at
(
stage
->
op
);
// The parent set.
for
(
const
Operation
&
op
:
consumers
)
{
std
::
unordered_map
<
const
VarNode
*
,
IntSet
>
relax_set
;
Map
<
Var
,
IntSet
>
relax_set
;
std
::
unordered_map
<
IterVar
,
IntSet
>
up_state
;
bool
found_attach
=
false
;
CHECK
(
ctx
.
op2stage_
.
count
(
op
.
get
()));
...
...
@@ -176,9 +176,9 @@ void InferRootBound(const Stage& stage,
<<
"InferBound requires every leaf iter var's min equals 0, "
<<
"call schedule.normalize to achieve this."
;
if
(
NeedRelax
(
iv
,
found_attach
,
ctx
.
bind_map
,
scope
))
{
relax_set
[
iv
->
var
.
get
()]
=
IntSet
::
range
(
vrange
);
relax_set
.
Set
(
iv
->
var
,
IntSet
::
range
(
vrange
)
);
if
(
ctx
.
bind_map
.
count
(
iv
))
{
relax_set
[
ctx
.
bind_map
.
at
(
iv
)
->
var
.
get
()]
=
IntSet
::
range
(
vrange
);
relax_set
.
Set
(
ctx
.
bind_map
.
at
(
iv
)
->
var
,
IntSet
::
range
(
vrange
)
);
}
}
}
...
...
@@ -190,6 +190,9 @@ void InferRootBound(const Stage& stage,
// Relax if needed.
std
::
unordered_map
<
const
VarNode
*
,
IntSet
>
dom_map
;
arith
::
Analyzer
analyzer
;
for
(
auto
entry
:
*
rmap
)
{
analyzer
.
Bind
(
entry
.
first
->
var
,
entry
.
second
);
}
for
(
auto
iv
:
op
->
root_iter_vars
())
{
Range
r
;
if
(
up_state
.
count
(
iv
))
{
...
...
@@ -198,11 +201,13 @@ void InferRootBound(const Stage& stage,
r
=
iv
->
dom
;
}
if
(
relax_set
.
size
()
!=
0
)
{
dom_map
[
iv
->
var
.
get
()]
=
EvalSet
(
r
,
relax_set
);
dom_map
[
iv
->
var
.
get
()]
=
IntSet
::
interval
(
analyzer
.
int_set
(
r
->
min
,
relax_set
).
min
(),
analyzer
.
int_set
(
r
->
min
+
r
->
extent
-
1
,
relax_set
).
max
());
}
else
{
dom_map
[
iv
->
var
.
get
()]
=
IntSet
::
range
(
r
);
}
analyzer
.
Bind
(
iv
->
var
,
r
);
analyzer
.
Bind
(
iv
->
var
,
r
,
true
);
}
op
->
PropBoundToInputs
(
op
,
&
analyzer
,
dom_map
,
&
tmap
);
}
...
...
src/te/schedule/message_passing.cc
View file @
37e57548
...
...
@@ -579,11 +579,15 @@ std::vector<PrimExpr> MakeBoundCheck(
PassUpBoundCheck
(
stage
,
dom_map
,
&
bound_state
,
&
analyzer
);
std
::
vector
<
PrimExpr
>
preds
;
std
::
unordered_map
<
const
VarNode
*
,
IntSet
>
iset_dmap
;
Map
<
Var
,
IntSet
>
iset_dmap
;
// setup domain map for set analysis
for
(
const
auto
&
kv
:
dom_map
)
{
iset_dmap
[
kv
.
first
->
var
.
get
()]
=
IntSet
::
range
(
kv
.
second
);
iset_dmap
.
Set
(
kv
.
first
->
var
,
IntSet
::
range
(
kv
.
second
));
}
for
(
auto
entry
:
dom_map
)
{
analyzer
.
Bind
(
entry
.
first
->
var
,
entry
.
second
);
}
for
(
const
IterVar
&
iv
:
stage
->
all_iter_vars
)
{
...
...
@@ -591,7 +595,7 @@ std::vector<PrimExpr> MakeBoundCheck(
if
(
bound_state
.
at
(
iv
))
{
Range
dom
=
dom_map
.
at
(
iv
);
PrimExpr
value
=
value_map
.
at
(
iv
)
-
dom
->
min
;
PrimExpr
vmax
=
EvalS
et
(
value
,
iset_dmap
).
max
();
PrimExpr
vmax
=
analyzer
.
int_s
et
(
value
,
iset_dmap
).
max
();
if
(
vmax
.
dtype
()
!=
value
.
dtype
()
||
!
analyzer
.
CanProve
(
vmax
<
dom
->
extent
))
{
preds
.
emplace_back
(
value
<
dom
->
extent
);
}
...
...
@@ -603,7 +607,7 @@ std::vector<PrimExpr> MakeBoundCheck(
CHECK
(
iv
->
dom
.
defined
());
if
(
!
skip_ivar_domain
&&
!
IsRangeSame
(
iv
->
dom
,
dom
))
{
PrimExpr
value
=
value_map
.
at
(
iv
)
-
iv
->
dom
->
min
;
IntSet
s
=
EvalS
et
(
value
,
iset_dmap
);
IntSet
s
=
analyzer
.
int_s
et
(
value
,
iset_dmap
);
PrimExpr
vmin
=
s
.
min
();
PrimExpr
vmax
=
s
.
max
();
// The range of `value` resides in [vmin, vmax]
...
...
tests/python/unittest/test_arith_intset.py
View file @
37e57548
...
...
@@ -90,6 +90,20 @@ def test_mod():
flm
=
tvm
.
te
.
floormod
ck
.
verify
(
flm
(
x
,
10
),
{
x
:
tvm
.
arith
.
IntervalSet
(
-
10
,
10
)},
(
0
,
9
))
ck
.
verify
(
flm
(
x
,
10
),
{
x
:
tvm
.
arith
.
IntervalSet
(
3
,
5
)},
(
3
,
5
))
ck
.
verify
(
flm
(
x
,
10
),
{
x
:
tvm
.
arith
.
IntervalSet
(
13
,
15
)},
(
3
,
5
))
ck
.
verify
(
flm
(
x
,
10
),
{
x
:
tvm
.
arith
.
IntervalSet
(
3
,
15
)},
(
0
,
9
))
ck
.
verify
(
flm
(
x
,
10
),
{
x
:
tvm
.
arith
.
IntervalSet
(
3
,
11
)},
(
0
,
9
))
ck
.
verify
(
flm
(
x
,
10
),
{
x
:
tvm
.
arith
.
IntervalSet
(
1
,
21
)},
(
0
,
9
))
floordiv
=
tvm
.
te
.
floordiv
z
=
te
.
var
(
"z"
)
ck
.
analyzer
.
bind
(
x
,
tvm
.
ir
.
Range
.
make_by_min_extent
(
0
,
3
))
ck
.
verify
(
flm
(
y
,
8
),
{
y
:
tvm
.
arith
.
IntervalSet
(
z
*
8
+
x
*
4
,
z
*
8
+
x
*
4
+
3
)},
(
0
,
7
))
ck1
=
IntSetChecker
()
ck1
.
analyzer
.
bind
(
x
,
tvm
.
ir
.
Range
.
make_by_min_extent
(
0
,
2
))
ck1
.
verify
(
flm
(
y
,
8
),
{
y
:
tvm
.
arith
.
IntervalSet
(
z
*
8
+
x
*
4
,
z
*
8
+
x
*
4
+
3
)},
(
x
*
4
,
x
*
4
+
3
))
def
test_max_min
():
...
...
tests/python/unittest/test_te_schedule_bound_inference_tiling.py
0 → 100644
View file @
37e57548
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import
tvm
from
tvm
import
te
def
test_bound_tile_mod
():
def
compute
(
M_tiles
,
N_tiles
,
factor
,
dtype
):
# Algo
M
=
M_tiles
*
factor
N
=
N_tiles
*
factor
A
=
tvm
.
te
.
placeholder
((
N
,
M
),
name
=
'A'
,
dtype
=
dtype
)
C
=
tvm
.
te
.
compute
((
N
,
M
),
lambda
n
,
m
:
A
[
n
,
m
],
name
=
'C'
)
s
=
tvm
.
te
.
create_schedule
(
C
.
op
)
return
s
,
A
,
C
def
schedule
(
s
,
factor
,
padding
,
A
,
C
):
C_local
=
s
.
cache_write
(
C
,
"local"
)
n
,
m
=
C
.
op
.
axis
bn
,
bm
,
ni
,
mi
=
s
[
C
]
.
tile
(
n
,
m
,
factor
,
factor
)
nio
,
nii
=
s
[
C
]
.
split
(
ni
,
2
)
n
=
s
[
C
]
.
fuse
(
nii
,
mi
)
C_shared
=
s
.
cache_write
(
C
,
"shared"
)
bn
,
bm
,
ni
,
mi
=
C_shared
.
op
.
axis
s
[
C_shared
]
.
storage_align
(
ni
,
factor
*
2
,
padding
)
n
,
m
=
s
[
C
]
.
op
.
axis
bn
,
bm
,
ni
,
mi
=
s
[
C
]
.
tile
(
n
,
m
,
factor
,
factor
)
s
[
C
]
.
set_scope
(
"global"
)
niio
,
niii
=
s
[
C
]
.
split
(
ni
,
32
)
s
[
C_shared
]
.
compute_at
(
s
[
C
],
niio
)
return
s
s
,
A
,
C
=
compute
(
2
,
2
,
128
,
"float16"
)
s
=
schedule
(
s
,
128
,
8
,
A
,
C
)
bounds
=
tvm
.
te
.
schedule
.
InferBound
(
s
)
check
=
(
bounds
[
s
.
stages
[
2
]
.
op
.
axis
[
2
]]
.
extent
==
16
)
if
(
not
check
):
print
(
tvm
.
lower
(
s
,
[
A
,
C
],
simple_mode
=
True
))
assert
(
check
)
if
__name__
==
"__main__"
:
test_bound_tile_mod
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment