Commit 4f120464 by Yao Wang Committed by Yizhi Liu

Improve graph tuner dealing with Tuple (#3649)

* Improve graph tuner dealing with Tuple

* Add test case

* Move some data out of _base.py

* Fix lint
parent 3d4ba8d3
......@@ -18,11 +18,6 @@
"""Helper functions and global data"""
# Operators dependent on original layouts.
LAYOUT_FIXED_OP = ["batch_flatten", "transpose", "reshape",
"multibox_prior", "multibox_transform_loc", "where",
"non_max_suppression", "strided_slice"]
# We set a large time to represent an invalid layout-transformation.
# This number is set to be 10e9 seconds to align with autotvm.
INVALID_LAYOUT_TIME = 10e9
......
......@@ -444,6 +444,7 @@ class BaseGraphTuner(object):
timeout=timeout)
measure_option = autotvm.measure_option(builder=builder, runner=runner)
for args in args_list:
data, in_layout, out_layout = args
args = serialize_args(args)
ltf_workload = ('layout_transform',) + autotvm.task.args_to_workload(args)
if ltf_workload in self._layout_transform_perf_records:
......@@ -454,7 +455,18 @@ class BaseGraphTuner(object):
flops = 1
for i in input_shape:
flops *= i
inferred_time = flops * avg_time
# Rule out invalid layout transformations
out = topi.layout_transform(data, in_layout, out_layout)
out_flops = 1
for i in topi.util.get_const_tuple(out.shape):
out_flops *= i
if flops != out_flops:
inferred_time = INVALID_LAYOUT_TIME
else:
inferred_time = flops * avg_time
record_input = MeasureInput(target=self._target, task=None, config=None)
record_output = MeasureResult(costs=(inferred_time,), error_no=0,
all_cost=-1, timestamp=-1)
......
......@@ -26,7 +26,7 @@ from tvm.relay.expr import Call, Function, TupleGetItem, Var, Constant, Tuple
from tvm.relay.ty import TupleType, TensorType
from tvm.autotvm.task import TaskExtractEnv
from .utils import has_multiple_inputs, is_boundary_node
from .utils import has_multiple_inputs, is_boundary_node, is_skipped_node
# Setup relay op base name -> topi compute functions
......@@ -252,7 +252,7 @@ def get_in_nodes(node_list, target_ops, input_names):
visited_dict = {}
in_node_dict = {}
for i, node in enumerate(node_list):
if is_boundary_node(node, input_names):
if is_boundary_node(node, input_names) or is_skipped_node(node):
continue
get_direct_ancestor(node_list, visited_dict, target_ops, i, input_names)
for key, val in visited_dict.items():
......@@ -282,10 +282,12 @@ def get_in_nodes(node_list, target_ops, input_names):
boundary_nodes.append(key)
if boundary_nodes:
for idx in boundary_nodes:
del in_node_dict[idx]
if idx in in_node_dict:
del in_node_dict[idx]
else:
has_reduced_node = False
# Remove empty nodes to ignore pre-computed sub-graph
has_empty_node = True
while has_empty_node:
......
......@@ -19,8 +19,6 @@
from tvm import relay
from tvm.relay import transform
from .._base import LAYOUT_FIXED_OP
def has_multiple_inputs(node_list, node_idx, input_names):
"""Check whether a node has multiple input nodes
......@@ -72,11 +70,35 @@ def is_boundary_node(node_entry, input_names):
out : bool
whether node is a boundary node.
"""
out = node_entry["op"] in LAYOUT_FIXED_OP or \
# Operators dependent on original layouts.
_LAYOUT_FIXED_OP = ["batch_flatten", "transpose", "reshape",
"multibox_prior", "multibox_transform_loc", "where",
"non_max_suppression", "strided_slice"]
out = node_entry["op"] in _LAYOUT_FIXED_OP or \
("name" in node_entry and node_entry["name"] in input_names)
return out
def is_skipped_node(node_entry):
"""Whether a node is not counted.
Parameters
----------
node_entry : dict
Node entry.
Returns
-------
out : bool
whether node is skipped.
"""
# Operators not counted in graph tuner.
_SKIPPED_OP = ["Tuple"]
return node_entry["op"] in _SKIPPED_OP
def bind_inputs(expr, input_shapes=None, input_dtypes="float32"):
"""Bind input variables of a relay function expression
to new shapes and/or dtypes.
......
......@@ -354,25 +354,107 @@ def test_many_sub_graphs():
ms_output = MeasureResult(costs=(1.91224744e-05,), error_no=0, all_cost=-1, timestamp=-1)
ltf_records.append((ms_input, ms_output))
ltf_keys = []
ltf_arg = [tvm.placeholder((1, 4, 8, 8, 4), dtype=dtype), "NCHW4c", "NCHW8c"]
ltf_arg = autotvm.task.topi_integration.serialize_args(ltf_arg)
ltf_wkl = ('layout_transform',) + autotvm.task.args_to_workload(ltf_arg)
ltf_keys.append(ltf_wkl)
ltf_arg = [tvm.placeholder((1, 1, 8, 8, 32), dtype=dtype), "NCHW32c", "NCHW4c"]
ltf_arg = autotvm.task.topi_integration.serialize_args(ltf_arg)
ltf_wkl = ('layout_transform',) + autotvm.task.args_to_workload(ltf_arg)
ltf_keys.append(ltf_wkl)
ltf_arg = [tvm.placeholder((1, 4, 8, 8, 8), dtype=dtype), "NCHW8c", "NCHW32c"]
executor = DPTuner(net, {"data": dshape}, records, target_ops, target)
executor.benchmark_layout_transform(layout_records=ltf_records, infer_layout=True)
executor.run()
out = [record[0].config for record in executor.get_optimal_records()]
expected_out = [records[3][0].config, records[1][0].config, records[2][0].config]
assert expected_out == out, "Output mismatch: expecting %s but got %s" \
% (str(expected_out), str(out))
executor = PBQPTuner(net, {"data": dshape}, records, target_ops, target)
executor.benchmark_layout_transform(layout_records=ltf_records, infer_layout=True)
executor.run()
out = [record[0].config for record in executor.get_optimal_records()]
expected_out = [records[3][0].config, records[1][0].config, records[2][0].config]
assert expected_out == out, "Output mismatch: expecting %s but got %s" \
% (str(expected_out), str(out))
def test_tuple():
target = "llvm"
dtype = "float32"
dshape = (1, 5, 32, 32)
layout = "NCHW"
target_ops = [relay.nn.conv2d]
data = relay.var("data", shape=dshape, dtype=dtype)
w0 = relay.var("w0_weight")
conv0 = relay.nn.conv2d(data, w0, channels=2, kernel_size=(3, 3), padding=(1, 1))
w1 = relay.var("w1_weight")
conv1 = relay.nn.conv2d(data, w1, channels=3, kernel_size=(3, 3), padding=(1, 1))
out = relay.concatenate([conv0, conv1], axis=1)
net = relay.Function(relay.analysis.free_vars(out), out)
net, params = relay.testing.create_workload(net)
tasks = autotvm.task.extract_from_program(net["main"],
target=target,
params=params,
ops=(relay.op.nn.conv2d,))
wkl_list = [
create_workload((1, 5, 32, 32), (2, 5, 3, 3), (1, 1), (1, 1), (1, 1), layout, layout, dtype, dtype),
create_workload((1, 5, 32, 32), (3, 5, 3, 3), (1, 1), (1, 1), (1, 1), layout, layout, dtype, dtype),
]
costs = [0.01, 0.012, 0.03, 0.04]
config_list = []
cfg_dict = {"i": -1,
"c": None,
"e": [["tile_ic", "sp", [1, 5]],
["tile_oc", "sp", [1, 2]],
["tile_ow", "sp", [4, 8]],
["unroll_kw", "ot", True]],
"t": ""}
config_list.append(ConfigEntity.from_json_dict(cfg_dict))
cfg_dict = {"i": -1,
"c": None,
"e": [["tile_ic", "sp", [1, 5]],
["tile_oc", "sp", [1, 3]],
["tile_ow", "sp", [2, 16]],
["unroll_kw", "ot", False]],
"t": ""}
config_list.append(ConfigEntity.from_json_dict(cfg_dict))
cfg_dict = {"i": -1,
"c": None,
"e": [["tile_ic", "sp", [1, 5]],
["tile_oc", "sp", [2, 1]],
["tile_ow", "sp", [4, 8]],
["unroll_kw", "ot", True]],
"t": ""}
config_list.append(ConfigEntity.from_json_dict(cfg_dict))
cfg_dict = {"i": -1,
"c": None,
"e": [["tile_ic", "sp", [1, 5]],
["tile_oc", "sp", [3, 1]],
["tile_ow", "sp", [2, 16]],
["unroll_kw", "ot", False]],
"t": ""}
config_list.append(ConfigEntity.from_json_dict(cfg_dict))
records = []
wkl_list = wkl_list + wkl_list
tasks = tasks + tasks
for wkl, cost, config, task in zip(wkl_list, costs, config_list, tasks):
task.workload = wkl
ms_input = MeasureInput(target=target, task=task, config=config)
ms_output = MeasureResult(costs=(cost,), error_no=0, all_cost=-1, timestamp=-1)
records.append((ms_input, ms_output))
ltf_records = []
ltf_arg = [tvm.placeholder((1, 64, 16, 16, 8), dtype=dtype), "NCHW8c", "NCHW512c"]
ltf_arg = autotvm.task.topi_integration.serialize_args(ltf_arg)
ltf_wkl = ('layout_transform',) + autotvm.task.args_to_workload(ltf_arg)
ltf_keys.append(ltf_wkl)
ltf_task = copy.deepcopy(tasks[0])
ltf_task.workload = ltf_wkl
ms_input = MeasureInput(target=target, task=ltf_task, config=None)
ms_output = MeasureResult(costs=(1.91224744e-05,), error_no=0, all_cost=-1, timestamp=-1)
ltf_records.append((ms_input, ms_output))
executor = DPTuner(net, {"data": dshape}, records, target_ops, target)
executor.benchmark_layout_transform(layout_records=ltf_records, infer_layout=True)
executor.run()
out = [record[0].config for record in executor.get_optimal_records()]
expected_out = [records[3][0].config, records[1][0].config, records[2][0].config]
expected_out = [records[2][0].config, records[1][0].config]
assert expected_out == out, "Output mismatch: expecting %s but got %s" \
% (str(expected_out), str(out))
......@@ -380,7 +462,7 @@ def test_many_sub_graphs():
executor.benchmark_layout_transform(layout_records=ltf_records, infer_layout=True)
executor.run()
out = [record[0].config for record in executor.get_optimal_records()]
expected_out = [records[3][0].config, records[1][0].config, records[2][0].config]
expected_out = [records[2][0].config, records[1][0].config]
assert expected_out == out, "Output mismatch: expecting %s but got %s" \
% (str(expected_out), str(out))
......@@ -390,3 +472,4 @@ if __name__=="__main__":
test_DPTuner_run()
test_PBQPTuner_run()
test_many_sub_graphs()
test_tuple()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment