Commit c8a0f524 by Yao Wang Committed by Yizhi Liu

[AutoTVM]Core functionality for Graph tuner (#2184)

* Add graph tuning

* Add tests

* Fix tests

* Fix pylint

* Small fix for docstring

* Minor fix

* Support fetching workload from relay expr

* Simplify benchmark layout transformation

* Add relay support

* Fix infer layout func name

* Refactor internal data representation

* Fix issues

* Add PBQP solver

* Fix layout transform check

* Add PBQPTuner test

* Fix lint

* Update tutorial

* Fix tutorial

* Fix lint

* Add relay test

* Remove nnvm since nnvm graph can be converted to relay function

* Modify benchmark layout wrt new layout_transform api

* Fix lint

* Update docstring for DP tuner

* Refactor traverse graph

* Support graph tuning for multiple target operators

* Fix fetching workloads

* Add x86 depthwise_conv2d infer_layout

* Fix x86 depthwise_conv2d autotvm

* Fix PBQP tuner

* Fix DP tuner

* Generate dummy layout transform record

* Update tutorial

* Modify layout records name

* Add ASF header

* Add ASF header for testing files

* Fix test

* Fix topi fetching

* Some refactors

* Fix lint

* Fix tutorial

* Rename test files

* Fix doc typo

* Add test case note link
parent 4767554c
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Autotvm graph tuner API."""
from __future__ import absolute_import as _abs
from . import _base
from . import base_graph_tuner
from .base_graph_tuner import BaseGraphTuner
from .dynamic_programming_tuner import DPTuner
from .pbqp_tuner import PBQPTuner
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Helper functions and global data"""
RULE_OUT_NODE_NAMES = ["Tuple", "TupleGetItem", "batch_flatten", "transpose", "reshape",
"multibox_prior", "multibox_transform_loc", "where",
"non_max_suppression", "strided_slice"]
# We set a large time to represent an invalid layout-transformation.
# This number is set to be 10e9 seconds to align with autotvm.
INVALID_LAYOUT_TIME = 10e9
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=too-many-arguments,too-many-locals,too-many-statements,too-many-instance-attributes,too-many-branches,too-many-nested-blocks,invalid-name,unused-argument,unused-variable,no-member,no-value-for-parameter
"""Base class for graph tuner."""
import logging
from abc import abstractmethod
import numpy as np
import topi
import tvm
from tvm import autotvm, relay
from tvm.autotvm.task import get_config
from tvm.autotvm.task.topi_integration import deserialize_args, serialize_args
from tvm.autotvm.record import encode, load_from_file
from tvm.autotvm.measure import MeasureResult, MeasureInput
from ... import target as _target
from .utils import is_input_node, get_in_nodes, get_out_nodes, has_multiple_inputs, \
bind_inputs, expr2graph
from ._base import INVALID_LAYOUT_TIME
# Setup topi_op_name -> layout function
# NOTE: To add more ops, change the following dictionary.
OP2LAYOUT = {
"topi_nn_conv2d": topi.nn.conv2d_infer_layout,
"topi_nn_depthwise_conv2d_nchw": topi.nn.depthwise_conv2d_infer_layout,
}
@autotvm.template
def layout_transform(*args):
"""Autotvm layout transform template."""
args = deserialize_args(args)
cfg = get_config()
cfg.add_flop(-1)
data = args[0]
out = topi.layout_transform(*args)
sch = topi.generic.schedule_injective([out])
return sch, [data, out]
class BaseGraphTuner(object):
"""Class to search schedules considering both kernel execution time and
layout transformation time.
Before creating a Graph Executor instance, schedule candidates for all kernels in
graph should be provided through tensor-level tuning.
"""
def __init__(self, graph, input_shapes, records, target_ops,
target, max_sch_num=20, dtype="float32", verbose=True,
log_file="graph_tuner.log", log_level=logging.DEBUG,
name="graph_tuner"):
"""Create a GlobalTuner instance. Local schedule searching for all nodes with
target_op in the input graph and layout transformation benchmark need to be
executed before initialization.
graph : tvm.relay.Expr.Function
Input graph
input_shapes : dict of str to tuple.
Input shapes of graph
records : str or iterator of (MeasureInput, MeasureResult)
Collection of kernel level tuning records.
If it is str, then it should be the filename of a records log file.
Each row of this file is an encoded record pair.
Otherwise, it is an iterator.
target_ops : List of str
Target tuning operators.
target : str or tvm.target
Compilation target.
max_sch_num : int, optional
Maximum number of schedule candidates for each workload.
dtype : str, optional
Data type.
log_file : str, optional
graph tuner log file name
name : str, optional
Name of global tuner.
"""
self._node_list = []
self._layout_transform_perf_records = {}
self._layout_transform_interlayer_cost = {}
self._input_shapes = input_shapes
self._target_ops = [op.__name__ for op in target_ops]
self._name = name
self._max_sch_num = max_sch_num
self._optimal_sch_dict = {}
self._records = records
self._dtype = dtype
if isinstance(target, str):
target = _target.create(target)
self._target = target
self._optimal_record_dict = {}
# Set up logger
self._verbose = verbose
self._logger = logging.getLogger(name + "_logger")
need_file_handler = need_console_handler = True
for handler in self._logger.handlers:
if handler.__class__.__name__ == 'FileHandler':
need_file_handler = False
if handler.__class__.__name__ == 'StreamHandler':
need_console_handler = False
self._log_level = log_level
self._log_file = log_file
self._formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
self._logger.setLevel(log_level)
if need_file_handler:
file_handler = logging.FileHandler(log_file)
file_handler.setFormatter(self._formatter)
self._logger.addHandler(file_handler)
if self._verbose and need_console_handler:
console_handler = logging.StreamHandler()
console_handler.setFormatter(self._formatter)
self._logger.addHandler(console_handler)
self._logger.setLevel(log_level)
self._logger.propagate = False
# Generate workload and schedule dictionaries.
if isinstance(graph, relay.expr.Function):
node_dict = {}
graph = bind_inputs(graph, input_shapes, dtype)
expr2graph(graph, self._target_ops, node_dict, self._node_list)
else:
raise RuntimeError("Unsupported graph type: %s" % str(type(graph)))
self._graph = graph
self._in_nodes_dict = get_in_nodes(self._node_list, self._target_ops, input_shapes.keys())
self._out_nodes_dict = get_out_nodes(self._in_nodes_dict)
self._fetch_cfg()
# Setup infer_layout for elemwise-like nodes
# Note: graph tuner currently only supports tuning of single input and single output
# op as target op, such as conv2d, dense and conv2d_transpose. In this case, we can
# reuse infer_layout function from target ops for elemwise-like nodes. The behavior
# is to modify the first tensor shape of input workload to the output shape of
# elemwise-like node, and use infer_layout function from input op to generate layouts.
input_names = self._input_shapes.keys()
for idx in sorted(self._in_nodes_dict.keys()):
if has_multiple_inputs(self._node_list, idx, input_names):
node_entry = self._node_list[idx]
node_entry["topi_op"] = []
node_entry["workloads"] = []
for input_idx in self._in_nodes_dict[idx]:
input_node = self._node_list[input_idx]
if not is_input_node(input_node, input_names):
input_topi_op = input_node["topi_op"][0]
node_entry["topi_op"].append(input_topi_op)
# Only replace the first input tensor
input_workload = input_node["workloads"][0]
first_tensor = input_workload[1]
dtype = first_tensor[-1]
new_shape = tuple([val.value for val in node_entry["types"][0].shape])
actual_workload = (input_workload[0],) + \
((new_shape + (dtype,)),) + input_workload[2:]
node_entry["workloads"].append(actual_workload)
if "record_candidates" not in node_entry:
node_entry["record_candidates"] = input_node["record_candidates"]
else:
node_entry["topi_op"].append(None)
node_entry["workloads"].append(None)
def _fetch_cfg(self):
"""Read and pre-process input schedules."""
if isinstance(self._records, str):
records = load_from_file(self._records)
else:
records = self._records
cfg_dict = {}
for record in records:
in_measure, _ = record
workload = in_measure.task.workload
if workload not in cfg_dict:
cfg_dict[workload] = []
cfg_dict[workload].append(record)
cache_dict = {}
for key in self._in_nodes_dict:
node_entry = self._node_list[key]
if node_entry["op"] not in self._target_ops:
continue
workload = node_entry["workloads"][0]
if workload in cache_dict:
node_entry["record_candidates"] = cache_dict[workload]
continue
record_candidates = []
infer_layout_func = OP2LAYOUT[node_entry["topi_op"][0]]
layout_tracking_dict = {}
for record in cfg_dict[workload]:
in_measure, out_measure = record
workload = in_measure.task.workload
cfg = in_measure.config
# For multiple cfgs which produces the same in/out layouts,
# only the most efficient one is preserved.
with self._target:
layouts = infer_layout_func(workload, cfg)
if layouts in layout_tracking_dict:
cost = out_measure.costs[0]
current_best_cost = layout_tracking_dict[layouts][1].costs[0]
if cost < current_best_cost:
layout_tracking_dict[layouts] = record
else:
layout_tracking_dict[layouts] = record
sorted_records = sorted(layout_tracking_dict.values(),
key=lambda item: item[1].costs[0])
for i in range(min(self._max_sch_num, len(sorted_records))):
record_candidates.append(sorted_records[i])
node_entry["record_candidates"] = record_candidates
cache_dict[workload] = record_candidates
def _iterate_layout_transform(self, callback):
"""Iterate all possible layout transformations and execute callback for each
iteration. callback function accepts 6 arguments: from_node_idx, to_node_idx,
from_sch_idx, to_sch_idx, args which represent the argument list of layout
transformation and is_valid showing whether this is a valid layout transformation.
"""
input_names = self._input_shapes.keys()
for key, val in self._in_nodes_dict.items():
node_entry = self._node_list[key]
target_input_idx = -1
target_input_pos = -1
if has_multiple_inputs(self._node_list, key, input_names):
for i, item in enumerate(val):
if not is_input_node(self._node_list[item], input_names):
target_input_idx = item
target_input_pos = i
break
for i, item in enumerate(val):
i_idx = item
in_node_entry = self._node_list[i_idx]
if is_input_node(in_node_entry, input_names):
continue
if node_entry["op"] in self._target_ops:
o_idx = key
o_infer_layout_func = OP2LAYOUT[node_entry["topi_op"][0]]
o_wkl = node_entry["workloads"][0]
i_topi_op = in_node_entry["topi_op"][0]
i_wkl = in_node_entry["workloads"][0]
pivot = 0
while not i_wkl:
pivot += 1
i_topi_op = in_node_entry["topi_op"][pivot]
i_wkl = in_node_entry["workloads"][pivot]
i_infer_layout_func = OP2LAYOUT[i_topi_op]
else:
o_idx = target_input_idx
if i <= target_input_pos:
continue
o_infer_layout_func = OP2LAYOUT[node_entry["topi_op"][0]]
o_wkl = node_entry["workloads"][target_input_pos]
i_infer_layout_func = OP2LAYOUT[node_entry["topi_op"][i]]
i_wkl = node_entry["workloads"][i]
for m, i_record in enumerate(in_node_entry["record_candidates"]):
for n, o_record in enumerate(node_entry["record_candidates"]):
i_cfg, o_cfg = i_record[0].config, o_record[0].config
with self._target:
i_input_info, i_output_info = i_infer_layout_func(i_wkl, i_cfg)
o_input_info, o_output_info = o_infer_layout_func(o_wkl, o_cfg)
if len(i_input_info) > 1 or len(i_output_info) > 1 or \
len(o_input_info) > 1 or len(o_output_info) > 1:
raise RuntimeError("Graph tuner only supports target operator "
"with single input and single output. "
"Please check target_ops argument.")
in_shape, in_layout = i_output_info[0]
if node_entry["op"] in self._target_ops:
_, out_layout = o_input_info[0]
else:
_, out_layout = o_output_info[0]
data_placeholder = tvm.placeholder(in_shape, name="data",
dtype=self._dtype)
args = [data_placeholder, in_layout, out_layout]
callback(i_idx, o_idx, m, n, args)
def _create_matrix_callback(self, from_node_idx, to_node_idx, from_sch_idx,
to_sch_idx, args):
"""Create dictionary containing matrix format of layout transformation
between nodes."""
sargs = serialize_args(args)
in_layout, out_layout = args[1], args[2]
ltf_workload = ('layout_transform',) + autotvm.task.args_to_workload(sargs)
idx_pair_key = (from_node_idx, to_node_idx)
if in_layout == out_layout:
layout_transform_time = 0
else:
layout_transform_time = \
self._layout_transform_perf_records[ltf_workload][1].costs[0]
if idx_pair_key not in self._layout_transform_interlayer_cost:
self._layout_transform_interlayer_cost[idx_pair_key] = []
if len(self._layout_transform_interlayer_cost[idx_pair_key]) <= from_sch_idx:
self._layout_transform_interlayer_cost[idx_pair_key].append([])
self._layout_transform_interlayer_cost[idx_pair_key][from_sch_idx]\
.append(layout_transform_time)
def benchmark_layout_transform(self, min_exec_num=100, timeout=10,
use_rpc=False, device_key=None, host="localhost",
port=9190, n_parallel=1, build_func='default',
layout_records=None, target_host=None, infer_layout=False):
"""Benchmark all possible layout transformation in the graph,
given a set of schedule candidates for each workload of target operator.
Parameters
----------
min_exec_num : int, optional
Minimum number of execution. Final execution time is the average of
all execution time.
timeout : int, optional
Time out for each execution.
use_rpc : boolean, optional
Whether to use rpc mode for benchmarking.
device_key : str, optional
Remote device key which can be queried by
python -m tvm.exec.query_rpc_tracker --host=0.0.0.0 --port=9190
host : str, optional
IP address used to create RPC tracker on host machine.
port : int, optional
Port number used to create RPC tracker on host machine.
n_parallel: int, optional
The number of measurement task that can run in parallel.
Set this according to the number of cpu cores (for compilation) and
the number of devices you have (for measuring generate code).
build_func: str or callable, optional
'default': call default builder. This works for normal target (llvm, cuda)
'ndk': use Android NDK to create shared library. Use this for android target.
callable: customized build function for other backends (e.g. VTA).
See autotvm/measure/measure_methods.py::default_build_func for example.
layout_records : str or iterator of (MeasureInput, MeasureResult). optional
Collection of layout_transform benchmarking records.
If is str, then it should be the filename of a records log file.
Each row of this file is an encoded record pair.
Otherwise, it is an iterator.
If this argument is set, graph tuner will first check whether layout_transform
workload already exists in records and skip benchmarking if possible.
target_host : str, optional
str or :any:`tvm.target.Target` optional
Host compilation target, if target is device.
When TVM compiles device specific program such as CUDA,
we also need host(CPU) side code to interact with the driver
setup the dimensions and parameters correctly.
target_host is used to specify the host side codegen target.
By default, llvm is used if it is enabled,
otherwise a stackvm intepreter is used.
infer_layout : bool, optional
Whether to infer layout transformation time if it doesn't exist in records, instead
of benchmarking on target device.
This might bring performance loss comparing to benchmarking layout transformation.
"""
self._logger.info("Start to benchmark layout transformation...")
if layout_records is None and infer_layout:
raise RuntimeError("Requires some records to infer layout transformation time.")
if isinstance(layout_records, str):
layout_records = load_from_file(layout_records)
if not layout_records and infer_layout:
raise RuntimeError("Records must be non-empty to infer layout transformation time.")
if isinstance(layout_records, str):
layout_records = load_from_file(layout_records)
num_flops, total_time = 0, 0
if layout_records is not None:
for record in layout_records:
ltf_wkl = record[0].task.workload
self._layout_transform_perf_records[ltf_wkl] = record
input_shape = ltf_wkl[1][1]
flops = np.prod(input_shape)
num_flops += flops
total_time += record[1].costs[0]
avg_time = total_time / num_flops if num_flops > 0 else 0
args_list = []
def _fetch_args_callback(from_node_idx, to_node_idx, from_sch_idx,
to_sch_idx, args):
"""Callback function to fetch layout transform args"""
_, in_layout, out_layout = args
if in_layout != out_layout:
args_list.append(args)
self._iterate_layout_transform(_fetch_args_callback)
def _log_to_list(record_list):
"""Callback to log result to a list."""
def _callback(_, inputs, results):
"""Callback implementation"""
record_list.append((inputs[0], results[0]))
return _callback
builder = autotvm.LocalBuilder(n_parallel=n_parallel, build_func=build_func)
runner = autotvm.LocalRunner(number=min_exec_num, repeat=1, timeout=timeout)
if use_rpc:
if device_key is None:
raise RuntimeError("device_key need to be set to use rpc tracker mode.")
runner = autotvm.measure.RPCRunner(device_key, host, port, n_parallel=n_parallel,
number=min_exec_num, repeat=1,
timeout=timeout)
measure_option = autotvm.measure_option(builder=builder, runner=runner)
for args in args_list:
args = serialize_args(args)
ltf_workload = ('layout_transform',) + autotvm.task.args_to_workload(args)
if ltf_workload in self._layout_transform_perf_records:
continue
if infer_layout:
input_shape = ltf_workload[1][1]
flops = 1
for i in input_shape:
flops *= i
inferred_time = flops * avg_time
record_input = MeasureInput(target=self._target, task=None, config=None)
record_output = MeasureResult(costs=(inferred_time,), error_no=0,
all_cost=-1, timestamp=-1)
self._layout_transform_perf_records[ltf_workload] = (record_input, record_output)
continue
records = []
task = autotvm.task.create(layout_transform, args=args, target=self._target,
target_host=target_host)
task.workload = ltf_workload
tuner = autotvm.tuner.GridSearchTuner(task)
tuner.tune(n_trial=1, measure_option=measure_option,
callbacks=[_log_to_list(records)])
if not isinstance(records[0][1].costs[0], float):
records[0] = (records[0][0], records[0][1]._replace(costs=(INVALID_LAYOUT_TIME,)))
self._layout_transform_perf_records[ltf_workload] = records[0]
self._iterate_layout_transform(self._create_matrix_callback)
self._logger.info("Benchmarking layout transformation successful.")
@property
def layout_transform_perf_records(self):
"""Get layout transformation dictionary for input graph.
Returns
-------
layout_transform_perf_records : dict of tuple to (MeasureInput, MeasureResult)
Layout transformation dictionary for input graph.
"""
return self._layout_transform_perf_records
def get_optimal_records(self):
"""Convert optimal record dictionary to a list of records
with ascending order of node index in graph.
Returns
-------
sch_list : list of tuple
List of records with ascending order of node index in graph.
"""
ordered_index_list = sorted(self._optimal_record_dict.keys())
ret = []
for index in ordered_index_list:
node_entry = self._node_list[index]
if node_entry["op"] not in self._target_ops:
continue
ret.append(node_entry["record_candidates"][self._optimal_record_dict[index]])
return ret
def write_opt_sch2record_file(self, record_file="graph_opt_schedule.log"):
"""Write graph level optimal schedules into file.
Parameters
----------
record_file : str, optional
Output schedule file.
"""
with open(record_file, "a") as out_file:
records = self.get_optimal_records()
for record in records:
out_file.write(encode(record[0], record[1]) + "\n")
msg = "Writing optimal schedules to %s successfully." % record_file
self._logger.info(msg)
@abstractmethod
def run(self, **kwargs):
"""Run graph tuning."""
pass
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=too-many-instance-attributes,too-many-branches,too-many-statements,too-many-arguments,too-many-locals,invalid-name
"""Stage class for dynamic programming tuner"""
import numpy as np
from .utils import is_input_node
class DPStage(object):
"""Class to represent node in Markov decision process. A stage has states
to represent different schedules of the current node. Since in this problem
the action is the schedule selected for current node, action can be fully
represented by states. No extra attribute needs for action.
In most cases, instance of this class should be created through DPTuner.
"""
def __init__(self, idx, input_shapes, node_list,
counted_nodes_set, layout_transform_interlayer_cost,
stage_dict, in_nodes_dict, out_nodes_dict,
dep_dict, target_ops, dtype="float32"):
"""Initialize a stage and create all states.
Parameters
----------
idx : int
Index for current node.
input_shapes : dict of string to tuple of int
Input shapes for current graph.
node_list : list of dict
List of all nodes for current graph.
counted_nodes_set : set of int
Global set recording whether the execution time of a node has been counted.
layout_transform_interlayer_cost : dict of tuple to list
Dictionary maps node index pair to layout transformation time between them.
stage_dict : dict of int to Stage
Global dictionary for all stages mapping node index to stage.
in_nodes_dict : dict of int to list of int
Dictionary maps node index to corresponding input node index.
out_nodes_dict : dict of int to list of int
Dictionary maps node index to corresponding output node index.
dep_dict : dict of int to set of int
Dictionary maps node index to dependent node index.
target_ops : list of str
Target operators
dtype : str, optional
Data type.
"""
self._global_input_shapes = input_shapes
self._global_input_names = input_shapes.keys()
self._global_node_list = node_list
self._global_counted_nodes_set = counted_nodes_set
self._global_layout_transform_interlayer_cost = layout_transform_interlayer_cost
self._global_stage_dict = stage_dict
self._global_in_nodes_dict = in_nodes_dict
self._global_out_nodes_dict = out_nodes_dict
self._global_dep_dict = dep_dict
self._idx = idx
self._node_entry = self._global_node_list[idx]
self._target_ops = target_ops
self._wkl = self._node_entry["workloads"][0]
self._record_list = self._node_entry["record_candidates"]
self._dep = []
self._dtype = dtype
self._states = None
self._full_states = None
self._full_states_idx = None
self._create_states()
def _create_states(self):
"""Create states."""
node = self._global_node_list[self._idx]
if node["op"] in self._target_ops:
self._create_op_states()
else:
self._create_multi_inputs_states()
def _create_op_states(self):
"""State creation routine for nodes with target_op."""
input_idx = -1
for index in self._global_in_nodes_dict[self._idx]:
input_idx = index
if not is_input_node(self._global_node_list[input_idx],
self._global_input_names):
break
if is_input_node(self._global_node_list[input_idx],
self._global_input_names):
self._full_states = np.array([record[1].costs[0]
for record in self._record_list])
self._states = self._full_states
else:
input_node_entry = self._global_node_list[input_idx]
input_stage = self._global_stage_dict[input_idx]
input_dep = input_stage.dep
input_states = input_stage.states
input_flatten_states = input_states.flatten()
input_record_list = input_node_entry["record_candidates"]
num_schedules = len(self._record_list)
num_input_schedules = len(input_record_list)
num_input_states = input_flatten_states.shape[0]
full_states_shape = tuple([num_schedules, num_input_schedules] +
[len(self._global_node_list[dep_idx]["record_candidates"])
for dep_idx in input_dep])
self._full_states = np.zeros(full_states_shape).flatten().astype("float32")
self._full_states_idx = [self._idx, input_idx] + input_dep
dep_multiplier = 1
for i in range(2, len(full_states_shape)):
dep_multiplier *= full_states_shape[i]
input_node_time_counted = input_idx in self._global_counted_nodes_set
for i in range(num_schedules):
current_sch_time = float(self._record_list[i][1].costs[0])
for j in range(num_input_states):
input_sch_idx = j // dep_multiplier
layout_transform_time = \
self._global_layout_transform_interlayer_cost \
[(input_idx, self._idx)][input_sch_idx][i]
if input_node_time_counted:
total_time = current_sch_time + layout_transform_time
else:
total_time = \
current_sch_time + layout_transform_time + input_flatten_states[j]
current_state_idx = i * num_input_states + j
self._full_states[current_state_idx] = total_time
if not input_node_time_counted:
self._global_counted_nodes_set.add(input_idx)
self._full_states = self._full_states.reshape(full_states_shape)
# If out degree of input node is 1, we can remove the dimension of input node,
# since the states of input node will not be needed any more. Otherwise, input
# node should become a dependency.
if len(self._global_out_nodes_dict[input_idx]) == 1:
self._states = np.amin(self._full_states, axis=1)
self._dep = list(input_dep)
else:
self._states = self._full_states
self._dep = [input_idx,] + input_dep
# Update global dependency dictionary.
# This is to monitor the dependency states to decide
# when a dependency can be eliminated, so that total
# number of states can be largely reduced.
for dep_idx in self._dep:
self._global_dep_dict[dep_idx].remove(self._idx)
for child in self._global_out_nodes_dict[self._idx]:
self._global_dep_dict[dep_idx].add(child)
if len(self._global_out_nodes_dict[self._idx]) > 1:
self._global_dep_dict[self._idx] = set()
for child in self._global_out_nodes_dict[self._idx]:
self._global_dep_dict[self._idx].add(child)
def _create_multi_inputs_states(self):
"""State creation routine for multi_input operator
In tvm, layout transformation for an elemwise-like follow the rule which
all input operators transform their layouts to the leftmost input operator
layout. For example:
elemwise-sum
| | |
| | |
op0 op1 op2
In this block, the possible layout transformations are: op1 -> op0 and op2 -> op0.
In graph tuning, a 3-D array with shape (k0, k1, k2) can represent the layout
transformations between these three nodes. It is also possible some earlier states
belong to other nodes(We name them as dependency) are required for dynamic programming.
The final states array for this elemwise-sum can be with shape (e0, k0, k1, e1, k2).
To iterate through all states, we first align the shape of op0, op1 and op2 to be
(e0, k0, k1, e1, k2) by broadcasting the original states. We also record the axis of
each input node in the states array, together with the multiplier. For example,
the axis index for op0 is 1, and multiplier is k1 * e1 * k2. If current iterating index
in the flatten array is i, the index of op0 can be computed as:
i % (k0 * k1 * e1 * k2) // (k1 * e1 * k2).
"""
full_input_node_list = list(self._global_in_nodes_dict[self._idx])
input_index_list = []
# Remove input and parameter nodes
for input_idx in full_input_node_list:
if not is_input_node(self._global_node_list[input_idx],
self._global_input_names):
input_index_list.append(input_idx)
# Generate new states
states_list, aligned_node_list = DPStage.align_states(input_index_list,
self._global_stage_dict,
self._global_node_list)
target_node_idx, target_major_axis, target_multiplier, target_states = states_list[0]
aligned_shape = target_states.shape
self._full_states = np.zeros(aligned_shape).astype("float32").flatten()
self._full_states_idx = list(aligned_node_list)
num_states = self._full_states.shape[0]
node_time_counted = [item[0] in self._global_counted_nodes_set for item in states_list]
target_states = target_states.flatten()
src_states_list = [states_list[i][3].flatten() for i in range(1, len(states_list))]
for i in range(num_states):
target_sch_idx = (i % (target_multiplier *
aligned_shape[target_major_axis])) // target_multiplier
if node_time_counted[0]:
new_state = 0
else:
new_state = target_states[i]
for j in range(1, len(states_list)):
src_states = src_states_list[j - 1]
src_node_idx, src_major_axis, src_multiplier, _ = states_list[j]
src_sch_idx = (i % (src_multiplier *
aligned_shape[src_major_axis])) // src_multiplier
layout_transform_time = \
self._global_layout_transform_interlayer_cost\
[(src_node_idx, target_node_idx)][src_sch_idx][target_sch_idx]
if node_time_counted[j]:
new_state += layout_transform_time
else:
new_state += layout_transform_time + src_states[i]
self._full_states[i] = new_state
for i, node_counted in enumerate(node_time_counted):
if not node_counted:
self._global_counted_nodes_set.add(states_list[i][0])
self._full_states = self._full_states.reshape(aligned_shape)
# Remove dependency to reduce states
reduced_states = np.array(self._full_states)
reduced_states_transpose = [states_list[0][1]]
reduced_states_dep_list = []
self._dep = []
for i in range(len(reduced_states.shape)):
if i != states_list[0][1]:
reduced_states_transpose.append(i)
reduced_states_dep_list.append(aligned_node_list[i])
reduced_states = np.transpose(reduced_states, reduced_states_transpose)
shift = 0
for i, dep in enumerate(reduced_states_dep_list):
if dep not in self._global_dep_dict or len(self._global_dep_dict[dep]) == 1:
self._global_dep_dict.pop(dep, None)
reduced_states = np.amin(reduced_states, axis=i+1-shift)
shift += 1
else:
self._dep.append(dep)
self._states = reduced_states
# Update dependency
for dep in self._dep:
self._global_dep_dict[dep].remove(self._idx)
for child in self._global_out_nodes_dict[self._idx]:
self._global_dep_dict[dep].add(child)
if len(self._global_out_nodes_dict[self._idx]) > 1:
self._global_dep_dict[self._idx] = set()
for child in self._global_out_nodes_dict[self._idx]:
self._global_dep_dict[self._idx].add(child)
@property
def dep(self):
"""Get dependency list."""
return self._dep
@property
def states(self):
"""Get states."""
return self._states
@property
def full_states(self):
"""Get complete states."""
return self._full_states
@property
def full_states_idx(self):
"""Get node index of complete states."""
return self._full_states_idx
@staticmethod
def align_states(input_index_list, stage_dict, node_list):
"""Align all input node states shapes to be the same and transpose/reshape properly.
This is used in creating multi_input operator states.
Parameters
----------
input_index_list : list of int
List of input node index.
stage_dict : dict of int to Stage
Global dictionary of node index to stage.
node_list : list of dict
List of all nodes for current graph.
Returns
-------
states_list : list of tuple
List of aligned states.
aligned_node_list : list in int
List of node index for aligned states.
"""
aligned_node_list = list(input_index_list)
states_list = []
for input_idx in input_index_list:
input_node_stage = stage_dict[input_idx]
for dep_idx in input_node_stage.dep:
if dep_idx not in aligned_node_list:
aligned_node_list.append(dep_idx)
aligned_shape = tuple([len(node_list[idx]["record_candidates"])
for idx in aligned_node_list])
for input_idx in input_index_list:
input_node_stage = stage_dict[input_idx]
input_node_shape_idx_list = [input_idx] + input_node_stage.dep
transpose_idx_list = []
reshape_list = []
major_axis = -1
for i, idx in enumerate(aligned_node_list):
if input_idx == idx:
major_axis = i
if idx in input_node_shape_idx_list:
transpose_idx_list.append(idx)
reshape_list.append(aligned_shape[i])
else:
reshape_list.append(1)
transpose_list = [input_node_shape_idx_list.index(idx) for idx in transpose_idx_list]
input_node_states = np.transpose(input_node_stage.states, tuple(transpose_list))
input_node_states = np.reshape(input_node_states, tuple(reshape_list))
input_node_states = np.broadcast_to(input_node_states, aligned_shape)
multiplier = 1
for i in range(major_axis + 1, len(aligned_shape)):
multiplier *= aligned_shape[i]
states_list.append((input_idx, major_axis, multiplier, input_node_states))
return states_list, aligned_node_list
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=import-error,too-many-locals,too-many-statements,too-many-branches,unused-variable
"""Dynamic programming tuner."""
import sys
import numpy as np
from .base_graph_tuner import BaseGraphTuner
from .dynamic_programming_stage import DPStage
from .utils import has_multiple_inputs, is_input_node
if sys.version_info[0] == 3:
import queue
else:
import Queue as queue
class DPTuner(BaseGraphTuner):
"""Tuner which uses dynamic programming to solve MDP problem.
Note: currently dynamic programming is used to solve this MDP problem. However,
this problem is intrinsically non-polynomial. DP can't apply for more complicated
models, such as networks with many element-wise sum operators. In this case, switch
to heuristic algorithm such as PBQP tuner.
"""
def __init__(self, *args, **kwargs):
"""Create a dynamic programming tuner.
"""
super(DPTuner, self).__init__(*args, **kwargs)
self._num_states = self._max_num_states = None
self._stage_dict = {}
self._dep_dict = {}
self._counted_nodes_set = set()
self._global_data_dict = {
"dtype": self._dtype,
"counted_nodes_set": self._counted_nodes_set,
"stage_dict": self._stage_dict,
"in_nodes_dict": self._in_nodes_dict,
"out_nodes_dict": self._out_nodes_dict,
"dep_dict": self._dep_dict,
"node_list": self._node_list,
"input_shapes": self._input_shapes,
"layout_transform_interlayer_cost": self._layout_transform_interlayer_cost
}
def _check_num_states(self, num_states):
"""Track the number of states."""
self._num_states += num_states
if self._max_num_states is not None:
if self._num_states > self._max_num_states:
raise RuntimeError("Too many states detected while running dynamic "
"programming: got %d states but upper limit is %d." %
(self._num_states, self._max_num_states))
def _forward(self):
"""Forward pass in DP to generate states for all stages.
"""
self._logger.info("Start forward pass...")
for node_idx in sorted(self._in_nodes_dict.keys()):
stage = DPStage(idx=node_idx, target_ops=self._target_ops,
**self._global_data_dict)
self._check_num_states(stage.full_states.size)
self._stage_dict[node_idx] = stage
self._logger.info("Finished forward pass.")
def _backward(self):
"""Backward pass in DP to generate optimal solution.
"""
self._logger.info("Start backward pass...")
input_names = self._input_shapes.keys()
optimal_record_dict = {}
# Pick optimal schedule for output nodes
output_idx_list = []
for key, val in self._out_nodes_dict.items():
if not val:
output_idx_list.append(key)
states_list, aligned_node_list = DPStage.align_states(output_idx_list, self._stage_dict,
self._node_list)
num_states = states_list[0][3].size
self._check_num_states(num_states * len(output_idx_list))
aligned_node_shape = states_list[0][3].shape
min_time = 0
min_pos = -1
for states in states_list:
min_time += np.amax(states[3])
flatten_states_list = [current_states[3].flatten() for current_states in states_list]
for i in range(num_states):
current_time = 0
for j, current_states in enumerate(states_list):
current_time += flatten_states_list[j][i]
if min_time > current_time:
min_time = current_time
min_pos = i
for i, states in enumerate(states_list):
current_major_axis = states[1]
current_sch_idx = (min_pos % (states[2] *
aligned_node_shape[current_major_axis])) // states[2]
optimal_record_dict[aligned_node_list[i]] = current_sch_idx
# Pick optimal schedule for dependencies of output nodes
for i in range(len(states_list), len(aligned_node_list)):
multiplier = 1
for j in range(i + 1, len(aligned_node_list)):
multiplier *= aligned_node_shape[j]
optimal_record_dict[aligned_node_list[i]] = \
min_pos // multiplier % aligned_node_shape[i]
# Backward pass to get optimal schedules for other nodes
bfs_q = queue.Queue()
visited = set()
for out_idx in output_idx_list:
bfs_q.put(out_idx)
while not bfs_q.empty():
node_idx = bfs_q.get()
visited.add(node_idx)
if is_input_node(self._node_list[node_idx], input_names):
continue
optimal_sch_idx = optimal_record_dict[node_idx]
full_states = self._stage_dict[node_idx].full_states
if not has_multiple_inputs(self._node_list, node_idx, input_names):
input_idx = self._in_nodes_dict[node_idx][0]
if is_input_node(self._node_list[input_idx], input_names):
continue
if input_idx not in visited:
bfs_q.put(input_idx)
if input_idx not in optimal_record_dict:
dep_list = self._stage_dict[node_idx].dep
dep_idx = tuple([optimal_record_dict[item] for item in dep_list])
tmp = np.argmin(full_states, axis=1)
optimal_input_sch_idx = tmp[(optimal_sch_idx,) + dep_idx]
optimal_record_dict[input_idx] = optimal_input_sch_idx
else:
input_idx_list = self._in_nodes_dict[node_idx]
optimal_record_dict[input_idx_list[0]] = optimal_sch_idx
full_states_idx = self._stage_dict[node_idx].full_states_idx
tmp = full_states[optimal_sch_idx]
new_states_idx, new_states_pos = [], []
visited_states_idx, visited_states_pos = [], []
for i in range(1, len(full_states_idx)):
if full_states_idx[i] in optimal_record_dict:
visited_states_idx.append(full_states_idx[i])
visited_states_pos.append(i - 1)
else:
new_states_idx.append(full_states_idx[i])
new_states_pos.append(i - 1)
if visited_states_idx:
tmp = np.transpose(tmp, tuple(visited_states_pos + new_states_pos))
tmp = tmp[tuple([optimal_record_dict[idx] for idx in visited_states_idx])]
min_pos = np.argmin(tmp)
multiplier = 1
for i in range(len(new_states_idx)):
multiplier *= full_states.shape[new_states_pos[i] + 1]
for pos, idx in zip(new_states_pos, new_states_idx):
multiplier //= full_states.shape[pos + 1]
optimal_record_dict[idx] = min_pos // multiplier
min_pos %= multiplier
for input_idx in input_idx_list:
if input_idx not in visited:
bfs_q.put(input_idx)
self._optimal_record_dict = optimal_record_dict
for node_idx, _ in self._in_nodes_dict.items():
if self._node_list[node_idx]["op"] not in self._target_ops:
continue
self._logger.info("Finished backward pass...")
def run(self, **kwargs):
"""Run dynamic programming solver.
"""
max_num_states = None if "max_num_states" not in kwargs else kwargs["max_num_states"]
self._num_states = 0
self._max_num_states = max_num_states
self._logger.info("Start to run dynamic programming algorithm...")
self._forward()
self._backward()
self._logger.info("Finished DPExecutor run.")
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,too-many-locals
"""Partitioned Boolean Quadratic Programming Tuner"""
from ._base import INVALID_LAYOUT_TIME
from .base_graph_tuner import BaseGraphTuner
from .utils import is_input_node, has_multiple_inputs
class PBQPTuner(BaseGraphTuner):
"""An approximation method to deal with intractably
large size of graph tuning problem.
This graph coloring algorithm mainly comes from:
Lang Hames and Bernhard Scholz.
Nearly optimal register allocation with pbqp.JMLC 2006.
LNCS, vol.4228,pp. 346-361, 2016
"""
def __init__(self, *args, **kwargs):
"""Create a partitioned boolean quadratic programming tuner.
"""
super(PBQPTuner, self).__init__(*args, **kwargs)
# Remove input nodes
input_names = self._input_shapes.keys()
for node_idx in self._out_nodes_dict:
if is_input_node(self._node_list[node_idx], input_names):
for out_node_idx in self._out_nodes_dict[node_idx]:
self._in_nodes_dict[out_node_idx].remove(node_idx)
self._adj_dict = {}
for node_idx in self._in_nodes_dict:
self._adj_dict[node_idx] = list(self._in_nodes_dict[node_idx]) + \
list(self._out_nodes_dict[node_idx])
self._record_cost_dict = {}
for key in self._in_nodes_dict:
self._record_cost_dict[key] = []
for record in self._node_list[key]["record_candidates"]:
self._record_cost_dict[key].append(record[1].costs[0])
self._max_degree = -1
self._node_degree_dict = {}
for node_idx in self._in_nodes_dict:
node_degree = self._get_degree(node_idx)
self._node_degree_dict[node_idx] = node_degree
self._max_degree = max(self._max_degree, node_degree)
self._stack = []
self._buckets = [[] for _ in range(self._max_degree + 2)]
for node_idx in sorted(self._in_nodes_dict):
node_degree = self._get_degree(node_idx)
self._buckets[node_degree].append(node_idx)
self._is_optimal = True
def _get_degree(self, node_idx):
"""Get node degree.
"""
return len(self._adj_dict[node_idx])
def _reorder_adj_nodes(self, node_idx):
"""Update buckets list with current adjacency list.
"""
for adj_node in self._adj_dict[node_idx]:
current_degree = self._get_degree(adj_node)
prev_degree = self._node_degree_dict[adj_node]
if prev_degree != current_degree:
self._buckets[prev_degree].remove(adj_node)
self._buckets[current_degree].insert(0, adj_node)
self._node_degree_dict[adj_node] = current_degree
def _remove_node(self, node_idx):
"""Remove node from graph. Update adjacency list accordingly.
"""
node_degree = self._get_degree(node_idx)
self._buckets[node_degree].remove(node_idx)
for adj_node in self._adj_dict[node_idx]:
self._adj_dict[adj_node].remove(node_idx)
def _insert_edge(self, node_x, node_y, adj_cost_matrix):
"""Insert an edge between two nodes.
"""
self._layout_transform_interlayer_cost[(node_x, node_y)] = adj_cost_matrix
self._layout_transform_interlayer_cost[(node_y, node_x)] = []
for i in range(len(adj_cost_matrix[0])):
self._layout_transform_interlayer_cost[(node_y, node_x)].append([])
for cost_vec in adj_cost_matrix:
self._layout_transform_interlayer_cost[(node_y, node_x)][i] \
.append(cost_vec[i])
self._adj_dict[node_x].append(node_y)
self._adj_dict[node_y].append(node_x)
def _backward_insert_node(self, node_idx):
"""Reinsert node in backward pass.
"""
for adj_node in self._adj_dict[node_idx]:
self._adj_dict[adj_node].append(node_idx)
def _RI_reduction(self, node_idx):
"""Reduce nodes with degree 1.
"""
adj_node = self._adj_dict[node_idx][0]
ltf_matrix = self._layout_transform_interlayer_cost[(adj_node, node_idx)]
for i, cost_vec in enumerate(ltf_matrix):
min_cost = INVALID_LAYOUT_TIME
for j, cost in enumerate(cost_vec):
min_cost = min(min_cost, cost + self._record_cost_dict[node_idx][j])
self._record_cost_dict[adj_node][i] += min_cost
self._remove_node(node_idx)
self._reorder_adj_nodes(node_idx)
self._stack.append(node_idx)
def _RII_reduction(self, node_idx):
"""Reduce nodes with degree 2.
"""
adj_node_x, adj_node_y = self._adj_dict[node_idx]
ltf_matrix_x = self._layout_transform_interlayer_cost[(adj_node_x, node_idx)]
ltf_matrix_y = self._layout_transform_interlayer_cost[(adj_node_y, node_idx)]
delta_matrix = [[] for _ in range(len(ltf_matrix_x))]
for i, cost_vec_x in enumerate(ltf_matrix_x):
for j, cost_vec_y in enumerate(ltf_matrix_y):
min_cost = INVALID_LAYOUT_TIME
for k in range(len(self._record_cost_dict[node_idx])):
min_cost = min(min_cost, cost_vec_x[k] + cost_vec_y[k]
+ self._record_cost_dict[node_idx][k])
delta_matrix[i].append(min_cost)
if adj_node_x == adj_node_y:
for i, delta_row in enumerate(delta_matrix):
self._record_cost_dict[adj_node_x][i] += delta_row[i]
elif adj_node_x in self._adj_dict[adj_node_y]:
for i, _ in enumerate(delta_matrix):
for j, delta in enumerate(delta_matrix[i]):
self._layout_transform_interlayer_cost[(adj_node_x, adj_node_y)][i][j] \
+= delta
self._layout_transform_interlayer_cost[(adj_node_y, adj_node_x)][j][i] \
+= delta
else:
self._insert_edge(adj_node_x, adj_node_y, delta_matrix)
self._remove_node(node_idx)
self._reorder_adj_nodes(node_idx)
self._stack.append(node_idx)
def _RN_reduction(self, node_idx):
"""Reduce nodes with degree greater than 2.
"""
min_cost = INVALID_LAYOUT_TIME
record_idx = -1
for i, record_cost in enumerate(self._record_cost_dict[node_idx]):
current_cost = record_cost
for adj_node in self._adj_dict[node_idx]:
ltf_matrix = self._layout_transform_interlayer_cost[(node_idx, adj_node)]
adj_record_cost = list(self._record_cost_dict[adj_node])
for j, ltf_cost in enumerate(ltf_matrix[i]):
adj_record_cost[j] += ltf_cost
current_cost += min(adj_record_cost)
if current_cost < min_cost:
min_cost = current_cost
record_idx = i
if record_idx < 0:
raise RuntimeError("Can't find a soltuion for node %d when "
"applying RN reduction" % node_idx)
self._optimal_record_dict[node_idx] = record_idx
self._is_optimal = False
for adj_node in self._adj_dict[node_idx]:
ltf_matrix = self._layout_transform_interlayer_cost[(node_idx, adj_node)]
for i, ltf_cost in enumerate(ltf_matrix[record_idx]):
self._record_cost_dict[adj_node][i] += ltf_cost
self._remove_node(node_idx)
self._reorder_adj_nodes(node_idx)
self._stack.append(node_idx)
def _forward(self):
"""Forward pass in PBQP to reduce nodes.
"""
while True:
if self._buckets[1]:
node_idx = self._buckets[1][0]
self._RI_reduction(node_idx)
elif self._max_degree >= 2 and self._buckets[2]:
node_idx = self._buckets[2][0]
self._RII_reduction(node_idx)
elif self._max_degree >= 3:
max_degree_node = -1
for i in range(self._max_degree, 2, -1):
if self._buckets[i]:
max_degree_node = self._buckets[i][0]
self._RN_reduction(max_degree_node)
break
if max_degree_node < 0:
break
else:
break
def _backward(self):
"""Backward pass in PBQP to generate optimal solution.
"""
# Solve nodes left in the forward graph
for node_idx in self._buckets[0]:
record_costs = self._record_cost_dict[node_idx]
min_cost = min(record_costs)
self._optimal_record_dict[node_idx] = record_costs.index(min_cost)
# Solve nodes with one or two degrees
for node_idx in reversed(self._stack):
self._backward_insert_node(node_idx)
if node_idx not in self._optimal_record_dict:
record_costs = list(self._record_cost_dict[node_idx])
for adj_node in self._adj_dict[node_idx]:
adj_optimal_idx = self._optimal_record_dict[adj_node]
for i, _ in enumerate(record_costs):
record_costs[i] += \
self._layout_transform_interlayer_cost \
[(node_idx, adj_node)][i][adj_optimal_idx]
min_cost = min(record_costs)
self._optimal_record_dict[node_idx] = record_costs.index(min_cost)
def run(self, **kwargs):
"""Run partitioned boolean quadratic programming tuner.
"""
self._logger.info("Start to run PBQP algorithm...")
# Define virtual record lists and layout transformaton matrices
# for multi-input nodes.
input_names = self._input_shapes.keys()
temp = {}
for key, val in self._in_nodes_dict.items():
target_input_idx = -1
target_input_pos = -1
if has_multiple_inputs(self._node_list, key, input_names):
for i, item in enumerate(val):
if not is_input_node(self._node_list[item], input_names):
target_input_idx = item
target_input_pos = i
break
temp[(target_input_idx, key)] = []
record_candidates = self._node_list[target_input_idx]["record_candidates"]
for j in range(len(record_candidates)):
temp[(target_input_idx, key)].append([])
for k in range(len(record_candidates)):
temp[(target_input_idx, key)][j].append(0 if j == k
else INVALID_LAYOUT_TIME)
for j in range(target_input_pos + 1, len(val)):
input_idx = val[j]
if is_input_node(self._node_list[input_idx], input_names):
continue
temp[(input_idx, key)] = \
self._layout_transform_interlayer_cost[(input_idx, target_input_idx)]
self._layout_transform_interlayer_cost.update(temp)
# Create reverse layout transformation matrices
temp = {}
for idx_pair, ltf_matrix in self._layout_transform_interlayer_cost.items():
reverse_key = (idx_pair[1], idx_pair[0])
reverse_matrix = [[] for _ in range(len(ltf_matrix[0]))]
for i, _ in enumerate(ltf_matrix):
for j, ltf in enumerate(ltf_matrix[i]):
reverse_matrix[j].append(ltf)
temp[reverse_key] = reverse_matrix
self._layout_transform_interlayer_cost.update(temp)
self._forward()
self._backward()
is_optimal = "optimal" if self._is_optimal else "sub-optimal"
msg = "Finished PBQPExecutor run. Got %s solution." % is_optimal
self._logger.info(msg)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=wildcard-import
"""Graph tuner utility functions"""
from __future__ import absolute_import
from . import traverse_graph
from . import utils
from .traverse_graph import expr2graph, get_direct_ancestor, get_in_nodes, \
get_out_nodes
from .utils import has_multiple_inputs, is_input_node, bind_inputs
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=too-many-locals,too-many-statements,too-many-branches,protected-access
"""API for graph traversing."""
import threading
import topi
from tvm import relay, autotvm
from tvm.relay.expr import Call, Function, TupleGetItem, Var, Constant, Tuple
from tvm.relay.ty import TupleType, TensorType
from tvm.autotvm.task import TaskExtractEnv
from .._base import RULE_OUT_NODE_NAMES
from .utils import has_multiple_inputs, is_input_node
# Setup relay op base name -> topi compute functions
# NOTE: To add more ops, change the following dictionary.
OP2COMPUTE = {
"conv2d" : [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw],
}
def expr2graph(expr, target_ops, node_dict, node_list):
"""Convert relay expr to graph data structure
and fetch workloads of target operators.
Parameters
----------
expr : tvm.relay.Expr.Function
Input relay function expression.
target_ops: List of str
List of target relay base op name
node_dict : dictionary from tvm.relay.Expr to int
Dictionary to record node index
node_list : list of dictionary
List of nodes which contains all expr in the input relay function.
Each node will be stored as a dictionary in the format of
{"op": str, "node": tvm.relay.expr, "inputs": [int], "types": [tvm.relay.Type],
"name": str, "workloads": [tuple], "topi_op": [function]}
"""
env = TaskExtractEnv.get(allow_duplicate=True)
topi_funcs = []
for op_name in target_ops:
if op_name not in OP2COMPUTE:
raise RuntimeError("Not supported relay op in graph tuner: %s"
% op_name)
topi_funcs += OP2COMPUTE[op_name]
env.reset(topi_funcs)
_expr2graph_impl(expr, target_ops, node_dict, node_list)
task_pos = 0
for node_entry in node_list:
if node_entry["op"] in target_ops:
task_name, args = env.task_collection[task_pos]
task = autotvm.task.create(task_name, args,
target="llvm",
target_host=None,
template_key='direct')
node_entry["workloads"] = [task.workload]
node_entry["topi_op"] = [task_name]
task_pos += 1
def _expr2graph_impl(expr, target_ops, node_dict, node_list):
"""Implementation to convert relay expr to graph data structure
"""
def _traverse_expr(node):
if node in node_dict:
return
node_index = len(node_list)
node_entry = {"node": node, "inputs": [], "types": [],
"op": "null", "name": None}
if isinstance(node, Call):
op_name = node.op.name.split(".")[-1]
node_entry["op"] = op_name
for arg in node.args:
in_node_idx = node_dict[arg]
if isinstance(arg, (Tuple, TupleGetItem)):
node_entry["inputs"] += node_list[in_node_idx]["inputs"]
else:
node_entry["inputs"].append([in_node_idx, 0, 0])
infer_out = relay.ir_pass.infer_type(node)
out_type = infer_out._checked_type_
if isinstance(out_type, TensorType):
node_entry["types"].append(out_type)
elif isinstance(out_type, TupleType):
for tupe_type in out_type.fields:
node_entry["types"].append(tupe_type)
else:
raise RuntimeError("Unsupported output type %s in operator %s"
% (type(out_type), op_name))
# Utilize tracing target to fetch workload with topo-order.
# Since we only need workload, dummy target can be used to
# create task.
if op_name in target_ops:
params = []
for i, input_idx in enumerate(node_entry["inputs"]):
input_node_entry = node_list[input_idx[0]]
input_type = input_node_entry["types"][input_idx[1]]
if not isinstance(input_node_entry["node"], (Var, Call)):
raise RuntimeError("Graph tuner can only tune target "
"operators with input node of type "
"relay.expr.Var or relay.expr.Call. Now "
"find a target op %s with input type %s"
% (op_name, str(type(input_node_entry["node"]))))
free_var = relay.Var("var_%d" % i, input_type)
params.append(free_var)
call = relay.Call(node.op, params, node.attrs)
func = relay.Function(params, call)
relay.backend.compile_engine.get().clear()
build_thread = threading.Thread(target=relay.build,
args=(func,
"llvm -device=tracing",
None,
None))
build_thread.start()
build_thread.join()
elif isinstance(node, Var):
node_entry["name"] = node.name_hint
node_entry["types"] = [node.type_annotation]
elif isinstance(node, Function):
# Ignore root node since it equals to input function expression
if node != expr:
_expr2graph_impl(node, target_ops, node_dict, node_list)
return
elif isinstance(node, TupleGetItem):
node_entry["op"] = "TupleGetItem"
in_node_idx = node_dict[node.tuple_value]
node_entry["inputs"].append([in_node_idx, node.index, 0])
elif isinstance(node, Tuple):
node_entry["op"] = "Tuple"
for tuple_item in node:
in_node_idx = node_dict[tuple_item]
if isinstance(tuple_item, TupleGetItem):
node_entry["inputs"] += node_list[in_node_idx]["inputs"]
elif isinstance(tuple_item, Tuple):
raise RuntimeError("Graph tuner doesn't support nested tuple.")
else:
node_entry["inputs"].append([in_node_idx, 0, 0])
elif isinstance(node, Constant):
pass
elif isinstance(node, relay.op.op.Op):
return
else:
raise RuntimeError("Not supported relay node type in graph tuning: %s"
% str(type(node)))
node_dict[node] = node_index
node_list.append(node_entry)
relay.ir_pass.post_order_visit(expr, _traverse_expr)
def get_direct_ancestor(node_list, visited_dict, target_ops, node_idx, input_names):
"""Given a node_list in relay function and a node index, return the
closest ancestor which has op_name as operator name or is multi_input operator.
If node has multiple inputs, multiple ancestor nodes will be returned.
Parameters
----------
node_list : list of dict of str to object
List of all nodes in a graph.
visited_dict : dict of int to int
Nodes and corresponding ancestors which have been visited.
target_ops: List of str
List of target relay base op name
node_idx : int
Input node index.
input_names : list of str
Names of graph input nodes.
Returns
-------
out : list of int
List of ancestor node index.
"""
if node_idx in visited_dict:
return visited_dict[node_idx]
if is_input_node(node_list[node_idx], input_names):
return [node_idx]
node = node_list[node_idx]
# Rule out injective operators
is_rule_out = False
for item_idx in node["inputs"]:
item = node_list[item_idx[0]]
if item["op"] in RULE_OUT_NODE_NAMES:
is_rule_out = True
break
if is_rule_out:
visited_dict[node_idx] = []
return []
node_direct_ancestor = []
for item_idx in node["inputs"]:
item = node_list[item_idx[0]]
is_multiple_inputs = has_multiple_inputs(node_list, item_idx[0], input_names)
if item["op"] in target_ops or is_multiple_inputs:
node_direct_ancestor.append(item_idx[0])
else:
tmp = get_direct_ancestor(node_list, visited_dict, target_ops,
item_idx[0], input_names)
for tmp_item in tmp:
node_direct_ancestor.append(tmp_item)
if not has_multiple_inputs(node_list, node_idx, input_names) and node_direct_ancestor:
node_direct_ancestor = [node_direct_ancestor[0]]
visited_dict[node_idx] = node_direct_ancestor
return node_direct_ancestor
def get_in_nodes(node_list, target_ops, input_names):
"""Create a dictionary mapping from op_name nodes or multi_input
nodes to closest input ancestors.
Parameters
----------
node_list : list of dict of str to object
List of all nodes in a graph.
target_ops: List of str
List of target relay op
input_names : list of str
Names of graph input nodes.
Returns
-------
out : dict of int to list of int
Dictionary maps node index to closest input ancestors.
"""
visited_dict = {}
in_node_dict = {}
for i, node in enumerate(node_list):
if node["op"] in RULE_OUT_NODE_NAMES:
continue
get_direct_ancestor(node_list, visited_dict, target_ops, i, input_names)
for key, val in visited_dict.items():
node = node_list[key]
is_multiple_inputs = has_multiple_inputs(node_list, key, input_names)
if node["op"] in target_ops or is_multiple_inputs:
in_node_dict[key] = val
# Remove empty nodes
has_empty_node = True
out_node_dict = get_out_nodes(in_node_dict)
while has_empty_node:
empty_nodes = []
for key, val in in_node_dict.items():
if not val:
empty_nodes.append(key)
if empty_nodes:
has_empty_node = True
for node in empty_nodes:
del in_node_dict[node]
if node in out_node_dict:
for out_node in out_node_dict[node]:
in_node_dict[out_node].remove(node)
else:
has_empty_node = False
return in_node_dict
def get_out_nodes(in_node_dict):
"""Create output dictionary from input dictionary.
Parameters
----------
in_node_dict : dict of int to list of int
Dictionary maps node index to closest input ancestors.
It can be created with get_in_nodes.
Returns
-------
out : dict of int to list of int
Dictionary maps node index to closest output nodes.
"""
out_node_dict = {}
for key in in_node_dict:
out_node_dict[key] = []
for key, val in in_node_dict.items():
for item in val:
if item in out_node_dict:
out_node_dict[item].append(key)
else:
out_node_dict[item] = [key]
return out_node_dict
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=eval-used,invalid-name,too-many-arguments
"""Utility functions"""
from tvm import relay
def has_multiple_inputs(node_list, node_idx, input_names):
"""Check whether a node has multiple input nodes
except variable nodes.
Parameters
----------
node_list : list of dict of str to object
List of all nodes in a graph.
node_idx : int
Node index to be checked.
input_names : list of str
List of input names of graph.
Returns
-------
out : bool
Whether the specified node has multiple input nodes
"""
num_inputs = 0
node = node_list[node_idx]
for in_idx in node["inputs"]:
in_idx = in_idx[0]
in_node = node_list[in_idx]
# Exclude parameter nodes
if in_node["op"] != "null" or is_input_node(in_node,
input_names):
num_inputs += 1
return num_inputs > 1
def is_input_node(node_entry, input_names):
"""Whether a node is an input node.
Parameters
----------
node_entry : dict
Node entry.
input_names : list of str
List of input names of graph.
Returns
-------
out : bool
whether node is a input node.
"""
return "name" in node_entry and node_entry["name"] in input_names
def bind_inputs(expr, input_shapes=None, input_dtypes="float32"):
"""Bind input variables of a relay function expression
to new shapes and/or dtypes.
Parameters
----------
expr : tvm.relay.Expr.Function
Input relay function expression.
input_shapes : dict of str to tuple of int, optional
Input shapes.
input_dtypes : str or dict of str to str, optional
Input dtypes.
Returns
-------
out : tvm.relay.Expr.Function
Bind relay function expression.
"""
if input_shapes is None:
return expr
if isinstance(input_dtypes, str):
input_dtypes = {key : input_dtypes for key in input_shapes.keys()}
updated_input_dict = {}
for input_name in input_shapes.keys():
updated_input = relay.var(input_name, shape=input_shapes[input_name],
dtype=input_dtypes[input_name])
updated_input_dict[input_name] = updated_input
rebind_dict = {}
for var in expr.params:
if var.name_hint in updated_input_dict:
rebind_dict[var] = updated_input_dict[var.name_hint]
updated_expr = relay.expr.bind(expr, rebind_dict)
return relay.ir_pass.infer_type(updated_expr)
......@@ -28,6 +28,7 @@ from .code_hash import attach_code_hash, attach_code_hash_to_arg
from .dispatcher import dispatcher, DispatchContext, ApplyConfig, ApplyHistoryBest, \
FallbackContext, clear_fallback_cache, ApplyGraphBest
from .topi_integration import register_topi_compute, register_topi_schedule
from .topi_integration import register_topi_compute, register_topi_schedule, \
TaskExtractEnv
from .nnvm_integration import extract_from_graph, extract_from_multiple_graph
from .relay_integration import extract_from_program, extract_from_multiple_program
......@@ -74,7 +74,7 @@ class TaskExtractEnv:
"""Global environment for extracting tuning tasks from nnvm graph"""
current = None
def __init__(self):
def __init__(self, allow_duplicate=False):
import topi
# topi compute -> autotvm task name
......@@ -106,6 +106,7 @@ class TaskExtractEnv:
topi.nn.deformable_conv2d_nchw: [topi.generic.schedule_deformable_conv2d_nchw],
}
self.allow_duplicate = allow_duplicate
self._register_tracing()
self._register_topi_task()
self.task_collection = []
......@@ -123,10 +124,9 @@ class TaskExtractEnv:
assert not kwargs, "Do not support extracting tuning tasks when" \
"kwargs is used in TOPI function call." \
"Please modify it to use only positional args."
if compute_func in self.wanted_topi_funcs: # record this call
key = (self.topi_to_task[compute_func], serialize_args(args))
if key not in self.task_collection:
if self.allow_duplicate or key not in self.task_collection:
self.task_collection.append(key)
return compute_func.fdefault(*args)
_local_scope(topi_compute)
......@@ -262,16 +262,25 @@ class TaskExtractEnv:
return self.task_collection
@staticmethod
def get():
def get(allow_duplicate=False):
"""Get the single instance of TaskExtractEnv
Parameters
----------
allow_duplicate : boolean
Whether to fetch all workloads in the network,
even though some of them are the same. This is
useful for graph tuning.
Returns
-------
env: TaskExtractEnv
The single instance of TaskExtractEnv
"""
if not TaskExtractEnv.current:
TaskExtractEnv.current = TaskExtractEnv()
TaskExtractEnv.current = TaskExtractEnv(allow_duplicate)
else:
TaskExtractEnv.current.allow_duplicate = allow_duplicate
return TaskExtractEnv.current
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# NOTE: We name this test file to start with test_graph_tuner
# to make it execute after zero_rank tensor test cases. This
# helps avoid topi arithmetic operator overloading issue:
# https://github.com/dmlc/tvm/issues/3240.
# TODO: restore the file name after this issue is resolved.
import os
import copy
import numpy as np
import tvm
import tvm.relay.testing
from tvm import autotvm
from tvm import relay
from tvm.autotvm.task import ConfigEntity
from tvm.autotvm.measure import MeasureResult, MeasureInput
from tvm.autotvm.graph_tuner import DPTuner, PBQPTuner
from test_graph_tuner_utils import create_workload
def _create_data(target, dshape, dtype, layout):
data = relay.var("data", shape=dshape, dtype=dtype)
w0 = relay.var("w0_weight")
conv0 = relay.nn.conv2d(data, w0, channels=16, kernel_size=(3, 3), padding=(1, 1))
w1 = relay.var("w1_weight")
conv1 = relay.nn.conv2d(conv0, w1, channels=32, kernel_size=(1, 1))
w2 = relay.var("w2_weight")
conv2 = relay.nn.conv2d(conv1, w2, channels=32, kernel_size=(3, 3), padding=(1, 1))
out = relay.add(conv1, conv2)
net = relay.Function(relay.ir_pass.free_vars(out), out)
net, params = relay.testing.create_workload(net)
tasks = autotvm.task.extract_from_program(net,
target=target,
params=params,
ops=(relay.op.nn.conv2d,))
wkl_list = [
create_workload((1, 3, 8, 8), (16, 3, 3, 3), (1, 1), (1, 1), (1, 1), layout, layout, dtype, dtype),
create_workload((1, 16, 8, 8), (32, 16, 1, 1), (1, 1), (0, 0), (1, 1), layout, layout, dtype, dtype),
create_workload((1, 32, 8, 8), (32, 32, 3, 3), (1, 1), (1, 1), (1, 1), layout, layout, dtype, dtype),
]
costs = [0.04, 0.012, 0.03]
config_list = []
cfg_dict = {"i": -1,
"c": None,
"e": [["tile_ic", "sp", [3, 1]],
["tile_oc", "sp", [4, 4]],
["tile_ow", "sp", [4, 2]],
["unroll_kw", "ot", True]],
"t": ""}
config_list.append(ConfigEntity.from_json_dict(cfg_dict))
cfg_dict = {"i": -1,
"c": None,
"e": [["tile_ic", "sp", [2, 8]],
["tile_oc", "sp", [1, 32]],
["tile_oh", "ot", 1],
["tile_ow", "sp", [4, 2]]],
"t": ""}
config_list.append(ConfigEntity.from_json_dict(cfg_dict))
cfg_dict = {"i": -1,
"c": None,
"e": [["tile_ic", "sp", [8, 4]],
["tile_oc", "sp", [4, 8]],
["tile_ow", "sp", [2, 4]],
["unroll_kw", "ot", False]],
"t": ""}
config_list.append(ConfigEntity.from_json_dict(cfg_dict))
records = []
for wkl, cost, config, task in zip(wkl_list, costs, config_list, tasks):
task.workload = wkl
ms_input = MeasureInput(target=target, task=task, config=config)
ms_output = MeasureResult(costs=(cost,), error_no=0, all_cost=-1, timestamp=-1)
records.append((ms_input, ms_output))
ltf_records = []
ltf_arg = [tvm.placeholder((1, 64, 16, 16, 8), dtype=dtype), "NCHW8c", "NCHW512c"]
ltf_arg = autotvm.task.topi_integration.serialize_args(ltf_arg)
ltf_wkl = ('layout_transform',) + autotvm.task.args_to_workload(ltf_arg)
ltf_task = copy.deepcopy(tasks[0])
ltf_task.workload = ltf_wkl
ms_input = MeasureInput(target=target, task=ltf_task, config=None)
ms_output = MeasureResult(costs=(1.91224744e-05,), error_no=0, all_cost=-1, timestamp=-1)
ltf_records.append((ms_input, ms_output))
ltf_keys = []
ltf_arg = [tvm.placeholder((1, 4, 8, 8, 4), dtype=dtype), "NCHW4c", "NCHW8c"]
ltf_arg = autotvm.task.topi_integration.serialize_args(ltf_arg)
ltf_wkl = ('layout_transform',) + autotvm.task.args_to_workload(ltf_arg)
ltf_keys.append(ltf_wkl)
ltf_arg = [tvm.placeholder((1, 1, 8, 8, 32), dtype=dtype), "NCHW32c", "NCHW4c"]
ltf_arg = autotvm.task.topi_integration.serialize_args(ltf_arg)
ltf_wkl = ('layout_transform',) + autotvm.task.args_to_workload(ltf_arg)
ltf_keys.append(ltf_wkl)
ltf_arg = [tvm.placeholder((1, 4, 8, 8, 8), dtype=dtype), "NCHW8c", "NCHW32c"]
ltf_arg = autotvm.task.topi_integration.serialize_args(ltf_arg)
ltf_wkl = ('layout_transform',) + autotvm.task.args_to_workload(ltf_arg)
ltf_keys.append(ltf_wkl)
return net, records, ltf_records, ltf_keys, tasks
def test_graph_tuner_layout_transform():
log_file = "%s/test_tuner.log" % (os.getcwd())
target = "llvm"
dshape = (1, 3, 8, 8)
dtype = "float32"
layout = "NCHW"
target_ops = [relay.nn.conv2d]
g, records, ltf_records, ltf_keys, _ = _create_data(target, dshape, dtype, layout)
executor = DPTuner(g, {"data": dshape}, records, target_ops, target=target, log_file=log_file)
executor.benchmark_layout_transform(layout_records=ltf_records, infer_layout=True)
out = executor._layout_transform_perf_records
num_flops = 0
total_time = 0
for record in ltf_records:
ltf_wkl = record[0].task.workload
input_shape = ltf_wkl[1][1]
flops = np.prod(input_shape)
num_flops += flops
total_time += record[1].costs[0]
avg_time = total_time / num_flops
for ltf_workload in out:
input_shape = ltf_workload[1][1]
flops = 1
for i in input_shape:
flops *= i
expected_time = flops * avg_time
out_time = out[ltf_workload][1].costs[0]
assert expected_time == out_time, "Inferred layout transformation time mismatch for %s: " \
"expecting %f but got %f" % (str(ltf_workload), expected_time,
out_time)
def test_DPTuner_run():
log_file = "%s/test_tuner.log" % (os.getcwd())
target = "llvm"
dtype = "float32"
layout = "NCHW"
dshape = (1, 3, 8, 8)
target_ops = [relay.nn.conv2d]
g, records, ltf_records, ltf_keys, tasks = _create_data(target, dshape, dtype, layout)
costs = [0.02, 0.02, 0.045]
config_list = []
cfg_dict = {"i": -1,
"c": None,
"e": [["tile_ic", "sp", [1, 3]],
["tile_oc", "sp", [2, 8]],
["tile_ow", "sp", [4, 2]],
["unroll_kw", "ot", True]],
"t": ""}
config_list.append(ConfigEntity.from_json_dict(cfg_dict))
cfg_dict = {"i": -1,
"c": None,
"e": [["tile_ic", "sp", [4, 4]],
["tile_oc", "sp", [2, 16]],
["tile_oh", "ot", 1],
["tile_ow", "sp", [4, 2]]],
"t": ""}
config_list.append(ConfigEntity.from_json_dict(cfg_dict))
cfg_dict = {"i": -1,
"c": None,
"e": [["tile_ic", "sp", [16, 2]],
["tile_oc", "sp", [8, 4]],
["tile_ow", "sp", [2, 4]],
["unroll_kw", "ot", False]],
"t": ""}
config_list.append(ConfigEntity.from_json_dict(cfg_dict))
for cost, config, task in zip(costs, config_list, tasks):
ms_input = MeasureInput(target=target, task=task, config=config)
ms_output = MeasureResult(costs=(cost,), error_no=0, all_cost=-1, timestamp=-1)
records.append((ms_input, ms_output))
executor = DPTuner(g, {"data": dshape}, records, target_ops, target, log_file=log_file)
executor.benchmark_layout_transform(layout_records=ltf_records, infer_layout=True)
executor.run()
out = [record[0].config for record in executor.get_optimal_records()]
expected_out = [records[3][0].config, records[1][0].config, records[2][0].config]
assert expected_out == out, "Output mismatch: expecting %s but got %s" \
% (str(expected_out), str(out))
assert os.path.isfile(log_file), "No log file with name %s exists." % log_file
def test_PBQPTuner_run():
target = "llvm"
dtype = "float32"
layout = "NCHW"
dshape = (1, 3, 8, 8)
target_ops = [relay.nn.conv2d]
g, records, ltf_records, ltf_keys, tasks = _create_data(target, dshape, dtype, layout)
costs = [0.02, 0.02, 0.045]
config_list = []
cfg_dict = {"i": -1,
"c": None,
"e": [["tile_ic", "sp", [1, 3]],
["tile_oc", "sp", [2, 8]],
["tile_ow", "sp", [4, 2]],
["unroll_kw", "ot", True]],
"t": ""}
config_list.append(ConfigEntity.from_json_dict(cfg_dict))
cfg_dict = {"i": -1,
"c": None,
"e": [["tile_ic", "sp", [4, 4]],
["tile_oc", "sp", [2, 16]],
["tile_oh", "ot", 1],
["tile_ow", "sp", [4, 2]]],
"t": ""}
config_list.append(ConfigEntity.from_json_dict(cfg_dict))
cfg_dict = {"i": -1,
"c": None,
"e": [["tile_ic", "sp", [16, 2]],
["tile_oc", "sp", [8, 4]],
["tile_ow", "sp", [2, 4]],
["unroll_kw", "ot", False]],
"t": ""}
config_list.append(ConfigEntity.from_json_dict(cfg_dict))
for cost, config, task in zip(costs, config_list, tasks):
ms_input = MeasureInput(target=target, task=task, config=config)
ms_output = MeasureResult(costs=(cost,), error_no=0, all_cost=-1, timestamp=-1)
records.append((ms_input, ms_output))
executor = PBQPTuner(g, {"data": dshape}, records, target_ops, target)
executor.benchmark_layout_transform(layout_records=ltf_records, infer_layout=True)
executor.run()
out = [record[0].config for record in executor.get_optimal_records()]
expected_out = [records[3][0].config, records[1][0].config, records[2][0].config]
assert expected_out == out, "Output mismatch: expecting %s but got %s" \
% (str(expected_out), str(out))
if __name__=="__main__":
test_graph_tuner_layout_transform()
test_DPTuner_run()
test_PBQPTuner_run()
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# NOTE: We name this test file to start with test_graph_tuner
# to make it execute after zero_rank tensor test cases. This
# helps avoid topi arithmetic operator overloading issue:
# https://github.com/dmlc/tvm/issues/3240
# TODO: restore the file name after this issue is resolved.
import tvm
from tvm import autotvm, relay
from tvm.relay.testing import resnet
from tvm.autotvm.graph_tuner.utils import has_multiple_inputs, get_direct_ancestor, get_in_nodes, \
get_out_nodes, expr2graph, bind_inputs
from tvm.relay.expr import Call, TupleGetItem, Tuple
from topi.nn.conv2d import conv2d
def create_workload(dshape, kshape, strides,
padding, dilation, layout,
out_layout, dtype, out_dtype):
data = tvm.placeholder(dshape, dtype=dtype)
kernel = tvm.placeholder(kshape, dtype=dtype)
return autotvm.task.args_to_workload([data, kernel, strides, padding, dilation, layout,
out_dtype], conv2d)
def verify_has_multiple_inputs(node_list, node_idx, input_names, expected_result):
out = has_multiple_inputs(node_list, node_idx, input_names)
assert out == expected_result, "Output mismatch: expecting checking %s to be %s but got %s." \
% (node_list[node_idx]["op"], str(expected_result), str(out))
def test_has_multiple_inputs():
data = relay.var("data")
out1 = data * relay.expr.const(3.0)
w0 = relay.var("w0")
out2 = relay.nn.conv2d(data, w0)
out = relay.add(out1, out2)
net = relay.Function(relay.ir_pass.free_vars(out), out)
net = bind_inputs(net, {"data": (1, 16, 224, 224), "w0": (16, 16, 1, 1)})
target_ops = ["conv2d"]
node_list = []
node_dict = {}
expr2graph(net, target_ops, node_dict, node_list)
input_names = ["data"]
verify_has_multiple_inputs(node_list, 2, input_names, False)
verify_has_multiple_inputs(node_list, 4, input_names, False)
verify_has_multiple_inputs(node_list, 5, input_names, True)
def test_expr2graph():
net, _ = resnet.get_workload(num_layers=50, batch_size=1)
node_dict = {}
node_list = []
target_ops = ["conv2d"]
op_name_list = []
def _count_node(node):
if not isinstance(node, relay.op.op.Op,):
return
if isinstance(node, Call):
op_name_list.append(node.op.name.split(".")[-1])
elif isinstance(node, TupleGetItem):
op_name_list.append("TupleGetItem")
elif isinstance(node, Tuple):
op_name_list.append("Tuple")
else:
op_name_list.append("null")
relay.ir_pass.post_order_visit(net, _count_node)
expr2graph(net, target_ops, node_dict, node_list)
for i, item in enumerate(zip(op_name_list, node_list)):
op_name, node = item
assert op_name == node["op"], "%dth Node operator mismatch: expecting %s but got %s" \
% (i, str(op_name), str(node["op"]))
def test_get_direct_ancestor():
data = relay.var("data")
w0 = relay.var("w0")
out1 = relay.nn.conv2d(data, w0)
out2 = relay.add(out1, data * relay.expr.const(5.0))
out3 = out2 + relay.expr.const(2.5)
w1 = relay.var("w1")
out = relay.nn.conv2d(out3, w1)
net = relay.Function(relay.ir_pass.free_vars(out), out)
net = bind_inputs(net, {"data": (1, 16, 224, 224), "w0": (16, 16, 1, 1), "w1": (16, 16, 1, 1)})
target_ops = ["conv2d"]
node_list = []
node_dict = {}
expr2graph(net, target_ops, node_dict, node_list)
visited_dict = {}
input_names = ["data"]
out = get_direct_ancestor(node_list, visited_dict, target_ops, 5, input_names)
assert out == [2, 0], "Output mismatch: expecting [2, 0] but got %s." % str(out)
def test_get_in_nodes():
data = relay.var("data")
w0 = relay.var("w0")
out1 = relay.nn.conv2d(data, w0)
out2 = relay.add(out1, data)
out3 = out2 + relay.expr.const(2.5)
w1 = relay.var("w1")
out = relay.nn.conv2d(out3, w1)
net = relay.Function(relay.ir_pass.free_vars(out), out)
net = bind_inputs(net, {"data": (1, 16, 224, 224), "w0": (16, 16, 1, 1), "w1": (16, 16, 1, 1)})
target_ops = ["conv2d"]
input_names = ["data"]
node_list = []
node_dict = {}
expr2graph(net, target_ops, node_dict, node_list)
out = get_in_nodes(node_list, target_ops, input_names)
expected_out = {7: [3], 3: [2, 0], 2: [0]}
diff_set = set(out) ^ set(expected_out)
if len(diff_set) != 0:
raise RuntimeError("Output mismatch: expecting %s but got %s." % (str(expected_out), str(out)))
def test_get_out_nodes():
in_nodes_dict = {8: [4], 4: [3, 0], 3: [0]}
expected_out = {0: [3, 4], 3: [4], 4: [8], 8: []}
out = get_out_nodes(in_nodes_dict)
diff_set = set(out) ^ set(expected_out)
if len(diff_set) != 0:
raise RuntimeError("Output mismatch: expecting %s but got %s." % (str(expected_out), str(out)))
if __name__ == "__main__":
test_has_multiple_inputs()
test_expr2graph()
test_get_direct_ancestor()
test_get_in_nodes()
test_get_out_nodes()
......@@ -94,6 +94,26 @@ def conv2d_alter_layout(attrs, inputs, tinfos, F):
# not to change by default
return None
@tvm.target.generic_func
def conv2d_infer_layout(workload, cfg):
"""Infer input/output shapes and layouts from a workload and cfg.
Parameters
----------
workload : tuple
conv2d workload
cfg : tuple
tvm.autotvm config
Returns
-------
Output : [tuple of tuple and str, tuple of tuple and str]
Input shapes and layouts, and output shapes and layouts
"""
raise ValueError("missing register for topi.nn.conv2d_infer_layout")
def _get_workload(data, kernel, stride, padding, out_dtype, data_layout='NCHW'):
""" Get the workload structure. """
......
......@@ -336,3 +336,22 @@ def depthwise_conv2d_NCHWc(Input, Filter, stride, padding, dilation,
5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block]
"""
raise ValueError("missing register for topi.nn.depthwise_conv2d_NCHWc")
@tvm.target.generic_func
def depthwise_conv2d_infer_layout(workload, cfg):
"""Infer input/output shapes and layouts from a workload and cfg.
Parameters
----------
workload : tuple
conv2d workload
cfg : tuple
tvm.autotvm config
Returns
-------
Output : [tuple of tuple and str, tuple of tuple and str]
Input shapes and layouts, and output shapes and layouts
"""
raise ValueError("missing register for topi.nn.depthwise_conv2d_infer_layout")
......@@ -28,7 +28,7 @@ from .. import generic, tag
from .. import nn
from ..util import get_const_tuple
from ..nn.conv2d import conv2d, conv2d_NCHWc, \
conv2d_alter_layout, _get_workload as _get_conv2d_workload
conv2d_alter_layout, conv2d_infer_layout, _get_workload as _get_conv2d_workload
from ..nn.depthwise_conv2d import _get_workload as _get_depthwise_conv2d_workload
from ..nn.depthwise_conv2d import depthwise_conv2d_NCHWc, depthwise_conv2d_nchw
from ..nn.pad import pad
......@@ -475,6 +475,21 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F):
return F.nn.contrib_conv2d_nchwc(*copy_inputs, **new_attrs)
@conv2d_infer_layout.register("cpu")
def _conv2d_infer_layout(workload, cfg):
_, data, kernel, strides, padding, dilation, layout, dtype = workload
batch_size, in_channel, in_height, in_width = data[:-1]
out_channel, _, k_height, k_width = kernel[:-1]
out_height = (in_height + 2 * padding[0] - k_height) // strides[0] + 1
out_width = (in_width + 2 * padding[1] - k_width) // strides[1] + 1
tile_ic, tile_oc = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
in_shape = (batch_size, in_channel // tile_ic, in_height, in_width, tile_ic)
in_layout = "NCHW%dc" % tile_ic
out_shape = (batch_size, out_channel // tile_oc, out_height, out_width, tile_oc)
out_layout = "NCHW%dc" % tile_oc
return ((in_shape, in_layout),), ((out_shape, out_layout),)
@autotvm.register_topi_compute(conv2d_NCHWc, 'cpu', 'direct')
def _declaration_conv_NCHWc(cfg, data, kernel, strides,
padding, dilation, layout, out_layout, out_dtype):
......
......@@ -25,7 +25,8 @@ from .. import generic, tag
from ..nn.pad import pad
from ..util import get_const_tuple
from ..nn.util import get_pad_tuple
from ..nn.depthwise_conv2d import depthwise_conv2d_NCHWc, _get_workload
from ..nn.depthwise_conv2d import depthwise_conv2d_NCHWc, _get_workload, \
depthwise_conv2d_infer_layout
from .util import get_fp32_len
......@@ -206,7 +207,7 @@ def _topi_nn_depthwise_conv2d_NCHWc(*args, **kwargs):
# change shape with the value in config
ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
new_data_shape = (batch, in_channel // ic_bn, height, width, ic_bn)
new_kernel_shape = (out_channel // oc_bn, kh, kw, oc_bn)
new_kernel_shape = (out_channel // oc_bn, 1, kh, kw, 1, oc_bn)
new_data = tvm.placeholder(new_data_shape, data.dtype)
new_kernel = tvm.placeholder(new_kernel_shape, kernel.dtype)
......@@ -217,3 +218,18 @@ def _topi_nn_depthwise_conv2d_NCHWc(*args, **kwargs):
data_layout, out_layout, dtype)
s = schedule_depthwise_conv2d_NCHWc(cfg, [C])
return s, [new_data, new_kernel, C]
@depthwise_conv2d_infer_layout.register("cpu")
def _depthwise_conv2d_infer_layout(workload, cfg):
_, data, kernel, strides, padding, dilation, dtype = workload
batch_size, in_channel, in_height, in_width = data[:-1]
filter_channel, channel_multiplier, k_height, k_width = kernel[:-1]
out_channel = filter_channel * channel_multiplier
out_height = (in_height + 2 * padding[0] - k_height) // strides[0] + 1
out_width = (in_width + 2 * padding[1] - k_width) // strides[1] + 1
tile_ic, tile_oc = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
in_shape = (batch_size, in_channel // tile_ic, in_height, in_width, tile_ic)
in_layout = "NCHW%dc" % tile_ic
out_shape = (batch_size, out_channel // tile_oc, out_height, out_width, tile_oc)
out_layout = "NCHW%dc" % tile_oc
return ((in_shape, in_layout),), ((out_shape, out_layout),)
......@@ -30,6 +30,7 @@ from tvm import autotvm
from tvm import relay
from tvm.relay import testing
from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner
from tvm.autotvm.graph_tuner import DPTuner, PBQPTuner
import tvm.contrib.graph_runtime as runtime
#################################################################
......@@ -81,6 +82,7 @@ batch_size = 1
dtype = "float32"
model_name = "resnet-18"
log_file = "%s.log" % model_name
graph_opt_sch_file = "%s_graph_opt.log" % model_name
# Set number of threads used for tuning based on the number of
# physical CPU cores on your machine.
......@@ -157,6 +159,16 @@ def tune_kernels(tasks,
autotvm.callback.progress_bar(n_trial, prefix=prefix),
autotvm.callback.log_to_file(log_filename)])
# Use graph tuner to achieve graph level optimal schedules
# Set use_DP=False if it takes too long to finish.
def tune_graph(graph, dshape, records, opt_sch_file, use_DP=True):
target_op = [relay.nn.conv2d]
Tuner = DPTuner if use_DP else PBQPTuner
executor = Tuner(graph, {"data": dshape}, records, target_op, target)
executor.benchmark_layout_transform(min_exec_num=2000)
executor.run()
executor.write_opt_sch2record_file(opt_sch_file)
########################################################################
# Finally, we launch tuning jobs and evaluate the end-to-end performance.
......@@ -171,9 +183,10 @@ def tune_and_evaluate(tuning_opt):
# run tuning tasks
print("Tuning...")
tune_kernels(tasks, **tuning_opt)
tune_graph(net, data_shape, log_file, graph_opt_sch_file)
# compile kernels with history best records
with autotvm.apply_history_best(log_file):
# compile kernels with graph-level best records
with autotvm.apply_graph_best(graph_opt_sch_file):
print("Compile...")
with relay.build_config(opt_level=3):
graph, lib, params = relay.build_module.build(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment