Commit 7751a6ba by Lianmin Zheng Committed by Tianqi Chen

[AUTOTVM] Fix GATuner and improve error message (#1605)

parent 54a115ef
...@@ -366,6 +366,8 @@ class ExternOpNode : public OperationNode { ...@@ -366,6 +366,8 @@ class ExternOpNode : public OperationNode {
v->Visit("tag", &tag); v->Visit("tag", &tag);
v->Visit("attrs", &attrs); v->Visit("attrs", &attrs);
v->Visit("inputs", &inputs); v->Visit("inputs", &inputs);
v->Visit("input_placeholders", &input_placeholders);
v->Visit("output_placeholders", &output_placeholders);
v->Visit("body", &body); v->Visit("body", &body);
} }
EXPORT static Operation make(std::string name, EXPORT static Operation make(std::string name,
......
...@@ -394,6 +394,8 @@ def _measure_common(input_pack, build_func, build_kwargs, number, repeat, ...@@ -394,6 +394,8 @@ def _measure_common(input_pack, build_func, build_kwargs, number, repeat,
msg = str(exc) msg = str(exc)
if "Stack trace returned" in msg: if "Stack trace returned" in msg:
msg = msg[:msg.index("Stack trace returned")] msg = msg[:msg.index("Stack trace returned")]
if "CUDA Source" in msg:
msg = msg[:msg.index("CUDA Source")]
costs = (RuntimeError(msg),) costs = (RuntimeError(msg),)
errno = MeasureErrorNo.RUNTIME_DEVICE errno = MeasureErrorNo.RUNTIME_DEVICE
tstamp = time.time() tstamp = time.time()
......
...@@ -4,12 +4,16 @@ Decorator and utilities for the integration with TOPI and NNVM ...@@ -4,12 +4,16 @@ Decorator and utilities for the integration with TOPI and NNVM
""" """
import warnings import warnings
import logging
from ... import tensor, placeholder, target as _target from ... import tensor, placeholder, target as _target
from ..util import get_const_tuple from ..util import get_const_tuple
from .task import create, register from .task import create, register
from .dispatcher import ApplyHistoryBest
logger = logging.getLogger('autotvm')
def serialize_args(args): def serialize_args(args):
"""serialize arguments of a topi function to a hashable tuple. """serialize arguments of a topi function to a hashable tuple.
...@@ -176,9 +180,18 @@ def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None): ...@@ -176,9 +180,18 @@ def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
# run compiler to collect all TOPI calls during compilation # run compiler to collect all TOPI calls during compilation
env.reset() env.reset()
# disable logger temporarily
old_state = logger.disabled
logger.disabled = True
# use a dummy target to do a fake compile for collecting topi calls
dummy_target = _target.create("opencl -device=dummy") dummy_target = _target.create("opencl -device=dummy")
with ApplyHistoryBest([], allow_fallback=True):
nnvm.compiler.build(graph, target=dummy_target, shape=shape, dtype=dtype) nnvm.compiler.build(graph, target=dummy_target, shape=shape, dtype=dtype)
logger.disabled = old_state
tasks = [] tasks = []
for task_name, args in env.get_tasks(): for task_name, args in env.get_tasks():
tasks.append(create(task_name, args, tasks.append(create(task_name, args,
......
...@@ -368,7 +368,7 @@ def compute_flop(sch): ...@@ -368,7 +368,7 @@ def compute_flop(sch):
pass pass
else: else:
raise FlopCalculationError("Only support tvm.compute currently. " raise FlopCalculationError("Only support tvm.compute currently. "
"Other ops like tvm.scan is not supported") "Other ops like tvm.scan/tvm.extern is not supported")
return ret return ret
try: try:
......
...@@ -62,7 +62,7 @@ def register_topi_compute(topi_compute, target_keys, template_keys, func=None): ...@@ -62,7 +62,7 @@ def register_topi_compute(topi_compute, target_keys, template_keys, func=None):
for target_key in targets: for target_key in targets:
if target_key not in _REGISTED_DISPATHCER: if target_key not in _REGISTED_DISPATHCER:
_REGISTED_DISPATHCER[target_key] = {} _REGISTED_DISPATHCER[target_key] = {}
if topi_compute not in _REGISTED_DISPATHCER: if topi_compute not in _REGISTED_DISPATHCER[target_key]:
@topi_compute.register(target_key) @topi_compute.register(target_key)
@dispatcher @dispatcher
def config_dispatcher(*args, **kwargs): def config_dispatcher(*args, **kwargs):
......
...@@ -101,11 +101,17 @@ def progress_bar(total, prefix=''): ...@@ -101,11 +101,17 @@ def progress_bar(total, prefix=''):
self.total = total self.total = total
def __del__(self): def __del__(self):
if logger.level < logging.DEBUG: # only print progress bar in non-debug mode
sys.stdout.write(' Done.\n') sys.stdout.write(' Done.\n')
ctx = _Context() ctx = _Context()
tic = time.time() tic = time.time()
if logger.level < logging.DEBUG: # only print progress bar in non-debug mode
sys.stdout.write('\r%s Current/Best: %7.2f/%7.2f GFLOPS | Progress: (%d/%d) '
'| %.2f s' % (prefix, 0, 0, 0, total, time.time() - tic))
sys.stdout.flush()
def _callback(tuner, inputs, results): def _callback(tuner, inputs, results):
ctx.ct += len(inputs) ctx.ct += len(inputs)
......
...@@ -47,6 +47,7 @@ class GATuner(Tuner): ...@@ -47,6 +47,7 @@ class GATuner(Tuner):
# random initialization # random initialization
self.pop_size = min(self.pop_size, len(self.space)) self.pop_size = min(self.pop_size, len(self.space))
self.elite_num = min(self.pop_size, self.elite_num)
for _ in range(self.pop_size): for _ in range(self.pop_size):
tmp_gene = point2knob(np.random.randint(len(self.space)), self.dims) tmp_gene = point2knob(np.random.randint(len(self.space)), self.dims)
while knob2point(tmp_gene, self.dims) in self.visited: while knob2point(tmp_gene, self.dims) in self.visited:
...@@ -70,9 +71,9 @@ class GATuner(Tuner): ...@@ -70,9 +71,9 @@ class GATuner(Tuner):
y = inp.task.flop / np.mean(res.costs) y = inp.task.flop / np.mean(res.costs)
self.scores.append(y) self.scores.append(y)
else: else:
self.scores.append(0) self.scores.append(0.0)
if len(self.scores) >= len(self.genes): if len(self.scores) >= len(self.genes) and len(self.visited) < len(self.space):
genes = self.genes + self.elites genes = self.genes + self.elites
scores = np.array(self.scores[:len(self.genes)] + self.elite_scores) scores = np.array(self.scores[:len(self.genes)] + self.elite_scores)
...@@ -85,7 +86,12 @@ class GATuner(Tuner): ...@@ -85,7 +86,12 @@ class GATuner(Tuner):
# cross over # cross over
indices = np.arange(len(genes)) indices = np.arange(len(genes))
scores /= np.max(scores) max_score = np.max(scores)
if max_score < 1e-8:
probs = np.empty_like(scores)
probs[:] = 1.0 / len(scores)
else:
scores /= max_score
probs = scores / np.sum(scores) probs = scores / np.sum(scores)
tmp_genes = [] tmp_genes = []
for _ in range(self.pop_size): for _ in range(self.pop_size):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment