Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
9ec40edd
Commit
9ec40edd
authored
Apr 10, 2017
by
Tianqi Chen
Committed by
GitHub
Apr 10, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[SCHEDULE] Fix cross thread schedule after refactor (#85)
parent
ea9c1c59
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
52 additions
and
35 deletions
+52
-35
src/op/compute_op.cc
+20
-22
src/pass/lower_thread_allreduce.cc
+12
-8
src/schedule/bound.cc
+5
-1
src/schedule/schedule_dataflow_rewrite.cc
+10
-0
tests/python/integration/test_reduce.py
+5
-4
No files found.
src/op/compute_op.cc
View file @
9ec40edd
...
...
@@ -194,29 +194,25 @@ Stmt Substitute(Stmt s,
// Cross Thread reduction marker.
bool
IsCrossThreadReduction
(
const
ComputeOpNode
*
self
,
const
Stage
&
stage
)
{
std
::
unordered_set
<
IterVar
>
rebase_thread
;
for
(
IterVarRelation
rel
:
stage
->
relations
)
{
if
(
const
RebaseNode
*
s
=
rel
.
as
<
RebaseNode
>
())
{
if
(
s
->
parent
->
iter_type
==
kCommReduce
&&
s
->
rebased
->
iter_type
==
kThreadIndex
)
{
rebase_thread
.
insert
(
s
->
rebased
);
}
}
}
if
(
rebase_thread
.
size
()
==
0
)
return
false
;
// Verify correctness of leaf nest.
bool
reduce_start
=
false
;
int
normal_red
=
0
,
thread_red
=
0
;
for
(
IterVar
iv
:
stage
->
leaf_iter_vars
)
{
if
(
iv
->
iter_type
==
kCommReduce
)
{
LOG
(
FATAL
)
<<
"Cannot mix cross thread reduce with normal reduce"
;
}
else
if
(
rebase_thread
.
count
(
iv
))
{
reduce_start
=
true
;
auto
it
=
stage
->
iter_var_attrs
.
find
(
iv
);
if
(
it
!=
stage
->
iter_var_attrs
.
end
()
&&
(
*
it
).
second
->
bind_thread
.
defined
())
{
++
thread_red
;
}
else
{
++
normal_red
;
}
}
else
{
CHECK
(
!
reduce_start
)
CHECK
_EQ
(
thread_red
,
0
)
<<
"Cross thread reduce cannot swap with normal data axis"
;
}
}
return
true
;
CHECK
(
normal_red
==
0
||
thread_red
==
0
)
<<
"Cannot mix normal reduction with thread reduce"
;
return
thread_red
!=
0
;
}
Stmt
MakeCrossThreadReduction
(
...
...
@@ -246,12 +242,14 @@ Stmt MakeCrossThreadReduction(
freduce_args
.
push_back
(
cond
);
std
::
vector
<
Expr
>
thread_head_check
;
for
(
IterVarRelation
rel
:
stage
->
relations
)
{
if
(
const
RebaseNode
*
s
=
rel
.
as
<
RebaseNode
>
())
{
if
(
s
->
parent
->
iter_type
==
kCommReduce
&&
s
->
rebased
->
iter_type
==
kThreadIndex
)
{
freduce_args
.
push_back
(
s
->
rebased
->
var
);
thread_head_check
.
push_back
(
s
->
rebased
->
var
==
0
);
for
(
IterVar
iv
:
stage
->
leaf_iter_vars
)
{
if
(
iv
->
iter_type
==
kCommReduce
)
{
auto
it
=
stage
->
iter_var_attrs
.
find
(
iv
);
if
(
it
!=
stage
->
iter_var_attrs
.
end
()
&&
(
*
it
).
second
->
bind_thread
.
defined
())
{
IterVar
tv
=
(
*
it
).
second
->
bind_thread
;
freduce_args
.
push_back
(
tv
->
var
);
thread_head_check
.
push_back
(
tv
->
var
==
0
);
}
}
}
...
...
src/pass/lower_thread_allreduce.cc
View file @
9ec40edd
...
...
@@ -99,13 +99,14 @@ class ThreadAllreduceBuilder : public IRMutator {
cond
,
value
,
Reduce
::
InitValue
(
op_code
,
value
.
type
()));
}
std
::
unordered_set
<
const
Variable
*>
reduce_
index_
;
std
::
unordered_set
<
const
Variable
*>
reduce_
set
;
for
(
size_t
i
=
3
;
i
<
call
->
args
.
size
();
++
i
)
{
const
Variable
*
v
=
call
->
args
[
i
].
as
<
Variable
>
();
CHECK
(
v
);
reduce_
index_
.
insert
(
v
);
reduce_
set
.
insert
(
v
);
}
size_t
nmatch
=
0
;
std
::
unordered_set
<
const
Variable
*>
visited
;
std
::
vector
<
ThreadEntry
>
vred
,
vpar
;
for
(
const
AttrStmt
*
attr
:
thread_extents_
)
{
ThreadEntry
e
;
...
...
@@ -118,15 +119,18 @@ class ThreadAllreduceBuilder : public IRMutator {
CHECK_GE
(
e
.
scope
.
dim_index
,
0
)
<<
"vthread do not work with cross thread reduction"
;
if
(
e
.
scope
.
rank
==
1
)
{
if
(
reduce_index_
.
count
(
iv
->
var
.
get
()))
{
vred
.
push_back
(
e
);
++
nmatch
;
}
else
{
vpar
.
push_back
(
e
);
if
(
!
visited
.
count
(
iv
->
var
.
get
()))
{
visited
.
insert
(
iv
->
var
.
get
());
if
(
reduce_set
.
count
(
iv
->
var
.
get
()))
{
vred
.
push_back
(
e
);
++
nmatch
;
}
else
{
vpar
.
push_back
(
e
);
}
}
}
}
CHECK_EQ
(
nmatch
,
reduce_
index_
.
size
())
CHECK_EQ
(
nmatch
,
reduce_
set
.
size
())
<<
"Not all reduce index are presented in the context"
;
std
::
sort
(
vred
.
begin
(),
vred
.
end
());
std
::
sort
(
vpar
.
begin
(),
vpar
.
end
());
...
...
src/schedule/bound.cc
View file @
9ec40edd
...
...
@@ -128,7 +128,11 @@ void InferRootBound(const Stage& stage,
CHECK
(
is_zero
(
vrange
->
min
))
<<
"InferBound requires every leaf iter var's min equals 0, "
<<
" call schedule.normalize to achieve this. "
;
up_state
[
iv
]
=
IntSet
::
single_point
(
iv
->
var
);
if
(
ctx
.
bind_map
.
count
(
iv
))
{
up_state
[
iv
]
=
IntSet
::
single_point
(
ctx
.
bind_map
.
at
(
iv
)
->
var
);
}
else
{
up_state
[
iv
]
=
IntSet
::
single_point
(
iv
->
var
);
}
}
else
{
up_state
[
iv
]
=
IntSet
::
range
(
vrange
);
}
...
...
src/schedule/schedule_dataflow_rewrite.cc
View file @
9ec40edd
...
...
@@ -161,6 +161,12 @@ void RebaseNonZeroMinLoop(const Schedule& sch) {
ArrayNode
*
leaf_vars
=
s
->
leaf_iter_vars
.
CopyOnWrite
();
for
(
IterVar
iv
:
root_iter_vars
)
{
size_t
idx
=
FindNodeRef
(
leaf_vars
,
iv
);
auto
it
=
s
->
iter_var_attrs
.
find
(
iv
);
// don;t need to rebase path that are binded.
if
(
it
!=
s
->
iter_var_attrs
.
end
()
&&
(
*
it
).
second
->
bind_thread
.
defined
())
{
continue
;
}
if
(
idx
<
leaf_vars
->
data
.
size
())
{
// insert rebase
IterVar
rebased
=
IterVarNode
::
make
(
...
...
@@ -364,6 +370,10 @@ Tensor Schedule::rfactor(const Tensor& tensor,
stages
->
data
.
insert
(
stages
->
data
.
begin
()
+
stage_pos
,
factor_stage
.
node_
);
(
*
this
)
->
stage_map
.
Set
(
factor_op
,
factor_stage
);
factor_stage
->
group
=
reduce_stage
->
group
;
if
(
factor_stage
->
group
.
defined
())
{
++
factor_stage
->
group
->
num_child_stages
;
}
// Replace the old reduction.
IterVar
repl_red_axis
=
reduce_axis
(
dom_map
.
at
(
axis
),
axis
->
var
->
name_hint
+
".v"
);
...
...
tests/python/integration/test_reduce.py
View file @
9ec40edd
...
...
@@ -90,10 +90,11 @@ def test_rfactor_threads():
s
=
tvm
.
Schedule
(
B
.
op
)
ko
,
kf
=
s
[
B
]
.
split
(
k
,
factor
=
nthread
)
BF
=
s
.
rfactor
(
B
,
kf
)
bx
,
t
x
=
s
[
B
]
.
split
(
s
[
B
]
.
op
.
axis
[
0
],
factor
=
nthread
)
bx
,
t
y
=
s
[
B
]
.
split
(
s
[
B
]
.
op
.
axis
[
0
],
factor
=
nthread
)
s
[
B
]
.
bind
(
bx
,
tvm
.
thread_axis
(
"blockIdx.x"
))
s
[
B
]
.
bind
(
tx
,
tvm
.
thread_axis
(
"threadIdx.y"
))
s
[
B
]
.
bind
(
s
[
B
]
.
op
.
reduce_axis
[
0
],
tvm
.
thread_axis
(
"threadIdx.x"
))
s
[
B
]
.
bind
(
ty
,
tvm
.
thread_axis
(
"threadIdx.y"
))
tx
=
s
[
B
]
.
op
.
reduce_axis
[
0
]
s
[
B
]
.
bind
(
tx
,
tvm
.
thread_axis
(
"threadIdx.x"
))
s
[
BF
]
.
compute_at
(
s
[
B
],
tx
)
# one line to build the function.
...
...
@@ -124,6 +125,6 @@ def test_rfactor_threads():
check_target
(
"opencl"
)
if
__name__
==
"__main__"
:
test_rfactor
()
test_rfactor_threads
()
test_rfactor
()
test_sum
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment