Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
12839e6d
Commit
12839e6d
authored
Aug 28, 2018
by
Lianmin Zheng
Committed by
Tianqi Chen
Aug 28, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[AUTOTVM] Decouple build and run in measurement (#1661)
parent
38203a86
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
880 additions
and
760 deletions
+880
-760
docs/api/python/autotvm.rst
+5
-0
python/tvm/autotvm/__init__.py
+2
-1
python/tvm/autotvm/measure/__init__.py
+4
-4
python/tvm/autotvm/measure/local_executor.py
+5
-9
python/tvm/autotvm/measure/measure.py
+173
-78
python/tvm/autotvm/measure/measure_methods.py
+486
-410
python/tvm/autotvm/tuner/ga_tuner.py
+1
-1
python/tvm/autotvm/tuner/sa_model_optimizer.py
+1
-1
tests/python/integration/test_tuning.py
+6
-35
tests/python/unittest/test_autotvm_common.py
+19
-0
tests/python/unittest/test_autotvm_database.py
+2
-149
tests/python/unittest/test_autotvm_measure.py
+97
-0
topi/recipe/gemm/gemm_int8.py
+5
-2
tutorials/autotvm/tune_conv2d_cuda.py
+6
-6
tutorials/autotvm/tune_nnvm_arm.py
+62
-61
tutorials/autotvm/tune_simple_template.py
+6
-3
No files found.
docs/api/python/autotvm.rst
View file @
12839e6d
...
@@ -16,6 +16,11 @@ tvm.autotvm.measure
...
@@ -16,6 +16,11 @@ tvm.autotvm.measure
.. autofunction:: tvm.autotvm.measure.create_measure_batch
.. autofunction:: tvm.autotvm.measure.create_measure_batch
.. autoclass:: tvm.autotvm.measure.measure_methods.LocalBuilder
.. autoclass:: tvm.autotvm.measure.measure_methods.RPCRunner
.. autoclass:: tvm.autotvm.measure.measure_methods.LocalRunner
tvm.autotvm.tuner
tvm.autotvm.tuner
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
...
...
python/tvm/autotvm/__init__.py
View file @
12839e6d
...
@@ -22,7 +22,8 @@ from . import env
...
@@ -22,7 +22,8 @@ from . import env
from
.
import
tophub
from
.
import
tophub
# some shortcuts
# some shortcuts
from
.measure
import
measure_option
,
MeasureInput
,
MeasureResult
,
MeasureErrorNo
from
.measure
import
measure_option
,
MeasureInput
,
MeasureResult
,
MeasureErrorNo
,
\
LocalBuilder
,
LocalRunner
,
RPCRunner
from
.tuner
import
callback
from
.tuner
import
callback
from
.task
import
template
,
get_config
,
create
,
ConfigSpace
,
ConfigEntity
,
\
from
.task
import
template
,
get_config
,
create
,
ConfigSpace
,
ConfigEntity
,
\
register_topi_compute
,
register_topi_schedule
,
\
register_topi_compute
,
register_topi_schedule
,
\
...
...
python/tvm/autotvm/measure/__init__.py
View file @
12839e6d
"""Distributed executor infrastructure to scale up the tuning"""
"""Distributed executor infrastructure to scale up the tuning"""
from
.measure
import
MeasureInput
,
MeasureResult
,
MeasureErrorNo
,
measure_option
from
.measure
import
MeasureInput
,
MeasureResult
,
MeasureErrorNo
,
measure_option
,
\
from
.measure_methods
import
request_remote
,
check_remote
,
create_measure_batch
,
rpc
create_measure_batch
from
.measure_methods
import
LocalBuilder
,
LocalRunner
,
RPCRunner
,
request_remote
from
.executor
import
Executor
from
.local_executor
import
LocalExecutor
from
.local_executor
import
LocalExecutor
from
.executor
import
Future
,
Executor
python/tvm/autotvm/measure/local_executor.py
View file @
12839e6d
...
@@ -37,7 +37,8 @@ def _execute_func(func, queue, args, kwargs):
...
@@ -37,7 +37,8 @@ def _execute_func(func, queue, args, kwargs):
res
=
exc
res
=
exc
queue
.
put
(
res
)
queue
.
put
(
res
)
def
timeout_monitor
(
queue
,
timeout
,
func
,
args
,
kwargs
):
def
call_with_timeout
(
queue
,
timeout
,
func
,
args
,
kwargs
):
"""A wrapper to support timeout of a function call"""
"""A wrapper to support timeout of a function call"""
# start a new process for timeout (cannot use thread because we have c function)
# start a new process for timeout (cannot use thread because we have c function)
...
@@ -45,17 +46,12 @@ def timeout_monitor(queue, timeout, func, args, kwargs):
...
@@ -45,17 +46,12 @@ def timeout_monitor(queue, timeout, func, args, kwargs):
p
.
start
()
p
.
start
()
p
.
join
(
timeout
=
timeout
)
p
.
join
(
timeout
=
timeout
)
alive
=
p
.
is_alive
()
queue
.
put
(
executor
.
TimeoutError
())
kill_child_processes
(
p
.
pid
)
kill_child_processes
(
p
.
pid
)
p
.
terminate
()
p
.
terminate
()
p
.
join
()
p
.
join
()
if
alive
:
queue
.
put
(
executor
.
TimeoutError
())
else
:
if
queue
.
empty
():
queue
.
put
(
executor
.
ExecutionError
(
"Fatal error in local executor"
))
class
LocalFuture
(
executor
.
Future
):
class
LocalFuture
(
executor
.
Future
):
"""Local wrapper for the future
"""Local wrapper for the future
...
@@ -134,7 +130,7 @@ class LocalExecutor(executor.Executor):
...
@@ -134,7 +130,7 @@ class LocalExecutor(executor.Executor):
return
LocalFutureNoFork
(
func
(
*
args
,
**
kwargs
))
return
LocalFutureNoFork
(
func
(
*
args
,
**
kwargs
))
queue
=
Queue
(
2
)
queue
=
Queue
(
2
)
process
=
Process
(
target
=
timeout_monitor
,
process
=
Process
(
target
=
call_with_timeout
,
args
=
(
queue
,
self
.
timeout
,
func
,
args
,
kwargs
))
args
=
(
queue
,
self
.
timeout
,
func
,
args
,
kwargs
))
process
.
start
()
process
.
start
()
return
LocalFuture
(
process
,
queue
)
return
LocalFuture
(
process
,
queue
)
python/tvm/autotvm/measure/measure.py
View file @
12839e6d
# pylint: disable=pointless-string-statement,consider-using-enumerate,invalid-name
# pylint: disable=pointless-string-statement,consider-using-enumerate,invalid-name
"""User facing API for specifying how to measure the generated code"""
"""User facing API for specifying how to measure the generated code"""
import
multiprocessing
from
collections
import
namedtuple
from
collections
import
namedtuple
class
MeasureInput
(
namedtuple
(
"MeasureInput"
,
[
"target"
,
"task"
,
"config"
])):
class
MeasureInput
(
namedtuple
(
"MeasureInput"
,
[
"target"
,
"task"
,
"config"
])):
...
@@ -16,6 +17,7 @@ class MeasureInput(namedtuple("MeasureInput", ["target", "task", "config"])):
...
@@ -16,6 +17,7 @@ class MeasureInput(namedtuple("MeasureInput", ["target", "task", "config"])):
Specific configuration.
Specific configuration.
"""
"""
class
MeasureResult
(
namedtuple
(
"MeasureResult"
,
[
"costs"
,
"error_no"
,
"all_cost"
,
"timestamp"
])):
class
MeasureResult
(
namedtuple
(
"MeasureResult"
,
[
"costs"
,
"error_no"
,
"all_cost"
,
"timestamp"
])):
"""
"""
Stores all the results of a measurement
Stores all the results of a measurement
...
@@ -23,8 +25,8 @@ class MeasureResult(namedtuple("MeasureResult", ["costs", "error_no", "all_cost"
...
@@ -23,8 +25,8 @@ class MeasureResult(namedtuple("MeasureResult", ["costs", "error_no", "all_cost"
Parameters
Parameters
----------
----------
costs: Array of float or Array of Exception
costs: Array of float or Array of Exception
If no error occurs
for this
measurement, it is an array of measured running times.
If no error occurs
during
measurement, it is an array of measured running times.
If
some error occurs during the
measurement, it is an array of the exception objections.
If
an error occurs during
measurement, it is an array of the exception objections.
error_no: int
error_no: int
Denote error type, defined by MeasureErrorNo
Denote error type, defined by MeasureErrorNo
all_cost: float
all_cost: float
...
@@ -37,92 +39,185 @@ class MeasureResult(namedtuple("MeasureResult", ["costs", "error_no", "all_cost"
...
@@ -37,92 +39,185 @@ class MeasureResult(namedtuple("MeasureResult", ["costs", "error_no", "all_cost"
class
MeasureErrorNo
(
object
):
class
MeasureErrorNo
(
object
):
"""Error type for MeasureResult"""
"""Error type for MeasureResult"""
NO_ERROR
=
0
# no error
NO_ERROR
=
0
# no error
INSTANTIATION_ERROR
=
1
#
error when calling template function
INSTANTIATION_ERROR
=
1
#
actively detected error in instantiating a template with a config
COMPILE_HOST
=
2
# error when compiling code on host (e.g. tvm.build)
COMPILE_HOST
=
2
# error when compiling code on host (e.g. tvm.build)
COMPILE_DEVICE
=
3
# error when compiling code on device (e.g.
opencl JIT on
device)
COMPILE_DEVICE
=
3
# error when compiling code on device (e.g.
OpenCL JIT on the
device)
RUNTIME_DEVICE
=
4
# error when run program on device
RUNTIME_DEVICE
=
4
# error when run program on device
WRONG_ANSWER
=
5
# answer is wrong when compared to a golden output
WRONG_ANSWER
=
5
# answer is wrong when compared to a golden output
FLEET_ERROR
=
6
# error of measure infrastructure
BUILD_TIMEOUT
=
6
# timeout during compilation
RUN_TIMEOUT
=
7
# timeout during run
UNKNOWN_ERROR
=
8
# unknown error
class
Builder
(
object
):
"""Builder that builds programs in tuning
def
measure_option
(
measure_func
,
Parameters
number
=
1
,
----------
repeat
=
1
,
timeout: float, optional
timeout
=
60
,
The timeout of a build task
n_parallel
=
1
,
n_parallel: int, optional
do_fork
=
True
,
The number of tasks submitted in parallel
build_func
=
'default'
,
By default it will use all cpu cores
check_correctness
=
False
,
"""
replay_db
=
None
):
def
__init__
(
self
,
timeout
=
10
,
n_parallel
=
None
):
"""Configure how to do measurement
self
.
timeout
=
timeout
self
.
n_parallel
=
n_parallel
or
multiprocessing
.
cpu_count
()
self
.
build_kwargs
=
{}
self
.
task
=
None
def
set_task
(
self
,
task
,
build_kwargs
=
None
):
"""
Initialize for a new tuning task
Parameters
----------
task: Task
The tuning task
build_kwargs: dict, optional
The additional kwargs for build function
"""
self
.
task
=
task
self
.
build_kwargs
=
build_kwargs
def
build
(
self
,
measure_inputs
):
"""Build programs
Parameters
----------
measure_inputs: List of MeasureInput
The measure input
Returns
-------
build_results: List of BuildResult
The build result.
"""
raise
NotImplementedError
()
class
Runner
(
object
):
"""Runner that runs and measures the time cost of a generated program in tuning
Parameters
Parameters
----------
----------
measure_func: str or callable
timeout: float, optional
'local': use the local device for measurement. The tuner will start a tracker
The timeout of a build task
and a RPC server silently for the user.
callable: It is a callable function for measurement.
See the return value of measure/measure_methods.py::rpc for example.
number : int, optional
Number of times to do the measurement for average
repeat : int, optional
Number of times to repeat the measurement.
In total, the generated code will be run (1 + number x repeat) times,
where the first one is warm up. The returned result contains `repeat` costs,
each of which is the average of `number` test run.
timeout: int, optional
Timeout for a whole batch. TimeoutError will be returned as the result if a
task timeouts.
n_parallel: int, optional
n_parallel: int, optional
The number of measurement task that can run in parallel.
The number of tasks submitted in parallel
Set this according to the number of cpu cores (for compilation) and
By default it will use all cpu cores
the number of devices you have (for measuring generate code).
"""
do_fork: bool, optional
def
__init__
(
self
,
timeout
=
5
,
n_parallel
=
None
):
Whether use multiprocessing (based on fork) for running measure jobs in parallel.
self
.
timeout
=
timeout
Set this to False if you want to debug (see trackback) or using fork is not suitable.
self
.
n_parallel
=
n_parallel
or
multiprocessing
.
cpu_count
()
NOTE: If this is False, parallel and timeout do not work.
self
.
task
=
None
build_func: str or callable, optional
'default': call default builder. This works for normal target (llvm, cuda)
def
set_task
(
self
,
task
):
"""
'ndk': use Android NDK to create shared library. Use this for android target.
Initialize for a new tuning task
callable: customized build function for other backends (e.g. VTA).
Parameters
See measure/measure_methods.py::default_build_func for example.
----------
check_correctness: bool, optional
task: Task
Whether check correctness after measurement. This will use llvm cpu target to generate
The tuning task
reference output.
"""
replay_db : Database, optional
self
.
task
=
task
The database that we retrieve saved MeasureResult from.
def
get_build_kwargs
(
self
):
"""
Get device specific build arguments (e.g. maximum shared memory size)
Returns
----------
kwargs: dict
The additional keyword arguments
"""
raise
NotImplementedError
()
def
run
(
self
,
measure_inputs
,
build_results
):
"""Run amd measure built programs
Parameters
----------
measure_inputs: List of MeasureInput
The raw measure input
build_results: List of BuildResults
The build results
Returns
-------
measure_results: List of MeasureResult
The final results of measurement
"""
raise
NotImplementedError
()
def
measure_option
(
builder
,
runner
):
"""
Set options for measure. To measure a config, we will build it and run it.
So we have to set options for these two steps.
They have their own options on timeout, parallel, etc.
Parameters
----------
builder: Builder
Specify how to build programs
runner: Runner
Specify how to run programs
"""
from
.measure_methods
import
LocalBuilder
,
LocalRunner
if
isinstance
(
builder
,
str
):
if
builder
==
'local'
:
builder
=
LocalBuilder
()
else
:
raise
ValueError
(
"Invalid builder: "
+
builder
)
if
isinstance
(
runner
,
str
):
if
runner
==
'local'
:
runner
=
LocalRunner
()
else
:
raise
ValueError
(
"Invalid runner: "
+
runner
)
opt
=
{
'builder'
:
builder
,
'runner'
:
runner
,
}
return
opt
def
create_measure_batch
(
task
,
option
):
"""Get a standard measure_batch function.
Parameters
----------
task: tvm.autotvm.task.Task
The tuning task
option: dict
The option for measuring generated code.
You should use the return value of function :any:`measure_option` for this argument.
Returns
Returns
-------
-------
options: dict
measure_batch: callable
A dict to store all options
a callback function to measure a batch of configs
Note
----
To support customized measure, you can pass callable `measure_func` or
`build_func` in. The `measure_func` will call `build_func` to build binary library
and handle the logic of measurement.
Signature:
* measure_func (see the return value of measure/measure_methods.py::rpc for example)
def measure_func(input_pack, build_func, build_kwargs, number, repeat, ref_input, ref_output):
return measure_results
* build_func (see measure/measure_methods.py::default_build_func for example)
def build_func(inp, tmp_dir, **kwargs):
return func, args, filename
"""
"""
return
{
builder
=
option
[
'builder'
]
'measure_func'
:
measure_func
,
runner
=
option
[
'runner'
]
'number'
:
number
,
'repeat'
:
repeat
,
attach_objects
=
runner
.
set_task
(
task
)
'timeout'
:
timeout
,
'n_parallel'
:
n_parallel
,
# feed device related information from runner to builder
'do_fork'
:
do_fork
,
# (e.g. max shared memory for validity checking)
'build_func'
:
build_func
,
build_kwargs
=
runner
.
get_build_kwargs
()
'check_correctness'
:
check_correctness
,
builder
.
set_task
(
task
,
build_kwargs
)
'replay_db'
:
replay_db
,
}
def
measure_batch
(
measure_inputs
):
build_results
=
builder
.
build
(
measure_inputs
)
results
=
runner
.
run
(
measure_inputs
,
build_results
)
return
results
measure_batch
.
n_parallel
=
builder
.
n_parallel
measure_batch
.
attach_objects
=
attach_objects
return
measure_batch
python/tvm/autotvm/measure/measure_methods.py
View file @
12839e6d
# pylint: disable=
consider-using-enumerate,invalid-name,too-many-function-arg
s
# pylint: disable=
invalid-name,too-many-function-args,too-many-nested-block
s
"""
"""
Functions that run on executor for measurement.
Functions that run on executor for measurement.
These functions are responsible for building tvm module, uploading it to
remote devices, recording the running time costs and checking the correctness of output
These functions are responsible for building the tvm module, uploading it to
remote devices, recording the running time costs, and checking the correctness of the output.
"""
"""
import
logging
import
logging
import
shutil
import
os
import
os
import
threading
import
time
import
time
from
random
import
getrandbits
from
random
import
getrandbits
import
threading
from
collections
import
namedtuple
import
tempfile
import
numpy
as
np
import
numpy
as
np
from
...
import
ir_pass
,
build
,
build_config
,
nd
,
context
,
TVMError
,
register_func
,
\
from
...
import
ir_pass
,
build
,
build_config
,
nd
,
TVMError
,
register_func
,
\
target
as
_target
,
rpc
as
_rpc
rpc
as
_rpc
,
target
as
_target
from
...contrib
import
nvcc
,
util
,
ndk
from
...contrib
import
nvcc
,
ndk
from
..util
import
get_const_tuple
from
..util
import
get_const_tuple
from
..env
import
AutotvmGlobalScope
from
..env
import
AutotvmGlobalScope
from
..task.space
import
InstantiationError
from
..task.space
import
InstantiationError
from
.measure
import
MeasureResult
,
MeasureErrorNo
from
.measure
import
MeasureResult
,
MeasureErrorNo
,
Builder
,
Runner
from
.local_executor
import
LocalExecutor
from
.local_executor
import
LocalExecutor
logger
=
logging
.
getLogger
(
'autotvm'
)
logger
=
logging
.
getLogger
(
'autotvm'
)
class
HashMismatchError
(
ValueError
):
class
BuildResult
(
namedtuple
(
"BuildResult"
,
(
'filename'
,
'arg_info'
,
'error'
,
'time_cost'
))):
"""Raised when the code hash of a submitted config doesn't match that on the
"""
measure side """
Stores all the necessary inputs for a measurement.
pass
Parameters
----------
filename : str
The filename of generated library
arg_info : Tuple
The shape and dtype information of tvm tensor arguments
error : Exception
The error happens during compilation.
time_cost : float
The time cost of building
"""
def
request_remote
(
device_key
,
tracker_addr
=
None
,
priority
=
1
,
timeout
=
60
):
class
LocalBuilder
(
Builder
):
"""
request a remote session
"""
Run compilation on local machine
Parameters
Parameters
----------
----------
device_key: string
timeout: float
device key of registered device in tracker
The timeout of a compilation
tracker_addr: Tuple(string, int), optional
n_parallel: int
The address of rpc tracker in (host, port) format.
The number of tasks run in parallel. "None" will use all cpu cores
If is none, will use environment variable "TVM_TRACKER_HOST"
build_func: callable or str
and "TVM_TRACKER_PORT"
If is 'default', use default build function
priority: int, optional
If is 'ndk', use function for android ndk
The priority of this request, larger is more prior
If is callable, use it as custom build function
timeout: float, optional
The timeout of this session (units: seconds)
Returns
------
session: RPCSession
"""
"""
# connect to the tracker
def
__init__
(
self
,
timeout
=
10
,
n_parallel
=
None
,
build_func
=
'default'
):
if
tracker_addr
:
super
(
LocalBuilder
,
self
)
.
__init__
(
timeout
,
n_parallel
)
host
=
tracker_addr
[
0
]
or
os
.
environ
[
'TVM_TRACKER_HOST'
]
port
=
tracker_addr
[
1
]
or
int
(
os
.
environ
[
'TVM_TRACKER_PORT'
])
if
isinstance
(
build_func
,
str
):
else
:
if
build_func
==
'default'
:
host
=
os
.
environ
[
'TVM_TRACKER_HOST'
]
build_func
=
default_build_func
port
=
int
(
os
.
environ
[
'TVM_TRACKER_PORT'
])
elif
build_func
==
'ndk'
:
build_func
=
android_ndk_build_func
else
:
raise
ValueError
(
"Invalid build_func"
+
build_func
)
tracker
=
_rpc
.
connect_tracker
(
host
,
port
)
self
.
build_func
=
build_func
remote
=
tracker
.
request
(
device_key
,
priority
=
priority
,
self
.
tmp_dir
=
tempfile
.
mkdtemp
()
session_timeout
=
timeout
)
self
.
executor
=
LocalExecutor
(
timeout
=
timeout
)
return
remote
def
check_remote
(
target
,
device_key
,
tracker_addr
=
None
,
priority
=
2
,
timeout
=
10
):
def
build
(
self
,
measure_inputs
):
"""
results
=
[]
Check the availability of a remote device
for
i
in
range
(
0
,
len
(
measure_inputs
),
self
.
n_parallel
):
futures
=
[]
for
inp
in
measure_inputs
[
i
:
i
+
self
.
n_parallel
]:
ret
=
self
.
executor
.
submit
(
self
.
build_func
,
inp
,
self
.
tmp_dir
,
**
self
.
build_kwargs
)
futures
.
append
(
ret
)
for
future
in
futures
:
res
=
future
.
get
()
if
isinstance
(
res
,
Exception
):
# timeout or fleet error, return MeasureResult directly
results
.
append
(
MeasureResult
((
res
,),
MeasureErrorNo
.
BUILD_TIMEOUT
,
self
.
timeout
,
time
.
time
()))
elif
res
.
error
is
not
None
:
# instantiation errorD
if
isinstance
(
res
.
error
,
InstantiationError
):
results
.
append
(
MeasureResult
((
res
.
error
,),
MeasureErrorNo
.
INSTANTIATION_ERROR
,
res
.
time_cost
,
time
.
time
()))
else
:
if
"InstantiationError"
in
str
(
res
.
error
):
msg
=
str
(
res
.
error
)
try
:
msg
=
msg
.
split
(
'
\n
'
)[
-
2
]
.
split
(
": "
)[
1
]
except
Exception
:
# pylint: disable=broad-except
pass
results
.
append
(
MeasureResult
((
InstantiationError
(
msg
),),
MeasureErrorNo
.
INSTANTIATION_ERROR
,
res
.
time_cost
,
time
.
time
()))
else
:
# tvm error
results
.
append
(
MeasureResult
((
res
.
error
,),
MeasureErrorNo
.
COMPILE_HOST
,
res
.
time_cost
,
time
.
time
()))
else
:
# return BuildResult
results
.
append
(
res
)
return
results
def
__del__
(
self
):
shutil
.
rmtree
(
self
.
tmp_dir
)
class
RPCRunner
(
Runner
):
"""Run generated code on remove devices.
This function will ask a RPC Tracker to get device for measurement.
Parameters
Parameters
----------
----------
target: Target
timeout: float
The wanted compilation target
The timeout of a compilation
device_key: string
n_parallel: int
device key of registered device in tracker
The number of tasks run in parallel. "None" will use all cpu cores
tracker_addr: Tuple(string, int), optional
key: str
The address of rpc tracker in (host, port) format.
The key of the device registered in the tracker
If is none, will use environment variable "TVM_TRACKER_HOST"
host: str
and "TVM_TRACKER_PORT"
The host address of RPC Tracker
priority: int, optional
port: int
The priority of this request, larger is more prior
The port of RPC Tracker
timeout: float, optional
number : int, optional
The timeout of this check (units: seconds).
Number of times to do measurement for tasking average
If time is out, a RuntimeError will be raised.
repeat : int, optional
Number of times to repeat the measurement.
In total, the generated code will be run (1 + number x repeat) times,
where the first one is warm up. The returned result contains `repeat` costs,
min_repeat_ms : float, optional
Minimum duration of a timer measurement in milliseconds.
When the run time of a measurement trial falls below this time, the
`number` parameter will be automatically increased.
Set this to improve the accuracy of perf measurement, e.g., when timers
are not precise enough to capture short-running tasks. This parameter is
also critical when devices need a certain minimum running time to "warm
up," such as GPUs that need time to reach a performance power state.
cooldown_interval: float, optional
The cool down interval between two measurements.
check_correctness: bool, optional
Whether check correctness after measurement. This will use llvm cpu target to
call your template and get the reference output.
This can work for TOPI templates, but may not work for your custom template.
"""
"""
def
_check
():
def
__init__
(
self
,
remote
=
request_remote
(
device_key
,
tracker_addr
,
priority
)
key
,
host
,
port
,
priority
=
1
,
remote
.
context
(
str
(
target
))
timeout
=
10
,
n_parallel
=
None
,
t
=
threading
.
Thread
(
target
=
_check
,)
number
=
4
,
repeat
=
3
,
min_repeat_ms
=
0
,
cooldown_interval
=
0.1
,
t
.
start
()
check_correctness
=
False
):
t
.
join
(
timeout
)
super
(
RPCRunner
,
self
)
.
__init__
(
timeout
,
n_parallel
)
return
not
t
.
is_alive
()
self
.
key
=
key
self
.
host
=
host
self
.
port
=
port
self
.
priority
=
priority
self
.
timeout
=
timeout
self
.
number
=
number
self
.
repeat
=
repeat
self
.
min_repeat_ms
=
min_repeat_ms
self
.
cur_number
=
number
self
.
ref_input
=
None
self
.
ref_output
=
None
self
.
check_correctness
=
check_correctness
self
.
cooldown_interval
=
cooldown_interval
self
.
executor
=
LocalExecutor
()
def
set_task
(
self
,
task
):
self
.
task
=
task
self
.
cur_number
=
self
.
number
if
check_remote
(
task
.
target
,
self
.
key
,
self
.
host
,
self
.
port
):
logger
.
info
(
"Get devices for measurement successfully!"
)
else
:
raise
RuntimeError
(
"Cannot get remote devices from the tracker. "
"Please check the status of tracker by "
"'python -m tvm.exec.query_rpc_tracker --port [THE PORT YOU USE]' "
"and make sure you have free devices on the queue status."
)
def
create_measure_batch
(
task
,
option
):
if
self
.
check_correctness
:
"""Get a standard measure_batch function.
# use llvm cpu to generate a reference input/output
# this option works for tuning topi, but might not work for you custom op
with
_target
.
create
(
"llvm"
):
s
,
arg_bufs
=
task
.
instantiate
(
task
.
config_space
.
get
(
0
))
self
.
ref_input
=
[
np
.
random
.
uniform
(
size
=
get_const_tuple
(
x
.
shape
))
.
astype
(
x
.
dtype
)
for
x
in
arg_bufs
]
func
=
build
(
s
,
arg_bufs
,
"llvm"
)
tvm_buf
=
[
nd
.
array
(
x
)
for
x
in
self
.
ref_input
]
func
(
*
tvm_buf
)
self
.
ref_output
=
[
x
.
asnumpy
()
for
x
in
tvm_buf
]
def
get_build_kwargs
(
self
):
kwargs
=
{}
if
'cuda'
in
self
.
task
.
target
.
keys
or
'opencl'
in
self
.
task
.
target
.
keys
:
remote
=
request_remote
(
self
.
key
,
self
.
host
,
self
.
port
)
ctx
=
remote
.
context
(
str
(
self
.
task
.
target
),
0
)
max_dims
=
ctx
.
max_thread_dimensions
kwargs
[
'check_gpu'
]
=
{
'max_shared_memory_per_block'
:
ctx
.
max_shared_memory_per_block
,
'max_threads_per_block'
:
ctx
.
max_threads_per_block
,
'max_thread_x'
:
max_dims
[
0
],
'max_thread_y'
:
max_dims
[
1
],
'max_thread_z'
:
max_dims
[
2
],
}
if
'cuda'
in
self
.
task
.
target
.
keys
:
kwargs
[
"cuda_arch"
]
=
"sm_"
+
""
.
join
(
ctx
.
compute_version
.
split
(
'.'
))
return
kwargs
def
run
(
self
,
measure_inputs
,
build_results
):
results
=
[]
remote_args
=
(
self
.
key
,
self
.
host
,
self
.
port
,
self
.
priority
,
self
.
timeout
)
for
i
in
range
(
0
,
len
(
measure_inputs
),
self
.
n_parallel
):
futures
=
[]
for
measure_inp
,
build_res
in
zip
(
measure_inputs
[
i
:
i
+
self
.
n_parallel
],
build_results
[
i
:
i
+
self
.
n_parallel
]):
ret
=
self
.
executor
.
submit
(
run_through_rpc
,
measure_inp
,
build_res
,
self
.
cur_number
,
self
.
repeat
,
self
.
cooldown_interval
,
remote_args
,
self
.
ref_input
,
self
.
ref_output
)
futures
.
append
(
ret
)
for
future
in
futures
:
res
=
future
.
get
()
if
isinstance
(
res
,
Exception
):
# executor error or timeout
results
.
append
(
MeasureResult
((
str
(
res
),),
MeasureErrorNo
.
RUN_TIMEOUT
,
self
.
timeout
,
time
.
time
()))
else
:
results
.
append
(
res
)
# If some runs were too fast, do remeasure for them
# to meet the requirement of `min_repeat_ms`
remeasure
=
np
.
zeros
((
len
(
measure_inputs
),),
dtype
=
np
.
bool
)
pre_number
=
next_number
=
self
.
cur_number
min_repeat_duration
=
self
.
min_repeat_ms
/
1000.0
for
i
,
res
in
enumerate
(
results
):
if
res
.
error_no
==
MeasureErrorNo
.
NO_ERROR
:
if
np
.
mean
(
res
.
costs
)
*
pre_number
<=
min_repeat_duration
:
next_number
=
max
(
next_number
,
int
(
np
.
ceil
(
min_repeat_duration
/
np
.
mean
(
res
.
costs
))))
remeasure
[
i
]
=
True
if
pre_number
!=
next_number
:
self
.
cur_number
=
next_number
msg
=
"increasing number to
%
d"
%
self
.
cur_number
logger
.
info
(
msg
)
re_measure_inputs
=
[
x
for
i
,
x
in
enumerate
(
measure_inputs
)
if
remeasure
[
i
]]
re_build_results
=
[
x
for
i
,
x
in
enumerate
(
build_results
)
if
remeasure
[
i
]]
re_res
=
self
.
run
(
re_measure_inputs
,
re_build_results
)
ct
=
0
for
i
,
rerun
in
enumerate
(
remeasure
):
if
rerun
:
results
[
i
]
=
re_res
[
ct
]
ct
+=
1
return
results
class
LocalRunner
(
RPCRunner
):
"""Run generated code on local devices.
Parameters
Parameters
----------
----------
task: tvm.autotvm.task.Task
timeout: float
The tuning task
The timeout of a compilation
option: dict
number : int, optional
The option for measuring generated code.
Number of times to do measurement for tasking average
You should use the return value of function :any:`measure_option` for this argument.
repeat : int, optional
Number of times to repeat the measurement.
Returns
In total, the generated code will be run (1 + number x repeat) times,
-------
where the first one is warm up. The returned result contains `repeat` costs,
measure_batch: callable
each of which is the average of `number` test run.
a callback function to measure a batch of configs
min_repeat_ms : float, optional
Minimum duration of a timer measurement in milliseconds.
When the run time of a measurement trial falls below this time, the
`number` parameter will be automatically increased.
Set this to improve the accuracy of perf measurement, e.g., when timers
are not precise enough to capture short-running tasks. This parameter is
also critical when devices need a certain minimum running time to "warm
up," such as GPUs that need time to reach a performance power state.
cooldown_interval: float, optional
The cool down interval between two measurements.
check_correctness: bool, optional
Whether check correctness after measurement. This will use llvm cpu target to
call your template and get the reference output.
This can work for TOPI templates, but may not work for your custom template.
Note
----
This is a "fake" local mode. We start a silent rpc tracker and rpc server
for the user. In this way we reuse timeout/isolation mechanism in RPC infrastructure.
"""
"""
from
..database
import
filter_inputs
def
__init__
(
self
,
timeout
=
10
,
measure_func
=
option
[
'measure_func'
]
number
=
4
,
repeat
=
3
,
min_repeat_ms
=
0
,
cooldown_interval
=
0.1
,
number
,
repeat
=
option
[
'number'
],
option
[
'repeat'
]
check_correctness
=
False
):
timeout
,
n_parallel
,
do_fork
=
option
[
'timeout'
],
option
[
'n_parallel'
],
option
[
'do_fork'
]
super
(
LocalRunner
,
self
)
.
__init__
(
''
,
None
,
None
,
0
,
build_func
=
option
[
'build_func'
]
timeout
=
timeout
,
n_parallel
=
1
,
check_correctness
=
option
[
'check_correctness'
]
number
=
number
,
repeat
=
repeat
,
replay_db
=
option
[
'replay_db'
]
min_repeat_ms
=
min_repeat_ms
,
cooldown_interval
=
cooldown_interval
,
check_correctness
=
check_correctness
)
self
.
tracker
=
None
self
.
server
=
None
def
set_task
(
self
,
task
):
self
.
task
=
task
executor
=
LocalExecutor
(
timeout
=
timeout
,
do_fork
=
do_fork
)
# convert convenient string to function object
attach_objects
=
None
if
measure_func
==
'local'
:
# start temporary rpc tracker and rpc server for the user
from
...rpc.tracker
import
Tracker
from
...rpc.tracker
import
Tracker
from
...rpc.server
import
Server
from
...rpc.server
import
Server
...
@@ -133,360 +343,215 @@ def create_measure_batch(task, option):
...
@@ -133,360 +343,215 @@ def create_measure_batch(task, option):
key
=
device_key
,
key
=
device_key
,
use_popen
=
True
,
silent
=
True
,
use_popen
=
True
,
silent
=
True
,
tracker_addr
=
(
tracker
.
host
,
tracker
.
port
))
tracker_addr
=
(
tracker
.
host
,
tracker
.
port
))
self
.
key
=
device_key
self
.
host
=
tracker
.
host
self
.
port
=
tracker
.
port
measure_func
=
rpc
(
device_key
,
tracker
.
host
,
tracker
.
port
)
super
(
LocalRunner
,
self
)
.
set_task
(
task
)
attach_objects
=
(
server
,
tracker
)
return
server
,
tracker
build_kwargs
=
{}
if
build_func
==
'default'
:
build_func
=
default_build_func
if
build_func
==
'ndk'
:
build_func
=
default_build_func
build_kwargs
[
'use_ndk'
]
=
True
# check the availability of remote devices
def
_build_func_common
(
measure_input
,
check_gpu
=
None
,
cuda_arch
=
None
,
build_option
=
None
):
if
hasattr
(
measure_func
,
'rpc_info'
):
"""Common part for building a configuration"""
rpc_info
=
measure_func
.
rpc_info
target
,
task
,
config
=
measure_input
if
check_remote
(
task
.
target
,
rpc_info
[
'key'
],
(
rpc_info
[
'host'
],
rpc_info
[
'port'
])):
logger
.
info
(
"Get devices for measurement successfully!"
)
else
:
raise
RuntimeError
(
"Cannot get remote devices from the tracker. "
"Please check the status of tracker by "
"'python -m tvm.exec.query_rpc_tracker --port [THE PORT YOU USE]' "
"and make sure you have free devices on the queue status."
)
# add device info of cuda and opencl target
with
target
:
if
(
'cuda'
in
task
.
target
.
keys
or
'opencl'
in
task
.
target
.
keys
)
\
s
,
args
=
task
.
instantiate
(
config
)
and
hasattr
(
measure_func
,
'rpc_info'
):
rpc_info
=
measure_func
.
rpc_info
# check invalidity of template and code hash consistency
add_gpu_target_info
(
task
.
target
,
rpc_info
[
"key"
],
(
rpc_info
[
"host"
],
rpc_info
[
"port"
]),
if
not
config
.
valid
():
build_kwargs
)
raise
InstantiationError
(
config
.
errors
)
if
check_correctness
:
opts
=
build_option
or
{}
# use llvm cpu to generate a reference input/output
if
check_gpu
:
# Add verify pass to filter out invalid configs in advance.
# this option works for tuning topi, but might not work for you custom op
opts
[
"add_lower_pass"
]
=
[(
2
,
gpu_verify_pass
(
**
check_gpu
))]
with
_target
.
create
(
"llvm"
):
if
cuda_arch
:
s
,
arg_bufs
=
task
.
instantiate
(
task
.
config_space
.
get
(
0
))
set_cuda_target_arch
(
cuda_arch
)
ref_input
=
[
np
.
random
.
uniform
(
size
=
get_const_tuple
(
x
.
shape
))
.
astype
(
x
.
dtype
)
for
x
in
arg_bufs
]
func
=
build
(
s
,
arg_bufs
,
"llvm"
)
tvm_buf
=
[
nd
.
array
(
x
)
for
x
in
ref_input
]
func
(
*
tvm_buf
)
ref_output
=
[
x
.
asnumpy
()
for
x
in
tvm_buf
]
else
:
ref_input
=
ref_output
=
None
def
measure_batch
(
measure_inputs
):
"""measure the time cost for a batch of configs in real machines"""
if
replay_db
is
not
None
:
partial_results
,
measure_inputs
=
\
filter_inputs
(
replay_db
,
measure_inputs
,
retry
=
False
)
# launch measure jobs in parallel
pack_size
=
getattr
(
measure_func
,
"pack_size"
,
1
)
# measure `pack_size` inputs in one job
futures
=
[]
for
i
in
range
(
0
,
len
(
measure_inputs
),
pack_size
):
input_pack
=
measure_inputs
[
i
:
i
+
pack_size
]
ret
=
executor
.
submit
(
measure_func
,
input_pack
,
build_func
,
build_kwargs
,
number
,
repeat
,
ref_input
,
ref_output
)
futures
.
append
(
ret
)
# transform results
results
=
[]
for
future
in
futures
:
result
=
future
.
get
()
if
isinstance
(
result
,
Exception
):
tstamp
=
time
.
time
()
results
.
extend
([
MeasureResult
((
result
,),
MeasureErrorNo
.
FLEET_ERROR
,
timeout
,
tstamp
)]
*
pack_size
)
else
:
results
.
extend
(
result
)
if
replay_db
is
not
None
:
result_idx
=
0
for
i
in
range
(
len
(
partial_results
)):
if
partial_results
[
i
]
is
None
:
partial_results
[
i
]
=
results
[
result_idx
]
result_idx
+=
1
return
partial_results
return
results
measure_batch
.
n_parallel
=
n_parallel
with
build_config
(
**
opts
):
# attach server and tracker object to avoid them of being garbage-collected
func
=
build
(
s
,
args
,
target_host
=
task
.
target_host
)
measure_batch
.
attach_objects
=
attach_objects
return
func
,
tuple
((
get_const_tuple
(
x
.
shape
),
x
.
dtype
)
for
x
in
args
)
return
measure_batch
def
rpc
(
key
,
def
default_build_func
(
measure_input
,
tmp_dir
,
**
kwargs
):
host
=
None
,
port
=
None
,
priority
=
1
,
session_timeout
=
60
,
pack_size
=
1
):
"""
"""
Create a standard measure_func which uses RPC Tracker for measurement.
Default build func. This can work for cuda, opencl, llvm backend
This measure_func will request a device from the RPC Tracker and
upload the built binary library to that device for measurement.
Parameters
Parameters
----------
----------
key: str
measure_input: MeasureInput
The registered key of the device in tracker. The tuner will request devices for
The input of measurement
measurement by this key.
tmp_dir: str
host: str, optional
The path of temporary directory to export generated library
The hostname of RPC Tracker. If not set, will use environment variable "TVM_TRACKER_HOST"
"""
port: int, optional
tic
=
time
.
time
()
The port of RPC Tracker. If not set, will use environment variable "TVM_TRACKER_PORT"
try
:
priority: int, optional
filename
=
os
.
path
.
join
(
tmp_dir
,
"tmp_func_
%0
x.tar"
%
getrandbits
(
64
))
Priority of this task, used by scheduler in tracker
func
,
arg_info
=
_build_func_common
(
measure_input
,
**
kwargs
)
session_timeout: int, optional
func
.
export_library
(
filename
)
Timeout of rpc session
except
Exception
as
e
:
# pylint: disable=broad-except
pack_size: int, optional
return
BuildResult
(
None
,
None
,
e
,
time
.
time
()
-
tic
)
The number of configs measure in one RPC session.
return
BuildResult
(
filename
,
arg_info
,
None
,
time
.
time
()
-
tic
)
Usually this can be set to 1. If your device has high overhead to establish a
rpc connection, set this higher.
def
android_ndk_build_func
(
measure_input
,
tmp_dir
,
**
kwargs
):
"""
Build function for android device using ndk.
Parameters
----------
measure_input: MeasureInput
The input of measurement
tmp_dir: str
The path of temporary directory to export generated library
"""
"""
def
fmeasure
(
input_pack
,
build_func
,
build_kwargs
,
number
,
repeat
,
ref_input
,
ref_output
):
tic
=
time
.
time
()
"""Do measurement for a list of inputs inside a same RPC session.
try
:
filename
=
os
.
path
.
join
(
tmp_dir
,
"tmp_func_
%0
x.so"
%
getrandbits
(
64
))
Parameters
func
,
arg_info
=
_build_func_common
(
measure_input
,
**
kwargs
)
----------
func
.
export_library
(
filename
,
ndk
.
create_shared
)
input_pack: List of MeasureInput
except
Exception
as
e
:
# pylint: disable=broad-except
The inputs of measurement
return
BuildResult
(
None
,
None
,
e
,
time
.
time
()
-
tic
)
build_func: callable
return
BuildResult
(
filename
,
arg_info
,
None
,
time
.
time
()
-
tic
)
Function for building the code. see :any:`default_build_func` for example
build_kwargs: dict
Extra arguments for build_func
def
run_through_rpc
(
measure_input
,
build_result
,
number : int, optional
number
,
repeat
,
cooldown_interval
,
Number of times to do the measurement for average
remote_args
,
ref_input
=
None
,
ref_output
=
None
):
repeat : int, optional
"""Run a generated library through rpc
Number of times to repeat the measurement.
In total, the generated code will be run (1 + number x repeat) times,
where the first one is warm up. The returned result contains `repeat` costs,
each of which is the average of `number` test run.
ref_input: List of numpy array
Reference input for correctness check
ref_output: List of numpy array
Reference output for correctness check
Returns
-------
results: List of MeasureResult
The results for input_pack
"""
remote_args
=
(
key
,
(
host
,
port
),
priority
,
session_timeout
)
res
=
_measure_common
(
input_pack
,
build_func
,
build_kwargs
,
number
,
repeat
,
ref_input
,
ref_output
,
remote_args
)
return
res
fmeasure
.
pack_size
=
pack_size
fmeasure
.
rpc_info
=
{
"key"
:
key
,
"host"
:
host
,
"port"
:
port
}
return
fmeasure
def
_measure_common
(
input_pack
,
build_func
,
build_kwargs
,
number
,
repeat
,
ref_input
=
None
,
ref_output
=
None
,
remote_args
=
None
):
"""Measure the time cost for a pack of inputs.
(Note: A pack is a list of inputs which will be measured inside a same RPC session)
Parameters
Parameters
----------
----------
input_pack : list of MeasureInput
measure_input: MeasureInput
The inputs we need to evaluate
The raw measure input
build_func : function takes MeasureInput returns tuple of (time_func, ctx, args)
build_result: BuildResult
The build function used to build each input.
The result returned from Builder. This contains the path to the generated library.
build_kwargs: Dict
The extra keyword arguments to build_func
number : int, optional
number : int, optional
Number of times to do
the measurement for
average
Number of times to do
measurement for tasking
average
repeat : int, optional
repeat : int, optional
Number of times to repeat the measurement.
Number of times to repeat the measurement.
In total, the generated code will be run (1 + number x repeat) times,
In total, the generated code will be run (1 + number x repeat) times,
where the first one is warm up. The returned result contains `repeat` costs,
where the first one is warm up. The returned result contains `repeat` costs,
each of which is the average of `number` test run.
each of which is the average of `number` test run.
ref_input: Array of np.ndarray, optional
cooldown_interval: float
Reference input for checking correctness
The cool down interval between two measurements
ref_output: Array of np.ndarray, optional
remote_args: Tuple
Reference output for checking correctness
The argument for request_remote
remote_args: Tuple, optional
ref_input: List of np.ndarray
The arguments to request_remote. If is not None, will use remote rpc devices.
The reference input used for checking correctness
ref_output: List of np.ndarray
Returns
The reference output used for checking correctness
-------
res_pack : Array of MeasureResult
The list of results of measurement.
"""
"""
res_pack
=
[]
if
isinstance
(
build_result
,
MeasureResult
):
tmp_dir
=
util
.
tempdir
()
if
remote_args
else
None
return
build_result
assert
len
(
input_pack
)
==
1
,
"Only supports input_pack == 1 for now"
tic
=
time
.
time
()
for
inp
in
input_pack
:
errno
=
MeasureErrorNo
.
NO_ERROR
tic
=
time
.
time
()
try
:
# upload built module
# build function
remote
=
request_remote
(
*
remote_args
)
try
:
remote
.
upload
(
build_result
.
filename
)
func
,
arg_bufs
,
filename
=
build_func
(
inp
,
tmp_dir
,
**
build_kwargs
)
func
=
remote
.
load_module
(
os
.
path
.
split
(
build_result
.
filename
)[
1
])
except
TVMError
as
exc
:
ctx
=
remote
.
context
(
str
(
measure_input
.
target
),
0
)
tstamp
=
time
.
time
()
time_f
=
func
.
time_evaluator
(
msg
=
str
(
exc
)
func
.
entry_name
,
ctx
,
number
=
number
,
repeat
=
repeat
)
if
"Stack trace returned"
in
msg
:
msg
=
msg
[:
msg
.
index
(
"Stack trace returned"
)]
# set input
if
"InstantiationError"
in
msg
:
if
ref_input
:
try
:
args
=
[
nd
.
array
(
x
,
ctx
=
ctx
)
for
x
in
ref_input
]
msg
=
msg
.
split
(
'
\n
'
)[
-
2
]
.
split
(
": "
)[
1
]
else
:
except
Exception
:
# pylint: disable=broad-except
args
=
[
nd
.
empty
(
x
[
0
],
dtype
=
x
[
1
],
ctx
=
ctx
)
for
x
in
build_result
.
arg_info
]
pass
res_pack
.
append
(
MeasureResult
((
InstantiationError
(
msg
),),
costs
=
time_f
(
*
args
)
.
results
MeasureErrorNo
.
INSTANTIATION_ERROR
,
if
len
(
costs
)
>
2
:
# remove largest and smallest value to reduce variance
tstamp
-
tic
,
tstamp
))
costs
=
list
(
costs
)
else
:
costs
.
sort
()
res_pack
.
append
(
MeasureResult
((
RuntimeError
(
msg
),),
costs
=
tuple
(
costs
[
1
:
-
1
])
MeasureErrorNo
.
COMPILE_HOST
,
tstamp
-
tic
,
tstamp
))
# check correctness of output
continue
if
ref_output
:
except
InstantiationError
as
e
:
for
expected
,
real
in
zip
(
ref_output
,
args
):
tstamp
=
time
.
time
()
if
not
np
.
allclose
(
expected
,
real
.
asnumpy
(),
rtol
=
1e-4
):
res_pack
.
append
(
MeasureResult
((
InstantiationError
(
str
(
e
)),),
logger
.
warning
(
"Wrong Answer!"
)
MeasureErrorNo
.
INSTANTIATION_ERROR
,
errno
=
MeasureErrorNo
.
WRONG_ANSWER
tstamp
-
tic
,
tstamp
))
except
TVMError
as
exc
:
continue
msg
=
str
(
exc
)
if
"Stack trace returned"
in
msg
:
# measure time
msg
=
msg
[:
msg
.
index
(
"Stack trace returned"
)]
errno
=
MeasureErrorNo
.
NO_ERROR
if
"CUDA Source"
in
msg
:
try
:
msg
=
msg
[:
msg
.
index
(
"CUDA Source"
)]
# upload built module
costs
=
(
RuntimeError
(
msg
[:
1024
]),)
if
remote_args
:
errno
=
MeasureErrorNo
.
RUNTIME_DEVICE
remote
=
request_remote
(
*
remote_args
)
tstamp
=
time
.
time
()
remote
.
upload
(
tmp_dir
.
relpath
(
filename
))
time
.
sleep
(
cooldown_interval
)
func
=
remote
.
load_module
(
filename
)
return
MeasureResult
(
costs
,
errno
,
tstamp
-
tic
+
build_result
.
time_cost
,
tstamp
)
ctx
=
remote
.
context
(
str
(
inp
.
target
),
0
)
time_f
=
func
.
time_evaluator
(
func
.
entry_name
,
ctx
,
number
=
number
,
repeat
=
repeat
)
def
request_remote
(
device_key
,
host
=
None
,
port
=
None
,
priority
=
1
,
timeout
=
60
):
else
:
"""Request a remote session
ctx
=
context
(
str
(
inp
.
target
),
0
)
time_f
=
func
.
time_evaluator
(
func
.
entry_name
,
ctx
,
number
=
number
,
repeat
=
repeat
)
# set input
if
ref_input
:
args
=
[
nd
.
array
(
x
,
ctx
=
ctx
)
for
x
in
ref_input
]
else
:
args
=
[
nd
.
empty
(
get_const_tuple
(
x
.
shape
),
dtype
=
x
.
dtype
,
ctx
=
ctx
)
for
x
in
arg_bufs
]
costs
=
time_f
(
*
args
)
.
results
if
len
(
costs
)
>
2
:
# remove largest and smallest value to reduce variance
costs
=
list
(
costs
)
costs
.
sort
()
costs
=
tuple
(
costs
[
1
:
-
1
])
# check correctness of output
if
ref_output
:
for
expected
,
real
in
zip
(
ref_output
,
args
):
if
not
np
.
allclose
(
expected
,
real
.
asnumpy
(),
rtol
=
1e-4
):
logger
.
warning
(
"Wrong Answer!"
)
errno
=
MeasureErrorNo
.
WRONG_ANSWER
except
TVMError
as
exc
:
msg
=
str
(
exc
)
if
"Stack trace returned"
in
msg
:
msg
=
msg
[:
msg
.
index
(
"Stack trace returned"
)]
if
"CUDA Source"
in
msg
:
msg
=
msg
[:
msg
.
index
(
"CUDA Source"
)]
costs
=
(
RuntimeError
(
msg
),)
errno
=
MeasureErrorNo
.
RUNTIME_DEVICE
tstamp
=
time
.
time
()
res_pack
.
append
(
MeasureResult
(
costs
,
errno
,
tstamp
-
tic
,
tstamp
))
return
res_pack
def
default_build_func
(
inp
,
tmp_dir
=
None
,
**
kwargs
):
"""Build function module. Exception will be raised when any error occurs
Parameters
Parameters
----------
----------
inp: MeasureInput
device_key: string
The input of this measurement
The device key of registered device in tracker
tmp_dir: tvm.contrib.util.TempDirectory, optional
host: host, optional
The temporary directory for exporting built binary library.
The host address of rpc tracker.
If is not None (in RPC mode), the library in this directory will be uploaded to
If is none, will use environment variable "TVM_TRACKER_HOST"
remote devices.
port: int, optional
kwargs: Dict, optional
The port of rpc tracker.
Other extra arguments
If is none, will use environment variable "TVM_TRACKER_PORT"
priority: int, optional
The priority of this request, larger is more prior
timeout: float, optional
The timeout of this session (units: second)
Returns
Returns
-------
------
func: Function
session: RPCSession
TVM built function. Typically this is the return value of tvm.build.
args: Array of Buffer or Tensor
The argument list for the function. Typically this is the second argument of tvm.build.
filename: str
The filename of the output build library
"""
"""
#
build function
#
connect to the tracker
with
inp
.
target
:
host
=
host
or
os
.
environ
[
'TVM_TRACKER_HOST'
]
s
,
args
=
inp
.
task
.
instantiate
(
inp
.
config
)
port
=
port
or
int
(
os
.
environ
[
'TVM_TRACKER_PORT'
]
)
# check invalidity of template and code hash consistency
tracker
=
_rpc
.
connect_tracker
(
host
,
port
)
if
not
inp
.
config
.
valid
():
remote
=
tracker
.
request
(
device_key
,
priority
=
priority
,
raise
InstantiationError
(
inp
.
config
.
errors
)
session_timeout
=
timeout
)
code_hash
=
getattr
(
s
,
'code_hash'
,
None
)
return
remote
if
inp
.
config
.
code_hash
!=
code_hash
:
raise
HashMismatchError
(
'got {0:s}, expected {1:s}'
.
format
(
str
(
inp
.
config
.
code_hash
),
str
(
code_hash
)))
opts
=
{}
if
"check_gpu"
in
kwargs
:
# Add verify pass to filter out invalid configs in advance.
opts
[
"add_lower_pass"
]
=
[(
2
,
gpu_verify_pass
(
**
kwargs
[
'check_gpu'
]))]
if
'cuda_arch'
in
kwargs
:
set_cuda_target_arch
(
kwargs
[
'cuda_arch'
])
with
build_config
(
**
opts
):
func
=
build
(
s
,
args
,
target_host
=
inp
.
task
.
target_host
)
# export library to temp directory
def
check_remote
(
target
,
device_key
,
host
=
None
,
port
=
None
,
priority
=
2
,
timeout
=
10
):
if
tmp_dir
:
"""
if
kwargs
.
get
(
'use_ndk'
,
False
):
# for Android NDK
Check the availability of a remote device
filename
=
"tmp_func_
%0
x.so"
%
getrandbits
(
64
)
func
.
export_library
(
tmp_dir
.
relpath
(
filename
),
ndk
.
create_shared
)
else
:
filename
=
"tmp_func_
%0
x.tar"
%
getrandbits
(
64
)
func
.
export_library
(
tmp_dir
.
relpath
(
filename
))
else
:
filename
=
None
return
func
,
args
,
filename
def
add_gpu_target_info
(
target
,
device_key
,
rpc_tracker_addr
,
kwargs
):
"""Add device info for gpu target.
The info will be used to check the validity of generated code."""
remote
=
request_remote
(
device_key
,
rpc_tracker_addr
)
ctx
=
remote
.
context
(
str
(
target
),
0
)
max_dims
=
ctx
.
max_thread_dimensions
kwargs
[
'check_gpu'
]
=
{
'max_shared_memory_per_block'
:
ctx
.
max_shared_memory_per_block
,
'max_threads_per_block'
:
ctx
.
max_threads_per_block
,
'max_thread_x'
:
max_dims
[
0
],
'max_thread_y'
:
max_dims
[
1
],
'max_thread_z'
:
max_dims
[
2
],
}
if
'cuda'
in
target
.
keys
:
kwargs
[
"cuda_arch"
]
=
"sm_"
+
""
.
join
(
ctx
.
compute_version
.
split
(
'.'
))
def
set_cuda_target_arch
(
arch
):
Parameters
"""set target architecture of nvcc compiler"""
----------
AutotvmGlobalScope
.
current
.
cuda_target_arch
=
arch
target: Target
The wanted compilation target
device_key: string
device key of registered device in tracker
host: host, optional
The host address of rpc tracker.
If is none, will use environment variable "TVM_TRACKER_HOST"
port: int, optional
The port address of rpc tracker.
If is none, will use environment variable "TVM_TRACKER_PORT"
priority: int, optional
The priority of this request, larger is more prior
timeout: float, optional
The timeout of this check (units: seconds).
Returns
-------
available: bool
True if can find available device
"""
def
_check
():
remote
=
request_remote
(
device_key
,
host
,
port
,
priority
)
remote
.
context
(
str
(
target
))
t
=
threading
.
Thread
(
target
=
_check
,)
t
.
start
()
t
.
join
(
timeout
)
return
not
t
.
is_alive
()
@register_func
@register_func
...
@@ -496,6 +561,17 @@ def tvm_callback_cuda_compile(code):
...
@@ -496,6 +561,17 @@ def tvm_callback_cuda_compile(code):
return
ptx
return
ptx
def
set_cuda_target_arch
(
arch
):
"""set target architecture of nvcc compiler
Parameters
----------
arch: str
The argument of nvcc -arch. (e.g. "sm_51", "sm_62")
"""
AutotvmGlobalScope
.
current
.
cuda_target_arch
=
arch
def
gpu_verify_pass
(
**
kwargs
):
def
gpu_verify_pass
(
**
kwargs
):
"""Verify the validity of a gpu kernel.
"""Verify the validity of a gpu kernel.
This pass will check memory usage and number of threads per block.
This pass will check memory usage and number of threads per block.
...
...
python/tvm/autotvm/tuner/ga_tuner.py
View file @
12839e6d
...
@@ -22,7 +22,7 @@ class GATuner(Tuner):
...
@@ -22,7 +22,7 @@ class GATuner(Tuner):
mutation_prob: float
mutation_prob: float
probability of mutation of a knob in a gene
probability of mutation of a knob in a gene
"""
"""
def
__init__
(
self
,
task
,
pop_size
,
elite_num
=
3
,
mutation_prob
=
0.1
):
def
__init__
(
self
,
task
,
pop_size
=
100
,
elite_num
=
3
,
mutation_prob
=
0.1
):
super
(
GATuner
,
self
)
.
__init__
(
task
)
super
(
GATuner
,
self
)
.
__init__
(
task
)
# algorithm configurations
# algorithm configurations
...
...
python/tvm/autotvm/tuner/sa_model_optimizer.py
View file @
12839e6d
...
@@ -87,7 +87,7 @@ class SimulatedAnnealingOptimizer(ModelOptimizer):
...
@@ -87,7 +87,7 @@ class SimulatedAnnealingOptimizer(ModelOptimizer):
new_scores
=
model
.
predict
(
new_points
)
new_scores
=
model
.
predict
(
new_points
)
ac_prob
=
np
.
exp
(
(
new_scores
-
scores
)
/
(
t
+
1e-2
))
ac_prob
=
np
.
exp
(
np
.
minimum
((
new_scores
-
scores
)
/
(
t
+
1e-5
),
1
))
ac_index
=
np
.
random
.
random
(
len
(
ac_prob
))
<
ac_prob
ac_index
=
np
.
random
.
random
(
len
(
ac_prob
))
<
ac_prob
points
[
ac_index
]
=
new_points
[
ac_index
]
points
[
ac_index
]
=
new_points
[
ac_index
]
...
...
tests/python/integration/test_tuning.py
View file @
12839e6d
...
@@ -103,34 +103,7 @@ def get_sample_task(target=tvm.target.cuda(), target_host=None):
...
@@ -103,34 +103,7 @@ def get_sample_task(target=tvm.target.cuda(), target_host=None):
target
=
target
,
target_host
=
target_host
)
target
=
target
,
target_host
=
target_host
)
return
task
,
target
return
task
,
target
def
test_tuning
():
def
test_task_tuner_without_measurement
():
"""test task and tuner without measurement"""
task
,
target
=
get_sample_task
()
def
custom_measure
(
input_pack
,
build_func
,
build_args
,
number
,
repeat
,
ref_input
,
ref_output
):
from
tvm.autotvm
import
MeasureResult
results
=
[]
for
inp
in
input_pack
:
tic
=
time
.
time
()
# do nothing
time
.
sleep
(
0.001
)
results
.
append
(
MeasureResult
([
time
.
time
()
-
tic
],
0
,
time
.
time
()
-
tic
,
time
.
time
()))
return
results
measure_option
=
autotvm
.
measure_option
(
custom_measure
)
logging
.
info
(
"
%
s"
,
task
.
config_space
)
# new tuner and recorder
for
tuner_class
in
[
autotvm
.
tuner
.
RandomTuner
,
autotvm
.
tuner
.
GridSearchTuner
]:
tuner
=
tuner_class
(
task
)
tuner
.
tune
(
n_trial
=
10
,
measure_option
=
measure_option
)
assert
tuner
.
best_flops
>
1
def
test_tuning_with_measure
():
def
check
(
target
,
target_host
):
def
check
(
target
,
target_host
):
ctx
=
tvm
.
context
(
target
,
0
)
ctx
=
tvm
.
context
(
target
,
0
)
if
not
ctx
.
exist
:
if
not
ctx
.
exist
:
...
@@ -141,12 +114,12 @@ def test_tuning_with_measure():
...
@@ -141,12 +114,12 @@ def test_tuning_with_measure():
task
,
target
=
get_sample_task
(
target
,
target_host
)
task
,
target
=
get_sample_task
(
target
,
target_host
)
logging
.
info
(
"
%
s"
,
task
.
config_space
)
logging
.
info
(
"
%
s"
,
task
.
config_space
)
measure_option
=
autotvm
.
measure_option
(
'local'
,
measure_option
=
autotvm
.
measure_option
(
timeout
=
4
,
autotvm
.
LocalBuilder
()
,
number
=
2
)
autotvm
.
LocalRunner
()
)
tuner
=
RandomTuner
(
task
)
tuner
=
RandomTuner
(
task
)
tuner
.
tune
(
n_trial
=
1
0
,
measure_option
=
measure_option
)
tuner
.
tune
(
n_trial
=
2
0
,
measure_option
=
measure_option
)
check
(
"cuda"
,
None
)
check
(
"cuda"
,
None
)
check
(
"opencl"
,
None
)
check
(
"opencl"
,
None
)
...
@@ -155,6 +128,4 @@ if __name__ == "__main__":
...
@@ -155,6 +128,4 @@ if __name__ == "__main__":
# only print log when invoked from main
# only print log when invoked from main
logging
.
basicConfig
(
level
=
logging
.
DEBUG
)
logging
.
basicConfig
(
level
=
logging
.
DEBUG
)
test_task_tuner_without_measurement
()
test_tuning
()
test_tuning_with_measure
()
tests/python/unittest/test_autotvm_common.py
View file @
12839e6d
...
@@ -32,6 +32,25 @@ def matmul(N, L, M, dtype):
...
@@ -32,6 +32,25 @@ def matmul(N, L, M, dtype):
return
s
,
[
A
,
B
,
C
]
return
s
,
[
A
,
B
,
C
]
@autotvm.template
def
bad_matmul
(
N
,
L
,
M
,
dtype
):
if
'bad_device'
in
tvm
.
target
.
current_target
()
.
keys
:
A
=
tvm
.
placeholder
((
N
,
L
),
name
=
'A'
,
dtype
=
dtype
)
B
=
tvm
.
placeholder
((
L
,
M
),
name
=
'B'
,
dtype
=
dtype
)
k
=
tvm
.
reduce_axis
((
0
,
L
-
1
),
name
=
'k'
)
C
=
tvm
.
compute
((
N
,
M
),
lambda
i
,
j
:
tvm
.
sum
(
A
[
i
,
k
]
*
B
[
k
,
j
],
axis
=
k
),
name
=
'C'
)
s
=
tvm
.
create_schedule
(
C
.
op
)
# schedule
y
,
x
=
s
[
C
]
.
op
.
axis
cfg
=
autotvm
.
get_config
()
cfg
.
define_split
(
"tile_y"
,
y
,
num_outputs
=
2
)
cfg
.
define_split
(
"tile_x"
,
x
,
num_outputs
=
2
)
return
s
,
[
A
,
B
,
C
]
return
matmul
(
N
,
L
,
M
,
dtype
)
def
get_sample_task
(
n
=
128
):
def
get_sample_task
(
n
=
128
):
"""return a sample task for testing"""
"""return a sample task for testing"""
target
=
tvm
.
target
.
create
(
"llvm"
)
target
=
tvm
.
target
.
create
(
"llvm"
)
...
...
tests/python/unittest/test_autotvm_database.py
View file @
12839e6d
"""Test database"""
"""Test database"""
import
copy
import
copy
import
logging
import
logging
import
time
import
numpy
as
np
import
tvm
from
tvm
import
autotvm
from
tvm.autotvm
import
database
from
tvm.autotvm
import
database
from
tvm.autotvm.measure.measure_methods
import
HashMismatchError
from
tvm.autotvm.record
import
encode
,
MeasureResult
from
tvm.autotvm.record
import
encode
,
MeasureInput
,
MeasureResult
from
test_autotvm_common
import
get_sample_
task
,
get_sample_
records
from
test_autotvm_common
import
get_sample_records
def
test_save_load
():
def
test_save_load
():
logging
.
info
(
"test basic db load/save ..."
)
logging
.
info
(
"test basic db load/save ..."
)
...
@@ -35,66 +29,6 @@ def test_save_load():
...
@@ -35,66 +29,6 @@ def test_save_load():
TRIAL_LIMIT
=
2
TRIAL_LIMIT
=
2
def
test_db_filter
():
logging
.
info
(
"test db filter ..."
)
# Pick a GPU target because there are more likely to be failures/invalid configs
task
,
target
=
get_sample_task
()
ctx
=
tvm
.
context
(
str
(
target
))
if
not
ctx
.
exist
:
logging
.
warning
(
"Skip this test because there is no supported device for test"
)
batch_size
=
2
measure_option
=
autotvm
.
measure_option
(
'local'
,
do_fork
=
False
,
timeout
=
2
)
measure_batch
=
autotvm
.
measure
.
create_measure_batch
(
task
,
measure_option
)
ct
=
0
all_inputs
=
list
()
all_results
=
list
()
batches
=
list
()
tuner
=
autotvm
.
tuner
.
RandomTuner
(
task
)
while
ct
<
TRIAL_LIMIT
:
inputs
=
list
()
for
i
in
range
(
batch_size
):
cfg
=
tuner
.
next_batch
(
1
)[
0
]
inputs
.
append
((
MeasureInput
(
target
,
task
,
cfg
)))
all_inputs
.
append
(
inputs
[
-
1
])
batches
.
append
(
inputs
)
results
=
measure_batch
(
inputs
)
all_results
+=
results
ct
+=
1
del
measure_batch
db
=
database
.
DummyDatabase
()
db
.
flush
()
# First setting, memoize one input at a time, check that each is saved and replayed
measure_option
=
autotvm
.
measure_option
(
'local'
,
do_fork
=
False
,
timeout
=
2
,
replay_db
=
db
)
measure_batch
=
autotvm
.
measure
.
create_measure_batch
(
task
,
measure_option
)
for
i
in
range
(
len
(
all_inputs
)
+
1
):
db
.
flush
()
for
j
in
range
(
i
):
db
.
save
(
all_inputs
[
j
],
all_results
[
j
])
for
k
in
range
(
len
(
batches
)):
batch
=
batches
[
k
]
batch_result
=
measure_batch
(
batch
)
for
l
in
range
(
batch_size
):
all_idx
=
k
*
batch_size
+
l
assert
batch_result
[
l
]
is
not
None
if
all_idx
<
i
:
assert
encode
(
batch
[
l
],
batch_result
[
l
])
==
encode
(
batch
[
l
],
all_results
[
all_idx
]),
\
"(no retry) EXPECTED MATCH, GOT MISMATCH"
else
:
assert
encode
(
batch
[
l
],
batch_result
[
l
])
!=
encode
(
batch
[
l
],
all_results
[
all_idx
]),
\
"(no retry) EXPECTED MISMATCH, GOT MATCH"
del
measure_batch
def
test_db_hash
():
def
test_db_hash
():
logging
.
info
(
"test db hash check ..."
)
logging
.
info
(
"test db hash check ..."
)
inp1
,
res1
=
get_sample_records
(
1
)[
0
]
inp1
,
res1
=
get_sample_records
(
1
)[
0
]
...
@@ -149,89 +83,8 @@ def test_db_latest_all():
...
@@ -149,89 +83,8 @@ def test_db_latest_all():
assert
encode
(
inp1
,
load4
[
1
])
==
encode
(
inp1
,
res2
)
assert
encode
(
inp1
,
load4
[
1
])
==
encode
(
inp1
,
res2
)
assert
encode
(
inp1
,
load4
[
2
])
==
encode
(
inp1
,
res3
)
assert
encode
(
inp1
,
load4
[
2
])
==
encode
(
inp1
,
res3
)
def
test_db_save_replay
():
logging
.
info
(
"test db save (from measure_batch) and replay ..."
)
_db
=
database
.
DummyDatabase
()
_db
.
flush
()
task
,
target
=
get_sample_task
()
ctx
=
tvm
.
context
(
str
(
target
))
if
not
ctx
.
exist
:
logging
.
warning
(
"Skip this test because there is no supported device for test"
)
measure_option
=
autotvm
.
measure_option
(
'local'
,
do_fork
=
False
,
timeout
=
2
,
replay_db
=
_db
)
measure_batch
=
autotvm
.
measure
.
create_measure_batch
(
task
,
measure_option
)
batch_size
=
2
ct
=
0
all_inputs
=
list
()
all_results
=
list
()
batches
=
list
()
tuner
=
autotvm
.
tuner
.
RandomTuner
(
task
)
while
ct
<
TRIAL_LIMIT
:
inputs
=
list
()
for
i
in
range
(
batch_size
):
cfg
=
tuner
.
next_batch
(
1
)[
0
]
inputs
.
append
((
MeasureInput
(
target
,
task
,
cfg
)))
all_inputs
.
append
(
inputs
[
-
1
])
batches
.
append
(
inputs
)
results
=
measure_batch
(
inputs
)
all_results
+=
results
ct
+=
1
callback
=
autotvm
.
callback
.
log_to_database
(
_db
)
callback
(
None
,
all_inputs
,
all_results
)
assert
len
(
_db
.
db
.
keys
())
==
batch_size
*
TRIAL_LIMIT
,
\
"
%
d vs
%
d"
%
(
len
(
_db
.
db
.
keys
()),
batch_size
*
TRIAL_LIMIT
)
all_results_2
=
measure_batch
(
all_inputs
)
all_results_3
=
measure_batch
(
all_inputs
)
for
i
in
range
(
len
(
all_results
)):
encr1
=
encode
(
all_inputs
[
i
],
all_results
[
i
])
encr2
=
encode
(
all_inputs
[
i
],
all_results_2
[
i
])
encr3
=
encode
(
all_inputs
[
i
],
all_results_3
[
i
])
assert
encr1
==
encr2
,
"EXPECTED MATCH WITH SAVE REPLAY (first replay), got MISMATCH"
assert
encr2
==
encr3
,
"EXPECTED MATCH WITH SAVE REPLAY (second replay), got MISMATCH"
del
measure_batch
def
test_check_hashmismatch
():
logging
.
info
(
"test hash mismatch check"
)
task
,
target
=
get_sample_task
()
ctx
=
tvm
.
context
(
str
(
target
))
if
not
ctx
.
exist
:
logging
.
warning
(
"Skip this test because there is no supported device for test"
)
measure_option
=
autotvm
.
measure_option
(
'local'
,
do_fork
=
False
)
measure_batch
=
autotvm
.
measure
.
create_measure_batch
(
task
,
measure_option
)
inputs
=
list
()
cfg
=
task
.
config_space
.
get
(
np
.
random
.
randint
(
len
(
task
.
config_space
)))
# notvalidh is not a valid CRC32 hash (not hex)
cfg
.
code_hash
=
'notvalidh'
inputs
.
append
((
MeasureInput
(
target
,
task
,
cfg
)))
try
:
results
=
measure_batch
(
inputs
)
assert
False
,
"HashMismatchError should be raised"
except
HashMismatchError
:
pass
del
measure_batch
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
logging
.
basicConfig
(
level
=
logging
.
INFO
)
logging
.
basicConfig
(
level
=
logging
.
INFO
)
test_save_load
()
test_save_load
()
test_db_filter
()
test_db_hash
()
test_db_hash
()
test_db_latest_all
()
test_db_latest_all
()
test_db_save_replay
()
test_check_hashmismatch
()
tests/python/unittest/test_autotvm_measure.py
0 → 100644
View file @
12839e6d
"""Test builder and runner"""
import
logging
import
time
import
numpy
as
np
import
tvm
from
tvm
import
autotvm
from
test_autotvm_common
import
get_sample_task
,
bad_matmul
from
tvm.autotvm.measure.measure
import
Runner
,
MeasureResult
,
MeasureErrorNo
def
test_task_tuner_without_measurement
():
"""test task and tuner without measurement"""
task
,
target
=
get_sample_task
()
class
DummyRunner
(
Runner
):
def
__init__
(
self
):
super
(
DummyRunner
,
self
)
.
__init__
(
1
,
1
)
def
run
(
self
,
measure_inputs
,
build_results
):
return
[
MeasureResult
((
np
.
random
.
random
(),),
0
,
0.2
,
time
.
time
())
for
_
in
range
(
len
(
measure_inputs
))]
def
get_build_kwargs
(
self
):
return
{}
measure_option
=
autotvm
.
measure_option
(
builder
=
autotvm
.
LocalBuilder
(),
runner
=
DummyRunner
()
)
logging
.
info
(
"
%
s"
,
task
.
config_space
)
for
tuner_class
in
[
autotvm
.
tuner
.
RandomTuner
,
autotvm
.
tuner
.
GridSearchTuner
,
autotvm
.
tuner
.
GATuner
,
autotvm
.
tuner
.
XGBTuner
]:
tuner
=
tuner_class
(
task
)
tuner
.
tune
(
n_trial
=
10
,
measure_option
=
measure_option
)
assert
tuner
.
best_flops
>
1
def
test_check_correctness
():
task
,
target
=
get_sample_task
()
measure_option
=
autotvm
.
measure_option
(
builder
=
autotvm
.
LocalBuilder
(),
runner
=
autotvm
.
LocalRunner
(
check_correctness
=
True
)
)
def
_callback_correct
(
tuner
,
measure_inputs
,
measure_results
):
for
inp
,
res
in
zip
(
measure_inputs
,
measure_results
):
assert
res
.
error_no
==
0
tuner
=
autotvm
.
tuner
.
RandomTuner
(
task
)
tuner
.
tune
(
n_trial
=
2
,
measure_option
=
measure_option
,
callbacks
=
[
_callback_correct
])
# a bad template
n
=
128
target
=
tvm
.
target
.
create
(
"llvm -device=bad_device"
)
task
=
autotvm
.
task
.
create
(
bad_matmul
,
args
=
(
n
,
n
,
n
,
'float32'
),
target
=
target
)
def
_callback_wrong
(
tuner
,
measure_inputs
,
measure_results
):
for
inp
,
res
in
zip
(
measure_inputs
,
measure_results
):
assert
res
.
error_no
==
MeasureErrorNo
.
WRONG_ANSWER
tuner
=
autotvm
.
tuner
.
RandomTuner
(
task
)
tuner
.
tune
(
n_trial
=
2
,
measure_option
=
measure_option
,
callbacks
=
[
_callback_wrong
])
def
test_min_repeat_ms
():
task
,
target
=
get_sample_task
()
measure_option
=
autotvm
.
measure_option
(
builder
=
autotvm
.
LocalBuilder
(),
runner
=
autotvm
.
LocalRunner
(
number
=
1
,
min_repeat_ms
=
100
)
)
def
_callback
(
tuner
,
measure_inputs
,
measure_results
):
for
inp
,
res
in
zip
(
measure_inputs
,
measure_results
):
if
res
.
error_no
!=
0
:
continue
assert
1000
*
np
.
mean
(
res
.
costs
)
*
\
measure_option
[
'runner'
]
.
cur_number
>=
100
tuner
=
autotvm
.
tuner
.
RandomTuner
(
task
)
tuner
.
tune
(
n_trial
=
5
,
measure_option
=
measure_option
,
callbacks
=
[
_callback
])
if
__name__
==
'__main__'
:
logging
.
basicConfig
(
level
=
logging
.
INFO
)
test_task_tuner_without_measurement
()
test_check_correctness
()
test_min_repeat_ms
()
topi/recipe/gemm/gemm_int8.py
View file @
12839e6d
...
@@ -137,12 +137,15 @@ if __name__ == '__main__':
...
@@ -137,12 +137,15 @@ if __name__ == '__main__':
print
(
task
.
config_space
)
print
(
task
.
config_space
)
measure_option
=
autotvm
.
measure_option
(
measure_option
=
autotvm
.
measure_option
(
measure_func
=
'local'
,
number
=
10
,
n_parallel
=
8
,
timeout
=
20
)
builder
=
autotvm
.
LocalBuilder
(),
runner
=
autotvm
.
LocalRunner
(
repeat
=
3
,
min_repeat_ms
=
100
,
timeout
=
4
)
)
log_name
=
'gemm_int8.log'
log_name
=
'gemm_int8.log'
if
DO_TUNING
:
if
DO_TUNING
:
tuner
=
autotvm
.
tuner
.
XGBTuner
(
task
)
tuner
=
autotvm
.
tuner
.
XGBTuner
(
task
)
tuner
.
tune
(
n_trial
=
1000
,
measure_option
=
measure_option
,
tuner
.
tune
(
n_trial
=
1000
,
measure_option
=
measure_option
,
callbacks
=
[
autotvm
.
callback
.
log_to_file
(
log_name
)])
callbacks
=
[
autotvm
.
callback
.
log_to_file
(
log_name
)])
dispatch_context
=
autotvm
.
apply_history_best
(
log_name
)
dispatch_context
=
autotvm
.
apply_history_best
(
log_name
)
best_config
=
dispatch_context
.
query
(
task
.
target
,
task
.
workload
)
best_config
=
dispatch_context
.
query
(
task
.
target
,
task
.
workload
)
...
...
tutorials/autotvm/tune_conv2d_cuda.py
View file @
12839e6d
...
@@ -164,12 +164,12 @@ task = autotvm.task.create(conv2d_no_batching,
...
@@ -164,12 +164,12 @@ task = autotvm.task.create(conv2d_no_batching,
target
=
'cuda'
)
target
=
'cuda'
)
print
(
task
.
config_space
)
print
(
task
.
config_space
)
# use local gpu, measure
5
times for every config to reduce variance
# use local gpu, measure
10
times for every config to reduce variance
#
run 8 parallel threads for compilation
#
The timeout of compiling a program is 10 seconds, the timeout for running is 4 seconds
measure_option
=
autotvm
.
measure_option
(
'local'
,
measure_option
=
autotvm
.
measure_option
(
number
=
5
,
builder
=
autotvm
.
LocalBuilder
()
,
n_parallel
=
8
,
runner
=
autotvm
.
LocalRunner
(
repeat
=
3
,
min_repeat_ms
=
100
,
timeout
=
4
)
timeout
=
20
)
)
# begin tuning, log records to file `conv2d.log`
# begin tuning, log records to file `conv2d.log`
tuner
=
autotvm
.
tuner
.
XGBTuner
(
task
)
tuner
=
autotvm
.
tuner
.
XGBTuner
(
task
)
...
...
tutorials/autotvm/tune_nnvm_arm.py
View file @
12839e6d
...
@@ -65,15 +65,20 @@ def get_network(name, batch_size):
...
@@ -65,15 +65,20 @@ def get_network(name, batch_size):
input_shape
=
(
batch_size
,
3
,
224
,
224
)
input_shape
=
(
batch_size
,
3
,
224
,
224
)
output_shape
=
(
batch_size
,
1000
)
output_shape
=
(
batch_size
,
1000
)
if
name
==
'resnet-18'
:
if
"resnet"
in
name
:
net
,
params
=
nnvm
.
testing
.
resnet
.
get_workload
(
num_layers
=
18
,
batch_size
=
batch_size
)
n_layer
=
int
(
name
.
split
(
'-'
)[
1
])
elif
name
==
'mobilenet'
:
net
,
params
=
nnvm
.
testing
.
resnet
.
get_workload
(
num_layers
=
n_layer
,
batch_size
=
batch_size
)
elif
"vgg"
in
name
:
n_layer
=
int
(
name
.
split
(
'-'
)[
1
])
net
,
params
=
nnvm
.
testing
.
vgg
.
get_workload
(
num_layers
=
n_layer
,
batch_size
=
batch_size
)
elif
name
==
'mobilenet'
:
net
,
params
=
nnvm
.
testing
.
mobilenet
.
get_workload
(
batch_size
=
batch_size
)
net
,
params
=
nnvm
.
testing
.
mobilenet
.
get_workload
(
batch_size
=
batch_size
)
elif
name
==
'squeezenet
v1.1'
:
elif
name
==
'squeezenet_
v1.1'
:
net
,
params
=
nnvm
.
testing
.
squeezenet
.
get_workload
(
batch_size
=
batch_size
,
version
=
'1.1'
)
net
,
params
=
nnvm
.
testing
.
squeezenet
.
get_workload
(
batch_size
=
batch_size
,
version
=
'1.1'
)
elif
name
==
'vgg-16'
:
elif
name
==
'inception_v3'
:
net
,
params
=
nnvm
.
testing
.
vgg
.
get_workload
(
num_layers
=
16
,
batch_size
=
batch_size
)
input_shape
=
(
1
,
3
,
299
,
299
)
elif
name
==
'custom'
:
net
,
params
=
nnvm
.
testing
.
inception_v3
.
get_workload
(
batch_size
=
batch_size
)
elif
name
==
'custom'
:
# an example for custom network
# an example for custom network
from
nnvm.testing
import
utils
from
nnvm.testing
import
utils
net
=
nnvm
.
sym
.
Variable
(
'data'
)
net
=
nnvm
.
sym
.
Variable
(
'data'
)
...
@@ -92,6 +97,7 @@ def get_network(name, batch_size):
...
@@ -92,6 +97,7 @@ def get_network(name, batch_size):
return
net
,
params
,
input_shape
,
output_shape
return
net
,
params
,
input_shape
,
output_shape
#################################################################
#################################################################
# Start RPC Tracker
# Start RPC Tracker
# -----------------
# -----------------
...
@@ -158,6 +164,8 @@ def get_network(name, batch_size):
...
@@ -158,6 +164,8 @@ def get_network(name, batch_size):
# rk3399 2 2 0
# rk3399 2 2 0
# rpi3b 11 11 0
# rpi3b 11 11 0
# ----------------------------------
# ----------------------------------
#
# You can register multiple devices to the tracker to accelerate the measurement in tuning.
###########################################
###########################################
# Set Tuning Options
# Set Tuning Options
...
@@ -184,34 +192,30 @@ log_file = "%s.%s.log" % (device_key, network)
...
@@ -184,34 +192,30 @@ log_file = "%s.%s.log" % (device_key, network)
dtype
=
'float32'
dtype
=
'float32'
tuning_option
=
{
tuning_option
=
{
'log_filename'
:
log_file
,
'log_filename'
:
log_file
,
'tuner'
:
'xgb'
,
'tuner'
:
'xgb'
,
'n_trial'
:
1000
,
'n_trial'
:
1000
,
'early_stopping'
:
250
,
'early_stopping'
:
400
,
'measure_option'
:
autotvm
.
measure_option
(
'measure_option'
:
autotvm
.
measure_option
(
autotvm
.
measure
.
rpc
(
device_key
,
host
=
'localhost'
,
port
=
9190
),
builder
=
autotvm
.
LocalBuilder
(
number
=
4
,
build_func
=
'ndk'
if
use_android
else
'default'
),
n_parallel
=
1
,
runner
=
autotvm
.
RPCRunner
(
timeout
=
10
,
device_key
,
host
=
'localhost'
,
port
=
9190
,
build_func
=
'ndk'
if
use_android
else
'default'
,
number
=
5
,
),
timeout
=
4
,
),
),
}
}
####################################################################
####################################################################
#
#
# .. note:: How to set tuning options
# .. note:: How to set tuning options
#
#
# In general, the default value provided here works well. It is the same
# In general, the default value provided here works well.
# value that we used to generate pre-tuned parameters.
# If you have multiple devices, you can set :code:`n_parallel` to
# the number of devices you have. (e.g. set it to 3 if you register 3 rk3399
# boards to the tracker).
# If you have large time budget, you can set :code:`n_trial`, :code:`early_stopping` larger,
# If you have large time budget, you can set :code:`n_trial`, :code:`early_stopping` larger,
# which makes the tuning run longer.
# which makes the tuning run longer.
# If your device is very slow or a single conv2d operator in your network has large FLOPs,
# consider setting timeout larger.
#
#
###################################################################
###################################################################
...
@@ -219,7 +223,7 @@ tuning_option = {
...
@@ -219,7 +223,7 @@ tuning_option = {
# ------------
# ------------
# Now we can extract tuning tasks from the network and begin tuning.
# Now we can extract tuning tasks from the network and begin tuning.
# Here we provide a simple utility function to tune a list of tasks.
# Here we provide a simple utility function to tune a list of tasks.
# This function is just an initial implementation which tune them in sequential order.
# This function is just an initial implementation which tune
s
them in sequential order.
# Later we will bring more sophisticated tuner scheduler.
# Later we will bring more sophisticated tuner scheduler.
# You can skip the implementation of this function for this tutorial.
# You can skip the implementation of this function for this tutorial.
...
@@ -236,7 +240,9 @@ def tune_tasks(tasks,
...
@@ -236,7 +240,9 @@ def tune_tasks(tasks,
try
:
# try winograd template
try
:
# try winograd template
tsk
=
autotvm
.
task
.
create
(
tasks
[
i
]
.
name
,
tasks
[
i
]
.
args
,
tsk
=
autotvm
.
task
.
create
(
tasks
[
i
]
.
name
,
tasks
[
i
]
.
args
,
tasks
[
i
]
.
target
,
tasks
[
i
]
.
target_host
,
'winograd'
)
tasks
[
i
]
.
target
,
tasks
[
i
]
.
target_host
,
'winograd'
)
tasks
.
append
(
tsk
)
input_channel
=
tsk
.
workload
[
1
][
1
]
if
input_channel
>=
64
:
tasks
[
i
]
=
tsk
except
Exception
:
except
Exception
:
pass
pass
...
@@ -245,8 +251,8 @@ def tune_tasks(tasks,
...
@@ -245,8 +251,8 @@ def tune_tasks(tasks,
if
os
.
path
.
exists
(
tmp_log_file
):
if
os
.
path
.
exists
(
tmp_log_file
):
os
.
remove
(
tmp_log_file
)
os
.
remove
(
tmp_log_file
)
for
i
,
tsk
in
enumerate
(
tasks
):
for
i
,
tsk
in
enumerate
(
reversed
(
tasks
)
):
prefix
=
"[Task
%2
d/
%2
d] "
%
(
i
+
1
,
len
(
tasks
))
prefix
=
"[Task
%2
d/
%2
d] "
%
(
i
+
1
,
len
(
tasks
))
# create tuner
# create tuner
if
tuner
==
'xgb'
or
tuner
==
'xgb-rank'
:
if
tuner
==
'xgb'
or
tuner
==
'xgb-rank'
:
...
@@ -280,7 +286,7 @@ def tune_tasks(tasks,
...
@@ -280,7 +286,7 @@ def tune_tasks(tasks,
########################################################################
########################################################################
# Finally we launch tuning jobs and evaluate the end-to-end performance.
# Finally we launch tuning jobs and evaluate the end-to-end performance.
def
tune_and_evaluate
():
def
tune_and_evaluate
(
tuning_opt
):
# extract workloads from nnvm graph
# extract workloads from nnvm graph
print
(
"Extract tasks..."
)
print
(
"Extract tasks..."
)
net
,
params
,
input_shape
,
out_shape
=
get_network
(
network
,
batch_size
=
1
)
net
,
params
,
input_shape
,
out_shape
=
get_network
(
network
,
batch_size
=
1
)
...
@@ -290,19 +296,18 @@ def tune_and_evaluate():
...
@@ -290,19 +296,18 @@ def tune_and_evaluate():
# run tuning tasks
# run tuning tasks
print
(
"Tuning..."
)
print
(
"Tuning..."
)
tune_tasks
(
tasks
,
**
tuning_opt
ion
)
tune_tasks
(
tasks
,
**
tuning_opt
)
# compile kernels with history best records
# compile kernels with history best records
with
autotvm
.
apply_history_best
(
log_file
):
with
autotvm
.
apply_history_best
(
log_file
):
print
(
"Compile..."
)
print
(
"Compile..."
)
with
nnvm
.
compiler
.
build_config
(
opt_level
=
2
,
add_pass
=
[
'AlterOpLayout'
]):
with
nnvm
.
compiler
.
build_config
(
opt_level
=
2
,
add_pass
=
[
'AlterOpLayout'
]):
graph
,
lib
,
params
=
nnvm
.
compiler
.
build
(
graph
,
lib
,
params
=
nnvm
.
compiler
.
build
(
net
,
target
=
target
,
net
,
target
=
target
,
shape
=
{
'data'
:
input_shape
},
params
=
params
,
dtype
=
dtype
)
shape
=
{
'data'
:
input_shape
},
params
=
params
,
dtype
=
dtype
)
# export library
# export library
tmp
=
tempdir
()
tmp
=
tempdir
()
if
tuning_option
[
'measure_option'
][
'build_func'
]
==
'ndk'
:
# for android
if
use_android
:
from
tvm.contrib
import
ndk
from
tvm.contrib
import
ndk
filename
=
"net.so"
filename
=
"net.so"
lib
.
export_library
(
tmp
.
relpath
(
filename
),
ndk
.
create_shared
)
lib
.
export_library
(
tmp
.
relpath
(
filename
),
ndk
.
create_shared
)
...
@@ -312,8 +317,7 @@ def tune_and_evaluate():
...
@@ -312,8 +317,7 @@ def tune_and_evaluate():
# upload module to device
# upload module to device
print
(
"Upload..."
)
print
(
"Upload..."
)
remote
=
autotvm
.
measure
.
request_remote
(
device_key
,
remote
=
autotvm
.
measure
.
request_remote
(
device_key
,
'localhost'
,
9190
,
tracker_addr
=
(
'localhost'
,
9190
),
timeout
=
10000
)
timeout
=
10000
)
remote
.
upload
(
tmp
.
relpath
(
filename
))
remote
.
upload
(
tmp
.
relpath
(
filename
))
rlib
=
remote
.
load_module
(
filename
)
rlib
=
remote
.
load_module
(
filename
)
...
@@ -328,47 +332,44 @@ def tune_and_evaluate():
...
@@ -328,47 +332,44 @@ def tune_and_evaluate():
# evaluate
# evaluate
print
(
"Evaluate inference time cost..."
)
print
(
"Evaluate inference time cost..."
)
ftimer
=
module
.
module
.
time_evaluator
(
"run"
,
ctx
,
number
=
1
,
repeat
=
10
)
ftimer
=
module
.
module
.
time_evaluator
(
"run"
,
ctx
,
number
=
8
,
repeat
=
3
)
prof_res
=
np
.
array
(
ftimer
()
.
results
)
*
1000
# convert to millisecond
prof_res
=
np
.
array
(
ftimer
()
.
results
)
*
1000
# convert to millisecond
print
(
"Mean inference time (std dev):
%.2
f ms (
%.2
f ms)"
%
print
(
"Mean inference time (std dev):
%.2
f ms (
%.2
f ms)"
%
(
np
.
mean
(
prof_res
),
np
.
std
(
prof_res
)))
(
np
.
mean
(
prof_res
),
np
.
std
(
prof_res
)))
# We do not run the tuning in our webpage server since it takes too long.
# We do not run the tuning in our webpage server since it takes too long.
# Uncomment the following line to run by yourself.
# Uncomment the following line to run by yourself.
# tune_and_evaluate()
# tune_and_evaluate(tuning_option)
######################################################################
######################################################################
# Sample Output
# Sample Output
# -------------
# -------------
# The tuning needs to
train xgboost models and use them for prediction
.
# The tuning needs to
compile many programs and extract feature from them
.
# So a high performance CPU is recommended.
# So a high performance CPU is recommended.
#
It takes about 2 hours on a 32T AMD Ryzen CPU
.
#
One sample output is listed below
.
#
One sample output is
#
It takes about 2 hours on a 32T AMD Ryzen Threadripper.
#
#
# .. code-block:: bash
# .. code-block:: bash
#
#
# Extract tasks...
# Extract tasks...
# Tuning...
# Tuning...
# [Task 1/16] Current/Best: 18.85/ 19.67 GFLOPS | Progress: (353/1000) | 387.05 s Done.
# [Task 1/12] Current/Best: 22.37/ 52.19 GFLOPS | Progress: (544/1000) | 406.59 s Done.
# [Task 2/16] Current/Best: 16.10/ 23.50 GFLOPS | Progress: (444/1000) | 379.99 s Done.
# [Task 2/12] Current/Best: 6.51/ 18.77 GFLOPS | Progress: (608/1000) | 325.05 s Done.
# [Task 3/16] Current/Best: 5.49/ 13.96 GFLOPS | Progress: (610/1000) | 485.87 s Done.
# [Task 3/12] Current/Best: 4.67/ 24.87 GFLOPS | Progress: (480/1000) | 372.31 s Done.
# [Task 4/16] Current/Best: 10.07/ 20.48 GFLOPS | Progress: (430/1000) | 391.66 s Done.
# [Task 4/12] Current/Best: 11.35/ 46.83 GFLOPS | Progress: (736/1000) | 602.39 s Done.
# [Task 5/16] Current/Best: 11.50/ 15.50 GFLOPS | Progress: (374/1000) | 356.03 s Done.
# [Task 5/12] Current/Best: 1.01/ 19.80 GFLOPS | Progress: (448/1000) | 262.16 s Done.
# [Task 6/16] Current/Best: 10.76/ 23.77 GFLOPS | Progress: (526/1000) | 526.42 s Done.
# [Task 6/12] Current/Best: 2.47/ 23.76 GFLOPS | Progress: (672/1000) | 563.85 s Done.
# [Task 7/16] Current/Best: 12.71/ 22.03 GFLOPS | Progress: (341/1000) | 322.96 s Done.
# [Task 7/12] Current/Best: 14.57/ 33.97 GFLOPS | Progress: (544/1000) | 465.15 s Done.
# [Task 8/16] Current/Best: 8.60/ 17.91 GFLOPS | Progress: (272/1000) | 236.08 s Done.
# [Task 8/12] Current/Best: 1.13/ 17.65 GFLOPS | Progress: (576/1000) | 365.08 s Done.
# [Task 9/16] Current/Best: 15.37/ 23.62 GFLOPS | Progress: (275/1000) | 275.18 s Done.
# [Task 9/12] Current/Best: 14.45/ 22.66 GFLOPS | Progress: (928/1000) | 724.25 s Done.
# [Task 10/16] Current/Best: 6.62/ 23.01 GFLOPS | Progress: (330/1000) | 315.02 s Done.
# [Task 10/12] Current/Best: 3.22/ 15.36 GFLOPS | Progress: (864/1000) | 564.27 s Done.
# [Task 11/16] Current/Best: 1.85/ 21.39 GFLOPS | Progress: (281/1000) | 239.19 s Done.
# [Task 11/12] Current/Best: 11.03/ 32.23 GFLOPS | Progress: (736/1000) | 635.15 s Done.
# [Task 12/16] Current/Best: 15.41/ 24.02 GFLOPS | Progress: (258/1000) | 270.82 s Done.
# [Task 12/12] Current/Best: 8.00/ 21.65 GFLOPS | Progress: (1000/1000) | 1111.81 s Done.
# [Task 13/16] Current/Best: 17.96/ 25.79 GFLOPS | Progress: (380/1000) | 738.29 s Done.
# [Task 14/16] Current/Best: 14.81/ 31.17 GFLOPS | Progress: (413/1000) | 799.21 s Done.
# [Task 15/16] Current/Best: 24.39/ 40.97 GFLOPS | Progress: (355/1000) | 700.25 s Done.
# [Task 16/16] Current/Best: 9.42/ 49.90 GFLOPS | Progress: (348/1000) | 603.84 s Done.
# Compile...
# Compile...
# Upload...
# Upload...
# Evaluate inference time cost...
# Evaluate inference time cost...
# Mean inference time (std dev): 1
57.29 ms (1.74
ms)
# Mean inference time (std dev): 1
62.59 ms (0.06
ms)
######################################################################
######################################################################
#
#
...
...
tutorials/autotvm/tune_simple_template.py
View file @
12839e6d
...
@@ -271,9 +271,12 @@ print(task.config_space)
...
@@ -271,9 +271,12 @@ print(task.config_space)
logging
.
getLogger
(
'autotvm'
)
.
setLevel
(
logging
.
DEBUG
)
logging
.
getLogger
(
'autotvm'
)
.
setLevel
(
logging
.
DEBUG
)
logging
.
getLogger
(
'autotvm'
)
.
addHandler
(
logging
.
StreamHandler
(
sys
.
stdout
))
logging
.
getLogger
(
'autotvm'
)
.
addHandler
(
logging
.
StreamHandler
(
sys
.
stdout
))
# use local cpu, measure 5 times for every config to reduce variance
# There are two steps for measuring a config: build and run.
measure_option
=
autotvm
.
measure_option
(
'local'
,
# By default, we use all cpu cores to compile program. Then measure them sequentially.
number
=
5
)
# We measure 5 times and take average to reduce variance.
measure_option
=
autotvm
.
measure_option
(
builder
=
'local'
,
runner
=
autotvm
.
LocalRunner
(
number
=
5
))
# begin tuning, log records to file `matmul.log`
# begin tuning, log records to file `matmul.log`
tuner
=
autotvm
.
tuner
.
RandomTuner
(
task
)
tuner
=
autotvm
.
tuner
.
RandomTuner
(
task
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment