Commit 3f599a60 by Xingjian Shi Committed by Tianqi Chen

add squeeze (#52)

* add transform

* fix

* update doc

* Update tvm
parent 5541a275
......@@ -41,6 +41,7 @@ This level enables fully connected multi-layer perceptron.
nnvm.symbol.flatten
nnvm.symbol.concatenate
nnvm.symbol.expand_dims
nnvm.symbol.squeeze
nnvm.symbol.split
nnvm.symbol.dropout
nnvm.symbol.batch_norm
......@@ -112,6 +113,7 @@ Detailed Definitions
.. autofunction:: nnvm.symbol.flatten
.. autofunction:: nnvm.symbol.concatenate
.. autofunction:: nnvm.symbol.expand_dims
.. autofunction:: nnvm.symbol.squeeze
.. autofunction:: nnvm.symbol.split
.. autofunction:: nnvm.symbol.dropout
.. autofunction:: nnvm.symbol.batch_norm
......
......@@ -79,6 +79,16 @@ struct ReshapeParam : public dmlc::Parameter<ReshapeParam> {
}
};
struct SqueezeParam : public dmlc::Parameter<SqueezeParam> {
TShape axis;
DMLC_DECLARE_PARAMETER(SqueezeParam) {
DMLC_DECLARE_FIELD(axis).set_default(TShape())
.describe("The axis to squeeze in the input tensor."
" If set to None, all size=1 axes will be squeezed");
}
};
struct ScalarParam : public dmlc::Parameter<ScalarParam> {
double scalar;
......
......@@ -47,7 +47,7 @@ class Engine(object):
"""
res = _list_cache_items()
assert len(res) % 2 == 0
return [(res[2*i], res[2*i+1]) for i in range(len(res)/2)]
return [(res[2*i], res[2*i+1]) for i in range(len(res) // 2)]
def clear_cache(self):
"""Clear the existing cached functions."""
......
......@@ -36,6 +36,16 @@ def compute_reshape(attrs, inputs, out_info):
reg.register_pattern("reshape", OpPattern.INJECTIVE)
reg.register_schedule("reshape", _fschedule_injective)
# reshape
@reg.register_compute("squeeze")
def compute_squeeze(attrs, inputs, out_info):
"""Compute definition of reshape"""
axis = attrs.get_int_tuple("axis")
axis = tuple(axis) if axis else None
return topi.squeeze(inputs[0], axis)
reg.register_pattern("squeeze", OpPattern.INJECTIVE)
reg.register_schedule("squeeze", _fschedule_injective)
# concatenate
@reg.register_compute("concatenate")
def compute_concatenate(attrs, inputs, out_info):
......
......@@ -10,6 +10,7 @@
#include <dmlc/parameter.h>
#include <string>
#include <vector>
#include <unordered_set>
namespace nnvm {
namespace top {
......
......@@ -445,6 +445,80 @@ The significance of each is explained below:
.set_num_outputs(1)
.set_support_level(3);
// squeeze
DMLC_REGISTER_PARAMETER(SqueezeParam);
inline bool SqueezeShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape>* in_attrs,
std::vector<TShape>* out_attrs) {
const SqueezeParam& param = nnvm::get<SqueezeParam>(attrs.parsed);
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
const TShape& shp = (*in_attrs)[0];
if (shp.ndim() == 0) return false;
std::vector<int64_t> oshape;
if (param.axis.ndim() == 0) {
for (dim_t i = 0; i < shp.ndim(); ++i) {
if(shp[i] != 1) {
oshape.emplace_back(shp[i]);
}
}
} else {
std::unordered_set<dim_t> axis_checker;
for (size_t i = 0; i < param.axis.ndim(); ++i) {
if(param.axis[i] < 0) {
int real_axis = param.axis[i] + static_cast<int>(shp.ndim());
CHECK(real_axis < static_cast<int>(shp.ndim()) && real_axis >= 0);
axis_checker.insert(real_axis);
}
}
for (size_t i = 0; i < shp.ndim(); ++i) {
if(axis_checker.find(i) == axis_checker.end()) {
oshape.emplace_back(shp[i]);
} else {
CHECK_EQ(shp[i], 1) << "The squeezed axis must have shape 1!"
<< "Want to squeeze " << i
<< ", which has shape" << shp[i];
}
}
}
if(oshape.size() == 0) {
// Handles the case where all axes are squeezed.
oshape.push_back(1);
}
TShape out_shape(oshape.begin(), oshape.end());
CHECK_EQ(out_shape.Size(), shp.Size())
<< "Target shape size is different to source. "
<< "Target: " << out_shape
<< "\nSource: " << shp;
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, out_shape);
return true;
}
NNVM_REGISTER_OP(squeeze)
.describe(R"code(Squeeze axises in the array.
Examples::
x = [[[0], [1], [2]]]
squeeze(x) = [0, 1, 2]
squeeze(x, 0) = [[0], [1], [2]]
squeeze(x, (0, 2)) = [0, 1, 2]
)code" NNVM_ADD_FILELINE)
.add_argument("data", "Tensor", "Source input")
.add_arguments(SqueezeParam::__FIELDS__())
.set_attr_parser(ParamParser<SqueezeParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<SqueezeParam>)
.set_attr<nnvm::FInferShape>("FInferShape", SqueezeShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_num_inputs(1)
.set_num_outputs(1)
.set_support_level(1);
// tranpose
DMLC_REGISTER_PARAMETER(TransposeParam);
......
......@@ -220,6 +220,30 @@ def test_split():
verify_split((5, 3), [3], axis=0)
verify_split((5, 9, 3), [3, 4], axis=1)
def verify_squeeze(dshape, axis):
x = sym.Variable("x")
if axis:
y = sym.squeeze(x, axis=axis)
else:
y = sym.squeeze(x)
y = y + 1
dtype = "float32"
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
m = graph_runtime.create(graph, lib, ctx)
# set input
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
m.run(x=data)
out_np = np.squeeze(data.asnumpy(), axis=axis) + 1
out = m.get_output(0, tvm.nd.empty(out_np.shape))
np.testing.assert_allclose(out.asnumpy(), out_np, atol=1e-5, rtol=1e-5)
def test_squeeze():
verify_squeeze((1, 3, 2, 5), None)
verify_squeeze((1, 3, 1), axis=0)
verify_squeeze((1, 3, 2, 5, 1), axis=-1)
if __name__ == "__main__":
test_split()
test_concatenate()
......@@ -232,3 +256,4 @@ if __name__ == "__main__":
test_tanh()
test_sigmoid()
test_softmax()
test_squeeze()
......@@ -71,7 +71,6 @@ def test_reshape():
verify_reshape((2, 3, 4), (8, 3))
verify_reshape((4, 7), (2, 7, 2))
if __name__ == "__main__":
test_reshape()
test_reduce()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment