Commit 44d8203f by nhynes Committed by Tianqi Chen

[NNVM][TOPI] Add gradients for broadcast_* ops (#1234)

parent 3f7cce3b
......@@ -53,11 +53,18 @@ class Tuple {
this->assign(init.begin(), init.end());
* \brief constructor from vector
* \param init the vector
inline Tuple(std::vector<ValueType> init) { // NOLINT(runtime/explicit)
this->assign(init.begin(), init.end());
* \brief move constructor from Tuple
* \param src the source shape
inline Tuple(Tuple<ValueType>&& src) { // NOLINT(*)
inline Tuple(Tuple<ValueType>&& src) { // NOLINT(runtime/explicit)
......@@ -37,3 +37,7 @@ reg.register_schedule("max", _fschedule_reduce)
# min
reg.register_pattern("min", OpPattern.COMM_REDUCE)
reg.register_schedule("min", _fschedule_reduce)
# collapse sum
reg.register_pattern("collapse_sum", OpPattern.COMM_REDUCE)
reg.register_schedule("collapse_sum", _fschedule_reduce)
......@@ -238,7 +238,15 @@ Example::
broadcast_add(x, y) = [[ 1., 1., 1.],
[ 2., 2., 2.]]
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
return std::vector<NodeEntry>{
MakeNode("collapse_sum", n-> + "_dlhs", { ograds[0], n->inputs[0] }),
MakeNode("collapse_sum", n-> + "_drhs", { ograds[0], n->inputs[1] })
NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_sub, subtract)
......@@ -256,7 +264,18 @@ Example::
broadcast_sub(x, y) = [[ 1., 1., 1.],
[ 0., 0., 0.]]
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
return std::vector<NodeEntry>{
MakeNode("collapse_sum", n-> + "_dlhs", { ograds[0], n->inputs[0] }),
MakeNode("collapse_sum", n-> + "_drhs", {
MakeNode("negative", n-> + "_drhs_neg", {ograds[0]}),
NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_mul, multiply)
......@@ -273,7 +292,22 @@ Example::
broadcast_mul(x, y) = [[ 0., 0., 0.],
[ 1., 1., 1.]]
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
NodeEntry dlhs = MakeNode("collapse_sum", n-> + "_dlhs_sum", {
MakeNode("broadcast_mul", n-> + "_dlhs_mul",
{ n->inputs[1], ograds[0] }),
NodeEntry drhs = MakeNode("collapse_sum", n-> + "_drhs_sum", {
MakeNode("broadcast_mul", n-> + "_drhs_mul",
{ n->inputs[0], ograds[0] }),
return std::vector<NodeEntry>{ dlhs, drhs };
......@@ -291,7 +325,26 @@ Example::
broadcast_div(x, y) = [[ 3., 3., 3.],
[ 2., 2., 2.]]
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
NodeEntry dlhs = MakeNode("collapse_sum", n-> + "_dlhs_sum", {
MakeNode("broadcast_div", n-> + "_dlhs_div",
{ ograds[0], n->inputs[1] }),
NodeEntry dy = MakeNode("broadcast_div", n-> + "_drhs_div", {
NodeEntry{n, 0, 0},
MakeNode("__mul_scalar__", n-> + "_rhs_by_two",
{n->inputs[1]}, {{"scalar", "2"}})
NodeEntry drhs = MakeNode("collapse_sum", n-> + "_drhs_sum", {
MakeNode("broadcast_mul", n-> + "_drhs_mul", { dy, ograds[0] }),
return std::vector<NodeEntry>{ dlhs, drhs };
} // namespace top
} // namespace nnvm
......@@ -9,9 +9,13 @@
#include <nnvm/compiler/op_attr_types.h>
#include <nnvm/compiler/util.h>
#include <nnvm/top/tensor.h>
#include <numeric>
#include "../op_common.h"
#include "../elemwise_op_common.h"
#include "topi/detail/constant_utils.h"
#include "topi/elemwise.h"
#include "topi/reduction.h"
#include "topi/transform.h"
namespace nnvm {
namespace top {
......@@ -21,58 +25,61 @@ using namespace nnvm::compiler;
// reduce
inline TShape ReduceShapeImpl(const TShape& ishape,
inline TShape GetReduceAxes(const uint32_t indim,
const TShape& axis,
bool keepdims,
bool exclude) {
if (axis.ndim() == 0) {
if (keepdims) {
return TShape(ishape.ndim());
} else {
return TShape(1);
TShape r_axes(indim);
std::iota(r_axes.begin(), r_axes.end(), 0);
return r_axes;
CHECK_LT(axis[axis.ndim() - 1], ishape.ndim())
CHECK_LT(axis[axis.ndim() - 1], indim)
<< "Reduction axis " << axis[axis.ndim() - 1]
<< " Exceeds input dimensions " << ishape;
<< " exceeds input dimensions " << indim;
TShape in_axis = axis;
for (auto& i : in_axis) {
i = i < 0 ? i + ishape.ndim(): i;
i = i < 0 ? i + indim : i;
CHECK_GE(i, 0) << "axis out of bounds in reduce operator";
CHECK_LT(i, ishape.ndim()) << "axis out of bounds in reduce operator";
CHECK_LT(i, indim) << "axis out of bounds in reduce operator";
std::sort(in_axis.begin(), in_axis.end());
if (keepdims) {
TShape oshape(ishape);
if (exclude) {
for (dim_t i = 0, j = 0; i < ishape.ndim(); ++i) {
if (j < in_axis.ndim() && i == in_axis[j]) {
if (!exclude) return in_axis;
TShape r_axis(indim - in_axis.ndim());
for (unsigned i = 0, j = 0, k = 0; i < indim; ++i) {
if (i == in_axis[j]) {
oshape[i] = 1;
return oshape;
r_axis[k++] = i;
return r_axis;
for (dim_t i = 0; i < in_axis.ndim(); ++i) {
oshape[in_axis[i]] = 1;
return oshape;
inline TShape ReduceShapeImpl(const TShape& ishape,
const TShape& axis,
bool keepdims,
bool exclude) {
uint32_t indim = ishape.ndim();
TShape r_axes = GetReduceAxes(indim, axis, exclude);
if (!r_axes.ndim()) return ishape;
if (r_axes.ndim() == indim)
return TShape(keepdims ? indim : 1);
if (exclude) {
TShape oshape = TShape(in_axis.ndim());
for (dim_t i = 0; i < in_axis.ndim(); ++i) {
oshape[i] = ishape[in_axis[i]];
if (keepdims) {
TShape oshape(ishape);
for (unsigned i = 0, j = 0; i < indim; ++i) {
if (i != r_axes[j]) continue;
oshape[i] = 1;
return oshape;
TShape oshape = TShape(std::max<dim_t>(1, ishape.ndim() - in_axis.ndim()));
for (dim_t i = 0, j = 0, k = 0; i < ishape.ndim(); ++i) {
if (j < in_axis.ndim() && i == in_axis[j]) {
TShape oshape(indim - r_axes.ndim());
for (unsigned i = 0, j = 0, k = 0; i < indim; ++i) {
if (i == r_axes[j]) {
......@@ -95,6 +102,16 @@ inline bool ReduceShape(const nnvm::NodeAttrs& attrs,
return true;
inline bool CollapseShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape>* in_attrs,
std::vector<TShape>* out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);
if ((*in_attrs)[0].ndim() == 1) return false;
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, (*in_attrs)[1]);
return true;
template<typename PType>
inline void AxesParamParser(nnvm::NodeAttrs* attrs) {
PType param;
......@@ -103,18 +120,21 @@ inline void AxesParamParser(nnvm::NodeAttrs* attrs) {
attrs->parsed = std::move(param);
.add_argument("data", "Tensor", "The input") \
.add_arguments(ReduceParam::__FIELDS__()) \
.set_attr_parser(AxesParamParser<ReduceParam>) \
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ReduceParam>) \
.add_argument("data", "Tensor", "The input") \
.set_attr<FInferShape>("FInferShape", ReduceShape) \
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) \
.set_attr<FCorrectLayout>("FCorrectLayout", \
ElemwiseFixedLayoutUnknownOut<1, 1>) \
.set_num_inputs(1) \
.describe(R"code(Computes the sum of array elements over given axes.
......@@ -139,20 +159,10 @@ Example::
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed);
Array<Expr> axis;
if (param.exclude) {
std::set<dim_t> exclude_axis;
for (dim_t i = 0; i < param.axis.ndim(); ++i) {
for (dim_t i = 0; i < static_cast<int>(inputs[0].ndim()); ++i) {
if (exclude_axis.count(i) == 0) {
axis.push_back(make_const(Int(32), i));
} else {
axis = ShapeToArray(param.axis);
TShape r_axes = GetReduceAxes(inputs[0]->shape.size(),
param.axis, param.exclude);
if (!r_axes.ndim()) return Array<Tensor> { topi::identity(inputs[0]) };
auto axis = ShapeToArray(r_axes);
return Array<Tensor>{
topi::sum(inputs[0], axis, param.keepdims) };
......@@ -178,7 +188,9 @@ NNVM_REGISTER_REDUCE_OP(max)
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed);
auto axis = ShapeToArray(param.axis);
TShape r_axes = GetReduceAxes(inputs[0]->shape.size(),
param.axis, param.exclude);
auto axis = ShapeToArray(r_axes);
return Array<Tensor>{
topi::max(inputs[0], axis, param.keepdims) };
......@@ -210,7 +222,9 @@ NNVM_REGISTER_REDUCE_OP(min)
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed);
auto axis = ShapeToArray(param.axis);
TShape r_axes = GetReduceAxes(inputs[0]->shape.size(),
param.axis, param.exclude);
auto axis = ShapeToArray(r_axes);
return Array<Tensor>{
topi::min(inputs[0], axis, param.keepdims) };
......@@ -233,6 +247,20 @@ NNVM_REGISTER_REDUCE_OP(min)
.add_argument("data", "Tensor", "The input")
.add_argument("as", "Tensor", "The reference")
.set_attr<FInferShape>("FInferShape", CollapseShape)
.set_attr<FInferType>("FInferType", ElemwiseType<2, 1>)
.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseFixedLayoutUnknownOut<2, 1>)
.describe(R"code(Reduces lhs to the shape of rhs via sum)code" NNVM_ADD_FILELINE)
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::collapse_sum(inputs[0], inputs[1]->shape) };
} // namespace top
} // namespace nnvm
......@@ -240,10 +240,8 @@ will return a new array with shape ``(2,1,1,1,1,1,3,4)``.
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){
const ExpandDimsParam& param = nnvm::get<ExpandDimsParam>(n->attrs.parsed);
return std::vector<NodeEntry> {
MakeNode("sum", n-> + "_grad", {ograds[0]},
{{"axis", std::to_string(param.axis)}})
MakeNode("collapse_sum", n-> + "_grad", {ograds[0], n->inputs[0]})
......@@ -81,7 +81,23 @@ def verify_reduce(dshape, fnp, fsym, **kwargs):
np.testing.assert_allclose(out.asnumpy(), out_np, atol=1e-5, rtol=1e-5)
def test_tranpose():
def verify_collapse(dshape, target_shape, fnp):
x = sym.Variable("x", shape=dshape)
t = sym.Variable("t", shape=target_shape)
y = sym.collapse_sum(x, t)
dtype = "float32"
for target, ctx in ctx_list():
graph, lib, _ =, target,
{"x": dshape, "t": target_shape})
m = graph_runtime.create(graph, lib, ctx)
data = np.random.uniform(size=dshape).astype(dtype)
out = m.get_output(0, tvm.nd.empty(target_shape))
out_np = fnp(data)
np.testing.assert_allclose(out.asnumpy(), out_np, atol=1e-5, rtol=1e-5)
def test_transpose():
verify_transpose((2, 3, 4), (0, 2, 1))
verify_transpose((2, 3, 4), None)
......@@ -90,6 +106,22 @@ def test_reduce():
verify_reduce((2, 3, 4), np.max, sym.max, axis=1, keepdims=True)
verify_reduce((4, 4, 3), np.min, sym.min, keepdims=True)
verify_reduce((4, 4, 3), np.sum, sym.sum, axis=(0, 2))
verify_reduce((4, 4, 3), np.sum, sym.sum)
def test_collapse():
verify_collapse((2, 3, 4), (1,), lambda x: x.sum())
verify_collapse((2, 3, 4), (1, 1, 1), lambda x: x.sum(keepdims=True))
verify_collapse((2, 3, 4), (1, 1), lambda x: x.sum().reshape(1, 1))
verify_collapse((2, 3, 4), (1, 4), lambda x: x.reshape(-1, 4).sum(0, keepdims=True))
verify_collapse((2, 3, 4), (3, 4), lambda x: x.sum(0))
verify_collapse((2, 3, 4), (1, 3, 4), lambda x: x.sum(0, keepdims=True))
verify_collapse((2, 3, 4), (1, 1, 4), lambda x: x.sum((0, 1), keepdims=True))
verify_collapse((2, 3, 4), (2, 1, 4), lambda x: x.sum(1, keepdims=True))
verify_collapse((2, 3, 4), (2, 1, 1), lambda x: x.sum((1, 2), keepdims=True))
verify_collapse((2, 3, 4), (2, 3, 1), lambda x: x.sum(2, keepdims=True))
verify_collapse((2, 3, 4), (2, 3, 4), lambda x: x)
def verify_flip(ishape, axis):
x = sym.Variable("x")
......@@ -106,6 +138,7 @@ def verify_flip(ishape, axis):
out = m.get_output(0, tvm.nd.empty(res.shape))
np.testing.assert_allclose(out.asnumpy(), res, atol=1e-5, rtol=1e-5)
def test_flip():
verify_flip((3, 4, 3), 1)
verify_flip((3, 4, 3), 0)
......@@ -114,6 +147,7 @@ def test_flip():
verify_flip((3, 4, 3), -3)
verify_flip((3, 4, 3), -2)
def verify_reshape(dshape, oshape):
x = sym.Variable("x")
y = sym.reshape(x, shape=oshape)
......@@ -156,6 +190,45 @@ def test_clip():
helper(y, inputs, dtype, forward, backward)
def test_broadcast():
a = sym.Variable("a")
b = sym.Variable("b")
inputs = [('a', (3, 4, 5), a),
('b', (1, 5), b)]
dtype = "float32"
def _collapse(g):
return g.reshape(-1, inputs[-1][1][-1]).sum(0, keepdims=True)
y = sym.broadcast_add(a, b)
def _backward_add(head_grads, a, b):
da = head_grads
db = _collapse(head_grads)
return da, db
helper(y, inputs, dtype, lambda a, b: a + b, _backward_add)
y = sym.broadcast_sub(a, b)
def _backward_sub(head_grads, a, b):
da = head_grads
db = -_collapse(head_grads)
return da, db
helper(y, inputs, dtype, lambda a, b: a - b, _backward_sub)
y = sym.broadcast_mul(a, b)
def _backward_mul(head_grads, a, b):
da = head_grads * b
db = _collapse(head_grads * a)
return da, db
helper(y, inputs, dtype, lambda a, b: a * b, _backward_mul)
y = sym.broadcast_div(a, b)
def _backward_div(head_grads, a, b):
da = head_grads / b
db = _collapse(head_grads * a / (2 * b**2))
return da, db
helper(y, inputs, dtype, lambda a, b: a / b, _backward_div)
def test_greater():
l = sym.Variable("l")
r = sym.Variable("r")
......@@ -472,8 +545,10 @@ def test_nms():
if __name__ == "__main__":
......@@ -6,12 +6,15 @@
#include <algorithm>
#include <string>
#include <set>
#include <vector>
#include <iterator>
#include "topi/elemwise.h"
#include "topi/tags.h"
#include "topi/transform.h"
#include "topi/detail/ravel_unravel.h"
#include "topi/detail/constant_utils.h"
#include "tvm/tvm.h"
......@@ -91,6 +94,9 @@ inline Array<Expr> MakeReduceTargetShape(const std::vector<int>& real_axis,
if (target_shape.size() == 0) {
return target_shape;
......@@ -99,55 +105,72 @@ inline Array<Expr> MakeReduceTargetShape(const std::vector<int>& real_axis,
* \brief Create a reduction operation.
* \param data The input tensor.
* \param axis The axes along which the reduction is performed.
* \param func The reduction function eg. tvm::sum
* \param keepdims If this is set to true, the axes which are reduced are
* left in the result as dimensions with size one. This enables the result
* to broadcast correctly against the input array.
* \param target_shape The output Tensor shape.
* \param reduce_axes The real axes along which the reduction is performed.
* \param squeeze_axes The real axes to squeeze. Unsqueezed, reduced axes will
* have shape 1 in the output tensor.
* \return The result tensor.
inline Tensor CommReduce(const Tensor& data,
const Array<Expr>& axis,
inline Tensor DoCommReduce(const Tensor& data,
FReduce func,
bool keepdims = false) {
auto ndim = data->shape.size();
CHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor";
auto axis_val = detail::GetConstIntValues(axis, "axis");
auto real_axis = GetRealAxis(static_cast<int>(ndim), axis_val);
auto reduce_axes = MakeReduceAxes(real_axis, data);
auto target_shape = MakeReduceTargetShape(real_axis, data, keepdims);
auto compute = [ndim, keepdims, &real_axis, &reduce_axes, &func, &data]
(const Array<Var>& indices) {
const Array<Expr>& target_shape,
const std::vector<int>& reduce_axes,
const std::vector<int>& squeeze_axes) {
auto r_axes = MakeReduceAxes(reduce_axes, data);
auto compute = [&](const Array<Var>& indices) {
Array<Expr> eval_range;
Array<Var> eval_indices;
int arg_counter = 0;
int red_counter = 0;
for (size_t i = 0; i < ndim; ++i) {
if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) {
for (size_t i = 0; i < data->shape.size(); ++i) {
bool squeeze_i = std::find(squeeze_axes.begin(), squeeze_axes.end(), i) != squeeze_axes.end();
if (std::find(reduce_axes.begin(), reduce_axes.end(), i) != reduce_axes.end()) {
// real_axis contains i
} else {
if (!keepdims) {
arg_counter += !squeeze_i;
} else {
return func(data(eval_range), reduce_axes);
return func(data(eval_range), r_axes);
return tvm::compute(target_shape, compute, data->op->name + "_red", kCommReduce);
* \brief Create a reduction operation.
* \param data The input tensor.
* \param axis The axes along which the reduction is performed.
* \param func The reduction function eg. tvm::sum
* \param keepdims If this is set to true, the axes which are reduced are
* left in the result as dimensions with size one. This enables the result
* to broadcast correctly against the input array.
* \return The result tensor.
inline Tensor CommReduce(const Tensor& data,
const Array<Expr>& axis,
FReduce func,
bool keepdims = false) {
auto ndim = data->shape.size();
CHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor";
auto axis_val = detail::GetConstIntValues(axis, "axis");
auto real_axis = GetRealAxis(static_cast<int>(ndim), axis_val);
auto target_shape = MakeReduceTargetShape(real_axis, data, keepdims);
return DoCommReduce(data, func, target_shape, real_axis,
keepdims ? std::vector<int>() : real_axis);
* \brief Create an index reduction operation.
* \param data The input tensor.
......@@ -281,6 +304,34 @@ inline Tensor sum(const Tensor& data, Array<Expr> axis, bool keepdims = false) {
return CommReduce(data, axis, tvm::sum, keepdims);
inline Tensor collapse_sum(const Tensor& data, Array<Expr> target_shape) {
CHECK_GE(data->shape.size(), target_shape.size());
auto ishape = detail::GetConstIntValues(data->shape, "ishape");
auto oshape = detail::GetConstIntValues(target_shape, "oshape");
std::vector<int> reduce_axes;
std::vector<int> squeeze_axes;
for (int i_ax = ishape.size() - 1,
o_ax = oshape.size() - 1; i_ax >= 0; --i_ax) {
if (o_ax >= 0 && ishape[i_ax] == oshape[o_ax]) {
if (o_ax < 0) { // squeeze o_ax if was added during expansion
} else if (oshape[o_ax] == 1) {
if (reduce_axes.size() == 0) return topi::identity(data, "tensor", kCommReduce);
std::reverse(reduce_axes.begin(), reduce_axes.end());
std::reverse(squeeze_axes.begin(), squeeze_axes.end());
return DoCommReduce(data, tvm::sum, target_shape, reduce_axes, squeeze_axes);
* \brief Creates an operation that finds the minimum of elements over
* a given axis.
......@@ -4,6 +4,7 @@ from __future__ import absolute_import as _abs
import tvm
from .. import tag
from .. import generic
from .injective import _schedule_injective
def _schedule_reduce(op, sch, is_idx_reduce=False):
if is_idx_reduce:
......@@ -11,7 +12,9 @@ def _schedule_reduce(op, sch, is_idx_reduce=False):
data_in = op.input_tensors[0]
data_out = op.output(0)
assert len(sch[data_out].op.reduce_axis) > 0, "reduce_axis must be bigger than zero!"
if not sch[data_out].op.reduce_axis:
return _schedule_injective(op, sch)
if len(sch[data_out].op.axis) > 0:
all_reduce = False
......@@ -73,6 +73,7 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
out_npy = _my_npy_argmin(in_npy_map, axis=axis, keepdims=keepdims)
raise NotImplementedError
out_npy = np.atleast_1d(out_npy)
data_tvm = tvm.nd.array(in_npy, ctx=ctx)
out_tvm = tvm.nd.empty(shape=out_npy.shape, ctx=ctx, dtype=out_dtype)
for _ in range(1):
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment