Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
3b8ad0a2
Commit
3b8ad0a2
authored
Apr 17, 2017
by
Tianqi Chen
Committed by
GitHub
Apr 17, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[SCHEDULE] Normalize returns a new schedule (#94)
parent
f9d604dd
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
86 additions
and
23 deletions
+86
-23
include/tvm/schedule.h
+6
-1
python/tvm/build.py
+1
-1
python/tvm/schedule.py
+8
-3
src/api/api_lang.cc
+1
-1
src/schedule/schedule_dataflow_rewrite.cc
+5
-3
src/schedule/schedule_lang.cc
+46
-0
tests/python/integration/test_dot.py
+1
-1
tests/python/integration/test_gemm.py
+1
-1
tests/python/unittest/test_codegen_device.py
+1
-0
tests/python/unittest/test_schedule_bound_inference.py
+12
-8
tests/python/unittest/test_schedule_schedule_ops.py
+3
-3
tests/verilog/integration/test_codegen_verilog.py
+1
-1
No files found.
include/tvm/schedule.h
View file @
3b8ad0a2
...
...
@@ -191,6 +191,11 @@ class Schedule : public NodeRef {
*/
explicit
Schedule
(
Array
<
Operation
>
ops
);
/*!
* \brief Get a copy of current schedule.
* \return The copied schedule.
*/
Schedule
copy
()
const
;
/*!
* \brief Get the stage corresponds to the op
* \param op The operation.
*/
...
...
@@ -257,7 +262,7 @@ class Schedule : public NodeRef {
*
* \return A normalized schedule, can be same as current one.
*/
void
normalize
();
Schedule
normalize
();
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
...
...
python/tvm/build.py
View file @
3b8ad0a2
...
...
@@ -57,7 +57,7 @@ def lower(sch,
else
:
raise
ValueError
(
"args must be Tensor, Buffer or Var"
)
# normalize schedule first
sch
.
normalize
()
sch
=
sch
.
normalize
()
bounds
=
schedule
.
InferBound
(
sch
)
stmt
=
schedule
.
ScheduleOps
(
sch
,
bounds
)
stmt
=
ir_pass
.
StorageFlatten
(
stmt
,
binds
)
...
...
python/tvm/schedule.py
View file @
3b8ad0a2
...
...
@@ -78,12 +78,17 @@ class Schedule(NodeBase):
return
self
.
stage_map
[
k
]
def
normalize
(
self
):
"""Build a normalized schedule.
"""Build a normalized schedule
from the current schedule
.
Insert necessary rebase to make certain iter var to start from 0.
This is needed before bound inference and followup step.
Returns
-------
sch : Schedule
The normalized schedule.
"""
_api_internal
.
_ScheduleNormalize
(
self
)
return
_api_internal
.
_ScheduleNormalize
(
self
)
def
create_group
(
self
,
outputs
,
inputs
,
include_inputs
=
False
):
"""Create stage group by giving output and input boundary.
...
...
@@ -261,7 +266,7 @@ class Stage(NodeBase):
threads : list of threads
The threads to be launched.
"""
if
isinstance
(
threads
,
_collections
.
IterVar
):
if
isinstance
(
threads
,
IterVar
):
threads
=
[
threads
]
_api_internal
.
_StageEnvThreads
(
self
,
threads
)
...
...
src/api/api_lang.cc
View file @
3b8ad0a2
...
...
@@ -311,7 +311,7 @@ TVM_REGISTER_API(_StageParallel)
TVM_REGISTER_API
(
_ScheduleNormalize
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
args
[
0
].
operator
Schedule
()
*
ret
=
args
[
0
].
operator
Schedule
()
.
normalize
();
});
...
...
src/schedule/schedule_dataflow_rewrite.cc
View file @
3b8ad0a2
...
...
@@ -242,9 +242,11 @@ void InjectInline(ScheduleNode* sch) {
ReplaceDataFlow
(
sch
->
stages
,
&
repl
);
}
void
Schedule
::
normalize
()
{
InjectInline
(
operator
->
());
RebaseNonZeroMinLoop
(
*
this
);
Schedule
Schedule
::
normalize
()
{
Schedule
sn
=
copy
();
InjectInline
(
sn
.
operator
->
());
RebaseNonZeroMinLoop
(
sn
);
return
sn
;
}
// Handle reduction factor.
...
...
src/schedule/schedule_lang.cc
View file @
3b8ad0a2
...
...
@@ -355,6 +355,52 @@ Schedule::Schedule(Array<Operation> ops) {
}
}
Stage
CopyStage
(
const
Stage
&
s
)
{
std
::
shared_ptr
<
StageNode
>
n
=
std
::
make_shared
<
StageNode
>
(
*
s
.
operator
->
());
return
Stage
(
n
);
}
Schedule
Schedule
::
copy
()
const
{
// map of stages.
const
ScheduleNode
*
self
=
operator
->
();
std
::
unordered_map
<
Stage
,
Stage
,
NodeHash
,
NodeEqual
>
smap
;
std
::
shared_ptr
<
ScheduleNode
>
n
=
std
::
make_shared
<
ScheduleNode
>
();
n
->
outputs
=
self
->
outputs
;
// Copy the stages.
for
(
Stage
s
:
self
->
stages
)
{
Stage
scopy
=
CopyStage
(
s
);
smap
[
s
]
=
scopy
;
n
->
stages
.
push_back
(
scopy
);
}
for
(
Stage
g
:
self
->
groups
)
{
Stage
gcopy
=
CopyStage
(
g
);
smap
[
g
]
=
gcopy
;
n
->
groups
.
push_back
(
gcopy
);
}
// Remaps the reference relations.
for
(
auto
kv
:
self
->
stage_map
)
{
n
->
stage_map
.
Set
(
kv
.
first
,
smap
.
at
(
kv
.
second
));
}
for
(
Stage
s
:
n
->
stages
)
{
if
(
s
->
attach_stage
.
defined
())
{
s
->
attach_stage
=
smap
.
at
(
s
->
attach_stage
);
}
if
(
s
->
group
.
defined
())
{
s
->
group
=
smap
.
at
(
s
->
group
);
}
}
for
(
Stage
s
:
n
->
groups
)
{
if
(
s
->
attach_stage
.
defined
())
{
s
->
attach_stage
=
smap
.
at
(
s
->
attach_stage
);
}
if
(
s
->
group
.
defined
())
{
s
->
group
=
smap
.
at
(
s
->
group
);
}
}
return
Schedule
(
n
);
}
Stage
Schedule
::
operator
[](
const
Operation
&
op
)
{
auto
it
=
(
*
this
)
->
stage_map
.
find
(
op
);
CHECK
(
it
!=
(
*
this
)
->
stage_map
.
end
())
...
...
tests/python/integration/test_dot.py
View file @
3b8ad0a2
...
...
@@ -10,7 +10,7 @@ def lower(s, args, name="mydot"):
buf
=
tvm
.
decl_buffer
(
x
.
shape
,
dtype
=
x
.
dtype
,
name
=
x
.
op
.
name
)
binds
[
x
]
=
buf
arg_list
.
append
(
buf
)
s
.
normalize
()
s
=
s
.
normalize
()
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
stmt
=
tvm
.
schedule
.
ScheduleOps
(
s
,
bounds
)
stmt
=
tvm
.
ir_pass
.
StorageFlatten
(
stmt
,
binds
)
...
...
tests/python/integration/test_gemm.py
View file @
3b8ad0a2
...
...
@@ -60,7 +60,7 @@ def test_gemm():
max_auto_unroll_step
=
0
# lowering test
s
.
normalize
()
s
=
s
.
normalize
()
# one line to build the function.
def
check_device
(
device
,
host
=
"stackvm"
):
...
...
tests/python/unittest/test_codegen_device.py
View file @
3b8ad0a2
...
...
@@ -16,6 +16,7 @@ def test_add_pipeline():
s
[
C
]
.
bind
(
xi
,
tvm
.
thread_axis
(
"blockIdx.x"
))
# compile to IR
s
=
s
.
normalize
()
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
stmt
=
tvm
.
schedule
.
ScheduleOps
(
s
,
bounds
)
Ab
=
tvm
.
decl_buffer
(
A
.
shape
,
A
.
dtype
,
name
=
'A'
)
...
...
tests/python/unittest/test_schedule_bound_inference.py
View file @
3b8ad0a2
...
...
@@ -22,6 +22,8 @@ def test_bound2():
A2
=
tvm
.
compute
((
m
,
l
),
lambda
i
,
j
:
A1
[
i
,
j
]
+
3
,
name
=
'A2'
)
s
=
tvm
.
create_schedule
(
A2
.
op
)
xo
,
yo
,
xi
,
yi
=
s
[
A2
]
.
tile
(
A2
.
op
.
axis
[
0
],
A2
.
op
.
axis
[
1
],
8
,
8
)
# test normalize not affecting schedule
_
=
s
.
normalize
()
s
[
A1
]
.
compute_at
(
s
[
A2
],
yo
)
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
assert
isinstance
(
bounds
,
tvm
.
collections
.
Map
)
...
...
@@ -41,6 +43,8 @@ def test_bound3():
xi0
,
xi1
=
s
[
A2
]
.
split
(
xi
,
nparts
=
16
)
s
[
A2
]
.
bind
(
xi0
,
tvm
.
thread_axis
(
"threadIdx.x"
))
yo
,
yi
=
s
[
A2
]
.
split
(
A2
.
op
.
axis
[
1
],
16
)
# test normalize not affecting schedule
_
=
s
.
normalize
()
s
[
A2
]
.
reorder
(
xo
,
xi0
,
yo
,
xi1
,
yi
)
s
[
A1
]
.
compute_at
(
s
[
A2
],
yo
)
...
...
@@ -63,7 +67,7 @@ def test_bound_scan():
XX
=
s
.
cache_read
(
X
,
"local"
,
s_update
)
xo
,
xi
=
s
[
s_update
]
.
split
(
s_update
.
op
.
axis
[
1
],
factor
=
4
)
s
[
XX
]
.
compute_at
(
s
[
s_update
],
xo
)
s
.
normalize
()
s
=
s
.
normalize
()
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
stmt
=
tvm
.
schedule
.
ScheduleOps
(
s
,
bounds
)
assert
bounds
[
XX
.
op
.
axis
[
1
]]
.
extent
.
value
==
4
...
...
@@ -77,7 +81,7 @@ def test_bound_conv1d():
B
=
tvm
.
compute
(
n
,
computeB
,
name
=
'B'
)
s
=
tvm
.
create_schedule
(
B
.
op
)
s
[
A
]
.
compute_at
(
s
[
B
],
B
.
op
.
axis
[
0
])
s
.
normalize
()
s
=
s
.
normalize
()
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
assert
(
bounds
[
A
.
op
.
axis
[
0
]]
.
extent
.
value
==
3
)
...
...
@@ -92,7 +96,7 @@ def test_bound_blur():
B
=
tvm
.
compute
((
n
-
2
,
n
-
2
),
computeB
,
name
=
'B'
)
s
=
tvm
.
create_schedule
(
B
.
op
)
s
[
A
]
.
compute_at
(
s
[
B
],
B
.
op
.
axis
[
1
])
s
.
normalize
()
s
=
s
.
normalize
()
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
assert
(
bounds
[
A
.
op
.
axis
[
0
]]
.
extent
.
value
==
3
)
assert
(
bounds
[
A
.
op
.
axis
[
1
]]
.
extent
.
value
==
3
)
...
...
@@ -106,7 +110,7 @@ def test_bound_rfactor():
s
=
tvm
.
create_schedule
(
B
.
op
)
kf
,
ki
=
s
[
B
]
.
split
(
k
,
nparts
=
4
)
BF
=
s
.
rfactor
(
B
,
kf
)
s
.
normalize
()
s
=
s
.
normalize
()
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
assert
(
bounds
[
BF
.
op
.
axis
[
0
]]
.
extent
.
value
==
4
)
...
...
@@ -123,7 +127,7 @@ def test_bound_group_schedule():
g
.
compute_at
(
s
[
x2
],
x2
.
op
.
axis
[
0
])
assert
s
[
x1
]
.
group
==
g
assert
s
[
x
]
.
group
==
g
s
.
normalize
()
s
=
s
.
normalize
()
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
assert
bounds
[
x
.
op
.
axis
[
0
]]
.
extent
.
value
==
1
assert
bounds
[
x
.
op
.
axis
[
1
]]
.
extent
==
n
...
...
@@ -141,7 +145,7 @@ def test_bound_nest_group():
assert
s
[
x1
]
.
group
==
g2
g2
.
compute_at
(
s
[
x2
],
x2
.
op
.
axis
[
0
])
g1
.
compute_at
(
s
[
x1
],
s
[
x1
]
.
op
.
axis
[
1
])
s
.
normalize
()
s
=
s
.
normalize
()
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
assert
bounds
[
x
.
op
.
axis
[
0
]]
.
extent
.
value
==
1
assert
bounds
[
x
.
op
.
axis
[
1
]]
.
extent
.
value
==
1
...
...
@@ -169,7 +173,7 @@ def test_bound_nest_thread():
_
,
xi
=
s
[
A2
]
.
split
(
A2
.
op
.
axis
[
0
],
nparts
=
1
)
s
[
A2
]
.
bind
(
xi
,
thread_x
)
s
[
A1
]
.
compute_at
(
s
[
A3
],
tx
)
s
.
normalize
()
s
=
s
.
normalize
()
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
assert
(
bounds
[
A1
.
op
.
axis
[
0
]]
.
extent
.
value
==
1
)
assert
(
bounds
[
A2
.
op
.
axis
[
0
]]
.
extent
.
value
==
32
)
...
...
@@ -225,7 +229,7 @@ def test_gemm_bound():
tx
,
xi
=
s
[
BB
]
.
split
(
xi
,
nparts
=
num_thread
)
s
[
BB
]
.
bind
(
ty
,
thread_y
)
s
[
BB
]
.
bind
(
tx
,
thread_x
)
s
.
normalize
()
s
=
s
.
normalize
()
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
assert
(
bounds
[
BB
.
op
.
axis
[
0
]]
.
extent
.
value
==
64
)
assert
(
bounds
[
AA
.
op
.
axis
[
0
]]
.
extent
.
value
==
64
)
...
...
tests/python/unittest/test_schedule_schedule_ops.py
View file @
3b8ad0a2
...
...
@@ -51,7 +51,7 @@ def test_schedule_scan():
assert
tuple
(
res
.
shape
)
==
(
m
,
n
)
s
=
tvm
.
create_schedule
(
res
.
op
)
s
.
normalize
()
s
=
s
.
normalize
()
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
assert
(
bounds
[
res
.
op
.
scan_axis
]
.
min
.
value
==
1
)
stmt
=
tvm
.
schedule
.
ScheduleOps
(
s
,
bounds
)
...
...
@@ -68,7 +68,7 @@ def test_auto_inline():
s
=
tvm
.
create_schedule
(
T2
.
op
)
tvm
.
schedule
.
AutoInlineElemWise
(
s
)
s
.
normalize
()
s
=
s
.
normalize
()
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
stmt
=
tvm
.
schedule
.
ScheduleOps
(
s
,
bounds
)
...
...
@@ -83,7 +83,7 @@ def test_inline_mixed():
xo
,
xi
=
s
[
C
]
.
split
(
C
.
op
.
axis
[
0
],
factor
=
8
)
s
[
A1
]
.
compute_at
(
s
[
C
],
xo
)
s
[
A2
]
.
compute_inline
()
s
.
normalize
()
s
=
s
.
normalize
()
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
stmt
=
tvm
.
schedule
.
ScheduleOps
(
s
,
bounds
)
print
(
stmt
)
...
...
tests/verilog/integration/test_codegen_verilog.py
View file @
3b8ad0a2
...
...
@@ -11,7 +11,7 @@ def lower(s, args, name):
buf
=
tvm
.
decl_buffer
(
x
.
shape
,
dtype
=
x
.
dtype
,
name
=
x
.
op
.
name
)
binds
[
x
]
=
buf
arg_list
.
append
(
buf
)
s
.
normalize
()
s
=
s
.
normalize
()
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
stmt
=
tvm
.
schedule
.
ScheduleOps
(
s
,
bounds
)
stmt
=
tvm
.
ir_pass
.
StorageFlatten
(
stmt
,
binds
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment