Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
a0062582
Commit
a0062582
authored
Dec 24, 2018
by
eqy
Committed by
Tianqi Chen
Dec 24, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RELAY][AUTOTVM] Extract tuning tasks from Relay programs (#2181)
parent
3cf910c8
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
477 additions
and
207 deletions
+477
-207
python/tvm/autotvm/task/__init__.py
+1
-0
python/tvm/autotvm/task/nnvm_integration.py
+29
-202
python/tvm/autotvm/task/relay_integration.py
+200
-0
python/tvm/autotvm/task/topi_integration.py
+189
-3
tests/python/relay/test_autotvm_task_extraction.py
+56
-0
topi/python/topi/x86/conv2d.py
+1
-1
topi/python/topi/x86/depthwise_conv2d.py
+1
-1
No files found.
python/tvm/autotvm/task/__init__.py
View file @
a0062582
...
@@ -14,3 +14,4 @@ from .dispatcher import dispatcher, DispatchContext, ApplyConfig, ApplyHistoryBe
...
@@ -14,3 +14,4 @@ from .dispatcher import dispatcher, DispatchContext, ApplyConfig, ApplyHistoryBe
from
.topi_integration
import
register_topi_compute
,
register_topi_schedule
from
.topi_integration
import
register_topi_compute
,
register_topi_schedule
from
.nnvm_integration
import
extract_from_graph
,
extract_from_multiple_graph
from
.nnvm_integration
import
extract_from_graph
,
extract_from_multiple_graph
from
.relay_integration
import
extract_from_program
,
extract_from_multiple_program
python/tvm/autotvm/task/nnvm_integration.py
View file @
a0062582
...
@@ -7,208 +7,13 @@ import warnings
...
@@ -7,208 +7,13 @@ import warnings
import
logging
import
logging
from
...
import
t
ensor
,
placeholder
,
create_schedule
,
t
arget
as
_target
from
...
import
target
as
_target
from
.
.util
import
get_const_tupl
e
from
.
task
import
creat
e
from
.t
ask
import
create
,
register
from
.t
opi_integration
import
TaskExtractEnv
logger
=
logging
.
getLogger
(
'autotvm'
)
logger
=
logging
.
getLogger
(
'autotvm'
)
def
serialize_args
(
args
):
"""serialize arguments of a topi function to a hashable tuple.
Parameters
----------
args: list of hashable or Tensor
"""
ret
=
[]
for
t
in
args
:
if
isinstance
(
t
,
tensor
.
Tensor
):
ret
.
append
((
'TENSOR'
,
get_const_tuple
(
t
.
shape
),
t
.
dtype
))
else
:
ret
.
append
(
t
)
return
tuple
(
ret
)
def
deserialize_args
(
args
):
"""The inverse function of :code:`serialize_args`.
Parameters
----------
args: list of hashable or Tensor
"""
ret
=
[]
for
t
in
args
:
if
isinstance
(
t
,
tuple
)
and
t
[
0
]
==
'TENSOR'
:
ret
.
append
(
placeholder
(
shape
=
t
[
1
],
dtype
=
t
[
2
]))
else
:
ret
.
append
(
t
)
return
ret
# Task extractor for nnvm graph
class
TaskExtractEnv
:
"""Global environment for extracting tuning tasks from nnvm graph"""
current
=
None
def
__init__
(
self
):
import
topi
import
nnvm
# NOTE: To add more symbols, you only need to change the following lists
# nnvm symbol -> topi compute
self
.
symbol2topi
=
{
nnvm
.
sym
.
conv2d
:
[
topi
.
nn
.
conv2d
,
topi
.
nn
.
depthwise_conv2d_nchw
,
topi
.
nn
.
group_conv2d_nchw
],
nnvm
.
sym
.
conv2d_transpose
:
[
topi
.
nn
.
conv2d_transpose_nchw
],
nnvm
.
sym
.
dense
:
[
topi
.
nn
.
dense
],
}
# topi compute -> autotvm task name
self
.
topi_to_task
=
{
topi
.
nn
.
conv2d
:
"topi_nn_conv2d"
,
topi
.
nn
.
depthwise_conv2d_nchw
:
"topi_nn_depthwise_conv2d_nchw"
,
topi
.
nn
.
group_conv2d_nchw
:
"topi_nn_group_conv2d_nchw"
,
topi
.
nn
.
conv2d_transpose_nchw
:
"topi_nn_conv2d_transpose_nchw"
,
topi
.
nn
.
dense
:
"topi_nn_dense"
,
}
self
.
topi_to_schedule
=
{
topi
.
nn
.
conv2d
:
[
topi
.
generic
.
schedule_conv2d_nchw
,
topi
.
generic
.
schedule_conv2d_nhwc
],
topi
.
nn
.
depthwise_conv2d_nchw
:
[
topi
.
generic
.
schedule_depthwise_conv2d_nchw
,
topi
.
generic
.
schedule_depthwise_conv2d_nhwc
],
topi
.
nn
.
group_conv2d_nchw
:
[
topi
.
generic
.
schedule_group_conv2d_nchw
],
topi
.
nn
.
conv2d_transpose_nchw
:
[
topi
.
generic
.
schedule_conv2d_transpose_nchw
],
topi
.
nn
.
dense
:
[
topi
.
generic
.
schedule_dense
],
}
self
.
_register_tracing
()
self
.
_register_topi_task
()
self
.
task_collection
=
[]
self
.
wanted_topi_funcs
=
list
(
self
.
topi_to_task
.
keys
())
def
_register_tracing
(
self
):
"""Register tracing function to track the topi function call"""
# register topi compute for "tracing" target
for
topi_compute
in
self
.
topi_to_task
:
def
_local_scope
(
compute_func
):
"""start a scope to hold the local function in for loop"""
@compute_func.register
(
"tracing"
,
)
def
_tracing_topi_compute
(
*
args
,
**
kwargs
):
assert
not
kwargs
,
"Do not support extracting tuning tasks when"
\
"kwargs is used in TOPI function call."
\
"Please modify it to use only positional args."
if
compute_func
in
self
.
wanted_topi_funcs
:
# record this call
key
=
(
self
.
topi_to_task
[
compute_func
],
serialize_args
(
args
))
if
key
not
in
self
.
task_collection
:
self
.
task_collection
.
append
(
key
)
return
compute_func
.
fdefault
(
*
args
)
_local_scope
(
topi_compute
)
# register topi schedule for "tracing" target
for
topi_compute
in
self
.
topi_to_task
:
for
topi_schedule
in
self
.
topi_to_schedule
[
topi_compute
]:
def
_local_scope_
(
schedule_func
):
"""start a scope to hold the local function in for loop"""
@schedule_func.register
(
"tracing"
,
)
def
_tracing_topi_compute
(
outs
):
outs
=
[
outs
]
if
isinstance
(
outs
,
tensor
.
Tensor
)
else
outs
return
create_schedule
([
x
.
op
for
x
in
outs
])
_local_scope_
(
topi_schedule
)
def
_register_topi_task
(
self
):
"""register tuning wrapper for topi function"""
import
topi
# Tuning wrapper for topi functions
@register
(
"topi_nn_conv2d"
)
def
_topi_nn_conv2d
(
*
args
,
**
kwargs
):
assert
not
kwargs
,
"Do not support kwargs in template function call"
args
=
deserialize_args
(
args
)
A
,
W
=
args
[:
2
]
layout
=
args
[
-
2
]
assert
layout
==
'NCHW'
,
"only support NCHW currently"
C
=
topi
.
nn
.
conv2d
(
*
args
,
**
kwargs
)
s
=
topi
.
generic
.
schedule_conv2d_nchw
([
C
])
return
s
,
[
A
,
W
,
C
]
@register
(
"topi_nn_depthwise_conv2d_nchw"
)
def
_topi_nn_depthwise_conv2d_nchw
(
*
args
,
**
kwargs
):
assert
not
kwargs
,
"Do not support kwargs in template function call"
args
=
deserialize_args
(
args
)
A
,
W
=
args
[:
2
]
C
=
topi
.
nn
.
depthwise_conv2d_nchw
(
*
args
,
**
kwargs
)
s
=
topi
.
generic
.
schedule_depthwise_conv2d_nchw
([
C
])
return
s
,
[
A
,
W
,
C
]
@register
(
"topi_nn_group_conv2d_nchw"
)
def
_topi_nn_group_conv2d_nchw
(
*
args
,
**
kwargs
):
assert
not
kwargs
,
"Do not support kwargs in template function call"
args
=
deserialize_args
(
args
)
A
,
W
=
args
[:
2
]
C
=
topi
.
nn
.
group_conv2d_nchw
(
*
args
,
**
kwargs
)
s
=
topi
.
generic
.
schedule_group_conv2d_nchw
([
C
])
return
s
,
[
A
,
W
,
C
]
@register
(
"topi_nn_conv2d_transpose_nchw"
)
def
_topi_nn_conv2d_transpose_nchw
(
*
args
,
**
kwargs
):
assert
not
kwargs
,
"Do not support kwargs in template function call"
args
=
deserialize_args
(
args
)
A
,
W
=
args
[:
2
]
C
=
topi
.
nn
.
conv2d_transpose_nchw
(
*
args
,
**
kwargs
)
s
=
topi
.
generic
.
schedule_conv2d_transpose_nchw
([
C
])
return
s
,
[
A
,
W
,
C
]
@register
(
"topi_nn_dense"
)
def
_topi_nn_dense
(
*
args
,
**
kwargs
):
assert
not
kwargs
,
"Do not support kwargs in template function call"
args
=
deserialize_args
(
args
)
data
,
weight
,
bias
=
args
C
=
topi
.
nn
.
dense
(
*
args
,
**
kwargs
)
s
=
topi
.
generic
.
schedule_dense
([
C
])
if
bias
is
not
None
:
return
s
,
[
data
,
weight
,
bias
,
C
]
return
s
,
[
data
,
weight
,
C
]
def
reset
(
self
,
wanted_topi_funcs
):
"""Reset task collections
Parameters
----------
wanted_topi_funcs: List of function
The topi function to be extracted
"""
self
.
task_collection
=
[]
self
.
wanted_topi_funcs
=
wanted_topi_funcs
def
get_tasks
(
self
):
"""Get collected tasks
Returns
-------
tasks: List of tuple(name, args)
A list of tasks extracted from the nnvm graph
"""
return
self
.
task_collection
@staticmethod
def
get
():
"""Get the single instance of TaskExtractEnv
Returns
-------
env: TaskExtractEnv
The single instance of TaskExtractEnv
"""
if
not
TaskExtractEnv
.
current
:
TaskExtractEnv
.
current
=
TaskExtractEnv
()
return
TaskExtractEnv
.
current
def
extract_from_graph
(
graph
,
shape
,
dtype
,
target
,
symbols
,
target_host
=
None
):
def
extract_from_graph
(
graph
,
shape
,
dtype
,
target
,
symbols
,
target_host
=
None
):
""" Extract tuning tasks from a nnvm graph.
""" Extract tuning tasks from a nnvm graph.
...
@@ -237,13 +42,24 @@ def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
...
@@ -237,13 +42,24 @@ def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
collected tasks
collected tasks
"""
"""
import
nnvm.compiler
import
nnvm.compiler
import
nnvm
import
topi
env
=
TaskExtractEnv
.
get
()
env
=
TaskExtractEnv
.
get
()
#NOTE: To add more symbols, you only need to change the following lists
#nnvm symbol -> topi compute
SYMBOL2TOPI
=
{
nnvm
.
sym
.
conv2d
:
[
topi
.
nn
.
conv2d
,
topi
.
nn
.
depthwise_conv2d_nchw
,
topi
.
nn
.
group_conv2d_nchw
],
nnvm
.
sym
.
conv2d_transpose
:
[
topi
.
nn
.
conv2d_transpose_nchw
],
nnvm
.
sym
.
dense
:
[
topi
.
nn
.
dense
],
}
topi_funcs
=
[]
topi_funcs
=
[]
for
sym_name
in
symbols
:
for
sym_name
in
symbols
:
if
sym_name
in
env
.
symbol2topi
:
if
sym_name
in
SYMBOL2TOPI
:
topi_funcs
.
extend
(
env
.
symbol2topi
[
sym_name
])
topi_funcs
.
extend
(
SYMBOL2TOPI
[
sym_name
])
else
:
else
:
warnings
.
warn
(
"Symbol
%
s is not tunable, ignored"
%
sym_name
)
warnings
.
warn
(
"Symbol
%
s is not tunable, ignored"
%
sym_name
)
...
@@ -297,13 +113,24 @@ def extract_from_multiple_graph(graphs, shapes, dtypes, target, symbols, target_
...
@@ -297,13 +113,24 @@ def extract_from_multiple_graph(graphs, shapes, dtypes, target, symbols, target_
collected tasks
collected tasks
"""
"""
import
nnvm.compiler
import
nnvm.compiler
import
nnvm
import
topi
env
=
TaskExtractEnv
.
get
()
env
=
TaskExtractEnv
.
get
()
#NOTE: To add more symbols, you only need to change the following lists
#nnvm symbol -> topi compute
SYMBOL2TOPI
=
{
nnvm
.
sym
.
conv2d
:
[
topi
.
nn
.
conv2d
,
topi
.
nn
.
depthwise_conv2d_nchw
,
topi
.
nn
.
group_conv2d_nchw
],
nnvm
.
sym
.
conv2d_transpose
:
[
topi
.
nn
.
conv2d_transpose_nchw
],
nnvm
.
sym
.
dense
:
[
topi
.
nn
.
dense
],
}
topi_funcs
=
[]
topi_funcs
=
[]
for
sym_name
in
symbols
:
for
sym_name
in
symbols
:
if
sym_name
in
env
.
symbol2topi
:
if
sym_name
in
SYMBOL2TOPI
:
topi_funcs
.
extend
(
env
.
symbol2topi
[
sym_name
])
topi_funcs
.
extend
(
SYMBOL2TOPI
[
sym_name
])
else
:
else
:
warnings
.
warn
(
"Symbol
%
s is not tunable, ignored"
%
sym_name
)
warnings
.
warn
(
"Symbol
%
s is not tunable, ignored"
%
sym_name
)
...
...
python/tvm/autotvm/task/relay_integration.py
0 → 100644
View file @
a0062582
# pylint: disable=unused-variable,invalid-name
"""
Decorator and utilities for the integration with TOPI and Relay
99.9
%
copy-paste of implementation by @MerryMercy
"""
import
threading
import
warnings
import
logging
from
...
import
tensor
,
placeholder
,
target
as
_target
from
.task
import
create
from
.topi_integration
import
TaskExtractEnv
logger
=
logging
.
getLogger
(
'autotvm'
)
def
serialize_args
(
args
):
"""serialize arguments of a topi function to a hashable tuple.
Parameters
----------
args: list of hashable or Tensor
"""
ret
=
[]
for
t
in
args
:
if
isinstance
(
t
,
tensor
.
Tensor
):
ret
.
append
((
'TENSOR'
,
get_const_tuple
(
t
.
shape
),
t
.
dtype
))
else
:
ret
.
append
(
t
)
return
tuple
(
ret
)
def
deserialize_args
(
args
):
"""The inverse function of :code:`serialize_args`.
Parameters
----------
args: list of hashable or Tensor
"""
ret
=
[]
for
t
in
args
:
if
isinstance
(
t
,
tuple
)
and
t
[
0
]
==
'TENSOR'
:
ret
.
append
(
placeholder
(
shape
=
t
[
1
],
dtype
=
t
[
2
]))
else
:
ret
.
append
(
t
)
return
ret
def
extract_from_program
(
func
,
params
,
ops
,
target
,
target_host
=
None
):
""" Extract tuning tasks from a relay program.
This function collects tuning tasks by building the program
with a "tracing" target and tracing all the calls to topi.
Parameters
----------
func: relay.expr.Function
The func to tune
params: dict of str to numpy array
The associated parameters of the program
ops: List of relay op
List of relay ops to be tuned
dtype: str or dict of str to str
The input types to the program
target: tvm.target.Target
The compilation target
target_host: tvm.target.Target
The host compilation target
Returns
-------
task: Array of autotvm.task.Task
collected tasks
"""
env
=
TaskExtractEnv
.
get
()
import
tvm.relay.op
from
tvm
import
relay
import
topi
# NOTE: To add more ops, you only need to change the following lists
# relay op -> topi compute
OP2TOPI
=
{
tvm
.
relay
.
op
.
nn
.
conv2d
:
[
topi
.
nn
.
conv2d
,
topi
.
nn
.
depthwise_conv2d_nchw
,
topi
.
nn
.
group_conv2d_nchw
],
tvm
.
relay
.
op
.
nn
.
conv2d_transpose
:
[
topi
.
nn
.
conv2d_transpose_nchw
],
tvm
.
relay
.
op
.
nn
.
dense
:
[
topi
.
nn
.
dense
],
}
topi_funcs
=
[]
for
op_name
in
ops
:
if
op_name
in
OP2TOPI
:
topi_funcs
.
extend
(
OP2TOPI
[
op_name
])
else
:
warnings
.
warn
(
"Op
%
s is not tunable, ignored"
%
op_name
)
# run compiler to collect all TOPI calls during compilation
env
.
reset
(
topi_funcs
)
# disable logger temporarily
old_state
=
logger
.
disabled
logger
.
disabled
=
True
# use a "tracing" target to do a fake compile for collecting topi calls
tracing_target
=
_target
.
create
(
"llvm -device=tracing"
)
relay
.
backend
.
compile_engine
.
get
()
.
clear
()
# wrap build call in thread to avoid multiprocessing problems
build_thread
=
threading
.
Thread
(
target
=
relay
.
build
,
args
=
(
func
,
tracing_target
,
target_host
,
params
))
build_thread
.
start
()
build_thread
.
join
()
logger
.
disabled
=
old_state
# create tasks for target
tasks
=
[]
for
task_name
,
args
in
env
.
get_tasks
():
tasks
.
append
(
create
(
task_name
,
args
,
target
=
target
,
target_host
=
target_host
,
template_key
=
'direct'
))
return
tasks
def
extract_from_multiple_program
(
funcs
,
params
,
ops
,
target
,
target_host
=
None
):
""" Extract tuning tasks from multiple relay programs.
This function is the multiple program version of extract_from_program
Parameters
----------
funcs: List of relay.expr.Function
The list of functions to tune
params: List of dict of str to numpy array
The associated parameters of the programs
ops: List of relay op
List of relay ops to be tuned
target: tvm.target.Target
The compilation target
target_host: tvm.target.Target
The host compilation target
Returns
-------
task: Array of autotvm.task.Task
collected tasks
"""
env
=
TaskExtractEnv
.
get
()
import
tvm.relay.op
from
tvm
import
relay
import
topi
# NOTE: To add more ops, you only need to change the following lists
# relay op -> topi compute
OP2TOPI
=
{
tvm
.
relay
.
op
.
nn
.
conv2d
:
[
topi
.
nn
.
conv2d
,
topi
.
nn
.
depthwise_conv2d_nchw
,
topi
.
nn
.
group_conv2d_nchw
],
tvm
.
relay
.
op
.
nn
.
conv2d_transpose
:
[
topi
.
nn
.
conv2d_transpose_nchw
],
tvm
.
relay
.
op
.
nn
.
dense
:
[
topi
.
nn
.
dense
],
}
topi_funcs
=
[]
for
op_name
in
ops
:
if
op_name
in
OP2TOPI
:
topi_funcs
.
extend
(
OP2TOPI
[
op_name
])
else
:
warnings
.
warn
(
"Op
%
s is not tunable, ignored"
%
op_name
)
# run compiler to collect all TOPI calls during compilation
env
.
reset
(
topi_funcs
)
# disable logger temporarily
old_state
=
logger
.
disabled
logger
.
disabled
=
True
# use a "tracing" target to do a fake compile for collecting topi calls
tracing_target
=
_target
.
create
(
"llvm -device=tracing"
)
for
func
,
param
in
zip
(
funcs
,
params
):
# wrap build call in thread to avoid multiprocessing problems
build_thread
=
threading
.
Thread
(
target
=
relay
.
build
,
args
=
(
func
,
tracing_target
,
target_host
,
params
))
build_thread
.
start
()
build_thread
.
join
()
logger
.
disabled
=
old_state
# create tasks for target
tasks
=
[]
for
task_name
,
args
in
env
.
get_tasks
():
tasks
.
append
(
create
(
task_name
,
args
,
target
=
target
,
target_host
=
target_host
,
template_key
=
'direct'
))
return
tasks
python/tvm/autotvm/task/topi_integration.py
View file @
a0062582
...
@@ -11,16 +11,202 @@ tuple.
...
@@ -11,16 +11,202 @@ tuple.
See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage.
See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage.
"""
"""
from
...
import
_api_internal
,
tensor
from
...
import
_api_internal
,
tensor
,
placeholder
,
create_schedule
from
.task
import
args_to_workload
,
dispatcher
from
.task
import
args_to_workload
,
dispatcher
,
register
from
..util
import
get_const_tuple
# A table that records all registered dispatcher for all targets
# A table that records all registered dispatcher for all targets
_REGISTED_DISPATHCER
=
{
_REGISTED_DISPATHCER
=
{
}
}
def
serialize_args
(
args
):
"""serialize arguments of a topi function to a hashable tuple.
Parameters
----------
args: list of hashable or Tensor
"""
ret
=
[]
for
t
in
args
:
if
isinstance
(
t
,
tensor
.
Tensor
):
ret
.
append
((
'TENSOR'
,
get_const_tuple
(
t
.
shape
),
t
.
dtype
))
else
:
ret
.
append
(
t
)
return
tuple
(
ret
)
def
deserialize_args
(
args
):
"""The inverse function of :code:`serialize_args`.
Parameters
----------
args: list of hashable or Tensor
"""
ret
=
[]
for
t
in
args
:
if
isinstance
(
t
,
tuple
)
and
t
[
0
]
==
'TENSOR'
:
ret
.
append
(
placeholder
(
shape
=
t
[
1
],
dtype
=
t
[
2
]))
else
:
ret
.
append
(
t
)
return
ret
# Task extractor for nnvm graph, relay program
class
TaskExtractEnv
:
"""Global environment for extracting tuning tasks from nnvm graph"""
current
=
None
def
__init__
(
self
):
import
topi
# topi compute -> autotvm task name
self
.
topi_to_task
=
{
topi
.
nn
.
conv2d
:
"topi_nn_conv2d"
,
topi
.
nn
.
depthwise_conv2d_nchw
:
"topi_nn_depthwise_conv2d_nchw"
,
topi
.
nn
.
group_conv2d_nchw
:
"topi_nn_group_conv2d_nchw"
,
topi
.
nn
.
conv2d_transpose_nchw
:
"topi_nn_conv2d_transpose_nchw"
,
topi
.
nn
.
dense
:
"topi_nn_dense"
,
}
self
.
topi_to_schedule
=
{
topi
.
nn
.
conv2d
:
[
topi
.
generic
.
schedule_conv2d_nchw
,
topi
.
generic
.
schedule_conv2d_nhwc
],
topi
.
nn
.
depthwise_conv2d_nchw
:
[
topi
.
generic
.
schedule_depthwise_conv2d_nchw
,
topi
.
generic
.
schedule_depthwise_conv2d_nhwc
],
topi
.
nn
.
group_conv2d_nchw
:
[
topi
.
generic
.
schedule_group_conv2d_nchw
],
topi
.
nn
.
conv2d_transpose_nchw
:
[
topi
.
generic
.
schedule_conv2d_transpose_nchw
],
topi
.
nn
.
dense
:
[
topi
.
generic
.
schedule_dense
],
}
self
.
_register_tracing
()
self
.
_register_topi_task
()
self
.
task_collection
=
[]
self
.
wanted_topi_funcs
=
list
(
self
.
topi_to_task
.
keys
())
def
_register_tracing
(
self
):
"""Register tracing function to track the topi function call"""
# register topi compute for "tracing" target
for
topi_compute
in
self
.
topi_to_task
:
def
_local_scope
(
compute_func
):
"""start a scope to hold the local function in for loop"""
@compute_func.register
(
"tracing"
,
)
def
_tracing_topi_compute
(
*
args
,
**
kwargs
):
assert
not
kwargs
,
"Do not support extracting tuning tasks when"
\
"kwargs is used in TOPI function call."
\
"Please modify it to use only positional args."
if
compute_func
in
self
.
wanted_topi_funcs
:
# record this call
key
=
(
self
.
topi_to_task
[
compute_func
],
serialize_args
(
args
))
if
key
not
in
self
.
task_collection
:
self
.
task_collection
.
append
(
key
)
return
compute_func
.
fdefault
(
*
args
)
_local_scope
(
topi_compute
)
# register topi schedule for "tracing" target
for
topi_compute
in
self
.
topi_to_task
:
for
topi_schedule
in
self
.
topi_to_schedule
[
topi_compute
]:
def
_local_scope_
(
schedule_func
):
"""start a scope to hold the local function in for loop"""
@schedule_func.register
(
"tracing"
,
)
def
_tracing_topi_compute
(
outs
):
outs
=
[
outs
]
if
isinstance
(
outs
,
tensor
.
Tensor
)
else
outs
return
create_schedule
([
x
.
op
for
x
in
outs
])
_local_scope_
(
topi_schedule
)
def
_register_topi_task
(
self
):
"""register tuning wrapper for topi function"""
import
topi
# Tuning wrapper for topi functions
@register
(
"topi_nn_conv2d"
)
def
_topi_nn_conv2d
(
*
args
,
**
kwargs
):
assert
not
kwargs
,
"Do not support kwargs in template function call"
args
=
deserialize_args
(
args
)
A
,
W
=
args
[:
2
]
layout
=
args
[
-
2
]
assert
layout
==
'NCHW'
,
"only support NCHW currently"
C
=
topi
.
nn
.
conv2d
(
*
args
,
**
kwargs
)
s
=
topi
.
generic
.
schedule_conv2d_nchw
([
C
])
return
s
,
[
A
,
W
,
C
]
@register
(
"topi_nn_depthwise_conv2d_nchw"
)
def
_topi_nn_depthwise_conv2d_nchw
(
*
args
,
**
kwargs
):
assert
not
kwargs
,
"Do not support kwargs in template function call"
args
=
deserialize_args
(
args
)
A
,
W
=
args
[:
2
]
C
=
topi
.
nn
.
depthwise_conv2d_nchw
(
*
args
,
**
kwargs
)
s
=
topi
.
generic
.
schedule_depthwise_conv2d_nchw
([
C
])
return
s
,
[
A
,
W
,
C
]
@register
(
"topi_nn_group_conv2d_nchw"
)
def
_topi_nn_group_conv2d_nchw
(
*
args
,
**
kwargs
):
assert
not
kwargs
,
"Do not support kwargs in template function call"
args
=
deserialize_args
(
args
)
A
,
W
=
args
[:
2
]
C
=
topi
.
nn
.
group_conv2d_nchw
(
*
args
,
**
kwargs
)
s
=
topi
.
generic
.
schedule_group_conv2d_nchw
([
C
])
return
s
,
[
A
,
W
,
C
]
@register
(
"topi_nn_conv2d_transpose_nchw"
)
def
_topi_nn_conv2d_transpose_nchw
(
*
args
,
**
kwargs
):
assert
not
kwargs
,
"Do not support kwargs in template function call"
args
=
deserialize_args
(
args
)
A
,
W
=
args
[:
2
]
C
=
topi
.
nn
.
conv2d_transpose_nchw
(
*
args
,
**
kwargs
)
s
=
topi
.
generic
.
schedule_conv2d_transpose_nchw
([
C
])
return
s
,
[
A
,
W
,
C
]
@register
(
"topi_nn_dense"
)
def
_topi_nn_dense
(
*
args
,
**
kwargs
):
assert
not
kwargs
,
"Do not support kwargs in template function call"
args
=
deserialize_args
(
args
)
data
,
weight
,
bias
=
args
C
=
topi
.
nn
.
dense
(
*
args
,
**
kwargs
)
s
=
topi
.
generic
.
schedule_dense
([
C
])
if
bias
is
not
None
:
return
s
,
[
data
,
weight
,
bias
,
C
]
return
s
,
[
data
,
weight
,
C
]
def
reset
(
self
,
wanted_topi_funcs
):
"""Reset task collections
Parameters
----------
wanted_topi_funcs: List of function
The topi function to be extracted
"""
self
.
task_collection
=
[]
self
.
wanted_topi_funcs
=
wanted_topi_funcs
def
get_tasks
(
self
):
"""Get collected tasks
Returns
-------
tasks: List of tuple(name, args)
A list of tasks extracted from the nnvm graph
"""
return
self
.
task_collection
@staticmethod
def
get
():
"""Get the single instance of TaskExtractEnv
Returns
-------
env: TaskExtractEnv
The single instance of TaskExtractEnv
"""
if
not
TaskExtractEnv
.
current
:
TaskExtractEnv
.
current
=
TaskExtractEnv
()
return
TaskExtractEnv
.
current
def
register_topi_compute
(
topi_compute
,
target_keys
,
template_keys
,
func
=
None
):
def
register_topi_compute
(
topi_compute
,
target_keys
,
template_keys
,
func
=
None
):
"""Register a tunable template for a topi compute function.
"""Register a tunable template for a topi compute function.
...
...
tests/python/relay/test_autotvm_task_extraction.py
0 → 100644
View file @
a0062582
"""Test task extraction for autotvm"""
import
tvm.relay.testing
from
tvm
import
relay
from
tvm
import
autotvm
def
get_network
(
name
,
batch_size
):
"""Get the symbol definition and random weight of a network"""
input_shape
=
(
batch_size
,
3
,
224
,
224
)
if
name
==
'resnet-18'
:
net
,
params
=
relay
.
testing
.
resnet
.
get_workload
(
num_layers
=
18
,
batch_size
=
batch_size
)
elif
name
==
'mobilenet'
:
net
,
params
=
relay
.
testing
.
mobilenet
.
get_workload
(
batch_size
=
batch_size
)
elif
name
==
'dcgan'
:
net
,
params
=
relay
.
testing
.
dcgan
.
get_workload
(
batch_size
=
batch_size
)
input_shape
=
(
batch_size
,
100
)
else
:
raise
ValueError
(
"Unsupported network: "
+
name
)
return
net
,
params
,
input_shape
def
test_task_extraction
():
target
=
'llvm'
net
,
params
,
input_shape
=
get_network
(
'resnet-18'
,
batch_size
=
1
)
tasks
=
autotvm
.
task
.
extract_from_program
(
net
,
target
=
target
,
params
=
params
,
ops
=
(
relay
.
op
.
nn
.
conv2d
,))
assert
len
(
tasks
)
==
12
net
,
params
,
input_shape
=
get_network
(
'resnet-18'
,
batch_size
=
1
)
tasks
=
autotvm
.
task
.
extract_from_program
(
net
,
target
=
target
,
params
=
params
,
ops
=
(
relay
.
op
.
nn
.
dense
,))
assert
len
(
tasks
)
==
1
net
,
params
,
input_shape
=
get_network
(
'resnet-18'
,
batch_size
=
1
)
tasks
=
autotvm
.
task
.
extract_from_program
(
net
,
target
=
target
,
params
=
params
,
ops
=
(
relay
.
op
.
nn
.
conv2d
,
relay
.
op
.
nn
.
dense
))
assert
len
(
tasks
)
==
13
net
,
params
,
input_shape
=
get_network
(
'mobilenet'
,
batch_size
=
1
)
tasks
=
autotvm
.
task
.
extract_from_program
(
net
,
target
=
target
,
params
=
params
,
ops
=
(
relay
.
op
.
nn
.
conv2d
,
relay
.
op
.
nn
.
dense
))
assert
len
(
tasks
)
==
20
net
,
params
,
input_shape
=
get_network
(
'dcgan'
,
batch_size
=
1
)
tasks
=
autotvm
.
task
.
extract_from_program
(
net
,
target
=
target
,
params
=
params
,
ops
=
(
relay
.
op
.
nn
.
conv2d_transpose
,))
assert
len
(
tasks
)
==
4
if
__name__
==
'__main__'
:
test_task_extraction
()
topi/python/topi/x86/conv2d.py
View file @
a0062582
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
"""Conv2D schedule on x86"""
"""Conv2D schedule on x86"""
import
tvm
import
tvm
from
tvm
import
autotvm
from
tvm
import
autotvm
from
tvm.autotvm.task.
nnvm
_integration
import
deserialize_args
from
tvm.autotvm.task.
topi
_integration
import
deserialize_args
from
tvm.autotvm.task
import
get_config
from
tvm.autotvm.task
import
get_config
from
..
import
generic
,
tag
from
..
import
generic
,
tag
from
..
import
nn
from
..
import
nn
...
...
topi/python/topi/x86/depthwise_conv2d.py
View file @
a0062582
...
@@ -4,7 +4,7 @@ import tvm
...
@@ -4,7 +4,7 @@ import tvm
from
tvm
import
autotvm
from
tvm
import
autotvm
from
tvm.autotvm.task
import
get_config
from
tvm.autotvm.task
import
get_config
from
tvm.autotvm.task.space
import
SplitEntity
from
tvm.autotvm.task.space
import
SplitEntity
from
tvm.autotvm.task.
nnvm
_integration
import
deserialize_args
from
tvm.autotvm.task.
topi
_integration
import
deserialize_args
from
..
import
generic
,
tag
from
..
import
generic
,
tag
from
..nn.pad
import
pad
from
..nn.pad
import
pad
from
..util
import
get_const_tuple
from
..util
import
get_const_tuple
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment