Commit 6a0fb6ef by masahi Committed by Tianqi Chen

Upsampling op support (#298)

* add nnvm upsampling symbol

* add upsampling mxnet frontend

* add doc for upsampling op

* cleanup upsampling test

* minor fix

* use schedule_injective for upsampling

* upgrade tvm
parent d059dbab
......@@ -257,6 +257,15 @@ struct GlobalPool2DParam : public dmlc::Parameter<GlobalPool2DParam> {
}
};
struct UpSamplingParam : public dmlc::Parameter<UpSamplingParam> {
int scale;
DMLC_DECLARE_PARAMETER(UpSamplingParam) {
DMLC_DECLARE_FIELD(scale)
.describe("upsampling scaling factor");
}
};
} // namespace top
} // namespace nnvm
......
......@@ -190,6 +190,12 @@ def _softmax_output(inputs, attrs):
new_attrs['axis'] = 1
return _get_nnvm_op(op_name)(inputs[0], **new_attrs)
def _upsampling(inputs, attrs):
scale = attrs.get('scale')
new_attrs = {'scale':int(scale)}
return _get_nnvm_op('upsampling')(inputs[0], **new_attrs)
_identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__',
'__div_symbol__', '__mul_scalar__', '__mul_symbol__',
'__pow_scalar__', '__rdiv_scalar__', '__rpow_scalar__',
......@@ -231,6 +237,7 @@ _convert_map = {
'min_axis' : _rename('min'),
'reshape' : _reshape,
'sum_axis' : _rename('sum'),
'UpSampling' : _upsampling
}
def _convert_symbol(op_name, inputs, attrs,
......
......@@ -250,3 +250,18 @@ def schedule_global_avg_pool2d(_, outs, target):
return topi.generic.schedule_global_pool(outs)
reg.register_pattern("global_avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
@reg.register_compute("upsampling")
def compute_upsampling(attrs, inputs, _):
"""Compute definition of upsampling"""
scale = attrs.get_int("scale")
return topi.nn.upsampling(inputs[0], scale)
@reg.register_schedule("upsampling")
def schedule_upsampling(_, outs, target):
"""Compute definition of upsampling"""
with tvm.target.create(target):
return topi.generic.schedule_injective(outs)
reg.register_pattern("upsampling", OpPattern.OUT_ELEMWISE_FUSABLE)
/*!
* Copyright (c) 2017 by Contributors
* \file pooling.cc
* \brief Property def of pooling operators.
*/
#include <nnvm/op.h>
#include <nnvm/node.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/top/nn.h>
#include "./nn_common.h"
#include "../op_common.h"
#include "../elemwise_op_common.h"
namespace nnvm {
namespace top {
DMLC_REGISTER_PARAMETER(UpSamplingParam);
inline bool UpSamplingInferShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape>* in_shape,
std::vector<TShape>* out_shape) {
const UpSamplingParam& param = nnvm::get<UpSamplingParam>(attrs.parsed);
CHECK_EQ(in_shape->size(), 1U);
CHECK_EQ(out_shape->size(), 1U);
TShape dshape = (*in_shape)[0];
if (dshape.ndim() == 0) return false;
TShape oshape = dshape;
oshape[2] = oshape[2] * param.scale;
oshape[3] = oshape[3] * param.scale;
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape);
return true;
}
NNVM_REGISTER_OP(upsampling)
.describe(R"(Perform nearest neighbor upsampling to input array.
- **data**: Input is 4D array of shape (batch_size, channels, in_height, in_width).
- **out**: Output is 4D array of shape (batch_size, channels, in_height*scale, in_width*scale).
)" NNVM_ADD_FILELINE)
.add_argument("data", "4D Tensor", "Input data.")
.add_arguments(UpSamplingParam::__FIELDS__())
.set_attr_parser(ParamParser<UpSamplingParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<UpSamplingParam>)
.set_attr<FInferShape>("FInferShape", UpSamplingInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_num_outputs(1)
.set_num_inputs(1)
.set_support_level(2);
} // namespace top
} // namespace nnvm
......@@ -148,6 +148,25 @@ def test_global_avg_pool2d():
np.testing.assert_allclose(out.asnumpy(), b_np, rtol=1e-5)
def test_upsampling():
x = sym.Variable("x")
scale = 2
y = sym.upsampling(x, scale=scale, name="y")
dtype = "float32"
dshape = (1, 16, 32, 32)
oshape = (1, 16, 32*scale, 32*scale)
shape_dict = {"x": dshape}
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, shape_dict)
m = graph_runtime.create(graph, lib, ctx)
a_np = np.random.uniform(size=dshape).astype(dtype)
data = tvm.nd.array(a_np)
m.run(x=data)
out = m.get_output(0, tvm.nd.empty(oshape, dtype))
b_np = topi.testing.upsampling_python(a_np, scale)
np.testing.assert_allclose(out.asnumpy(), b_np, rtol=1e-5)
if __name__ == "__main__":
test_conv2d()
test_grouped_conv2d()
......@@ -156,3 +175,4 @@ if __name__ == "__main__":
test_avg_pool2d()
test_global_max_pool2d()
test_global_avg_pool2d()
test_upsampling()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment