Commit d7c600b8 by Tianqi Chen Committed by GitHub

[TOPI] Update parameter name of conv2d (#1380)

parent 30f09bea
......@@ -9,8 +9,8 @@ These are utility functions used for testing and tutorial file.
from __future__ import division
import math
import numpy as np
import cv2
from cffi import FFI
import cv2
def _resize_image(img, w_in, h_in):
"""Resize the image to the given height and width."""
......
# pylint: disable=invalid-name, unused-variable, too-many-locals, unused-argument
# pylint: disable=invalid-name, unused-variable, too-many-locals
# pylint: disable=unused-argument, redefined-builtin
"""Conv2D operators"""
from __future__ import absolute_import as _abs
from collections import namedtuple
......@@ -75,7 +76,7 @@ _WORKLOADS = [
_CONV_SCHEDULE = {}
@tvm.target.generic_func
def conv2d(data, kernel, stride, padding, layout='NCHW', out_dtype=None):
def conv2d(input, filter, strides, padding, layout='NCHW', out_dtype=None):
"""Conv2D operator.
Parameters
......@@ -86,7 +87,7 @@ def conv2d(data, kernel, stride, padding, layout='NCHW', out_dtype=None):
filter : tvm.Tensor
4-D with shape [num_filter, in_channel, filter_height, filter_width]
stride : int or a list/tuple of two ints
strides : int or a list/tuple of two ints
stride size, or [stride_height, stride_width]
padding : int or a list/tuple of two ints
......@@ -103,11 +104,11 @@ def conv2d(data, kernel, stride, padding, layout='NCHW', out_dtype=None):
# search platform specific declaration first
# default declaration
if layout == 'NCHW':
return conv2d_nchw(data, kernel, stride, padding, out_dtype)
return conv2d_nchw(input, filter, strides, padding, out_dtype)
elif layout == 'HWCN':
return conv2d_hwcn(data, kernel, stride, padding, out_dtype)
return conv2d_hwcn(input, filter, strides, padding, out_dtype)
elif layout == 'NHWC':
return conv2d_nhwc(data, kernel, stride, padding, out_dtype)
return conv2d_nhwc(input, filter, strides, padding, out_dtype)
else:
raise ValueError("not support this layout {} yet".format(layout))
......
......@@ -2,7 +2,6 @@
"""Conv2D schedule on raspberry pi"""
from __future__ import absolute_import as _abs
import tvm
from tvm import target as _target
from .. import tag
from ..nn.conv2d import conv2d as _conv2d, _get_schedule
from ..nn.conv2d import SpatialPack, Im2ColPack
......@@ -201,9 +200,7 @@ def _schedule_im2col_conv2d(s, data, data_pad, data_col, data_vec,
else:
stride = infer_stride(data_pad, kernel, output)
wkl = _get_workload(data, kernel, stride, padding, output.dtype)
with _target.rasp():
sch = _get_schedule(wkl)
sch = _get_schedule(wkl)
H, W = wkl.height, wkl.width
CI = wkl.in_filter
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment