Commit 7751a6ba by Lianmin Zheng Committed by Tianqi Chen

[AUTOTVM] Fix GATuner and improve error message (#1605)

parent 54a115ef
......@@ -366,6 +366,8 @@ class ExternOpNode : public OperationNode {
v->Visit("tag", &tag);
v->Visit("attrs", &attrs);
v->Visit("inputs", &inputs);
v->Visit("input_placeholders", &input_placeholders);
v->Visit("output_placeholders", &output_placeholders);
v->Visit("body", &body);
}
EXPORT static Operation make(std::string name,
......
......@@ -394,6 +394,8 @@ def _measure_common(input_pack, build_func, build_kwargs, number, repeat,
msg = str(exc)
if "Stack trace returned" in msg:
msg = msg[:msg.index("Stack trace returned")]
if "CUDA Source" in msg:
msg = msg[:msg.index("CUDA Source")]
costs = (RuntimeError(msg),)
errno = MeasureErrorNo.RUNTIME_DEVICE
tstamp = time.time()
......
......@@ -4,12 +4,16 @@ Decorator and utilities for the integration with TOPI and NNVM
"""
import warnings
import logging
from ... import tensor, placeholder, target as _target
from ..util import get_const_tuple
from .task import create, register
from .dispatcher import ApplyHistoryBest
logger = logging.getLogger('autotvm')
def serialize_args(args):
"""serialize arguments of a topi function to a hashable tuple.
......@@ -176,9 +180,18 @@ def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
# run compiler to collect all TOPI calls during compilation
env.reset()
# disable logger temporarily
old_state = logger.disabled
logger.disabled = True
# use a dummy target to do a fake compile for collecting topi calls
dummy_target = _target.create("opencl -device=dummy")
with ApplyHistoryBest([], allow_fallback=True):
nnvm.compiler.build(graph, target=dummy_target, shape=shape, dtype=dtype)
logger.disabled = old_state
tasks = []
for task_name, args in env.get_tasks():
tasks.append(create(task_name, args,
......
......@@ -368,7 +368,7 @@ def compute_flop(sch):
pass
else:
raise FlopCalculationError("Only support tvm.compute currently. "
"Other ops like tvm.scan is not supported")
"Other ops like tvm.scan/tvm.extern is not supported")
return ret
try:
......
......@@ -62,7 +62,7 @@ def register_topi_compute(topi_compute, target_keys, template_keys, func=None):
for target_key in targets:
if target_key not in _REGISTED_DISPATHCER:
_REGISTED_DISPATHCER[target_key] = {}
if topi_compute not in _REGISTED_DISPATHCER:
if topi_compute not in _REGISTED_DISPATHCER[target_key]:
@topi_compute.register(target_key)
@dispatcher
def config_dispatcher(*args, **kwargs):
......
......@@ -101,11 +101,17 @@ def progress_bar(total, prefix=''):
self.total = total
def __del__(self):
if logger.level < logging.DEBUG: # only print progress bar in non-debug mode
sys.stdout.write(' Done.\n')
ctx = _Context()
tic = time.time()
if logger.level < logging.DEBUG: # only print progress bar in non-debug mode
sys.stdout.write('\r%s Current/Best: %7.2f/%7.2f GFLOPS | Progress: (%d/%d) '
'| %.2f s' % (prefix, 0, 0, 0, total, time.time() - tic))
sys.stdout.flush()
def _callback(tuner, inputs, results):
ctx.ct += len(inputs)
......
......@@ -47,6 +47,7 @@ class GATuner(Tuner):
# random initialization
self.pop_size = min(self.pop_size, len(self.space))
self.elite_num = min(self.pop_size, self.elite_num)
for _ in range(self.pop_size):
tmp_gene = point2knob(np.random.randint(len(self.space)), self.dims)
while knob2point(tmp_gene, self.dims) in self.visited:
......@@ -70,9 +71,9 @@ class GATuner(Tuner):
y = inp.task.flop / np.mean(res.costs)
self.scores.append(y)
else:
self.scores.append(0)
self.scores.append(0.0)
if len(self.scores) >= len(self.genes):
if len(self.scores) >= len(self.genes) and len(self.visited) < len(self.space):
genes = self.genes + self.elites
scores = np.array(self.scores[:len(self.genes)] + self.elite_scores)
......@@ -85,7 +86,12 @@ class GATuner(Tuner):
# cross over
indices = np.arange(len(genes))
scores /= np.max(scores)
max_score = np.max(scores)
if max_score < 1e-8:
probs = np.empty_like(scores)
probs[:] = 1.0 / len(scores)
else:
scores /= max_score
probs = scores / np.sum(scores)
tmp_genes = []
for _ in range(self.pop_size):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment