Commit c8a0f524 by Yao Wang Committed by Yizhi Liu

[AutoTVM]Core functionality for Graph tuner (#2184)

* Add graph tuning

* Add tests

* Fix tests

* Fix pylint

* Small fix for docstring

* Minor fix

* Support fetching workload from relay expr

* Simplify benchmark layout transformation

* Add relay support

* Fix infer layout func name

* Refactor internal data representation

* Fix issues

* Add PBQP solver

* Fix layout transform check

* Add PBQPTuner test

* Fix lint

* Update tutorial

* Fix tutorial

* Fix lint

* Add relay test

* Remove nnvm since nnvm graph can be converted to relay function

* Modify benchmark layout wrt new layout_transform api

* Fix lint

* Update docstring for DP tuner

* Refactor traverse graph

* Support graph tuning for multiple target operators

* Fix fetching workloads

* Add x86 depthwise_conv2d infer_layout

* Fix x86 depthwise_conv2d autotvm

* Fix PBQP tuner

* Fix DP tuner

* Generate dummy layout transform record

* Update tutorial

* Modify layout records name

* Add ASF header

* Add ASF header for testing files

* Fix test

* Fix topi fetching

* Some refactors

* Fix lint

* Fix tutorial

* Rename test files

* Fix doc typo

* Add test case note link
parent 4767554c
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Autotvm graph tuner API."""
from __future__ import absolute_import as _abs
from . import _base
from . import base_graph_tuner
from .base_graph_tuner import BaseGraphTuner
from .dynamic_programming_tuner import DPTuner
from .pbqp_tuner import PBQPTuner
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Helper functions and global data"""
RULE_OUT_NODE_NAMES = ["Tuple", "TupleGetItem", "batch_flatten", "transpose", "reshape",
"multibox_prior", "multibox_transform_loc", "where",
"non_max_suppression", "strided_slice"]
# We set a large time to represent an invalid layout-transformation.
# This number is set to be 10e9 seconds to align with autotvm.
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=import-error,too-many-locals,too-many-statements,too-many-branches,unused-variable
"""Dynamic programming tuner."""
import sys
import numpy as np
from .base_graph_tuner import BaseGraphTuner
from .dynamic_programming_stage import DPStage
from .utils import has_multiple_inputs, is_input_node
if sys.version_info[0] == 3:
import queue
import Queue as queue
class DPTuner(BaseGraphTuner):
"""Tuner which uses dynamic programming to solve MDP problem.
Note: currently dynamic programming is used to solve this MDP problem. However,
this problem is intrinsically non-polynomial. DP can't apply for more complicated
models, such as networks with many element-wise sum operators. In this case, switch
to heuristic algorithm such as PBQP tuner.
def __init__(self, *args, **kwargs):
"""Create a dynamic programming tuner.
super(DPTuner, self).__init__(*args, **kwargs)
self._num_states = self._max_num_states = None
self._stage_dict = {}
self._dep_dict = {}
self._counted_nodes_set = set()
self._global_data_dict = {
"dtype": self._dtype,
"counted_nodes_set": self._counted_nodes_set,
"stage_dict": self._stage_dict,
"in_nodes_dict": self._in_nodes_dict,
"out_nodes_dict": self._out_nodes_dict,
"dep_dict": self._dep_dict,
"node_list": self._node_list,
"input_shapes": self._input_shapes,
"layout_transform_interlayer_cost": self._layout_transform_interlayer_cost
def _check_num_states(self, num_states):
"""Track the number of states."""
self._num_states += num_states
if self._max_num_states is not None:
if self._num_states > self._max_num_states:
raise RuntimeError("Too many states detected while running dynamic "
"programming: got %d states but upper limit is %d." %
(self._num_states, self._max_num_states))
def _forward(self):
"""Forward pass in DP to generate states for all stages.
""""Start forward pass...")
for node_idx in sorted(self._in_nodes_dict.keys()):
stage = DPStage(idx=node_idx, target_ops=self._target_ops,
self._stage_dict[node_idx] = stage"Finished forward pass.")
def _backward(self):
"""Backward pass in DP to generate optimal solution.
""""Start backward pass...")
input_names = self._input_shapes.keys()
optimal_record_dict = {}
# Pick optimal schedule for output nodes
output_idx_list = []
for key, val in self._out_nodes_dict.items():
if not val:
states_list, aligned_node_list = DPStage.align_states(output_idx_list, self._stage_dict,
num_states = states_list[0][3].size
self._check_num_states(num_states * len(output_idx_list))
aligned_node_shape = states_list[0][3].shape
min_time = 0
min_pos = -1
for states in states_list:
min_time += np.amax(states[3])
flatten_states_list = [current_states[3].flatten() for current_states in states_list]
for i in range(num_states):
current_time = 0
for j, current_states in enumerate(states_list):
current_time += flatten_states_list[j][i]
if min_time > current_time:
min_time = current_time
min_pos = i
for i, states in enumerate(states_list):
current_major_axis = states[1]
current_sch_idx = (min_pos % (states[2] *
aligned_node_shape[current_major_axis])) // states[2]
optimal_record_dict[aligned_node_list[i]] = current_sch_idx
# Pick optimal schedule for dependencies of output nodes
for i in range(len(states_list), len(aligned_node_list)):
multiplier = 1
for j in range(i + 1, len(aligned_node_list)):
multiplier *= aligned_node_shape[j]
optimal_record_dict[aligned_node_list[i]] = \
min_pos // multiplier % aligned_node_shape[i]
# Backward pass to get optimal schedules for other nodes
bfs_q = queue.Queue()
visited = set()
for out_idx in output_idx_list:
while not bfs_q.empty():
node_idx = bfs_q.get()
if is_input_node(self._node_list[node_idx], input_names):
optimal_sch_idx = optimal_record_dict[node_idx]
full_states = self._stage_dict[node_idx].full_states
if not has_multiple_inputs(self._node_list, node_idx, input_names):
input_idx = self._in_nodes_dict[node_idx][0]
if is_input_node(self._node_list[input_idx], input_names):
if input_idx not in visited:
if input_idx not in optimal_record_dict:
dep_list = self._stage_dict[node_idx].dep
dep_idx = tuple([optimal_record_dict[item] for item in dep_list])
tmp = np.argmin(full_states, axis=1)
optimal_input_sch_idx = tmp[(optimal_sch_idx,) + dep_idx]
optimal_record_dict[input_idx] = optimal_input_sch_idx
input_idx_list = self._in_nodes_dict[node_idx]
optimal_record_dict[input_idx_list[0]] = optimal_sch_idx
full_states_idx = self._stage_dict[node_idx].full_states_idx
tmp = full_states[optimal_sch_idx]
new_states_idx, new_states_pos = [], []
visited_states_idx, visited_states_pos = [], []
for i in range(1, len(full_states_idx)):
if full_states_idx[i] in optimal_record_dict:
visited_states_pos.append(i - 1)
new_states_pos.append(i - 1)
if visited_states_idx:
tmp = np.transpose(tmp, tuple(visited_states_pos + new_states_pos))
tmp = tmp[tuple([optimal_record_dict[idx] for idx in visited_states_idx])]
min_pos = np.argmin(tmp)
multiplier = 1
for i in range(len(new_states_idx)):
multiplier *= full_states.shape[new_states_pos[i] + 1]
for pos, idx in zip(new_states_pos, new_states_idx):
multiplier //= full_states.shape[pos + 1]
optimal_record_dict[idx] = min_pos // multiplier
min_pos %= multiplier
for input_idx in input_idx_list:
if input_idx not in visited:
self._optimal_record_dict = optimal_record_dict
for node_idx, _ in self._in_nodes_dict.items():
if self._node_list[node_idx]["op"] not in self._target_ops:
continue"Finished backward pass...")
def run(self, **kwargs):
"""Run dynamic programming solver.
max_num_states = None if "max_num_states" not in kwargs else kwargs["max_num_states"]
self._num_states = 0
self._max_num_states = max_num_states"Start to run dynamic programming algorithm...")
self._backward()"Finished DPExecutor run.")
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=wildcard-import
"""Graph tuner utility functions"""
from __future__ import absolute_import
from . import traverse_graph
from . import utils
from .traverse_graph import expr2graph, get_direct_ancestor, get_in_nodes, \
from .utils import has_multiple_inputs, is_input_node, bind_inputs
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=eval-used,invalid-name,too-many-arguments
"""Utility functions"""
from tvm import relay
def has_multiple_inputs(node_list, node_idx, input_names):
"""Check whether a node has multiple input nodes
except variable nodes.
node_list : list of dict of str to object
List of all nodes in a graph.
node_idx : int
Node index to be checked.
input_names : list of str
List of input names of graph.
out : bool
Whether the specified node has multiple input nodes
num_inputs = 0
node = node_list[node_idx]
for in_idx in node["inputs"]:
in_idx = in_idx[0]
in_node = node_list[in_idx]
# Exclude parameter nodes
if in_node["op"] != "null" or is_input_node(in_node,
num_inputs += 1
return num_inputs > 1
def is_input_node(node_entry, input_names):
"""Whether a node is an input node.
node_entry : dict
Node entry.
input_names : list of str
List of input names of graph.
out : bool
whether node is a input node.
return "name" in node_entry and node_entry["name"] in input_names
def bind_inputs(expr, input_shapes=None, input_dtypes="float32"):
"""Bind input variables of a relay function expression
to new shapes and/or dtypes.
expr : tvm.relay.Expr.Function
Input relay function expression.
input_shapes : dict of str to tuple of int, optional
Input shapes.
input_dtypes : str or dict of str to str, optional
Input dtypes.
out : tvm.relay.Expr.Function
Bind relay function expression.
if input_shapes is None:
return expr
if isinstance(input_dtypes, str):
input_dtypes = {key : input_dtypes for key in input_shapes.keys()}
updated_input_dict = {}
for input_name in input_shapes.keys():
updated_input = relay.var(input_name, shape=input_shapes[input_name],
updated_input_dict[input_name] = updated_input
rebind_dict = {}
for var in expr.params:
if var.name_hint in updated_input_dict:
rebind_dict[var] = updated_input_dict[var.name_hint]
updated_expr = relay.expr.bind(expr, rebind_dict)
return relay.ir_pass.infer_type(updated_expr)
...@@ -28,6 +28,7 @@ from .code_hash import attach_code_hash, attach_code_hash_to_arg ...@@ -28,6 +28,7 @@ from .code_hash import attach_code_hash, attach_code_hash_to_arg
from .dispatcher import dispatcher, DispatchContext, ApplyConfig, ApplyHistoryBest, \ from .dispatcher import dispatcher, DispatchContext, ApplyConfig, ApplyHistoryBest, \
FallbackContext, clear_fallback_cache, ApplyGraphBest FallbackContext, clear_fallback_cache, ApplyGraphBest
from .topi_integration import register_topi_compute, register_topi_schedule from .topi_integration import register_topi_compute, register_topi_schedule, \
from .nnvm_integration import extract_from_graph, extract_from_multiple_graph from .nnvm_integration import extract_from_graph, extract_from_multiple_graph
from .relay_integration import extract_from_program, extract_from_multiple_program from .relay_integration import extract_from_program, extract_from_multiple_program
...@@ -74,7 +74,7 @@ class TaskExtractEnv: ...@@ -74,7 +74,7 @@ class TaskExtractEnv:
"""Global environment for extracting tuning tasks from nnvm graph""" """Global environment for extracting tuning tasks from nnvm graph"""
current = None current = None
def __init__(self): def __init__(self, allow_duplicate=False):
import topi import topi
# topi compute -> autotvm task name # topi compute -> autotvm task name
...@@ -106,6 +106,7 @@ class TaskExtractEnv: ...@@ -106,6 +106,7 @@ class TaskExtractEnv:
topi.nn.deformable_conv2d_nchw: [topi.generic.schedule_deformable_conv2d_nchw], topi.nn.deformable_conv2d_nchw: [topi.generic.schedule_deformable_conv2d_nchw],
} }
self.allow_duplicate = allow_duplicate
self._register_tracing() self._register_tracing()
self._register_topi_task() self._register_topi_task()
self.task_collection = [] self.task_collection = []
...@@ -123,10 +124,9 @@ class TaskExtractEnv: ...@@ -123,10 +124,9 @@ class TaskExtractEnv:
assert not kwargs, "Do not support extracting tuning tasks when" \ assert not kwargs, "Do not support extracting tuning tasks when" \
"kwargs is used in TOPI function call." \ "kwargs is used in TOPI function call." \
"Please modify it to use only positional args." "Please modify it to use only positional args."
if compute_func in self.wanted_topi_funcs: # record this call if compute_func in self.wanted_topi_funcs: # record this call
key = (self.topi_to_task[compute_func], serialize_args(args)) key = (self.topi_to_task[compute_func], serialize_args(args))
if key not in self.task_collection: if self.allow_duplicate or key not in self.task_collection:
self.task_collection.append(key) self.task_collection.append(key)
return compute_func.fdefault(*args) return compute_func.fdefault(*args)
_local_scope(topi_compute) _local_scope(topi_compute)
...@@ -262,16 +262,25 @@ class TaskExtractEnv: ...@@ -262,16 +262,25 @@ class TaskExtractEnv:
return self.task_collection return self.task_collection
@staticmethod @staticmethod
def get(): def get(allow_duplicate=False):
"""Get the single instance of TaskExtractEnv """Get the single instance of TaskExtractEnv
allow_duplicate : boolean
Whether to fetch all workloads in the network,
even though some of them are the same. This is
useful for graph tuning.
Returns Returns
------- -------
env: TaskExtractEnv env: TaskExtractEnv
The single instance of TaskExtractEnv The single instance of TaskExtractEnv
""" """
if not TaskExtractEnv.current: if not TaskExtractEnv.current:
TaskExtractEnv.current = TaskExtractEnv() TaskExtractEnv.current = TaskExtractEnv(allow_duplicate)
TaskExtractEnv.current.allow_duplicate = allow_duplicate
return TaskExtractEnv.current return TaskExtractEnv.current
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# NOTE: We name this test file to start with test_graph_tuner
# to make it execute after zero_rank tensor test cases. This
# helps avoid topi arithmetic operator overloading issue:
# TODO: restore the file name after this issue is resolved.
import tvm
from tvm import autotvm, relay
from tvm.relay.testing import resnet
from tvm.autotvm.graph_tuner.utils import has_multiple_inputs, get_direct_ancestor, get_in_nodes, \
get_out_nodes, expr2graph, bind_inputs
from tvm.relay.expr import Call, TupleGetItem, Tuple
from topi.nn.conv2d import conv2d
def create_workload(dshape, kshape, strides,
padding, dilation, layout,
out_layout, dtype, out_dtype):
data = tvm.placeholder(dshape, dtype=dtype)
kernel = tvm.placeholder(kshape, dtype=dtype)
return autotvm.task.args_to_workload([data, kernel, strides, padding, dilation, layout,
out_dtype], conv2d)
def verify_has_multiple_inputs(node_list, node_idx, input_names, expected_result):
out = has_multiple_inputs(node_list, node_idx, input_names)
assert out == expected_result, "Output mismatch: expecting checking %s to be %s but got %s." \
% (node_list[node_idx]["op"], str(expected_result), str(out))
def test_has_multiple_inputs():
data = relay.var("data")
out1 = data * relay.expr.const(3.0)
w0 = relay.var("w0")
out2 = relay.nn.conv2d(data, w0)
out = relay.add(out1, out2)
net = relay.Function(relay.ir_pass.free_vars(out), out)
net = bind_inputs(net, {"data": (1, 16, 224, 224), "w0": (16, 16, 1, 1)})
target_ops = ["conv2d"]
node_list = []
node_dict = {}
expr2graph(net, target_ops, node_dict, node_list)
input_names = ["data"]
verify_has_multiple_inputs(node_list, 2, input_names, False)
verify_has_multiple_inputs(node_list, 4, input_names, False)
verify_has_multiple_inputs(node_list, 5, input_names, True)
def test_expr2graph():
net, _ = resnet.get_workload(num_layers=50, batch_size=1)
node_dict = {}
node_list = []
target_ops = ["conv2d"]
op_name_list = []
def _count_node(node):
if not isinstance(node, relay.op.op.Op,):
if isinstance(node, Call):
elif isinstance(node, TupleGetItem):
elif isinstance(node, Tuple):
relay.ir_pass.post_order_visit(net, _count_node)
expr2graph(net, target_ops, node_dict, node_list)
for i, item in enumerate(zip(op_name_list, node_list)):
op_name, node = item
assert op_name == node["op"], "%dth Node operator mismatch: expecting %s but got %s" \
% (i, str(op_name), str(node["op"]))
def test_get_direct_ancestor():
data = relay.var("data")
w0 = relay.var("w0")
out1 = relay.nn.conv2d(data, w0)
out2 = relay.add(out1, data * relay.expr.const(5.0))
out3 = out2 + relay.expr.const(2.5)
w1 = relay.var("w1")
out = relay.nn.conv2d(out3, w1)
net = relay.Function(relay.ir_pass.free_vars(out), out)
net = bind_inputs(net, {"data": (1, 16, 224, 224), "w0": (16, 16, 1, 1), "w1": (16, 16, 1, 1)})
target_ops = ["conv2d"]
node_list = []
node_dict = {}
expr2graph(net, target_ops, node_dict, node_list)
visited_dict = {}
input_names = ["data"]
out = get_direct_ancestor(node_list, visited_dict, target_ops, 5, input_names)
assert out == [2, 0], "Output mismatch: expecting [2, 0] but got %s." % str(out)
def test_get_in_nodes():
data = relay.var("data")
w0 = relay.var("w0")
out1 = relay.nn.conv2d(data, w0)
out2 = relay.add(out1, data)
out3 = out2 + relay.expr.const(2.5)
w1 = relay.var("w1")
out = relay.nn.conv2d(out3, w1)
net = relay.Function(relay.ir_pass.free_vars(out), out)
net = bind_inputs(net, {"data": (1, 16, 224, 224), "w0": (16, 16, 1, 1), "w1": (16, 16, 1, 1)})
target_ops = ["conv2d"]
input_names = ["data"]
node_list = []
node_dict = {}
expr2graph(net, target_ops, node_dict, node_list)
out = get_in_nodes(node_list, target_ops, input_names)
expected_out = {7: [3], 3: [2, 0], 2: [0]}
diff_set = set(out) ^ set(expected_out)
if len(diff_set) != 0:
raise RuntimeError("Output mismatch: expecting %s but got %s." % (str(expected_out), str(out)))
def test_get_out_nodes():
in_nodes_dict = {8: [4], 4: [3, 0], 3: [0]}
expected_out = {0: [3, 4], 3: [4], 4: [8], 8: []}
out = get_out_nodes(in_nodes_dict)
diff_set = set(out) ^ set(expected_out)
if len(diff_set) != 0:
raise RuntimeError("Output mismatch: expecting %s but got %s." % (str(expected_out), str(out)))
if __name__ == "__main__":
...@@ -94,6 +94,26 @@ def conv2d_alter_layout(attrs, inputs, tinfos, F): ...@@ -94,6 +94,26 @@ def conv2d_alter_layout(attrs, inputs, tinfos, F):
# not to change by default # not to change by default
return None return None
def conv2d_infer_layout(workload, cfg):
"""Infer input/output shapes and layouts from a workload and cfg.
workload : tuple
conv2d workload
cfg : tuple
tvm.autotvm config
Output : [tuple of tuple and str, tuple of tuple and str]
Input shapes and layouts, and output shapes and layouts
raise ValueError("missing register for topi.nn.conv2d_infer_layout")
def _get_workload(data, kernel, stride, padding, out_dtype, data_layout='NCHW'): def _get_workload(data, kernel, stride, padding, out_dtype, data_layout='NCHW'):
""" Get the workload structure. """ """ Get the workload structure. """
...@@ -336,3 +336,22 @@ def depthwise_conv2d_NCHWc(Input, Filter, stride, padding, dilation, ...@@ -336,3 +336,22 @@ def depthwise_conv2d_NCHWc(Input, Filter, stride, padding, dilation,
5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block] 5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block]
""" """
raise ValueError("missing register for topi.nn.depthwise_conv2d_NCHWc") raise ValueError("missing register for topi.nn.depthwise_conv2d_NCHWc")
def depthwise_conv2d_infer_layout(workload, cfg):
"""Infer input/output shapes and layouts from a workload and cfg.
workload : tuple
conv2d workload
cfg : tuple
tvm.autotvm config
Output : [tuple of tuple and str, tuple of tuple and str]
Input shapes and layouts, and output shapes and layouts
raise ValueError("missing register for topi.nn.depthwise_conv2d_infer_layout")
...@@ -28,7 +28,7 @@ from .. import generic, tag ...@@ -28,7 +28,7 @@ from .. import generic, tag
from .. import nn from .. import nn
from ..util import get_const_tuple from ..util import get_const_tuple
from ..nn.conv2d import conv2d, conv2d_NCHWc, \ from ..nn.conv2d import conv2d, conv2d_NCHWc, \
conv2d_alter_layout, _get_workload as _get_conv2d_workload conv2d_alter_layout, conv2d_infer_layout, _get_workload as _get_conv2d_workload
from ..nn.depthwise_conv2d import _get_workload as _get_depthwise_conv2d_workload from ..nn.depthwise_conv2d import _get_workload as _get_depthwise_conv2d_workload
from ..nn.depthwise_conv2d import depthwise_conv2d_NCHWc, depthwise_conv2d_nchw from ..nn.depthwise_conv2d import depthwise_conv2d_NCHWc, depthwise_conv2d_nchw
from ..nn.pad import pad from ..nn.pad import pad
...@@ -475,6 +475,21 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F): ...@@ -475,6 +475,21 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F):
return F.nn.contrib_conv2d_nchwc(*copy_inputs, **new_attrs) return F.nn.contrib_conv2d_nchwc(*copy_inputs, **new_attrs)
def _conv2d_infer_layout(workload, cfg):
_, data, kernel, strides, padding, dilation, layout, dtype = workload
batch_size, in_channel, in_height, in_width = data[:-1]
out_channel, _, k_height, k_width = kernel[:-1]
out_height = (in_height + 2 * padding[0] - k_height) // strides[0] + 1
out_width = (in_width + 2 * padding[1] - k_width) // strides[1] + 1
tile_ic, tile_oc = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
in_shape = (batch_size, in_channel // tile_ic, in_height, in_width, tile_ic)
in_layout = "NCHW%dc" % tile_ic
out_shape = (batch_size, out_channel // tile_oc, out_height, out_width, tile_oc)
out_layout = "NCHW%dc" % tile_oc
return ((in_shape, in_layout),), ((out_shape, out_layout),)
@autotvm.register_topi_compute(conv2d_NCHWc, 'cpu', 'direct') @autotvm.register_topi_compute(conv2d_NCHWc, 'cpu', 'direct')
def _declaration_conv_NCHWc(cfg, data, kernel, strides, def _declaration_conv_NCHWc(cfg, data, kernel, strides,
padding, dilation, layout, out_layout, out_dtype): padding, dilation, layout, out_layout, out_dtype):
...@@ -25,7 +25,8 @@ from .. import generic, tag ...@@ -25,7 +25,8 @@ from .. import generic, tag
from ..nn.pad import pad from ..nn.pad import pad
from ..util import get_const_tuple from ..util import get_const_tuple
from ..nn.util import get_pad_tuple from ..nn.util import get_pad_tuple
from ..nn.depthwise_conv2d import depthwise_conv2d_NCHWc, _get_workload from ..nn.depthwise_conv2d import depthwise_conv2d_NCHWc, _get_workload, \
from .util import get_fp32_len from .util import get_fp32_len
...@@ -206,7 +207,7 @@ def _topi_nn_depthwise_conv2d_NCHWc(*args, **kwargs): ...@@ -206,7 +207,7 @@ def _topi_nn_depthwise_conv2d_NCHWc(*args, **kwargs):
# change shape with the value in config # change shape with the value in config
ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1] ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
new_data_shape = (batch, in_channel // ic_bn, height, width, ic_bn) new_data_shape = (batch, in_channel // ic_bn, height, width, ic_bn)
new_kernel_shape = (out_channel // oc_bn, kh, kw, oc_bn) new_kernel_shape = (out_channel // oc_bn, 1, kh, kw, 1, oc_bn)
new_data = tvm.placeholder(new_data_shape, data.dtype) new_data = tvm.placeholder(new_data_shape, data.dtype)
new_kernel = tvm.placeholder(new_kernel_shape, kernel.dtype) new_kernel = tvm.placeholder(new_kernel_shape, kernel.dtype)
...@@ -217,3 +218,18 @@ def _topi_nn_depthwise_conv2d_NCHWc(*args, **kwargs): ...@@ -217,3 +218,18 @@ def _topi_nn_depthwise_conv2d_NCHWc(*args, **kwargs):
data_layout, out_layout, dtype) data_layout, out_layout, dtype)
s = schedule_depthwise_conv2d_NCHWc(cfg, [C]) s = schedule_depthwise_conv2d_NCHWc(cfg, [C])
return s, [new_data, new_kernel, C] return s, [new_data, new_kernel, C]
def _depthwise_conv2d_infer_layout(workload, cfg):
_, data, kernel, strides, padding, dilation, dtype = workload
batch_size, in_channel, in_height, in_width = data[:-1]
filter_channel, channel_multiplier, k_height, k_width = kernel[:-1]
out_channel = filter_channel * channel_multiplier
out_height = (in_height + 2 * padding[0] - k_height) // strides[0] + 1
out_width = (in_width + 2 * padding[1] - k_width) // strides[1] + 1
tile_ic, tile_oc = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
in_shape = (batch_size, in_channel // tile_ic, in_height, in_width, tile_ic)
in_layout = "NCHW%dc" % tile_ic
out_shape = (batch_size, out_channel // tile_oc, out_height, out_width, tile_oc)
out_layout = "NCHW%dc" % tile_oc
return ((in_shape, in_layout),), ((out_shape, out_layout),)
...@@ -30,6 +30,7 @@ from tvm import autotvm ...@@ -30,6 +30,7 @@ from tvm import autotvm
from tvm import relay from tvm import relay
from tvm.relay import testing from tvm.relay import testing
from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner
from tvm.autotvm.graph_tuner import DPTuner, PBQPTuner
import tvm.contrib.graph_runtime as runtime import tvm.contrib.graph_runtime as runtime
################################################################# #################################################################
...@@ -81,6 +82,7 @@ batch_size = 1 ...@@ -81,6 +82,7 @@ batch_size = 1
dtype = "float32" dtype = "float32"
model_name = "resnet-18" model_name = "resnet-18"
log_file = "%s.log" % model_name log_file = "%s.log" % model_name
graph_opt_sch_file = "%s_graph_opt.log" % model_name
# Set number of threads used for tuning based on the number of # Set number of threads used for tuning based on the number of
# physical CPU cores on your machine. # physical CPU cores on your machine.
...@@ -157,6 +159,16 @@ def tune_kernels(tasks, ...@@ -157,6 +159,16 @@ def tune_kernels(tasks,
autotvm.callback.progress_bar(n_trial, prefix=prefix), autotvm.callback.progress_bar(n_trial, prefix=prefix),
autotvm.callback.log_to_file(log_filename)]) autotvm.callback.log_to_file(log_filename)])
# Use graph tuner to achieve graph level optimal schedules
# Set use_DP=False if it takes too long to finish.
def tune_graph(graph, dshape, records, opt_sch_file, use_DP=True):
target_op = [relay.nn.conv2d]
Tuner = DPTuner if use_DP else PBQPTuner
executor = Tuner(graph, {"data": dshape}, records, target_op, target)
######################################################################## ########################################################################
# Finally, we launch tuning jobs and evaluate the end-to-end performance. # Finally, we launch tuning jobs and evaluate the end-to-end performance.
...@@ -171,9 +183,10 @@ def tune_and_evaluate(tuning_opt): ...@@ -171,9 +183,10 @@ def tune_and_evaluate(tuning_opt):
# run tuning tasks # run tuning tasks
print("Tuning...") print("Tuning...")
tune_kernels(tasks, **tuning_opt) tune_kernels(tasks, **tuning_opt)
tune_graph(net, data_shape, log_file, graph_opt_sch_file)
# compile kernels with history best records # compile kernels with graph-level best records
with autotvm.apply_history_best(log_file): with autotvm.apply_graph_best(graph_opt_sch_file):
print("Compile...") print("Compile...")
with relay.build_config(opt_level=3): with relay.build_config(opt_level=3):
graph, lib, params = graph, lib, params =
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment