Commit 6c43019b by Yao Wang Committed by Tianqi Chen

GraphTuner supports relay.module as input (#3434)

parent a074dafc
......@@ -141,6 +141,9 @@ class BaseGraphTuner(object):
self._logger.propagate = False
# Generate workload and schedule dictionaries.
if isinstance(graph, relay.Module):
graph = graph[graph.entry_func]
if isinstance(graph, relay.expr.Function):
node_dict = {}
graph = bind_inputs(graph, input_shapes, dtype)
......
......@@ -159,6 +159,8 @@ def test_DPTuner_run():
target_ops = [relay.nn.conv2d]
g, records, ltf_records, ltf_keys, tasks = _create_data(target, dshape, dtype, layout)
mod = relay.module.Module()
mod[mod.entry_func] = g
costs = [0.02, 0.02, 0.045]
config_list = []
cfg_dict = {"i": -1,
......@@ -190,7 +192,7 @@ def test_DPTuner_run():
ms_output = MeasureResult(costs=(cost,), error_no=0, all_cost=-1, timestamp=-1)
records.append((ms_input, ms_output))
executor = DPTuner(g, {"data": dshape}, records, target_ops, target, log_file=log_file)
executor = DPTuner(mod, {"data": dshape}, records, target_ops, target, log_file=log_file)
executor.benchmark_layout_transform(layout_records=ltf_records, infer_layout=True)
executor.run()
out = [record[0].config for record in executor.get_optimal_records()]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment