Commit be260836 by Yao Wang Committed by Yizhi Liu

[AutoTVM]Improve graph tuner for multiple subgraphs (#3490)

* Improve boundary nodes in graph tuner

* Limit output node number

* Fix test

* Improve warning.

* Fix test
parent 6b5fbdad
......@@ -18,10 +18,13 @@
"""Helper functions and global data"""
RULE_OUT_NODE_NAMES = ["Tuple", "TupleGetItem", "batch_flatten", "transpose", "reshape",
"multibox_prior", "multibox_transform_loc", "where",
"non_max_suppression", "strided_slice"]
# Operators dependent on original layouts.
LAYOUT_FIXED_OP = ["batch_flatten", "transpose", "reshape",
"multibox_prior", "multibox_transform_loc", "where",
"non_max_suppression", "strided_slice"]
# We set a large time to represent an invalid layout-transformation.
# This number is set to be 10e9 seconds to align with autotvm.
INVALID_LAYOUT_TIME = 10e9
MAX_OUTPUT_NODES = 16
......@@ -30,7 +30,7 @@ from tvm.autotvm.record import encode, load_from_file
from tvm.autotvm.measure import MeasureResult, MeasureInput
from ... import target as _target
from .utils import is_input_node, get_in_nodes, get_out_nodes, has_multiple_inputs, \
from .utils import is_boundary_node, get_in_nodes, get_out_nodes, has_multiple_inputs, \
bind_inputs, expr2graph
from ._base import INVALID_LAYOUT_TIME
......@@ -170,7 +170,7 @@ class BaseGraphTuner(object):
node_entry["workloads"] = []
for input_idx in self._in_nodes_dict[idx]:
input_node = self._node_list[input_idx]
if not is_input_node(input_node, input_names):
if not is_boundary_node(input_node, input_names):
input_topi_op = input_node["topi_op"][0]
node_entry["topi_op"].append(input_topi_op)
# Only replace the first input tensor
......@@ -249,7 +249,8 @@ class BaseGraphTuner(object):
target_input_pos = -1
if has_multiple_inputs(self._node_list, key, input_names):
for i, item in enumerate(val):
if not is_input_node(self._node_list[item], input_names):
node = self._node_list[item]
if not is_boundary_node(node, input_names):
target_input_idx = item
target_input_pos = i
break
......@@ -257,7 +258,7 @@ class BaseGraphTuner(object):
for i, item in enumerate(val):
i_idx = item
in_node_entry = self._node_list[i_idx]
if is_input_node(in_node_entry, input_names):
if is_boundary_node(in_node_entry, input_names):
continue
if node_entry["op"] in self._target_ops:
......
......@@ -18,7 +18,7 @@
"""Stage class for dynamic programming tuner"""
import numpy as np
from .utils import is_input_node
from .utils import is_boundary_node
class DPStage(object):
......@@ -102,20 +102,13 @@ class DPStage(object):
def _create_op_states(self):
"""State creation routine for nodes with target_op."""
input_idx = -1
for index in self._global_in_nodes_dict[self._idx]:
input_idx = index
if not is_input_node(self._global_node_list[input_idx],
self._global_input_names):
break
if is_input_node(self._global_node_list[input_idx],
self._global_input_names):
input_idx = self._global_in_nodes_dict[self._idx][0]
input_node_entry = self._global_node_list[input_idx]
if is_boundary_node(input_node_entry, self._global_input_names):
self._full_states = np.array([record[1].costs[0]
for record in self._record_list])
self._states = self._full_states
else:
input_node_entry = self._global_node_list[input_idx]
input_stage = self._global_stage_dict[input_idx]
input_dep = input_stage.dep
input_states = input_stage.states
......@@ -202,10 +195,10 @@ class DPStage(object):
"""
full_input_node_list = list(self._global_in_nodes_dict[self._idx])
input_index_list = []
# Remove input and parameter nodes
# Remove input and ruled_out nodes
for input_idx in full_input_node_list:
if not is_input_node(self._global_node_list[input_idx],
self._global_input_names):
input_node = self._global_node_list[input_idx]
if not is_boundary_node(input_node, self._global_input_names):
input_index_list.append(input_idx)
# Generate new states
......@@ -331,8 +324,9 @@ class DPStage(object):
for dep_idx in input_node_stage.dep:
if dep_idx not in aligned_node_list:
aligned_node_list.append(dep_idx)
aligned_shape = tuple([len(node_list[idx]["record_candidates"])
for idx in aligned_node_list])
aligned_shape = []
for idx in aligned_node_list:
aligned_shape.append(len(node_list[idx]["record_candidates"]))
for input_idx in input_index_list:
input_node_stage = stage_dict[input_idx]
input_node_shape_idx_list = [input_idx] + input_node_stage.dep
......
......@@ -19,9 +19,10 @@
import sys
import numpy as np
from ._base import MAX_OUTPUT_NODES
from .base_graph_tuner import BaseGraphTuner
from .dynamic_programming_stage import DPStage
from .utils import has_multiple_inputs, is_input_node
from .utils import has_multiple_inputs, is_boundary_node
if sys.version_info[0] == 3:
import queue
......@@ -88,6 +89,18 @@ class DPTuner(BaseGraphTuner):
for key, val in self._out_nodes_dict.items():
if not val:
output_idx_list.append(key)
# Restrict number of output nodes to avoid numpy reshape error
if len(output_idx_list) > MAX_OUTPUT_NODES:
msg = "The number of outputs in graph is larger than upper " \
"limit: %s vs %s. Usually this is caused by too many " \
"LAYOUT_FIXED_OP in graph. Switch to greedily select schedule." \
"No action required at this moment. We will continuously improve graph tuner" \
% (len(output_idx_list), MAX_OUTPUT_NODES)
self._logger.warning(msg)
self._optimal_record_dict = {key : 0 for key in self._in_nodes_dict}
return
states_list, aligned_node_list = DPStage.align_states(output_idx_list, self._stage_dict,
self._node_list)
num_states = states_list[0][3].size
......@@ -126,13 +139,15 @@ class DPTuner(BaseGraphTuner):
while not bfs_q.empty():
node_idx = bfs_q.get()
visited.add(node_idx)
if is_input_node(self._node_list[node_idx], input_names):
node = self._node_list[node_idx]
if is_boundary_node(node, input_names):
continue
optimal_sch_idx = optimal_record_dict[node_idx]
full_states = self._stage_dict[node_idx].full_states
if not has_multiple_inputs(self._node_list, node_idx, input_names):
input_idx = self._in_nodes_dict[node_idx][0]
if is_input_node(self._node_list[input_idx], input_names):
input_node = self._node_list[input_idx]
if is_boundary_node(input_node, input_names):
continue
if input_idx not in visited:
bfs_q.put(input_idx)
......
......@@ -18,7 +18,7 @@
"""Partitioned Boolean Quadratic Programming Tuner"""
from ._base import INVALID_LAYOUT_TIME
from .base_graph_tuner import BaseGraphTuner
from .utils import is_input_node, has_multiple_inputs
from .utils import is_boundary_node, has_multiple_inputs
class PBQPTuner(BaseGraphTuner):
......@@ -36,10 +36,11 @@ class PBQPTuner(BaseGraphTuner):
"""
super(PBQPTuner, self).__init__(*args, **kwargs)
# Remove input nodes
# Remove input and ruled_out nodes
input_names = self._input_shapes.keys()
for node_idx in self._out_nodes_dict:
if is_input_node(self._node_list[node_idx], input_names):
node = self._node_list[node_idx]
if is_boundary_node(node, input_names):
for out_node_idx in self._out_nodes_dict[node_idx]:
self._in_nodes_dict[out_node_idx].remove(node_idx)
......@@ -250,10 +251,16 @@ class PBQPTuner(BaseGraphTuner):
target_input_pos = -1
if has_multiple_inputs(self._node_list, key, input_names):
for i, item in enumerate(val):
if not is_input_node(self._node_list[item], input_names):
node = self._node_list[item]
if not is_boundary_node(node, input_names):
target_input_idx = item
target_input_pos = i
break
# Skip boundary operator
if target_input_idx < 0:
continue
temp[(target_input_idx, key)] = []
record_candidates = self._node_list[target_input_idx]["record_candidates"]
for j in range(len(record_candidates)):
......@@ -264,7 +271,8 @@ class PBQPTuner(BaseGraphTuner):
for j in range(target_input_pos + 1, len(val)):
input_idx = val[j]
if is_input_node(self._node_list[input_idx], input_names):
input_node = self._node_list[input_idx]
if is_boundary_node(input_node, input_names):
continue
temp[(input_idx, key)] = \
self._layout_transform_interlayer_cost[(input_idx, target_input_idx)]
......
......@@ -23,4 +23,4 @@ from . import utils
from .traverse_graph import expr2graph, get_direct_ancestor, get_in_nodes, \
get_out_nodes
from .utils import has_multiple_inputs, is_input_node, bind_inputs
from .utils import has_multiple_inputs, is_boundary_node, bind_inputs
......@@ -26,8 +26,7 @@ from tvm.relay.expr import Call, Function, TupleGetItem, Var, Constant, Tuple
from tvm.relay.ty import TupleType, TensorType
from tvm.autotvm.task import TaskExtractEnv
from .._base import RULE_OUT_NODE_NAMES
from .utils import has_multiple_inputs, is_input_node
from .utils import has_multiple_inputs, is_boundary_node
# Setup relay op base name -> topi compute functions
......@@ -210,19 +209,9 @@ def get_direct_ancestor(node_list, visited_dict, target_ops, node_idx, input_nam
"""
if node_idx in visited_dict:
return visited_dict[node_idx]
if is_input_node(node_list[node_idx], input_names):
return [node_idx]
node = node_list[node_idx]
# Rule out injective operators
is_rule_out = False
for item_idx in node["inputs"]:
item = node_list[item_idx[0]]
if item["op"] in RULE_OUT_NODE_NAMES:
is_rule_out = True
break
if is_rule_out:
visited_dict[node_idx] = []
return []
if is_boundary_node(node, input_names):
return [node_idx]
node_direct_ancestor = []
for item_idx in node["inputs"]:
......@@ -235,14 +224,12 @@ def get_direct_ancestor(node_list, visited_dict, target_ops, node_idx, input_nam
item_idx[0], input_names)
for tmp_item in tmp:
node_direct_ancestor.append(tmp_item)
if not has_multiple_inputs(node_list, node_idx, input_names) and node_direct_ancestor:
node_direct_ancestor = [node_direct_ancestor[0]]
visited_dict[node_idx] = node_direct_ancestor
return node_direct_ancestor
def get_in_nodes(node_list, target_ops, input_names):
"""Create a dictionary mapping from op_name nodes or multi_input
"""Create a dictionary mapping from op_name nodes or multi-input
nodes to closest input ancestors.
Parameters
......@@ -265,7 +252,7 @@ def get_in_nodes(node_list, target_ops, input_names):
visited_dict = {}
in_node_dict = {}
for i, node in enumerate(node_list):
if node["op"] in RULE_OUT_NODE_NAMES:
if is_boundary_node(node, input_names):
continue
get_direct_ancestor(node_list, visited_dict, target_ops, i, input_names)
for key, val in visited_dict.items():
......@@ -274,9 +261,33 @@ def get_in_nodes(node_list, target_ops, input_names):
if node["op"] in target_ops or is_multiple_inputs:
in_node_dict[key] = val
# Remove empty nodes
has_empty_node = True
# Reduce boundary nodes
out_node_dict = get_out_nodes(in_node_dict)
has_reduced_node = True
while has_reduced_node:
boundary_nodes = []
for key, val in in_node_dict.items():
node = node_list[key]
is_boundary = True
# Target ops can't be boundary nodes
if node["op"] not in target_ops:
for input_idx in val:
in_node = node_list[input_idx]
if not is_boundary_node(in_node, input_names) and \
input_idx in in_node_dict:
is_boundary = False
else:
val.remove(input_idx)
if is_boundary:
boundary_nodes.append(key)
if boundary_nodes:
for idx in boundary_nodes:
del in_node_dict[idx]
else:
has_reduced_node = False
# Remove empty nodes to ignore pre-computed sub-graph
has_empty_node = True
while has_empty_node:
empty_nodes = []
for key, val in in_node_dict.items():
......
......@@ -19,6 +19,8 @@
from tvm import relay
from tvm.relay import transform
from .._base import LAYOUT_FIXED_OP
def has_multiple_inputs(node_list, node_idx, input_names):
"""Check whether a node has multiple input nodes
......@@ -46,14 +48,16 @@ def has_multiple_inputs(node_list, node_idx, input_names):
in_idx = in_idx[0]
in_node = node_list[in_idx]
# Exclude parameter nodes
if in_node["op"] != "null" or is_input_node(in_node,
input_names):
if in_node["op"] != "null" or \
("name" in in_node and in_node["name"] in input_names):
num_inputs += 1
return num_inputs > 1
def is_input_node(node_entry, input_names):
"""Whether a node is an input node.
def is_boundary_node(node_entry, input_names):
"""Whether a node is a boundary node.
Currently input node and nodes in LAYOUT_FIXED_OP are
counted as boundary.
Parameters
----------
......@@ -66,9 +70,11 @@ def is_input_node(node_entry, input_names):
Returns
-------
out : bool
whether node is a input node.
whether node is a boundary node.
"""
return "name" in node_entry and node_entry["name"] in input_names
out = node_entry["op"] in LAYOUT_FIXED_OP or \
("name" in node_entry and node_entry["name"] in input_names)
return out
def bind_inputs(expr, input_shapes=None, input_dtypes="float32"):
......
......@@ -250,7 +250,143 @@ def test_PBQPTuner_run():
% (str(expected_out), str(out))
def test_many_sub_graphs():
target = "llvm"
dtype = "float32"
dshape = (1, 8, 8, 3)
layout = "NCHW"
target_ops = [relay.nn.conv2d]
data = relay.var("data", shape=dshape, dtype=dtype)
t0 = relay.transpose(data, (0, 3, 1, 2))
w0 = relay.var("w0_weight")
conv0 = relay.nn.conv2d(t0, w0, channels=16, kernel_size=(3, 3), padding=(1, 1))
t1 = relay.transpose(conv0, (0, 2, 3, 1))
w1 = relay.var("w1_weight")
t2 = relay.transpose(t1, (0, 3, 1, 2))
conv1 = relay.nn.conv2d(t2, w1, channels=32, kernel_size=(1, 1))
t3 = relay.transpose(conv1, (0, 2, 3, 1))
w2 = relay.var("w2_weight")
t4 = relay.transpose(t3, (0, 3, 1, 2))
conv2 = relay.nn.conv2d(t4, w2, channels=32, kernel_size=(3, 3), padding=(1, 1))
t5 = relay.transpose(conv2, (0, 2, 3, 1))
out = relay.add(t3, t5)
net = relay.Function(relay.analysis.free_vars(out), out)
net, params = relay.testing.create_workload(net)
tasks = autotvm.task.extract_from_program(net["main"],
target=target,
params=params,
ops=(relay.op.nn.conv2d,))
wkl_list = [
create_workload((1, 3, 8, 8), (16, 3, 3, 3), (1, 1), (1, 1), (1, 1), layout, layout, dtype, dtype),
create_workload((1, 16, 8, 8), (32, 16, 1, 1), (1, 1), (0, 0), (1, 1), layout, layout, dtype, dtype),
create_workload((1, 32, 8, 8), (32, 32, 3, 3), (1, 1), (1, 1), (1, 1), layout, layout, dtype, dtype),
]
costs = [0.04, 0.012, 0.03, 0.02, 0.02, 0.045]
config_list = []
cfg_dict = {"i": -1,
"c": None,
"e": [["tile_ic", "sp", [3, 1]],
["tile_oc", "sp", [4, 4]],
["tile_ow", "sp", [4, 2]],
["unroll_kw", "ot", True]],
"t": ""}
config_list.append(ConfigEntity.from_json_dict(cfg_dict))
cfg_dict = {"i": -1,
"c": None,
"e": [["tile_ic", "sp", [2, 8]],
["tile_oc", "sp", [1, 32]],
["tile_oh", "ot", 1],
["tile_ow", "sp", [4, 2]]],
"t": ""}
config_list.append(ConfigEntity.from_json_dict(cfg_dict))
cfg_dict = {"i": -1,
"c": None,
"e": [["tile_ic", "sp", [8, 4]],
["tile_oc", "sp", [4, 8]],
["tile_ow", "sp", [2, 4]],
["unroll_kw", "ot", False]],
"t": ""}
config_list.append(ConfigEntity.from_json_dict(cfg_dict))
cfg_dict = {"i": -1,
"c": None,
"e": [["tile_ic", "sp", [1, 3]],
["tile_oc", "sp", [2, 8]],
["tile_ow", "sp", [4, 2]],
["unroll_kw", "ot", True]],
"t": ""}
config_list.append(ConfigEntity.from_json_dict(cfg_dict))
cfg_dict = {"i": -1,
"c": None,
"e": [["tile_ic", "sp", [4, 4]],
["tile_oc", "sp", [2, 16]],
["tile_oh", "ot", 1],
["tile_ow", "sp", [4, 2]]],
"t": ""}
config_list.append(ConfigEntity.from_json_dict(cfg_dict))
cfg_dict = {"i": -1,
"c": None,
"e": [["tile_ic", "sp", [16, 2]],
["tile_oc", "sp", [8, 4]],
["tile_ow", "sp", [2, 4]],
["unroll_kw", "ot", False]],
"t": ""}
config_list.append(ConfigEntity.from_json_dict(cfg_dict))
records = []
wkl_list = wkl_list + wkl_list
tasks = tasks + tasks
for wkl, cost, config, task in zip(wkl_list, costs, config_list, tasks):
task.workload = wkl
ms_input = MeasureInput(target=target, task=task, config=config)
ms_output = MeasureResult(costs=(cost,), error_no=0, all_cost=-1, timestamp=-1)
records.append((ms_input, ms_output))
ltf_records = []
ltf_arg = [tvm.placeholder((1, 64, 16, 16, 8), dtype=dtype), "NCHW8c", "NCHW512c"]
ltf_arg = autotvm.task.topi_integration.serialize_args(ltf_arg)
ltf_wkl = ('layout_transform',) + autotvm.task.args_to_workload(ltf_arg)
ltf_task = copy.deepcopy(tasks[0])
ltf_task.workload = ltf_wkl
ms_input = MeasureInput(target=target, task=ltf_task, config=None)
ms_output = MeasureResult(costs=(1.91224744e-05,), error_no=0, all_cost=-1, timestamp=-1)
ltf_records.append((ms_input, ms_output))
ltf_keys = []
ltf_arg = [tvm.placeholder((1, 4, 8, 8, 4), dtype=dtype), "NCHW4c", "NCHW8c"]
ltf_arg = autotvm.task.topi_integration.serialize_args(ltf_arg)
ltf_wkl = ('layout_transform',) + autotvm.task.args_to_workload(ltf_arg)
ltf_keys.append(ltf_wkl)
ltf_arg = [tvm.placeholder((1, 1, 8, 8, 32), dtype=dtype), "NCHW32c", "NCHW4c"]
ltf_arg = autotvm.task.topi_integration.serialize_args(ltf_arg)
ltf_wkl = ('layout_transform',) + autotvm.task.args_to_workload(ltf_arg)
ltf_keys.append(ltf_wkl)
ltf_arg = [tvm.placeholder((1, 4, 8, 8, 8), dtype=dtype), "NCHW8c", "NCHW32c"]
ltf_arg = autotvm.task.topi_integration.serialize_args(ltf_arg)
ltf_wkl = ('layout_transform',) + autotvm.task.args_to_workload(ltf_arg)
ltf_keys.append(ltf_wkl)
executor = DPTuner(net, {"data": dshape}, records, target_ops, target)
executor.benchmark_layout_transform(layout_records=ltf_records, infer_layout=True)
executor.run()
out = [record[0].config for record in executor.get_optimal_records()]
expected_out = [records[3][0].config, records[1][0].config, records[2][0].config]
assert expected_out == out, "Output mismatch: expecting %s but got %s" \
% (str(expected_out), str(out))
executor = PBQPTuner(net, {"data": dshape}, records, target_ops, target)
executor.benchmark_layout_transform(layout_records=ltf_records, infer_layout=True)
executor.run()
out = [record[0].config for record in executor.get_optimal_records()]
expected_out = [records[3][0].config, records[1][0].config, records[2][0].config]
assert expected_out == out, "Output mismatch: expecting %s but got %s" \
% (str(expected_out), str(out))
if __name__=="__main__":
test_graph_tuner_layout_transform()
test_DPTuner_run()
test_PBQPTuner_run()
test_many_sub_graphs()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment