Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
a45d3b01
Commit
a45d3b01
authored
Aug 31, 2017
by
Tianqi Chen
Committed by
GitHub
Aug 31, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[PASS] InjectDoubleBuffer (#405)
parent
b8c8aadf
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
421 additions
and
14 deletions
+421
-14
include/tvm/ir.h
+8
-0
include/tvm/ir_pass.h
+8
-0
include/tvm/schedule.h
+8
-0
python/tvm/build_module.py
+7
-1
python/tvm/ir_builder.py
+15
-0
python/tvm/schedule.py
+9
-0
src/api/api_lang.cc
+7
-2
src/api/api_pass.cc
+1
-0
src/pass/inject_double_buffer.cc
+226
-0
src/pass/storage_access.cc
+18
-0
src/pass/storage_access.h
+4
-0
src/pass/storage_flatten.cc
+13
-0
src/pass/storage_sync.cc
+42
-8
src/schedule/schedule_lang.cc
+7
-0
src/schedule/schedule_ops.cc
+6
-1
tests/python/integration/test_gemm.py
+3
-2
tests/python/unittest/test_pass_inject_double_buffer.py
+37
-0
topi/recipe/gemm/cuda_gemm_square.py
+2
-0
No files found.
include/tvm/ir.h
View file @
a45d3b01
...
@@ -178,6 +178,14 @@ constexpr const char* pragma_scope = "pragma_scope";
...
@@ -178,6 +178,14 @@ constexpr const char* pragma_scope = "pragma_scope";
* run prefetch of Tensor on the current loop scope
* run prefetch of Tensor on the current loop scope
*/
*/
constexpr
const
char
*
prefetch_scope
=
"prefetch_scope"
;
constexpr
const
char
*
prefetch_scope
=
"prefetch_scope"
;
/*!
* \brief Marks production of double buffer data
*/
constexpr
const
char
*
double_buffer_scope
=
"double_buffer_scope"
;
/*!
* \brief Marks region used by double buffer write
*/
constexpr
const
char
*
double_buffer_write
=
"double_buffer_write"
;
/*! \brief Mark of scan update scope */
/*! \brief Mark of scan update scope */
constexpr
const
char
*
scan_update_scope
=
"scan_update_scope"
;
constexpr
const
char
*
scan_update_scope
=
"scan_update_scope"
;
/*! \brief Mark of scan init scope */
/*! \brief Mark of scan init scope */
...
...
include/tvm/ir_pass.h
View file @
a45d3b01
...
@@ -232,6 +232,14 @@ Stmt InjectVirtualThread(Stmt stmt);
...
@@ -232,6 +232,14 @@ Stmt InjectVirtualThread(Stmt stmt);
Stmt
InjectPrefetch
(
Stmt
stmt
);
Stmt
InjectPrefetch
(
Stmt
stmt
);
/*!
/*!
* \brief Inject double buffer into stmt.
* \param stmt The statment to be transformed.
* \param split_loop Whether split the loop containing double buffering.
* \return Transformed stmt.
*/
Stmt
InjectDoubleBuffer
(
Stmt
stmt
,
bool
split_loop
);
/*!
* \brief Rewrite storage allocation pattern.
* \brief Rewrite storage allocation pattern.
* Moves the allocation to outer most possible scope.
* Moves the allocation to outer most possible scope.
* Trying to share space between allocations to make
* Trying to share space between allocations to make
...
...
include/tvm/schedule.h
View file @
a45d3b01
...
@@ -209,6 +209,11 @@ class Stage : public NodeRef {
...
@@ -209,6 +209,11 @@ class Stage : public NodeRef {
*/
*/
Stage
&
storage_align
(
IterVar
axis
,
int
factor
,
int
offset
);
//NOLINT(*)
Stage
&
storage_align
(
IterVar
axis
,
int
factor
,
int
offset
);
//NOLINT(*)
/*!
/*!
* \brief Compute current stage with double buffering.
* \return reference to self.
*/
Stage
&
double_buffer
();
// NOLINT(*)
/*!
* \brief whether the stage has been scheduled.
* \brief whether the stage has been scheduled.
* \return whether the stage has been scheduled.
* \return whether the stage has been scheduled.
*/
*/
...
@@ -408,6 +413,8 @@ class StageNode : public Node {
...
@@ -408,6 +413,8 @@ class StageNode : public Node {
std
::
string
scope
;
std
::
string
scope
;
/*! \brief Whether this is an output stage */
/*! \brief Whether this is an output stage */
bool
is_output
{
false
};
bool
is_output
{
false
};
/*! \brief Whether apply double buffer optimization to this stage */
bool
double_buffer
{
false
};
/*!
/*!
* \brief The parent group of the current stage.
* \brief The parent group of the current stage.
* The stage cannot be assigned to stages outside the group.
* The stage cannot be assigned to stages outside the group.
...
@@ -429,6 +436,7 @@ class StageNode : public Node {
...
@@ -429,6 +436,7 @@ class StageNode : public Node {
v
->
Visit
(
"attach_stage"
,
&
attach_stage
);
v
->
Visit
(
"attach_stage"
,
&
attach_stage
);
v
->
Visit
(
"scope"
,
&
scope
);
v
->
Visit
(
"scope"
,
&
scope
);
v
->
Visit
(
"is_output"
,
&
is_output
);
v
->
Visit
(
"is_output"
,
&
is_output
);
v
->
Visit
(
"double_buffer"
,
&
double_buffer
);
v
->
Visit
(
"group"
,
&
group
);
v
->
Visit
(
"group"
,
&
group
);
v
->
Visit
(
"num_child_stages"
,
&
num_child_stages
);
v
->
Visit
(
"num_child_stages"
,
&
num_child_stages
);
}
}
...
...
python/tvm/build_module.py
View file @
a45d3b01
...
@@ -33,6 +33,7 @@ class BuildConfig(object):
...
@@ -33,6 +33,7 @@ class BuildConfig(object):
"offset_factor"
:
0
,
"offset_factor"
:
0
,
"data_alignment"
:
-
1
,
"data_alignment"
:
-
1
,
"restricted_func"
:
True
,
"restricted_func"
:
True
,
"double_buffer_split_loop"
:
True
,
"add_lower_pass"
:
None
"add_lower_pass"
:
None
}
}
def
__init__
(
self
,
**
kwargs
):
def
__init__
(
self
,
**
kwargs
):
...
@@ -97,6 +98,10 @@ def build_config(**kwargs):
...
@@ -97,6 +98,10 @@ def build_config(**kwargs):
not to overlap. This enables more optimization.
not to overlap. This enables more optimization.
Corresponds to restricted keyword in C99
Corresponds to restricted keyword in C99
double_buffer_split_loop: bool, default=True
Whether split the loop containing double buffer so
that the buffer fetching won't contain condition.
add_lower_pass: list of function(Stmt->Stmt), default=None
add_lower_pass: list of function(Stmt->Stmt), default=None
Additional lowering passes to be applied before make_api.
Additional lowering passes to be applied before make_api.
...
@@ -187,6 +192,7 @@ def lower(sch,
...
@@ -187,6 +192,7 @@ def lower(sch,
Then the Stmt before make api is returned.
Then the Stmt before make api is returned.
"""
"""
binds
,
arg_list
=
get_binds
(
args
,
binds
)
binds
,
arg_list
=
get_binds
(
args
,
binds
)
cfg
=
BuildConfig
.
current
# normalize schedule first
# normalize schedule first
sch
=
sch
.
normalize
()
sch
=
sch
.
normalize
()
bounds
=
schedule
.
InferBound
(
sch
)
bounds
=
schedule
.
InferBound
(
sch
)
...
@@ -198,8 +204,8 @@ def lower(sch,
...
@@ -198,8 +204,8 @@ def lower(sch,
stmt
=
ir_pass
.
LoopPartition
(
stmt
)
stmt
=
ir_pass
.
LoopPartition
(
stmt
)
stmt
=
ir_pass
.
VectorizeLoop
(
stmt
)
stmt
=
ir_pass
.
VectorizeLoop
(
stmt
)
stmt
=
ir_pass
.
InjectVirtualThread
(
stmt
)
stmt
=
ir_pass
.
InjectVirtualThread
(
stmt
)
stmt
=
ir_pass
.
InjectDoubleBuffer
(
stmt
,
cfg
.
double_buffer_split_loop
)
stmt
=
ir_pass
.
StorageRewrite
(
stmt
)
stmt
=
ir_pass
.
StorageRewrite
(
stmt
)
cfg
=
BuildConfig
.
current
stmt
=
ir_pass
.
UnrollLoop
(
stmt
=
ir_pass
.
UnrollLoop
(
stmt
,
stmt
,
cfg
.
auto_unroll_max_step
,
cfg
.
auto_unroll_max_step
,
...
...
python/tvm/ir_builder.py
View file @
a45d3b01
...
@@ -268,6 +268,21 @@ class IRBuilder(object):
...
@@ -268,6 +268,21 @@ class IRBuilder(object):
self
.
emit
(
_make
.
IfThenElse
(
prev
.
condition
,
prev
.
then_case
,
self
.
_pop_seq
()))
self
.
emit
(
_make
.
IfThenElse
(
prev
.
condition
,
prev
.
then_case
,
self
.
_pop_seq
()))
return
WithScope
(
None
,
_exit_cb
)
return
WithScope
(
None
,
_exit_cb
)
def
new_scope
(
self
):
"""Create new scope,
this is useful to set boundary of attr and allocate.
Returns
-------
new_scope : WithScope
The result new scope.
"""
self
.
_seq_stack
.
append
([])
def
_exit_cb
():
self
.
emit
(
self
.
_pop_seq
())
return
WithScope
(
None
,
_exit_cb
)
def
allocate
(
self
,
dtype
,
shape
,
name
=
"buf"
,
scope
=
None
):
def
allocate
(
self
,
dtype
,
shape
,
name
=
"buf"
,
scope
=
None
):
"""Create a allocate statement.
"""Create a allocate statement.
...
...
python/tvm/schedule.py
View file @
a45d3b01
...
@@ -589,4 +589,13 @@ class Stage(NodeBase):
...
@@ -589,4 +589,13 @@ class Stage(NodeBase):
"""
"""
_api_internal
.
_StageStorageAlign
(
self
,
axis
,
factor
,
offset
)
_api_internal
.
_StageStorageAlign
(
self
,
axis
,
factor
,
offset
)
def
double_buffer
(
self
):
"""Compute the current stage via double buffering.
This can only be applied to intermediate stage.
This will double the storage cost of the current stage.
Can be useful to hide load latency.
"""
_api_internal
.
_StageDoubleBuffer
(
self
)
_init_api
(
"tvm.schedule"
)
_init_api
(
"tvm.schedule"
)
src/api/api_lang.cc
View file @
a45d3b01
...
@@ -385,13 +385,18 @@ TVM_REGISTER_API("_StagePragma")
...
@@ -385,13 +385,18 @@ TVM_REGISTER_API("_StagePragma")
TVM_REGISTER_API
(
"_StagePrefetch"
)
TVM_REGISTER_API
(
"_StagePrefetch"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
args
[
0
].
operator
Stage
()
args
[
0
].
operator
Stage
()
.
prefetch
(
args
[
1
],
args
[
2
],
args
[
3
]);
.
prefetch
(
args
[
1
],
args
[
2
],
args
[
3
]);
});
});
TVM_REGISTER_API
(
"_StageStorageAlign"
)
TVM_REGISTER_API
(
"_StageStorageAlign"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
args
[
0
].
operator
Stage
()
args
[
0
].
operator
Stage
()
.
storage_align
(
args
[
1
],
args
[
2
],
args
[
3
]);
.
storage_align
(
args
[
1
],
args
[
2
],
args
[
3
]);
});
TVM_REGISTER_API
(
"_StageDoubleBuffer"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
args
[
0
].
operator
Stage
().
double_buffer
();
});
});
TVM_REGISTER_API
(
"_ScheduleNormalize"
)
TVM_REGISTER_API
(
"_ScheduleNormalize"
)
...
...
src/api/api_pass.cc
View file @
a45d3b01
...
@@ -101,6 +101,7 @@ REGISTER_PASS1(CoProcSync);
...
@@ -101,6 +101,7 @@ REGISTER_PASS1(CoProcSync);
REGISTER_PASS1
(
LowerStorageAccessInfo
);
REGISTER_PASS1
(
LowerStorageAccessInfo
);
REGISTER_PASS1
(
InjectVirtualThread
);
REGISTER_PASS1
(
InjectVirtualThread
);
REGISTER_PASS1
(
InjectPrefetch
);
REGISTER_PASS1
(
InjectPrefetch
);
REGISTER_PASS2
(
InjectDoubleBuffer
);
REGISTER_PASS1
(
LoopPartition
);
REGISTER_PASS1
(
LoopPartition
);
REGISTER_PASS1
(
RemoveNoOp
);
REGISTER_PASS1
(
RemoveNoOp
);
REGISTER_PASS2
(
SplitPipeline
);
REGISTER_PASS2
(
SplitPipeline
);
...
...
src/pass/inject_double_buffer.cc
0 → 100644
View file @
a45d3b01
/*!
* Copyright (c) 2017 by Contributors
*
* \brief Inject double buffering optimization for data fetch.
* \file inject_double_buffer.cc
*/
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include "./ir_util.h"
#include "../arithmetic/compute_expr.h"
namespace
tvm
{
namespace
ir
{
// Detect double buffer variables.
class
DoubleBufferDetector
:
public
IRVisitor
{
public
:
void
Visit_
(
const
AttrStmt
*
op
)
final
{
if
(
op
->
attr_key
==
attr
::
double_buffer_scope
)
{
touched_
.
insert
(
op
->
node
.
as
<
Variable
>
());
IRVisitor
::
Visit_
(
op
);
}
else
{
IRVisitor
::
Visit_
(
op
);
}
}
void
Visit_
(
const
Variable
*
op
)
final
{
if
(
touched_
.
count
(
op
))
{
touched_
.
erase
(
op
);
}
}
// The set of touched variable.
std
::
unordered_set
<
const
Variable
*>
touched_
;
};
class
DoubleBufferInjector
:
public
IRMutator
{
public
:
explicit
DoubleBufferInjector
(
bool
split_loop
)
:
split_loop_
(
split_loop
)
{}
Stmt
Inject
(
const
Stmt
&
stmt
)
{
DoubleBufferDetector
detector
;
detector
.
Visit
(
stmt
);
if
(
detector
.
touched_
.
empty
())
return
stmt
;
for
(
const
Variable
*
v
:
detector
.
touched_
)
{
dbuffer_info_
[
v
]
=
StorageEntry
();
}
return
ConvertSSA
(
this
->
Mutate
(
stmt
));
}
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
if
(
op
->
attr_key
==
attr
::
storage_scope
)
{
const
Variable
*
buf
=
op
->
node
.
as
<
Variable
>
();
auto
it
=
dbuffer_info_
.
find
(
buf
);
if
(
it
!=
dbuffer_info_
.
end
())
{
it
->
second
.
scope
=
op
->
value
.
as
<
StringImm
>
()
->
value
;
return
Mutate
(
op
->
body
);
}
else
{
return
IRMutator
::
Mutate_
(
op
,
s
);
}
}
else
if
(
op
->
attr_key
==
attr
::
double_buffer_scope
)
{
return
MakeProducer
(
op
,
s
);
}
else
{
return
IRMutator
::
Mutate_
(
op
,
s
);
}
}
Stmt
Mutate_
(
const
Allocate
*
op
,
const
Stmt
&
s
)
final
{
auto
it
=
dbuffer_info_
.
find
(
op
->
buffer_var
.
get
());
if
(
it
!=
dbuffer_info_
.
end
())
{
it
->
second
.
size
=
arith
::
ComputeReduce
<
Mul
>
(
op
->
extents
);
Stmt
stmt
=
IRMutator
::
Mutate_
(
op
,
s
);
op
=
stmt
.
as
<
Allocate
>
();
Array
<
Expr
>
new_extents
{
make_const
(
op
->
extents
[
0
].
type
(),
2
)};
for
(
Expr
e
:
op
->
extents
)
{
new_extents
.
push_back
(
e
);
}
CHECK
(
it
->
second
.
loop
!=
nullptr
);
auto
&
alloc_nest
=
loop_allocs_
[
it
->
second
.
loop
];
alloc_nest
.
emplace_back
(
AttrStmt
::
make
(
op
->
buffer_var
,
attr
::
storage_scope
,
StringImm
::
make
(
it
->
second
.
scope
),
Evaluate
::
make
(
0
)));
alloc_nest
.
emplace_back
(
Allocate
::
make
(
op
->
buffer_var
,
op
->
type
,
new_extents
,
op
->
condition
,
Evaluate
::
make
(
0
)));
return
op
->
body
;
}
else
{
return
IRMutator
::
Mutate_
(
op
,
s
);
}
}
Stmt
Mutate_
(
const
For
*
op
,
const
Stmt
&
s
)
final
{
loop_nest_
.
push_back
(
op
);
Stmt
stmt
=
IRMutator
::
Mutate_
(
op
,
s
);
auto
it
=
loop_pre_
.
find
(
op
);
if
(
it
!=
loop_pre_
.
end
())
{
const
For
*
old_loop
=
stmt
.
as
<
For
>
();
if
(
split_loop_
)
{
Expr
new_ext
=
arith
::
ComputeExpr
<
Sub
>
(
old_loop
->
extent
,
make_const
(
old_loop
->
loop_var
.
type
(),
1
));
Stmt
loop
=
For
::
make
(
old_loop
->
loop_var
,
old_loop
->
min
,
new_ext
,
old_loop
->
for_type
,
old_loop
->
device_api
,
old_loop
->
body
);
std
::
unordered_map
<
const
Variable
*
,
Expr
>
vmap
;
vmap
[
old_loop
->
loop_var
.
get
()]
=
new_ext
;
Stmt
end
=
Substitute
(
old_loop
->
body
,
vmap
);
stmt
=
Block
::
make
(
loop
,
end
);
}
stmt
=
Block
::
make
(
MergeSeq
(
it
->
second
),
stmt
);
}
it
=
loop_allocs_
.
find
(
op
);
if
(
it
!=
loop_allocs_
.
end
())
{
stmt
=
MergeNest
(
it
->
second
,
stmt
);
}
loop_nest_
.
pop_back
();
return
stmt
;
}
Stmt
Mutate_
(
const
Store
*
op
,
const
Stmt
&
s
)
final
{
Stmt
stmt
=
IRMutator
::
Mutate_
(
op
,
s
);
op
=
stmt
.
as
<
Store
>
();
auto
it
=
dbuffer_info_
.
find
(
op
->
buffer_var
.
get
());
if
(
it
!=
dbuffer_info_
.
end
())
{
const
StorageEntry
&
e
=
it
->
second
;
CHECK
(
in_double_buffer_scope_
);
CHECK
(
e
.
size
.
defined
());
return
Store
::
make
(
op
->
buffer_var
,
op
->
value
,
e
.
switch_write_var
*
e
.
size
+
op
->
index
,
op
->
predicate
);
}
else
{
return
stmt
;
}
}
Expr
Mutate_
(
const
Load
*
op
,
const
Expr
&
e
)
final
{
Expr
expr
=
IRMutator
::
Mutate_
(
op
,
e
);
op
=
expr
.
as
<
Load
>
();
auto
it
=
dbuffer_info_
.
find
(
op
->
buffer_var
.
get
());
if
(
it
!=
dbuffer_info_
.
end
())
{
const
StorageEntry
&
e
=
it
->
second
;
CHECK
(
e
.
size
.
defined
());
CHECK
(
e
.
switch_read_var
.
defined
());
return
Load
::
make
(
op
->
type
,
op
->
buffer_var
,
e
.
switch_read_var
*
e
.
size
+
op
->
index
,
op
->
predicate
);
}
else
{
return
expr
;
}
}
Expr
Mutate_
(
const
Variable
*
op
,
const
Expr
&
e
)
final
{
CHECK
(
!
dbuffer_info_
.
count
(
op
));
return
e
;
}
private
:
Stmt
MakeProducer
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
{
const
VarExpr
buffer
(
op
->
node
.
node_
);
CHECK_NE
(
loop_nest_
.
size
(),
0U
)
<<
"Double buffer scope must be inside a loop"
;
auto
it
=
dbuffer_info_
.
find
(
buffer
.
get
());
if
(
it
==
dbuffer_info_
.
end
())
{
LOG
(
WARNING
)
<<
"Skip double buffer scope "
<<
op
->
node
;
return
Mutate
(
op
->
body
);
}
StorageEntry
&
e
=
it
->
second
;
e
.
loop
=
loop_nest_
.
back
();
Expr
zero
=
make_const
(
e
.
loop
->
loop_var
.
type
(),
0
);
Expr
one
=
make_const
(
e
.
loop
->
loop_var
.
type
(),
1
);
Expr
two
=
make_const
(
e
.
loop
->
loop_var
.
type
(),
2
);
Expr
loop_shift
=
e
.
loop
->
loop_var
+
one
;
e
.
switch_write_var
=
Var
(
e
.
loop
->
loop_var
->
name_hint
+
".db"
,
e
.
loop
->
loop_var
.
type
());
e
.
switch_read_var
=
e
.
loop
->
loop_var
%
two
;
in_double_buffer_scope_
=
true
;
Stmt
body
=
Mutate
(
op
->
body
);
in_double_buffer_scope_
=
false
;
std
::
unordered_map
<
const
Variable
*
,
Expr
>
vmap
;
vmap
[
e
.
switch_write_var
.
get
()]
=
zero
;
vmap
[
e
.
loop
->
loop_var
.
get
()]
=
zero
;
loop_pre_
[
e
.
loop
].
emplace_back
(
Substitute
(
body
,
vmap
));
vmap
[
e
.
loop
->
loop_var
.
get
()]
=
loop_shift
;
vmap
[
e
.
switch_write_var
.
get
()]
=
loop_shift
%
two
;
body
=
Substitute
(
body
,
vmap
);
body
=
AttrStmt
::
make
(
buffer
,
attr
::
double_buffer_write
,
1
,
body
);
body
=
IfThenElse
::
make
(
loop_shift
<
e
.
loop
->
extent
,
body
);
return
body
;
}
// Storage entry for those who need double buffering.
struct
StorageEntry
{
// The size of the buffer
Expr
size
;
// The loop we need
const
For
*
loop
{
nullptr
};
// The switch variable.
VarExpr
switch_write_var
;
// The switch variable for reading.
Expr
switch_read_var
;
// The storage scope.
std
::
string
scope
;
};
// Whether split loop
bool
split_loop_
;
// Whether we are inside double buffer scope.
bool
in_double_buffer_scope_
{
false
};
// The current loop next
std
::
vector
<
const
For
*>
loop_nest_
;
// The allocs to be appended before the loop
std
::
unordered_map
<
const
For
*
,
std
::
vector
<
Stmt
>
>
loop_allocs_
;
// The stmt to be appended before the loop
std
::
unordered_map
<
const
For
*
,
std
::
vector
<
Stmt
>
>
loop_pre_
;
// The allocation size of the buffer
std
::
unordered_map
<
const
Variable
*
,
StorageEntry
>
dbuffer_info_
;
};
Stmt
InjectDoubleBuffer
(
Stmt
stmt
,
bool
split_loop
)
{
return
DoubleBufferInjector
(
split_loop
).
Inject
(
stmt
);
}
}
// namespace ir
}
// namespace tvm
src/pass/storage_access.cc
View file @
a45d3b01
...
@@ -74,6 +74,24 @@ void StorageAccessVisitor::Visit_(const AttrStmt* op) {
...
@@ -74,6 +74,24 @@ void StorageAccessVisitor::Visit_(const AttrStmt* op) {
storage_scope_
[
buf
]
=
storage_scope_
[
buf
]
=
StorageScope
::
make
(
op
->
value
.
as
<
StringImm
>
()
->
value
);
StorageScope
::
make
(
op
->
value
.
as
<
StringImm
>
()
->
value
);
IRVisitor
::
Visit_
(
op
);
IRVisitor
::
Visit_
(
op
);
}
else
if
(
op
->
attr_key
==
attr
::
double_buffer_write
)
{
CHECK
(
double_buffer_write_
==
nullptr
);
double_buffer_write_
=
op
->
node
.
as
<
Variable
>
();
scope_
.
push_back
(
std
::
vector
<
StmtEntry
>
());
IRVisitor
::
Visit_
(
op
);
StmtEntry
s
;
s
.
stmt
=
op
;
s
.
access
=
Summarize
(
std
::
move
(
scope_
.
back
()),
nullptr
);
scope_
.
pop_back
();
if
(
!
s
.
access
.
empty
())
{
for
(
AccessEntry
&
e
:
s
.
access
)
{
if
(
e
.
type
==
kWrite
&&
e
.
buffer
.
get
()
==
double_buffer_write_
)
{
e
.
double_buffer_write
=
true
;
}
}
scope_
.
back
().
emplace_back
(
std
::
move
(
s
));
}
double_buffer_write_
=
nullptr
;
}
else
if
(
op
->
attr_key
==
attr
::
coproc_scope
)
{
}
else
if
(
op
->
attr_key
==
attr
::
coproc_scope
)
{
IterVar
iv
(
op
->
node
.
node_
);
IterVar
iv
(
op
->
node
.
node_
);
env_threads_
.
push_back
(
iv
);
env_threads_
.
push_back
(
iv
);
...
...
src/pass/storage_access.h
View file @
a45d3b01
...
@@ -45,6 +45,8 @@ class StorageAccessVisitor : public IRVisitor {
...
@@ -45,6 +45,8 @@ class StorageAccessVisitor : public IRVisitor {
AccessType
type
;
AccessType
type
;
/*! \brief The storage scope */
/*! \brief The storage scope */
StorageScope
scope
;
StorageScope
scope
;
/*! \brief Whether the access is double buffer write */
bool
double_buffer_write
{
false
};
};
};
/*! \brief Access pattern about a single statement */
/*! \brief Access pattern about a single statement */
struct
StmtEntry
{
struct
StmtEntry
{
...
@@ -116,6 +118,8 @@ class StorageAccessVisitor : public IRVisitor {
...
@@ -116,6 +118,8 @@ class StorageAccessVisitor : public IRVisitor {
bool
in_device_env_
{
false
};
bool
in_device_env_
{
false
};
// Whether we are inside condition.
// Whether we are inside condition.
int
condition_counter_
{
0
};
int
condition_counter_
{
0
};
// The current double buffer write scope.
const
Variable
*
double_buffer_write_
{
nullptr
};
// the current free stmt entry.
// the current free stmt entry.
StmtEntry
curr_stmt_
;
StmtEntry
curr_stmt_
;
// The involving threads
// The involving threads
...
...
src/pass/storage_flatten.cc
View file @
a45d3b01
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
*/
*/
#include <tvm/ir.h>
#include <tvm/ir.h>
#include <tvm/expr.h>
#include <tvm/expr.h>
#include <tvm/operation.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_operator.h>
#include <tvm/ir_operator.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_pass.h>
...
@@ -53,6 +54,18 @@ class StorageFlattener : public IRMutator {
...
@@ -53,6 +54,18 @@ class StorageFlattener : public IRMutator {
if
(
op
->
attr_key
==
attr
::
realize_scope
)
{
if
(
op
->
attr_key
==
attr
::
realize_scope
)
{
storage_scope_
[
op
->
node
.
get
()]
=
op
->
value
.
as
<
StringImm
>
()
->
value
;
storage_scope_
[
op
->
node
.
get
()]
=
op
->
value
.
as
<
StringImm
>
()
->
value
;
return
this
->
Mutate
(
op
->
body
);
return
this
->
Mutate
(
op
->
body
);
}
else
if
(
op
->
attr_key
==
attr
::
double_buffer_scope
)
{
Operation
func
(
op
->
node
.
node_
);
Stmt
body
=
Mutate
(
op
->
body
);
for
(
int
i
=
0
;
i
<
func
->
num_outputs
();
++
i
)
{
TensorKey
key
{
func
,
i
};
auto
it
=
buf_map_
.
find
(
key
);
CHECK
(
it
!=
buf_map_
.
end
())
<<
"Cannot find allocated buffer for "
<<
key
.
f
;
body
=
AttrStmt
::
make
(
it
->
second
.
buffer
->
data
,
op
->
attr_key
,
op
->
value
,
body
);
}
return
body
;
}
else
if
(
op
->
attr_key
==
attr
::
thread_extent
)
{
}
else
if
(
op
->
attr_key
==
attr
::
thread_extent
)
{
IterVar
iv
(
op
->
node
.
node_
);
IterVar
iv
(
op
->
node
.
node_
);
ThreadScope
ts
=
ThreadScope
::
make
(
iv
->
thread_tag
);
ThreadScope
ts
=
ThreadScope
::
make
(
iv
->
thread_tag
);
...
...
src/pass/storage_sync.cc
View file @
a45d3b01
...
@@ -34,13 +34,10 @@ class ThreadSyncPlanner : public StorageAccessVisitor {
...
@@ -34,13 +34,10 @@ class ThreadSyncPlanner : public StorageAccessVisitor {
// Unsynced reads and writes
// Unsynced reads and writes
std
::
vector
<
AccessEntry
>
reads
;
std
::
vector
<
AccessEntry
>
reads
;
std
::
vector
<
AccessEntry
>
writes
;
std
::
vector
<
AccessEntry
>
writes
;
// if it is a loop, rotate two times to consider effect of loop.
// if it is a loop, rotate two times to consider effect of loop.
size_t
max_seq
=
seq
.
size
();
if
(
loop
!=
nullptr
)
max_seq
*=
2
;
// simulation based approach to find dependenceies
// simulation based approach to find dependenceies
for
(
size_t
i
=
0
;
i
<
max_seq
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
seq
.
size
()
;
++
i
)
{
const
StmtEntry
&
s
=
seq
[
i
%
seq
.
size
()
];
const
StmtEntry
&
s
=
seq
[
i
];
// check if sync before statement is needed.
// check if sync before statement is needed.
bool
sync_before_stmt
=
(
syncs_inserted_
.
count
(
s
.
stmt
)
!=
0
);
bool
sync_before_stmt
=
(
syncs_inserted_
.
count
(
s
.
stmt
)
!=
0
);
// Apply the syncs added already.
// Apply the syncs added already.
...
@@ -50,11 +47,11 @@ class ThreadSyncPlanner : public StorageAccessVisitor {
...
@@ -50,11 +47,11 @@ class ThreadSyncPlanner : public StorageAccessVisitor {
}
}
for
(
const
AccessEntry
&
acc
:
s
.
access
)
{
for
(
const
AccessEntry
&
acc
:
s
.
access
)
{
if
(
acc
.
type
==
kRead
)
{
if
(
acc
.
type
==
kRead
)
{
if
(
FindConflict
(
writes
,
acc
))
{
if
(
FindConflict
(
writes
,
acc
,
false
))
{
sync_before_stmt
=
true
;
break
;
sync_before_stmt
=
true
;
break
;
}
}
}
else
if
(
acc
.
type
==
kWrite
)
{
}
else
if
(
acc
.
type
==
kWrite
)
{
if
(
FindConflict
(
reads
,
acc
))
{
if
(
FindConflict
(
reads
,
acc
,
false
))
{
sync_before_stmt
=
true
;
break
;
sync_before_stmt
=
true
;
break
;
}
}
}
else
if
(
acc
.
type
==
kSync
)
{
}
else
if
(
acc
.
type
==
kSync
)
{
...
@@ -81,6 +78,33 @@ class ThreadSyncPlanner : public StorageAccessVisitor {
...
@@ -81,6 +78,33 @@ class ThreadSyncPlanner : public StorageAccessVisitor {
syncs_inserted_
.
insert
(
s
.
stmt
);
syncs_inserted_
.
insert
(
s
.
stmt
);
}
}
}
}
if
(
loop
!=
nullptr
)
{
for
(
size_t
i
=
0
;
i
<
seq
.
size
();
++
i
)
{
const
StmtEntry
&
s
=
seq
[
i
];
if
(
syncs_inserted_
.
count
(
s
.
stmt
)
!=
0
)
break
;
if
(
reads
.
empty
()
&&
writes
.
empty
())
break
;
bool
sync_before_stmt
=
false
;
for
(
const
AccessEntry
&
acc
:
s
.
access
)
{
if
(
acc
.
type
==
kRead
)
{
if
(
FindConflict
(
writes
,
acc
,
true
))
{
sync_before_stmt
=
true
;
break
;
}
}
else
if
(
acc
.
type
==
kWrite
)
{
if
(
FindConflict
(
reads
,
acc
,
true
))
{
sync_before_stmt
=
true
;
break
;
}
}
else
if
(
acc
.
type
==
kSync
)
{
reads
.
clear
();
writes
.
clear
();
}
}
if
(
sync_before_stmt
)
{
CHECK_EQ
(
condition_counter
(),
0
)
<<
"Cannot insert syncs inside condition"
;
syncs_inserted_
.
insert
(
s
.
stmt
);
break
;
}
}
}
// return the exposed entries, remove unecessary ones.
// return the exposed entries, remove unecessary ones.
int
sync_count
=
0
;
int
sync_count
=
0
;
// head are before first sync, tail are after last sync
// head are before first sync, tail are after last sync
...
@@ -117,13 +141,20 @@ class ThreadSyncPlanner : public StorageAccessVisitor {
...
@@ -117,13 +141,20 @@ class ThreadSyncPlanner : public StorageAccessVisitor {
}
}
}
}
head
.
insert
(
head
.
end
(),
tail
.
begin
(),
tail
.
end
());
head
.
insert
(
head
.
end
(),
tail
.
begin
(),
tail
.
end
());
if
(
loop
!=
nullptr
)
{
// clear double buffer flag after a loop is finished.
for
(
AccessEntry
&
e
:
head
)
{
e
.
double_buffer_write
=
false
;
}
}
return
head
;
return
head
;
}
}
private
:
private
:
// find conflicting entry in vec.
// find conflicting entry in vec.
bool
FindConflict
(
const
std
::
vector
<
AccessEntry
>&
vec
,
bool
FindConflict
(
const
std
::
vector
<
AccessEntry
>&
vec
,
const
AccessEntry
&
e
)
{
const
AccessEntry
&
e
,
bool
loop_carry
)
{
for
(
const
AccessEntry
&
x
:
vec
)
{
for
(
const
AccessEntry
&
x
:
vec
)
{
if
(
x
.
buffer
.
same_as
(
e
.
buffer
))
{
if
(
x
.
buffer
.
same_as
(
e
.
buffer
))
{
// Assumes no race between threads
// Assumes no race between threads
...
@@ -134,6 +165,9 @@ class ThreadSyncPlanner : public StorageAccessVisitor {
...
@@ -134,6 +165,9 @@ class ThreadSyncPlanner : public StorageAccessVisitor {
if
(
Equal
(
e
.
touched
.
point_value
(),
if
(
Equal
(
e
.
touched
.
point_value
(),
x
.
touched
.
point_value
()))
continue
;
x
.
touched
.
point_value
()))
continue
;
}
}
if
(
x
.
double_buffer_write
&&
e
.
type
==
kRead
&&
!
loop_carry
)
continue
;
return
true
;
return
true
;
}
}
}
}
...
...
src/schedule/schedule_lang.cc
View file @
a45d3b01
...
@@ -385,6 +385,13 @@ Stage& Stage::storage_align(IterVar axis, int factor, int offset) {
...
@@ -385,6 +385,13 @@ Stage& Stage::storage_align(IterVar axis, int factor, int offset) {
return
*
this
;
return
*
this
;
}
}
Stage
&
Stage
::
double_buffer
()
{
StageNode
*
self
=
operator
->
();
CHECK
(
!
self
->
is_output
)
<<
"Cannot apply double buffer on output"
;
self
->
double_buffer
=
true
;
return
*
this
;
}
Stage
CopyStage
(
const
Stage
&
s
)
{
Stage
CopyStage
(
const
Stage
&
s
)
{
std
::
shared_ptr
<
StageNode
>
n
=
std
::
shared_ptr
<
StageNode
>
n
=
std
::
make_shared
<
StageNode
>
(
*
s
.
operator
->
());
std
::
make_shared
<
StageNode
>
(
*
s
.
operator
->
());
...
...
src/schedule/schedule_ops.cc
View file @
a45d3b01
...
@@ -27,6 +27,10 @@ Stmt MakePipeline(const Stage& s,
...
@@ -27,6 +27,10 @@ Stmt MakePipeline(const Stage& s,
if
(
producer
.
defined
())
{
if
(
producer
.
defined
())
{
producer
=
ProducerConsumer
::
make
(
s
->
op
,
true
,
producer
);
producer
=
ProducerConsumer
::
make
(
s
->
op
,
true
,
producer
);
}
}
if
(
s
->
double_buffer
)
{
producer
=
AttrStmt
::
make
(
s
->
op
,
ir
::
attr
::
double_buffer_scope
,
1
,
producer
);
}
Stmt
pipeline
=
producer
;
Stmt
pipeline
=
producer
;
if
(
consumer
.
defined
()
&&
!
is_no_op
(
consumer
))
{
if
(
consumer
.
defined
()
&&
!
is_no_op
(
consumer
))
{
...
@@ -170,7 +174,8 @@ class SchedulePostProc : public IRMutator {
...
@@ -170,7 +174,8 @@ class SchedulePostProc : public IRMutator {
thread_extent_scope_
.
erase
(
op
->
node
.
get
());
thread_extent_scope_
.
erase
(
op
->
node
.
get
());
return
ret
;
return
ret
;
}
}
}
else
if
(
op
->
attr_key
==
ir
::
attr
::
realize_scope
)
{
}
else
if
(
op
->
attr_key
==
ir
::
attr
::
realize_scope
||
op
->
attr_key
==
ir
::
attr
::
double_buffer_scope
)
{
auto
it
=
replace_op_
.
find
(
op
->
node
.
get
());
auto
it
=
replace_op_
.
find
(
op
->
node
.
get
());
if
(
it
!=
replace_op_
.
end
())
{
if
(
it
!=
replace_op_
.
end
())
{
if
(
it
->
second
.
defined
())
{
if
(
it
->
second
.
defined
())
{
...
...
tests/python/integration/test_gemm.py
View file @
a45d3b01
...
@@ -47,7 +47,8 @@ def test_gemm():
...
@@ -47,7 +47,8 @@ def test_gemm():
s
[
CC
]
.
compute_at
(
s
[
C
],
tx
)
s
[
CC
]
.
compute_at
(
s
[
C
],
tx
)
s
[
AA
]
.
compute_at
(
s
[
CC
],
k
)
s
[
AA
]
.
compute_at
(
s
[
CC
],
k
)
s
[
BB
]
.
compute_at
(
s
[
CC
],
k
)
s
[
BB
]
.
compute_at
(
s
[
CC
],
k
)
s
[
AA
]
.
double_buffer
()
s
[
BB
]
.
double_buffer
()
ty
,
xi
=
s
[
AA
]
.
split
(
s
[
AA
]
.
op
.
axis
[
0
],
nparts
=
num_thread
)
ty
,
xi
=
s
[
AA
]
.
split
(
s
[
AA
]
.
op
.
axis
[
0
],
nparts
=
num_thread
)
tx
,
xi
=
s
[
AA
]
.
split
(
xi
,
nparts
=
num_thread
)
tx
,
xi
=
s
[
AA
]
.
split
(
xi
,
nparts
=
num_thread
)
s
[
AA
]
.
bind
(
ty
,
thread_y
)
s
[
AA
]
.
bind
(
ty
,
thread_y
)
...
@@ -84,10 +85,10 @@ def test_gemm():
...
@@ -84,10 +85,10 @@ def test_gemm():
np
.
testing
.
assert_allclose
(
np
.
testing
.
assert_allclose
(
c
.
asnumpy
(),
np
.
dot
(
a_np
,
b_np
.
T
),
rtol
=
1e-5
)
c
.
asnumpy
(),
np
.
dot
(
a_np
,
b_np
.
T
),
rtol
=
1e-5
)
check_device
(
"nvptx -mcpu=sm_20"
)
check_device
(
"metal"
)
check_device
(
"metal"
)
check_device
(
"opencl"
)
check_device
(
"opencl"
)
check_device
(
"cuda"
)
check_device
(
"cuda"
)
#check_device("nvptx -mcpu=sm_20")
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_gemm
()
test_gemm
()
tests/python/unittest/test_pass_inject_double_buffer.py
0 → 100644
View file @
a45d3b01
import
tvm
def
test_double_buffer
():
dtype
=
'int64'
n
=
100
m
=
4
tx
=
tvm
.
thread_axis
(
"threadIdx.x"
)
ib
=
tvm
.
ir_builder
.
create
()
A
=
ib
.
pointer
(
"float32"
,
name
=
"A"
)
C
=
ib
.
pointer
(
"float32"
,
name
=
"A"
)
ib
.
scope_attr
(
tx
,
"thread_extent"
,
1
)
with
ib
.
for_range
(
0
,
n
)
as
i
:
B
=
ib
.
allocate
(
"float32"
,
m
,
name
=
"B"
,
scope
=
"shared"
)
with
ib
.
new_scope
():
ib
.
scope_attr
(
B
.
asnode
(),
"double_buffer_scope"
,
1
)
with
ib
.
for_range
(
0
,
m
)
as
j
:
B
[
j
]
=
A
[
i
*
4
+
j
]
with
ib
.
for_range
(
0
,
m
)
as
j
:
C
[
j
]
=
B
[
j
]
+
1
stmt
=
ib
.
get
()
stmt
=
tvm
.
ir_pass
.
InjectDoubleBuffer
(
stmt
,
True
)
stmt
=
tvm
.
ir_pass
.
Simplify
(
stmt
)
assert
isinstance
(
stmt
.
body
.
body
,
tvm
.
stmt
.
Allocate
)
assert
stmt
.
body
.
body
.
extents
[
0
]
.
value
==
2
f
=
tvm
.
ir_pass
.
MakeAPI
(
stmt
,
"db"
,
[
A
.
asnode
(),
C
.
asnode
()],
2
,
True
)
f
=
tvm
.
ir_pass
.
ThreadSync
(
f
,
"shared"
)
count
=
[
0
]
def
count_sync
(
op
):
if
isinstance
(
op
,
tvm
.
expr
.
Call
)
and
op
.
name
==
"tvm_storage_sync"
:
count
[
0
]
+=
1
tvm
.
ir_pass
.
PostOrderVisit
(
f
.
body
,
count_sync
)
assert
count
[
0
]
==
2
if
__name__
==
"__main__"
:
test_double_buffer
()
topi/recipe/gemm/cuda_gemm_square.py
View file @
a45d3b01
...
@@ -96,6 +96,8 @@ def test_gemm():
...
@@ -96,6 +96,8 @@ def test_gemm():
s
[
BB
]
.
bind
(
ty
,
thread_y
)
s
[
BB
]
.
bind
(
ty
,
thread_y
)
s
[
BB
]
.
bind
(
tx
,
thread_x
)
s
[
BB
]
.
bind
(
tx
,
thread_x
)
s
[
BB
]
.
vectorize
(
xi
)
s
[
BB
]
.
vectorize
(
xi
)
s
[
AA
]
.
double_buffer
()
s
[
BB
]
.
double_buffer
()
# correctness
# correctness
def
check_device
(
device
):
def
check_device
(
device
):
if
not
tvm
.
module
.
enabled
(
device
):
if
not
tvm
.
module
.
enabled
(
device
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment