Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
f1aabedc
Commit
f1aabedc
authored
Nov 11, 2017
by
Tianqi Chen
Committed by
GitHub
Nov 11, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[PASS] Update coproc sync (#634)
parent
32b0fff2
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
303 additions
and
6 deletions
+303
-6
python/tvm/build_module.py
+7
-3
src/pass/coproc_sync.cc
+267
-1
src/pass/storage_rewrite.cc
+7
-1
src/runtime/thread_storage_scope.h
+1
-1
tests/python/unittest/test_pass_storage_sync.py
+21
-0
No files found.
python/tvm/build_module.py
View file @
f1aabedc
...
...
@@ -201,7 +201,8 @@ def lower(sch,
add_lower_pass
=
cfg
.
add_lower_pass
if
cfg
.
add_lower_pass
else
[]
lower_phase0
=
[
x
[
1
]
for
x
in
add_lower_pass
if
x
[
0
]
==
0
]
lower_phase1
=
[
x
[
1
]
for
x
in
add_lower_pass
if
x
[
0
]
==
1
]
lower_phase2
=
[
x
[
1
]
for
x
in
add_lower_pass
if
x
[
0
]
>
1
]
lower_phase2
=
[
x
[
1
]
for
x
in
add_lower_pass
if
x
[
0
]
==
2
]
lower_phase3
=
[
x
[
1
]
for
x
in
add_lower_pass
if
x
[
0
]
>
2
]
# normalize schedule first
sch
=
sch
.
normalize
()
# Phase 0
...
...
@@ -213,6 +214,9 @@ def lower(sch,
# Phase 1
stmt
=
ir_pass
.
StorageFlatten
(
stmt
,
binds
,
64
)
stmt
=
ir_pass
.
CanonicalSimplify
(
stmt
)
for
f
in
lower_phase1
:
stmt
=
f
(
stmt
)
# Phase 2
if
not
simple_mode
:
stmt
=
ir_pass
.
LoopPartition
(
stmt
)
stmt
=
ir_pass
.
VectorizeLoop
(
stmt
)
...
...
@@ -224,14 +228,14 @@ def lower(sch,
cfg
.
auto_unroll_max_step
,
cfg
.
auto_unroll_max_depth
,
cfg
.
unroll_explicit
)
for
f
in
lower_phase
1
:
for
f
in
lower_phase
2
:
stmt
=
f
(
stmt
)
# Phase 2
stmt
=
ir_pass
.
Simplify
(
stmt
)
stmt
=
ir_pass
.
LowerStorageAccessInfo
(
stmt
)
stmt
=
ir_pass
.
RemoveNoOp
(
stmt
)
stmt
=
ir_pass
.
RewriteUnsafeSelect
(
stmt
)
for
f
in
lower_phase
2
:
for
f
in
lower_phase
3
:
stmt
=
f
(
stmt
)
if
simple_mode
:
return
stmt
...
...
src/pass/coproc_sync.cc
View file @
f1aabedc
...
...
@@ -338,6 +338,256 @@ class CoProcBarrierDetector : public StorageAccessVisitor {
};
class
CoProcInstDepDetector
:
public
IRVisitor
{
public
:
explicit
CoProcInstDepDetector
(
const
IterVar
&
coproc_axis
,
const
std
::
string
&
coproc_name
)
:
coproc_axis_
(
coproc_axis
)
{
sync_push_name_
=
coproc_name
+
".coproc_dep_push"
;
sync_pop_name_
=
coproc_name
+
".coproc_dep_pop"
;
}
void
Plan
(
Stmt
stmt
)
{
this
->
Visit
(
stmt
);
if
(
last_state_
.
node
!=
nullptr
)
{
MatchFixEnterPop
(
first_state_
);
MatchFixExitPush
(
last_state_
);
}
}
void
Visit_
(
const
AttrStmt
*
op
)
final
{
if
(
op
->
attr_key
==
attr
::
coproc_scope
&&
op
->
node
.
same_as
(
coproc_axis_
))
{
const
IntImm
*
ctx_id
=
op
->
value
.
as
<
IntImm
>
();
CHECK
(
ctx_id
!=
nullptr
);
curr_state_
.
clear
();
curr_state_
.
node
=
op
->
body
.
get
();
curr_state_
.
enter_ctx
.
insert
(
ctx_id
->
value
);
curr_state_
.
exit_ctx
.
insert
(
ctx_id
->
value
);
UpdateState
();
}
else
{
IRVisitor
::
Visit_
(
op
);
}
}
void
Visit_
(
const
For
*
op
)
final
{
SyncState
temp_first
,
temp_last
;
std
::
swap
(
first_state_
,
temp_first
);
std
::
swap
(
last_state_
,
temp_last
);
this
->
Visit
(
op
->
body
);
curr_state_
.
clear
();
if
(
last_state_
.
node
!=
nullptr
)
{
curr_state_
.
node
=
op
;
CHECK
(
first_state_
.
node
!=
nullptr
);
// loop carry dependency
InjectSync
(
last_state_
,
first_state_
,
&
(
curr_state_
.
exit_push
),
&
(
curr_state_
.
enter_pop
));
curr_state_
.
enter_ctx
=
first_state_
.
enter_ctx
;
curr_state_
.
exit_ctx
=
last_state_
.
enter_ctx
;
}
std
::
swap
(
first_state_
,
temp_first
);
std
::
swap
(
last_state_
,
temp_last
);
if
(
curr_state_
.
node
!=
nullptr
)
{
UpdateState
();
}
}
void
Visit_
(
const
IfThenElse
*
op
)
final
{
SyncState
temp_first
,
temp_last
,
curr_state
;
std
::
swap
(
first_state_
,
temp_first
);
std
::
swap
(
last_state_
,
temp_last
);
{
// then stmt
this
->
Visit
(
op
->
then_case
);
if
(
last_state_
.
node
!=
nullptr
)
{
curr_state
.
node
=
op
;
MatchFixEnterPop
(
first_state_
);
MatchFixExitPush
(
last_state_
);
curr_state
.
enter_ctx
.
insert
(
first_state_
.
enter_ctx
.
begin
(),
first_state_
.
enter_ctx
.
end
());
curr_state
.
exit_ctx
.
insert
(
last_state_
.
exit_ctx
.
begin
(),
last_state_
.
exit_ctx
.
end
());
}
first_state_
.
clear
();
last_state_
.
clear
();
}
if
(
op
->
else_case
.
defined
())
{
this
->
Visit
(
op
->
else_case
);
if
(
last_state_
.
node
!=
nullptr
)
{
curr_state
.
node
=
op
;
MatchFixEnterPop
(
first_state_
);
MatchFixExitPush
(
last_state_
);
curr_state
.
enter_ctx
.
insert
(
first_state_
.
enter_ctx
.
begin
(),
first_state_
.
enter_ctx
.
end
());
curr_state
.
exit_ctx
.
insert
(
last_state_
.
exit_ctx
.
begin
(),
last_state_
.
exit_ctx
.
end
());
}
}
// update in the trace.
std
::
swap
(
first_state_
,
temp_first
);
std
::
swap
(
last_state_
,
temp_last
);
std
::
swap
(
curr_state_
,
curr_state
);
if
(
curr_state_
.
node
!=
nullptr
)
{
UpdateState
();
}
}
// insert before is stored in reverse order
// the first element is closest to the node.
std
::
unordered_map
<
const
Node
*
,
std
::
vector
<
Stmt
>
>
insert_before_
;
std
::
unordered_map
<
const
Node
*
,
std
::
vector
<
Stmt
>
>
insert_after_
;
private
:
// state in the sync entry
struct
SyncState
{
// The statement of the state.
const
Node
*
node
{
nullptr
};
// Set of all possible contexts in the entering moment.
std
::
unordered_set
<
int
>
enter_ctx
;
// Set of all possible contexts in the exit moment.
std
::
unordered_set
<
int
>
exit_ctx
;
// existing pop performed at enter
std
::
vector
<
std
::
pair
<
int
,
int
>
>
enter_pop
;
// existing push peformed at exit
std
::
vector
<
std
::
pair
<
int
,
int
>
>
exit_push
;
// clear the state
void
clear
()
{
node
=
nullptr
;
enter_ctx
.
clear
();
exit_ctx
.
clear
();
enter_pop
.
clear
();
exit_push
.
clear
();
}
};
// inject proper sync into the pair
// record the push/pop sequence that could be possibly un-matched.
// return the push/pop message at enter/exit of the Block
// after considering the existing unmatcheded events and added events
void
InjectSync
(
const
SyncState
&
prev
,
const
SyncState
&
next
,
std
::
vector
<
std
::
pair
<
int
,
int
>
>*
prev_exit_push
,
std
::
vector
<
std
::
pair
<
int
,
int
>
>*
next_enter_pop
)
{
prev_exit_push
->
clear
();
next_enter_pop
->
clear
();
// quick path
if
(
prev
.
exit_push
.
size
()
==
0
&&
next
.
enter_pop
.
size
()
==
0
&&
prev
.
exit_ctx
.
size
()
==
1
&&
next
.
enter_ctx
.
size
()
==
1
)
{
int
from
=
*
prev
.
exit_ctx
.
begin
();
int
to
=
*
next
.
enter_ctx
.
begin
();
if
(
from
!=
to
)
{
insert_after_
[
prev
.
node
].
emplace_back
(
MakePush
(
from
,
to
));
insert_before_
[
next
.
node
].
emplace_back
(
MakePop
(
from
,
to
));
prev_exit_push
->
emplace_back
(
std
::
make_pair
(
from
,
to
));
next_enter_pop
->
emplace_back
(
std
::
make_pair
(
from
,
to
));
}
return
;
}
// complicate path.
std
::
vector
<
std
::
pair
<
int
,
int
>
>
vpush
=
prev
.
exit_push
;
std
::
vector
<
std
::
pair
<
int
,
int
>
>
vpop
=
next
.
enter_pop
;
std
::
vector
<
std
::
pair
<
int
,
int
>
>
pending
;
for
(
int
from
:
prev
.
exit_ctx
)
{
for
(
int
to
:
next
.
enter_ctx
)
{
if
(
from
!=
to
)
{
pending
.
emplace_back
(
std
::
make_pair
(
from
,
to
));
}
}
}
// policy 1
std
::
vector
<
Stmt
>
prev_after
,
next_before
;
for
(
const
std
::
pair
<
int
,
int
>&
p
:
pending
)
{
if
(
std
::
find
(
prev
.
exit_push
.
begin
(),
prev
.
exit_push
.
end
(),
p
)
==
prev
.
exit_push
.
end
())
{
vpush
.
push_back
(
p
);
prev_after
.
emplace_back
(
MakePush
(
p
.
first
,
p
.
second
));
}
if
(
std
::
find
(
next
.
enter_pop
.
begin
(),
next
.
enter_pop
.
end
(),
p
)
==
next
.
enter_pop
.
end
())
{
vpop
.
push_back
(
p
);
next_before
.
emplace_back
(
MakePop
(
p
.
first
,
p
.
second
));
}
}
// fix pending
for
(
const
std
::
pair
<
int
,
int
>&
p
:
vpush
)
{
if
(
std
::
find
(
vpop
.
begin
(),
vpop
.
end
(),
p
)
==
vpop
.
end
())
{
prev_after
.
emplace_back
(
MakePop
(
p
.
first
,
p
.
second
));
}
else
{
prev_exit_push
->
push_back
(
p
);
}
}
for
(
const
std
::
pair
<
int
,
int
>&
p
:
vpop
)
{
if
(
std
::
find
(
vpush
.
begin
(),
vpush
.
end
(),
p
)
==
vpush
.
end
())
{
next_before
.
emplace_back
(
MakePush
(
p
.
first
,
p
.
second
));
}
else
{
next_enter_pop
->
push_back
(
p
);
}
}
if
(
prev_after
.
size
()
!=
0
)
{
auto
&
v1
=
insert_after_
[
prev
.
node
];
v1
.
insert
(
v1
.
end
(),
prev_after
.
begin
(),
prev_after
.
end
());
}
if
(
next_before
.
size
()
!=
0
)
{
auto
&
v2
=
insert_before_
[
next
.
node
];
v2
.
insert
(
v2
.
end
(),
next_before
.
begin
(),
next_before
.
end
());
}
}
void
MatchFixEnterPop
(
const
SyncState
&
state
)
{
if
(
state
.
enter_pop
.
size
()
==
0
)
return
;
auto
&
vec
=
insert_before_
[
state
.
node
];
for
(
const
std
::
pair
<
int
,
int
>&
p
:
state
.
enter_pop
)
{
vec
.
push_back
(
MakePush
(
p
.
first
,
p
.
second
));
}
}
void
MatchFixExitPush
(
const
SyncState
&
state
)
{
if
(
state
.
exit_push
.
size
()
==
0
)
return
;
auto
&
vec
=
insert_after_
[
state
.
node
];
for
(
const
std
::
pair
<
int
,
int
>&
p
:
state
.
exit_push
)
{
vec
.
push_back
(
MakePop
(
p
.
first
,
p
.
second
));
}
}
void
UpdateState
()
{
if
(
last_state_
.
node
!=
nullptr
)
{
std
::
vector
<
std
::
pair
<
int
,
int
>
>
t1
,
t2
;
InjectSync
(
last_state_
,
curr_state_
,
&
t1
,
&
t2
);
std
::
swap
(
last_state_
,
curr_state_
);
}
else
{
CHECK
(
first_state_
.
node
==
nullptr
);
first_state_
=
curr_state_
;
last_state_
=
curr_state_
;
}
}
Stmt
MakePush
(
int
from
,
int
to
)
{
return
Evaluate
::
make
(
Call
::
make
(
Int
(
32
),
sync_push_name_
,
{
make_const
(
Int
(
32
),
from
),
make_const
(
Int
(
32
),
to
)},
Call
::
Intrinsic
));
}
Stmt
MakePop
(
int
from
,
int
to
)
{
return
Evaluate
::
make
(
Call
::
make
(
Int
(
32
),
sync_pop_name_
,
{
make_const
(
Int
(
32
),
from
),
make_const
(
Int
(
32
),
to
)},
Call
::
Intrinsic
));
}
// sync states.
SyncState
first_state_
,
last_state_
,
curr_state_
;
// Variables
IterVar
coproc_axis_
;
std
::
string
sync_push_name_
,
sync_pop_name_
;
};
class
CoProcSyncInserter
:
public
IRMutator
{
public
:
Stmt
Insert
(
Stmt
stmt
)
{
...
...
@@ -372,6 +622,18 @@ class CoProcSyncInserter : public IRMutator {
auto
&
vec
=
insert_after_
[
kv
.
first
];
vec
.
insert
(
vec
.
end
(),
kv
.
second
.
begin
(),
kv
.
second
.
end
());
}
// Detect barrier
CoProcInstDepDetector
sync_detector
(
*
visitor
.
coproc_
.
begin
(),
coproc_name
);
sync_detector
.
Plan
(
stmt
);
for
(
const
auto
&
kv
:
sync_detector
.
insert_before_
)
{
auto
&
vec
=
insert_before_
[
kv
.
first
];
vec
.
insert
(
vec
.
end
(),
kv
.
second
.
begin
(),
kv
.
second
.
end
());
}
for
(
const
auto
&
kv
:
sync_detector
.
insert_after_
)
{
auto
&
vec
=
insert_after_
[
kv
.
first
];
vec
.
insert
(
vec
.
end
(),
kv
.
second
.
begin
(),
kv
.
second
.
end
());
}
return
Mutate
(
stmt
);
}
...
...
@@ -379,7 +641,8 @@ class CoProcSyncInserter : public IRMutator {
Stmt
before
,
after
;
auto
it
=
insert_before_
.
find
(
stmt
.
get
());
if
(
it
!=
insert_before_
.
end
())
{
before
=
MergeSeq
(
it
->
second
);
before
=
MergeSeq
(
std
::
vector
<
Stmt
>
(
it
->
second
.
rbegin
(),
it
->
second
.
rend
()));
}
it
=
insert_after_
.
find
(
stmt
.
get
());
if
(
it
!=
insert_after_
.
end
())
{
...
...
@@ -396,10 +659,13 @@ class CoProcSyncInserter : public IRMutator {
}
private
:
// insert before is stored in reverse order
// the first element is closest to the node.
std
::
unordered_map
<
const
Node
*
,
std
::
vector
<
Stmt
>
>
insert_before_
;
std
::
unordered_map
<
const
Node
*
,
std
::
vector
<
Stmt
>
>
insert_after_
;
};
Stmt
CoProcSync
(
Stmt
stmt
)
{
return
CoProcSyncInserter
().
Insert
(
stmt
);
}
...
...
src/pass/storage_rewrite.cc
View file @
f1aabedc
...
...
@@ -189,7 +189,7 @@ class StoragePlanRewriter : public IRMutator {
if
(
attach_map_
.
count
(
nullptr
))
{
std
::
vector
<
Stmt
>
nest
;
for
(
StorageEntry
*
e
:
attach_map_
.
at
(
nullptr
))
{
CHECK_EQ
(
e
->
scope
.
rank
,
0
);
//
CHECK_EQ(e->scope.rank, 0);
if
(
e
->
new_alloc
.
defined
())
{
nest
.
emplace_back
(
AttrStmt
::
make
(
e
->
alloc_var
,
attr
::
storage_scope
,
...
...
@@ -395,6 +395,12 @@ class StoragePlanRewriter : public IRMutator {
e
->
new_alloc
=
Allocate
::
make
(
e
->
alloc_var
,
alloc_type
,
e
->
allocs
[
0
]
->
extents
,
e
->
allocs
[
0
]
->
condition
,
Evaluate
::
make
(
0
));
if
(
e
->
scope
.
tag
.
length
()
!=
0
)
{
MemoryInfo
info
=
GetMemoryInfo
(
e
->
scope
.
to_string
());
uint64_t
total_elem
=
e
->
const_nbits
/
e
->
elem_type
.
bits
();
CHECK_LE
(
total_elem
*
e
->
elem_type
.
bits
(),
info
->
max_num_bits
)
<<
"Allocation exceed bound of memory tag "
<<
e
->
scope
.
to_string
();
}
}
else
{
// Build a merged allocation
Expr
combo_size
;
...
...
src/runtime/thread_storage_scope.h
View file @
f1aabedc
...
...
@@ -71,7 +71,7 @@ struct ThreadScope {
*/
static
ThreadScope
make
(
const
std
::
string
&
s
)
{
ThreadScope
r
;
if
(
s
==
"vthread"
)
{
if
(
s
==
"vthread"
||
s
==
"cthread"
)
{
// virtual thread at the same level as local
r
.
rank
=
1
;
r
.
dim_index
=
-
1
;
...
...
tests/python/unittest/test_pass_storage_sync.py
View file @
f1aabedc
...
...
@@ -58,6 +58,27 @@ def test_coproc_sync():
assert
(
blist
[
-
1
]
.
value
.
args
[
3
]
.
value
==
10
)
def
test_coproc_sync2
():
ib
=
tvm
.
ir_builder
.
create
()
n
=
tvm
.
var
(
"n"
)
cp
=
tvm
.
thread_axis
((
0
,
1
),
"cop"
)
ty
=
tvm
.
thread_axis
(
"cthread"
)
A
=
ib
.
allocate
(
"float32"
,
128
,
name
=
"A"
)
ib
.
scope_attr
(
ty
,
"virtual_thread"
,
2
)
with
ib
.
new_scope
():
ib
.
scope_attr
(
cp
,
"coproc_scope"
,
2
)
A
[
ty
]
=
0.0
with
ib
.
for_range
(
0
,
n
,
name
=
"i"
)
as
i
:
with
ib
.
new_scope
():
ib
.
scope_attr
(
cp
,
"coproc_scope"
,
1
)
A
[
ty
]
=
1.0
with
ib
.
new_scope
():
ib
.
scope_attr
(
cp
,
"coproc_scope"
,
2
)
A
[
ty
]
=
1.0
stmt
=
ib
.
get
()
stmt
=
tvm
.
ir_pass
.
CoProcSync
(
stmt
)
if
__name__
==
"__main__"
:
test_coproc_sync
()
test_storage_sync
()
test_coproc_sync2
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment