Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
70e11d32
Unverified
Commit
70e11d32
authored
Mar 12, 2020
by
Haichen Shen
Committed by
GitHub
Mar 12, 2020
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Autotvm] Fix autotvm customized template (#5034)
* init * fix template * tweak naming
parent
681df4fc
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
94 additions
and
54 deletions
+94
-54
python/tvm/autotvm/__init__.py
+1
-1
python/tvm/autotvm/graph_tuner/base_graph_tuner.py
+1
-1
python/tvm/autotvm/task/__init__.py
+1
-2
python/tvm/autotvm/task/task.py
+77
-36
python/tvm/autotvm/task/topi_integration.py
+4
-4
tests/python/integration/test_tuning.py
+1
-1
tests/python/unittest/test_autotvm_common.py
+2
-2
tests/python/unittest/test_autotvm_dispatch_context.py
+1
-1
tutorials/autotvm/tune_conv2d_cuda.py
+1
-1
tutorials/autotvm/tune_simple_template.py
+2
-2
tutorials/optimize/opt_matmul_auto_tensorcore.py
+2
-2
vta/tutorials/autotvm/tune_relay_vta.py
+1
-1
No files found.
python/tvm/autotvm/__init__.py
View file @
70e11d32
...
@@ -42,7 +42,7 @@ from .measure import measure_option, MeasureInput, MeasureResult, MeasureErrorNo
...
@@ -42,7 +42,7 @@ from .measure import measure_option, MeasureInput, MeasureResult, MeasureErrorNo
LocalBuilder
,
LocalRunner
,
RPCRunner
LocalBuilder
,
LocalRunner
,
RPCRunner
from
.tuner
import
callback
from
.tuner
import
callback
from
.task
import
get_config
,
create
,
ConfigSpace
,
ConfigEntity
,
\
from
.task
import
get_config
,
create
,
ConfigSpace
,
ConfigEntity
,
\
register_topi_compute
,
register_topi_schedule
,
register_customized_task
,
\
register_topi_compute
,
register_topi_schedule
,
template
,
\
DispatchContext
,
FallbackContext
,
ApplyHistoryBest
as
apply_history_best
,
\
DispatchContext
,
FallbackContext
,
ApplyHistoryBest
as
apply_history_best
,
\
ApplyGraphBest
as
apply_graph_best
ApplyGraphBest
as
apply_graph_best
from
.env
import
GLOBAL_SCOPE
from
.env
import
GLOBAL_SCOPE
python/tvm/autotvm/graph_tuner/base_graph_tuner.py
View file @
70e11d32
...
@@ -42,7 +42,7 @@ def get_infer_layout(task_name):
...
@@ -42,7 +42,7 @@ def get_infer_layout(task_name):
return
topi
.
nn
.
depthwise_conv2d_infer_layout
return
topi
.
nn
.
depthwise_conv2d_infer_layout
raise
ValueError
(
"Cannot find infer layout for task
%
s"
%
task_name
)
raise
ValueError
(
"Cannot find infer layout for task
%
s"
%
task_name
)
@autotvm.
register_customized_task
(
"layout_transform"
)
@autotvm.
template
(
"layout_transform"
)
def
layout_transform
(
*
args
):
def
layout_transform
(
*
args
):
"""Autotvm layout transform template."""
"""Autotvm layout transform template."""
cfg
=
get_config
()
cfg
=
get_config
()
...
...
python/tvm/autotvm/task/__init__.py
View file @
70e11d32
...
@@ -22,8 +22,7 @@ This module defines the task data structure, as well as a collection(zoo)
...
@@ -22,8 +22,7 @@ This module defines the task data structure, as well as a collection(zoo)
of typical tasks of interest.
of typical tasks of interest.
"""
"""
from
.task
import
Task
,
create
,
get_config
,
args_to_workload
,
\
from
.task
import
Task
,
create
,
get_config
,
args_to_workload
,
template
register_customized_task
from
.space
import
ConfigSpace
,
ConfigEntity
from
.space
import
ConfigSpace
,
ConfigEntity
from
.code_hash
import
attach_code_hash
,
attach_code_hash_to_arg
from
.code_hash
import
attach_code_hash
,
attach_code_hash_to_arg
from
.dispatcher
import
DispatchContext
,
ApplyConfig
,
ApplyHistoryBest
,
\
from
.dispatcher
import
DispatchContext
,
ApplyConfig
,
ApplyHistoryBest
,
\
...
...
python/tvm/autotvm/task/task.py
View file @
70e11d32
...
@@ -186,25 +186,35 @@ class Task(object):
...
@@ -186,25 +186,35 @@ class Task(object):
TASK_TABLE
=
{}
TASK_TABLE
=
{}
class
TopiTemplate
(
object
):
class
TaskTemplate
(
object
):
"""Topi template that holds the topi compute and schedule function"""
"""
Task template is used to creates a tunable AutoTVM task.
It can be defined by a pair of compute and schedule function using
`_register_task_compute` and `_register_task_schedule`,
or by a customized task creation function that is more flexible using
`_register_customized_task`.
Note that when customized func is registered, compute and schedule function
will be ignored
"""
def
__init__
(
self
):
def
__init__
(
self
):
self
.
compute
=
None
self
.
f
compute
=
None
self
.
schedule
=
None
self
.
f
schedule
=
None
self
.
customized_func
=
None
self
.
fcustomized
=
None
def
__call__
(
self
,
*
args
,
**
kwargs
):
def
__call__
(
self
,
*
args
,
**
kwargs
):
args
=
deserialize_args
(
args
)
args
=
deserialize_args
(
args
)
if
self
.
customized_func
is
None
:
if
self
.
fcustomized
is
None
:
return
self
.
_default_func
(
*
args
,
**
kwargs
)
return
self
.
_default_func
(
*
args
,
**
kwargs
)
assert
callable
(
self
.
customized_func
)
assert
callable
(
self
.
fcustomized
)
return
self
.
customized_func
(
*
args
,
**
kwargs
)
return
self
.
fcustomized
(
*
args
,
**
kwargs
)
def
_default_func
(
self
,
*
args
,
**
kwargs
):
def
_default_func
(
self
,
*
args
,
**
kwargs
):
assert
callable
(
self
.
compute
)
and
callable
(
self
.
schedule
)
assert
callable
(
self
.
fcompute
)
and
callable
(
self
.
f
schedule
)
out
=
self
.
compute
(
*
args
,
**
kwargs
)
out
=
self
.
f
compute
(
*
args
,
**
kwargs
)
arg_bufs
=
[
out
]
+
self
.
get_inputs
(
out
)
arg_bufs
=
[
out
]
+
self
.
get_inputs
(
out
)
s
=
self
.
schedule
([
out
])
s
=
self
.
f
schedule
([
out
])
return
s
,
arg_bufs
return
s
,
arg_bufs
def
get_inputs
(
self
,
out
):
def
get_inputs
(
self
,
out
):
...
@@ -218,7 +228,7 @@ class TopiTemplate(object):
...
@@ -218,7 +228,7 @@ class TopiTemplate(object):
queue
.
extend
(
t
.
op
.
input_tensors
)
queue
.
extend
(
t
.
op
.
input_tensors
)
return
inputs
return
inputs
def
register_task_compute
(
name
,
func
=
None
):
def
_
register_task_compute
(
name
,
func
=
None
):
"""Register compute function to autotvm task
"""Register compute function to autotvm task
Parameters
Parameters
...
@@ -237,17 +247,17 @@ def register_task_compute(name, func=None):
...
@@ -237,17 +247,17 @@ def register_task_compute(name, func=None):
"""
"""
def
_do_reg
(
f
):
def
_do_reg
(
f
):
if
name
not
in
TASK_TABLE
:
if
name
not
in
TASK_TABLE
:
TASK_TABLE
[
name
]
=
T
opi
Template
()
TASK_TABLE
[
name
]
=
T
ask
Template
()
tmpl
=
TASK_TABLE
[
name
]
tmpl
=
TASK_TABLE
[
name
]
if
tmpl
.
compute
is
not
None
:
if
tmpl
.
f
compute
is
not
None
:
raise
ValueError
(
"Compute is already registered in autoTVM task
%
s"
%
name
)
raise
ValueError
(
"Compute is already registered in autoTVM task
%
s"
%
name
)
tmpl
.
compute
=
f
tmpl
.
f
compute
=
f
return
f
return
f
if
func
:
if
func
:
return
_do_reg
(
func
)
return
_do_reg
(
func
)
return
_do_reg
return
_do_reg
def
register_task_schedule
(
name
,
func
=
None
):
def
_
register_task_schedule
(
name
,
func
=
None
):
"""Register schedule function to autotvm task
"""Register schedule function to autotvm task
Parameters
Parameters
...
@@ -266,24 +276,19 @@ def register_task_schedule(name, func=None):
...
@@ -266,24 +276,19 @@ def register_task_schedule(name, func=None):
"""
"""
def
_do_reg
(
f
):
def
_do_reg
(
f
):
if
name
not
in
TASK_TABLE
:
if
name
not
in
TASK_TABLE
:
TASK_TABLE
[
name
]
=
T
opi
Template
()
TASK_TABLE
[
name
]
=
T
ask
Template
()
tmpl
=
TASK_TABLE
[
name
]
tmpl
=
TASK_TABLE
[
name
]
if
tmpl
.
schedule
is
not
None
:
if
tmpl
.
f
schedule
is
not
None
:
raise
ValueError
(
"Schedule is already registered in autoTVM task
%
s"
%
name
)
raise
ValueError
(
"Schedule is already registered in autoTVM task
%
s"
%
name
)
tmpl
.
schedule
=
f
tmpl
.
f
schedule
=
f
return
f
return
f
if
func
:
if
func
:
return
_do_reg
(
func
)
return
_do_reg
(
func
)
return
_do_reg
return
_do_reg
def
register_customized_task
(
name
,
func
=
None
):
def
_
register_customized_task
(
name
,
func
=
None
):
"""Register a customized function to AutoTVM task.
"""Register a customized function to AutoTVM task.
In most cases, you can just use register_topi_compute and register_topi_schedule
with the same task name to define an AutoTVM task. However, you can also
create a customized AutoTVM task that defines a tunable template or performs
extra layout transform before invoking compute/schedule function.
Parameters
Parameters
----------
----------
name: str
name: str
...
@@ -297,6 +302,37 @@ def register_customized_task(name, func=None):
...
@@ -297,6 +302,37 @@ def register_customized_task(name, func=None):
-------
-------
decorator: callable
decorator: callable
A decorator
A decorator
"""
def
_do_reg
(
f
):
if
name
not
in
TASK_TABLE
:
TASK_TABLE
[
name
]
=
TaskTemplate
()
tmpl
=
TASK_TABLE
[
name
]
if
tmpl
.
fcustomized
is
not
None
:
raise
ValueError
(
"Customized func is already registered in autoTVM task
%
s"
%
name
)
tmpl
.
fcustomized
=
f
return
f
if
func
:
return
_do_reg
(
func
)
return
_do_reg
def
template
(
task_name
,
func
=
None
):
"""Decorate a function as a tunable schedule template.
Parameters
----------
task_name: str
The task name
func: None or callable
A callable template function.
If it is None, return a decorator.
If is callable, decorate this function.
Returns
-------
func: callable
The decorated function
Examples
Examples
--------
--------
...
@@ -304,7 +340,7 @@ def register_customized_task(name, func=None):
...
@@ -304,7 +340,7 @@ def register_customized_task(name, func=None):
.. code-block:: python
.. code-block:: python
@autotvm.
register_customized_task
("matmul")
@autotvm.
template
("matmul")
def matmul(N, L, M, dtype):
def matmul(N, L, M, dtype):
A = te.placeholder((N, L), name='A', dtype=dtype)
A = te.placeholder((N, L), name='A', dtype=dtype)
B = te.placeholder((L, M), name='B', dtype=dtype)
B = te.placeholder((L, M), name='B', dtype=dtype)
...
@@ -331,17 +367,22 @@ def register_customized_task(name, func=None):
...
@@ -331,17 +367,22 @@ def register_customized_task(name, func=None):
return s, [A, B, C]
return s, [A, B, C]
"""
"""
def
_do_reg
(
f
):
def
_decorate
(
f
):
if
name
not
in
TASK_TABLE
:
def
wrapper
(
*
args
,
**
kwargs
):
TASK_TABLE
[
name
]
=
TopiTemplate
()
assert
not
kwargs
,
"Do not support kwargs in template function call"
tmpl
=
TASK_TABLE
[
name
]
workload
=
args_to_workload
(
args
,
task_name
)
if
tmpl
.
customized_func
is
not
None
:
tgt
=
_target
.
Target
.
current
()
raise
ValueError
(
"Customized func is already registered in autoTVM task
%
s"
%
name
)
cfg
=
DispatchContext
.
current
.
query
(
tgt
,
workload
)
tmpl
.
customized_func
=
f
with
ApplyConfig
(
cfg
):
return
f
return
f
(
*
args
,
**
kwargs
)
_register_customized_task
(
task_name
,
f
)
return
wrapper
if
func
:
if
func
:
return
_do_reg
(
func
)
return
_decorate
(
func
)
return
_do_reg
return
_decorate
def
create
(
task_name
,
args
,
target
,
target_host
=
None
):
def
create
(
task_name
,
args
,
target
,
target_host
=
None
):
"""Create a tuning task and initialize its search space
"""Create a tuning task and initialize its search space
...
...
python/tvm/autotvm/task/topi_integration.py
View file @
70e11d32
...
@@ -30,8 +30,8 @@ import tvm.te._ffi_api
...
@@ -30,8 +30,8 @@ import tvm.te._ffi_api
from
tvm
import
target
as
_target
from
tvm
import
target
as
_target
from
tvm.te
import
tensor
from
tvm.te
import
tensor
from
.task
import
args_to_workload
,
DispatchContext
,
\
from
.task
import
args_to_workload
,
serialize_args
,
DispatchContext
,
\
register_task_compute
,
register_task_schedule
,
serialize_args
_register_task_compute
,
_register_task_schedule
# Task extractor for relay program
# Task extractor for relay program
...
@@ -142,7 +142,7 @@ def register_topi_compute(task_name, func=None):
...
@@ -142,7 +142,7 @@ def register_topi_compute(task_name, func=None):
See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage.
See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage.
"""
"""
def
_decorate
(
topi_compute
):
def
_decorate
(
topi_compute
):
@register_task_compute
(
task_name
)
@
_
register_task_compute
(
task_name
)
def
wrapper
(
*
args
,
**
kwargs
):
def
wrapper
(
*
args
,
**
kwargs
):
"""wrapper function for topi compute"""
"""wrapper function for topi compute"""
assert
not
kwargs
,
"Do not support kwargs in template function call"
assert
not
kwargs
,
"Do not support kwargs in template function call"
...
@@ -212,7 +212,7 @@ def register_topi_schedule(task_name, func=None):
...
@@ -212,7 +212,7 @@ def register_topi_schedule(task_name, func=None):
See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage.
See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage.
"""
"""
def
_decorate
(
topi_schedule
):
def
_decorate
(
topi_schedule
):
@register_task_schedule
(
task_name
)
@
_
register_task_schedule
(
task_name
)
def
wrapper
(
outs
,
*
args
,
**
kwargs
):
def
wrapper
(
outs
,
*
args
,
**
kwargs
):
"""wrapper function for topi schedule"""
"""wrapper function for topi schedule"""
workload
=
get_workload
(
outs
)
workload
=
get_workload
(
outs
)
...
...
tests/python/integration/test_tuning.py
View file @
70e11d32
...
@@ -26,7 +26,7 @@ from tvm import te
...
@@ -26,7 +26,7 @@ from tvm import te
from
tvm
import
autotvm
from
tvm
import
autotvm
from
tvm.autotvm.tuner
import
RandomTuner
from
tvm.autotvm.tuner
import
RandomTuner
@autotvm.
register_customized_task
(
"testing/conv2d_no_batching"
)
@autotvm.
template
(
"testing/conv2d_no_batching"
)
def
conv2d_no_batching
(
N
,
H
,
W
,
CI
,
CO
,
KH
,
KW
):
def
conv2d_no_batching
(
N
,
H
,
W
,
CI
,
CO
,
KH
,
KW
):
"""An example template for testing"""
"""An example template for testing"""
assert
N
==
1
,
"Only consider batch_size = 1 in this template"
assert
N
==
1
,
"Only consider batch_size = 1 in this template"
...
...
tests/python/unittest/test_autotvm_common.py
View file @
70e11d32
...
@@ -37,7 +37,7 @@ class DummyRunner(Runner):
...
@@ -37,7 +37,7 @@ class DummyRunner(Runner):
def
get_build_kwargs
(
self
):
def
get_build_kwargs
(
self
):
return
{}
return
{}
@autotvm.
register_customized_task
(
"testing/matmul"
)
@autotvm.
template
(
"testing/matmul"
)
def
matmul
(
N
,
L
,
M
,
dtype
):
def
matmul
(
N
,
L
,
M
,
dtype
):
A
=
te
.
placeholder
((
N
,
L
),
name
=
'A'
,
dtype
=
dtype
)
A
=
te
.
placeholder
((
N
,
L
),
name
=
'A'
,
dtype
=
dtype
)
B
=
te
.
placeholder
((
L
,
M
),
name
=
'B'
,
dtype
=
dtype
)
B
=
te
.
placeholder
((
L
,
M
),
name
=
'B'
,
dtype
=
dtype
)
...
@@ -64,7 +64,7 @@ def matmul(N, L, M, dtype):
...
@@ -64,7 +64,7 @@ def matmul(N, L, M, dtype):
return
s
,
[
A
,
B
,
C
]
return
s
,
[
A
,
B
,
C
]
@autotvm.
register_customized_task
(
"testing/bad_matmul"
)
@autotvm.
template
(
"testing/bad_matmul"
)
def
bad_matmul
(
N
,
L
,
M
,
dtype
):
def
bad_matmul
(
N
,
L
,
M
,
dtype
):
if
'bad_device'
in
tvm
.
target
.
Target
.
current
()
.
keys
:
if
'bad_device'
in
tvm
.
target
.
Target
.
current
()
.
keys
:
A
=
te
.
placeholder
((
N
,
L
),
name
=
'A'
,
dtype
=
dtype
)
A
=
te
.
placeholder
((
N
,
L
),
name
=
'A'
,
dtype
=
dtype
)
...
...
tests/python/unittest/test_autotvm_dispatch_context.py
View file @
70e11d32
...
@@ -22,7 +22,7 @@ from tvm import autotvm
...
@@ -22,7 +22,7 @@ from tvm import autotvm
def
test_fallback
():
def
test_fallback
():
@autotvm.
register_customized_task
(
"testing/dispatch/
fallback"
)
@autotvm.
template
(
"testing/dispatch_
fallback"
)
def
simple_template
(
a
,
b
):
def
simple_template
(
a
,
b
):
cfg
=
autotvm
.
get_config
()
cfg
=
autotvm
.
get_config
()
assert
cfg
.
is_fallback
assert
cfg
.
is_fallback
...
...
tutorials/autotvm/tune_conv2d_cuda.py
View file @
70e11d32
...
@@ -79,7 +79,7 @@ from tvm import autotvm
...
@@ -79,7 +79,7 @@ from tvm import autotvm
# can be very large (at the level of 10^9 for some input shapes)
# can be very large (at the level of 10^9 for some input shapes)
#
#
@autotvm.
register_customized_task
(
"tutorial/conv2d_no_batching"
)
@autotvm.
template
(
"tutorial/conv2d_no_batching"
)
def
conv2d_no_batching
(
N
,
H
,
W
,
CO
,
CI
,
KH
,
KW
,
stride
,
padding
):
def
conv2d_no_batching
(
N
,
H
,
W
,
CO
,
CI
,
KH
,
KW
,
stride
,
padding
):
assert
N
==
1
,
"Only consider batch_size = 1 in this template"
assert
N
==
1
,
"Only consider batch_size = 1 in this template"
...
...
tutorials/autotvm/tune_simple_template.py
View file @
70e11d32
...
@@ -103,7 +103,7 @@ def matmul_v0(N, L, M, dtype):
...
@@ -103,7 +103,7 @@ def matmul_v0(N, L, M, dtype):
# In autotvm, we can define a tunable parameter, or a "knob" for such kind of value.
# In autotvm, we can define a tunable parameter, or a "knob" for such kind of value.
# Matmul V1: List candidate values
# Matmul V1: List candidate values
@autotvm.
register_customized_task
(
"tutorial/matmul_v1"
)
# 1. use a decorator
@autotvm.
template
(
"tutorial/matmul_v1"
)
# 1. use a decorator
def
matmul_v1
(
N
,
L
,
M
,
dtype
):
def
matmul_v1
(
N
,
L
,
M
,
dtype
):
A
=
te
.
placeholder
((
N
,
L
),
name
=
'A'
,
dtype
=
dtype
)
A
=
te
.
placeholder
((
N
,
L
),
name
=
'A'
,
dtype
=
dtype
)
B
=
te
.
placeholder
((
L
,
M
),
name
=
'B'
,
dtype
=
dtype
)
B
=
te
.
placeholder
((
L
,
M
),
name
=
'B'
,
dtype
=
dtype
)
...
@@ -183,7 +183,7 @@ def matmul_v1(N, L, M, dtype):
...
@@ -183,7 +183,7 @@ def matmul_v1(N, L, M, dtype):
# When the high level API cannot meet your requirement, you can always fall
# When the high level API cannot meet your requirement, you can always fall
# back to use low level API.
# back to use low level API.
@autotvm.
register_customized_task
(
"tutorial/matmul"
)
@autotvm.
template
(
"tutorial/matmul"
)
def
matmul
(
N
,
L
,
M
,
dtype
):
def
matmul
(
N
,
L
,
M
,
dtype
):
A
=
te
.
placeholder
((
N
,
L
),
name
=
'A'
,
dtype
=
dtype
)
A
=
te
.
placeholder
((
N
,
L
),
name
=
'A'
,
dtype
=
dtype
)
B
=
te
.
placeholder
((
L
,
M
),
name
=
'B'
,
dtype
=
dtype
)
B
=
te
.
placeholder
((
L
,
M
),
name
=
'B'
,
dtype
=
dtype
)
...
...
tutorials/optimize/opt_matmul_auto_tensorcore.py
View file @
70e11d32
...
@@ -95,7 +95,7 @@ def matmul_nn(A, B, L, dtype='float16', layout='NN'):
...
@@ -95,7 +95,7 @@ def matmul_nn(A, B, L, dtype='float16', layout='NN'):
#
#
# We use AutoTVM to search for best configurations in this schedule.
# We use AutoTVM to search for best configurations in this schedule.
@autotvm.
register_customized_task
(
"tutorial
/test_gemm"
)
@autotvm.
template
(
"tutorial/auto_tensorcore
/test_gemm"
)
def
test_gemm
(
N
,
L
,
M
,
dtype
,
layout
):
def
test_gemm
(
N
,
L
,
M
,
dtype
,
layout
):
if
(
layout
==
"NN"
):
if
(
layout
==
"NN"
):
shape_a
=
(
N
,
L
)
shape_a
=
(
N
,
L
)
...
@@ -265,7 +265,7 @@ elif dtype == 'int4' or dtype == 'int1':
...
@@ -265,7 +265,7 @@ elif dtype == 'int4' or dtype == 'int1':
assert
(
major
==
7
and
minor
==
5
and
layout
==
'TN'
)
assert
(
major
==
7
and
minor
==
5
and
layout
==
'TN'
)
def
tune_and_evaluate
(
M
,
N
,
L
,
dtype
,
layout
):
def
tune_and_evaluate
(
M
,
N
,
L
,
dtype
,
layout
):
task
=
autotvm
.
task
.
create
(
"tutorial/test_gemm"
,
args
=
(
N
,
L
,
M
,
dtype
,
layout
),
task
=
autotvm
.
task
.
create
(
"tutorial/
auto_tensorcore/
test_gemm"
,
args
=
(
N
,
L
,
M
,
dtype
,
layout
),
target
=
'cuda'
)
target
=
'cuda'
)
print
(
task
.
config_space
)
print
(
task
.
config_space
)
...
...
vta/tutorials/autotvm/tune_relay_vta.py
View file @
70e11d32
...
@@ -310,7 +310,7 @@ def register_vta_tuning_tasks():
...
@@ -310,7 +310,7 @@ def register_vta_tuning_tasks():
# init autotvm env to register VTA operator
# init autotvm env to register VTA operator
TaskExtractEnv
()
TaskExtractEnv
()
@autotvm.
register_customized_task
(
"conv2d_packed.vta"
)
@autotvm.
template
(
"conv2d_packed.vta"
)
def
_topi_nn_conv2d
(
*
args
,
**
kwargs
):
def
_topi_nn_conv2d
(
*
args
,
**
kwargs
):
assert
not
kwargs
,
"Do not support kwargs in template function call"
assert
not
kwargs
,
"Do not support kwargs in template function call"
A
,
W
=
args
[:
2
]
A
,
W
=
args
[:
2
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment