Commit 7264cb6a by Josh Fromm Committed by masahi

Changed topi cc resize to python implementation with new features. (#3788)

parent c870261f
...@@ -37,6 +37,7 @@ struct ResizeAttrs : public tvm::AttrsNode<ResizeAttrs> { ...@@ -37,6 +37,7 @@ struct ResizeAttrs : public tvm::AttrsNode<ResizeAttrs> {
std::string layout; std::string layout;
std::string method; std::string method;
bool align_corners; bool align_corners;
DataType out_dtype;
TVM_DECLARE_ATTRS(ResizeAttrs, "relay.attrs.ResizeAttrs") { TVM_DECLARE_ATTRS(ResizeAttrs, "relay.attrs.ResizeAttrs") {
TVM_ATTR_FIELD(size).set_default(NullValue<Array<IndexExpr> >()) TVM_ATTR_FIELD(size).set_default(NullValue<Array<IndexExpr> >())
...@@ -46,12 +47,16 @@ struct ResizeAttrs : public tvm::AttrsNode<ResizeAttrs> { ...@@ -46,12 +47,16 @@ struct ResizeAttrs : public tvm::AttrsNode<ResizeAttrs> {
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width" "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Resize is applied on the 'H' and" "dimensions respectively. Resize is applied on the 'H' and"
"'W' dimensions."); "'W' dimensions.");
TVM_ATTR_FIELD(method).set_default("BILINEAR") TVM_ATTR_FIELD(method).set_default("bilinear")
.describe("Specify the mode to use for scaling." .describe("Specify the mode to use for scaling."
"NEAREST_NEIGHBOR - Nearest Neighbor" "nearest_neighbor - Nearest Neighbor"
"BILINEAR - Bilinear Interpolation"); "bilinear - Bilinear Interpolation"
TVM_ATTR_FIELD(align_corners).set_default(false) "bicubic - Bicubic Interpolation");
TVM_ATTR_FIELD(align_corners).set_default(true)
.describe("Should be true to preserve the values at the corner pixels"); .describe("Should be true to preserve the values at the corner pixels");
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type.");
} }
}; };
......
...@@ -381,6 +381,7 @@ struct UpSamplingAttrs : public tvm::AttrsNode<UpSamplingAttrs> { ...@@ -381,6 +381,7 @@ struct UpSamplingAttrs : public tvm::AttrsNode<UpSamplingAttrs> {
int scale; int scale;
std::string layout; std::string layout;
std::string method; std::string method;
bool align_corners;
TVM_DECLARE_ATTRS(UpSamplingAttrs, "relay.attrs.UpSamplingAttrs") { TVM_DECLARE_ATTRS(UpSamplingAttrs, "relay.attrs.UpSamplingAttrs") {
TVM_ATTR_FIELD(scale) TVM_ATTR_FIELD(scale)
...@@ -390,10 +391,13 @@ struct UpSamplingAttrs : public tvm::AttrsNode<UpSamplingAttrs> { ...@@ -390,10 +391,13 @@ struct UpSamplingAttrs : public tvm::AttrsNode<UpSamplingAttrs> {
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width" "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Upsampling is applied on the 'H' and" "dimensions respectively. Upsampling is applied on the 'H' and"
"'W' dimensions."); "'W' dimensions.");
TVM_ATTR_FIELD(method).set_default("NEAREST_NEIGHBOR") TVM_ATTR_FIELD(method).set_default("nearest_neighbor")
.describe("Specify the mode to use for scaling." .describe("Specify the mode to use for scaling."
"NEAREST_NEIGHBOR - Nearest Neighbor" "nearest_neighbor - Nearest Neighbor"
"BILINEAR - Bilinear Interpolation"); "bilinear - Bilinear Interpolation"
"bicubic - Bicubic Interpolation");
TVM_ATTR_FIELD(align_corners).set_default(false)
.describe("Should be true to preserve the values at the corner pixels");
} }
}; };
......
...@@ -324,12 +324,12 @@ def test_upsampling_bilinear(): ...@@ -324,12 +324,12 @@ def test_upsampling_bilinear():
data = tvm.nd.array(a_np) data = tvm.nd.array(a_np)
m.run(x=data) m.run(x=data)
out = m.get_output(0, tvm.nd.empty(oshape, dtype)) out = m.get_output(0, tvm.nd.empty(oshape, dtype))
b_np = topi.testing.bilinear_resize_python(a_np, (32*scale, 32*scale), "NCHW") b_np = topi.testing.bilinear_resize_python(a_np, (32*scale, 32*scale), "NCHW", align_corners=False)
tvm.testing.assert_allclose(out.asnumpy(), b_np, rtol=1e-5, atol=1e-5) tvm.testing.assert_allclose(out.asnumpy(), b_np, rtol=1e-5, atol=1e-5)
def test_resize_bilinear(): def test_resize_bilinear():
x = sym.Variable("x") x = sym.Variable("x")
y = sym.resize(x, size=(60, 60), method="BILINEAR", name="y", layout="NHWC") y = sym.resize(x, size=(60, 60), method="BILINEAR", name="y", layout="NHWC", align_corners=True)
dtype = "float32" dtype = "float32"
dshape = (1, 32, 32, 4) dshape = (1, 32, 32, 4)
oshape = (1, 60, 60, 4) oshape = (1, 60, 60, 4)
......
...@@ -425,7 +425,7 @@ def _test_upsample_bilinear(): ...@@ -425,7 +425,7 @@ def _test_upsample_bilinear():
y = helper.make_node("Upsample", ['in'], ['out'], mode='linear', scales=[1.0, 1.0, 2.0, 2.0]) y = helper.make_node("Upsample", ['in'], ['out'], mode='linear', scales=[1.0, 1.0, 2.0, 2.0])
in_array = np.random.uniform(size=in_shape).astype(np.float32) in_array = np.random.uniform(size=in_shape).astype(np.float32)
out_array = topi.testing.bilinear_resize_python(in_array, (3*scale, 3*scale), "NCHW") out_array = topi.testing.bilinear_resize_python(in_array, (3*scale, 3*scale), "NCHW", align_corners=False)
graph = helper.make_graph([y], graph = helper.make_graph([y],
'upsample_bilinear_test', 'upsample_bilinear_test',
...@@ -445,7 +445,7 @@ def _test_upsample_bilinear_opset9(): ...@@ -445,7 +445,7 @@ def _test_upsample_bilinear_opset9():
y = helper.make_node("Upsample", ['in','scales'], ['out'], mode='linear') y = helper.make_node("Upsample", ['in','scales'], ['out'], mode='linear')
scales=[1.0, 1.0, 2.0, 2.0] scales=[1.0, 1.0, 2.0, 2.0]
in_array = np.random.uniform(size=in_shape).astype(np.float32) in_array = np.random.uniform(size=in_shape).astype(np.float32)
out_array = topi.testing.bilinear_resize_python(in_array, (3*scale, 3*scale), "NCHW") out_array = topi.testing.bilinear_resize_python(in_array, (3*scale, 3*scale), "NCHW", align_corners=False)
ref_array = np.array(scales) ref_array = np.array(scales)
ref_node = helper.make_node('Constant', ref_node = helper.make_node('Constant',
......
...@@ -312,7 +312,7 @@ def _UpsampleLayerParams(op, inexpr, etab): ...@@ -312,7 +312,7 @@ def _UpsampleLayerParams(op, inexpr, etab):
if op.scalingFactor[0] != op.scalingFactor[1]: if op.scalingFactor[0] != op.scalingFactor[1]:
raise tvm.error.OpAttributeUnimplemented( raise tvm.error.OpAttributeUnimplemented(
'Upsample height and width must be equal.') 'Upsample height and width must be equal.')
interpolationMode = 'NEAREST_NEIGHBOR' if op.mode == 0 else 'BILINEAR' interpolationMode = 'nearest_neighbor' if op.mode == 0 else 'bilinear'
return _op.nn.upsampling(inexpr, scale=op.scalingFactor[0], method=interpolationMode) return _op.nn.upsampling(inexpr, scale=op.scalingFactor[0], method=interpolationMode)
......
...@@ -358,29 +358,30 @@ def _convert_pooling(inexpr, keras_layer, etab): ...@@ -358,29 +358,30 @@ def _convert_pooling(inexpr, keras_layer, etab):
def _convert_upsample(inexpr, keras_layer, _): def _convert_upsample(inexpr, keras_layer, _):
_check_data_format(keras_layer) _check_data_format(keras_layer)
upsample_type = type(keras_layer).__name__ upsample_type = type(keras_layer).__name__
params = {'layout': 'NHWC'}
if upsample_type == 'UpSampling1D': if upsample_type == 'UpSampling1D':
h = keras_layer.size h = keras_layer.size
params = {'scale': h} params['scale'] = h
elif upsample_type == 'UpSampling2D': elif upsample_type == 'UpSampling2D':
h, w = keras_layer.size h, w = keras_layer.size
if h != w: if h != w:
raise tvm.error.OpAttributeInvalid( raise tvm.error.OpAttributeInvalid(
'Height must equal width for operator Upsample.') 'Height must equal width for operator Upsample.')
params = {'scale': h} params['scale'] = h
if hasattr(keras_layer, 'interpolation'): if hasattr(keras_layer, 'interpolation'):
interpolation = keras_layer.interpolation interpolation = keras_layer.interpolation
if interpolation == 'nearest': if interpolation == 'nearest':
params['method'] = 'NEAREST_NEIGHBOR' params['method'] = 'nearest_neighbor'
else: else:
params['method'] = 'BILINEAR' params['method'] = 'bilinear'
elif upsample_type == 'UpSampling3D': elif upsample_type == 'UpSampling3D':
h, w, d = keras_layer.size h, w, d = keras_layer.size
if h != w or w != d: if h != w or w != d:
raise tvm.error.OpAttributeInvalid( raise tvm.error.OpAttributeInvalid(
'Height, width, and depth must all be equal for operator Upsample.') 'Height, width, and depth must all be equal for operator Upsample.')
params = {'scale': h} params['scale'] = h
else: else:
raise tvm.error.OpNotImplemented( raise tvm.error.OpNotImplemented(
'Operator {} is not supported for frontend Keras.'.format(upsample_type)) 'Operator {} is not supported for frontend Keras.'.format(upsample_type))
......
...@@ -559,13 +559,13 @@ class Upsample(OnnxOpConverter): ...@@ -559,13 +559,13 @@ class Upsample(OnnxOpConverter):
assert len(scales) == 4 and scales[0] == 1.0 and scales[1] == 1.0 and scales[2] == scales[3] assert len(scales) == 4 and scales[0] == 1.0 and scales[1] == 1.0 and scales[2] == scales[3]
mode = attr.get('mode') mode = attr.get('mode')
if mode == b'nearest': if mode == b'nearest':
method = "NEAREST_NEIGHBOR" method = "nearest_neighbor"
elif mode == b'linear': elif mode == b'linear':
method = "BILINEAR" method = "bilinear"
else: else:
raise tvm.error.OpAttributeInvalid( raise tvm.error.OpAttributeInvalid(
'Value {} in attribute "mode" of operator Upsample is not valid.'.format(mode)) 'Value {} in attribute "mode" of operator Upsample is not valid.'.format(mode))
attr = {'scale':int(scales[-1]), 'method':method, 'layout':'NCHW'} attr = {'scale':int(scales[-1]), 'method':method, 'layout':'NCHW', 'align_corners':True}
return AttrCvt('upsampling')(inputs, attr) return AttrCvt('upsampling')(inputs, attr)
......
...@@ -358,7 +358,7 @@ def _crop_and_resize(): ...@@ -358,7 +358,7 @@ def _crop_and_resize():
'Attribute method=nearest is not supported') 'Attribute method=nearest is not supported')
else: else:
attrs['align_corners'] = True attrs['align_corners'] = True
attrs['method'] = 'BILINEAR' attrs['method'] = 'bilinear'
out = None out = None
begin = [0] * data_dim begin = [0] * data_dim
...@@ -408,7 +408,7 @@ def _resize_bilinear(): ...@@ -408,7 +408,7 @@ def _resize_bilinear():
return AttrCvt(op_name="resize", return AttrCvt(op_name="resize",
ignores=['Tdim'], ignores=['Tdim'],
extras={'method': "BILINEAR"})(inputs, attr) extras={'method': "bilinear"})(inputs, attr)
return _impl return _impl
def _resize_nearest_neighbor(): def _resize_nearest_neighbor():
...@@ -423,7 +423,7 @@ def _resize_nearest_neighbor(): ...@@ -423,7 +423,7 @@ def _resize_nearest_neighbor():
return AttrCvt(op_name="resize", return AttrCvt(op_name="resize",
ignores=['Tdim'], ignores=['Tdim'],
extras={'method': "NEAREST_NEIGHBOR"})(inputs, attr) extras={'method': "nearest_neighbor"})(inputs, attr)
return _impl return _impl
def _check_numerics(): def _check_numerics():
......
...@@ -262,7 +262,7 @@ class OperatorConverter(object): ...@@ -262,7 +262,7 @@ class OperatorConverter(object):
# Options - align_corners (bool) # Options - align_corners (bool)
resize_options = None resize_options = None
align_corners = False align_corners = False
if method == "BILINEAR": if method == "bilinear":
assert op.BuiltinOptionsType() == BuiltinOptions.ResizeBilinearOptions assert op.BuiltinOptionsType() == BuiltinOptions.ResizeBilinearOptions
resize_options = ResizeBilinearOptions() resize_options = ResizeBilinearOptions()
elif tflite_ver >= 1130: elif tflite_ver >= 1130:
...@@ -280,11 +280,11 @@ class OperatorConverter(object): ...@@ -280,11 +280,11 @@ class OperatorConverter(object):
def convert_resize_bilinear(self, op): def convert_resize_bilinear(self, op):
"""Convert TFLite RESIZE_BILINEAR""" """Convert TFLite RESIZE_BILINEAR"""
return self._convert_resize("BILINEAR", op) return self._convert_resize("bilinear", op)
def convert_resize_nearest_neighbor(self, op): def convert_resize_nearest_neighbor(self, op):
"""Convert TFLite RESIZE_NEAREST_NEIGHBOR""" """Convert TFLite RESIZE_NEAREST_NEIGHBOR"""
return self._convert_resize("NEAREST_NEIGHBOR", op) return self._convert_resize("nearest_neighbor", op)
def convert_logistic(self, op): def convert_logistic(self, op):
"""Convert TFLite LOGISTIC""" """Convert TFLite LOGISTIC"""
......
...@@ -17,7 +17,20 @@ ...@@ -17,7 +17,20 @@
#pylint: disable=invalid-name, unused-argument #pylint: disable=invalid-name, unused-argument
"""Backend compiler related feature registration""" """Backend compiler related feature registration"""
from __future__ import absolute_import from __future__ import absolute_import
from ..op import register_schedule, schedule_injective
import topi
from .. import op as reg
from ..op import schedule_injective
# resize # resize
register_schedule("image.resize", schedule_injective) reg.register_schedule("image.resize", schedule_injective)
@reg.register_compute("image.resize")
def compute_resize(attrs, inputs, out_type, target):
size = attrs.size
layout = attrs.layout
method = attrs.method
align_corners = attrs.align_corners
out_dtype = attrs.out_dtype
return [topi.image.resize(inputs[0], size, layout, method, align_corners, out_dtype)]
...@@ -21,8 +21,9 @@ from . import _make ...@@ -21,8 +21,9 @@ from . import _make
def resize(data, def resize(data,
size, size,
layout="NCHW", layout="NCHW",
method="BILINEAR", method="bilinear",
align_corners=False): align_corners=True,
out_dtype=None):
"""Image resize operator. """Image resize operator.
This operator takes data as input and does 2D scaling to the given scale factor. This operator takes data as input and does 2D scaling to the given scale factor.
...@@ -31,7 +32,7 @@ def resize(data, ...@@ -31,7 +32,7 @@ def resize(data,
out will have a shape (n, c, size[0], size[1]) out will have a shape (n, c, size[0], size[1])
method indicates the algorithm to be used while calculating ghe out value method indicates the algorithm to be used while calculating ghe out value
and method can be one of ("BILINEAR", "NEAREST_NEIGHBOR") and method can be one of ("bilinear", "nearest_neighbor", "bicubic")
Parameters Parameters
---------- ----------
...@@ -45,14 +46,17 @@ def resize(data, ...@@ -45,14 +46,17 @@ def resize(data,
Layout of the input. Layout of the input.
method : str, optional method : str, optional
Scale method to used [NEAREST_NEIGHBOR, BILINEAR]. Scale method to used [nearest_neighbor, bilinear, bicubic].
align_corners : int, optional align_corners : int, optional
Should be true to preserve the values at the corner pixels Should be true to preserve the values at the corner pixels
out_dtype : str, optional
Type to return. If left None returns the same type as input.
Returns Returns
------- -------
result: relay.Expr result: relay.Expr
The resized result. The resized result.
""" """
return _make.resize(data, size, layout, method, align_corners) return _make.resize(data, size, layout, method, align_corners, out_dtype)
...@@ -376,6 +376,13 @@ def schedule_upsampling(_, outs, target): ...@@ -376,6 +376,13 @@ def schedule_upsampling(_, outs, target):
with target: with target:
return topi.generic.schedule_injective(outs) return topi.generic.schedule_injective(outs)
@reg.register_compute("nn.upsampling")
def compute_upsampling(attrs, inputs, out_dtype, target):
scale = attrs.scale
layout = attrs.layout
method = attrs.method
align_corners = attrs.align_corners
return [topi.nn.upsampling(inputs[0], scale, layout, method, align_corners)]
# pad # pad
reg.register_schedule("nn.pad", schedule_broadcast) reg.register_schedule("nn.pad", schedule_broadcast)
......
...@@ -481,7 +481,8 @@ def global_avg_pool2d(data, ...@@ -481,7 +481,8 @@ def global_avg_pool2d(data,
def upsampling(data, def upsampling(data,
scale=1, scale=1,
layout="NCHW", layout="NCHW",
method="NEAREST_NEIGHBOR"): method="nearest_neighbor",
align_corners=False):
"""Upsampling. """Upsampling.
This operator takes data as input and does 2D scaling to the given scale factor. This operator takes data as input and does 2D scaling to the given scale factor.
...@@ -490,7 +491,7 @@ def upsampling(data, ...@@ -490,7 +491,7 @@ def upsampling(data,
out will have a shape (n, c, h*scale, w*scale) out will have a shape (n, c, h*scale, w*scale)
method indicates the algorithm to be used while calculating the out value method indicates the algorithm to be used while calculating the out value
and method can be one of ("BILINEAR", "NEAREST_NEIGHBOR") and method can be one of ("bilinear", "nearest_neighbor", "bicubic")
Parameters Parameters
---------- ----------
...@@ -504,14 +505,17 @@ def upsampling(data, ...@@ -504,14 +505,17 @@ def upsampling(data,
Layout of the input. Layout of the input.
method : str, optional method : str, optional
Scale method to used [NEAREST_NEIGHBOR, BILINEAR]. Scale method to used [nearest_neighbor, bilinear, bicubic].
align_corners : bool, optional
Whether to keep corners in proper place.
Returns Returns
------- -------
result : tvm.relay.Expr result : tvm.relay.Expr
The computed result. The computed result.
""" """
return _make.upsampling(data, scale, layout, method) return _make.upsampling(data, scale, layout, method, align_corners)
def batch_flatten(data): def batch_flatten(data):
......
...@@ -25,8 +25,6 @@ ...@@ -25,8 +25,6 @@
#include <tvm/data_layout.h> #include <tvm/data_layout.h>
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/attrs/image.h> #include <tvm/relay/attrs/image.h>
#include <topi/elemwise.h>
#include <topi/image/resize.h>
#include "../op_common.h" #include "../op_common.h"
namespace tvm { namespace tvm {
...@@ -56,49 +54,32 @@ bool ResizeRel(const Array<Type>& types, ...@@ -56,49 +54,32 @@ bool ResizeRel(const Array<Type>& types,
oshape.Set(2, param->size[0]); oshape.Set(2, param->size[0]);
oshape.Set(3, param->size[1]); oshape.Set(3, param->size[1]);
DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
}
// assign output type // assign output type
reporter->Assign(types[1], reporter->Assign(types[1],
TensorTypeNode::make(layout_converter.BackwardShape(oshape), TensorTypeNode::make(layout_converter.BackwardShape(oshape),
data->dtype)); out_dtype));
return true; return true;
} }
Array<Tensor> ResizeCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
const auto* param = attrs.as<ResizeAttrs>();
CHECK(param != nullptr);
CHECK(param->layout == "NCHW" || param->layout == "NHWC");
const auto* out_ttype = out_type.as<TensorTypeNode>();
CHECK(out_ttype != nullptr);
Array<IndexExpr> oshape;
if (param->layout == "NCHW") {
oshape.push_back(out_ttype->shape[2]);
oshape.push_back(out_ttype->shape[3]);
} else if (param->layout == "NHWC") {
oshape.push_back(out_ttype->shape[1]);
oshape.push_back(out_ttype->shape[2]);
}
return Array<Tensor>{ topi::image::resize(inputs[0],
oshape,
param->layout,
param->align_corners,
param->method) };
}
// Positional relay function to create image operator // Positional relay function to create image operator
// used by frontend FFI. // used by frontend FFI.
Expr MakeResize(Expr data, Expr MakeResize(Expr data,
Array<IndexExpr> size, Array<IndexExpr> size,
std::string layout, std::string layout,
std::string method, std::string method,
bool align_corners) { bool align_corners,
DataType out_dtype) {
auto attrs = make_node<ResizeAttrs>(); auto attrs = make_node<ResizeAttrs>();
attrs->size = std::move(size); attrs->size = std::move(size);
attrs->layout = std::move(layout); attrs->layout = std::move(layout);
attrs->method = std::move(method); attrs->method = std::move(method);
attrs->align_corners = align_corners; attrs->align_corners = align_corners;
attrs->out_dtype = out_dtype;
static const Op& op = Op::Get("image.resize"); static const Op& op = Op::Get("image.resize");
return CallNode::make(op, {data}, Attrs(attrs), {}); return CallNode::make(op, {data}, Attrs(attrs), {});
} }
...@@ -127,7 +108,6 @@ RELAY_REGISTER_OP("image.resize") ...@@ -127,7 +108,6 @@ RELAY_REGISTER_OP("image.resize")
.add_argument("data", "Tensor", "The input tensor.") .add_argument("data", "Tensor", "The input tensor.")
.set_support_level(5) .set_support_level(5)
.add_type_rel("Resize", ResizeRel) .add_type_rel("Resize", ResizeRel)
.set_attr<FTVMCompute>("FTVMCompute", ResizeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective); .set_attr<TOpPattern>("TOpPattern", kInjective);
} // namespace relay } // namespace relay
......
...@@ -27,8 +27,6 @@ ...@@ -27,8 +27,6 @@
#include <tvm/relay/attrs/nn.h> #include <tvm/relay/attrs/nn.h>
#include <tvm/relay/op_attr_types.h> #include <tvm/relay/op_attr_types.h>
#include <tvm/build_module.h> #include <tvm/build_module.h>
#include <topi/elemwise.h>
#include <topi/nn/upsampling.h>
#include <vector> #include <vector>
#include "../op_common.h" #include "../op_common.h"
...@@ -99,11 +97,13 @@ bool UpSamplingRel(const Array<Type>& types, ...@@ -99,11 +97,13 @@ bool UpSamplingRel(const Array<Type>& types,
Expr MakeUpSampling(Expr data, Expr MakeUpSampling(Expr data,
int scale, int scale,
std::string layout, std::string layout,
std::string method) { std::string method,
bool align_corners) {
auto attrs = make_node<UpSamplingAttrs>(); auto attrs = make_node<UpSamplingAttrs>();
attrs->layout = std::move(layout); attrs->layout = std::move(layout);
attrs->method = std::move(method); attrs->method = std::move(method);
attrs->scale = scale; attrs->scale = scale;
attrs->align_corners = align_corners;
static const Op& op = Op::Get("nn.upsampling"); static const Op& op = Op::Get("nn.upsampling");
return CallNode::make(op, {data}, Attrs(attrs), {}); return CallNode::make(op, {data}, Attrs(attrs), {});
} }
...@@ -135,38 +135,7 @@ RELAY_REGISTER_OP("nn.upsampling") ...@@ -135,38 +135,7 @@ RELAY_REGISTER_OP("nn.upsampling")
.add_type_rel("UpSampling", UpSamplingRel) .add_type_rel("UpSampling", UpSamplingRel)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", .set_attr<FInferCorrectLayout>("FInferCorrectLayout",
UpsamplingInferCorrectLayout<UpSamplingAttrs>) UpsamplingInferCorrectLayout<UpSamplingAttrs>)
.set_attr<TOpPattern>("TOpPattern", kInjective) .set_attr<TOpPattern>("TOpPattern", kInjective);
.set_attr<FTVMCompute>(
"FTVMCompute", [](const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
const auto* uattrs = attrs.as<UpSamplingAttrs>();
CHECK(uattrs != nullptr);
auto out_tt = out_type.as<TensorTypeNode>();
CHECK(out_tt) << "expected a tensor type: " << out_type;
const auto layout = uattrs->layout;
const auto base_layout = layout.substr(0, 4);
CHECK(base_layout == "NCHW" || layout == "NHWC")
<< "unknown layout: " << uattrs->layout;
Array<IndexExpr> oshape;
if (base_layout == "NCHW") {
oshape.push_back(out_tt->shape[2]);
oshape.push_back(out_tt->shape[3]);
} else if (layout == "NHWC") {
oshape.push_back(out_tt->shape[1]);
oshape.push_back(out_tt->shape[2]);
}
return Array<Tensor>{
topi::nn::upsampling(
inputs[0],
oshape,
uattrs->layout,
uattrs->method)
};
});
} // namespace relay } // namespace relay
......
...@@ -172,7 +172,7 @@ def test_forward_upsample(interpolation='nearest'): ...@@ -172,7 +172,7 @@ def test_forward_upsample(interpolation='nearest'):
data = keras.layers.Input(shape=(32, 32, 3)) data = keras.layers.Input(shape=(32, 32, 3))
x = keras.layers.UpSampling2D(size=(3, 3), interpolation=interpolation)(data) x = keras.layers.UpSampling2D(size=(3, 3), interpolation=interpolation)(data)
keras_model = keras.models.Model(data, x) keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model) verify_keras_frontend(keras_model, need_transpose=False)
def test_forward_reshape(): def test_forward_reshape():
......
...@@ -1212,7 +1212,7 @@ def test_forward_crop_and_resize(): ...@@ -1212,7 +1212,7 @@ def test_forward_crop_and_resize():
_test_forward_crop_and_resize([1, 11, 11, 3], [[0, 0, .9, .9]], [0], [5, 5]) _test_forward_crop_and_resize([1, 11, 11, 3], [[0, 0, .9, .9]], [0], [5, 5])
_test_forward_crop_and_resize([1, 11, 11, 3], [[.1, .2, 1, 1]], [0], [5, 5]) _test_forward_crop_and_resize([1, 11, 11, 3], [[.1, .2, 1, 1]], [0], [5, 5])
_test_forward_crop_and_resize([1, 21, 21, 3], [[.2, .3, .7, .9]], [0], [3, 4]) _test_forward_crop_and_resize([1, 21, 21, 3], [[.2, .3, .7, .9]], [0], [3, 4])
_test_forward_crop_and_resize([1, 106, 106, 3], [[0.2, 0.4, 0.8, 0.8]], [0], [3, 3]) _test_forward_crop_and_resize([1, 41, 41, 3], [[0.2, 0.4, 0.8, 0.8]], [0], [3, 3])
_test_forward_crop_and_resize([10, 11, 11, 3], _test_forward_crop_and_resize([10, 11, 11, 3],
[[0, 0, 0.9, 0.9], [0.2, 0.2, 0.8, 0.8]], [[0, 0, 0.9, 0.9], [0.2, 0.2, 0.8, 0.8]],
[0, 1], [0, 1],
......
...@@ -233,13 +233,13 @@ def test_conv2d_transpose_run(): ...@@ -233,13 +233,13 @@ def test_conv2d_transpose_run():
def test_upsampling_infer_type(): def test_upsampling_infer_type():
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) x = relay.var("x", relay.TensorType((n, c, h, w), "float32"))
y = relay.nn.upsampling(x, scale=2, layout="NCHW", method="BILINEAR") y = relay.nn.upsampling(x, scale=2, layout="NCHW", method="bilinear")
"method=\"BINLINEAR\"" in y.astext() "method=\"BINLINEAR\"" in y.astext()
yy = run_infer_type(y) yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, c, h*2, w*2), "float32") assert yy.checked_type == relay.TensorType((n, c, h*2, w*2), "float32")
n, c = tvm.var("n"), tvm.var("c") n, c = tvm.var("n"), tvm.var("c")
x = relay.var("x", relay.TensorType((n, c, 100, 200), "float32")) x = relay.var("x", relay.TensorType((n, c, 100, 200), "float32"))
y = relay.nn.upsampling(x, scale=2, layout="NCHW", method="BILINEAR") y = relay.nn.upsampling(x, scale=2, layout="NCHW", method="bilinear")
yy = run_infer_type(y) yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, c, 200, 400), "float32") assert yy.checked_type == relay.TensorType((n, c, 200, 400), "float32")
...@@ -502,7 +502,7 @@ def test_batch_flatten(): ...@@ -502,7 +502,7 @@ def test_batch_flatten():
np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01) np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
def _test_upsampling(layout, method): def _test_upsampling(layout, method, align_corners=False):
n, c, h, w = tvm.var("n"), 16, 32, 32 n, c, h, w = tvm.var("n"), 16, 32, 32
scale = 2 scale = 2
dtype = "float32" dtype = "float32"
...@@ -513,15 +513,17 @@ def _test_upsampling(layout, method): ...@@ -513,15 +513,17 @@ def _test_upsampling(layout, method):
return (h, w, c), (h*scale, w*scale, c) return (h, w, c), (h*scale, w*scale, c)
ishape, oshape = get_shape() ishape, oshape = get_shape()
x = relay.var("x", relay.TensorType((n,) + ishape, dtype)) x = relay.var("x", relay.TensorType((n,) + ishape, dtype))
y = relay.nn.upsampling(x, scale=scale, layout=layout, method=method) y = relay.nn.upsampling(x, scale=scale, layout=layout,
method=method, align_corners=align_corners)
yy = run_infer_type(y) yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n,) + oshape, dtype) assert yy.checked_type == relay.TensorType((n,) + oshape, dtype)
dshape = (1,) + ishape dshape = (1,) + ishape
x = relay.var("x", shape=dshape) x = relay.var("x", shape=dshape)
y = relay.nn.upsampling(x, scale=scale, layout=layout, method=method) y = relay.nn.upsampling(x, scale=scale, layout=layout,
method=method, align_corners=align_corners)
func = relay.Function([x], y) func = relay.Function([x], y)
data = np.random.uniform(size=dshape).astype(dtype) data = np.random.uniform(size=dshape).astype(dtype)
if method == "NEAREST_NEIGHBOR": if method == "nearest_neighbor":
ref = topi.testing.upsampling_python(data, (scale, scale), layout) ref = topi.testing.upsampling_python(data, (scale, scale), layout)
else: else:
ref = topi.testing.bilinear_resize_python(data, (h*scale, w*scale), layout) ref = topi.testing.bilinear_resize_python(data, (h*scale, w*scale), layout)
...@@ -532,10 +534,10 @@ def _test_upsampling(layout, method): ...@@ -532,10 +534,10 @@ def _test_upsampling(layout, method):
def test_upsampling(): def test_upsampling():
_test_upsampling("NCHW", "NEAREST_NEIGHBOR") _test_upsampling("NCHW", "nearest_neighbor")
_test_upsampling("NCHW", "BILINEAR") _test_upsampling("NCHW", "bilinear", True)
_test_upsampling("NHWC", "NEAREST_NEIGHBOR") _test_upsampling("NHWC", "nearest_neighbor")
_test_upsampling("NHWC", "BILINEAR") _test_upsampling("NHWC", "bilinear", True)
def test_conv2d_int8_intrinsics(): def test_conv2d_int8_intrinsics():
......
...@@ -39,7 +39,7 @@ def test_resize_infer_type(): ...@@ -39,7 +39,7 @@ def test_resize_infer_type():
assert zz.checked_type == relay.TensorType((n, c, th, tw), "int8") assert zz.checked_type == relay.TensorType((n, c, th, tw), "int8")
x = relay.var("x", relay.TensorType((n, c, h, w), "int8")) x = relay.var("x", relay.TensorType((n, c, h, w), "int8"))
z= relay.image.resize(x, (100, 200), "NCHW", "BILINEAR", False) z= relay.image.resize(x, (100, 200), "NCHW", "bilinear", True)
assert "size=" in z.astext() assert "size=" in z.astext()
zz = run_infer_type(z) zz = run_infer_type(z)
assert zz.checked_type == relay.TensorType((n, c, 100, 200), "int8") assert zz.checked_type == relay.TensorType((n, c, 100, 200), "int8")
...@@ -52,12 +52,12 @@ def test_resize(): ...@@ -52,12 +52,12 @@ def test_resize():
size = (dshape[2] * scale, dshape[3] * scale) size = (dshape[2] * scale, dshape[3] * scale)
x_data = np.random.uniform(size=dshape).astype("float32") x_data = np.random.uniform(size=dshape).astype("float32")
if method == "BILINEAR": if method == "bilinear":
ref_res = topi.testing.bilinear_resize_python(x_data, size, layout) ref_res = topi.testing.bilinear_resize_python(x_data, size, layout)
else: else:
ref_res = topi.testing.upsampling_python(x_data, (scale, scale), layout) ref_res = topi.testing.upsampling_python(x_data, (scale, scale), layout)
x = relay.var("x", relay.TensorType(dshape, "float32")) x = relay.var("x", relay.TensorType(dshape, "float32"))
z = relay.image.resize(x, size, layout, method, False) z = relay.image.resize(x, size, layout, method, True)
assert "size=" in z.astext() assert "size=" in z.astext()
zz = run_infer_type(z) zz = run_infer_type(z)
assert zz.checked_type == relay.TensorType(ref_res.shape, "float32") assert zz.checked_type == relay.TensorType(ref_res.shape, "float32")
...@@ -68,7 +68,7 @@ def test_resize(): ...@@ -68,7 +68,7 @@ def test_resize():
intrp = relay.create_executor(kind, ctx=ctx, target=target) intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data) op_res = intrp.evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
for method in ["BILINEAR", "NEAREST_NEIGHBOR"]: for method in ["bilinear", "nearest_neighbor"]:
for layout in ["NHWC", "NCHW"]: for layout in ["NHWC", "NCHW"]:
verify_resize((1, 4, 4, 4), 2, method, layout) verify_resize((1, 4, 4, 4), 2, method, layout)
......
...@@ -201,7 +201,6 @@ inline Tensor resize_nearest_neighbor(const Tensor& input, ...@@ -201,7 +201,6 @@ inline Tensor resize_nearest_neighbor(const Tensor& input,
bool align_corners = false, bool align_corners = false,
std::string name = "tensor", std::string name = "tensor",
std::string tag = kInjective) { std::string tag = kInjective) {
CHECK_EQ(align_corners, false) << "Align corners not supported for nearest neighbour";
auto base_layout = layout.substr(0, 4); auto base_layout = layout.substr(0, 4);
if (layout == "NHWC") { if (layout == "NHWC") {
return resize_nearest_neighbor_nhwc(input, shape, align_corners); return resize_nearest_neighbor_nhwc(input, shape, align_corners);
......
...@@ -14,11 +14,14 @@ ...@@ -14,11 +14,14 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=invalid-name
"""TVM operator input resize compute.""" """TVM operator input resize compute."""
from __future__ import absolute_import from __future__ import absolute_import
import topi import tvm
from .. import tag
def resize(data, size, layout="NCHW", align_corners=False, method="BILINEAR"):
def resize(data, size, layout="NCHW", method="bilinear", align_corners=True, out_dtype=None):
"""Perform resize operation on the data. """Perform resize operation on the data.
Parameters Parameters
...@@ -32,18 +35,178 @@ def resize(data, size, layout="NCHW", align_corners=False, method="BILINEAR"): ...@@ -32,18 +35,178 @@ def resize(data, size, layout="NCHW", align_corners=False, method="BILINEAR"):
Output resolution scale to Output resolution scale to
layout: string, optional layout: string, optional
either "NCHW" or "NHWC" "NCHW", "NHWC", or "NCHWc".
align_corners: Boolean, optional align_corners: Boolean, optional
To preserve the values at the corner pixels To preserve the values at the corner pixels.
method: {"BILINEAR", "NEAREST_NEIGHBOR"} method: {"bilinear", "nearest_neighbor", "bicubic"}
Method to be used for resizing. Method to be used for resizing.
out_dtype: string, optional
Type to return. If left None will be same as input type.
Returns Returns
------- -------
output : tvm.Tensor output : tvm.Tensor
4-D with shape [batch, channel, in_height*scale, in_width*scale] 4-D with shape [batch, channel, in_height*scale, in_width*scale]
or [batch, in_height*scale, in_width*scale, channel] or [batch, in_height*scale, in_width*scale, channel]
or 5-D with shape [batch, channel-major, in_height*scale, in_width*scale, channel-minor]
""" """
return topi.cpp.image.resize(data, size, layout, align_corners, method) method = method.lower()
if layout == 'NHWC':
in_n, in_h, in_w, in_c = data.shape
output_shape = [in_n, size[0], size[1], in_c]
elif layout == 'NCHW':
in_n, in_c, in_h, in_w = data.shape
output_shape = [in_n, in_c, size[0], size[1]]
# Otherwise layout must be NCHWxc
else:
in_n, in_c, in_h, in_w, in_cc = data.shape
output_shape = [in_n, in_c, size[0], size[1], in_cc]
if align_corners:
y_ratio = (in_h - 1).astype('float') / (size[0] - 1)
x_ratio = (in_w - 1).astype('float') / (size[1] - 1)
else:
y_ratio = (in_h).astype('float') / (size[0])
x_ratio = (in_w).astype('float') / (size[1])
def _get_pixel(n, c, y, x, cc):
y = tvm.max(tvm.min(y, in_h - 1), 0)
x = tvm.max(tvm.min(x, in_w - 1), 0)
if layout == 'NHWC':
return data(n, y, x, c).astype('float')
if layout == 'NCHW':
return data(n, c, y, x).astype('float')
# else must be NCHWxc
return data(n, c, y, x, cc).astype('float')
def _get_indices(*indices):
if layout == 'NHWC':
n, y, x, c = indices
cc = None
elif layout == 'NCHW':
n, c, y, x = indices
cc = None
else:
n, c, y, x, cc = indices
return n, c, y, x, cc
def _cast_output(value):
if out_dtype:
dtype = out_dtype
else:
dtype = data.dtype
return value.astype(dtype)
# Nearest neighbor computation
def _nearest_neighbor(*indices):
n, c, y, x, cc = _get_indices(*indices)
in_y = y_ratio * y
in_x = x_ratio * x
if align_corners:
yint = tvm.round(in_y).astype('int32')
xint = tvm.round(in_x).astype('int32')
else:
# Add epsilon to floor to prevent gpu rounding errors.
epsilon = 1e-5
yint = tvm.floor(in_y + epsilon).astype('int32')
xint = tvm.floor(in_x + epsilon).astype('int32')
return _cast_output(_get_pixel(n, c, yint, xint, cc))
# Bilinear helper functions and computation.
def _lerp(A, B, t):
return A * (1.0 - t) + B * t
def _bilinear(*indices):
n, c, y, x, cc = _get_indices(*indices)
in_y = y_ratio * y
in_x = x_ratio * x
xint = tvm.floor(in_x).astype('int32')
xfract = in_x - tvm.floor(in_x)
yint = tvm.floor(in_y).astype('int32')
yfract = in_y - tvm.floor(in_y)
p00 = _get_pixel(n, c, yint, xint, cc)
p10 = _get_pixel(n, c, yint, xint + 1, cc)
p01 = _get_pixel(n, c, yint + 1, xint, cc)
p11 = _get_pixel(n, c, yint + 1, xint + 1, cc)
col0 = _lerp(p00, p10, xfract)
col1 = _lerp(p01, p11, xfract)
value = _lerp(col0, col1, yfract)
return _cast_output(value)
# Bicubic helper function and computation.
def _cubic_kernel(A, B, C, D, t):
a = -A / 2.0 + (3.0*B) / 2.0 - (3.0*C) / 2.0 + D / 2.0
b = A - (5.0*B) / 2.0 + 2.0*C - D / 2.0
c = -A / 2.0 + C / 2.0
d = B
return a*t*t*t + b*t*t + c*t + d
def _bicubic(*indices):
n, c, y, x, cc = _get_indices(*indices)
in_y = y_ratio * y
in_x = x_ratio * x
xint = tvm.floor(in_x).astype('int32')
xfract = in_x - tvm.floor(in_x)
yint = tvm.floor(in_y).astype('int32')
yfract = in_y - tvm.floor(in_y)
# 1st row
p00 = _get_pixel(n, c, yint - 1, xint - 1, cc)
p10 = _get_pixel(n, c, yint - 1, xint + 0, cc)
p20 = _get_pixel(n, c, yint - 1, xint + 1, cc)
p30 = _get_pixel(n, c, yint - 1, xint + 2, cc)
# 2nd row
p01 = _get_pixel(n, c, yint + 0, xint - 1, cc)
p11 = _get_pixel(n, c, yint + 0, xint + 0, cc)
p21 = _get_pixel(n, c, yint + 0, xint + 1, cc)
p31 = _get_pixel(n, c, yint + 0, xint + 2, cc)
# 3rd row
p02 = _get_pixel(n, c, yint + 1, xint - 1, cc)
p12 = _get_pixel(n, c, yint + 1, xint + 0, cc)
p22 = _get_pixel(n, c, yint + 1, xint + 1, cc)
p32 = _get_pixel(n, c, yint + 1, xint + 2, cc)
# 4th row
p03 = _get_pixel(n, c, yint + 2, xint - 1, cc)
p13 = _get_pixel(n, c, yint + 2, xint + 0, cc)
p23 = _get_pixel(n, c, yint + 2, xint + 1, cc)
p33 = _get_pixel(n, c, yint + 2, xint + 2, cc)
# Interpolate bicubically
col0 = _cubic_kernel(p00, p10, p20, p30, xfract)
col1 = _cubic_kernel(p01, p11, p21, p31, xfract)
col2 = _cubic_kernel(p02, p12, p22, p32, xfract)
col3 = _cubic_kernel(p03, p13, p23, p33, xfract)
value = _cubic_kernel(col0, col1, col2, col3, yfract)
return _cast_output(value)
# Determine which interpolation method to use then run it.
if method == "nearest_neighbor":
compute_func = _nearest_neighbor
elif method == "bilinear":
compute_func = _bilinear
elif method == "bicubic":
compute_func = _bicubic
else:
raise ValueError('%s method is not supported.' % method)
return tvm.compute(output_shape, compute_func, name='resize', tag=tag.INJECTIVE)
...@@ -20,7 +20,7 @@ import topi ...@@ -20,7 +20,7 @@ import topi
from ..util import simplify from ..util import simplify
def upsampling(data, scale, layout="NCHW", method='NEAREST_NEIGHBOR'): def upsampling(data, scale, layout="NCHW", method='nearest_neighbor', align_corners=False):
"""Perform upsampling on the data. """Perform upsampling on the data.
Nearest neighbor and bilinear upsampling are supported. Nearest neighbor and bilinear upsampling are supported.
...@@ -37,7 +37,7 @@ def upsampling(data, scale, layout="NCHW", method='NEAREST_NEIGHBOR'): ...@@ -37,7 +37,7 @@ def upsampling(data, scale, layout="NCHW", method='NEAREST_NEIGHBOR'):
layout : string, optional layout : string, optional
either "NCHW" or "NHWC" either "NCHW" or "NHWC"
method : {"BILINEAR", "NEAREST_NEIGHBOR"} method : {"bilinear", "nearest_neighbor", "bicubic"}
Method to be used for upsampling. Method to be used for upsampling.
Returns Returns
...@@ -53,4 +53,5 @@ def upsampling(data, scale, layout="NCHW", method='NEAREST_NEIGHBOR'): ...@@ -53,4 +53,5 @@ def upsampling(data, scale, layout="NCHW", method='NEAREST_NEIGHBOR'):
out_shape = (simplify(data.shape[1] * scale), simplify(data.shape[2] * scale)) out_shape = (simplify(data.shape[1] * scale), simplify(data.shape[2] * scale))
else: else:
raise ValueError("not support this layout {} yet".format(layout)) raise ValueError("not support this layout {} yet".format(layout))
return topi.cpp.nn.upsampling(data, out_shape, layout, method) return topi.image.resize(data, out_shape, layout=layout,
method=method, align_corners=align_corners)
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
import math import math
import numpy as np import numpy as np
def bilinear_resize_python(image, out_size, layout, align_corners=False): def bilinear_resize_python(image, out_size, layout, align_corners=True):
""" Bilinear scaling using python""" """ Bilinear scaling using python"""
(new_h, new_w) = out_size (new_h, new_w) = out_size
......
...@@ -23,7 +23,7 @@ import math ...@@ -23,7 +23,7 @@ import math
from common import get_all_backend from common import get_all_backend
def verify_resize(batch, in_channel, in_height, in_width, out_height, out_width, layout='NCHW', align_corners=False, method="BILINEAR"): def verify_resize(batch, in_channel, in_height, in_width, out_height, out_width, layout='NCHW', align_corners=True, method="bilinear"):
if layout == 'NCHW': if layout == 'NCHW':
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A', dtype='float32') A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A', dtype='float32')
dtype = A.dtype dtype = A.dtype
...@@ -40,7 +40,7 @@ def verify_resize(batch, in_channel, in_height, in_width, out_height, out_width, ...@@ -40,7 +40,7 @@ def verify_resize(batch, in_channel, in_height, in_width, out_height, out_width,
B = topi.image.resize(A, (out_height, out_width), layout=layout, align_corners=align_corners, method=method) B = topi.image.resize(A, (out_height, out_width), layout=layout, align_corners=align_corners, method=method)
if method == "BILINEAR": if method == "bilinear":
b_np = topi.testing.bilinear_resize_python(a_np, (out_height, out_width), layout, align_corners) b_np = topi.testing.bilinear_resize_python(a_np, (out_height, out_width), layout, align_corners)
else: else:
scale_h = out_height / in_height scale_h = out_height / in_height
...@@ -76,8 +76,8 @@ def test_resize(): ...@@ -76,8 +76,8 @@ def test_resize():
# Scale NHWC + Align Corners # Scale NHWC + Align Corners
verify_resize(6, 32, 64, 64, 20, 20, "NHWC", True) verify_resize(6, 32, 64, 64, 20, 20, "NHWC", True)
# Nearest + Fractional # Nearest + Fractional
verify_resize(4, 16, 32, 32, 50, 50, 'NCHW', method="NEAREST_NEIGHBOR") verify_resize(4, 16, 32, 32, 50, 50, 'NCHW', method="nearest_neighbor", align_corners=False)
verify_resize(4, 16, 32, 32, 50, 50, 'NHWC', method="NEAREST_NEIGHBOR") verify_resize(4, 16, 32, 32, 50, 50, 'NHWC', method="nearest_neighbor", align_corners=False)
if __name__ == "__main__": if __name__ == "__main__":
test_resize() test_resize()
...@@ -23,7 +23,7 @@ import math ...@@ -23,7 +23,7 @@ import math
from common import get_all_backend from common import get_all_backend
def verify_upsampling(batch, in_channel, in_height, in_width, scale, layout='NCHW', method="NEAREST_NEIGHBOR"): def verify_upsampling(batch, in_channel, in_height, in_width, scale, layout='NCHW', method="nearest_neighbor"):
if layout == 'NCHW': if layout == 'NCHW':
...@@ -40,11 +40,11 @@ def verify_upsampling(batch, in_channel, in_height, in_width, scale, layout='NCH ...@@ -40,11 +40,11 @@ def verify_upsampling(batch, in_channel, in_height, in_width, scale, layout='NCH
raise NotImplementedError( raise NotImplementedError(
'Layout not supported {} '.format(layout)) 'Layout not supported {} '.format(layout))
B = topi.nn.upsampling(A, scale, layout=layout, method=method) B = topi.nn.upsampling(A, scale, layout=layout, method=method, align_corners=False)
if method == "BILINEAR": if method == "bilinear":
out_size = (in_height*scale, in_width*scale) out_size = (in_height*scale, in_width*scale)
b_np = topi.testing.bilinear_resize_python(a_np, out_size, layout) b_np = topi.testing.bilinear_resize_python(a_np, out_size, layout, align_corners=False)
else: else:
b_np = topi.testing.upsampling_python(a_np, (scale, scale), layout) b_np = topi.testing.upsampling_python(a_np, (scale, scale), layout)
...@@ -67,21 +67,21 @@ def verify_upsampling(batch, in_channel, in_height, in_width, scale, layout='NCH ...@@ -67,21 +67,21 @@ def verify_upsampling(batch, in_channel, in_height, in_width, scale, layout='NCH
check_device(device) check_device(device)
def test_upsampling(): def test_upsampling():
# NEAREST_NEIGHBOR - NCHW # nearest_neighbor - NCHW
verify_upsampling(8, 16, 32, 32, 2) verify_upsampling(8, 16, 32, 32, 2)
verify_upsampling(2, 32, 64, 64, 3) verify_upsampling(2, 32, 64, 64, 3)
# NEAREST_NEIGHBOR - NHWC ## nearest_neighbor - NHWC
verify_upsampling(8, 16, 32, 32, 2, layout="NHWC") verify_upsampling(8, 16, 32, 32, 2, layout="NHWC")
verify_upsampling(2, 32, 64, 64, 3, layout="NHWC") verify_upsampling(2, 32, 64, 64, 3, layout="NHWC")
# BILINEAR - NCHW # bilinear - NCHW
verify_upsampling(2, 2, 32, 32, 2, method="BILINEAR") verify_upsampling(2, 2, 32, 32, 2, method="bilinear")
verify_upsampling(2, 2, 32, 32, 3, method="BILINEAR") verify_upsampling(2, 2, 32, 32, 3, method="bilinear")
# BILINEAR - NHWC # bilinear - NHWC
verify_upsampling(2, 2, 32, 32, 2, layout="NHWC", method="BILINEAR") verify_upsampling(2, 2, 32, 32, 2, layout="NHWC", method="bilinear")
verify_upsampling(2, 2, 32, 32, 3, layout="NHWC", method="BILINEAR") verify_upsampling(2, 2, 32, 32, 3, layout="NHWC", method="bilinear")
if __name__ == "__main__": if __name__ == "__main__":
test_upsampling() test_upsampling()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment