Commit be457348 by masahi Committed by Tianqi Chen

[TOPI] Upsampling op support (#772)

* add upsampling cpu op

* add upsampling gpu schedule

* add doc for upsampling op

add more doc

* cleanup upsampling test

* add doc

* fix lint

* fix lint

* fix lint

* remove unused import

* remove skimage dependency

* remove skimage import

* remove schedule_upsampling
parent af9f69a7
...@@ -177,7 +177,6 @@ def schedule_global_pool(outs): ...@@ -177,7 +177,6 @@ def schedule_global_pool(outs):
""" """
return _default_schedule(outs, False) return _default_schedule(outs, False)
@tvm.target.generic_func @tvm.target.generic_func
def schedule_binarize_pack(outs): def schedule_binarize_pack(outs):
"""Schedule for binarize_pack """Schedule for binarize_pack
......
...@@ -14,3 +14,4 @@ from .pooling import * ...@@ -14,3 +14,4 @@ from .pooling import *
from .softmax import * from .softmax import *
from .conv2d_transpose import * from .conv2d_transpose import *
from .bnn import * from .bnn import *
from .upsampling import *
"""TVM operator upsampling compute."""
from __future__ import absolute_import
import tvm
def upsampling(data, scale):
"""Perform nearest neighbor upsampling on the data.
Bilinear upsampling is not supported.
Parameters
----------
data : tvm.Tensor
4-D with shape [batch, channel, in_height, in_width]
scale: int
upsampling scaling factor
Returns
-------
output : tvm.Tensor
4-D with shape [batch, channel, in_height*scale, in_width*scale]
"""
batch, channel, height, width = data.shape
out_height = height * scale
out_width = width * scale
return tvm.compute((batch, channel, out_height, out_width), \
lambda n, c, h, w: data[n, c, h/scale, w/scale])
...@@ -10,3 +10,4 @@ from .conv2d_transpose_nchw_python import conv2d_transpose_nchw_python ...@@ -10,3 +10,4 @@ from .conv2d_transpose_nchw_python import conv2d_transpose_nchw_python
from .depthwise_conv2d_python import depthwise_conv2d_python_nchw, depthwise_conv2d_python_nhwc from .depthwise_conv2d_python import depthwise_conv2d_python_nchw, depthwise_conv2d_python_nhwc
from .dilate_python import dilate_python from .dilate_python import dilate_python
from .softmax_python import softmax_python, log_softmax_python from .softmax_python import softmax_python, log_softmax_python
from .upsampling_python import upsampling_python
# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals
"""Upsampling in python"""
import numpy as np
def upsample_nearest(arr, scale):
return arr.repeat(scale, axis=0).repeat(scale, axis=1)
def upsampling_python(data, scale):
ishape = data.shape
oshape = (ishape[0], ishape[1], ishape[2]*scale, ishape[3]*scale)
output_np = np.zeros(oshape, dtype=data.dtype)
for b in range(oshape[0]):
for c in range(oshape[1]):
output_np[b, c, :, :] = upsample_nearest(data[b, c, :, :], scale)
return output_np
"""Test code for upsampling"""
import numpy as np
import tvm
import topi
import math
def verify_upsampling(batch, in_channel, in_height, in_width, scale):
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
B = topi.nn.upsampling(A, scale)
out_shape = (batch, in_channel, in_height*scale, in_width*scale)
dtype = A.dtype
a_np = np.random.uniform(size=(batch, in_channel, in_height, in_width)).astype(dtype)
b_np = topi.testing.upsampling_python(a_np, scale)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_injective(B)
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(out_shape, dtype=dtype), ctx)
f = tvm.build(s, [A, B], device)
f(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['llvm', 'cuda']:
check_device(device)
def test_upsampling():
verify_upsampling(8, 16, 32, 32, 2)
verify_upsampling(12, 32, 64, 64, 3)
if __name__ == "__main__":
test_upsampling()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment