Commit 7264cb6a by Josh Fromm Committed by masahi

Changed topi cc resize to python implementation with new features. (#3788)

parent c870261f
......@@ -37,6 +37,7 @@ struct ResizeAttrs : public tvm::AttrsNode<ResizeAttrs> {
std::string layout;
std::string method;
bool align_corners;
DataType out_dtype;
TVM_DECLARE_ATTRS(ResizeAttrs, "relay.attrs.ResizeAttrs") {
TVM_ATTR_FIELD(size).set_default(NullValue<Array<IndexExpr> >())
......@@ -46,12 +47,16 @@ struct ResizeAttrs : public tvm::AttrsNode<ResizeAttrs> {
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Resize is applied on the 'H' and"
"'W' dimensions.");
TVM_ATTR_FIELD(method).set_default("BILINEAR")
TVM_ATTR_FIELD(method).set_default("bilinear")
.describe("Specify the mode to use for scaling."
"NEAREST_NEIGHBOR - Nearest Neighbor"
"BILINEAR - Bilinear Interpolation");
TVM_ATTR_FIELD(align_corners).set_default(false)
"nearest_neighbor - Nearest Neighbor"
"bilinear - Bilinear Interpolation"
"bicubic - Bicubic Interpolation");
TVM_ATTR_FIELD(align_corners).set_default(true)
.describe("Should be true to preserve the values at the corner pixels");
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type.");
}
};
......
......@@ -381,6 +381,7 @@ struct UpSamplingAttrs : public tvm::AttrsNode<UpSamplingAttrs> {
int scale;
std::string layout;
std::string method;
bool align_corners;
TVM_DECLARE_ATTRS(UpSamplingAttrs, "relay.attrs.UpSamplingAttrs") {
TVM_ATTR_FIELD(scale)
......@@ -390,10 +391,13 @@ struct UpSamplingAttrs : public tvm::AttrsNode<UpSamplingAttrs> {
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Upsampling is applied on the 'H' and"
"'W' dimensions.");
TVM_ATTR_FIELD(method).set_default("NEAREST_NEIGHBOR")
TVM_ATTR_FIELD(method).set_default("nearest_neighbor")
.describe("Specify the mode to use for scaling."
"NEAREST_NEIGHBOR - Nearest Neighbor"
"BILINEAR - Bilinear Interpolation");
"nearest_neighbor - Nearest Neighbor"
"bilinear - Bilinear Interpolation"
"bicubic - Bicubic Interpolation");
TVM_ATTR_FIELD(align_corners).set_default(false)
.describe("Should be true to preserve the values at the corner pixels");
}
};
......
......@@ -324,12 +324,12 @@ def test_upsampling_bilinear():
data = tvm.nd.array(a_np)
m.run(x=data)
out = m.get_output(0, tvm.nd.empty(oshape, dtype))
b_np = topi.testing.bilinear_resize_python(a_np, (32*scale, 32*scale), "NCHW")
b_np = topi.testing.bilinear_resize_python(a_np, (32*scale, 32*scale), "NCHW", align_corners=False)
tvm.testing.assert_allclose(out.asnumpy(), b_np, rtol=1e-5, atol=1e-5)
def test_resize_bilinear():
x = sym.Variable("x")
y = sym.resize(x, size=(60, 60), method="BILINEAR", name="y", layout="NHWC")
y = sym.resize(x, size=(60, 60), method="BILINEAR", name="y", layout="NHWC", align_corners=True)
dtype = "float32"
dshape = (1, 32, 32, 4)
oshape = (1, 60, 60, 4)
......
......@@ -425,7 +425,7 @@ def _test_upsample_bilinear():
y = helper.make_node("Upsample", ['in'], ['out'], mode='linear', scales=[1.0, 1.0, 2.0, 2.0])
in_array = np.random.uniform(size=in_shape).astype(np.float32)
out_array = topi.testing.bilinear_resize_python(in_array, (3*scale, 3*scale), "NCHW")
out_array = topi.testing.bilinear_resize_python(in_array, (3*scale, 3*scale), "NCHW", align_corners=False)
graph = helper.make_graph([y],
'upsample_bilinear_test',
......@@ -445,7 +445,7 @@ def _test_upsample_bilinear_opset9():
y = helper.make_node("Upsample", ['in','scales'], ['out'], mode='linear')
scales=[1.0, 1.0, 2.0, 2.0]
in_array = np.random.uniform(size=in_shape).astype(np.float32)
out_array = topi.testing.bilinear_resize_python(in_array, (3*scale, 3*scale), "NCHW")
out_array = topi.testing.bilinear_resize_python(in_array, (3*scale, 3*scale), "NCHW", align_corners=False)
ref_array = np.array(scales)
ref_node = helper.make_node('Constant',
......
......@@ -312,7 +312,7 @@ def _UpsampleLayerParams(op, inexpr, etab):
if op.scalingFactor[0] != op.scalingFactor[1]:
raise tvm.error.OpAttributeUnimplemented(
'Upsample height and width must be equal.')
interpolationMode = 'NEAREST_NEIGHBOR' if op.mode == 0 else 'BILINEAR'
interpolationMode = 'nearest_neighbor' if op.mode == 0 else 'bilinear'
return _op.nn.upsampling(inexpr, scale=op.scalingFactor[0], method=interpolationMode)
......
......@@ -358,29 +358,30 @@ def _convert_pooling(inexpr, keras_layer, etab):
def _convert_upsample(inexpr, keras_layer, _):
_check_data_format(keras_layer)
upsample_type = type(keras_layer).__name__
params = {'layout': 'NHWC'}
if upsample_type == 'UpSampling1D':
h = keras_layer.size
params = {'scale': h}
params['scale'] = h
elif upsample_type == 'UpSampling2D':
h, w = keras_layer.size
if h != w:
raise tvm.error.OpAttributeInvalid(
'Height must equal width for operator Upsample.')
params = {'scale': h}
params['scale'] = h
if hasattr(keras_layer, 'interpolation'):
interpolation = keras_layer.interpolation
if interpolation == 'nearest':
params['method'] = 'NEAREST_NEIGHBOR'
params['method'] = 'nearest_neighbor'
else:
params['method'] = 'BILINEAR'
params['method'] = 'bilinear'
elif upsample_type == 'UpSampling3D':
h, w, d = keras_layer.size
if h != w or w != d:
raise tvm.error.OpAttributeInvalid(
'Height, width, and depth must all be equal for operator Upsample.')
params = {'scale': h}
params['scale'] = h
else:
raise tvm.error.OpNotImplemented(
'Operator {} is not supported for frontend Keras.'.format(upsample_type))
......
......@@ -559,13 +559,13 @@ class Upsample(OnnxOpConverter):
assert len(scales) == 4 and scales[0] == 1.0 and scales[1] == 1.0 and scales[2] == scales[3]
mode = attr.get('mode')
if mode == b'nearest':
method = "NEAREST_NEIGHBOR"
method = "nearest_neighbor"
elif mode == b'linear':
method = "BILINEAR"
method = "bilinear"
else:
raise tvm.error.OpAttributeInvalid(
'Value {} in attribute "mode" of operator Upsample is not valid.'.format(mode))
attr = {'scale':int(scales[-1]), 'method':method, 'layout':'NCHW'}
attr = {'scale':int(scales[-1]), 'method':method, 'layout':'NCHW', 'align_corners':True}
return AttrCvt('upsampling')(inputs, attr)
......
......@@ -358,7 +358,7 @@ def _crop_and_resize():
'Attribute method=nearest is not supported')
else:
attrs['align_corners'] = True
attrs['method'] = 'BILINEAR'
attrs['method'] = 'bilinear'
out = None
begin = [0] * data_dim
......@@ -408,7 +408,7 @@ def _resize_bilinear():
return AttrCvt(op_name="resize",
ignores=['Tdim'],
extras={'method': "BILINEAR"})(inputs, attr)
extras={'method': "bilinear"})(inputs, attr)
return _impl
def _resize_nearest_neighbor():
......@@ -423,7 +423,7 @@ def _resize_nearest_neighbor():
return AttrCvt(op_name="resize",
ignores=['Tdim'],
extras={'method': "NEAREST_NEIGHBOR"})(inputs, attr)
extras={'method': "nearest_neighbor"})(inputs, attr)
return _impl
def _check_numerics():
......
......@@ -262,7 +262,7 @@ class OperatorConverter(object):
# Options - align_corners (bool)
resize_options = None
align_corners = False
if method == "BILINEAR":
if method == "bilinear":
assert op.BuiltinOptionsType() == BuiltinOptions.ResizeBilinearOptions
resize_options = ResizeBilinearOptions()
elif tflite_ver >= 1130:
......@@ -280,11 +280,11 @@ class OperatorConverter(object):
def convert_resize_bilinear(self, op):
"""Convert TFLite RESIZE_BILINEAR"""
return self._convert_resize("BILINEAR", op)
return self._convert_resize("bilinear", op)
def convert_resize_nearest_neighbor(self, op):
"""Convert TFLite RESIZE_NEAREST_NEIGHBOR"""
return self._convert_resize("NEAREST_NEIGHBOR", op)
return self._convert_resize("nearest_neighbor", op)
def convert_logistic(self, op):
"""Convert TFLite LOGISTIC"""
......
......@@ -17,7 +17,20 @@
#pylint: disable=invalid-name, unused-argument
"""Backend compiler related feature registration"""
from __future__ import absolute_import
from ..op import register_schedule, schedule_injective
import topi
from .. import op as reg
from ..op import schedule_injective
# resize
register_schedule("image.resize", schedule_injective)
reg.register_schedule("image.resize", schedule_injective)
@reg.register_compute("image.resize")
def compute_resize(attrs, inputs, out_type, target):
size = attrs.size
layout = attrs.layout
method = attrs.method
align_corners = attrs.align_corners
out_dtype = attrs.out_dtype
return [topi.image.resize(inputs[0], size, layout, method, align_corners, out_dtype)]
......@@ -21,8 +21,9 @@ from . import _make
def resize(data,
size,
layout="NCHW",
method="BILINEAR",
align_corners=False):
method="bilinear",
align_corners=True,
out_dtype=None):
"""Image resize operator.
This operator takes data as input and does 2D scaling to the given scale factor.
......@@ -31,7 +32,7 @@ def resize(data,
out will have a shape (n, c, size[0], size[1])
method indicates the algorithm to be used while calculating ghe out value
and method can be one of ("BILINEAR", "NEAREST_NEIGHBOR")
and method can be one of ("bilinear", "nearest_neighbor", "bicubic")
Parameters
----------
......@@ -45,14 +46,17 @@ def resize(data,
Layout of the input.
method : str, optional
Scale method to used [NEAREST_NEIGHBOR, BILINEAR].
Scale method to used [nearest_neighbor, bilinear, bicubic].
align_corners : int, optional
Should be true to preserve the values at the corner pixels
out_dtype : str, optional
Type to return. If left None returns the same type as input.
Returns
-------
result: relay.Expr
The resized result.
"""
return _make.resize(data, size, layout, method, align_corners)
return _make.resize(data, size, layout, method, align_corners, out_dtype)
......@@ -376,6 +376,13 @@ def schedule_upsampling(_, outs, target):
with target:
return topi.generic.schedule_injective(outs)
@reg.register_compute("nn.upsampling")
def compute_upsampling(attrs, inputs, out_dtype, target):
scale = attrs.scale
layout = attrs.layout
method = attrs.method
align_corners = attrs.align_corners
return [topi.nn.upsampling(inputs[0], scale, layout, method, align_corners)]
# pad
reg.register_schedule("nn.pad", schedule_broadcast)
......
......@@ -481,7 +481,8 @@ def global_avg_pool2d(data,
def upsampling(data,
scale=1,
layout="NCHW",
method="NEAREST_NEIGHBOR"):
method="nearest_neighbor",
align_corners=False):
"""Upsampling.
This operator takes data as input and does 2D scaling to the given scale factor.
......@@ -490,7 +491,7 @@ def upsampling(data,
out will have a shape (n, c, h*scale, w*scale)
method indicates the algorithm to be used while calculating the out value
and method can be one of ("BILINEAR", "NEAREST_NEIGHBOR")
and method can be one of ("bilinear", "nearest_neighbor", "bicubic")
Parameters
----------
......@@ -504,14 +505,17 @@ def upsampling(data,
Layout of the input.
method : str, optional
Scale method to used [NEAREST_NEIGHBOR, BILINEAR].
Scale method to used [nearest_neighbor, bilinear, bicubic].
align_corners : bool, optional
Whether to keep corners in proper place.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return _make.upsampling(data, scale, layout, method)
return _make.upsampling(data, scale, layout, method, align_corners)
def batch_flatten(data):
......
......@@ -25,8 +25,6 @@
#include <tvm/data_layout.h>
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/image.h>
#include <topi/elemwise.h>
#include <topi/image/resize.h>
#include "../op_common.h"
namespace tvm {
......@@ -56,49 +54,32 @@ bool ResizeRel(const Array<Type>& types,
oshape.Set(2, param->size[0]);
oshape.Set(3, param->size[1]);
DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
}
// assign output type
reporter->Assign(types[1],
TensorTypeNode::make(layout_converter.BackwardShape(oshape),
data->dtype));
out_dtype));
return true;
}
Array<Tensor> ResizeCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
const auto* param = attrs.as<ResizeAttrs>();
CHECK(param != nullptr);
CHECK(param->layout == "NCHW" || param->layout == "NHWC");
const auto* out_ttype = out_type.as<TensorTypeNode>();
CHECK(out_ttype != nullptr);
Array<IndexExpr> oshape;
if (param->layout == "NCHW") {
oshape.push_back(out_ttype->shape[2]);
oshape.push_back(out_ttype->shape[3]);
} else if (param->layout == "NHWC") {
oshape.push_back(out_ttype->shape[1]);
oshape.push_back(out_ttype->shape[2]);
}
return Array<Tensor>{ topi::image::resize(inputs[0],
oshape,
param->layout,
param->align_corners,
param->method) };
}
// Positional relay function to create image operator
// used by frontend FFI.
Expr MakeResize(Expr data,
Array<IndexExpr> size,
std::string layout,
std::string method,
bool align_corners) {
bool align_corners,
DataType out_dtype) {
auto attrs = make_node<ResizeAttrs>();
attrs->size = std::move(size);
attrs->layout = std::move(layout);
attrs->method = std::move(method);
attrs->align_corners = align_corners;
attrs->out_dtype = out_dtype;
static const Op& op = Op::Get("image.resize");
return CallNode::make(op, {data}, Attrs(attrs), {});
}
......@@ -127,7 +108,6 @@ RELAY_REGISTER_OP("image.resize")
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(5)
.add_type_rel("Resize", ResizeRel)
.set_attr<FTVMCompute>("FTVMCompute", ResizeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);
} // namespace relay
......
......@@ -27,8 +27,6 @@
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/build_module.h>
#include <topi/elemwise.h>
#include <topi/nn/upsampling.h>
#include <vector>
#include "../op_common.h"
......@@ -99,11 +97,13 @@ bool UpSamplingRel(const Array<Type>& types,
Expr MakeUpSampling(Expr data,
int scale,
std::string layout,
std::string method) {
std::string method,
bool align_corners) {
auto attrs = make_node<UpSamplingAttrs>();
attrs->layout = std::move(layout);
attrs->method = std::move(method);
attrs->scale = scale;
attrs->align_corners = align_corners;
static const Op& op = Op::Get("nn.upsampling");
return CallNode::make(op, {data}, Attrs(attrs), {});
}
......@@ -135,38 +135,7 @@ RELAY_REGISTER_OP("nn.upsampling")
.add_type_rel("UpSampling", UpSamplingRel)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
UpsamplingInferCorrectLayout<UpSamplingAttrs>)
.set_attr<TOpPattern>("TOpPattern", kInjective)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
const auto* uattrs = attrs.as<UpSamplingAttrs>();
CHECK(uattrs != nullptr);
auto out_tt = out_type.as<TensorTypeNode>();
CHECK(out_tt) << "expected a tensor type: " << out_type;
const auto layout = uattrs->layout;
const auto base_layout = layout.substr(0, 4);
CHECK(base_layout == "NCHW" || layout == "NHWC")
<< "unknown layout: " << uattrs->layout;
Array<IndexExpr> oshape;
if (base_layout == "NCHW") {
oshape.push_back(out_tt->shape[2]);
oshape.push_back(out_tt->shape[3]);
} else if (layout == "NHWC") {
oshape.push_back(out_tt->shape[1]);
oshape.push_back(out_tt->shape[2]);
}
return Array<Tensor>{
topi::nn::upsampling(
inputs[0],
oshape,
uattrs->layout,
uattrs->method)
};
});
.set_attr<TOpPattern>("TOpPattern", kInjective);
} // namespace relay
......
......@@ -172,7 +172,7 @@ def test_forward_upsample(interpolation='nearest'):
data = keras.layers.Input(shape=(32, 32, 3))
x = keras.layers.UpSampling2D(size=(3, 3), interpolation=interpolation)(data)
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model)
verify_keras_frontend(keras_model, need_transpose=False)
def test_forward_reshape():
......
......@@ -1212,7 +1212,7 @@ def test_forward_crop_and_resize():
_test_forward_crop_and_resize([1, 11, 11, 3], [[0, 0, .9, .9]], [0], [5, 5])
_test_forward_crop_and_resize([1, 11, 11, 3], [[.1, .2, 1, 1]], [0], [5, 5])
_test_forward_crop_and_resize([1, 21, 21, 3], [[.2, .3, .7, .9]], [0], [3, 4])
_test_forward_crop_and_resize([1, 106, 106, 3], [[0.2, 0.4, 0.8, 0.8]], [0], [3, 3])
_test_forward_crop_and_resize([1, 41, 41, 3], [[0.2, 0.4, 0.8, 0.8]], [0], [3, 3])
_test_forward_crop_and_resize([10, 11, 11, 3],
[[0, 0, 0.9, 0.9], [0.2, 0.2, 0.8, 0.8]],
[0, 1],
......
......@@ -233,13 +233,13 @@ def test_conv2d_transpose_run():
def test_upsampling_infer_type():
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = relay.var("x", relay.TensorType((n, c, h, w), "float32"))
y = relay.nn.upsampling(x, scale=2, layout="NCHW", method="BILINEAR")
y = relay.nn.upsampling(x, scale=2, layout="NCHW", method="bilinear")
"method=\"BINLINEAR\"" in y.astext()
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, c, h*2, w*2), "float32")
n, c = tvm.var("n"), tvm.var("c")
x = relay.var("x", relay.TensorType((n, c, 100, 200), "float32"))
y = relay.nn.upsampling(x, scale=2, layout="NCHW", method="BILINEAR")
y = relay.nn.upsampling(x, scale=2, layout="NCHW", method="bilinear")
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, c, 200, 400), "float32")
......@@ -502,7 +502,7 @@ def test_batch_flatten():
np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
def _test_upsampling(layout, method):
def _test_upsampling(layout, method, align_corners=False):
n, c, h, w = tvm.var("n"), 16, 32, 32
scale = 2
dtype = "float32"
......@@ -513,15 +513,17 @@ def _test_upsampling(layout, method):
return (h, w, c), (h*scale, w*scale, c)
ishape, oshape = get_shape()
x = relay.var("x", relay.TensorType((n,) + ishape, dtype))
y = relay.nn.upsampling(x, scale=scale, layout=layout, method=method)
y = relay.nn.upsampling(x, scale=scale, layout=layout,
method=method, align_corners=align_corners)
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n,) + oshape, dtype)
dshape = (1,) + ishape
x = relay.var("x", shape=dshape)
y = relay.nn.upsampling(x, scale=scale, layout=layout, method=method)
y = relay.nn.upsampling(x, scale=scale, layout=layout,
method=method, align_corners=align_corners)
func = relay.Function([x], y)
data = np.random.uniform(size=dshape).astype(dtype)
if method == "NEAREST_NEIGHBOR":
if method == "nearest_neighbor":
ref = topi.testing.upsampling_python(data, (scale, scale), layout)
else:
ref = topi.testing.bilinear_resize_python(data, (h*scale, w*scale), layout)
......@@ -532,10 +534,10 @@ def _test_upsampling(layout, method):
def test_upsampling():
_test_upsampling("NCHW", "NEAREST_NEIGHBOR")
_test_upsampling("NCHW", "BILINEAR")
_test_upsampling("NHWC", "NEAREST_NEIGHBOR")
_test_upsampling("NHWC", "BILINEAR")
_test_upsampling("NCHW", "nearest_neighbor")
_test_upsampling("NCHW", "bilinear", True)
_test_upsampling("NHWC", "nearest_neighbor")
_test_upsampling("NHWC", "bilinear", True)
def test_conv2d_int8_intrinsics():
......
......@@ -39,7 +39,7 @@ def test_resize_infer_type():
assert zz.checked_type == relay.TensorType((n, c, th, tw), "int8")
x = relay.var("x", relay.TensorType((n, c, h, w), "int8"))
z= relay.image.resize(x, (100, 200), "NCHW", "BILINEAR", False)
z= relay.image.resize(x, (100, 200), "NCHW", "bilinear", True)
assert "size=" in z.astext()
zz = run_infer_type(z)
assert zz.checked_type == relay.TensorType((n, c, 100, 200), "int8")
......@@ -52,12 +52,12 @@ def test_resize():
size = (dshape[2] * scale, dshape[3] * scale)
x_data = np.random.uniform(size=dshape).astype("float32")
if method == "BILINEAR":
if method == "bilinear":
ref_res = topi.testing.bilinear_resize_python(x_data, size, layout)
else:
ref_res = topi.testing.upsampling_python(x_data, (scale, scale), layout)
x = relay.var("x", relay.TensorType(dshape, "float32"))
z = relay.image.resize(x, size, layout, method, False)
z = relay.image.resize(x, size, layout, method, True)
assert "size=" in z.astext()
zz = run_infer_type(z)
assert zz.checked_type == relay.TensorType(ref_res.shape, "float32")
......@@ -68,7 +68,7 @@ def test_resize():
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
for method in ["BILINEAR", "NEAREST_NEIGHBOR"]:
for method in ["bilinear", "nearest_neighbor"]:
for layout in ["NHWC", "NCHW"]:
verify_resize((1, 4, 4, 4), 2, method, layout)
......
......@@ -201,7 +201,6 @@ inline Tensor resize_nearest_neighbor(const Tensor& input,
bool align_corners = false,
std::string name = "tensor",
std::string tag = kInjective) {
CHECK_EQ(align_corners, false) << "Align corners not supported for nearest neighbour";
auto base_layout = layout.substr(0, 4);
if (layout == "NHWC") {
return resize_nearest_neighbor_nhwc(input, shape, align_corners);
......
......@@ -14,11 +14,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""TVM operator input resize compute."""
from __future__ import absolute_import
import topi
import tvm
from .. import tag
def resize(data, size, layout="NCHW", align_corners=False, method="BILINEAR"):
def resize(data, size, layout="NCHW", method="bilinear", align_corners=True, out_dtype=None):
"""Perform resize operation on the data.
Parameters
......@@ -32,18 +35,178 @@ def resize(data, size, layout="NCHW", align_corners=False, method="BILINEAR"):
Output resolution scale to
layout: string, optional
either "NCHW" or "NHWC"
"NCHW", "NHWC", or "NCHWc".
align_corners: Boolean, optional
To preserve the values at the corner pixels
To preserve the values at the corner pixels.
method: {"BILINEAR", "NEAREST_NEIGHBOR"}
method: {"bilinear", "nearest_neighbor", "bicubic"}
Method to be used for resizing.
out_dtype: string, optional
Type to return. If left None will be same as input type.
Returns
-------
output : tvm.Tensor
4-D with shape [batch, channel, in_height*scale, in_width*scale]
or [batch, in_height*scale, in_width*scale, channel]
or 5-D with shape [batch, channel-major, in_height*scale, in_width*scale, channel-minor]
"""
return topi.cpp.image.resize(data, size, layout, align_corners, method)
method = method.lower()
if layout == 'NHWC':
in_n, in_h, in_w, in_c = data.shape
output_shape = [in_n, size[0], size[1], in_c]
elif layout == 'NCHW':
in_n, in_c, in_h, in_w = data.shape
output_shape = [in_n, in_c, size[0], size[1]]
# Otherwise layout must be NCHWxc
else:
in_n, in_c, in_h, in_w, in_cc = data.shape
output_shape = [in_n, in_c, size[0], size[1], in_cc]
if align_corners:
y_ratio = (in_h - 1).astype('float') / (size[0] - 1)
x_ratio = (in_w - 1).astype('float') / (size[1] - 1)
else:
y_ratio = (in_h).astype('float') / (size[0])
x_ratio = (in_w).astype('float') / (size[1])
def _get_pixel(n, c, y, x, cc):
y = tvm.max(tvm.min(y, in_h - 1), 0)
x = tvm.max(tvm.min(x, in_w - 1), 0)
if layout == 'NHWC':
return data(n, y, x, c).astype('float')
if layout == 'NCHW':
return data(n, c, y, x).astype('float')
# else must be NCHWxc
return data(n, c, y, x, cc).astype('float')
def _get_indices(*indices):
if layout == 'NHWC':
n, y, x, c = indices
cc = None
elif layout == 'NCHW':
n, c, y, x = indices
cc = None
else:
n, c, y, x, cc = indices
return n, c, y, x, cc
def _cast_output(value):
if out_dtype:
dtype = out_dtype
else:
dtype = data.dtype
return value.astype(dtype)
# Nearest neighbor computation
def _nearest_neighbor(*indices):
n, c, y, x, cc = _get_indices(*indices)
in_y = y_ratio * y
in_x = x_ratio * x
if align_corners:
yint = tvm.round(in_y).astype('int32')
xint = tvm.round(in_x).astype('int32')
else:
# Add epsilon to floor to prevent gpu rounding errors.
epsilon = 1e-5
yint = tvm.floor(in_y + epsilon).astype('int32')
xint = tvm.floor(in_x + epsilon).astype('int32')
return _cast_output(_get_pixel(n, c, yint, xint, cc))
# Bilinear helper functions and computation.
def _lerp(A, B, t):
return A * (1.0 - t) + B * t
def _bilinear(*indices):
n, c, y, x, cc = _get_indices(*indices)
in_y = y_ratio * y
in_x = x_ratio * x
xint = tvm.floor(in_x).astype('int32')
xfract = in_x - tvm.floor(in_x)
yint = tvm.floor(in_y).astype('int32')
yfract = in_y - tvm.floor(in_y)
p00 = _get_pixel(n, c, yint, xint, cc)
p10 = _get_pixel(n, c, yint, xint + 1, cc)
p01 = _get_pixel(n, c, yint + 1, xint, cc)
p11 = _get_pixel(n, c, yint + 1, xint + 1, cc)
col0 = _lerp(p00, p10, xfract)
col1 = _lerp(p01, p11, xfract)
value = _lerp(col0, col1, yfract)
return _cast_output(value)
# Bicubic helper function and computation.
def _cubic_kernel(A, B, C, D, t):
a = -A / 2.0 + (3.0*B) / 2.0 - (3.0*C) / 2.0 + D / 2.0
b = A - (5.0*B) / 2.0 + 2.0*C - D / 2.0
c = -A / 2.0 + C / 2.0
d = B
return a*t*t*t + b*t*t + c*t + d
def _bicubic(*indices):
n, c, y, x, cc = _get_indices(*indices)
in_y = y_ratio * y
in_x = x_ratio * x
xint = tvm.floor(in_x).astype('int32')
xfract = in_x - tvm.floor(in_x)
yint = tvm.floor(in_y).astype('int32')
yfract = in_y - tvm.floor(in_y)
# 1st row
p00 = _get_pixel(n, c, yint - 1, xint - 1, cc)
p10 = _get_pixel(n, c, yint - 1, xint + 0, cc)
p20 = _get_pixel(n, c, yint - 1, xint + 1, cc)
p30 = _get_pixel(n, c, yint - 1, xint + 2, cc)
# 2nd row
p01 = _get_pixel(n, c, yint + 0, xint - 1, cc)
p11 = _get_pixel(n, c, yint + 0, xint + 0, cc)
p21 = _get_pixel(n, c, yint + 0, xint + 1, cc)
p31 = _get_pixel(n, c, yint + 0, xint + 2, cc)
# 3rd row
p02 = _get_pixel(n, c, yint + 1, xint - 1, cc)
p12 = _get_pixel(n, c, yint + 1, xint + 0, cc)
p22 = _get_pixel(n, c, yint + 1, xint + 1, cc)
p32 = _get_pixel(n, c, yint + 1, xint + 2, cc)
# 4th row
p03 = _get_pixel(n, c, yint + 2, xint - 1, cc)
p13 = _get_pixel(n, c, yint + 2, xint + 0, cc)
p23 = _get_pixel(n, c, yint + 2, xint + 1, cc)
p33 = _get_pixel(n, c, yint + 2, xint + 2, cc)
# Interpolate bicubically
col0 = _cubic_kernel(p00, p10, p20, p30, xfract)
col1 = _cubic_kernel(p01, p11, p21, p31, xfract)
col2 = _cubic_kernel(p02, p12, p22, p32, xfract)
col3 = _cubic_kernel(p03, p13, p23, p33, xfract)
value = _cubic_kernel(col0, col1, col2, col3, yfract)
return _cast_output(value)
# Determine which interpolation method to use then run it.
if method == "nearest_neighbor":
compute_func = _nearest_neighbor
elif method == "bilinear":
compute_func = _bilinear
elif method == "bicubic":
compute_func = _bicubic
else:
raise ValueError('%s method is not supported.' % method)
return tvm.compute(output_shape, compute_func, name='resize', tag=tag.INJECTIVE)
......@@ -20,7 +20,7 @@ import topi
from ..util import simplify
def upsampling(data, scale, layout="NCHW", method='NEAREST_NEIGHBOR'):
def upsampling(data, scale, layout="NCHW", method='nearest_neighbor', align_corners=False):
"""Perform upsampling on the data.
Nearest neighbor and bilinear upsampling are supported.
......@@ -37,7 +37,7 @@ def upsampling(data, scale, layout="NCHW", method='NEAREST_NEIGHBOR'):
layout : string, optional
either "NCHW" or "NHWC"
method : {"BILINEAR", "NEAREST_NEIGHBOR"}
method : {"bilinear", "nearest_neighbor", "bicubic"}
Method to be used for upsampling.
Returns
......@@ -53,4 +53,5 @@ def upsampling(data, scale, layout="NCHW", method='NEAREST_NEIGHBOR'):
out_shape = (simplify(data.shape[1] * scale), simplify(data.shape[2] * scale))
else:
raise ValueError("not support this layout {} yet".format(layout))
return topi.cpp.nn.upsampling(data, out_shape, layout, method)
return topi.image.resize(data, out_shape, layout=layout,
method=method, align_corners=align_corners)
......@@ -19,7 +19,7 @@
import math
import numpy as np
def bilinear_resize_python(image, out_size, layout, align_corners=False):
def bilinear_resize_python(image, out_size, layout, align_corners=True):
""" Bilinear scaling using python"""
(new_h, new_w) = out_size
......
......@@ -23,7 +23,7 @@ import math
from common import get_all_backend
def verify_resize(batch, in_channel, in_height, in_width, out_height, out_width, layout='NCHW', align_corners=False, method="BILINEAR"):
def verify_resize(batch, in_channel, in_height, in_width, out_height, out_width, layout='NCHW', align_corners=True, method="bilinear"):
if layout == 'NCHW':
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A', dtype='float32')
dtype = A.dtype
......@@ -40,7 +40,7 @@ def verify_resize(batch, in_channel, in_height, in_width, out_height, out_width,
B = topi.image.resize(A, (out_height, out_width), layout=layout, align_corners=align_corners, method=method)
if method == "BILINEAR":
if method == "bilinear":
b_np = topi.testing.bilinear_resize_python(a_np, (out_height, out_width), layout, align_corners)
else:
scale_h = out_height / in_height
......@@ -76,8 +76,8 @@ def test_resize():
# Scale NHWC + Align Corners
verify_resize(6, 32, 64, 64, 20, 20, "NHWC", True)
# Nearest + Fractional
verify_resize(4, 16, 32, 32, 50, 50, 'NCHW', method="NEAREST_NEIGHBOR")
verify_resize(4, 16, 32, 32, 50, 50, 'NHWC', method="NEAREST_NEIGHBOR")
verify_resize(4, 16, 32, 32, 50, 50, 'NCHW', method="nearest_neighbor", align_corners=False)
verify_resize(4, 16, 32, 32, 50, 50, 'NHWC', method="nearest_neighbor", align_corners=False)
if __name__ == "__main__":
test_resize()
......@@ -23,7 +23,7 @@ import math
from common import get_all_backend
def verify_upsampling(batch, in_channel, in_height, in_width, scale, layout='NCHW', method="NEAREST_NEIGHBOR"):
def verify_upsampling(batch, in_channel, in_height, in_width, scale, layout='NCHW', method="nearest_neighbor"):
if layout == 'NCHW':
......@@ -40,11 +40,11 @@ def verify_upsampling(batch, in_channel, in_height, in_width, scale, layout='NCH
raise NotImplementedError(
'Layout not supported {} '.format(layout))
B = topi.nn.upsampling(A, scale, layout=layout, method=method)
B = topi.nn.upsampling(A, scale, layout=layout, method=method, align_corners=False)
if method == "BILINEAR":
if method == "bilinear":
out_size = (in_height*scale, in_width*scale)
b_np = topi.testing.bilinear_resize_python(a_np, out_size, layout)
b_np = topi.testing.bilinear_resize_python(a_np, out_size, layout, align_corners=False)
else:
b_np = topi.testing.upsampling_python(a_np, (scale, scale), layout)
......@@ -67,21 +67,21 @@ def verify_upsampling(batch, in_channel, in_height, in_width, scale, layout='NCH
check_device(device)
def test_upsampling():
# NEAREST_NEIGHBOR - NCHW
# nearest_neighbor - NCHW
verify_upsampling(8, 16, 32, 32, 2)
verify_upsampling(2, 32, 64, 64, 3)
# NEAREST_NEIGHBOR - NHWC
## nearest_neighbor - NHWC
verify_upsampling(8, 16, 32, 32, 2, layout="NHWC")
verify_upsampling(2, 32, 64, 64, 3, layout="NHWC")
# BILINEAR - NCHW
verify_upsampling(2, 2, 32, 32, 2, method="BILINEAR")
verify_upsampling(2, 2, 32, 32, 3, method="BILINEAR")
# bilinear - NCHW
verify_upsampling(2, 2, 32, 32, 2, method="bilinear")
verify_upsampling(2, 2, 32, 32, 3, method="bilinear")
# BILINEAR - NHWC
verify_upsampling(2, 2, 32, 32, 2, layout="NHWC", method="BILINEAR")
verify_upsampling(2, 2, 32, 32, 3, layout="NHWC", method="BILINEAR")
# bilinear - NHWC
verify_upsampling(2, 2, 32, 32, 2, layout="NHWC", method="bilinear")
verify_upsampling(2, 2, 32, 32, 3, layout="NHWC", method="bilinear")
if __name__ == "__main__":
test_upsampling()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment