Commit 44d8203f by nhynes Committed by Tianqi Chen

[NNVM][TOPI] Add gradients for broadcast_* ops (#1234)

parent 3f7cce3b
...@@ -53,11 +53,18 @@ class Tuple { ...@@ -53,11 +53,18 @@ class Tuple {
this->assign(init.begin(), init.end()); this->assign(init.begin(), init.end());
} }
/*! /*!
* \brief constructor from vector
* \param init the vector
*/
inline Tuple(std::vector<ValueType> init) { // NOLINT(runtime/explicit)
this->assign(init.begin(), init.end());
}
/*!
* \brief move constructor from Tuple * \brief move constructor from Tuple
* \param src the source shape * \param src the source shape
*/ */
inline Tuple(Tuple<ValueType>&& src) { // NOLINT(*) inline Tuple(Tuple<ValueType>&& src) { // NOLINT(runtime/explicit)
this->swap(src); this->swap(src);
} }
/*! /*!
......
...@@ -37,3 +37,7 @@ reg.register_schedule("max", _fschedule_reduce) ...@@ -37,3 +37,7 @@ reg.register_schedule("max", _fschedule_reduce)
# min # min
reg.register_pattern("min", OpPattern.COMM_REDUCE) reg.register_pattern("min", OpPattern.COMM_REDUCE)
reg.register_schedule("min", _fschedule_reduce) reg.register_schedule("min", _fschedule_reduce)
# collapse sum
reg.register_pattern("collapse_sum", OpPattern.COMM_REDUCE)
reg.register_schedule("collapse_sum", _fschedule_reduce)
...@@ -238,7 +238,15 @@ Example:: ...@@ -238,7 +238,15 @@ Example::
broadcast_add(x, y) = [[ 1., 1., 1.], broadcast_add(x, y) = [[ 1., 1., 1.],
[ 2., 2., 2.]] [ 2., 2., 2.]]
)code" NNVM_ADD_FILELINE); )code" NNVM_ADD_FILELINE)
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
return std::vector<NodeEntry>{
MakeNode("collapse_sum", n->attrs.name + "_dlhs", { ograds[0], n->inputs[0] }),
MakeNode("collapse_sum", n->attrs.name + "_drhs", { ograds[0], n->inputs[1] })
};
});
NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_sub, subtract) NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_sub, subtract)
...@@ -256,7 +264,18 @@ Example:: ...@@ -256,7 +264,18 @@ Example::
broadcast_sub(x, y) = [[ 1., 1., 1.], broadcast_sub(x, y) = [[ 1., 1., 1.],
[ 0., 0., 0.]] [ 0., 0., 0.]]
)code" NNVM_ADD_FILELINE); )code" NNVM_ADD_FILELINE)
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
return std::vector<NodeEntry>{
MakeNode("collapse_sum", n->attrs.name + "_dlhs", { ograds[0], n->inputs[0] }),
MakeNode("collapse_sum", n->attrs.name + "_drhs", {
MakeNode("negative", n->attrs.name + "_drhs_neg", {ograds[0]}),
n->inputs[1]
})
};
});
NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_mul, multiply) NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_mul, multiply)
...@@ -273,7 +292,22 @@ Example:: ...@@ -273,7 +292,22 @@ Example::
broadcast_mul(x, y) = [[ 0., 0., 0.], broadcast_mul(x, y) = [[ 0., 0., 0.],
[ 1., 1., 1.]] [ 1., 1., 1.]]
)code" NNVM_ADD_FILELINE); )code" NNVM_ADD_FILELINE)
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
NodeEntry dlhs = MakeNode("collapse_sum", n->attrs.name + "_dlhs_sum", {
MakeNode("broadcast_mul", n->attrs.name + "_dlhs_mul",
{ n->inputs[1], ograds[0] }),
n->inputs[0]
});
NodeEntry drhs = MakeNode("collapse_sum", n->attrs.name + "_drhs_sum", {
MakeNode("broadcast_mul", n->attrs.name + "_drhs_mul",
{ n->inputs[0], ograds[0] }),
n->inputs[1]
});
return std::vector<NodeEntry>{ dlhs, drhs };
});
NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_div, divide) NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_div, divide)
...@@ -291,7 +325,26 @@ Example:: ...@@ -291,7 +325,26 @@ Example::
broadcast_div(x, y) = [[ 3., 3., 3.], broadcast_div(x, y) = [[ 3., 3., 3.],
[ 2., 2., 2.]] [ 2., 2., 2.]]
)code" NNVM_ADD_FILELINE); )code" NNVM_ADD_FILELINE)
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
NodeEntry dlhs = MakeNode("collapse_sum", n->attrs.name + "_dlhs_sum", {
MakeNode("broadcast_div", n->attrs.name + "_dlhs_div",
{ ograds[0], n->inputs[1] }),
n->inputs[0]
});
NodeEntry dy = MakeNode("broadcast_div", n->attrs.name + "_drhs_div", {
NodeEntry{n, 0, 0},
MakeNode("__mul_scalar__", n->attrs.name + "_rhs_by_two",
{n->inputs[1]}, {{"scalar", "2"}})
});
NodeEntry drhs = MakeNode("collapse_sum", n->attrs.name + "_drhs_sum", {
MakeNode("broadcast_mul", n->attrs.name + "_drhs_mul", { dy, ograds[0] }),
n->inputs[1]
});
return std::vector<NodeEntry>{ dlhs, drhs };
});
} // namespace top } // namespace top
} // namespace nnvm } // namespace nnvm
...@@ -9,9 +9,13 @@ ...@@ -9,9 +9,13 @@
#include <nnvm/compiler/op_attr_types.h> #include <nnvm/compiler/op_attr_types.h>
#include <nnvm/compiler/util.h> #include <nnvm/compiler/util.h>
#include <nnvm/top/tensor.h> #include <nnvm/top/tensor.h>
#include <numeric>
#include "../op_common.h" #include "../op_common.h"
#include "../elemwise_op_common.h" #include "../elemwise_op_common.h"
#include "topi/detail/constant_utils.h"
#include "topi/elemwise.h"
#include "topi/reduction.h" #include "topi/reduction.h"
#include "topi/transform.h"
namespace nnvm { namespace nnvm {
namespace top { namespace top {
...@@ -21,58 +25,61 @@ using namespace nnvm::compiler; ...@@ -21,58 +25,61 @@ using namespace nnvm::compiler;
// reduce // reduce
DMLC_REGISTER_PARAMETER(ReduceParam); DMLC_REGISTER_PARAMETER(ReduceParam);
inline TShape ReduceShapeImpl(const TShape& ishape, inline TShape GetReduceAxes(const uint32_t indim,
const TShape& axis, const TShape& axis,
bool keepdims,
bool exclude) { bool exclude) {
if (axis.ndim() == 0) { if (axis.ndim() == 0) {
if (keepdims) { TShape r_axes(indim);
return TShape(ishape.ndim()); std::iota(r_axes.begin(), r_axes.end(), 0);
} else { return r_axes;
return TShape(1);
} }
}
CHECK_LT(axis[axis.ndim() - 1], ishape.ndim()) CHECK_LT(axis[axis.ndim() - 1], indim)
<< "Reduction axis " << axis[axis.ndim() - 1] << "Reduction axis " << axis[axis.ndim() - 1]
<< " Exceeds input dimensions " << ishape; << " exceeds input dimensions " << indim;
TShape in_axis = axis; TShape in_axis = axis;
for (auto& i : in_axis) { for (auto& i : in_axis) {
i = i < 0 ? i + ishape.ndim(): i; i = i < 0 ? i + indim : i;
CHECK_GE(i, 0) << "axis out of bounds in reduce operator"; CHECK_GE(i, 0) << "axis out of bounds in reduce operator";
CHECK_LT(i, ishape.ndim()) << "axis out of bounds in reduce operator"; CHECK_LT(i, indim) << "axis out of bounds in reduce operator";
} }
std::sort(in_axis.begin(), in_axis.end()); std::sort(in_axis.begin(), in_axis.end());
if (!exclude) return in_axis;
if (keepdims) { TShape r_axis(indim - in_axis.ndim());
TShape oshape(ishape); for (unsigned i = 0, j = 0, k = 0; i < indim; ++i) {
if (exclude) { if (i == in_axis[j]) {
for (dim_t i = 0, j = 0; i < ishape.ndim(); ++i) {
if (j < in_axis.ndim() && i == in_axis[j]) {
++j; ++j;
continue; continue;
} }
oshape[i] = 1; r_axis[k++] = i;
}
return oshape;
} }
return r_axis;
}
for (dim_t i = 0; i < in_axis.ndim(); ++i) { inline TShape ReduceShapeImpl(const TShape& ishape,
oshape[in_axis[i]] = 1; const TShape& axis,
} bool keepdims,
return oshape; bool exclude) {
} uint32_t indim = ishape.ndim();
TShape r_axes = GetReduceAxes(indim, axis, exclude);
if (!r_axes.ndim()) return ishape;
if (r_axes.ndim() == indim)
return TShape(keepdims ? indim : 1);
if (exclude) { if (keepdims) {
TShape oshape = TShape(in_axis.ndim()); TShape oshape(ishape);
for (dim_t i = 0; i < in_axis.ndim(); ++i) { for (unsigned i = 0, j = 0; i < indim; ++i) {
oshape[i] = ishape[in_axis[i]]; if (i != r_axes[j]) continue;
oshape[i] = 1;
++j;
} }
return oshape; return oshape;
} }
TShape oshape = TShape(std::max<dim_t>(1, ishape.ndim() - in_axis.ndim()));
for (dim_t i = 0, j = 0, k = 0; i < ishape.ndim(); ++i) { TShape oshape(indim - r_axes.ndim());
if (j < in_axis.ndim() && i == in_axis[j]) { for (unsigned i = 0, j = 0, k = 0; i < indim; ++i) {
if (i == r_axes[j]) {
++j; ++j;
continue; continue;
} }
...@@ -95,6 +102,16 @@ inline bool ReduceShape(const nnvm::NodeAttrs& attrs, ...@@ -95,6 +102,16 @@ inline bool ReduceShape(const nnvm::NodeAttrs& attrs,
return true; return true;
} }
inline bool CollapseShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape>* in_attrs,
std::vector<TShape>* out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);
if ((*in_attrs)[0].ndim() == 1) return false;
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, (*in_attrs)[1]);
return true;
}
template<typename PType> template<typename PType>
inline void AxesParamParser(nnvm::NodeAttrs* attrs) { inline void AxesParamParser(nnvm::NodeAttrs* attrs) {
PType param; PType param;
...@@ -103,18 +120,21 @@ inline void AxesParamParser(nnvm::NodeAttrs* attrs) { ...@@ -103,18 +120,21 @@ inline void AxesParamParser(nnvm::NodeAttrs* attrs) {
attrs->parsed = std::move(param); attrs->parsed = std::move(param);
} }
#define NNVM_REGISTER_REDUCE_OP(op) \ #define NNVM_REGISTER_BASE_REDUCE_OP(op) \
NNVM_REGISTER_OP(op) \ NNVM_REGISTER_OP(op) \
.add_argument("data", "Tensor", "The input") \
.add_arguments(ReduceParam::__FIELDS__()) \ .add_arguments(ReduceParam::__FIELDS__()) \
.set_attr_parser(AxesParamParser<ReduceParam>) \ .set_attr_parser(AxesParamParser<ReduceParam>) \
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ReduceParam>) \ .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ReduceParam>) \
.set_num_outputs(1)
#define NNVM_REGISTER_REDUCE_OP(op) \
NNVM_REGISTER_BASE_REDUCE_OP(op) \
.add_argument("data", "Tensor", "The input") \
.set_attr<FInferShape>("FInferShape", ReduceShape) \ .set_attr<FInferShape>("FInferShape", ReduceShape) \
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) \ .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) \
.set_attr<FCorrectLayout>("FCorrectLayout", \ .set_attr<FCorrectLayout>("FCorrectLayout", \
ElemwiseFixedLayoutUnknownOut<1, 1>) \ ElemwiseFixedLayoutUnknownOut<1, 1>) \
.set_num_inputs(1) \ .set_num_inputs(1)
.set_num_outputs(1)
NNVM_REGISTER_REDUCE_OP(sum) NNVM_REGISTER_REDUCE_OP(sum)
.describe(R"code(Computes the sum of array elements over given axes. .describe(R"code(Computes the sum of array elements over given axes.
...@@ -139,20 +159,10 @@ Example:: ...@@ -139,20 +159,10 @@ Example::
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
const Array<Tensor>& out_info) { const Array<Tensor>& out_info) {
const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed); const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed);
Array<Expr> axis; TShape r_axes = GetReduceAxes(inputs[0]->shape.size(),
if (param.exclude) { param.axis, param.exclude);
std::set<dim_t> exclude_axis; if (!r_axes.ndim()) return Array<Tensor> { topi::identity(inputs[0]) };
for (dim_t i = 0; i < param.axis.ndim(); ++i) { auto axis = ShapeToArray(r_axes);
exclude_axis.insert(param.axis[i]);
}
for (dim_t i = 0; i < static_cast<int>(inputs[0].ndim()); ++i) {
if (exclude_axis.count(i) == 0) {
axis.push_back(make_const(Int(32), i));
}
}
} else {
axis = ShapeToArray(param.axis);
}
return Array<Tensor>{ return Array<Tensor>{
topi::sum(inputs[0], axis, param.keepdims) }; topi::sum(inputs[0], axis, param.keepdims) };
}) })
...@@ -178,7 +188,9 @@ NNVM_REGISTER_REDUCE_OP(max) ...@@ -178,7 +188,9 @@ NNVM_REGISTER_REDUCE_OP(max)
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
const Array<Tensor>& out_info) { const Array<Tensor>& out_info) {
const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed); const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed);
auto axis = ShapeToArray(param.axis); TShape r_axes = GetReduceAxes(inputs[0]->shape.size(),
param.axis, param.exclude);
auto axis = ShapeToArray(r_axes);
return Array<Tensor>{ return Array<Tensor>{
topi::max(inputs[0], axis, param.keepdims) }; topi::max(inputs[0], axis, param.keepdims) };
}) })
...@@ -210,7 +222,9 @@ NNVM_REGISTER_REDUCE_OP(min) ...@@ -210,7 +222,9 @@ NNVM_REGISTER_REDUCE_OP(min)
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
const Array<Tensor>& out_info) { const Array<Tensor>& out_info) {
const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed); const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed);
auto axis = ShapeToArray(param.axis); TShape r_axes = GetReduceAxes(inputs[0]->shape.size(),
param.axis, param.exclude);
auto axis = ShapeToArray(r_axes);
return Array<Tensor>{ return Array<Tensor>{
topi::min(inputs[0], axis, param.keepdims) }; topi::min(inputs[0], axis, param.keepdims) };
}) })
...@@ -233,6 +247,20 @@ NNVM_REGISTER_REDUCE_OP(min) ...@@ -233,6 +247,20 @@ NNVM_REGISTER_REDUCE_OP(min)
}; };
}); });
NNVM_REGISTER_BASE_REDUCE_OP(collapse_sum)
.add_argument("data", "Tensor", "The input")
.add_argument("as", "Tensor", "The reference")
.set_attr<FInferShape>("FInferShape", CollapseShape)
.set_attr<FInferType>("FInferType", ElemwiseType<2, 1>)
.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseFixedLayoutUnknownOut<2, 1>)
.set_num_inputs(2)
.describe(R"code(Reduces lhs to the shape of rhs via sum)code" NNVM_ADD_FILELINE)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::collapse_sum(inputs[0], inputs[1]->shape) };
});
} // namespace top } // namespace top
} // namespace nnvm } // namespace nnvm
...@@ -240,10 +240,8 @@ will return a new array with shape ``(2,1,1,1,1,1,3,4)``. ...@@ -240,10 +240,8 @@ will return a new array with shape ``(2,1,1,1,1,1,3,4)``.
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){ const std::vector<NodeEntry>& ograds){
const ExpandDimsParam& param = nnvm::get<ExpandDimsParam>(n->attrs.parsed);
return std::vector<NodeEntry> { return std::vector<NodeEntry> {
MakeNode("sum", n->attrs.name + "_grad", {ograds[0]}, MakeNode("collapse_sum", n->attrs.name + "_grad", {ograds[0], n->inputs[0]})
{{"axis", std::to_string(param.axis)}})
}; };
}) })
.set_support_level(1); .set_support_level(1);
......
...@@ -81,7 +81,23 @@ def verify_reduce(dshape, fnp, fsym, **kwargs): ...@@ -81,7 +81,23 @@ def verify_reduce(dshape, fnp, fsym, **kwargs):
np.testing.assert_allclose(out.asnumpy(), out_np, atol=1e-5, rtol=1e-5) np.testing.assert_allclose(out.asnumpy(), out_np, atol=1e-5, rtol=1e-5)
def test_tranpose(): def verify_collapse(dshape, target_shape, fnp):
x = sym.Variable("x", shape=dshape)
t = sym.Variable("t", shape=target_shape)
y = sym.collapse_sum(x, t)
dtype = "float32"
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target,
{"x": dshape, "t": target_shape})
m = graph_runtime.create(graph, lib, ctx)
data = np.random.uniform(size=dshape).astype(dtype)
m.run(x=data)
out = m.get_output(0, tvm.nd.empty(target_shape))
out_np = fnp(data)
np.testing.assert_allclose(out.asnumpy(), out_np, atol=1e-5, rtol=1e-5)
def test_transpose():
verify_transpose((2, 3, 4), (0, 2, 1)) verify_transpose((2, 3, 4), (0, 2, 1))
verify_transpose((2, 3, 4), None) verify_transpose((2, 3, 4), None)
...@@ -90,6 +106,22 @@ def test_reduce(): ...@@ -90,6 +106,22 @@ def test_reduce():
verify_reduce((2, 3, 4), np.max, sym.max, axis=1, keepdims=True) verify_reduce((2, 3, 4), np.max, sym.max, axis=1, keepdims=True)
verify_reduce((4, 4, 3), np.min, sym.min, keepdims=True) verify_reduce((4, 4, 3), np.min, sym.min, keepdims=True)
verify_reduce((4, 4, 3), np.sum, sym.sum, axis=(0, 2)) verify_reduce((4, 4, 3), np.sum, sym.sum, axis=(0, 2))
verify_reduce((4, 4, 3), np.sum, sym.sum)
def test_collapse():
verify_collapse((2, 3, 4), (1,), lambda x: x.sum())
verify_collapse((2, 3, 4), (1, 1, 1), lambda x: x.sum(keepdims=True))
verify_collapse((2, 3, 4), (1, 1), lambda x: x.sum().reshape(1, 1))
verify_collapse((2, 3, 4), (1, 4), lambda x: x.reshape(-1, 4).sum(0, keepdims=True))
verify_collapse((2, 3, 4), (3, 4), lambda x: x.sum(0))
verify_collapse((2, 3, 4), (1, 3, 4), lambda x: x.sum(0, keepdims=True))
verify_collapse((2, 3, 4), (1, 1, 4), lambda x: x.sum((0, 1), keepdims=True))
verify_collapse((2, 3, 4), (2, 1, 4), lambda x: x.sum(1, keepdims=True))
verify_collapse((2, 3, 4), (2, 1, 1), lambda x: x.sum((1, 2), keepdims=True))
verify_collapse((2, 3, 4), (2, 3, 1), lambda x: x.sum(2, keepdims=True))
verify_collapse((2, 3, 4), (2, 3, 4), lambda x: x)
def verify_flip(ishape, axis): def verify_flip(ishape, axis):
x = sym.Variable("x") x = sym.Variable("x")
...@@ -106,6 +138,7 @@ def verify_flip(ishape, axis): ...@@ -106,6 +138,7 @@ def verify_flip(ishape, axis):
out = m.get_output(0, tvm.nd.empty(res.shape)) out = m.get_output(0, tvm.nd.empty(res.shape))
np.testing.assert_allclose(out.asnumpy(), res, atol=1e-5, rtol=1e-5) np.testing.assert_allclose(out.asnumpy(), res, atol=1e-5, rtol=1e-5)
def test_flip(): def test_flip():
verify_flip((3, 4, 3), 1) verify_flip((3, 4, 3), 1)
verify_flip((3, 4, 3), 0) verify_flip((3, 4, 3), 0)
...@@ -114,6 +147,7 @@ def test_flip(): ...@@ -114,6 +147,7 @@ def test_flip():
verify_flip((3, 4, 3), -3) verify_flip((3, 4, 3), -3)
verify_flip((3, 4, 3), -2) verify_flip((3, 4, 3), -2)
def verify_reshape(dshape, oshape): def verify_reshape(dshape, oshape):
x = sym.Variable("x") x = sym.Variable("x")
y = sym.reshape(x, shape=oshape) y = sym.reshape(x, shape=oshape)
...@@ -156,6 +190,45 @@ def test_clip(): ...@@ -156,6 +190,45 @@ def test_clip():
helper(y, inputs, dtype, forward, backward) helper(y, inputs, dtype, forward, backward)
def test_broadcast():
a = sym.Variable("a")
b = sym.Variable("b")
inputs = [('a', (3, 4, 5), a),
('b', (1, 5), b)]
dtype = "float32"
def _collapse(g):
return g.reshape(-1, inputs[-1][1][-1]).sum(0, keepdims=True)
y = sym.broadcast_add(a, b)
def _backward_add(head_grads, a, b):
da = head_grads
db = _collapse(head_grads)
return da, db
helper(y, inputs, dtype, lambda a, b: a + b, _backward_add)
y = sym.broadcast_sub(a, b)
def _backward_sub(head_grads, a, b):
da = head_grads
db = -_collapse(head_grads)
return da, db
helper(y, inputs, dtype, lambda a, b: a - b, _backward_sub)
y = sym.broadcast_mul(a, b)
def _backward_mul(head_grads, a, b):
da = head_grads * b
db = _collapse(head_grads * a)
return da, db
helper(y, inputs, dtype, lambda a, b: a * b, _backward_mul)
y = sym.broadcast_div(a, b)
def _backward_div(head_grads, a, b):
da = head_grads / b
db = _collapse(head_grads * a / (2 * b**2))
return da, db
helper(y, inputs, dtype, lambda a, b: a / b, _backward_div)
def test_greater(): def test_greater():
l = sym.Variable("l") l = sym.Variable("l")
r = sym.Variable("r") r = sym.Variable("r")
...@@ -472,8 +545,10 @@ def test_nms(): ...@@ -472,8 +545,10 @@ def test_nms():
if __name__ == "__main__": if __name__ == "__main__":
test_reshape() test_reshape()
test_broadcast()
test_reduce() test_reduce()
test_tranpose() test_collapse()
test_transpose()
test_clip() test_clip()
test_greater() test_greater()
test_less() test_less()
......
...@@ -6,12 +6,15 @@ ...@@ -6,12 +6,15 @@
#ifndef TOPI_REDUCTION_H_ #ifndef TOPI_REDUCTION_H_
#define TOPI_REDUCTION_H_ #define TOPI_REDUCTION_H_
#include <algorithm>
#include <string> #include <string>
#include <set> #include <set>
#include <vector> #include <vector>
#include <iterator> #include <iterator>
#include "topi/elemwise.h"
#include "topi/tags.h" #include "topi/tags.h"
#include "topi/transform.h"
#include "topi/detail/ravel_unravel.h" #include "topi/detail/ravel_unravel.h"
#include "topi/detail/constant_utils.h" #include "topi/detail/constant_utils.h"
#include "tvm/tvm.h" #include "tvm/tvm.h"
...@@ -91,6 +94,9 @@ inline Array<Expr> MakeReduceTargetShape(const std::vector<int>& real_axis, ...@@ -91,6 +94,9 @@ inline Array<Expr> MakeReduceTargetShape(const std::vector<int>& real_axis,
target_shape.push_back(data->shape[i]); target_shape.push_back(data->shape[i]);
} }
} }
if (target_shape.size() == 0) {
target_shape.push_back(1);
}
} }
return target_shape; return target_shape;
} }
...@@ -99,55 +105,72 @@ inline Array<Expr> MakeReduceTargetShape(const std::vector<int>& real_axis, ...@@ -99,55 +105,72 @@ inline Array<Expr> MakeReduceTargetShape(const std::vector<int>& real_axis,
* \brief Create a reduction operation. * \brief Create a reduction operation.
* *
* \param data The input tensor. * \param data The input tensor.
* \param axis The axes along which the reduction is performed.
* \param func The reduction function eg. tvm::sum * \param func The reduction function eg. tvm::sum
* \param keepdims If this is set to true, the axes which are reduced are * \param target_shape The output Tensor shape.
* left in the result as dimensions with size one. This enables the result * \param reduce_axes The real axes along which the reduction is performed.
* to broadcast correctly against the input array. * \param squeeze_axes The real axes to squeeze. Unsqueezed, reduced axes will
* have shape 1 in the output tensor.
* *
* \return The result tensor. * \return The result tensor.
*/ */
inline Tensor CommReduce(const Tensor& data, inline Tensor DoCommReduce(const Tensor& data,
const Array<Expr>& axis,
FReduce func, FReduce func,
bool keepdims = false) { const Array<Expr>& target_shape,
auto ndim = data->shape.size(); const std::vector<int>& reduce_axes,
CHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor"; const std::vector<int>& squeeze_axes) {
auto axis_val = detail::GetConstIntValues(axis, "axis"); auto r_axes = MakeReduceAxes(reduce_axes, data);
auto real_axis = GetRealAxis(static_cast<int>(ndim), axis_val); auto compute = [&](const Array<Var>& indices) {
auto reduce_axes = MakeReduceAxes(real_axis, data);
auto target_shape = MakeReduceTargetShape(real_axis, data, keepdims);
auto compute = [ndim, keepdims, &real_axis, &reduce_axes, &func, &data]
(const Array<Var>& indices) {
Array<Expr> eval_range; Array<Expr> eval_range;
Array<Var> eval_indices; Array<Var> eval_indices;
int arg_counter = 0; int arg_counter = 0;
int red_counter = 0; int red_counter = 0;
for (size_t i = 0; i < ndim; ++i) { for (size_t i = 0; i < data->shape.size(); ++i) {
if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) { bool squeeze_i = std::find(squeeze_axes.begin(), squeeze_axes.end(), i) != squeeze_axes.end();
if (std::find(reduce_axes.begin(), reduce_axes.end(), i) != reduce_axes.end()) {
// real_axis contains i // real_axis contains i
eval_range.push_back(reduce_axes[red_counter]); eval_range.push_back(r_axes[red_counter]);
eval_indices.push_back(reduce_axes[red_counter]->var); eval_indices.push_back(r_axes[red_counter]->var);
red_counter++; red_counter++;
} else { arg_counter += !squeeze_i;
if (!keepdims) { continue;
}
eval_range.push_back(indices[arg_counter]); eval_range.push_back(indices[arg_counter]);
arg_counter++; arg_counter++;
} else {
eval_range.push_back(indices[i]);
}
}
} }
return func(data(eval_range), reduce_axes); return func(data(eval_range), r_axes);
}; };
return tvm::compute(target_shape, compute, data->op->name + "_red", kCommReduce); return tvm::compute(target_shape, compute, data->op->name + "_red", kCommReduce);
} }
/*! /*!
* \brief Create a reduction operation.
*
* \param data The input tensor.
* \param axis The axes along which the reduction is performed.
* \param func The reduction function eg. tvm::sum
* \param keepdims If this is set to true, the axes which are reduced are
* left in the result as dimensions with size one. This enables the result
* to broadcast correctly against the input array.
*
* \return The result tensor.
*/
inline Tensor CommReduce(const Tensor& data,
const Array<Expr>& axis,
FReduce func,
bool keepdims = false) {
auto ndim = data->shape.size();
CHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor";
auto axis_val = detail::GetConstIntValues(axis, "axis");
auto real_axis = GetRealAxis(static_cast<int>(ndim), axis_val);
auto target_shape = MakeReduceTargetShape(real_axis, data, keepdims);
return DoCommReduce(data, func, target_shape, real_axis,
keepdims ? std::vector<int>() : real_axis);
}
/*!
* \brief Create an index reduction operation. * \brief Create an index reduction operation.
* *
* \param data The input tensor. * \param data The input tensor.
...@@ -281,6 +304,34 @@ inline Tensor sum(const Tensor& data, Array<Expr> axis, bool keepdims = false) { ...@@ -281,6 +304,34 @@ inline Tensor sum(const Tensor& data, Array<Expr> axis, bool keepdims = false) {
return CommReduce(data, axis, tvm::sum, keepdims); return CommReduce(data, axis, tvm::sum, keepdims);
} }
inline Tensor collapse_sum(const Tensor& data, Array<Expr> target_shape) {
CHECK_GE(data->shape.size(), target_shape.size());
auto ishape = detail::GetConstIntValues(data->shape, "ishape");
auto oshape = detail::GetConstIntValues(target_shape, "oshape");
std::vector<int> reduce_axes;
std::vector<int> squeeze_axes;
for (int i_ax = ishape.size() - 1,
o_ax = oshape.size() - 1; i_ax >= 0; --i_ax) {
if (o_ax >= 0 && ishape[i_ax] == oshape[o_ax]) {
--o_ax;
continue;
}
reduce_axes.push_back(i_ax);
if (o_ax < 0) { // squeeze o_ax if was added during expansion
squeeze_axes.push_back(i_ax);
} else if (oshape[o_ax] == 1) {
--o_ax;
}
}
if (reduce_axes.size() == 0) return topi::identity(data, "tensor", kCommReduce);
std::reverse(reduce_axes.begin(), reduce_axes.end());
std::reverse(squeeze_axes.begin(), squeeze_axes.end());
return DoCommReduce(data, tvm::sum, target_shape, reduce_axes, squeeze_axes);
}
/*! /*!
* \brief Creates an operation that finds the minimum of elements over * \brief Creates an operation that finds the minimum of elements over
* a given axis. * a given axis.
......
...@@ -4,6 +4,7 @@ from __future__ import absolute_import as _abs ...@@ -4,6 +4,7 @@ from __future__ import absolute_import as _abs
import tvm import tvm
from .. import tag from .. import tag
from .. import generic from .. import generic
from .injective import _schedule_injective
def _schedule_reduce(op, sch, is_idx_reduce=False): def _schedule_reduce(op, sch, is_idx_reduce=False):
if is_idx_reduce: if is_idx_reduce:
...@@ -11,7 +12,9 @@ def _schedule_reduce(op, sch, is_idx_reduce=False): ...@@ -11,7 +12,9 @@ def _schedule_reduce(op, sch, is_idx_reduce=False):
else: else:
data_in = op.input_tensors[0] data_in = op.input_tensors[0]
data_out = op.output(0) data_out = op.output(0)
assert len(sch[data_out].op.reduce_axis) > 0, "reduce_axis must be bigger than zero!"
if not sch[data_out].op.reduce_axis:
return _schedule_injective(op, sch)
if len(sch[data_out].op.axis) > 0: if len(sch[data_out].op.axis) > 0:
all_reduce = False all_reduce = False
......
...@@ -73,6 +73,7 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"): ...@@ -73,6 +73,7 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
out_npy = _my_npy_argmin(in_npy_map, axis=axis, keepdims=keepdims) out_npy = _my_npy_argmin(in_npy_map, axis=axis, keepdims=keepdims)
else: else:
raise NotImplementedError raise NotImplementedError
out_npy = np.atleast_1d(out_npy)
data_tvm = tvm.nd.array(in_npy, ctx=ctx) data_tvm = tvm.nd.array(in_npy, ctx=ctx)
out_tvm = tvm.nd.empty(shape=out_npy.shape, ctx=ctx, dtype=out_dtype) out_tvm = tvm.nd.empty(shape=out_npy.shape, ctx=ctx, dtype=out_dtype)
for _ in range(1): for _ in range(1):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment