Commit 644a15c3 by MORINAGA Committed by Tianqi Chen

[Frontend][MXNet] argmax, argmin ops support (#2048)

parent ac37abf9
......@@ -205,6 +205,7 @@ struct ReduceParam : public dmlc::Parameter<ReduceParam> {
TShape axis;
bool keepdims;
bool exclude;
int dtype;
DMLC_DECLARE_PARAMETER(ReduceParam) {
DMLC_DECLARE_FIELD(axis).set_default(TShape())
......@@ -226,6 +227,8 @@ struct ReduceParam : public dmlc::Parameter<ReduceParam> {
"in the result as dimension with size one.");
DMLC_DECLARE_FIELD(exclude).set_default(false)
.describe("Whether to perform reduction on axis that are NOT in axis instead.");
DMLC_DECLARE_DTYPE_FIELD(dtype).set_default(kInt32)
.describe("Target data type.");
}
};
......
......@@ -259,12 +259,12 @@ def _crop_like(inputs, attrs):
def _expand_dims(inputs, attrs):
op_name, new_attrs = "expand_dims", {}
op_name, new_attrs = 'expand_dims', {}
new_attrs['axis'] = _required_attr(attrs, 'axis')
return _get_nnvm_op(op_name)(*inputs, **new_attrs)
def _lrn(inputs, attrs):
op_name, new_attrs = "lrn", {}
op_name, new_attrs = 'lrn', {}
new_attrs['alpha'] = attrs.get('alpha', 0.0001)
new_attrs['beta'] = attrs.get('beta', 0.75)
new_attrs['bias'] = attrs.get('knorm', 2)
......@@ -274,13 +274,27 @@ def _lrn(inputs, attrs):
return _get_nnvm_op(op_name)(*inputs, **new_attrs)
def _ones(_, attrs):
op_name = "ones"
op_name = 'ones'
return _get_nnvm_op(op_name)(**attrs)
def _zeros(_, attrs):
op_name = "zeros"
op_name = 'zeros'
return _get_nnvm_op(op_name)(**attrs)
def _argmax(inputs, attrs):
op_name, new_attrs = 'argmax', {}
new_attrs['dtype'] = 'float32'
new_attrs['axis'] = attrs.get('axis', 0)
new_attrs['keepdims'] = _parse_bool_str(attrs, 'keepdims', default="False")
return _get_nnvm_op(op_name)(*inputs, **new_attrs)
def _argmin(inputs, attrs):
op_name, new_attrs = 'argmin', {}
new_attrs['dtype'] = 'float32'
new_attrs['axis'] = attrs.get('axis', 0)
new_attrs['keepdims'] = _parse_bool_str(attrs, 'keepdims', default="False")
return _get_nnvm_op(op_name)(*inputs, **new_attrs)
_identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__',
'__div_symbol__', '__mul_scalar__', '__mul_symbol__',
'__pow_scalar__', '__rdiv_scalar__', '__rpow_scalar__',
......@@ -303,8 +317,10 @@ _convert_map = {
'_rminus_scalar': _rename('__rsub_scalar__'),
'_contrib_MultiBoxPrior' : _rename('multibox_prior'),
'_contrib_MultiBoxDetection' : _contrib_multibox_detection,
'_ones' : _ones,
'_zeros' : _zeros,
'_ones' : _ones,
'_zeros' : _zeros,
'argmax' : _argmax,
'argmin' : _argmin,
'Activation' : _activations,
'BatchNorm' : _batch_norm,
'BatchNorm_v1' : _batch_norm,
......
......@@ -272,15 +272,13 @@ NNVM_REGISTER_BASE_REDUCE_OP(collapse_sum)
return Array<Tensor>{ topi::collapse_sum(inputs[0], inputs[1]->shape) };
});
template<int Type>
inline bool InferFixedType(const NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
// Static type inference for argmax operation. Argmax return indices which
// should have Int32 type as shapes do.
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
NNVM_ASSIGN_OUTPUT_TYPE(attrs, *out_attrs, 0, static_cast<int>(Type));
const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed);
NNVM_ASSIGN_OUTPUT_TYPE(attrs, *out_attrs, 0, param.dtype);
return true;
}
......@@ -291,7 +289,7 @@ values over a given axis.
)code" NNVM_ADD_FILELINE)
.add_argument("data", "Tensor", "The input")
.set_attr<FInferShape>("FInferShape", ReduceShape)
.set_attr<FInferType>("FInferType", InferFixedType<kInt32>)
.set_attr<FInferType>("FInferType", InferFixedType)
.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseFixedLayoutUnknownOut<1, 1>)
.set_num_inputs(1)
.set_attr<FTVMCompute>(
......@@ -302,8 +300,9 @@ values over a given axis.
TShape r_axes = GetReduceAxes(inputs[0]->shape.size(),
param.axis, param.exclude);
auto axis = ShapeToArray(r_axes);
return Array<Tensor>{
topi::argmax(inputs[0], axis, param.keepdims) };
Tensor out = topi::argmax(inputs[0], axis, param.keepdims);
if (param.dtype == kFloat32) out = topi::cast(out, out_info[0]->dtype);
return Array<Tensor>{out};
});
NNVM_REGISTER_BASE_REDUCE_OP(argmin)
......@@ -313,7 +312,7 @@ values over a given axis.
)code" NNVM_ADD_FILELINE)
.add_argument("data", "Tensor", "The input")
.set_attr<FInferShape>("FInferShape", ReduceShape)
.set_attr<FInferType>("FInferType", InferFixedType<kInt32>)
.set_attr<FInferType>("FInferType", InferFixedType)
.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseFixedLayoutUnknownOut<1, 1>)
.set_num_inputs(1)
.set_attr<FTVMCompute>(
......@@ -324,8 +323,9 @@ values over a given axis.
TShape r_axes = GetReduceAxes(inputs[0]->shape.size(),
param.axis, param.exclude);
auto axis = ShapeToArray(r_axes);
return Array<Tensor>{
topi::argmin(inputs[0], axis, param.keepdims) };
Tensor out = topi::argmin(inputs[0], axis, param.keepdims);
if (param.dtype == kFloat32) out = topi::cast(out, out_info[0]->dtype);
return Array<Tensor>{out};
});
NNVM_REGISTER_REDUCE_OP(mean)
......
......@@ -174,6 +174,16 @@ def test_forward_zeros_like():
data = mx.sym.var('data')
mx_sym = mx.sym.zeros_like(data, dtype='float32')
verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4))
def test_forward_argmax():
data = mx.sym.var('data')
mx_sym = mx.sym.argmax(data, axis=1)
verify_mxnet_frontend_impl(mx_sym, (5, 3), (5,))
def test_forward_argmin():
data = mx.sym.var('data')
mx_sym = mx.sym.argmin(data, axis=0)
verify_mxnet_frontend_impl(mx_sym, (5, 4), (4,))
if __name__ == '__main__':
test_forward_mlp()
......@@ -194,3 +204,6 @@ if __name__ == '__main__':
test_forward_zeros()
test_forward_ones_like()
test_forward_zeros_like()
test_forward_argmax()
test_forward_argmin()
......@@ -107,7 +107,10 @@ def schedule_reduce(outs):
def traverse_after_reduce(operator):
"""Internal travserse function"""
if tag.is_broadcast(operator.tag):
raise RuntimeError("Not yet support ewise after reduce")
if operator not in scheduled_ops:
_schedule_injective(operator, sch)
for tensor in operator.input_tensors:
traverse_after_reduce(tensor.op)
elif operator.tag == 'comm_reduce':
_schedule_reduce(operator, sch, is_idx_reduce=False)
for tensor in operator.input_tensors:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment