Unverified Commit 2e913f0b by zhen-jia Committed by GitHub

[Graph tuner]Add opt out operator for has_multiple_inputs for graph tuner (#5000)

* consider layout_transform in has_multiple_inputs

* refactor code

* remove debug info

* remove subclass assignment

* refactoring a little bit

* remove default value

* remove trailing whitespace

* modify test for has_multiple_inputs

Co-authored-by: Ubuntu <ubuntu@ip-172-31-40-194.us-west-2.compute.internal>
parent 64bc9978
...@@ -23,3 +23,5 @@ ...@@ -23,3 +23,5 @@
INVALID_LAYOUT_TIME = 10e9 INVALID_LAYOUT_TIME = 10e9
MAX_OUTPUT_NODES = 16 MAX_OUTPUT_NODES = 16
OPT_OUT_OP = ["layout_transform"]
...@@ -34,6 +34,7 @@ from .utils import is_boundary_node, get_in_nodes, get_out_nodes, has_multiple_i ...@@ -34,6 +34,7 @@ from .utils import is_boundary_node, get_in_nodes, get_out_nodes, has_multiple_i
bind_inputs, expr2graph bind_inputs, expr2graph
from ._base import INVALID_LAYOUT_TIME from ._base import INVALID_LAYOUT_TIME
from ._base import OPT_OUT_OP
def get_infer_layout(task_name): def get_infer_layout(task_name):
if task_name.startswith("conv2d"): if task_name.startswith("conv2d"):
...@@ -153,6 +154,7 @@ class BaseGraphTuner(object): ...@@ -153,6 +154,7 @@ class BaseGraphTuner(object):
self._in_nodes_dict = get_in_nodes(self._node_list, self._target_ops, input_shapes.keys()) self._in_nodes_dict = get_in_nodes(self._node_list, self._target_ops, input_shapes.keys())
self._out_nodes_dict = get_out_nodes(self._in_nodes_dict) self._out_nodes_dict = get_out_nodes(self._in_nodes_dict)
self._fetch_cfg() self._fetch_cfg()
self._opt_out_op = OPT_OUT_OP
# Setup infer_layout for elemwise-like nodes # Setup infer_layout for elemwise-like nodes
# Note: graph tuner currently only supports tuning of single input and single output # Note: graph tuner currently only supports tuning of single input and single output
...@@ -162,7 +164,7 @@ class BaseGraphTuner(object): ...@@ -162,7 +164,7 @@ class BaseGraphTuner(object):
# elemwise-like node, and use infer_layout function from input op to generate layouts. # elemwise-like node, and use infer_layout function from input op to generate layouts.
input_names = self._input_shapes.keys() input_names = self._input_shapes.keys()
for idx in sorted(self._in_nodes_dict.keys()): for idx in sorted(self._in_nodes_dict.keys()):
if has_multiple_inputs(self._node_list, idx, input_names): if has_multiple_inputs(self._node_list, idx, input_names, self._opt_out_op):
node_entry = self._node_list[idx] node_entry = self._node_list[idx]
node_entry["topi_op"] = [] node_entry["topi_op"] = []
node_entry["workloads"] = [] node_entry["workloads"] = []
...@@ -246,7 +248,7 @@ class BaseGraphTuner(object): ...@@ -246,7 +248,7 @@ class BaseGraphTuner(object):
node_entry = self._node_list[key] node_entry = self._node_list[key]
target_input_idx = -1 target_input_idx = -1
target_input_pos = -1 target_input_pos = -1
if has_multiple_inputs(self._node_list, key, input_names): if has_multiple_inputs(self._node_list, key, input_names, self._opt_out_op):
for i, item in enumerate(val): for i, item in enumerate(val):
node = self._node_list[item] node = self._node_list[item]
if not is_boundary_node(node, input_names): if not is_boundary_node(node, input_names):
......
...@@ -144,7 +144,7 @@ class DPTuner(BaseGraphTuner): ...@@ -144,7 +144,7 @@ class DPTuner(BaseGraphTuner):
continue continue
optimal_sch_idx = optimal_record_dict[node_idx] optimal_sch_idx = optimal_record_dict[node_idx]
full_states = self._stage_dict[node_idx].full_states full_states = self._stage_dict[node_idx].full_states
if not has_multiple_inputs(self._node_list, node_idx, input_names): if not has_multiple_inputs(self._node_list, node_idx, input_names, self._opt_out_op):
input_idx = self._in_nodes_dict[node_idx][0] input_idx = self._in_nodes_dict[node_idx][0]
input_node = self._node_list[input_idx] input_node = self._node_list[input_idx]
if is_boundary_node(input_node, input_names): if is_boundary_node(input_node, input_names):
......
...@@ -249,7 +249,7 @@ class PBQPTuner(BaseGraphTuner): ...@@ -249,7 +249,7 @@ class PBQPTuner(BaseGraphTuner):
for key, val in self._in_nodes_dict.items(): for key, val in self._in_nodes_dict.items():
target_input_idx = -1 target_input_idx = -1
target_input_pos = -1 target_input_pos = -1
if has_multiple_inputs(self._node_list, key, input_names): if has_multiple_inputs(self._node_list, key, input_names, self._opt_out_op):
for i, item in enumerate(val): for i, item in enumerate(val):
node = self._node_list[item] node = self._node_list[item]
if not is_boundary_node(node, input_names): if not is_boundary_node(node, input_names):
......
...@@ -26,7 +26,7 @@ from tvm.relay.ty import TupleType, TensorType ...@@ -26,7 +26,7 @@ from tvm.relay.ty import TupleType, TensorType
from tvm.autotvm.task import TaskExtractEnv from tvm.autotvm.task import TaskExtractEnv
from .utils import has_multiple_inputs, is_boundary_node, is_skipped_node from .utils import has_multiple_inputs, is_boundary_node, is_skipped_node
from .._base import OPT_OUT_OP
def expr2graph(expr, target_ops, node_dict, node_list): def expr2graph(expr, target_ops, node_dict, node_list):
"""Convert relay expr to graph data structure """Convert relay expr to graph data structure
...@@ -204,7 +204,8 @@ def get_direct_ancestor(node_list, visited_dict, target_ops, node_idx, input_nam ...@@ -204,7 +204,8 @@ def get_direct_ancestor(node_list, visited_dict, target_ops, node_idx, input_nam
node_direct_ancestor = [] node_direct_ancestor = []
for item_idx in node["inputs"]: for item_idx in node["inputs"]:
item = node_list[item_idx[0]] item = node_list[item_idx[0]]
is_multiple_inputs = has_multiple_inputs(node_list, item_idx[0], input_names) is_multiple_inputs = has_multiple_inputs(node_list, item_idx[0], \
input_names, OPT_OUT_OP)
if item["op"] in target_ops or is_multiple_inputs: if item["op"] in target_ops or is_multiple_inputs:
node_direct_ancestor.append(item_idx[0]) node_direct_ancestor.append(item_idx[0])
else: else:
...@@ -245,7 +246,8 @@ def get_in_nodes(node_list, target_ops, input_names): ...@@ -245,7 +246,8 @@ def get_in_nodes(node_list, target_ops, input_names):
get_direct_ancestor(node_list, visited_dict, target_ops, i, input_names) get_direct_ancestor(node_list, visited_dict, target_ops, i, input_names)
for key, val in visited_dict.items(): for key, val in visited_dict.items():
node = node_list[key] node = node_list[key]
is_multiple_inputs = has_multiple_inputs(node_list, key, input_names) is_multiple_inputs = has_multiple_inputs(node_list, key, \
input_names, OPT_OUT_OP)
if node["op"] in target_ops or is_multiple_inputs: if node["op"] in target_ops or is_multiple_inputs:
in_node_dict[key] = val in_node_dict[key] = val
......
...@@ -20,8 +20,7 @@ import tvm ...@@ -20,8 +20,7 @@ import tvm
from tvm import relay from tvm import relay
from tvm.relay import transform from tvm.relay import transform
def has_multiple_inputs(node_list, node_idx, input_names, opt_out_op):
def has_multiple_inputs(node_list, node_idx, input_names):
"""Check whether a node has multiple input nodes """Check whether a node has multiple input nodes
except variable nodes. except variable nodes.
...@@ -47,7 +46,14 @@ def has_multiple_inputs(node_list, node_idx, input_names): ...@@ -47,7 +46,14 @@ def has_multiple_inputs(node_list, node_idx, input_names):
in_idx = in_idx[0] in_idx = in_idx[0]
in_node = node_list[in_idx] in_node = node_list[in_idx]
# Exclude parameter nodes # Exclude parameter nodes
if in_node["op"] is not None or \ if(in_node["op"] is not None and in_node["op"].name in opt_out_op):
increase = False
for t_idx in in_node["inputs"]:
increase = has_multiple_inputs(node_list, t_idx[0], \
input_names, opt_out_op)
if increase:
num_inputs += 1
elif in_node["op"] is not None or \
("name" in in_node and in_node["name"] in input_names): ("name" in in_node and in_node["name"] in input_names):
num_inputs += 1 num_inputs += 1
return num_inputs > 1 return num_inputs > 1
......
...@@ -27,11 +27,12 @@ from tvm import autotvm, relay ...@@ -27,11 +27,12 @@ from tvm import autotvm, relay
from tvm.relay.testing import resnet from tvm.relay.testing import resnet
from tvm.autotvm.graph_tuner.utils import has_multiple_inputs, get_direct_ancestor, get_in_nodes, \ from tvm.autotvm.graph_tuner.utils import has_multiple_inputs, get_direct_ancestor, get_in_nodes, \
get_out_nodes, expr2graph, bind_inputs get_out_nodes, expr2graph, bind_inputs
from tvm.autotvm.graph_tuner._base import OPT_OUT_OP
from tvm.relay.expr import Call, TupleGetItem, Tuple, Var from tvm.relay.expr import Call, TupleGetItem, Tuple, Var
def verify_has_multiple_inputs(node_list, node_idx, input_names, expected_result): def verify_has_multiple_inputs(node_list, node_idx, input_names, expected_result):
out = has_multiple_inputs(node_list, node_idx, input_names) out = has_multiple_inputs(node_list, node_idx, input_names, OPT_OUT_OP)
assert out == expected_result, "Output mismatch: expecting checking %s to be %s but got %s." \ assert out == expected_result, "Output mismatch: expecting checking %s to be %s but got %s." \
% (node_list[node_idx]["op"], str(expected_result), str(out)) % (node_list[node_idx]["op"], str(expected_result), str(out))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment