Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
adf39837
Commit
adf39837
authored
Sep 02, 2017
by
Tianqi Chen
Committed by
GitHub
Sep 02, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[PASS] Improve double buffer (#413)
parent
5072efae
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
53 additions
and
19 deletions
+53
-19
include/tvm/ir_pass.h
+2
-2
python/tvm/build_module.py
+5
-4
src/pass/inject_double_buffer.cc
+44
-11
tests/python/unittest/test_pass_inject_double_buffer.py
+2
-2
No files found.
include/tvm/ir_pass.h
View file @
adf39837
...
@@ -234,10 +234,10 @@ Stmt InjectPrefetch(Stmt stmt);
...
@@ -234,10 +234,10 @@ Stmt InjectPrefetch(Stmt stmt);
/*!
/*!
* \brief Inject double buffer into stmt.
* \brief Inject double buffer into stmt.
* \param stmt The statment to be transformed.
* \param stmt The statment to be transformed.
* \param split_loop
Whether split the loop containing double buffering
.
* \param split_loop
Loop splitting factor
.
* \return Transformed stmt.
* \return Transformed stmt.
*/
*/
Stmt
InjectDoubleBuffer
(
Stmt
stmt
,
bool
split_loop
);
Stmt
InjectDoubleBuffer
(
Stmt
stmt
,
int
split_loop
);
/*!
/*!
* \brief Rewrite storage allocation pattern.
* \brief Rewrite storage allocation pattern.
...
...
python/tvm/build_module.py
View file @
adf39837
...
@@ -33,7 +33,7 @@ class BuildConfig(object):
...
@@ -33,7 +33,7 @@ class BuildConfig(object):
"offset_factor"
:
0
,
"offset_factor"
:
0
,
"data_alignment"
:
-
1
,
"data_alignment"
:
-
1
,
"restricted_func"
:
True
,
"restricted_func"
:
True
,
"double_buffer_split_loop"
:
True
,
"double_buffer_split_loop"
:
1
,
"add_lower_pass"
:
None
"add_lower_pass"
:
None
}
}
def
__init__
(
self
,
**
kwargs
):
def
__init__
(
self
,
**
kwargs
):
...
@@ -99,9 +99,10 @@ def build_config(**kwargs):
...
@@ -99,9 +99,10 @@ def build_config(**kwargs):
not to overlap. This enables more optimization.
not to overlap. This enables more optimization.
Corresponds to restricted keyword in C99
Corresponds to restricted keyword in C99
double_buffer_split_loop: bool, default=True
double_buffer_split_loop: int, default=2
Whether split the loop containing double buffer so
Whether split the loop with factor. If it is zero, no splitting will happen.
that the buffer fetching won't contain condition.
It it is bigger than one, the logic will do a split with factor equals the integer
and unroll the inner loop. This allows the buffer fetching won't contain condition.
add_lower_pass: list of tuiple (phase, function(Stmt->Stmt)), default=None
add_lower_pass: list of tuiple (phase, function(Stmt->Stmt)), default=None
phase contains an integer on which optimization pass we apply the pass.
phase contains an integer on which optimization pass we apply the pass.
...
...
src/pass/inject_double_buffer.cc
View file @
adf39837
...
@@ -34,9 +34,21 @@ class DoubleBufferDetector : public IRVisitor {
...
@@ -34,9 +34,21 @@ class DoubleBufferDetector : public IRVisitor {
std
::
unordered_set
<
const
Variable
*>
touched_
;
std
::
unordered_set
<
const
Variable
*>
touched_
;
};
};
class
StripDoubleBufferWrite
:
public
IRMutator
{
public
:
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
if
(
op
->
attr_key
==
attr
::
double_buffer_write
)
{
return
Mutate
(
op
->
body
);
}
else
{
return
IRMutator
::
Mutate_
(
op
,
s
);
}
}
};
class
DoubleBufferInjector
:
public
IRMutator
{
class
DoubleBufferInjector
:
public
IRMutator
{
public
:
public
:
explicit
DoubleBufferInjector
(
bool
split_loop
)
explicit
DoubleBufferInjector
(
int
split_loop
)
:
split_loop_
(
split_loop
)
{}
:
split_loop_
(
split_loop
)
{}
Stmt
Inject
(
const
Stmt
&
stmt
)
{
Stmt
Inject
(
const
Stmt
&
stmt
)
{
...
@@ -97,17 +109,38 @@ class DoubleBufferInjector : public IRMutator {
...
@@ -97,17 +109,38 @@ class DoubleBufferInjector : public IRMutator {
auto
it
=
loop_pre_
.
find
(
op
);
auto
it
=
loop_pre_
.
find
(
op
);
if
(
it
!=
loop_pre_
.
end
())
{
if
(
it
!=
loop_pre_
.
end
())
{
const
For
*
old_loop
=
stmt
.
as
<
For
>
();
const
For
*
old_loop
=
stmt
.
as
<
For
>
();
if
(
split_loop_
)
{
if
(
split_loop_
!=
0
)
{
// Explicitly unroll the loop
CHECK
(
split_loop_
%
2
==
0
||
split_loop_
==
1
)
<<
"It is better to split with multiple of 2"
;
CHECK
(
is_zero
(
old_loop
->
min
));
Expr
zero
=
old_loop
->
min
;
Expr
new_ext
=
arith
::
ComputeExpr
<
Sub
>
(
Expr
new_ext
=
arith
::
ComputeExpr
<
Sub
>
(
old_loop
->
extent
,
make_const
(
old_loop
->
loop_var
.
type
(),
1
));
old_loop
->
extent
,
make_const
(
old_loop
->
loop_var
.
type
(),
1
));
Stmt
loop
=
For
::
make
(
Expr
factor
=
make_const
(
new_ext
.
type
(),
split_loop_
);
old_loop
->
loop_var
,
old_loop
->
min
,
new_ext
,
Expr
outer_ext
=
arith
::
ComputeExpr
<
Div
>
(
new_ext
,
factor
);
old_loop
->
for_type
,
old_loop
->
device_api
,
Expr
tail_base
=
arith
::
ComputeExpr
<
Mul
>
(
outer_ext
,
factor
);
old_loop
->
body
);
Var
outer_var
(
old_loop
->
loop_var
->
name_hint
+
".outer"
,
old_loop
->
loop_var
.
type
()
);
std
::
unordered_map
<
const
Variable
*
,
Expr
>
vmap
;
std
::
unordered_map
<
const
Variable
*
,
Expr
>
vmap
;
vmap
[
old_loop
->
loop_var
.
get
()]
=
new_ext
;
std
::
vector
<
Stmt
>
loop_seq
;
Stmt
end
=
Substitute
(
old_loop
->
body
,
vmap
);
for
(
size_t
i
=
0
;
i
<
split_loop_
;
++
i
)
{
stmt
=
Block
::
make
(
loop
,
end
);
vmap
[
old_loop
->
loop_var
.
get
()]
=
outer_var
*
factor
+
make_const
(
factor
.
type
(),
i
);
loop_seq
.
emplace_back
(
Substitute
(
old_loop
->
body
,
vmap
));
}
Stmt
loop
=
For
::
make
(
outer_var
,
zero
,
outer_ext
,
old_loop
->
for_type
,
old_loop
->
device_api
,
MergeSeq
(
loop_seq
));
// tail
std
::
vector
<
Stmt
>
tail_seq
;
Stmt
tail_body
=
StripDoubleBufferWrite
().
Mutate
(
old_loop
->
body
);
for
(
size_t
i
=
0
;
i
<
split_loop_
;
++
i
)
{
Expr
idx
=
tail_base
+
make_const
(
tail_base
.
type
(),
i
);
vmap
[
old_loop
->
loop_var
.
get
()]
=
idx
;
tail_seq
.
emplace_back
(
IfThenElse
::
make
(
idx
<
old_loop
->
extent
,
Substitute
(
tail_body
,
vmap
)));
}
stmt
=
Block
::
make
(
loop
,
MergeSeq
(
tail_seq
));
}
}
stmt
=
Block
::
make
(
MergeSeq
(
it
->
second
),
stmt
);
stmt
=
Block
::
make
(
MergeSeq
(
it
->
second
),
stmt
);
}
}
...
@@ -205,7 +238,7 @@ class DoubleBufferInjector : public IRMutator {
...
@@ -205,7 +238,7 @@ class DoubleBufferInjector : public IRMutator {
std
::
string
scope
;
std
::
string
scope
;
};
};
// Whether split loop
// Whether split loop
bool
split_loop_
;
int
split_loop_
;
// Whether we are inside double buffer scope.
// Whether we are inside double buffer scope.
bool
in_double_buffer_scope_
{
false
};
bool
in_double_buffer_scope_
{
false
};
// The current loop next
// The current loop next
...
@@ -219,7 +252,7 @@ class DoubleBufferInjector : public IRMutator {
...
@@ -219,7 +252,7 @@ class DoubleBufferInjector : public IRMutator {
};
};
Stmt
InjectDoubleBuffer
(
Stmt
stmt
,
bool
split_loop
)
{
Stmt
InjectDoubleBuffer
(
Stmt
stmt
,
int
split_loop
)
{
return
DoubleBufferInjector
(
split_loop
).
Inject
(
stmt
);
return
DoubleBufferInjector
(
split_loop
).
Inject
(
stmt
);
}
}
}
// namespace ir
}
// namespace ir
...
...
tests/python/unittest/test_pass_inject_double_buffer.py
View file @
adf39837
...
@@ -19,7 +19,7 @@ def test_double_buffer():
...
@@ -19,7 +19,7 @@ def test_double_buffer():
C
[
j
]
=
B
[
j
]
+
1
C
[
j
]
=
B
[
j
]
+
1
stmt
=
ib
.
get
()
stmt
=
ib
.
get
()
stmt
=
tvm
.
ir_pass
.
InjectDoubleBuffer
(
stmt
,
True
)
stmt
=
tvm
.
ir_pass
.
InjectDoubleBuffer
(
stmt
,
2
)
stmt
=
tvm
.
ir_pass
.
Simplify
(
stmt
)
stmt
=
tvm
.
ir_pass
.
Simplify
(
stmt
)
assert
isinstance
(
stmt
.
body
.
body
,
tvm
.
stmt
.
Allocate
)
assert
isinstance
(
stmt
.
body
.
body
,
tvm
.
stmt
.
Allocate
)
assert
stmt
.
body
.
body
.
extents
[
0
]
.
value
==
2
assert
stmt
.
body
.
body
.
extents
[
0
]
.
value
==
2
...
@@ -30,7 +30,7 @@ def test_double_buffer():
...
@@ -30,7 +30,7 @@ def test_double_buffer():
if
isinstance
(
op
,
tvm
.
expr
.
Call
)
and
op
.
name
==
"tvm_storage_sync"
:
if
isinstance
(
op
,
tvm
.
expr
.
Call
)
and
op
.
name
==
"tvm_storage_sync"
:
count
[
0
]
+=
1
count
[
0
]
+=
1
tvm
.
ir_pass
.
PostOrderVisit
(
f
.
body
,
count_sync
)
tvm
.
ir_pass
.
PostOrderVisit
(
f
.
body
,
count_sync
)
assert
count
[
0
]
==
2
assert
count
[
0
]
==
4
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment