Commit c3deec19 by optima2005 Committed by masahi

[TOPI] add 3D upsampling Op. (#4584)

* [TOPI] add 3D upsampling Op.

* fix lint issues

* change align_corners to coordinate_transformation_mode

* fix resize3d half_pixel

* make a simple function and clean up trilinear_resize3d_python

* fix doc
parent 1071e242
......@@ -77,6 +77,7 @@ This level enables typical convnet models.
tvm.relay.nn.global_max_pool2d
tvm.relay.nn.global_avg_pool2d
tvm.relay.nn.upsampling
tvm.relay.nn.upsampling3d
tvm.relay.nn.batch_flatten
tvm.relay.nn.pad
tvm.relay.nn.lrn
......@@ -254,6 +255,7 @@ Level 2 Definitions
.. autofunction:: tvm.relay.nn.global_max_pool2d
.. autofunction:: tvm.relay.nn.global_avg_pool2d
.. autofunction:: tvm.relay.nn.upsampling
.. autofunction:: tvm.relay.nn.upsampling3d
.. autofunction:: tvm.relay.nn.batch_flatten
.. autofunction:: tvm.relay.nn.pad
.. autofunction:: tvm.relay.nn.lrn
......
......@@ -589,6 +589,39 @@ struct UpSamplingAttrs : public tvm::AttrsNode<UpSamplingAttrs> {
}
};
/*! \brief Attributes for upsampling3d operator */
struct UpSampling3DAttrs : public tvm::AttrsNode<UpSampling3DAttrs> {
double scale_d;
double scale_h;
double scale_w;
std::string layout;
std::string method;
std::string coordinate_transformation_mode;
TVM_DECLARE_ATTRS(UpSampling3DAttrs, "relay.attrs.UpSampling3DAttrs") {
TVM_ATTR_FIELD(scale_d)
.describe("The upsampling factor for depth");
TVM_ATTR_FIELD(scale_h)
.describe("The upsampling factor for height");
TVM_ATTR_FIELD(scale_w)
.describe("The upsampling factor for width");
TVM_ATTR_FIELD(layout).set_default("NCDHW")
.describe("Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc."
"'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
"dimensions respectively. Upsampling is applied on the 'D', 'H' and"
"'W' dimensions.");
TVM_ATTR_FIELD(method).set_default("nearest_neighbor")
.describe("Specify the mode to use for scaling."
"nearest_neighbor - Nearest Neighbor"
"trilinear - Trilinear Interpolation");
TVM_ATTR_FIELD(coordinate_transformation_mode).set_default("half_pixel")
.describe("Describes how to transform the coordinate in the resized tensor"
"to the coordinate in the original tensor."
"Refer to the ONNX Resize operator specification for details"
"Available options are half_pixel, align_corners and asymmetric");
}
};
/*! \brief Attributes used for the padding operator */
struct PadAttrs : public tvm::AttrsNode<PadAttrs> {
double pad_value;
......
......@@ -582,6 +582,25 @@ def compute_upsampling(attrs, inputs, out_dtype, target):
align_corners = attrs.align_corners
return [topi.nn.upsampling(inputs[0], scale_h, scale_w, layout, method, align_corners)]
# upsampling3d
reg.register_schedule("nn.upsampling3d", reg.schedule_injective)
def schedule_upsampling3d(_, outs, target):
"""Schedule definition of upsampling3d"""
with target:
return topi.generic.schedule_injective(outs)
@reg.register_compute("nn.upsampling3d")
def compute_upsampling3d(attrs, inputs, out_dtype, target):
scale_d = attrs.scale_d
scale_h = attrs.scale_h
scale_w = attrs.scale_w
layout = attrs.layout
method = attrs.method
coordinate_transformation_mode = attrs.coordinate_transformation_mode
return [topi.nn.upsampling3d(inputs[0], scale_d, scale_h, scale_w, layout, method,\
coordinate_transformation_mode)]
# pad
reg.register_schedule("nn.pad", schedule_broadcast)
......
......@@ -771,6 +771,58 @@ def upsampling(data,
return _make.upsampling(data, scale_h, scale_w, layout, method, align_corners)
def upsampling3d(data,
scale_d=1,
scale_h=1,
scale_w=1,
layout="NCDHW",
method="nearest_neighbor",
coordinate_transformation_mode="half_pixel"):
"""3D Upsampling.
This operator takes data as input and does 3D scaling to the given scale factor.
In the default case, where the data_layout is `NCDHW`
with data of shape (n, c, d, h, w)
out will have a shape (n, c, d*scale_d, h*scale_h, w*scale_w)
method indicates the algorithm to be used while calculating the out value
and method can be one of ("trilinear", "nearest_neighbor")
Parameters
----------
data : tvm.relay.Expr
The input data to the operator.
scale_d : tvm.relay.Expr
The scale factor for depth upsampling.
scale_h : tvm.relay.Expr
The scale factor for height upsampling.
scale_w : tvm.relay.Expr
The scale factor for width upsampling.
layout : str, optional
Layout of the input.
method : str, optional
Scale method to used [nearest_neighbor, trilinear].
coordinate_transformation_mode: string, optional
Describes how to transform the coordinate in the resized tensor
to the coordinate in the original tensor.
Refer to the ONNX Resize operator specification for details.
Available options are "half_pixel", "align_corners" and "asymmetric".
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return _make.upsampling3d(data, scale_d, scale_h, scale_w, layout, method,
coordinate_transformation_mode)
def batch_flatten(data):
"""BatchFlatten.
......
......@@ -64,6 +64,10 @@ class UpSamplingAttrs(Attrs):
"""Attributes for nn.upsampling"""
@register_relay_attr_node
class UpSampling3DAttrs(Attrs):
"""Attributes for nn.upsampling3d"""
@register_relay_attr_node
class PadAttrs(Attrs):
"""Attributes for nn.pad"""
......
......@@ -33,6 +33,7 @@ namespace tvm {
namespace relay {
TVM_REGISTER_NODE_TYPE(UpSamplingAttrs);
TVM_REGISTER_NODE_TYPE(UpSampling3DAttrs);
template <typename T>
Array<Array<Layout> > UpsamplingInferCorrectLayout(
......@@ -50,8 +51,11 @@ Array<Array<Layout> > UpsamplingInferCorrectLayout(
Layout input = new_in_layouts[0];
if (input.IndexOf(LayoutAxis::Get('W')) == raw_layout.IndexOf(LayoutAxis::Get('W')) &&
input.IndexOf(LayoutAxis::Get('H')) == raw_layout.IndexOf(LayoutAxis::Get('H')) &&
!input.Contains(LayoutAxis::Get('w')) && !input.Contains(LayoutAxis::Get('h'))) {
params->layout = input.name(); // modify self to follow the input layout
!input.Contains(LayoutAxis::Get('w')) && !input.Contains(LayoutAxis::Get('h'))&&
(input.IndexOf(LayoutAxis::Get('D')) == -1 ||
(input.IndexOf(LayoutAxis::Get('D')) == raw_layout.IndexOf(LayoutAxis::Get('D')) &&
!input.Contains(LayoutAxis::Get('d'))))) {
params->layout = input.name(); // modify self to follow the input layout
}
}
......@@ -108,7 +112,6 @@ Expr MakeUpSampling(Expr data,
return CallNode::make(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op.nn._make.upsampling")
.set_body_typed(MakeUpSampling);
......@@ -138,5 +141,86 @@ RELAY_REGISTER_OP("nn.upsampling")
.set_attr<TOpPattern>("TOpPattern", kInjective);
// UpSampling3D
bool UpSampling3DRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;
static const Layout kNCDHW("NCDHW");
const UpSampling3DAttrs* param = attrs.as<UpSampling3DAttrs>();
CHECK(param != nullptr);
const Layout in_layout(param->layout);
auto layout_converter = BijectiveLayoutNode::make(in_layout, kNCDHW);
CHECK(layout_converter.defined())
<< "UpSampling3D only support input layouts that are convertible from NCDHW."
<< " But got " << in_layout;
auto oshape = layout_converter.ForwardShape(data->shape);
oshape.Set(2, ir::Cast::make(oshape[2].dtype(), tvm::round(oshape[2] * param->scale_d)));
oshape.Set(3, ir::Cast::make(oshape[3].dtype(), tvm::round(oshape[3] * param->scale_h)));
oshape.Set(4, ir::Cast::make(oshape[4].dtype(), tvm::round(oshape[4] * param->scale_w)));
// assign output type
reporter->Assign(types[1],
TensorTypeNode::make(layout_converter.BackwardShape(oshape),
data->dtype));
return true;
}
// Positional relay function to create upsampling3d operator
// used by frontend FFI.
Expr MakeUpSampling3D(Expr data,
double scale_d,
double scale_h,
double scale_w,
std::string layout,
std::string method,
std::string coordinate_transformation_mode) {
auto attrs = make_node<UpSampling3DAttrs>();
attrs->layout = std::move(layout);
attrs->method = std::move(method);
attrs->scale_d = scale_d;
attrs->scale_h = scale_h;
attrs->scale_w = scale_w;
attrs->coordinate_transformation_mode = coordinate_transformation_mode;
static const Op& op = Op::Get("nn.upsampling3d");
return CallNode::make(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op.nn._make.upsampling3d")
.set_body_typed(MakeUpSampling3D);
RELAY_REGISTER_OP("nn.upsampling3d")
.describe(R"code(Perform upsampling on input array with nearest neighbour or
bilinear interpolation.
- **data**: data is 5D array of shape
(batch_size, channels, in_depth, in_height, in_width) for NCDHW
(batch_size, in_depth, in_height, in_width, channels) for NDHWC
- **out**: Output is 5D array of shape
for layout NCDHW
(batch_size, channels, in_depth*scale, in_height*scale, in_width*scale)
for layout NDHWC
(batch_size, in_depth*scale, in_height*scale, in_width*scale, channels)
)code" TVM_ADD_FILELINE)
.set_attrs_type<UpSampling3DAttrs>()
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2)
.add_type_rel("UpSampling3D", UpSampling3DRel)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
UpsamplingInferCorrectLayout<UpSampling3DAttrs>)
.set_attr<TOpPattern>("TOpPattern", kInjective);
} // namespace relay
} // namespace tvm
......@@ -456,6 +456,22 @@ def test_upsampling_infer_type():
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, c, 200, 400), "float32")
def test_upsampling3d_infer_type():
n, c, d, h, w = tvm.var("n"), tvm.var("c"), tvm.var("d"), tvm.var("h"), tvm.var("w")
scale = tvm.const(2.0, "float64")
x = relay.var("x", relay.TensorType((n, c, d, h, w), "float32"))
y = relay.nn.upsampling3d(x, scale_d=2, scale_h=2, scale_w=2, layout="NCDHW", method="trilinear")
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, c, tvm.expr.Cast("int32", tvm.round(d*scale)),
tvm.expr.Cast("int32", tvm.round(h*scale)),
tvm.expr.Cast("int32", tvm.round(w*scale))),
"float32")
n, c = tvm.var("n"), tvm.var("c")
x = relay.var("x", relay.TensorType((n, c, 100, 100, 200), "float32"))
y = relay.nn.upsampling3d(x, scale_d=2, scale_h=2, scale_w=2, layout="NCDHW", method="trilinear")
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, c, 200, 200, 400), "float32")
def _test_pool2d(opfunc, reffunc):
n, c, h, w = tvm.var("n"), 10, 224, 224
......@@ -782,6 +798,50 @@ def test_upsampling():
_test_upsampling("NHWC", "nearest_neighbor")
_test_upsampling("NHWC", "bilinear", True)
def _test_upsampling3d(layout, method, coordinate_transformation_mode="half_pixel"):
n, c, d, h, w = tvm.var("n"), 8, 16, 16, 16
scale_d = 2.0
scale_h = 2.0
scale_w = 2.0
dtype = "float32"
def get_shape():
if layout == "NCDHW":
return (c, d, h, w), (c, int(round(d*scale_d)), int(round(h*scale_h)),\
int(round(w*scale_w)))
else:
return (d, h, w, c), (int(round(d*scale_d)), int(round(h*scale_h)),\
int(round(w*scale_w)), c)
ishape, oshape = get_shape()
x = relay.var("x", relay.TensorType((n,) + ishape, dtype))
y = relay.nn.upsampling3d(x, scale_d=scale_d, scale_h=scale_h, scale_w=scale_w,\
layout=layout, method=method,\
coordinate_transformation_mode=coordinate_transformation_mode)
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n,) + oshape, dtype)
dshape = (1,) + ishape
x = relay.var("x", shape=dshape)
y = relay.nn.upsampling3d(x, scale_d=scale_d, scale_h=scale_h, scale_w=scale_w,\
layout=layout, method=method,\
coordinate_transformation_mode=coordinate_transformation_mode)
func = relay.Function([x], y)
data = np.random.uniform(size=dshape).astype(dtype)
if method == "nearest_neighbor":
ref = topi.testing.upsampling3d_python(data, (scale_d, scale_h, scale_w), layout)
else:
ref = topi.testing.trilinear_resize3d_python(data, (int(round(d*scale_d)),\
int(round(h*scale_h)),\
int(round(w*scale_w))), layout)
for target, ctx in ctx_list():
executor = relay.create_executor("graph", ctx=ctx, target=target)
out = executor.evaluate(func)(data)
tvm.testing.assert_allclose(out.asnumpy(), ref, rtol=1e-5, atol=1e-5)
def test_upsampling3d():
_test_upsampling3d("NCDHW", "nearest_neighbor")
_test_upsampling3d("NCDHW", "trilinear", "align_corners")
_test_upsampling3d("NDHWC", "nearest_neighbor")
_test_upsampling3d("NDHWC", "trilinear", "align_corners")
def test_conv2d_int8_intrinsics():
def _compile(ic, oc, target, data_layout, kernel_layout, dtypes):
......@@ -935,6 +995,7 @@ if __name__ == "__main__":
test_conv2d_infer_type()
test_bitpack_infer_type()
test_upsampling_infer_type()
test_upsampling3d_infer_type()
test_flatten_infer_type()
test_pad_infer_type()
test_pad_run()
......@@ -948,4 +1009,5 @@ if __name__ == "__main__":
test_bitserial_conv2d_infer_type()
test_batch_flatten()
test_upsampling()
test_upsampling3d()
test_conv2d_int8_intrinsics()
......@@ -210,3 +210,172 @@ def resize(data, size, layout="NCHW", method="bilinear", align_corners=True, out
raise ValueError('%s method is not supported.' % method)
return tvm.compute(output_shape, compute_func, name='resize', tag=tag.INJECTIVE)
def resize3d(data, size, layout="NCDHW", method="nearest_neighbor",
coordinate_transformation_mode="align_corners", out_dtype=None):
"""Perform resize operation on the data.
Parameters
----------
inputs: tvm.Tensor
inputs is a 5-D tensor with shape
[batch, channel, in_depth, in_height, in_width]
or [batch, in_depth, in_height, in_width, channel]
size: Tuple
Output resolution scale to
layout: string, optional
"NCDHW", "NDHWC", or "NCDHWc".
coordinate_transformation_mode: string, optional
Describes how to transform the coordinate in the resized tensor
to the coordinate in the original tensor.
Refer to the ONNX Resize operator specification for details.
Available options are "half_pixel", "align_corners" and "asymmetric".
method: {"trilinear", "nearest_neighbor"}
Method to be used for resizing.
out_dtype: string, optional
Type to return. If left None will be same as input type.
Returns
-------
output : tvm.Tensor
5-D with shape [batch, channel, in_depth*scale, in_height*scale, in_width*scale]
or [batch, in_depth*scale, in_height*scale, in_width*scale, channel]
or 5-D with shape [batch, channel-major, in_depth*scale, in_height*scale, in_width*scale,
channel-minor]
"""
method = method.lower()
if layout == 'NDHWC':
in_n, in_d, in_h, in_w, in_c = data.shape
output_shape = [in_n, size[0], size[1], size[2], in_c]
elif layout == 'NCDHW':
in_n, in_c, in_d, in_h, in_w = data.shape
output_shape = [in_n, in_c, size[0], size[1], size[2]]
# Otherwise layout must be NCHWxc
else:
in_n, in_c, in_d, in_h, in_w, in_cc = data.shape
output_shape = [in_n, in_c, size[0], size[1], size[2], in_cc]
if coordinate_transformation_mode == "align_corners":
z_ratio = (in_d - 1).astype('float') / (size[0] - 1)
y_ratio = (in_h - 1).astype('float') / (size[1] - 1)
x_ratio = (in_w - 1).astype('float') / (size[2] - 1)
elif coordinate_transformation_mode in ["asymmetric", "half_pixel"]:
z_ratio = (in_d).astype('float') / (size[0])
y_ratio = (in_h).astype('float') / (size[1])
x_ratio = (in_w).astype('float') / (size[2])
else:
raise ValueError("Unsupported coordinate_transformation_mode: {}".format(
coordinate_transformation_mode))
def _get_pixel(n, c, z, y, x, cc):
z = tvm.max(tvm.min(z, in_d - 1), 0)
y = tvm.max(tvm.min(y, in_h - 1), 0)
x = tvm.max(tvm.min(x, in_w - 1), 0)
if layout == 'NDHWC':
return data(n, z, y, x, c).astype('float')
if layout == 'NCDHW':
return data(n, c, z, y, x).astype('float')
# else must be NCDHWxc
return data(n, c, z, y, x, cc).astype('float')
def _get_indices(*indices):
if layout == 'NDHWC':
n, z, y, x, c = indices
cc = None
elif layout == 'NCDHW':
n, c, z, y, x = indices
cc = None
else:
n, c, z, y, x, cc = indices
return n, c, z, y, x, cc
def _cast_output(value):
if out_dtype:
dtype = out_dtype
else:
dtype = data.dtype
return value.astype(dtype)
# Nearest neighbor computation
def _nearest_neighbor(*indices):
n, c, z, y, x, cc = _get_indices(*indices)
in_z = z_ratio * z
in_y = y_ratio * y
in_x = x_ratio * x
if coordinate_transformation_mode == "align_corners":
zint = tvm.round(in_z).astype('int32')
yint = tvm.round(in_y).astype('int32')
xint = tvm.round(in_x).astype('int32')
elif coordinate_transformation_mode in ["asymmetric", "half_pixel"]:
# Add epsilon to floor to prevent gpu rounding errors.
epsilon = 1e-5
zint = tvm.floor(in_z + epsilon).astype('int32')
yint = tvm.floor(in_y + epsilon).astype('int32')
xint = tvm.floor(in_x + epsilon).astype('int32')
else:
raise ValueError("Unsupported coordinate_transformation_mode: {}".format(
coordinate_transformation_mode))
return _cast_output(_get_pixel(n, c, zint, yint, xint, cc))
# Trilinear helper functions and computation.
def _lerp(A, B, t):
return A * (1.0 - t) + B * t
def _trilinear(*indices):
n, c, z, y, x, cc = _get_indices(*indices)
if coordinate_transformation_mode == "half_pixel":
in_z = z_ratio * (z + 0.5) - 0.5
in_y = y_ratio * (y + 0.5) - 0.5
in_x = x_ratio * (x + 0.5) - 0.5
else:
in_z = z_ratio * z
in_y = y_ratio * y
in_x = x_ratio * x
zint = tvm.floor(in_z).astype('int32')
zfract = in_z - tvm.floor(in_z)
xint = tvm.floor(in_x).astype('int32')
xfract = in_x - tvm.floor(in_x)
yint = tvm.floor(in_y).astype('int32')
yfract = in_y - tvm.floor(in_y)
p000 = _get_pixel(n, c, zint, yint, xint, cc)
p001 = _get_pixel(n, c, zint, yint, xint + 1, cc)
p010 = _get_pixel(n, c, zint, yint + 1, xint, cc)
p011 = _get_pixel(n, c, zint, yint + 1, xint + 1, cc)
p100 = _get_pixel(n, c, zint + 1, yint, xint, cc)
p101 = _get_pixel(n, c, zint + 1, yint, xint + 1, cc)
p110 = _get_pixel(n, c, zint + 1, yint + 1, xint, cc)
p111 = _get_pixel(n, c, zint + 1, yint + 1, xint + 1, cc)
dep00 = _lerp(p000, p100, zfract)
dep01 = _lerp(p001, p101, zfract)
dep10 = _lerp(p010, p110, zfract)
dep11 = _lerp(p011, p111, zfract)
col0 = _lerp(dep00, dep01, xfract)
col1 = _lerp(dep10, dep11, xfract)
value = _lerp(col0, col1, yfract)
return _cast_output(value)
# Determine which interpolation method to use then run it.
if method == "nearest_neighbor":
compute_func = _nearest_neighbor
elif method == "trilinear":
compute_func = _trilinear
else:
raise ValueError('%s method is not supported.' % method)
return tvm.compute(output_shape, compute_func, name='resize3d', tag=tag.INJECTIVE)
......@@ -63,3 +63,58 @@ def upsampling(data, scale_h, scale_w, layout="NCHW", method='nearest_neighbor',
raise ValueError("not support this layout {} yet".format(layout))
return topi.image.resize(data, out_shape, layout=layout,
method=method, align_corners=align_corners)
def upsampling3d(data, scale_d, scale_h, scale_w, layout="NCDHW", method='nearest_neighbor',
coordinate_transformation_mode="half_pixel"):
"""Perform upsampling on the data.
Nearest neighbor and bilinear upsampling are supported.
Parameters
----------
inputs : tvm.Tensor
inputs is a 5-D tensor with shape
[batch, channel, in_depth, in_height, in_width]
or [batch, in_depth, in_height, in_width, channel]
scale_d : float
Scaling factor for depth
scale_h : float
Scaling factor for height
scale_w : float
Scaling factor for width
layout : string, optional
either "NCDHW" or "NDHWC"
method : {"trilinear", "nearest_neighbor"}
Method to be used for upsampling.
coordinate_transformation_mode: string, optional
Describes how to transform the coordinate in the resized tensor
to the coordinate in the original tensor.
Refer to the ONNX Resize operator specification for details.
Available options are "half_pixel", "align_corners" and "asymmetric".
Returns
-------
output : tvm.Tensor
5-D with shape [batch, channel, in_depth*scale, in_height*scale, in_width*scale]
or [batch, in_depth*scale, in_height*scale, in_width*scale, channel]
"""
base_layout = layout[0:5]
if base_layout == "NCDHW":
out_shape = (simplify(topi.cast(tvm.round(data.shape[2] * scale_d), data.shape[2].dtype)),
simplify(topi.cast(tvm.round(data.shape[3] * scale_h), data.shape[3].dtype)),
simplify(topi.cast(tvm.round(data.shape[4] * scale_w), data.shape[4].dtype)))
elif layout == "NDHWC":
out_shape = (simplify(topi.cast(tvm.round(data.shape[1] * scale_d), data.shape[1].dtype)),
simplify(topi.cast(tvm.round(data.shape[2] * scale_h), data.shape[2].dtype)),
simplify(topi.cast(tvm.round(data.shape[3] * scale_w), data.shape[3].dtype)))
else:
raise ValueError("not support this layout {} yet".format(layout))
return topi.image.resize3d(data, out_shape, layout=layout, method=method,
coordinate_transformation_mode=coordinate_transformation_mode)
......@@ -31,8 +31,9 @@ from .deformable_conv2d_nchw_python import deformable_conv2d_nchw_python
from .depthwise_conv2d_python import depthwise_conv2d_python_nchw, depthwise_conv2d_python_nhwc
from .dilate_python import dilate_python
from .softmax_python import softmax_python, log_softmax_python
from .upsampling_python import upsampling_python
from .upsampling_python import upsampling_python, upsampling3d_python
from .bilinear_resize_python import bilinear_resize_python
from .trilinear_resize3d_python import trilinear_resize3d_python
from .reorg_python import reorg_python
from .roi_align_python import roi_align_nchw_python
from .roi_pool_python import roi_pool_nchw_python
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals, too-many-nested-blocks
"""Trilinear 3D resize in python"""
import math
import numpy as np
def trilinear_resize3d_python(data_in, out_size, layout,
coordinate_transformation_mode="align_corners"):
""" Trilinear 3d scaling using python"""
(new_d, new_h, new_w) = out_size
if layout == 'NDHWC':
(batch, d, h, w, channel) = data_in.shape
data_out = np.ones((batch, new_d, new_h, new_w, channel))
else:
(batch, channel, d, h, w) = data_in.shape
data_out = np.ones((batch, channel, new_d, new_h, new_w))
if coordinate_transformation_mode == "align_corners":
depth_scale = np.float32(d-1) / np.float32(out_size[0]-1)
height_scale = np.float32(h-1) / np.float32(out_size[1]-1)
width_scale = np.float32(w-1) / np.float32(out_size[2]-1)
elif coordinate_transformation_mode in ["asymmetric", "half_pixel"]:
depth_scale = np.float32(d) / np.float32(out_size[0])
height_scale = np.float32(h) / np.float32(out_size[1])
width_scale = np.float32(w) / np.float32(out_size[2])
else:
raise ValueError("Unsupported coordinate_transformation_mode: {}".format(
coordinate_transformation_mode))
def _lerp(A, B, t):
return A * (1.0 - t) + B * t
def _in_coord(new_coord, scale, shape, mode):
if mode == "half_pixel":
in_coord = (new_coord + 0.5) * scale - 0.5
else:
in_coord = new_coord * scale
coord0 = int(math.floor(in_coord))
coord1 = max(min(coord0 + 1, shape - 1), 0)
coord0 = max(coord0, 0)
coord_lerp = in_coord - math.floor(in_coord)
return coord0, coord1, coord_lerp
for b in range(batch):
for i in range(channel):
for m in range(new_d):
for j in range(new_h):
for k in range(new_w):
z0, z1, z_lerp = _in_coord(m, depth_scale, d,\
coordinate_transformation_mode)
y0, y1, y_lerp = _in_coord(j, height_scale, h,\
coordinate_transformation_mode)
x0, x1, x_lerp = _in_coord(k, width_scale, w,\
coordinate_transformation_mode)
if layout == 'NDHWC':
A0 = data_in[b][z0][y0][x0][i]
B0 = data_in[b][z0][y0][x1][i]
C0 = data_in[b][z0][y1][x0][i]
D0 = data_in[b][z0][y1][x1][i]
A1 = data_in[b][z1][y0][x0][i]
B1 = data_in[b][z1][y0][x1][i]
C1 = data_in[b][z1][y1][x0][i]
D1 = data_in[b][z1][y1][x1][i]
else:
A0 = data_in[b][i][z0][y0][x0]
B0 = data_in[b][i][z0][y0][x1]
C0 = data_in[b][i][z0][y1][x0]
D0 = data_in[b][i][z0][y1][x1]
A1 = data_in[b][i][z1][y0][x0]
B1 = data_in[b][i][z1][y0][x1]
C1 = data_in[b][i][z1][y1][x0]
D1 = data_in[b][i][z1][y1][x1]
A = _lerp(A0, A1, z_lerp)
B = _lerp(B0, B1, z_lerp)
C = _lerp(C0, C1, z_lerp)
D = _lerp(D0, D1, z_lerp)
top = _lerp(A, B, x_lerp)
bottom = _lerp(C, D, x_lerp)
pixel = np.float32(_lerp(top, bottom, y_lerp))
if layout == 'NDHWC':
data_out[b][m][j][k][i] = pixel
else:
data_out[b][i][m][j][k] = pixel
return data_out
......@@ -53,3 +53,45 @@ def upsampling_python(data, scale, layout='NCHW'):
output_np[b, :, :, c] = upsample_nearest(data[b, :, :, c], scale)
return output_np
raise ValueError("not support this layout {} yet".format(layout))
def upsample3d_nearest(arr, scale):
""" Populate the array by scale factor"""
d, h, w = arr.shape
out_d = int(round(d * scale[0]))
out_h = int(round(h * scale[1]))
out_w = int(round(w * scale[2]))
out = np.empty((out_d, out_h, out_w))
for z in range(out_d):
for y in range(out_h):
for x in range(out_w):
in_z = math.floor(z / scale[0])
in_y = math.floor(y / scale[1])
in_x = math.floor(x / scale[2])
out[z, y, x] = arr[in_z, in_y, in_x]
return out
def upsampling3d_python(data, scale, layout='NCDHW'):
""" Python version of 3D scaling using nearest neighbour """
ishape = data.shape
if layout == 'NCDHW':
oshape = (ishape[0], ishape[1],
int(round(ishape[2]*scale[0])),
int(round(ishape[3]*scale[1])),
int(round(ishape[4]*scale[2])))
output_np = np.zeros(oshape, dtype=data.dtype)
for b in range(oshape[0]):
for c in range(oshape[1]):
output_np[b, c, :, :, :] = upsample3d_nearest(data[b, c, :, :, :], scale)
return output_np
if layout == 'NDHWC':
oshape = (ishape[0],
int(round(ishape[1]*scale[0])),
int(round(ishape[2]*scale[1])),
int(round(ishape[3]*scale[2])), ishape[4])
output_np = np.zeros(oshape, dtype=data.dtype)
for b in range(oshape[0]):
for c in range(oshape[4]):
output_np[b, :, :, :, c] = upsample3d_nearest(data[b, :, :, :, c], scale)
return output_np
raise ValueError("not support this layout {} yet".format(layout))
......@@ -79,5 +79,68 @@ def test_resize():
verify_resize(4, 16, 32, 32, 50, 50, 'NCHW', method="nearest_neighbor", align_corners=False)
verify_resize(4, 16, 32, 32, 50, 50, 'NHWC', method="nearest_neighbor", align_corners=False)
def verify_resize3d(batch, in_channel, in_depth, in_height, in_width, out_depth, out_height, out_width,
layout='NCDHW', coordinate_transformation_mode="half_pixel", method="trilinear"):
if layout == 'NCDHW':
A = tvm.placeholder((batch, in_channel, in_depth, in_height, in_width), name='A', dtype='float32')
dtype = A.dtype
out_shape = (batch, in_channel, out_depth, out_height, out_width)
a_np = np.random.uniform(size=(batch, in_channel, in_depth, in_height, in_width)).astype(dtype)
elif layout == 'NDHWC':
A = tvm.placeholder((batch, in_depth, in_height, in_width, in_channel), name='A', dtype='float32')
dtype = A.dtype
out_shape = (batch, out_depth, out_height, out_width, in_channel)
a_np = np.random.uniform(size=(batch, in_depth, in_height, in_width, in_channel)).astype(dtype)
else:
raise NotImplementedError(
'Layout not supported {} '.format(layout))
B = topi.image.resize3d(A, (out_depth, out_height, out_width), layout=layout,
coordinate_transformation_mode=coordinate_transformation_mode, method=method)
if method == "trilinear":
b_np = topi.testing.trilinear_resize3d_python(a_np, (out_depth, out_height, out_width), layout,
coordinate_transformation_mode)
else:
scale_d = out_depth / in_depth
scale_h = out_height / in_height
scale_w = out_width / in_width
b_np = topi.testing.upsampling3d_python(a_np, (scale_d, scale_h, scale_w), layout)
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_injective(B)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(out_shape, dtype=dtype), ctx)
f = tvm.build(s, [A, B], device)
f(a, b)
tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-3, atol=1e-3)
for device in get_all_backend():
check_device(device)
def test_resize3d():
# Trilinear
verify_resize3d(4, 8, 16, 16, 16, 25, 25, 25, 'NCDHW')
verify_resize3d(1, 8, 16, 16, 16, 25, 25, 25, "NDHWC")
verify_resize3d(3, 16, 32, 32, 32, 10, 10, 10, 'NCDHW', "align_corners")
verify_resize3d(3, 16, 32, 32, 32, 10, 10, 10, 'NDHWC', "align_corners")
verify_resize3d(3, 16, 32, 32, 32, 10, 10, 10, 'NCDHW', "asymmetric")
verify_resize3d(3, 16, 32, 32, 32, 10, 10, 10, 'NDHWC', "asymmetric")
# Nearest neighbor
verify_resize3d(4, 8, 16, 16, 16, 25, 25, 25, 'NCDHW', method="nearest_neighbor")
verify_resize3d(4, 8, 16, 16, 16, 25, 25, 25, 'NDHWC', method="nearest_neighbor")
if __name__ == "__main__":
test_resize()
test_resize3d()
......@@ -86,5 +86,73 @@ def test_upsampling():
verify_upsampling(2, 2, 32, 32, 3.0, 3.0, layout="NHWC", method="bilinear")
verify_upsampling(1, 64, 22, 32, 3.0, 3.0, layout="NHWC", method="bilinear")
def verify_upsampling3d(batch, in_channel, in_depth, in_height, in_width, scale_d, scale_h, scale_w,
layout='NCDHW', method="nearest_neighbor"):
if layout == 'NCDHW':
A = tvm.placeholder((batch, in_channel, in_depth, in_height, in_width), name='A')
dtype = A.dtype
out_shape = (batch, in_channel, int(round(in_depth*scale_d)), int(round(in_height*scale_h)),
int(round(in_width*scale_w)))
a_np = np.random.uniform(size=(batch, in_channel, in_depth, in_height, in_width)).astype(dtype)
elif layout == 'NDHWC':
A = tvm.placeholder((batch, in_depth, in_height, in_width, in_channel), name='A')
dtype = A.dtype
out_shape = (batch, int(round(in_depth*scale_d)), int(round(in_height*scale_h)),
int(round(in_width*scale_w)), in_channel)
a_np = np.random.uniform(size=(batch, in_depth, in_height, in_width, in_channel)).astype(dtype)
else:
raise NotImplementedError(
'Layout not supported {} '.format(layout))
B = topi.nn.upsampling3d(A, scale_d, scale_h, scale_w, layout=layout, method=method,
coordinate_transformation_mode="half_pixel")
if method == "trilinear":
out_size = (int(round(in_depth*scale_d)), int(round(in_height*scale_h)), int(round(in_width*scale_w)))
b_np = topi.testing.trilinear_resize3d_python(a_np, out_size, layout,
coordinate_transformation_mode="half_pixel")
else:
b_np = topi.testing.upsampling3d_python(a_np, (scale_d, scale_h, scale_w), layout)
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_injective(B)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(out_shape, dtype=dtype), ctx)
f = tvm.build(s, [A, B], device)
f(a, b)
tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5)
for device in get_all_backend():
check_device(device)
def test_upsampling3d():
# nearest_neighbor - NCDHW
verify_upsampling3d(8, 8, 16, 16, 16, 2.0, 2.0, 2.0)
verify_upsampling3d(2, 16, 32, 32, 32, 3.0, 3.0, 3.0)
verify_upsampling3d(1, 8, 11, 16, 6, 1.954545497894287, 2.0, 1.5)
## nearest_neighbor - NDHWC
verify_upsampling3d(8, 8, 16, 16, 16, 2.0, 2.0, 2.0, layout="NDHWC")
verify_upsampling3d(2, 16, 32, 32, 32, 3.0, 3.0, 3.0, layout="NDHWC")
verify_upsampling3d(1, 8, 11, 16, 6, 1.954545497894287, 2.0, 1.5, layout="NDHWC")
# trilinear - NCDHW
verify_upsampling3d(2, 2, 16, 16, 16, 2.0, 2.0, 2.0, method="trilinear")
verify_upsampling3d(2, 2, 32, 32, 32, 3.0, 3.0, 3.0, method="trilinear")
verify_upsampling3d(1, 2, 11, 16, 6, 1.954545497894287, 2.0, 1.5, method="trilinear")
# trilinear - NDHWC
verify_upsampling3d(2, 2, 16, 16, 16, 2.0, 2.0, 2.0, layout="NDHWC", method="trilinear")
verify_upsampling3d(2, 2, 32, 32, 32, 3.0, 3.0, 3.0, layout="NDHWC", method="trilinear")
verify_upsampling3d(1, 2, 11, 16, 6, 1.954545497894287, 2.0, 1.5, layout="NDHWC", method="trilinear")
if __name__ == "__main__":
test_upsampling()
test_upsampling3d()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment