Commit c8a0f524 by Yao Wang Committed by Yizhi Liu

[AutoTVM]Core functionality for Graph tuner (#2184)

* Add graph tuning

* Add tests

* Fix tests

* Fix pylint

* Small fix for docstring

* Minor fix

* Support fetching workload from relay expr

* Simplify benchmark layout transformation

* Add relay support

* Fix infer layout func name

* Refactor internal data representation

* Fix issues

* Add PBQP solver

* Fix layout transform check

* Add PBQPTuner test

* Fix lint

* Update tutorial

* Fix tutorial

* Fix lint

* Add relay test

* Remove nnvm since nnvm graph can be converted to relay function

* Modify benchmark layout wrt new layout_transform api

* Fix lint

* Update docstring for DP tuner

* Refactor traverse graph

* Support graph tuning for multiple target operators

* Fix fetching workloads

* Add x86 depthwise_conv2d infer_layout

* Fix x86 depthwise_conv2d autotvm

* Fix PBQP tuner

* Fix DP tuner

* Generate dummy layout transform record

* Update tutorial

* Modify layout records name

* Add ASF header

* Add ASF header for testing files

* Fix test

* Fix topi fetching

* Some refactors

* Fix lint

* Fix tutorial

* Rename test files

* Fix doc typo

* Add test case note link
parent 4767554c
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Autotvm graph tuner API."""
from __future__ import absolute_import as _abs
from . import _base
from . import base_graph_tuner
from .base_graph_tuner import BaseGraphTuner
from .dynamic_programming_tuner import DPTuner
from .pbqp_tuner import PBQPTuner
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Helper functions and global data"""
RULE_OUT_NODE_NAMES = ["Tuple", "TupleGetItem", "batch_flatten", "transpose", "reshape",
"multibox_prior", "multibox_transform_loc", "where",
"non_max_suppression", "strided_slice"]
# We set a large time to represent an invalid layout-transformation.
# This number is set to be 10e9 seconds to align with autotvm.
INVALID_LAYOUT_TIME = 10e9
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=import-error,too-many-locals,too-many-statements,too-many-branches,unused-variable
"""Dynamic programming tuner."""
import sys
import numpy as np
from .base_graph_tuner import BaseGraphTuner
from .dynamic_programming_stage import DPStage
from .utils import has_multiple_inputs, is_input_node
if sys.version_info[0] == 3:
import queue
else:
import Queue as queue
class DPTuner(BaseGraphTuner):
"""Tuner which uses dynamic programming to solve MDP problem.
Note: currently dynamic programming is used to solve this MDP problem. However,
this problem is intrinsically non-polynomial. DP can't apply for more complicated
models, such as networks with many element-wise sum operators. In this case, switch
to heuristic algorithm such as PBQP tuner.
"""
def __init__(self, *args, **kwargs):
"""Create a dynamic programming tuner.
"""
super(DPTuner, self).__init__(*args, **kwargs)
self._num_states = self._max_num_states = None
self._stage_dict = {}
self._dep_dict = {}
self._counted_nodes_set = set()
self._global_data_dict = {
"dtype": self._dtype,
"counted_nodes_set": self._counted_nodes_set,
"stage_dict": self._stage_dict,
"in_nodes_dict": self._in_nodes_dict,
"out_nodes_dict": self._out_nodes_dict,
"dep_dict": self._dep_dict,
"node_list": self._node_list,
"input_shapes": self._input_shapes,
"layout_transform_interlayer_cost": self._layout_transform_interlayer_cost
}
def _check_num_states(self, num_states):
"""Track the number of states."""
self._num_states += num_states
if self._max_num_states is not None:
if self._num_states > self._max_num_states:
raise RuntimeError("Too many states detected while running dynamic "
"programming: got %d states but upper limit is %d." %
(self._num_states, self._max_num_states))
def _forward(self):
"""Forward pass in DP to generate states for all stages.
"""
self._logger.info("Start forward pass...")
for node_idx in sorted(self._in_nodes_dict.keys()):
stage = DPStage(idx=node_idx, target_ops=self._target_ops,
**self._global_data_dict)
self._check_num_states(stage.full_states.size)
self._stage_dict[node_idx] = stage
self._logger.info("Finished forward pass.")
def _backward(self):
"""Backward pass in DP to generate optimal solution.
"""
self._logger.info("Start backward pass...")
input_names = self._input_shapes.keys()
optimal_record_dict = {}
# Pick optimal schedule for output nodes
output_idx_list = []
for key, val in self._out_nodes_dict.items():
if not val:
output_idx_list.append(key)
states_list, aligned_node_list = DPStage.align_states(output_idx_list, self._stage_dict,
self._node_list)
num_states = states_list[0][3].size
self._check_num_states(num_states * len(output_idx_list))
aligned_node_shape = states_list[0][3].shape
min_time = 0
min_pos = -1
for states in states_list:
min_time += np.amax(states[3])
flatten_states_list = [current_states[3].flatten() for current_states in states_list]
for i in range(num_states):
current_time = 0
for j, current_states in enumerate(states_list):
current_time += flatten_states_list[j][i]
if min_time > current_time:
min_time = current_time
min_pos = i
for i, states in enumerate(states_list):
current_major_axis = states[1]
current_sch_idx = (min_pos % (states[2] *
aligned_node_shape[current_major_axis])) // states[2]
optimal_record_dict[aligned_node_list[i]] = current_sch_idx
# Pick optimal schedule for dependencies of output nodes
for i in range(len(states_list), len(aligned_node_list)):
multiplier = 1
for j in range(i + 1, len(aligned_node_list)):
multiplier *= aligned_node_shape[j]
optimal_record_dict[aligned_node_list[i]] = \
min_pos // multiplier % aligned_node_shape[i]
# Backward pass to get optimal schedules for other nodes
bfs_q = queue.Queue()
visited = set()
for out_idx in output_idx_list:
bfs_q.put(out_idx)
while not bfs_q.empty():
node_idx = bfs_q.get()
visited.add(node_idx)
if is_input_node(self._node_list[node_idx], input_names):
continue
optimal_sch_idx = optimal_record_dict[node_idx]
full_states = self._stage_dict[node_idx].full_states
if not has_multiple_inputs(self._node_list, node_idx, input_names):
input_idx = self._in_nodes_dict[node_idx][0]
if is_input_node(self._node_list[input_idx], input_names):
continue
if input_idx not in visited:
bfs_q.put(input_idx)
if input_idx not in optimal_record_dict:
dep_list = self._stage_dict[node_idx].dep
dep_idx = tuple([optimal_record_dict[item] for item in dep_list])
tmp = np.argmin(full_states, axis=1)
optimal_input_sch_idx = tmp[(optimal_sch_idx,) + dep_idx]
optimal_record_dict[input_idx] = optimal_input_sch_idx
else:
input_idx_list = self._in_nodes_dict[node_idx]
optimal_record_dict[input_idx_list[0]] = optimal_sch_idx
full_states_idx = self._stage_dict[node_idx].full_states_idx
tmp = full_states[optimal_sch_idx]
new_states_idx, new_states_pos = [], []
visited_states_idx, visited_states_pos = [], []
for i in range(1, len(full_states_idx)):
if full_states_idx[i] in optimal_record_dict:
visited_states_idx.append(full_states_idx[i])
visited_states_pos.append(i - 1)
else:
new_states_idx.append(full_states_idx[i])
new_states_pos.append(i - 1)
if visited_states_idx:
tmp = np.transpose(tmp, tuple(visited_states_pos + new_states_pos))
tmp = tmp[tuple([optimal_record_dict[idx] for idx in visited_states_idx])]
min_pos = np.argmin(tmp)
multiplier = 1
for i in range(len(new_states_idx)):
multiplier *= full_states.shape[new_states_pos[i] + 1]
for pos, idx in zip(new_states_pos, new_states_idx):
multiplier //= full_states.shape[pos + 1]
optimal_record_dict[idx] = min_pos // multiplier
min_pos %= multiplier
for input_idx in input_idx_list:
if input_idx not in visited:
bfs_q.put(input_idx)
self._optimal_record_dict = optimal_record_dict
for node_idx, _ in self._in_nodes_dict.items():
if self._node_list[node_idx]["op"] not in self._target_ops:
continue
self._logger.info("Finished backward pass...")
def run(self, **kwargs):
"""Run dynamic programming solver.
"""
max_num_states = None if "max_num_states" not in kwargs else kwargs["max_num_states"]
self._num_states = 0
self._max_num_states = max_num_states
self._logger.info("Start to run dynamic programming algorithm...")
self._forward()
self._backward()
self._logger.info("Finished DPExecutor run.")
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=wildcard-import
"""Graph tuner utility functions"""
from __future__ import absolute_import
from . import traverse_graph
from . import utils
from .traverse_graph import expr2graph, get_direct_ancestor, get_in_nodes, \
get_out_nodes
from .utils import has_multiple_inputs, is_input_node, bind_inputs
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=eval-used,invalid-name,too-many-arguments
"""Utility functions"""
from tvm import relay
def has_multiple_inputs(node_list, node_idx, input_names):
"""Check whether a node has multiple input nodes
except variable nodes.
Parameters
----------
node_list : list of dict of str to object
List of all nodes in a graph.
node_idx : int
Node index to be checked.
input_names : list of str
List of input names of graph.
Returns
-------
out : bool
Whether the specified node has multiple input nodes
"""
num_inputs = 0
node = node_list[node_idx]
for in_idx in node["inputs"]:
in_idx = in_idx[0]
in_node = node_list[in_idx]
# Exclude parameter nodes
if in_node["op"] != "null" or is_input_node(in_node,
input_names):
num_inputs += 1
return num_inputs > 1
def is_input_node(node_entry, input_names):
"""Whether a node is an input node.
Parameters
----------
node_entry : dict
Node entry.
input_names : list of str
List of input names of graph.
Returns
-------
out : bool
whether node is a input node.
"""
return "name" in node_entry and node_entry["name"] in input_names
def bind_inputs(expr, input_shapes=None, input_dtypes="float32"):
"""Bind input variables of a relay function expression
to new shapes and/or dtypes.
Parameters
----------
expr : tvm.relay.Expr.Function
Input relay function expression.
input_shapes : dict of str to tuple of int, optional
Input shapes.
input_dtypes : str or dict of str to str, optional
Input dtypes.
Returns
-------
out : tvm.relay.Expr.Function
Bind relay function expression.
"""
if input_shapes is None:
return expr
if isinstance(input_dtypes, str):
input_dtypes = {key : input_dtypes for key in input_shapes.keys()}
updated_input_dict = {}
for input_name in input_shapes.keys():
updated_input = relay.var(input_name, shape=input_shapes[input_name],
dtype=input_dtypes[input_name])
updated_input_dict[input_name] = updated_input
rebind_dict = {}
for var in expr.params:
if var.name_hint in updated_input_dict:
rebind_dict[var] = updated_input_dict[var.name_hint]
updated_expr = relay.expr.bind(expr, rebind_dict)
return relay.ir_pass.infer_type(updated_expr)
......@@ -28,6 +28,7 @@ from .code_hash import attach_code_hash, attach_code_hash_to_arg
from .dispatcher import dispatcher, DispatchContext, ApplyConfig, ApplyHistoryBest, \
FallbackContext, clear_fallback_cache, ApplyGraphBest
from .topi_integration import register_topi_compute, register_topi_schedule
from .topi_integration import register_topi_compute, register_topi_schedule, \
TaskExtractEnv
from .nnvm_integration import extract_from_graph, extract_from_multiple_graph
from .relay_integration import extract_from_program, extract_from_multiple_program
......@@ -74,7 +74,7 @@ class TaskExtractEnv:
"""Global environment for extracting tuning tasks from nnvm graph"""
current = None
def __init__(self):
def __init__(self, allow_duplicate=False):
import topi
# topi compute -> autotvm task name
......@@ -106,6 +106,7 @@ class TaskExtractEnv:
topi.nn.deformable_conv2d_nchw: [topi.generic.schedule_deformable_conv2d_nchw],
}
self.allow_duplicate = allow_duplicate
self._register_tracing()
self._register_topi_task()
self.task_collection = []
......@@ -123,10 +124,9 @@ class TaskExtractEnv:
assert not kwargs, "Do not support extracting tuning tasks when" \
"kwargs is used in TOPI function call." \
"Please modify it to use only positional args."
if compute_func in self.wanted_topi_funcs: # record this call
key = (self.topi_to_task[compute_func], serialize_args(args))
if key not in self.task_collection:
if self.allow_duplicate or key not in self.task_collection:
self.task_collection.append(key)
return compute_func.fdefault(*args)
_local_scope(topi_compute)
......@@ -262,16 +262,25 @@ class TaskExtractEnv:
return self.task_collection
@staticmethod
def get():
def get(allow_duplicate=False):
"""Get the single instance of TaskExtractEnv
Parameters
----------
allow_duplicate : boolean
Whether to fetch all workloads in the network,
even though some of them are the same. This is
useful for graph tuning.
Returns
-------
env: TaskExtractEnv
The single instance of TaskExtractEnv
"""
if not TaskExtractEnv.current:
TaskExtractEnv.current = TaskExtractEnv()
TaskExtractEnv.current = TaskExtractEnv(allow_duplicate)
else:
TaskExtractEnv.current.allow_duplicate = allow_duplicate
return TaskExtractEnv.current
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# NOTE: We name this test file to start with test_graph_tuner
# to make it execute after zero_rank tensor test cases. This
# helps avoid topi arithmetic operator overloading issue:
# https://github.com/dmlc/tvm/issues/3240
# TODO: restore the file name after this issue is resolved.
import tvm
from tvm import autotvm, relay
from tvm.relay.testing import resnet
from tvm.autotvm.graph_tuner.utils import has_multiple_inputs, get_direct_ancestor, get_in_nodes, \
get_out_nodes, expr2graph, bind_inputs
from tvm.relay.expr import Call, TupleGetItem, Tuple
from topi.nn.conv2d import conv2d
def create_workload(dshape, kshape, strides,
padding, dilation, layout,
out_layout, dtype, out_dtype):
data = tvm.placeholder(dshape, dtype=dtype)
kernel = tvm.placeholder(kshape, dtype=dtype)
return autotvm.task.args_to_workload([data, kernel, strides, padding, dilation, layout,
out_dtype], conv2d)
def verify_has_multiple_inputs(node_list, node_idx, input_names, expected_result):
out = has_multiple_inputs(node_list, node_idx, input_names)
assert out == expected_result, "Output mismatch: expecting checking %s to be %s but got %s." \
% (node_list[node_idx]["op"], str(expected_result), str(out))
def test_has_multiple_inputs():
data = relay.var("data")
out1 = data * relay.expr.const(3.0)
w0 = relay.var("w0")
out2 = relay.nn.conv2d(data, w0)
out = relay.add(out1, out2)
net = relay.Function(relay.ir_pass.free_vars(out), out)
net = bind_inputs(net, {"data": (1, 16, 224, 224), "w0": (16, 16, 1, 1)})
target_ops = ["conv2d"]
node_list = []
node_dict = {}
expr2graph(net, target_ops, node_dict, node_list)
input_names = ["data"]
verify_has_multiple_inputs(node_list, 2, input_names, False)
verify_has_multiple_inputs(node_list, 4, input_names, False)
verify_has_multiple_inputs(node_list, 5, input_names, True)
def test_expr2graph():
net, _ = resnet.get_workload(num_layers=50, batch_size=1)
node_dict = {}
node_list = []
target_ops = ["conv2d"]
op_name_list = []
def _count_node(node):
if not isinstance(node, relay.op.op.Op,):
return
if isinstance(node, Call):
op_name_list.append(node.op.name.split(".")[-1])
elif isinstance(node, TupleGetItem):
op_name_list.append("TupleGetItem")
elif isinstance(node, Tuple):
op_name_list.append("Tuple")
else:
op_name_list.append("null")
relay.ir_pass.post_order_visit(net, _count_node)
expr2graph(net, target_ops, node_dict, node_list)
for i, item in enumerate(zip(op_name_list, node_list)):
op_name, node = item
assert op_name == node["op"], "%dth Node operator mismatch: expecting %s but got %s" \
% (i, str(op_name), str(node["op"]))
def test_get_direct_ancestor():
data = relay.var("data")
w0 = relay.var("w0")
out1 = relay.nn.conv2d(data, w0)
out2 = relay.add(out1, data * relay.expr.const(5.0))
out3 = out2 + relay.expr.const(2.5)
w1 = relay.var("w1")
out = relay.nn.conv2d(out3, w1)
net = relay.Function(relay.ir_pass.free_vars(out), out)
net = bind_inputs(net, {"data": (1, 16, 224, 224), "w0": (16, 16, 1, 1), "w1": (16, 16, 1, 1)})
target_ops = ["conv2d"]
node_list = []
node_dict = {}
expr2graph(net, target_ops, node_dict, node_list)
visited_dict = {}
input_names = ["data"]
out = get_direct_ancestor(node_list, visited_dict, target_ops, 5, input_names)
assert out == [2, 0], "Output mismatch: expecting [2, 0] but got %s." % str(out)
def test_get_in_nodes():
data = relay.var("data")
w0 = relay.var("w0")
out1 = relay.nn.conv2d(data, w0)
out2 = relay.add(out1, data)
out3 = out2 + relay.expr.const(2.5)
w1 = relay.var("w1")
out = relay.nn.conv2d(out3, w1)
net = relay.Function(relay.ir_pass.free_vars(out), out)
net = bind_inputs(net, {"data": (1, 16, 224, 224), "w0": (16, 16, 1, 1), "w1": (16, 16, 1, 1)})
target_ops = ["conv2d"]
input_names = ["data"]
node_list = []
node_dict = {}
expr2graph(net, target_ops, node_dict, node_list)
out = get_in_nodes(node_list, target_ops, input_names)
expected_out = {7: [3], 3: [2, 0], 2: [0]}
diff_set = set(out) ^ set(expected_out)
if len(diff_set) != 0:
raise RuntimeError("Output mismatch: expecting %s but got %s." % (str(expected_out), str(out)))
def test_get_out_nodes():
in_nodes_dict = {8: [4], 4: [3, 0], 3: [0]}
expected_out = {0: [3, 4], 3: [4], 4: [8], 8: []}
out = get_out_nodes(in_nodes_dict)
diff_set = set(out) ^ set(expected_out)
if len(diff_set) != 0:
raise RuntimeError("Output mismatch: expecting %s but got %s." % (str(expected_out), str(out)))
if __name__ == "__main__":
test_has_multiple_inputs()
test_expr2graph()
test_get_direct_ancestor()
test_get_in_nodes()
test_get_out_nodes()
......@@ -94,6 +94,26 @@ def conv2d_alter_layout(attrs, inputs, tinfos, F):
# not to change by default
return None
@tvm.target.generic_func
def conv2d_infer_layout(workload, cfg):
"""Infer input/output shapes and layouts from a workload and cfg.
Parameters
----------
workload : tuple
conv2d workload
cfg : tuple
tvm.autotvm config
Returns
-------
Output : [tuple of tuple and str, tuple of tuple and str]
Input shapes and layouts, and output shapes and layouts
"""
raise ValueError("missing register for topi.nn.conv2d_infer_layout")
def _get_workload(data, kernel, stride, padding, out_dtype, data_layout='NCHW'):
""" Get the workload structure. """
......
......@@ -336,3 +336,22 @@ def depthwise_conv2d_NCHWc(Input, Filter, stride, padding, dilation,
5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block]
"""
raise ValueError("missing register for topi.nn.depthwise_conv2d_NCHWc")
@tvm.target.generic_func
def depthwise_conv2d_infer_layout(workload, cfg):
"""Infer input/output shapes and layouts from a workload and cfg.
Parameters
----------
workload : tuple
conv2d workload
cfg : tuple
tvm.autotvm config
Returns
-------
Output : [tuple of tuple and str, tuple of tuple and str]
Input shapes and layouts, and output shapes and layouts
"""
raise ValueError("missing register for topi.nn.depthwise_conv2d_infer_layout")
......@@ -28,7 +28,7 @@ from .. import generic, tag
from .. import nn
from ..util import get_const_tuple
from ..nn.conv2d import conv2d, conv2d_NCHWc, \
conv2d_alter_layout, _get_workload as _get_conv2d_workload
conv2d_alter_layout, conv2d_infer_layout, _get_workload as _get_conv2d_workload
from ..nn.depthwise_conv2d import _get_workload as _get_depthwise_conv2d_workload
from ..nn.depthwise_conv2d import depthwise_conv2d_NCHWc, depthwise_conv2d_nchw
from ..nn.pad import pad
......@@ -475,6 +475,21 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F):
return F.nn.contrib_conv2d_nchwc(*copy_inputs, **new_attrs)
@conv2d_infer_layout.register("cpu")
def _conv2d_infer_layout(workload, cfg):
_, data, kernel, strides, padding, dilation, layout, dtype = workload
batch_size, in_channel, in_height, in_width = data[:-1]
out_channel, _, k_height, k_width = kernel[:-1]
out_height = (in_height + 2 * padding[0] - k_height) // strides[0] + 1
out_width = (in_width + 2 * padding[1] - k_width) // strides[1] + 1
tile_ic, tile_oc = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
in_shape = (batch_size, in_channel // tile_ic, in_height, in_width, tile_ic)
in_layout = "NCHW%dc" % tile_ic
out_shape = (batch_size, out_channel // tile_oc, out_height, out_width, tile_oc)
out_layout = "NCHW%dc" % tile_oc
return ((in_shape, in_layout),), ((out_shape, out_layout),)
@autotvm.register_topi_compute(conv2d_NCHWc, 'cpu', 'direct')
def _declaration_conv_NCHWc(cfg, data, kernel, strides,
padding, dilation, layout, out_layout, out_dtype):
......
......@@ -25,7 +25,8 @@ from .. import generic, tag
from ..nn.pad import pad
from ..util import get_const_tuple
from ..nn.util import get_pad_tuple
from ..nn.depthwise_conv2d import depthwise_conv2d_NCHWc, _get_workload
from ..nn.depthwise_conv2d import depthwise_conv2d_NCHWc, _get_workload, \
depthwise_conv2d_infer_layout
from .util import get_fp32_len
......@@ -206,7 +207,7 @@ def _topi_nn_depthwise_conv2d_NCHWc(*args, **kwargs):
# change shape with the value in config
ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
new_data_shape = (batch, in_channel // ic_bn, height, width, ic_bn)
new_kernel_shape = (out_channel // oc_bn, kh, kw, oc_bn)
new_kernel_shape = (out_channel // oc_bn, 1, kh, kw, 1, oc_bn)
new_data = tvm.placeholder(new_data_shape, data.dtype)
new_kernel = tvm.placeholder(new_kernel_shape, kernel.dtype)
......@@ -217,3 +218,18 @@ def _topi_nn_depthwise_conv2d_NCHWc(*args, **kwargs):
data_layout, out_layout, dtype)
s = schedule_depthwise_conv2d_NCHWc(cfg, [C])
return s, [new_data, new_kernel, C]
@depthwise_conv2d_infer_layout.register("cpu")
def _depthwise_conv2d_infer_layout(workload, cfg):
_, data, kernel, strides, padding, dilation, dtype = workload
batch_size, in_channel, in_height, in_width = data[:-1]
filter_channel, channel_multiplier, k_height, k_width = kernel[:-1]
out_channel = filter_channel * channel_multiplier
out_height = (in_height + 2 * padding[0] - k_height) // strides[0] + 1
out_width = (in_width + 2 * padding[1] - k_width) // strides[1] + 1
tile_ic, tile_oc = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
in_shape = (batch_size, in_channel // tile_ic, in_height, in_width, tile_ic)
in_layout = "NCHW%dc" % tile_ic
out_shape = (batch_size, out_channel // tile_oc, out_height, out_width, tile_oc)
out_layout = "NCHW%dc" % tile_oc
return ((in_shape, in_layout),), ((out_shape, out_layout),)
......@@ -30,6 +30,7 @@ from tvm import autotvm
from tvm import relay
from tvm.relay import testing
from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner
from tvm.autotvm.graph_tuner import DPTuner, PBQPTuner
import tvm.contrib.graph_runtime as runtime
#################################################################
......@@ -81,6 +82,7 @@ batch_size = 1
dtype = "float32"
model_name = "resnet-18"
log_file = "%s.log" % model_name
graph_opt_sch_file = "%s_graph_opt.log" % model_name
# Set number of threads used for tuning based on the number of
# physical CPU cores on your machine.
......@@ -157,6 +159,16 @@ def tune_kernels(tasks,
autotvm.callback.progress_bar(n_trial, prefix=prefix),
autotvm.callback.log_to_file(log_filename)])
# Use graph tuner to achieve graph level optimal schedules
# Set use_DP=False if it takes too long to finish.
def tune_graph(graph, dshape, records, opt_sch_file, use_DP=True):
target_op = [relay.nn.conv2d]
Tuner = DPTuner if use_DP else PBQPTuner
executor = Tuner(graph, {"data": dshape}, records, target_op, target)
executor.benchmark_layout_transform(min_exec_num=2000)
executor.run()
executor.write_opt_sch2record_file(opt_sch_file)
########################################################################
# Finally, we launch tuning jobs and evaluate the end-to-end performance.
......@@ -171,9 +183,10 @@ def tune_and_evaluate(tuning_opt):
# run tuning tasks
print("Tuning...")
tune_kernels(tasks, **tuning_opt)
tune_graph(net, data_shape, log_file, graph_opt_sch_file)
# compile kernels with history best records
with autotvm.apply_history_best(log_file):
# compile kernels with graph-level best records
with autotvm.apply_graph_best(graph_opt_sch_file):
print("Compile...")
with relay.build_config(opt_level=3):
graph, lib, params = relay.build_module.build(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment