Commit 4f120464 by Yao Wang Committed by Yizhi Liu

Improve graph tuner dealing with Tuple (#3649)

* Improve graph tuner dealing with Tuple

* Add test case

* Move some data out of _base.py

* Fix lint
parent 3d4ba8d3
...@@ -18,11 +18,6 @@ ...@@ -18,11 +18,6 @@
"""Helper functions and global data""" """Helper functions and global data"""
# Operators dependent on original layouts.
LAYOUT_FIXED_OP = ["batch_flatten", "transpose", "reshape",
"multibox_prior", "multibox_transform_loc", "where",
"non_max_suppression", "strided_slice"]
# We set a large time to represent an invalid layout-transformation. # We set a large time to represent an invalid layout-transformation.
# This number is set to be 10e9 seconds to align with autotvm. # This number is set to be 10e9 seconds to align with autotvm.
INVALID_LAYOUT_TIME = 10e9 INVALID_LAYOUT_TIME = 10e9
......
...@@ -444,6 +444,7 @@ class BaseGraphTuner(object): ...@@ -444,6 +444,7 @@ class BaseGraphTuner(object):
timeout=timeout) timeout=timeout)
measure_option = autotvm.measure_option(builder=builder, runner=runner) measure_option = autotvm.measure_option(builder=builder, runner=runner)
for args in args_list: for args in args_list:
data, in_layout, out_layout = args
args = serialize_args(args) args = serialize_args(args)
ltf_workload = ('layout_transform',) + autotvm.task.args_to_workload(args) ltf_workload = ('layout_transform',) + autotvm.task.args_to_workload(args)
if ltf_workload in self._layout_transform_perf_records: if ltf_workload in self._layout_transform_perf_records:
...@@ -454,7 +455,18 @@ class BaseGraphTuner(object): ...@@ -454,7 +455,18 @@ class BaseGraphTuner(object):
flops = 1 flops = 1
for i in input_shape: for i in input_shape:
flops *= i flops *= i
inferred_time = flops * avg_time
# Rule out invalid layout transformations
out = topi.layout_transform(data, in_layout, out_layout)
out_flops = 1
for i in topi.util.get_const_tuple(out.shape):
out_flops *= i
if flops != out_flops:
inferred_time = INVALID_LAYOUT_TIME
else:
inferred_time = flops * avg_time
record_input = MeasureInput(target=self._target, task=None, config=None) record_input = MeasureInput(target=self._target, task=None, config=None)
record_output = MeasureResult(costs=(inferred_time,), error_no=0, record_output = MeasureResult(costs=(inferred_time,), error_no=0,
all_cost=-1, timestamp=-1) all_cost=-1, timestamp=-1)
......
...@@ -26,7 +26,7 @@ from tvm.relay.expr import Call, Function, TupleGetItem, Var, Constant, Tuple ...@@ -26,7 +26,7 @@ from tvm.relay.expr import Call, Function, TupleGetItem, Var, Constant, Tuple
from tvm.relay.ty import TupleType, TensorType from tvm.relay.ty import TupleType, TensorType
from tvm.autotvm.task import TaskExtractEnv from tvm.autotvm.task import TaskExtractEnv
from .utils import has_multiple_inputs, is_boundary_node from .utils import has_multiple_inputs, is_boundary_node, is_skipped_node
# Setup relay op base name -> topi compute functions # Setup relay op base name -> topi compute functions
...@@ -252,7 +252,7 @@ def get_in_nodes(node_list, target_ops, input_names): ...@@ -252,7 +252,7 @@ def get_in_nodes(node_list, target_ops, input_names):
visited_dict = {} visited_dict = {}
in_node_dict = {} in_node_dict = {}
for i, node in enumerate(node_list): for i, node in enumerate(node_list):
if is_boundary_node(node, input_names): if is_boundary_node(node, input_names) or is_skipped_node(node):
continue continue
get_direct_ancestor(node_list, visited_dict, target_ops, i, input_names) get_direct_ancestor(node_list, visited_dict, target_ops, i, input_names)
for key, val in visited_dict.items(): for key, val in visited_dict.items():
...@@ -282,10 +282,12 @@ def get_in_nodes(node_list, target_ops, input_names): ...@@ -282,10 +282,12 @@ def get_in_nodes(node_list, target_ops, input_names):
boundary_nodes.append(key) boundary_nodes.append(key)
if boundary_nodes: if boundary_nodes:
for idx in boundary_nodes: for idx in boundary_nodes:
del in_node_dict[idx] if idx in in_node_dict:
del in_node_dict[idx]
else: else:
has_reduced_node = False has_reduced_node = False
# Remove empty nodes to ignore pre-computed sub-graph # Remove empty nodes to ignore pre-computed sub-graph
has_empty_node = True has_empty_node = True
while has_empty_node: while has_empty_node:
......
...@@ -19,8 +19,6 @@ ...@@ -19,8 +19,6 @@
from tvm import relay from tvm import relay
from tvm.relay import transform from tvm.relay import transform
from .._base import LAYOUT_FIXED_OP
def has_multiple_inputs(node_list, node_idx, input_names): def has_multiple_inputs(node_list, node_idx, input_names):
"""Check whether a node has multiple input nodes """Check whether a node has multiple input nodes
...@@ -72,11 +70,35 @@ def is_boundary_node(node_entry, input_names): ...@@ -72,11 +70,35 @@ def is_boundary_node(node_entry, input_names):
out : bool out : bool
whether node is a boundary node. whether node is a boundary node.
""" """
out = node_entry["op"] in LAYOUT_FIXED_OP or \ # Operators dependent on original layouts.
_LAYOUT_FIXED_OP = ["batch_flatten", "transpose", "reshape",
"multibox_prior", "multibox_transform_loc", "where",
"non_max_suppression", "strided_slice"]
out = node_entry["op"] in _LAYOUT_FIXED_OP or \
("name" in node_entry and node_entry["name"] in input_names) ("name" in node_entry and node_entry["name"] in input_names)
return out return out
def is_skipped_node(node_entry):
"""Whether a node is not counted.
Parameters
----------
node_entry : dict
Node entry.
Returns
-------
out : bool
whether node is skipped.
"""
# Operators not counted in graph tuner.
_SKIPPED_OP = ["Tuple"]
return node_entry["op"] in _SKIPPED_OP
def bind_inputs(expr, input_shapes=None, input_dtypes="float32"): def bind_inputs(expr, input_shapes=None, input_dtypes="float32"):
"""Bind input variables of a relay function expression """Bind input variables of a relay function expression
to new shapes and/or dtypes. to new shapes and/or dtypes.
......
...@@ -354,25 +354,107 @@ def test_many_sub_graphs(): ...@@ -354,25 +354,107 @@ def test_many_sub_graphs():
ms_output = MeasureResult(costs=(1.91224744e-05,), error_no=0, all_cost=-1, timestamp=-1) ms_output = MeasureResult(costs=(1.91224744e-05,), error_no=0, all_cost=-1, timestamp=-1)
ltf_records.append((ms_input, ms_output)) ltf_records.append((ms_input, ms_output))
ltf_keys = [] executor = DPTuner(net, {"data": dshape}, records, target_ops, target)
ltf_arg = [tvm.placeholder((1, 4, 8, 8, 4), dtype=dtype), "NCHW4c", "NCHW8c"] executor.benchmark_layout_transform(layout_records=ltf_records, infer_layout=True)
ltf_arg = autotvm.task.topi_integration.serialize_args(ltf_arg) executor.run()
ltf_wkl = ('layout_transform',) + autotvm.task.args_to_workload(ltf_arg) out = [record[0].config for record in executor.get_optimal_records()]
ltf_keys.append(ltf_wkl) expected_out = [records[3][0].config, records[1][0].config, records[2][0].config]
ltf_arg = [tvm.placeholder((1, 1, 8, 8, 32), dtype=dtype), "NCHW32c", "NCHW4c"] assert expected_out == out, "Output mismatch: expecting %s but got %s" \
ltf_arg = autotvm.task.topi_integration.serialize_args(ltf_arg) % (str(expected_out), str(out))
ltf_wkl = ('layout_transform',) + autotvm.task.args_to_workload(ltf_arg)
ltf_keys.append(ltf_wkl) executor = PBQPTuner(net, {"data": dshape}, records, target_ops, target)
ltf_arg = [tvm.placeholder((1, 4, 8, 8, 8), dtype=dtype), "NCHW8c", "NCHW32c"] executor.benchmark_layout_transform(layout_records=ltf_records, infer_layout=True)
executor.run()
out = [record[0].config for record in executor.get_optimal_records()]
expected_out = [records[3][0].config, records[1][0].config, records[2][0].config]
assert expected_out == out, "Output mismatch: expecting %s but got %s" \
% (str(expected_out), str(out))
def test_tuple():
target = "llvm"
dtype = "float32"
dshape = (1, 5, 32, 32)
layout = "NCHW"
target_ops = [relay.nn.conv2d]
data = relay.var("data", shape=dshape, dtype=dtype)
w0 = relay.var("w0_weight")
conv0 = relay.nn.conv2d(data, w0, channels=2, kernel_size=(3, 3), padding=(1, 1))
w1 = relay.var("w1_weight")
conv1 = relay.nn.conv2d(data, w1, channels=3, kernel_size=(3, 3), padding=(1, 1))
out = relay.concatenate([conv0, conv1], axis=1)
net = relay.Function(relay.analysis.free_vars(out), out)
net, params = relay.testing.create_workload(net)
tasks = autotvm.task.extract_from_program(net["main"],
target=target,
params=params,
ops=(relay.op.nn.conv2d,))
wkl_list = [
create_workload((1, 5, 32, 32), (2, 5, 3, 3), (1, 1), (1, 1), (1, 1), layout, layout, dtype, dtype),
create_workload((1, 5, 32, 32), (3, 5, 3, 3), (1, 1), (1, 1), (1, 1), layout, layout, dtype, dtype),
]
costs = [0.01, 0.012, 0.03, 0.04]
config_list = []
cfg_dict = {"i": -1,
"c": None,
"e": [["tile_ic", "sp", [1, 5]],
["tile_oc", "sp", [1, 2]],
["tile_ow", "sp", [4, 8]],
["unroll_kw", "ot", True]],
"t": ""}
config_list.append(ConfigEntity.from_json_dict(cfg_dict))
cfg_dict = {"i": -1,
"c": None,
"e": [["tile_ic", "sp", [1, 5]],
["tile_oc", "sp", [1, 3]],
["tile_ow", "sp", [2, 16]],
["unroll_kw", "ot", False]],
"t": ""}
config_list.append(ConfigEntity.from_json_dict(cfg_dict))
cfg_dict = {"i": -1,
"c": None,
"e": [["tile_ic", "sp", [1, 5]],
["tile_oc", "sp", [2, 1]],
["tile_ow", "sp", [4, 8]],
["unroll_kw", "ot", True]],
"t": ""}
config_list.append(ConfigEntity.from_json_dict(cfg_dict))
cfg_dict = {"i": -1,
"c": None,
"e": [["tile_ic", "sp", [1, 5]],
["tile_oc", "sp", [3, 1]],
["tile_ow", "sp", [2, 16]],
["unroll_kw", "ot", False]],
"t": ""}
config_list.append(ConfigEntity.from_json_dict(cfg_dict))
records = []
wkl_list = wkl_list + wkl_list
tasks = tasks + tasks
for wkl, cost, config, task in zip(wkl_list, costs, config_list, tasks):
task.workload = wkl
ms_input = MeasureInput(target=target, task=task, config=config)
ms_output = MeasureResult(costs=(cost,), error_no=0, all_cost=-1, timestamp=-1)
records.append((ms_input, ms_output))
ltf_records = []
ltf_arg = [tvm.placeholder((1, 64, 16, 16, 8), dtype=dtype), "NCHW8c", "NCHW512c"]
ltf_arg = autotvm.task.topi_integration.serialize_args(ltf_arg) ltf_arg = autotvm.task.topi_integration.serialize_args(ltf_arg)
ltf_wkl = ('layout_transform',) + autotvm.task.args_to_workload(ltf_arg) ltf_wkl = ('layout_transform',) + autotvm.task.args_to_workload(ltf_arg)
ltf_keys.append(ltf_wkl) ltf_task = copy.deepcopy(tasks[0])
ltf_task.workload = ltf_wkl
ms_input = MeasureInput(target=target, task=ltf_task, config=None)
ms_output = MeasureResult(costs=(1.91224744e-05,), error_no=0, all_cost=-1, timestamp=-1)
ltf_records.append((ms_input, ms_output))
executor = DPTuner(net, {"data": dshape}, records, target_ops, target) executor = DPTuner(net, {"data": dshape}, records, target_ops, target)
executor.benchmark_layout_transform(layout_records=ltf_records, infer_layout=True) executor.benchmark_layout_transform(layout_records=ltf_records, infer_layout=True)
executor.run() executor.run()
out = [record[0].config for record in executor.get_optimal_records()] out = [record[0].config for record in executor.get_optimal_records()]
expected_out = [records[3][0].config, records[1][0].config, records[2][0].config] expected_out = [records[2][0].config, records[1][0].config]
assert expected_out == out, "Output mismatch: expecting %s but got %s" \ assert expected_out == out, "Output mismatch: expecting %s but got %s" \
% (str(expected_out), str(out)) % (str(expected_out), str(out))
...@@ -380,7 +462,7 @@ def test_many_sub_graphs(): ...@@ -380,7 +462,7 @@ def test_many_sub_graphs():
executor.benchmark_layout_transform(layout_records=ltf_records, infer_layout=True) executor.benchmark_layout_transform(layout_records=ltf_records, infer_layout=True)
executor.run() executor.run()
out = [record[0].config for record in executor.get_optimal_records()] out = [record[0].config for record in executor.get_optimal_records()]
expected_out = [records[3][0].config, records[1][0].config, records[2][0].config] expected_out = [records[2][0].config, records[1][0].config]
assert expected_out == out, "Output mismatch: expecting %s but got %s" \ assert expected_out == out, "Output mismatch: expecting %s but got %s" \
% (str(expected_out), str(out)) % (str(expected_out), str(out))
...@@ -390,3 +472,4 @@ if __name__=="__main__": ...@@ -390,3 +472,4 @@ if __name__=="__main__":
test_DPTuner_run() test_DPTuner_run()
test_PBQPTuner_run() test_PBQPTuner_run()
test_many_sub_graphs() test_many_sub_graphs()
test_tuple()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment