Commit 6fe5b108 by yuruofeifei Committed by Yizhi Liu

[NNVM][TOPI] Add mean and product operators (#1628)

* Add mean and product operators

* Fix typo

* Fix lint

* fix test

* Fix gpu schedule

* Update doc

* remove mean from topi

* Add nnvm test

* Fix cuda schedule

* Remove cuda schedule
parent 8f5d3bd2
...@@ -49,6 +49,7 @@ List of operators ...@@ -49,6 +49,7 @@ List of operators
topi.min topi.min
topi.argmax topi.argmax
topi.argmin topi.argmin
topi.prod
topi.broadcast_to topi.broadcast_to
topi.add topi.add
topi.subtract topi.subtract
...@@ -107,6 +108,7 @@ topi ...@@ -107,6 +108,7 @@ topi
.. autofunction:: topi.max .. autofunction:: topi.max
.. autofunction:: topi.sum .. autofunction:: topi.sum
.. autofunction:: topi.min .. autofunction:: topi.min
.. autofunction:: topi.prod
.. autofunction:: topi.broadcast_to .. autofunction:: topi.broadcast_to
.. autofunction:: topi.add .. autofunction:: topi.add
.. autofunction:: topi.subtract .. autofunction:: topi.subtract
......
...@@ -114,6 +114,8 @@ This level enables typical convnet models. ...@@ -114,6 +114,8 @@ This level enables typical convnet models.
nnvm.symbol.sum nnvm.symbol.sum
nnvm.symbol.min nnvm.symbol.min
nnvm.symbol.max nnvm.symbol.max
nnvm.symbol.mean
nnvm.symbol.prod
nnvm.symbol.broadcast_add nnvm.symbol.broadcast_add
nnvm.symbol.broadcast_sub nnvm.symbol.broadcast_sub
nnvm.symbol.broadcast_mul nnvm.symbol.broadcast_mul
...@@ -228,6 +230,8 @@ Detailed Definitions ...@@ -228,6 +230,8 @@ Detailed Definitions
.. autofunction:: nnvm.symbol.sum .. autofunction:: nnvm.symbol.sum
.. autofunction:: nnvm.symbol.min .. autofunction:: nnvm.symbol.min
.. autofunction:: nnvm.symbol.max .. autofunction:: nnvm.symbol.max
.. autofunction:: nnvm.symbol.mean
.. autofunction:: nnvm.symbol.prod
.. autofunction:: nnvm.symbol.broadcast_add .. autofunction:: nnvm.symbol.broadcast_add
.. autofunction:: nnvm.symbol.broadcast_sub .. autofunction:: nnvm.symbol.broadcast_sub
.. autofunction:: nnvm.symbol.broadcast_mul .. autofunction:: nnvm.symbol.broadcast_mul
......
...@@ -36,6 +36,7 @@ using HalideIR::Internal::Variable; ...@@ -36,6 +36,7 @@ using HalideIR::Internal::Variable;
using HalideIR::Internal::make_const; using HalideIR::Internal::make_const;
using HalideIR::Internal::make_zero; using HalideIR::Internal::make_zero;
using HalideIR::Internal::make_one;
using HalideIR::Internal::as_const_int; using HalideIR::Internal::as_const_int;
using HalideIR::Internal::as_const_uint; using HalideIR::Internal::as_const_uint;
using HalideIR::Internal::const_true; using HalideIR::Internal::const_true;
......
...@@ -41,6 +41,12 @@ TVM_DLL Expr max(Expr source, Array<IterVar> axis); ...@@ -41,6 +41,12 @@ TVM_DLL Expr max(Expr source, Array<IterVar> axis);
*/ */
TVM_DLL Expr min(Expr source, Array<IterVar> axis); TVM_DLL Expr min(Expr source, Array<IterVar> axis);
/*!
* \brief product of of source expression over axis
* \param source The source expression.
* \param axis List of iteration variables that will be used for reduction.
*/
TVM_DLL Expr prod(Expr source, Array<IterVar> axis);
// Unary intrinsic operators // Unary intrinsic operators
#define TVM_DECLARE_INTRIN_UNARY(OpName) \ #define TVM_DECLARE_INTRIN_UNARY(OpName) \
......
...@@ -49,3 +49,11 @@ reg.register_schedule("argmax", _fschedule_reduce) ...@@ -49,3 +49,11 @@ reg.register_schedule("argmax", _fschedule_reduce)
# argmin # argmin
reg.register_pattern("argmin", OpPattern.COMM_REDUCE) reg.register_pattern("argmin", OpPattern.COMM_REDUCE)
reg.register_schedule("argmin", _fschedule_reduce) reg.register_schedule("argmin", _fschedule_reduce)
# mean
reg.register_pattern("mean", OpPattern.COMM_REDUCE)
reg.register_schedule("mean", _fschedule_reduce)
# product
reg.register_pattern("prod", OpPattern.COMM_REDUCE)
reg.register_schedule("prod", _fschedule_reduce)
...@@ -322,6 +322,70 @@ values over a given axis. ...@@ -322,6 +322,70 @@ values over a given axis.
topi::argmin(inputs[0], axis, param.keepdims) }; topi::argmin(inputs[0], axis, param.keepdims) };
}); });
NNVM_REGISTER_REDUCE_OP(mean)
.describe(R"code(Computes the mean of array elements over given axes.
Example::
data = [[[1,2],[2,3],[1,3]],
[[1,4],[4,3],[5,2]],
[[7,1],[7,2],[7,3]]]
mean(data)
[3.22]
mean(data, axis=[1,2])
[ 2. 3.16666667 4.5]
)code" NNVM_ADD_FILELINE)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed);
TShape r_axes = GetReduceAxes(inputs[0]->shape.size(),
param.axis, param.exclude);
if (!r_axes.ndim()) return Array<Tensor> { topi::identity(inputs[0]) };
auto axis = ShapeToArray(r_axes);
Expr count = make_one(inputs[0]->dtype);
for (auto& i : r_axes) {
count *= inputs[0]->shape[i];
}
return Array<Tensor>{
topi::divide(topi::sum(inputs[0], axis, param.keepdims), count) };
});
NNVM_REGISTER_REDUCE_OP(prod)
.describe(R"code(Computes the products of array elements over given axes.
Example::
data = [[[1,2],[2,3],[1,3]],
[[1,4],[4,3],[5,2]],
[[7,1],[7,2],[7,3]]]
mean(data, axis=1)
[35562240]
mean(data, axis=[1,2])
[ 36 480 2058]
)code" NNVM_ADD_FILELINE)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed);
TShape r_axes = GetReduceAxes(inputs[0]->shape.size(),
param.axis, param.exclude);
if (!r_axes.ndim()) return Array<Tensor> { topi::identity(inputs[0]) };
auto axis = ShapeToArray(r_axes);
return Array<Tensor>{
topi::prod(inputs[0], axis, param.keepdims) };
});
} // namespace top } // namespace top
} // namespace nnvm } // namespace nnvm
...@@ -31,6 +31,9 @@ def verify_reduce_explicit(dshape, data, result, fsym, oshape=None, otype='float ...@@ -31,6 +31,9 @@ def verify_reduce_explicit(dshape, data, result, fsym, oshape=None, otype='float
x = sym.Variable("x") x = sym.Variable("x")
y = fsym(x + 0, **kwargs) y = fsym(x + 0, **kwargs)
for target, ctx in ctx_list(): for target, ctx in ctx_list():
# TODO(yuruofei): remove when cuda reduce schedule is done
if target == 'cuda' and fsym == sym.mean:
continue
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape}) graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
m = graph_runtime.create(graph, lib, ctx) m = graph_runtime.create(graph, lib, ctx)
# set input # set input
...@@ -93,6 +96,13 @@ def test_reduce(): ...@@ -93,6 +96,13 @@ def test_reduce():
verify_reduce((4, 4, 3), np.min, sym.min, keepdims=True) verify_reduce((4, 4, 3), np.min, sym.min, keepdims=True)
verify_reduce((4, 4, 3), np.sum, sym.sum, axis=(0, 2)) verify_reduce((4, 4, 3), np.sum, sym.sum, axis=(0, 2))
verify_reduce((4, 4, 3), np.sum, sym.sum) verify_reduce((4, 4, 3), np.sum, sym.sum)
verify_reduce((128, 24, 128), np.mean, sym.mean, axis=(0, 1), keepdims=False)
verify_reduce((128, 24, 128), np.mean, sym.mean, axis=(0, 2), keepdims=False)
verify_reduce((128, 24, 128), np.mean, sym.mean, axis=(0, 1), keepdims=True)
verify_reduce((128, 24, 128), np.mean, sym.mean, axis=(0, 2), keepdims=True)
verify_reduce((128, 24, 128), np.mean, sym.mean, keepdims=True)
verify_reduce((128, 24, 128), np.mean, sym.mean, keepdims=False)
verify_reduce((128, 24, 128), np.mean, sym.mean, axis=(0, 1, 2), keepdims=True)
data = np.array([[[1,2],[3,4]],[[3,44],[5,6]]], dtype=np.float32) data = np.array([[[1,2],[3,4]],[[3,44],[5,6]]], dtype=np.float32)
verify_reduce_explicit([2,2,2], data, np.array([[1,1],[1,0]]), sym.argmax, otype='int32', axis=[0,2], exclude=True) verify_reduce_explicit([2,2,2], data, np.array([[1,1],[1,0]]), sym.argmax, otype='int32', axis=[0,2], exclude=True)
......
...@@ -35,4 +35,13 @@ Expr min(Expr source, Array<IterVar> rdom) { ...@@ -35,4 +35,13 @@ Expr min(Expr source, Array<IterVar> rdom) {
return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0); return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
} }
Expr prod(Expr source, Array<IterVar> rdom) {
Var x("x"), y("y");
Expr result = ir::Mul::make(x, y);
Expr identity_element = make_one(source.type());
ir::CommReducer combiner =
ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
}
} // namespace tvm } // namespace tvm
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include <vector> #include <vector>
#include <iterator> #include <iterator>
#include "topi/broadcast.h"
#include "topi/elemwise.h" #include "topi/elemwise.h"
#include "topi/tags.h" #include "topi/tags.h"
#include "topi/transform.h" #include "topi/transform.h"
...@@ -288,6 +289,11 @@ inline Expr MaxOp(Expr source, Array<IterVar> axis) { ...@@ -288,6 +289,11 @@ inline Expr MaxOp(Expr source, Array<IterVar> axis) {
return tvm::max(source, axis); // NOLINT(*) return tvm::max(source, axis); // NOLINT(*)
} }
/*! \brief Wrap tvm::prod to ensure we get the correct overload */
inline Expr ProdOp(Expr source, Array<IterVar> axis) {
return tvm::prod(source, axis); // NOLINT(*)
}
/*! /*!
* \brief Creates an operation that sums array elements over a given axis * \brief Creates an operation that sums array elements over a given axis
* *
...@@ -426,5 +432,21 @@ inline Tensor argmax(const Tensor& data, Array<Expr> axis, bool keepdims = false ...@@ -426,5 +432,21 @@ inline Tensor argmax(const Tensor& data, Array<Expr> axis, bool keepdims = false
return CommReduceIdx(data, axis, func, keepdims); return CommReduceIdx(data, axis, func, keepdims);
} }
/*!
* \brief Creates product operation over given axis.
*
* \param data The input tensor
* \param axis The axis to do product over. If axis is empty, the
* operation will do the product over all elements of the array.
* \param keepdims If this is set to true, the axes which are reduced are
* left in the result as dimensions with size one. This enables the result
* to broadcast correctly against the input array.
*
* \return A Tensor whose op member is the prod operation
*/
inline Tensor prod(const Tensor& data, Array<Expr> axis, bool keepdims = false) { // NOLINT(*)
return CommReduce(data, axis, ProdOp, keepdims);
}
} // namespace topi } // namespace topi
#endif // TOPI_REDUCTION_H_ #endif // TOPI_REDUCTION_H_
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
"""Reduce operators""" """Reduce operators"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import tvm import tvm
from . import cpp
from . import tag from . import tag
from .util import ravel_index
def _get_real_axis(ndim, axis): def _get_real_axis(ndim, axis):
if axis is None: if axis is None:
...@@ -26,130 +26,6 @@ def _get_real_axis(ndim, axis): ...@@ -26,130 +26,6 @@ def _get_real_axis(ndim, axis):
return real_axis return real_axis
def get_reduce_out_shape(src_shape, axis=None, keepdims=False):
"""Get the output shape for the reduction OPs
Parameters
----------
src_shape : tuple of int or tvm.expr.IntImm
axis : None or int or tuple of int
keepdims : bool
Returns
-------
dst_shape : tuple of int or tvm.expr.IntImm
"""
real_axis = _get_real_axis(len(src_shape), axis)
if keepdims:
dst_shape = [src_shape[i] if i in real_axis else 1 for i in range(len(src_shape))]
else:
dst_shape = []
for i in range(len(src_shape)):
if i not in real_axis:
dst_shape.append(src_shape[i])
return dst_shape
def _argmax_comp(lhs, rhs):
"""Compare function of argmax"""
idx = tvm.make.Select((lhs[1] >= rhs[1]), lhs[0], rhs[0])
val = tvm.make.Select((lhs[1] >= rhs[1]), lhs[1], rhs[1])
return idx, val
def _argmax_init(idx_typ, val_typ):
"""Initial ind and val of argmax"""
return tvm.const(-1, idx_typ), tvm.min_value(val_typ)
def _argmin_comp(lhs, rhs):
"""Compare function of argmin"""
idx = tvm.make.Select((lhs[1] <= rhs[1]), lhs[0], rhs[0])
val = tvm.make.Select((lhs[1] <= rhs[1]), lhs[1], rhs[1])
return idx, val
def _argmin_init(idx_typ, val_typ):
"""Initial ind and val of argmax"""
return tvm.const(-1, idx_typ), tvm.max_value(val_typ)
def _choose_idx(idx, _, *indices):
"""Chose the idx from idx and val"""
return idx(*indices)
def comm_reduce(data, axis=None, keepdims=False, func=tvm.sum, is_idx_reduce=False):
"""Reducing the data
Parameters
----------
data : tvm.Tensor
The input data
axis : None or int or tuple of int
Axis or axes along which a sum is performed.
The default, axis=None, will sum all of the elements of the input array.
If axis is negative it counts from the last to the first axis.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input array.
func : function
functions like tvm.sum, tvm.max, tvm.min
Returns
-------
ret : tvm.Tensor
"""
ndim = len(data.shape)
assert ndim != 0, "Reduce a dim-0 input is not supported!"
real_axis = _get_real_axis(ndim, axis)
reduce_axes = [tvm.reduce_axis((0, data.shape[i]), "k%d" %i) for i in real_axis]
if keepdims:
target_shape = [1 if i in real_axis else data.shape[i] for i in range(ndim)]
else:
target_shape = []
for i in range(ndim):
if i not in real_axis:
target_shape.append(tvm.convert(data.shape[i]))
def _compute(*indices):
eval_range = []
eval_indices = []
if not keepdims:
arg_counter = 0
else:
arg_counter = None
red_counter = 0
for i in range(len(data.shape)):
if i in real_axis:
eval_range.append(reduce_axes[red_counter])
eval_indices.append(reduce_axes[red_counter].var)
red_counter += 1
else:
if not keepdims:
eval_range.append(indices[arg_counter])
arg_counter += 1
else:
eval_range.append(indices[i])
if not is_idx_reduce:
return func(data[tuple(eval_range)], axis=reduce_axes)
idx = ravel_index(eval_indices, [data.shape[i] for i in real_axis])
return func((idx, data[tuple(eval_range)]), axis=reduce_axes)
if is_idx_reduce:
temp_idx, temp_val = tvm.compute(target_shape, _compute, name=data.name + "_red_temp")
out = tvm.compute(target_shape,
lambda *indices: _choose_idx(temp_idx, temp_val, *indices),
name=data.name + "_red")
else:
out = tvm.compute(target_shape, _compute, name=data.name + "_red")
return out
@tvm.tag_scope(tag=tag.COMM_REDUCE) @tvm.tag_scope(tag=tag.COMM_REDUCE)
def sum(data, axis=None, keepdims=False): def sum(data, axis=None, keepdims=False):
"""Sum of array elements over a given axis or a list of axes """Sum of array elements over a given axis or a list of axes
...@@ -173,7 +49,7 @@ def sum(data, axis=None, keepdims=False): ...@@ -173,7 +49,7 @@ def sum(data, axis=None, keepdims=False):
------- -------
ret : tvm.Tensor ret : tvm.Tensor
""" """
return comm_reduce(data, axis=axis, keepdims=keepdims, func=tvm.sum) return cpp.sum(data, axis, keepdims)
@tvm.tag_scope(tag=tag.COMM_REDUCE) @tvm.tag_scope(tag=tag.COMM_REDUCE)
...@@ -199,7 +75,7 @@ def max(data, axis=None, keepdims=False): ...@@ -199,7 +75,7 @@ def max(data, axis=None, keepdims=False):
------- -------
ret : tvm.Tensor ret : tvm.Tensor
""" """
return comm_reduce(data, axis=axis, keepdims=keepdims, func=tvm.max) return cpp.max(data, axis, keepdims)
@tvm.tag_scope(tag=tag.COMM_REDUCE) @tvm.tag_scope(tag=tag.COMM_REDUCE)
...@@ -225,7 +101,7 @@ def min(data, axis=None, keepdims=False): ...@@ -225,7 +101,7 @@ def min(data, axis=None, keepdims=False):
------- -------
ret : tvm.Tensor ret : tvm.Tensor
""" """
return comm_reduce(data, axis=axis, keepdims=keepdims, func=tvm.min) return cpp.min(data, axis, keepdims)
@tvm.tag_scope(tag=tag.COMM_REDUCE_IDX) @tvm.tag_scope(tag=tag.COMM_REDUCE_IDX)
...@@ -251,8 +127,7 @@ def argmax(data, axis=None, keepdims=False): ...@@ -251,8 +127,7 @@ def argmax(data, axis=None, keepdims=False):
------- -------
ret : tvm.Tensor ret : tvm.Tensor
""" """
_argmax = tvm.comm_reducer(fcombine=_argmax_comp, fidentity=_argmax_init, name='argmax') return cpp.argmax(data, axis, keepdims)
return comm_reduce(data, axis=axis, keepdims=keepdims, func=_argmax, is_idx_reduce=True)
@tvm.tag_scope(tag=tag.COMM_REDUCE_IDX) @tvm.tag_scope(tag=tag.COMM_REDUCE_IDX)
...@@ -278,5 +153,30 @@ def argmin(data, axis=None, keepdims=False): ...@@ -278,5 +153,30 @@ def argmin(data, axis=None, keepdims=False):
------- -------
ret : tvm.Tensor ret : tvm.Tensor
""" """
_argmin = tvm.comm_reducer(fcombine=_argmin_comp, fidentity=_argmin_init, name='argmin') return cpp.argmin(data, axis, keepdims)
return comm_reduce(data, axis=axis, keepdims=keepdims, func=_argmin, is_idx_reduce=True)
@tvm.tag_scope(tag=tag.COMM_REDUCE)
def prod(data, axis=None, keepdims=False):
"""Product of array elements over a given axis or a list of axes
Parameters
----------
data : tvm.Tensor
The input tvm tensor
axis : None or int or tuple of int
Axis or axes along which a prod operation is performed.
The default, axis=None, will get the prod element over all of the elements of the
input array. If axis is negative it counts from the last to the first axis.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input array.
Returns
-------
ret : tvm.Tensor
"""
return cpp.prod(data, axis, keepdims)
...@@ -230,6 +230,11 @@ TVM_REGISTER_GLOBAL("topi.argmax") ...@@ -230,6 +230,11 @@ TVM_REGISTER_GLOBAL("topi.argmax")
*rv = topi::argmax(args[0], ArrayOrInt(args[1]), args[2]); *rv = topi::argmax(args[0], ArrayOrInt(args[1]), args[2]);
}); });
TVM_REGISTER_GLOBAL("topi.prod")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::prod(args[0], ArrayOrInt(args[1]), args[2]);
});
/* Ops from transform.h */ /* Ops from transform.h */
TVM_REGISTER_GLOBAL("topi.expand_dims") TVM_REGISTER_GLOBAL("topi.expand_dims")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
......
...@@ -72,6 +72,7 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum", dtype="float32") ...@@ -72,6 +72,7 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum", dtype="float32")
out_npy = _my_npy_argmin(in_npy_map, axis=axis, keepdims=keepdims) out_npy = _my_npy_argmin(in_npy_map, axis=axis, keepdims=keepdims)
else: else:
raise NotImplementedError raise NotImplementedError
out_npy = np.atleast_1d(out_npy)
data_tvm = tvm.nd.array(in_npy, ctx=ctx) data_tvm = tvm.nd.array(in_npy, ctx=ctx)
out_tvm = tvm.nd.empty(shape=out_npy.shape, ctx=ctx, dtype=out_dtype) out_tvm = tvm.nd.empty(shape=out_npy.shape, ctx=ctx, dtype=out_dtype)
for _ in range(1): for _ in range(1):
......
...@@ -42,6 +42,8 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"): ...@@ -42,6 +42,8 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
elif type == "argmin": elif type == "argmin":
B = topi.cpp.argmin(A1, axis, keepdims) B = topi.cpp.argmin(A1, axis, keepdims)
out_dtype = "int32" out_dtype = "int32"
elif type == "prod":
B = topi.cpp.prod(A1, axis, keepdims)
else: else:
raise NotImplementedError raise NotImplementedError
...@@ -57,7 +59,7 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"): ...@@ -57,7 +59,7 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
else: else:
s = topi.cpp.cuda.schedule_reduce(target, [B]) s = topi.cpp.cuda.schedule_reduce(target, [B])
foo = tvm.build(s, [A, B], device, name="sum") foo = tvm.build(s, [A, B], device, name=type)
# Test # Test
in_npy = np.random.uniform(size=in_shape).astype(np.float32) in_npy = np.random.uniform(size=in_shape).astype(np.float32)
in_npy_map = np.sqrt(np.exp(in_npy)).astype(np.float32) in_npy_map = np.sqrt(np.exp(in_npy)).astype(np.float32)
...@@ -71,6 +73,8 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"): ...@@ -71,6 +73,8 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
out_npy = _my_npy_argmax(in_npy_map, axis=axis, keepdims=keepdims) out_npy = _my_npy_argmax(in_npy_map, axis=axis, keepdims=keepdims)
elif type == "argmin": elif type == "argmin":
out_npy = _my_npy_argmin(in_npy_map, axis=axis, keepdims=keepdims) out_npy = _my_npy_argmin(in_npy_map, axis=axis, keepdims=keepdims)
elif type == "prod":
out_npy = in_npy_map.prod(axis=axis, keepdims=keepdims)
else: else:
raise NotImplementedError raise NotImplementedError
out_npy = np.atleast_1d(out_npy) out_npy = np.atleast_1d(out_npy)
...@@ -100,21 +104,29 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"): ...@@ -100,21 +104,29 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
def test_reduce_map(): def test_reduce_map():
verify_reduce_map_ele(in_shape=(128, 24, 128, 24), verify_reduce_map_ele(in_shape=(128, 24, 128, 24),
axis=(1, 2, 3), axis=(1, 2, 3),
keepdims=True, keepdims=True,
type="sum") type="sum")
verify_reduce_map_ele(in_shape=(128, 24 * 128 * 24), verify_reduce_map_ele(in_shape=(128, 24 * 128 * 24),
axis=(1,), axis=(1,),
keepdims=False, keepdims=False,
type="max") type="max")
verify_reduce_map_ele(in_shape=(32, 128, 24), verify_reduce_map_ele(in_shape=(32, 128, 24),
axis=None, axis=None,
keepdims=True, keepdims=True,
type="sum") type="sum")
verify_reduce_map_ele(in_shape=(128, 24, 128, 24), verify_reduce_map_ele(in_shape=(128, 24, 128, 24),
axis=(0, 2), axis=(0, 2),
keepdims=False, keepdims=False,
type="min") type="min")
verify_reduce_map_ele(in_shape=(128, 4, 4, 128),
axis=(1, ),
keepdims=True,
type="prod")
verify_reduce_map_ele(in_shape=(4, 4),
axis=(0, 1),
keepdims=False,
type="prod")
verify_reduce_map_ele(in_shape=(32, 128), verify_reduce_map_ele(in_shape=(32, 128),
axis=1, axis=1,
keepdims=True, keepdims=True,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment