Commit f9788871 by Yida Wang Committed by Yizhi Liu

[TOPI] add basic scheduling for conv2d_transpose on x86 (#3491)

* initialize cond 2d transpose scheduling on x86

* refine the scheduler a bit

* fix for lint

* address review comments; remove duplicate code

* fix lint
parent 59448fed
...@@ -51,11 +51,15 @@ def conv2d_transpose_nchw(Input, Filter, strides, padding, out_dtype): ...@@ -51,11 +51,15 @@ def conv2d_transpose_nchw(Input, Filter, strides, padding, out_dtype):
Output : tvm.Tensor Output : tvm.Tensor
4-D with shape [batch, out_channel, out_height, out_width] 4-D with shape [batch, out_channel, out_height, out_width]
""" """
batch, in_c, in_h, in_w = Input.shape return declaration_conv2d_transpose_impl(Input, Filter, strides, padding, out_dtype)
_, out_c, filter_h, filter_w = Filter.shape
def declaration_conv2d_transpose_impl(data, kernel, strides, padding, out_dtype):
"""Implementation of conv2d transpose"""
batch, in_c, in_h, in_w = data.shape
_, out_c, filter_h, filter_w = kernel.shape
stride_h, stride_w = strides stride_h, stride_w = strides
# dilate stage # dilate stage
DilatedInput = dilate(Input, [1, 1, stride_h, stride_w], name='DilatedInput') DilatedInput = dilate(data, [1, 1, stride_h, stride_w], name='DilatedInput')
# padding stage # padding stage
fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple(padding, (filter_h, filter_w)) fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple(padding, (filter_h, filter_w))
bpad_top = filter_h - 1 - fpad_top bpad_top = filter_h - 1 - fpad_top
...@@ -78,7 +82,7 @@ def conv2d_transpose_nchw(Input, Filter, strides, padding, out_dtype): ...@@ -78,7 +82,7 @@ def conv2d_transpose_nchw(Input, Filter, strides, padding, out_dtype):
(batch, out_c, out_h, out_w), (batch, out_c, out_h, out_w),
lambda b, c, h, w: tvm.sum( lambda b, c, h, w: tvm.sum(
PaddedInput[b, dc, h+dh, w+dw].astype(out_dtype) * PaddedInput[b, dc, h+dh, w+dw].astype(out_dtype) *
Filter[dc, c, filter_h-1-dh, filter_w-1-dw].astype(out_dtype), kernel[dc, c, filter_h-1-dh, filter_w-1-dw].astype(out_dtype),
axis=[dc, dh, dw]), tag="conv2d_transpose_nchw") axis=[dc, dh, dw]), tag="conv2d_transpose_nchw")
return Output return Output
...@@ -14,3 +14,4 @@ from .depthwise_conv2d import schedule_depthwise_conv2d_NCHWc ...@@ -14,3 +14,4 @@ from .depthwise_conv2d import schedule_depthwise_conv2d_NCHWc
from .dense import _schedule_dense, _schedule_dense_pack, _schedule_dense_nopack from .dense import _schedule_dense, _schedule_dense_pack, _schedule_dense_nopack
from .batch_matmul import schedule_batch_matmul from .batch_matmul import schedule_batch_matmul
from .roi_align import roi_align_nchw from .roi_align import roi_align_nchw
from .conv2d_transpose import schedule_conv2d_transpose
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,unused-variable,unused-argument,no-member
"""Conv2D Transpose schedule on x86"""
import tvm
from tvm import autotvm
from .. import generic, tag
from ..nn.conv2d_transpose import conv2d_transpose_nchw, declaration_conv2d_transpose_impl
@autotvm.register_topi_compute(conv2d_transpose_nchw, 'cpu', ['direct'])
def _declaration_conv2d_transpose(cfg, data, kernel, strides, padding, out_dtype):
# TODO cfg is not used for now
return declaration_conv2d_transpose_impl(data, kernel, strides, padding, out_dtype)
@autotvm.register_topi_schedule(generic.schedule_conv2d_transpose_nchw, 'cpu', ['direct'])
def schedule_conv2d_transpose(cfg, outs):
"""Create schedule for tensors"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
def traverse(op):
"""Traverse operators from computation graph"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_injective(op.tag):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
if 'conv2d_transpose_nchw' in op.tag:
C = op.output(0)
N, OC, OH, OW = C.op.axis
rc, ry, rx = C.op.reduce_axis
OH, oh = s[C].split(OH, factor=2)
OC, oc = s[C].split(OC, factor=32)
IC, ic = s[C].split(rc, factor=32)
s[C].reorder(N, OC, OH, OW, oc, IC, ry, rx, ic)
N = s[C].fuse(N, OC)
s[C].vectorize(oc)
s[C].parallel(N)
scheduled_ops.append(op)
traverse(outs[0].op)
return s
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment