Commit d7c600b8 by Tianqi Chen Committed by GitHub

[TOPI] Update parameter name of conv2d (#1380)

parent 30f09bea
...@@ -9,8 +9,8 @@ These are utility functions used for testing and tutorial file. ...@@ -9,8 +9,8 @@ These are utility functions used for testing and tutorial file.
from __future__ import division from __future__ import division
import math import math
import numpy as np import numpy as np
import cv2
from cffi import FFI from cffi import FFI
import cv2
def _resize_image(img, w_in, h_in): def _resize_image(img, w_in, h_in):
"""Resize the image to the given height and width.""" """Resize the image to the given height and width."""
......
# pylint: disable=invalid-name, unused-variable, too-many-locals, unused-argument # pylint: disable=invalid-name, unused-variable, too-many-locals
# pylint: disable=unused-argument, redefined-builtin
"""Conv2D operators""" """Conv2D operators"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from collections import namedtuple from collections import namedtuple
...@@ -75,7 +76,7 @@ _WORKLOADS = [ ...@@ -75,7 +76,7 @@ _WORKLOADS = [
_CONV_SCHEDULE = {} _CONV_SCHEDULE = {}
@tvm.target.generic_func @tvm.target.generic_func
def conv2d(data, kernel, stride, padding, layout='NCHW', out_dtype=None): def conv2d(input, filter, strides, padding, layout='NCHW', out_dtype=None):
"""Conv2D operator. """Conv2D operator.
Parameters Parameters
...@@ -86,7 +87,7 @@ def conv2d(data, kernel, stride, padding, layout='NCHW', out_dtype=None): ...@@ -86,7 +87,7 @@ def conv2d(data, kernel, stride, padding, layout='NCHW', out_dtype=None):
filter : tvm.Tensor filter : tvm.Tensor
4-D with shape [num_filter, in_channel, filter_height, filter_width] 4-D with shape [num_filter, in_channel, filter_height, filter_width]
stride : int or a list/tuple of two ints strides : int or a list/tuple of two ints
stride size, or [stride_height, stride_width] stride size, or [stride_height, stride_width]
padding : int or a list/tuple of two ints padding : int or a list/tuple of two ints
...@@ -103,11 +104,11 @@ def conv2d(data, kernel, stride, padding, layout='NCHW', out_dtype=None): ...@@ -103,11 +104,11 @@ def conv2d(data, kernel, stride, padding, layout='NCHW', out_dtype=None):
# search platform specific declaration first # search platform specific declaration first
# default declaration # default declaration
if layout == 'NCHW': if layout == 'NCHW':
return conv2d_nchw(data, kernel, stride, padding, out_dtype) return conv2d_nchw(input, filter, strides, padding, out_dtype)
elif layout == 'HWCN': elif layout == 'HWCN':
return conv2d_hwcn(data, kernel, stride, padding, out_dtype) return conv2d_hwcn(input, filter, strides, padding, out_dtype)
elif layout == 'NHWC': elif layout == 'NHWC':
return conv2d_nhwc(data, kernel, stride, padding, out_dtype) return conv2d_nhwc(input, filter, strides, padding, out_dtype)
else: else:
raise ValueError("not support this layout {} yet".format(layout)) raise ValueError("not support this layout {} yet".format(layout))
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
"""Conv2D schedule on raspberry pi""" """Conv2D schedule on raspberry pi"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import tvm import tvm
from tvm import target as _target
from .. import tag from .. import tag
from ..nn.conv2d import conv2d as _conv2d, _get_schedule from ..nn.conv2d import conv2d as _conv2d, _get_schedule
from ..nn.conv2d import SpatialPack, Im2ColPack from ..nn.conv2d import SpatialPack, Im2ColPack
...@@ -201,9 +200,7 @@ def _schedule_im2col_conv2d(s, data, data_pad, data_col, data_vec, ...@@ -201,9 +200,7 @@ def _schedule_im2col_conv2d(s, data, data_pad, data_col, data_vec,
else: else:
stride = infer_stride(data_pad, kernel, output) stride = infer_stride(data_pad, kernel, output)
wkl = _get_workload(data, kernel, stride, padding, output.dtype) wkl = _get_workload(data, kernel, stride, padding, output.dtype)
sch = _get_schedule(wkl)
with _target.rasp():
sch = _get_schedule(wkl)
H, W = wkl.height, wkl.width H, W = wkl.height, wkl.width
CI = wkl.in_filter CI = wkl.in_filter
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment