Commit 5b8a8d00 by ziheng Committed by Tianqi Chen

[UTILS] Move target to tvm; rename convolution as conv2d (#492)

* Move target to tvm; rename convolution as conv2d

* Fix

* Fix
parent dc7ab96b
...@@ -14,6 +14,7 @@ from . import schedule ...@@ -14,6 +14,7 @@ from . import schedule
from . import module from . import module
from . import node from . import node
from . import ir_builder from . import ir_builder
from . import target
from . import ndarray as nd from . import ndarray as nd
from .ndarray import context, cpu, gpu, opencl, cl, metal, mtl, vpi, rocm from .ndarray import context, cpu, gpu, opencl, cl, metal, mtl, vpi, rocm
......
"""Target management API of topi""" """Target management API of tvm"""
from __future__ import absolute_import from __future__ import absolute_import
......
...@@ -16,6 +16,5 @@ from .broadcast import * ...@@ -16,6 +16,5 @@ from .broadcast import *
from . import nn from . import nn
from . import cuda from . import cuda
from . import rasp from . import rasp
from . import target
from . import testing from . import testing
from . import util from . import util
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from .batch_norm import * from .batch_norm import *
from .convolution import * from .conv2d import *
from .depthwise_convolution import * from .depthwise_convolution import *
from .elemwise import * from .elemwise import *
from .dilate import * from .dilate import *
......
# pylint: disable=invalid-name, unused-variable, too-many-locals # pylint: disable=invalid-name, unused-variable, too-many-locals
"""Convolution operators""" """Conv2D operators"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from collections import namedtuple from collections import namedtuple
import tvm import tvm
from tvm import target as _target
from .pad import pad from .pad import pad
from .util import get_pad_tuple from .util import get_pad_tuple
from ..util import simplify from ..util import simplify
from .. import target as _target
# workload description of convolution # workload description of conv2d
Workload = namedtuple('Workload', Workload = namedtuple('Workload',
['height', 'width', 'in_filter', 'out_filter', ['height', 'width', 'in_filter', 'out_filter',
'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride']) 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])
...@@ -43,8 +43,8 @@ _CONV_SCHEDULE = {} ...@@ -43,8 +43,8 @@ _CONV_SCHEDULE = {}
# platform specific declaration # platform specific declaration
_CONV_DECLARATION = {} _CONV_DECLARATION = {}
def convolution(data, kernel, stride, padding, layout='NCHW'): def conv2d(data, kernel, stride, padding, layout='NCHW'):
"""Convolution operator. """Conv2D operator.
Parameters Parameters
---------- ----------
...@@ -75,9 +75,9 @@ def convolution(data, kernel, stride, padding, layout='NCHW'): ...@@ -75,9 +75,9 @@ def convolution(data, kernel, stride, padding, layout='NCHW'):
# default declaration # default declaration
if layout == 'NCHW': if layout == 'NCHW':
conv2d_nchw(data, kernel, stride, padding) return conv2d_nchw(data, kernel, stride, padding)
elif layout == 'HWCN': elif layout == 'HWCN':
conv2d_hwcn(data, kernel, stride, padding) return conv2d_hwcn(data, kernel, stride, padding)
else: else:
raise ValueError("not support this layout {} yet".format(layout)) raise ValueError("not support this layout {} yet".format(layout))
......
...@@ -2,4 +2,4 @@ ...@@ -2,4 +2,4 @@
"""Raspberry pi specific declaration and schedules.""" """Raspberry pi specific declaration and schedules."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from .convolution import * from .conv2d import *
# pylint: disable=invalid-name,unused-variable,invalid-name # pylint: disable=invalid-name,unused-variable,invalid-name
"""Convolution schedule on raspberry pi""" """Conv2D schedule on raspberry pi"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import tvm import tvm
from .. import target as _target from tvm import target as _target
from .. import tag from .. import tag
from ..nn.convolution import SpatialPack, Im2ColPack from ..nn.conv2d import SpatialPack, Im2ColPack
from ..nn.convolution import _CONV_DECLARATION, _CONV_SCHEDULE from ..nn.conv2d import _CONV_DECLARATION, _CONV_SCHEDULE
from ..nn.convolution import _WORKLOADS, _SCH_TO_DECL_FUNC from ..nn.conv2d import _WORKLOADS, _SCH_TO_DECL_FUNC
from ..nn.convolution import _get_workload, _get_schedule from ..nn.conv2d import _get_workload, _get_schedule
from ..nn.util import infer_pad, infer_stride from ..nn.util import infer_pad, infer_stride
_SCHEDULES = [ _SCHEDULES = [
...@@ -264,7 +264,7 @@ def _schedule_im2col_conv2d(s, data, data_pad, data_col, data_vec, ...@@ -264,7 +264,7 @@ def _schedule_im2col_conv2d(s, data, data_pad, data_col, data_vec,
return s return s
def schedule_convolution(outs): def schedule_conv2d(outs):
"""Create schedule for tensors""" """Create schedule for tensors"""
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
......
"""Example code to do convolution.""" """Example code to do conv2d."""
import os import os
import numpy as np import numpy as np
import tvm import tvm
...@@ -7,20 +7,20 @@ from tvm.contrib.pickle_memoize import memoize ...@@ -7,20 +7,20 @@ from tvm.contrib.pickle_memoize import memoize
from topi.util import get_const_tuple from topi.util import get_const_tuple
def verify_convolution(batch, in_size, in_channel, num_filter, kernel, stride, padding): def verify_conv2d(batch, in_size, in_channel, num_filter, kernel, stride, padding):
in_height = in_width = in_size in_height = in_width = in_size
with topi.target.rasp(): with tvm.target.rasp():
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A') A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W') W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W')
B = topi.nn.convolution(A, W, stride, padding) B = topi.nn.conv2d(A, W, stride, padding)
s = topi.rasp.schedule_convolution([B]) s = topi.rasp.schedule_conv2d([B])
a_shape = get_const_tuple(A.shape) a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape) w_shape = get_const_tuple(W.shape)
dtype = A.dtype dtype = A.dtype
@memoize("topi.tests.test_topi_convolution.verify_convolution") @memoize("topi.tests.test_topi_conv2d.verify_conv2d")
def get_ref_data(): def get_ref_data():
a_np = np.random.uniform(size=a_shape).astype(dtype) a_np = np.random.uniform(size=a_shape).astype(dtype)
w_np = np.random.uniform(size=w_shape).astype(dtype) w_np = np.random.uniform(size=w_shape).astype(dtype)
...@@ -37,8 +37,8 @@ def verify_convolution(batch, in_size, in_channel, num_filter, kernel, stride, p ...@@ -37,8 +37,8 @@ def verify_convolution(batch, in_size, in_channel, num_filter, kernel, stride, p
func(a, w, b) func(a, w, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
def test_convolution(): def test_conv2d():
verify_convolution(1, 56, 64, 64, 3, 1, 1) verify_conv2d(1, 56, 64, 64, 3, 1, 1)
if __name__ == "__main__": if __name__ == "__main__":
test_convolution() test_conv2d()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment