Unverified Commit feda150e by Cody Yu Committed by GitHub

[AutoTVM] Support range in index based tuners (#4870)

* Support range in index based tuners

* Address comments

* Remove __*state__

* trigger CI
parent a5e54b1d
......@@ -25,6 +25,6 @@ from . import callback
from .tuner import Tuner
from .gridsearch_tuner import GridSearchTuner, RandomTuner
from .index_based_tuner import GridSearchTuner, RandomTuner
from .ga_tuner import GATuner
from .xgboost_tuner import XGBTuner
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=abstract-method
"""Grid search tuner and random tuner"""
import numpy as np
from .tuner import Tuner
class GridSearchTuner(Tuner):
"""Enumerate the search space in a grid search order"""
def __init__(self, task):
super(GridSearchTuner, self).__init__(task)
self.counter = 0
def next_batch(self, batch_size):
ret = []
for _ in range(batch_size):
if self.counter >= len(self.task.config_space):
continue
index = self.counter
ret.append(self.task.config_space.get(index))
self.counter = self.counter + 1
return ret
def has_next(self):
return self.counter < len(self.task.config_space)
def load_history(self, data_set):
pass
def __getstate__(self):
return {"counter": self.counter}
def __setstate__(self, state):
self.counter = state['counter']
class RandomTuner(Tuner):
"""Enumerate the search space in a random order"""
def __init__(self, task):
super(RandomTuner, self).__init__(task)
self.visited = set()
def next_batch(self, batch_size):
ret = []
counter = 0
while counter < batch_size:
if len(self.visited) >= len(self.task.config_space):
break
index = np.random.randint(len(self.task.config_space))
while index in self.visited:
index = np.random.randint(len(self.task.config_space))
ret.append(self.task.config_space.get(index))
self.visited.add(index)
counter += 1
return ret
def has_next(self):
return len(self.visited) < len(self.task.config_space)
def load_history(self, data_set):
pass
def __getstate__(self):
return {"visited": self.counter}
def __setstate__(self, state):
self.counter = state['visited']
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=abstract-method
"""Grid search tuner and random tuner"""
import numpy as np
from .tuner import Tuner
class IndexBaseTuner(Tuner):
"""Base class for index based tuner
This type of tuner determine the next batch of configs based on config indices.
Parameters
----------
task: autotvm.task.Task
The tuning task
range_idx: Optional[Tuple[int, int]]
A tuple of index range that this tuner can select from
"""
def __init__(self, task, range_idx=None):
super(IndexBaseTuner, self).__init__(task)
assert range_idx is None or isinstance(range_idx, tuple), \
"range_idx must be None or (int, int)"
self.range_length = len(self.task.config_space)
self.index_offset = 0
if range_idx is not None:
assert range_idx[1] > range_idx[0], "Index range must be positive"
assert range_idx[0] >= 0, "Start index must be positive"
self.range_length = range_idx[1] - range_idx[0] + 1
self.index_offset = range_idx[0]
self.counter = 0
def has_next(self):
return self.counter < self.range_length
def load_history(self, data_set):
pass
class GridSearchTuner(IndexBaseTuner):
"""Enumerate the search space in a grid search order"""
def next_batch(self, batch_size):
ret = []
for _ in range(batch_size):
if self.counter >= self.range_length:
break
index = self.counter + self.index_offset
ret.append(self.task.config_space.get(index))
self.counter = self.counter + 1
return ret
class RandomTuner(IndexBaseTuner):
"""Enumerate the search space in a random order
Parameters
----------
task: autotvm.task.Task
Tuning Task
range_idx: Optional[Tuple[int, int]]
A tuple of index range to random
"""
def __init__(self, task, range_idx=None):
super(RandomTuner, self).__init__(task, range_idx)
# Use a dict to mimic a range(n) list without storing rand_state[i] = i entries so that
# we can generate non-repetitive random indices.
self.rand_state = {}
self.rand_max = self.range_length
self.visited = []
def next_batch(self, batch_size):
ret = []
for _ in range(batch_size):
if self.rand_max == 0:
break
# Random an indirect index.
index_ = np.random.randint(self.rand_max)
self.rand_max -= 1
# Use the indirect index to get a direct index.
index = self.rand_state.get(index_, index_) + self.index_offset
ret.append(self.task.config_space.get(index))
self.visited.append(index)
# Update the direct index map.
self.rand_state[index_] = self.rand_state.get(self.rand_max, self.rand_max)
self.rand_state.pop(self.rand_max, None)
self.counter += 1
return ret
......@@ -17,9 +17,24 @@
"""Common utilities for testing autotvm"""
import time
import numpy as np
import tvm
from tvm import autotvm
from tvm.autotvm import MeasureInput, MeasureResult
from tvm.autotvm.measure.measure import Runner
class DummyRunner(Runner):
def __init__(self):
super(DummyRunner, self).__init__(1, 1)
def run(self, measure_inputs, build_results):
return [MeasureResult((np.random.random(),), 0, 0.2, time.time())
for _ in range(len(measure_inputs))]
def get_build_kwargs(self):
return {}
@autotvm.template
def matmul(N, L, M, dtype):
......@@ -82,4 +97,3 @@ def get_sample_records(n):
inps.append(MeasureInput(target, tsk, tsk.config_space.get(i)))
ress.append(MeasureResult((i+1,), 0, i, time.time()))
return list(zip(inps, ress))
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test index based tuners"""
from test_autotvm_common import DummyRunner, get_sample_task
from tvm import autotvm
from tvm.autotvm.tuner import GridSearchTuner, RandomTuner
def test_gridsearch_tuner():
"""Test GridSearchTuner"""
task, _ = get_sample_task()
measure_option = autotvm.measure_option(builder=autotvm.LocalBuilder(), runner=DummyRunner())
# When no range index, range_length should be the length of config space
tuner = autotvm.tuner.GridSearchTuner(task)
assert tuner.range_length == len(task.config_space)
assert tuner.index_offset == 0
# With range index, range_length should be the length of the specified range
tuner = autotvm.tuner.GridSearchTuner(task, range_idx=(8, 15))
assert tuner.range_length == 8
assert tuner.index_offset == 8
# Tuner should only focus on the specified range
tuner.tune(n_trial=8, measure_option=measure_option)
assert tuner.counter == 8
assert not tuner.has_next()
def test_random_tuner():
"""Test RandomTuner"""
task, _ = get_sample_task()
measure_option = autotvm.measure_option(builder=autotvm.LocalBuilder(), runner=DummyRunner())
tuner = autotvm.tuner.RandomTuner(task, range_idx=(8, 15))
assert tuner.range_length == 8
assert tuner.index_offset == 8
# Tuner should only focus on the specified range and should visit all indices
tuner.tune(n_trial=8, measure_option=measure_option)
assert tuner.counter == 8
assert not tuner.has_next()
visited = set()
for idx in tuner.visited:
assert idx not in visited
assert 8 <= idx <= 15
if __name__ == '__main__':
test_gridsearch_tuner()
test_random_tuner()
\ No newline at end of file
......@@ -21,24 +21,14 @@ import time
import numpy as np
import tvm
from test_autotvm_common import DummyRunner, bad_matmul, get_sample_task
from tvm import autotvm
from test_autotvm_common import get_sample_task, bad_matmul
from tvm.autotvm.measure.measure import Runner, MeasureResult, MeasureErrorNo
from tvm.autotvm.measure.measure import MeasureErrorNo, MeasureResult
def test_task_tuner_without_measurement():
"""test task and tuner without measurement"""
task, target = get_sample_task()
class DummyRunner(Runner):
def __init__(self):
super(DummyRunner, self).__init__(1, 1)
def run(self, measure_inputs, build_results):
return [MeasureResult((np.random.random(),), 0, 0.2, time.time())
for _ in range(len(measure_inputs))]
def get_build_kwargs(self):
return {}
task, _ = get_sample_task()
measure_option = autotvm.measure_option(
builder=autotvm.LocalBuilder(),
......@@ -64,7 +54,7 @@ def test_check_correctness():
)
def _callback_correct(tuner, measure_inputs, measure_results):
for inp, res in zip(measure_inputs, measure_results):
for _, res in zip(measure_inputs, measure_results):
assert res.error_no == 0
tuner = autotvm.tuner.RandomTuner(task)
......@@ -77,7 +67,7 @@ def test_check_correctness():
task = autotvm.task.create(bad_matmul, args=(n, n, n, 'float32'), target=target)
def _callback_wrong(tuner, measure_inputs, measure_results):
for inp, res in zip(measure_inputs, measure_results):
for _, res in zip(measure_inputs, measure_results):
assert res.error_no == MeasureErrorNo.WRONG_ANSWER
tuner = autotvm.tuner.RandomTuner(task)
......@@ -90,4 +80,3 @@ if __name__ == '__main__':
test_task_tuner_without_measurement()
test_check_correctness()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment