Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
fd96d285
Commit
fd96d285
authored
Aug 03, 2017
by
Tianqi Chen
Committed by
GitHub
Aug 03, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[PASS] More storage sync. (#297)
parent
581be165
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
255 additions
and
11 deletions
+255
-11
include/tvm/ir.h
+2
-0
include/tvm/ir_pass.h
+8
-0
python/tvm/build_module.py
+1
-0
python/tvm/make.py
+20
-0
src/api/api_pass.cc
+1
-0
src/pass/storage_access.cc
+11
-4
src/pass/storage_access.h
+6
-4
src/pass/storage_sync.cc
+186
-2
tests/python/unittest/test_pass_storage_sync.py
+20
-1
No files found.
include/tvm/ir.h
View file @
fd96d285
...
...
@@ -143,6 +143,8 @@ namespace attr {
constexpr
const
char
*
thread_extent
=
"thread_extent"
;
/*! \brief Mark launching of a virtual thread. */
constexpr
const
char
*
virtual_thread
=
"virtual_thread"
;
/*! \brief Mark region is processed by a co-proccesor */
constexpr
const
char
*
coproc_scope
=
"coproc_scope"
;
/*! \brief Mark the scope as volatile access for certain handle. */
constexpr
const
char
*
volatile_scope
=
"volatile_scope"
;
/*!
...
...
include/tvm/ir_pass.h
View file @
fd96d285
...
...
@@ -250,6 +250,14 @@ Stmt StorageRewrite(Stmt stmt);
Stmt
LoopPartition
(
Stmt
stmt
);
/*!
* \brief Detect and insert sync points to co-processor.
*
* \param stmt The stmt to be trasnformed
* \return Transformed stmt.
*/
Stmt
CoProcSync
(
Stmt
stmt
);
/*!
* \brief Make an user callable API LoweredFunc.
*
* The main task of this function is to create code to :
...
...
python/tvm/build_module.py
View file @
fd96d285
...
...
@@ -193,6 +193,7 @@ def lower(sch,
stmt
=
ir_pass
.
VectorizeLoop
(
stmt
)
stmt
=
ir_pass
.
InjectVirtualThread
(
stmt
)
stmt
=
ir_pass
.
StorageRewrite
(
stmt
)
stmt
=
ir_pass
.
CoProcSync
(
stmt
)
cfg
=
BuildConfig
.
current
stmt
=
ir_pass
.
UnrollLoop
(
stmt
,
...
...
python/tvm/make.py
View file @
fd96d285
...
...
@@ -77,4 +77,24 @@ def stmt_seq(*args):
ret
=
value
if
ret
is
None
else
Block
(
ret
,
value
)
return
ret
if
ret
else
Evaluate
(
0
)
def
stmt_list
(
stmt
):
"""Make list of stmt from blocks.
Parameters
----------
stmt : A block statement
Returns
-------
stmt_list : list of Stmt
The unpacked list of statements
"""
if
isinstance
(
stmt
,
_stmt
.
Block
):
return
stmt_list
(
stmt
.
first
)
+
stmt_list
(
stmt
.
rest
)
elif
isinstance
(
stmt
,
_stmt
.
ProducerConsumer
):
return
stmt_list
(
stmt
.
body
)
return
[
stmt
]
_init_api
(
"tvm.make"
)
src/api/api_pass.cc
View file @
fd96d285
...
...
@@ -94,6 +94,7 @@ REGISTER_PASS5(MakeAPI);
REGISTER_PASS2
(
BindDeviceType
);
REGISTER_PASS1
(
SplitHostDevice
);
REGISTER_PASS1
(
StorageRewrite
);
REGISTER_PASS1
(
CoProcSync
);
REGISTER_PASS1
(
InjectVirtualThread
);
REGISTER_PASS1
(
InjectPrefetch
);
REGISTER_PASS1
(
LoopPartition
);
...
...
src/pass/storage_access.cc
View file @
fd96d285
...
...
@@ -14,7 +14,7 @@ void StorageAccessVisitor::Visit_(const Load* op) {
CHECK
(
allow_append_
);
AccessEntry
e
;
e
.
threads
=
env_threads
();
e
.
buffer
=
buf
;
e
.
buffer
=
op
->
buffer_var
;
e
.
dtype
=
op
->
type
.
element_of
();
e
.
touched
=
arith
::
IntSet
::
vector
(
op
->
index
);
e
.
type
=
kRead
;
...
...
@@ -34,7 +34,7 @@ void StorageAccessVisitor::Visit_(const Store* op) {
if
(
Enabled
(
buf
,
scope
))
{
AccessEntry
e
;
e
.
threads
=
env_threads
();
e
.
buffer
=
buf
;
e
.
buffer
=
op
->
buffer_var
;
e
.
dtype
=
op
->
value
.
type
().
element_of
();
e
.
touched
=
arith
::
IntSet
::
vector
(
op
->
index
);
e
.
type
=
kWrite
;
...
...
@@ -69,6 +69,11 @@ void StorageAccessVisitor::Visit_(const AttrStmt* op) {
storage_scope_
[
buf
]
=
StorageScope
::
make
(
op
->
value
.
as
<
StringImm
>
()
->
value
);
IRVisitor
::
Visit_
(
op
);
}
else
if
(
op
->
attr_key
==
attr
::
coproc_scope
)
{
IterVar
iv
(
op
->
node
.
node_
);
env_threads_
.
push_back
(
iv
);
IRVisitor
::
Visit_
(
op
);
env_threads_
.
CopyOnWrite
()
->
data
.
pop_back
();
}
else
if
(
op
->
attr_key
==
attr
::
thread_extent
)
{
IterVar
iv
(
op
->
node
.
node_
);
env_threads_
.
push_back
(
iv
);
...
...
@@ -102,11 +107,13 @@ void StorageAccessVisitor::Visit_(const For* op) {
relax_map
[
op
->
loop_var
.
get
()]
=
arith
::
IntSet
::
range
(
Range
::
make_by_min_extent
(
op
->
min
,
op
->
extent
));
for
(
AccessEntry
&
e
:
s
.
access
)
{
if
(
e
.
buffer
!=
nullptr
)
{
if
(
e
.
buffer
.
defined
()
)
{
CHECK
(
e
.
touched
.
defined
());
e
.
touched
=
arith
::
EvalSet
(
e
.
touched
,
relax_map
);
}
}
}
if
(
!
s
.
access
.
empty
())
{
scope_
.
back
().
emplace_back
(
std
::
move
(
s
));
}
}
...
...
@@ -148,7 +155,7 @@ void StorageAccessVisitor::Visit_(const Call* op) {
AccessEntry
e
;
e
.
threads
=
env_threads
();
e
.
dtype
=
dtype
;
e
.
buffer
=
buffer
;
e
.
buffer
=
VarExpr
(
op
->
args
[
1
].
node_
)
;
e
.
touched
=
arith
::
IntSet
::
range
(
Range
::
make_by_min_extent
(
offset
,
extent
));
e
.
scope
=
scope
;
...
...
src/pass/storage_access.h
View file @
fd96d285
...
...
@@ -27,14 +27,16 @@ class StorageAccessVisitor : public IRVisitor {
kRead
,
kWrite
,
kSync
,
kAlloc
kAlloc
,
// acquired version of read, only need to handle WAR dep.
kReadAcquire
};
/*! \brief An access entry */
struct
AccessEntry
{
/*! \brief The thread index that access this entry */
Array
<
IterVar
>
threads
;
/*! \brief The buffer variable, if any */
const
Variable
*
buffer
{
nullptr
}
;
VarExpr
buffer
;
/*! \brief The access data type */
Type
dtype
;
/*! \brief The touched access range */
...
...
@@ -104,6 +106,8 @@ class StorageAccessVisitor : public IRVisitor {
* \return The scope of the final buffer array.
*/
StorageScope
GetScope
(
const
Variable
*
buf
)
const
;
// access scope
std
::
vector
<
std
::
vector
<
StmtEntry
>
>
scope_
;
private
:
// whether access appending is enabled.
...
...
@@ -116,8 +120,6 @@ class StorageAccessVisitor : public IRVisitor {
StmtEntry
curr_stmt_
;
// The involving threads
Array
<
IterVar
>
env_threads_
;
// access scope
std
::
vector
<
std
::
vector
<
StmtEntry
>
>
scope_
;
// The storage scope of each buffer
std
::
unordered_map
<
const
Variable
*
,
StorageScope
>
storage_scope_
;
};
...
...
src/pass/storage_sync.cc
View file @
fd96d285
...
...
@@ -37,7 +37,7 @@ class ThreadSyncPlanner : public StorageAccessVisitor {
// if it is a loop, rotate two times to consider effect of loop.
size_t
max_seq
=
seq
.
size
();
if
(
loop
!=
0
)
max_seq
*=
2
;
if
(
loop
!=
nullptr
)
max_seq
*=
2
;
// simulation based approach to find dependenceies
for
(
size_t
i
=
0
;
i
<
max_seq
;
++
i
)
{
const
StmtEntry
&
s
=
seq
[
i
%
seq
.
size
()];
...
...
@@ -125,7 +125,7 @@ class ThreadSyncPlanner : public StorageAccessVisitor {
bool
FindConflict
(
const
std
::
vector
<
AccessEntry
>&
vec
,
const
AccessEntry
&
e
)
{
for
(
const
AccessEntry
&
x
:
vec
)
{
if
(
x
.
buffer
==
e
.
buffer
)
{
if
(
x
.
buffer
.
same_as
(
e
.
buffer
)
)
{
// Assumes no race between threads
// Same index value means no conflicts
// TODO(tqchen) more standard set based testing.
...
...
@@ -296,5 +296,189 @@ LoweredFunc ThreadSync(LoweredFunc f, std::string storage_scope) {
return
LoweredFunc
(
n
);
}
// Visitor to find touched set by co-processor scope.
class
CoProcTouchedBuffer
:
public
IRVisitor
{
public
:
void
Visit_
(
const
Load
*
op
)
final
{
if
(
in_scope_
)
{
touched_
.
insert
(
op
->
buffer_var
.
get
());
}
IRVisitor
::
Visit_
(
op
);
}
void
Visit_
(
const
Store
*
op
)
final
{
if
(
in_scope_
)
{
touched_
.
insert
(
op
->
buffer_var
.
get
());
}
IRVisitor
::
Visit_
(
op
);
}
void
Visit_
(
const
Call
*
op
)
final
{
if
(
op
->
is_intrinsic
(
intrinsic
::
tvm_access_ptr
))
{
const
Variable
*
buffer
=
op
->
args
[
1
].
as
<
Variable
>
();
touched_
.
insert
(
buffer
);
}
IRVisitor
::
Visit_
(
op
);
}
void
Visit_
(
const
AttrStmt
*
op
)
final
{
if
(
op
->
attr_key
==
attr
::
coproc_scope
&&
!
in_scope_
)
{
in_scope_
=
true
;
IRVisitor
::
Visit_
(
op
);
in_scope_
=
false
;
}
else
{
IRVisitor
::
Visit_
(
op
);
}
}
std
::
unordered_set
<
const
Variable
*>
touched_
;
private
:
bool
in_scope_
{
false
};
};
// Synchronization planning with co-processor.
class
CoProcSyncPlanner
:
public
StorageAccessVisitor
{
public
:
void
Plan
(
const
Stmt
&
stmt
)
{
CoProcTouchedBuffer
visitor
;
visitor
.
Visit
(
stmt
);
touched_
=
std
::
move
(
visitor
.
touched_
);
if
(
!
touched_
.
empty
())
{
this
->
Visit
(
stmt
);
PlanWriteSync
(
scope_
.
back
(),
nullptr
,
true
);
}
}
// Write synchronization to be inserted before or after stmt.
std
::
unordered_map
<
const
Node
*
,
std
::
vector
<
Stmt
>
>
write_sync_
;
protected
:
bool
Enabled
(
const
Variable
*
buf
,
const
StorageScope
&
scope
)
const
final
{
return
touched_
.
count
(
buf
)
&&
scope
==
global_scope_
;
}
// Plan the sync
std
::
vector
<
AccessEntry
>
Summarize
(
std
::
vector
<
StmtEntry
>
seq
,
const
For
*
loop
)
final
{
return
PlanWriteSync
(
seq
,
loop
,
false
);
}
private
:
// Plan write synchronization if write is not coherent
std
::
vector
<
AccessEntry
>
PlanWriteSync
(
std
::
vector
<
StmtEntry
>
seq
,
const
For
*
loop
,
bool
force_sync_at_end
)
{
// detect write barriers
// access by the co-processor.
std
::
vector
<
AccessEntry
>
co_access
;
bool
contain_sync
=
false
;
auto
find_conflict
=
[
&
](
const
AccessEntry
&
acc
)
{
for
(
const
AccessEntry
&
x
:
co_access
)
{
if
(
x
.
buffer
.
same_as
(
acc
.
buffer
)
&&
((
acc
.
type
==
kRead
&&
x
.
type
==
kWrite
)
||
acc
.
type
==
kWrite
))
{
return
true
;
}
}
return
false
;
};
for
(
size_t
i
=
0
;
i
<
seq
.
size
();
++
i
)
{
const
StmtEntry
&
s
=
seq
[
i
];
bool
sync_write
=
false
;
for
(
const
AccessEntry
&
acc
:
s
.
access
)
{
if
(
acc
.
threads
.
size
()
==
0
&&
find_conflict
(
acc
))
{
sync_write
=
true
;
break
;
}
if
(
acc
.
type
==
kSync
)
{
co_access
.
clear
();
contain_sync
=
true
;
}
}
if
(
sync_write
)
{
CHECK_NE
(
i
,
0U
);
write_sync_
[
seq
[
i
-
1
].
stmt
]
=
GetWriteSync
(
co_access
);
co_access
.
clear
();
contain_sync
=
true
;
}
for
(
const
AccessEntry
&
acc
:
s
.
access
)
{
if
(
acc
.
threads
.
size
()
!=
0
)
{
co_access
.
push_back
(
acc
);
}
}
}
bool
sync_at_end
=
force_sync_at_end
;
if
(
loop
!=
nullptr
&&
!
sync_at_end
)
{
// loop carray dependency
for
(
size_t
i
=
0
;
i
<
seq
.
size
();
++
i
)
{
const
StmtEntry
&
s
=
seq
[
i
];
for
(
const
AccessEntry
&
acc
:
s
.
access
)
{
if
(
acc
.
threads
.
size
()
==
0
&&
find_conflict
(
acc
))
{
sync_at_end
=
true
;
break
;
}
}
if
(
write_sync_
.
count
(
s
.
stmt
)
||
sync_at_end
)
break
;
}
}
if
(
sync_at_end
&&
co_access
.
size
()
!=
0
)
{
CHECK_NE
(
seq
.
size
(),
0
);
contain_sync
=
true
;
write_sync_
[
seq
.
back
().
stmt
]
=
GetWriteSync
(
co_access
);
co_access
.
clear
();
}
if
(
contain_sync
)
{
AccessEntry
e
;
e
.
type
=
kSync
;
e
.
scope
=
global_scope_
;
co_access
.
insert
(
co_access
.
begin
(),
e
);
}
return
co_access
;
}
// Add write Synchronization
std
::
vector
<
Stmt
>
GetWriteSync
(
const
std
::
vector
<
AccessEntry
>&
co_access
)
{
// Does not consider memory coherence, need runtime.
CHECK_NE
(
co_access
.
size
(),
0U
);
CHECK_EQ
(
co_access
[
0
].
threads
.
size
(),
1U
);
std
::
string
sync_name
=
co_access
[
0
].
threads
[
0
]
->
var
->
name_hint
+
".coproc_sync"
;
std
::
vector
<
Stmt
>
stmts
;
stmts
.
emplace_back
(
Evaluate
::
make
(
Call
::
make
(
Int
(
32
),
sync_name
,
{},
Call
::
Intrinsic
)));
return
stmts
;
}
std
::
unordered_set
<
const
Variable
*>
touched_
;
StorageScope
global_scope_
=
StorageScope
::
make
(
"global"
);
};
class
CoProcSyncInserter
:
public
IRMutator
{
public
:
explicit
CoProcSyncInserter
(
const
std
::
unordered_map
<
const
Node
*
,
std
::
vector
<
Stmt
>
>&
write_sync
)
:
write_sync_
(
write_sync
)
{}
Stmt
Mutate
(
Stmt
stmt
)
final
{
stmt
=
IRMutator
::
Mutate
(
stmt
);
auto
it
=
write_sync_
.
find
(
stmt
.
get
());
if
(
it
!=
write_sync_
.
end
())
{
stmt
=
Block
::
make
(
stmt
,
MergeSeq
(
it
->
second
));
}
return
stmt
;
}
private
:
const
std
::
unordered_map
<
const
Node
*
,
std
::
vector
<
Stmt
>
>&
write_sync_
;
};
Stmt
CoProcSync
(
Stmt
stmt
)
{
CoProcSyncPlanner
planner
;
planner
.
Plan
(
stmt
);
if
(
planner
.
write_sync_
.
size
()
!=
0
)
{
return
CoProcSyncInserter
(
planner
.
write_sync_
).
Mutate
(
stmt
);
}
else
{
return
stmt
;
}
}
}
// namespace ir
}
// namespace tvm
tests/python/unittest/test_pass_storage_sync.py
View file @
fd96d285
...
...
@@ -24,7 +24,26 @@ def test_storage_sync():
flist
=
tvm
.
ir_pass
.
SplitHostDevice
(
f
)
f
=
flist
[
1
]
f
=
tvm
.
ir_pass
.
ThreadSync
(
f
,
"shared"
)
print
(
f
.
body
)
body_list
=
tvm
.
make
.
stmt_list
(
f
.
body
.
body
.
body
.
body
)
assert
(
body_list
[
1
]
.
value
.
name
==
"tvm_storage_sync"
)
def
test_coproc_sync
():
ib
=
tvm
.
ir_builder
.
create
()
n
=
tvm
.
var
(
"n"
)
cp
=
tvm
.
thread_axis
((
0
,
1
),
"cop"
)
A
=
ib
.
allocate
(
"float32"
,
n
,
name
=
"A"
,
scope
=
"global"
)
with
ib
.
for_range
(
0
,
n
,
name
=
"i"
)
as
i
:
A
[
i
]
=
A
[
i
]
+
1
with
ib
.
for_range
(
0
,
10
,
name
=
"j"
)
as
j
:
ib
.
scope_attr
(
cp
,
"coproc_scope"
,
1
)
A
[
j
]
=
A
[
j
]
+
2
body
=
ib
.
get
()
body
=
tvm
.
ir_pass
.
CoProcSync
(
body
)
body
=
body
.
body
.
body
.
body
assert
(
tvm
.
make
.
stmt_list
(
body
)[
-
1
]
.
value
.
name
==
"cop.coproc_sync"
)
if
__name__
==
"__main__"
:
test_coproc_sync
()
test_storage_sync
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment