Unverified Commit 2e913f0b by zhen-jia Committed by GitHub

[Graph tuner]Add opt out operator for has_multiple_inputs for graph tuner (#5000)

* consider layout_transform in has_multiple_inputs

* refactor code

* remove debug info

* remove subclass assignment

* refactoring a little bit

* remove default value

* remove trailing whitespace

* modify test for has_multiple_inputs

Co-authored-by: Ubuntu <ubuntu@ip-172-31-40-194.us-west-2.compute.internal>
parent 64bc9978
......@@ -23,3 +23,5 @@
INVALID_LAYOUT_TIME = 10e9
MAX_OUTPUT_NODES = 16
OPT_OUT_OP = ["layout_transform"]
......@@ -34,6 +34,7 @@ from .utils import is_boundary_node, get_in_nodes, get_out_nodes, has_multiple_i
bind_inputs, expr2graph
from ._base import INVALID_LAYOUT_TIME
from ._base import OPT_OUT_OP
def get_infer_layout(task_name):
if task_name.startswith("conv2d"):
......@@ -153,6 +154,7 @@ class BaseGraphTuner(object):
self._in_nodes_dict = get_in_nodes(self._node_list, self._target_ops, input_shapes.keys())
self._out_nodes_dict = get_out_nodes(self._in_nodes_dict)
self._fetch_cfg()
self._opt_out_op = OPT_OUT_OP
# Setup infer_layout for elemwise-like nodes
# Note: graph tuner currently only supports tuning of single input and single output
......@@ -162,7 +164,7 @@ class BaseGraphTuner(object):
# elemwise-like node, and use infer_layout function from input op to generate layouts.
input_names = self._input_shapes.keys()
for idx in sorted(self._in_nodes_dict.keys()):
if has_multiple_inputs(self._node_list, idx, input_names):
if has_multiple_inputs(self._node_list, idx, input_names, self._opt_out_op):
node_entry = self._node_list[idx]
node_entry["topi_op"] = []
node_entry["workloads"] = []
......@@ -246,7 +248,7 @@ class BaseGraphTuner(object):
node_entry = self._node_list[key]
target_input_idx = -1
target_input_pos = -1
if has_multiple_inputs(self._node_list, key, input_names):
if has_multiple_inputs(self._node_list, key, input_names, self._opt_out_op):
for i, item in enumerate(val):
node = self._node_list[item]
if not is_boundary_node(node, input_names):
......
......@@ -144,7 +144,7 @@ class DPTuner(BaseGraphTuner):
continue
optimal_sch_idx = optimal_record_dict[node_idx]
full_states = self._stage_dict[node_idx].full_states
if not has_multiple_inputs(self._node_list, node_idx, input_names):
if not has_multiple_inputs(self._node_list, node_idx, input_names, self._opt_out_op):
input_idx = self._in_nodes_dict[node_idx][0]
input_node = self._node_list[input_idx]
if is_boundary_node(input_node, input_names):
......
......@@ -249,7 +249,7 @@ class PBQPTuner(BaseGraphTuner):
for key, val in self._in_nodes_dict.items():
target_input_idx = -1
target_input_pos = -1
if has_multiple_inputs(self._node_list, key, input_names):
if has_multiple_inputs(self._node_list, key, input_names, self._opt_out_op):
for i, item in enumerate(val):
node = self._node_list[item]
if not is_boundary_node(node, input_names):
......
......@@ -26,7 +26,7 @@ from tvm.relay.ty import TupleType, TensorType
from tvm.autotvm.task import TaskExtractEnv
from .utils import has_multiple_inputs, is_boundary_node, is_skipped_node
from .._base import OPT_OUT_OP
def expr2graph(expr, target_ops, node_dict, node_list):
"""Convert relay expr to graph data structure
......@@ -204,7 +204,8 @@ def get_direct_ancestor(node_list, visited_dict, target_ops, node_idx, input_nam
node_direct_ancestor = []
for item_idx in node["inputs"]:
item = node_list[item_idx[0]]
is_multiple_inputs = has_multiple_inputs(node_list, item_idx[0], input_names)
is_multiple_inputs = has_multiple_inputs(node_list, item_idx[0], \
input_names, OPT_OUT_OP)
if item["op"] in target_ops or is_multiple_inputs:
node_direct_ancestor.append(item_idx[0])
else:
......@@ -245,7 +246,8 @@ def get_in_nodes(node_list, target_ops, input_names):
get_direct_ancestor(node_list, visited_dict, target_ops, i, input_names)
for key, val in visited_dict.items():
node = node_list[key]
is_multiple_inputs = has_multiple_inputs(node_list, key, input_names)
is_multiple_inputs = has_multiple_inputs(node_list, key, \
input_names, OPT_OUT_OP)
if node["op"] in target_ops or is_multiple_inputs:
in_node_dict[key] = val
......
......@@ -20,8 +20,7 @@ import tvm
from tvm import relay
from tvm.relay import transform
def has_multiple_inputs(node_list, node_idx, input_names):
def has_multiple_inputs(node_list, node_idx, input_names, opt_out_op):
"""Check whether a node has multiple input nodes
except variable nodes.
......@@ -47,7 +46,14 @@ def has_multiple_inputs(node_list, node_idx, input_names):
in_idx = in_idx[0]
in_node = node_list[in_idx]
# Exclude parameter nodes
if in_node["op"] is not None or \
if(in_node["op"] is not None and in_node["op"].name in opt_out_op):
increase = False
for t_idx in in_node["inputs"]:
increase = has_multiple_inputs(node_list, t_idx[0], \
input_names, opt_out_op)
if increase:
num_inputs += 1
elif in_node["op"] is not None or \
("name" in in_node and in_node["name"] in input_names):
num_inputs += 1
return num_inputs > 1
......
......@@ -27,11 +27,12 @@ from tvm import autotvm, relay
from tvm.relay.testing import resnet
from tvm.autotvm.graph_tuner.utils import has_multiple_inputs, get_direct_ancestor, get_in_nodes, \
get_out_nodes, expr2graph, bind_inputs
from tvm.autotvm.graph_tuner._base import OPT_OUT_OP
from tvm.relay.expr import Call, TupleGetItem, Tuple, Var
def verify_has_multiple_inputs(node_list, node_idx, input_names, expected_result):
out = has_multiple_inputs(node_list, node_idx, input_names)
out = has_multiple_inputs(node_list, node_idx, input_names, OPT_OUT_OP)
assert out == expected_result, "Output mismatch: expecting checking %s to be %s but got %s." \
% (node_list[node_idx]["op"], str(expected_result), str(out))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment