Commit 8b1fb4d5 by Xingyu Zhou Committed by Zhi

[Relay][Op] Enhance Upsample Operator to support float scales (#4206)

* :add scale2 for upsample

* update unit test for upsampling

* support latest upsample op for multiple frontend

* fix lint

* fix lint

* fix lint

* fix lint

* update scale description and rebase
parent 2e07447e
...@@ -700,6 +700,9 @@ inline Expr make_zero(Type t) { ...@@ -700,6 +700,9 @@ inline Expr make_zero(Type t) {
} \ } \
inline Expr Name(const Expr& a, int b) { \ inline Expr Name(const Expr& a, int b) { \
return Name(a, make_const(a.type(), b)); \ return Name(a, make_const(a.type(), b)); \
} \
inline Expr Name(const Expr& a, double b) { \
return Name(a, make_const(Float(64), b)); \
} }
#define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(Name) \ #define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(Name) \
......
...@@ -387,14 +387,17 @@ struct FIFOBufferAttrs : public tvm::AttrsNode<FIFOBufferAttrs> { ...@@ -387,14 +387,17 @@ struct FIFOBufferAttrs : public tvm::AttrsNode<FIFOBufferAttrs> {
/*! \brief Attributes for upsampling operator */ /*! \brief Attributes for upsampling operator */
struct UpSamplingAttrs : public tvm::AttrsNode<UpSamplingAttrs> { struct UpSamplingAttrs : public tvm::AttrsNode<UpSamplingAttrs> {
int scale; double scale_h;
double scale_w;
std::string layout; std::string layout;
std::string method; std::string method;
bool align_corners; bool align_corners;
TVM_DECLARE_ATTRS(UpSamplingAttrs, "relay.attrs.UpSamplingAttrs") { TVM_DECLARE_ATTRS(UpSamplingAttrs, "relay.attrs.UpSamplingAttrs") {
TVM_ATTR_FIELD(scale) TVM_ATTR_FIELD(scale_h)
.describe("Should be true to preserve the values at the corner pixels"); .describe("The upsampling factor for height");
TVM_ATTR_FIELD(scale_w)
.describe("The upsampling factor for width");
TVM_ATTR_FIELD(layout).set_default("NCHW") TVM_ATTR_FIELD(layout).set_default("NCHW")
.describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width" "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
......
...@@ -219,7 +219,8 @@ def _upsampling(children, attrs, odtype='float32'): ...@@ -219,7 +219,8 @@ def _upsampling(children, attrs, odtype='float32'):
method = attrs.get_str('method', 'NEAREST_NEIGHBOR') method = attrs.get_str('method', 'NEAREST_NEIGHBOR')
return op.nn.upsampling( return op.nn.upsampling(
children[0], children[0],
scale=scale, scale_h=scale,
scale_w=scale,
layout=layout, layout=layout,
method=method) method=method)
......
...@@ -280,7 +280,7 @@ class ResizeNearest(Caffe2OpConverter): ...@@ -280,7 +280,7 @@ class ResizeNearest(Caffe2OpConverter):
assert width_scale == height_scale assert width_scale == height_scale
return _op.nn.upsampling( return _op.nn.upsampling(
inputs[0], scale=int(width_scale), method="NEAREST_NEIGHBOR") inputs[0], scale_h=int(width_scale), scale_w=int(width_scale), method="NEAREST_NEIGHBOR")
class Sum(Caffe2OpConverter): class Sum(Caffe2OpConverter):
......
...@@ -313,7 +313,8 @@ def _UpsampleLayerParams(op, inexpr, etab): ...@@ -313,7 +313,8 @@ def _UpsampleLayerParams(op, inexpr, etab):
raise tvm.error.OpAttributeUnimplemented( raise tvm.error.OpAttributeUnimplemented(
'Upsample height and width must be equal.') 'Upsample height and width must be equal.')
interpolationMode = 'nearest_neighbor' if op.mode == 0 else 'bilinear' interpolationMode = 'nearest_neighbor' if op.mode == 0 else 'bilinear'
return _op.nn.upsampling(inexpr, scale=op.scalingFactor[0], method=interpolationMode) return _op.nn.upsampling(inexpr, scale_h=op.scalingFactor[0],
scale_w=op.scalingFactor[1], method=interpolationMode)
def _L2NormalizeLayerParams(op, inexpr, etab): def _L2NormalizeLayerParams(op, inexpr, etab):
......
...@@ -129,7 +129,7 @@ def _darknet_shortcut(inputs, params, attrs, prefix): ...@@ -129,7 +129,7 @@ def _darknet_shortcut(inputs, params, attrs, prefix):
if input_0_size > input_1_size: if input_0_size > input_1_size:
scale = int(input_0_size/input_1_size) scale = int(input_0_size/input_1_size)
input_1 = get_relay_op('upsampling')(input_1, scale=scale) input_1 = get_relay_op('upsampling')(input_1, scale_h=scale, scale_w=scale)
elif input_0_size < input_1_size: elif input_0_size < input_1_size:
stride = int(input_1_size/input_0_size) stride = int(input_1_size/input_0_size)
...@@ -196,7 +196,8 @@ def _darknet_reshape(inputs, params, attrs, prefix): ...@@ -196,7 +196,8 @@ def _darknet_reshape(inputs, params, attrs, prefix):
def _darknet_upsampling(inputs, params, attrs, prefix): def _darknet_upsampling(inputs, params, attrs, prefix):
"""Process the upsampling operation.""" """Process the upsampling operation."""
new_attrs = {} new_attrs = {}
new_attrs['scale'] = attrs.get('scale', 1) new_attrs['scale_h'] = attrs.get('scale', 1)
new_attrs['scale_w'] = attrs.get('scale', 1)
return get_relay_op('upsampling')(*inputs, **new_attrs) return get_relay_op('upsampling')(*inputs, **new_attrs)
def _darknet_l2normalize(inputs, params, attrs, prefix): def _darknet_l2normalize(inputs, params, attrs, prefix):
......
...@@ -398,13 +398,14 @@ def _convert_upsample(inexpr, keras_layer, _): ...@@ -398,13 +398,14 @@ def _convert_upsample(inexpr, keras_layer, _):
params = {} params = {}
if upsample_type == 'UpSampling1D': if upsample_type == 'UpSampling1D':
h = keras_layer.size h = keras_layer.size
params['scale'] = h params['scale_h'] = h
elif upsample_type == 'UpSampling2D': elif upsample_type == 'UpSampling2D':
h, w = keras_layer.size h, w = keras_layer.size
if h != w: if h != w:
raise tvm.error.OpAttributeInvalid( raise tvm.error.OpAttributeInvalid(
'Height must equal width for operator Upsample.') 'Height must equal width for operator Upsample.')
params['scale'] = h params['scale_h'] = h
params['scale_w'] = h
if hasattr(keras_layer, 'interpolation'): if hasattr(keras_layer, 'interpolation'):
interpolation = keras_layer.interpolation interpolation = keras_layer.interpolation
...@@ -418,7 +419,8 @@ def _convert_upsample(inexpr, keras_layer, _): ...@@ -418,7 +419,8 @@ def _convert_upsample(inexpr, keras_layer, _):
if h != w or w != d: if h != w or w != d:
raise tvm.error.OpAttributeInvalid( raise tvm.error.OpAttributeInvalid(
'Height, width, and depth must all be equal for operator Upsample.') 'Height, width, and depth must all be equal for operator Upsample.')
params['scale'] = h params['scale_h'] = h
params['scale_w'] = h
else: else:
raise tvm.error.OpNotImplemented( raise tvm.error.OpNotImplemented(
'Operator {} is not supported for frontend Keras.'.format(upsample_type)) 'Operator {} is not supported for frontend Keras.'.format(upsample_type))
......
...@@ -112,7 +112,7 @@ def _transpose(inputs, attrs): ...@@ -112,7 +112,7 @@ def _transpose(inputs, attrs):
def _upsampling(inputs, attrs): def _upsampling(inputs, attrs):
scale = attrs.get_int("scale") scale = attrs.get_int("scale")
return _op.nn.upsampling(inputs[0], scale=scale) return _op.nn.upsampling(inputs[0], scale_h=scale, scale_w=scale)
def _elemwise_sum(inputs, _, _dtype='float32'): def _elemwise_sum(inputs, _, _dtype='float32'):
......
...@@ -581,7 +581,7 @@ class Upsample(OnnxOpConverter): ...@@ -581,7 +581,7 @@ class Upsample(OnnxOpConverter):
assert len(inputs) == 2, "Upsample op take 2 inputs, {} given".format(len(inputs)) assert len(inputs) == 2, "Upsample op take 2 inputs, {} given".format(len(inputs))
scales = params[inputs[1].name_hint].asnumpy() scales = params[inputs[1].name_hint].asnumpy()
inputs = inputs[:1] inputs = inputs[:1]
assert len(scales) == 4 and scales[0] == 1.0 and scales[1] == 1.0 and scales[2] == scales[3] assert len(scales) == 4 and scales[0] == 1.0 and scales[1] == 1.0
mode = attr.get('mode') mode = attr.get('mode')
if mode == b'nearest': if mode == b'nearest':
method = "nearest_neighbor" method = "nearest_neighbor"
...@@ -590,7 +590,8 @@ class Upsample(OnnxOpConverter): ...@@ -590,7 +590,8 @@ class Upsample(OnnxOpConverter):
else: else:
raise tvm.error.OpAttributeInvalid( raise tvm.error.OpAttributeInvalid(
'Value {} in attribute "mode" of operator Upsample is not valid.'.format(mode)) 'Value {} in attribute "mode" of operator Upsample is not valid.'.format(mode))
attr = {'scale':int(scales[-1]), 'method':method, 'layout':'NCHW', 'align_corners':True} attr = {'scale_h':scales[-2], 'scale_w':scales[-1], 'method':method,
'layout':'NCHW', 'align_corners':True}
return AttrCvt('upsampling')(inputs, attr) return AttrCvt('upsampling')(inputs, attr)
......
...@@ -409,11 +409,12 @@ def schedule_upsampling(_, outs, target): ...@@ -409,11 +409,12 @@ def schedule_upsampling(_, outs, target):
@reg.register_compute("nn.upsampling") @reg.register_compute("nn.upsampling")
def compute_upsampling(attrs, inputs, out_dtype, target): def compute_upsampling(attrs, inputs, out_dtype, target):
scale = attrs.scale scale_h = attrs.scale_h
scale_w = attrs.scale_w
layout = attrs.layout layout = attrs.layout
method = attrs.method method = attrs.method
align_corners = attrs.align_corners align_corners = attrs.align_corners
return [topi.nn.upsampling(inputs[0], scale, layout, method, align_corners)] return [topi.nn.upsampling(inputs[0], scale_h, scale_w, layout, method, align_corners)]
# pad # pad
reg.register_schedule("nn.pad", schedule_broadcast) reg.register_schedule("nn.pad", schedule_broadcast)
......
...@@ -483,7 +483,8 @@ def global_avg_pool2d(data, ...@@ -483,7 +483,8 @@ def global_avg_pool2d(data,
def upsampling(data, def upsampling(data,
scale=1, scale_h=1,
scale_w=1,
layout="NCHW", layout="NCHW",
method="nearest_neighbor", method="nearest_neighbor",
align_corners=False): align_corners=False):
...@@ -492,7 +493,7 @@ def upsampling(data, ...@@ -492,7 +493,7 @@ def upsampling(data,
This operator takes data as input and does 2D scaling to the given scale factor. This operator takes data as input and does 2D scaling to the given scale factor.
In the default case, where the data_layout is `NCHW` In the default case, where the data_layout is `NCHW`
with data of shape (n, c, h, w) with data of shape (n, c, h, w)
out will have a shape (n, c, h*scale, w*scale) out will have a shape (n, c, h*scale_h, w*scale_w)
method indicates the algorithm to be used while calculating the out value method indicates the algorithm to be used while calculating the out value
and method can be one of ("bilinear", "nearest_neighbor", "bicubic") and method can be one of ("bilinear", "nearest_neighbor", "bicubic")
...@@ -502,8 +503,11 @@ def upsampling(data, ...@@ -502,8 +503,11 @@ def upsampling(data,
data : tvm.relay.Expr data : tvm.relay.Expr
The input data to the operator. The input data to the operator.
scale : tvm.relay.Expr scale_h : tvm.relay.Expr
The scale factor for upsampling. The scale factor for height upsampling.
scale_w : tvm.relay.Expr
The scale factor for width upsampling.
layout : str, optional layout : str, optional
Layout of the input. Layout of the input.
...@@ -519,7 +523,7 @@ def upsampling(data, ...@@ -519,7 +523,7 @@ def upsampling(data,
result : tvm.relay.Expr result : tvm.relay.Expr
The computed result. The computed result.
""" """
return _make.upsampling(data, scale, layout, method, align_corners) return _make.upsampling(data, scale_h, scale_w, layout, method, align_corners)
def batch_flatten(data): def batch_flatten(data):
......
...@@ -80,9 +80,8 @@ bool UpSamplingRel(const Array<Type>& types, ...@@ -80,9 +80,8 @@ bool UpSamplingRel(const Array<Type>& types,
<< " But got " << in_layout; << " But got " << in_layout;
auto oshape = layout_converter.ForwardShape(data->shape); auto oshape = layout_converter.ForwardShape(data->shape);
oshape.Set(2, ir::Cast::make(oshape[2].type(), tvm::round(oshape[2] * param->scale_h)));
oshape.Set(2, oshape[2] * param->scale); oshape.Set(3, ir::Cast::make(oshape[3].type(), tvm::round(oshape[3] * param->scale_w)));
oshape.Set(3, oshape[3] * param->scale);
// assign output type // assign output type
reporter->Assign(types[1], reporter->Assign(types[1],
...@@ -95,14 +94,16 @@ bool UpSamplingRel(const Array<Type>& types, ...@@ -95,14 +94,16 @@ bool UpSamplingRel(const Array<Type>& types,
// Positional relay function to create upsampling operator // Positional relay function to create upsampling operator
// used by frontend FFI. // used by frontend FFI.
Expr MakeUpSampling(Expr data, Expr MakeUpSampling(Expr data,
int scale, double scale_h,
double scale_w,
std::string layout, std::string layout,
std::string method, std::string method,
bool align_corners) { bool align_corners) {
auto attrs = make_node<UpSamplingAttrs>(); auto attrs = make_node<UpSamplingAttrs>();
attrs->layout = std::move(layout); attrs->layout = std::move(layout);
attrs->method = std::move(method); attrs->method = std::move(method);
attrs->scale = scale; attrs->scale_h = scale_h;
attrs->scale_w = scale_w;
attrs->align_corners = align_corners; attrs->align_corners = align_corners;
static const Op& op = Op::Get("nn.upsampling"); static const Op& op = Op::Get("nn.upsampling");
return CallNode::make(op, {data}, Attrs(attrs), {}); return CallNode::make(op, {data}, Attrs(attrs), {});
......
...@@ -232,14 +232,17 @@ def test_conv2d_transpose_run(): ...@@ -232,14 +232,17 @@ def test_conv2d_transpose_run():
def test_upsampling_infer_type(): def test_upsampling_infer_type():
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
scale = tvm.const(2.0, "float64")
x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) x = relay.var("x", relay.TensorType((n, c, h, w), "float32"))
y = relay.nn.upsampling(x, scale=2, layout="NCHW", method="bilinear") y = relay.nn.upsampling(x, scale_h=2, scale_w=2, layout="NCHW", method="bilinear")
"method=\"BINLINEAR\"" in y.astext() "method=\"BINLINEAR\"" in y.astext()
yy = run_infer_type(y) yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, c, h*2, w*2), "float32") assert yy.checked_type == relay.TensorType((n, c, tvm.expr.Cast("int32", tvm.round(h*scale)),
tvm.expr.Cast("int32", tvm.round(w*scale))),
"float32")
n, c = tvm.var("n"), tvm.var("c") n, c = tvm.var("n"), tvm.var("c")
x = relay.var("x", relay.TensorType((n, c, 100, 200), "float32")) x = relay.var("x", relay.TensorType((n, c, 100, 200), "float32"))
y = relay.nn.upsampling(x, scale=2, layout="NCHW", method="bilinear") y = relay.nn.upsampling(x, scale_h=2, scale_w=2, layout="NCHW", method="bilinear")
yy = run_infer_type(y) yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, c, 200, 400), "float32") assert yy.checked_type == relay.TensorType((n, c, 200, 400), "float32")
...@@ -504,29 +507,31 @@ def test_batch_flatten(): ...@@ -504,29 +507,31 @@ def test_batch_flatten():
def _test_upsampling(layout, method, align_corners=False): def _test_upsampling(layout, method, align_corners=False):
n, c, h, w = tvm.var("n"), 16, 32, 32 n, c, h, w = tvm.var("n"), 16, 32, 32
scale = 2 scale_h = 2.0
scale_w = 2.0
dtype = "float32" dtype = "float32"
def get_shape(): def get_shape():
if layout == "NCHW": if layout == "NCHW":
return (c, h, w), (c, h*scale, w*scale) return (c, h, w), (c, int(round(h*scale_h)), int(round(w*scale_w)))
else: else:
return (h, w, c), (h*scale, w*scale, c) return (h, w, c), (int(round(h*scale_h)), int(round(w*scale_w)), c)
ishape, oshape = get_shape() ishape, oshape = get_shape()
x = relay.var("x", relay.TensorType((n,) + ishape, dtype)) x = relay.var("x", relay.TensorType((n,) + ishape, dtype))
y = relay.nn.upsampling(x, scale=scale, layout=layout, y = relay.nn.upsampling(x, scale_h=scale_h, scale_w=scale_w, layout=layout,
method=method, align_corners=align_corners) method=method, align_corners=align_corners)
yy = run_infer_type(y) yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n,) + oshape, dtype) assert yy.checked_type == relay.TensorType((n,) + oshape, dtype)
dshape = (1,) + ishape dshape = (1,) + ishape
x = relay.var("x", shape=dshape) x = relay.var("x", shape=dshape)
y = relay.nn.upsampling(x, scale=scale, layout=layout, y = relay.nn.upsampling(x, scale_h=scale_h, scale_w=scale_w, layout=layout,
method=method, align_corners=align_corners) method=method, align_corners=align_corners)
func = relay.Function([x], y) func = relay.Function([x], y)
data = np.random.uniform(size=dshape).astype(dtype) data = np.random.uniform(size=dshape).astype(dtype)
if method == "nearest_neighbor": if method == "nearest_neighbor":
ref = topi.testing.upsampling_python(data, (scale, scale), layout) ref = topi.testing.upsampling_python(data, (scale_h, scale_w), layout)
else: else:
ref = topi.testing.bilinear_resize_python(data, (h*scale, w*scale), layout) ref = topi.testing.bilinear_resize_python(data, (int(round(h*scale_h)),
int(round(w*scale_w))), layout)
for target, ctx in ctx_list(): for target, ctx in ctx_list():
executor = relay.create_executor("graph", ctx=ctx, target=target) executor = relay.create_executor("graph", ctx=ctx, target=target)
out = executor.evaluate(func)(data) out = executor.evaluate(func)(data)
......
...@@ -487,7 +487,7 @@ def test_alter_layout_nchw_upsamping_op(): ...@@ -487,7 +487,7 @@ def test_alter_layout_nchw_upsamping_op():
x = relay.var("x", shape=(1, 32, 28, 28)) x = relay.var("x", shape=(1, 32, 28, 28))
weight = relay.var('weight', shape=(32, 32, 3, 3)) weight = relay.var('weight', shape=(32, 32, 3, 3))
y = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1)) y = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1))
y = relay.nn.upsampling(y, scale=2) y = relay.nn.upsampling(y, scale_h=2, scale_w=2)
y = relay.nn.avg_pool2d(y, pool_size=(2, 2), strides=(2, 2)) y = relay.nn.avg_pool2d(y, pool_size=(2, 2), strides=(2, 2))
y = relay.Function(analysis.free_vars(y), y) y = relay.Function(analysis.free_vars(y), y)
return y return y
...@@ -506,7 +506,7 @@ def test_alter_layout_nchw_upsamping_op(): ...@@ -506,7 +506,7 @@ def test_alter_layout_nchw_upsamping_op():
x = relay.layout_transform(x, "NCHW", "NCHW16c") x = relay.layout_transform(x, "NCHW", "NCHW16c")
y = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1), y = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1),
data_layout="NCHW16c") data_layout="NCHW16c")
y = relay.nn.upsampling(y, scale=2, layout="NCHW16c") y = relay.nn.upsampling(y, scale_h=2, scale_w=2, layout="NCHW16c")
y = relay.nn.avg_pool2d(y, pool_size=(2, 2), strides=(2, 2), layout='NCHW16c') y = relay.nn.avg_pool2d(y, pool_size=(2, 2), strides=(2, 2), layout='NCHW16c')
y = relay.layout_transform(y, "NCHW16c", "NCHW") y = relay.layout_transform(y, "NCHW16c", "NCHW")
y = relay.Function(analysis.free_vars(y), y) y = relay.Function(analysis.free_vars(y), y)
......
...@@ -126,7 +126,7 @@ def test_concatenate(): ...@@ -126,7 +126,7 @@ def test_concatenate():
def before(dshape): def before(dshape):
x = relay.var("x", shape=dshape) x = relay.var("x", shape=dshape)
pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
upsampled = relay.nn.upsampling(pooled, scale=2, layout="NCHW") upsampled = relay.nn.upsampling(pooled, scale_h=2, scale_w=2, layout="NCHW")
concat = relay.concatenate((upsampled, x), axis=1) concat = relay.concatenate((upsampled, x), axis=1)
out = relay.add(concat, relay.const(1, "float32")) out = relay.add(concat, relay.const(1, "float32"))
return relay.Function(relay.analysis.free_vars(out), out) return relay.Function(relay.analysis.free_vars(out), out)
...@@ -138,7 +138,7 @@ def test_concatenate(): ...@@ -138,7 +138,7 @@ def test_concatenate():
p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, dshape[3]//2)) p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, dshape[3]//2))
p1 = relay.var("p1", shape=dshape) p1 = relay.var("p1", shape=dshape)
upsampled = relay.nn.upsampling(p0, scale=2, layout="NCHW") upsampled = relay.nn.upsampling(p0, scale_h=2, scale_w=2, layout="NCHW")
concat = relay.concatenate((upsampled, p1), axis=1) concat = relay.concatenate((upsampled, p1), axis=1)
out = relay.add(concat, relay.const(1, "float32")) out = relay.add(concat, relay.const(1, "float32"))
f1 = relay.Function([p0, p1], out) f1 = relay.Function([p0, p1], out)
...@@ -164,7 +164,7 @@ def test_tuple_root(): ...@@ -164,7 +164,7 @@ def test_tuple_root():
def before(dshape): def before(dshape):
x = relay.var("x", shape=dshape) x = relay.var("x", shape=dshape)
pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
upsampled = relay.nn.upsampling(pooled, scale=2, layout="NCHW") upsampled = relay.nn.upsampling(pooled, scale_h=2, scale_w=2, layout="NCHW")
out = relay.Tuple((upsampled, x)) out = relay.Tuple((upsampled, x))
return relay.Function(relay.analysis.free_vars(out), out) return relay.Function(relay.analysis.free_vars(out), out)
...@@ -174,7 +174,7 @@ def test_tuple_root(): ...@@ -174,7 +174,7 @@ def test_tuple_root():
f0 = relay.Function([x], pooled) f0 = relay.Function([x], pooled)
p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, dshape[3]//2)) p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, dshape[3]//2))
upsampled = relay.nn.upsampling(p0, scale=2, layout="NCHW") upsampled = relay.nn.upsampling(p0, scale_h=2, scale_w=2, layout="NCHW")
f1 = relay.Function([p0], upsampled) f1 = relay.Function([p0], upsampled)
x = relay.var("x", shape=dshape) x = relay.var("x", shape=dshape)
......
...@@ -17,10 +17,12 @@ ...@@ -17,10 +17,12 @@
"""TVM operator upsampling compute.""" """TVM operator upsampling compute."""
from __future__ import absolute_import from __future__ import absolute_import
import topi import topi
import tvm
from ..util import simplify from ..util import simplify
def upsampling(data, scale, layout="NCHW", method='nearest_neighbor', align_corners=False): def upsampling(data, scale_h, scale_w, layout="NCHW", method='nearest_neighbor',
align_corners=False):
"""Perform upsampling on the data. """Perform upsampling on the data.
Nearest neighbor and bilinear upsampling are supported. Nearest neighbor and bilinear upsampling are supported.
...@@ -31,8 +33,11 @@ def upsampling(data, scale, layout="NCHW", method='nearest_neighbor', align_corn ...@@ -31,8 +33,11 @@ def upsampling(data, scale, layout="NCHW", method='nearest_neighbor', align_corn
[batch, channel, in_height, in_width] [batch, channel, in_height, in_width]
or [batch, in_height, in_width, channel] or [batch, in_height, in_width, channel]
scale : int scale_h : float
Scaling factor Scaling factor for height
scale_w : float
Scaling factor for width
layout : string, optional layout : string, optional
either "NCHW" or "NHWC" either "NCHW" or "NHWC"
...@@ -43,14 +48,17 @@ def upsampling(data, scale, layout="NCHW", method='nearest_neighbor', align_corn ...@@ -43,14 +48,17 @@ def upsampling(data, scale, layout="NCHW", method='nearest_neighbor', align_corn
Returns Returns
------- -------
output : tvm.Tensor output : tvm.Tensor
4-D with shape [batch, channel, in_height*scale, in_width*scale] 4-D with shape [batch, channel, in_height*scale_h, in_width*scale_w]
or [batch, in_height*scale, in_width*scale, channel] or [batch, in_height*scale, in_width*scale, channel]
""" """
base_layout = layout[0:4] base_layout = layout[0:4]
if base_layout == "NCHW": if base_layout == "NCHW":
out_shape = (simplify(data.shape[2] * scale), simplify(data.shape[3] * scale)) out_shape = (simplify(topi.cast(tvm.round(data.shape[2] * scale_h), data.shape[2].dtype)),
simplify(topi.cast(tvm.round(data.shape[3] * scale_w), data.shape[3].dtype)))
elif layout == "NHWC": elif layout == "NHWC":
out_shape = (simplify(data.shape[1] * scale), simplify(data.shape[2] * scale)) out_shape = (simplify(topi.cast(tvm.round(data.shape[1] * scale_h), data.shape[1].dtype)),
simplify(topi.cast(tvm.round(data.shape[2] * scale_w), data.shape[2].dtype)))
else: else:
raise ValueError("not support this layout {} yet".format(layout)) raise ValueError("not support this layout {} yet".format(layout))
return topi.image.resize(data, out_shape, layout=layout, return topi.image.resize(data, out_shape, layout=layout,
......
...@@ -22,8 +22,8 @@ import numpy as np ...@@ -22,8 +22,8 @@ import numpy as np
def upsample_nearest(arr, scale): def upsample_nearest(arr, scale):
""" Populate the array by scale factor""" """ Populate the array by scale factor"""
h, w = arr.shape h, w = arr.shape
out_h = math.floor(h * scale[0]) out_h = int(round(h * scale[0]))
out_w = math.floor(w * scale[1]) out_w = int(round(w * scale[1]))
out = np.empty((out_h, out_w)) out = np.empty((out_h, out_w))
for y in range(out_h): for y in range(out_h):
for x in range(out_w): for x in range(out_w):
...@@ -37,14 +37,16 @@ def upsampling_python(data, scale, layout='NCHW'): ...@@ -37,14 +37,16 @@ def upsampling_python(data, scale, layout='NCHW'):
ishape = data.shape ishape = data.shape
if layout == 'NCHW': if layout == 'NCHW':
oshape = (ishape[0], ishape[1], math.floor(ishape[2]*scale[0]), math.floor(ishape[3]*scale[1])) oshape = (ishape[0], ishape[1], int(round(ishape[2]*scale[0])),
int(round(ishape[3]*scale[1])))
output_np = np.zeros(oshape, dtype=data.dtype) output_np = np.zeros(oshape, dtype=data.dtype)
for b in range(oshape[0]): for b in range(oshape[0]):
for c in range(oshape[1]): for c in range(oshape[1]):
output_np[b, c, :, :] = upsample_nearest(data[b, c, :, :], scale) output_np[b, c, :, :] = upsample_nearest(data[b, c, :, :], scale)
return output_np return output_np
if layout == 'NHWC': if layout == 'NHWC':
oshape = (ishape[0], math.floor(ishape[1]*scale[0]), math.floor(ishape[1]*scale[1]), ishape[3]) oshape = (ishape[0], int(round(ishape[1]*scale[0])),
int(round(ishape[2]*scale[1])), ishape[3])
output_np = np.zeros(oshape, dtype=data.dtype) output_np = np.zeros(oshape, dtype=data.dtype)
for b in range(oshape[0]): for b in range(oshape[0]):
for c in range(oshape[3]): for c in range(oshape[3]):
......
...@@ -23,30 +23,29 @@ import math ...@@ -23,30 +23,29 @@ import math
from common import get_all_backend from common import get_all_backend
def verify_upsampling(batch, in_channel, in_height, in_width, scale, layout='NCHW', method="nearest_neighbor"): def verify_upsampling(batch, in_channel, in_height, in_width, scale_h, scale_w,
layout='NCHW', method="nearest_neighbor"):
if layout == 'NCHW': if layout == 'NCHW':
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A') A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
dtype = A.dtype dtype = A.dtype
out_shape = (batch, in_channel, in_height*scale, in_width*scale) out_shape = (batch, in_channel, int(round(in_height*scale_h)), int(round(in_width*scale_w)))
a_np = np.random.uniform(size=(batch, in_channel, in_height, in_width)).astype(dtype) a_np = np.random.uniform(size=(batch, in_channel, in_height, in_width)).astype(dtype)
elif layout == 'NHWC': elif layout == 'NHWC':
A = tvm.placeholder((batch, in_height, in_width, in_channel), name='A') A = tvm.placeholder((batch, in_height, in_width, in_channel), name='A')
dtype = A.dtype dtype = A.dtype
out_shape = (batch, in_height*scale, in_width*scale, in_channel) out_shape = (batch, int(round(in_height*scale_h)), int(round(in_width*scale_w)), in_channel)
a_np = np.random.uniform(size=(batch, in_height, in_width, in_channel)).astype(dtype) a_np = np.random.uniform(size=(batch, in_height, in_width, in_channel)).astype(dtype)
else: else:
raise NotImplementedError( raise NotImplementedError(
'Layout not supported {} '.format(layout)) 'Layout not supported {} '.format(layout))
B = topi.nn.upsampling(A, scale, layout=layout, method=method, align_corners=False) B = topi.nn.upsampling(A, scale_h, scale_w, layout=layout, method=method, align_corners=False)
if method == "bilinear": if method == "bilinear":
out_size = (in_height*scale, in_width*scale) out_size = (int(round(in_height*scale_h)), int(round(in_width*scale_w)))
b_np = topi.testing.bilinear_resize_python(a_np, out_size, layout, align_corners=False) b_np = topi.testing.bilinear_resize_python(a_np, out_size, layout, align_corners=False)
else: else:
b_np = topi.testing.upsampling_python(a_np, (scale, scale), layout) b_np = topi.testing.upsampling_python(a_np, (scale_h, scale_w), layout)
def check_device(device): def check_device(device):
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
...@@ -68,20 +67,24 @@ def verify_upsampling(batch, in_channel, in_height, in_width, scale, layout='NCH ...@@ -68,20 +67,24 @@ def verify_upsampling(batch, in_channel, in_height, in_width, scale, layout='NCH
def test_upsampling(): def test_upsampling():
# nearest_neighbor - NCHW # nearest_neighbor - NCHW
verify_upsampling(8, 16, 32, 32, 2) verify_upsampling(8, 16, 32, 32, 2.0, 2.0)
verify_upsampling(2, 32, 64, 64, 3) verify_upsampling(2, 32, 64, 64, 3.0, 3.0)
verify_upsampling(1, 64, 22, 32, 1.954545497894287, 2.0)
## nearest_neighbor - NHWC ## nearest_neighbor - NHWC
verify_upsampling(8, 16, 32, 32, 2, layout="NHWC") verify_upsampling(8, 16, 32, 32, 2.0, 2.0, layout="NHWC")
verify_upsampling(2, 32, 64, 64, 3, layout="NHWC") verify_upsampling(2, 32, 64, 64, 3.0, 3.0, layout="NHWC")
verify_upsampling(1, 64, 22, 32, 1.954545497894287, 2.0, layout="NHWC")
# bilinear - NCHW # bilinear - NCHW
verify_upsampling(2, 2, 32, 32, 2, method="bilinear") verify_upsampling(2, 2, 32, 32, 2.0, 2.0, method="bilinear")
verify_upsampling(2, 2, 32, 32, 3, method="bilinear") verify_upsampling(2, 2, 32, 32, 3.0, 3.0, method="bilinear")
verify_upsampling(1, 64, 22, 32, 1.954545497894287, 2.0, method="bilinear")
# bilinear - NHWC # bilinear - NHWC
verify_upsampling(2, 2, 32, 32, 2, layout="NHWC", method="bilinear") verify_upsampling(2, 2, 32, 32, 2.0, 2.0, layout="NHWC", method="bilinear")
verify_upsampling(2, 2, 32, 32, 3, layout="NHWC", method="bilinear") verify_upsampling(2, 2, 32, 32, 3.0, 3.0, layout="NHWC", method="bilinear")
verify_upsampling(1, 64, 22, 32, 3.0, 3.0, layout="NHWC", method="bilinear")
if __name__ == "__main__": if __name__ == "__main__":
test_upsampling() test_upsampling()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment