Commit 644a15c3 by MORINAGA Committed by Tianqi Chen

[Frontend][MXNet] argmax, argmin ops support (#2048)

parent ac37abf9
...@@ -205,6 +205,7 @@ struct ReduceParam : public dmlc::Parameter<ReduceParam> { ...@@ -205,6 +205,7 @@ struct ReduceParam : public dmlc::Parameter<ReduceParam> {
TShape axis; TShape axis;
bool keepdims; bool keepdims;
bool exclude; bool exclude;
int dtype;
DMLC_DECLARE_PARAMETER(ReduceParam) { DMLC_DECLARE_PARAMETER(ReduceParam) {
DMLC_DECLARE_FIELD(axis).set_default(TShape()) DMLC_DECLARE_FIELD(axis).set_default(TShape())
...@@ -226,6 +227,8 @@ struct ReduceParam : public dmlc::Parameter<ReduceParam> { ...@@ -226,6 +227,8 @@ struct ReduceParam : public dmlc::Parameter<ReduceParam> {
"in the result as dimension with size one."); "in the result as dimension with size one.");
DMLC_DECLARE_FIELD(exclude).set_default(false) DMLC_DECLARE_FIELD(exclude).set_default(false)
.describe("Whether to perform reduction on axis that are NOT in axis instead."); .describe("Whether to perform reduction on axis that are NOT in axis instead.");
DMLC_DECLARE_DTYPE_FIELD(dtype).set_default(kInt32)
.describe("Target data type.");
} }
}; };
......
...@@ -259,12 +259,12 @@ def _crop_like(inputs, attrs): ...@@ -259,12 +259,12 @@ def _crop_like(inputs, attrs):
def _expand_dims(inputs, attrs): def _expand_dims(inputs, attrs):
op_name, new_attrs = "expand_dims", {} op_name, new_attrs = 'expand_dims', {}
new_attrs['axis'] = _required_attr(attrs, 'axis') new_attrs['axis'] = _required_attr(attrs, 'axis')
return _get_nnvm_op(op_name)(*inputs, **new_attrs) return _get_nnvm_op(op_name)(*inputs, **new_attrs)
def _lrn(inputs, attrs): def _lrn(inputs, attrs):
op_name, new_attrs = "lrn", {} op_name, new_attrs = 'lrn', {}
new_attrs['alpha'] = attrs.get('alpha', 0.0001) new_attrs['alpha'] = attrs.get('alpha', 0.0001)
new_attrs['beta'] = attrs.get('beta', 0.75) new_attrs['beta'] = attrs.get('beta', 0.75)
new_attrs['bias'] = attrs.get('knorm', 2) new_attrs['bias'] = attrs.get('knorm', 2)
...@@ -274,13 +274,27 @@ def _lrn(inputs, attrs): ...@@ -274,13 +274,27 @@ def _lrn(inputs, attrs):
return _get_nnvm_op(op_name)(*inputs, **new_attrs) return _get_nnvm_op(op_name)(*inputs, **new_attrs)
def _ones(_, attrs): def _ones(_, attrs):
op_name = "ones" op_name = 'ones'
return _get_nnvm_op(op_name)(**attrs) return _get_nnvm_op(op_name)(**attrs)
def _zeros(_, attrs): def _zeros(_, attrs):
op_name = "zeros" op_name = 'zeros'
return _get_nnvm_op(op_name)(**attrs) return _get_nnvm_op(op_name)(**attrs)
def _argmax(inputs, attrs):
op_name, new_attrs = 'argmax', {}
new_attrs['dtype'] = 'float32'
new_attrs['axis'] = attrs.get('axis', 0)
new_attrs['keepdims'] = _parse_bool_str(attrs, 'keepdims', default="False")
return _get_nnvm_op(op_name)(*inputs, **new_attrs)
def _argmin(inputs, attrs):
op_name, new_attrs = 'argmin', {}
new_attrs['dtype'] = 'float32'
new_attrs['axis'] = attrs.get('axis', 0)
new_attrs['keepdims'] = _parse_bool_str(attrs, 'keepdims', default="False")
return _get_nnvm_op(op_name)(*inputs, **new_attrs)
_identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__', _identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__',
'__div_symbol__', '__mul_scalar__', '__mul_symbol__', '__div_symbol__', '__mul_scalar__', '__mul_symbol__',
'__pow_scalar__', '__rdiv_scalar__', '__rpow_scalar__', '__pow_scalar__', '__rdiv_scalar__', '__rpow_scalar__',
...@@ -305,6 +319,8 @@ _convert_map = { ...@@ -305,6 +319,8 @@ _convert_map = {
'_contrib_MultiBoxDetection' : _contrib_multibox_detection, '_contrib_MultiBoxDetection' : _contrib_multibox_detection,
'_ones' : _ones, '_ones' : _ones,
'_zeros' : _zeros, '_zeros' : _zeros,
'argmax' : _argmax,
'argmin' : _argmin,
'Activation' : _activations, 'Activation' : _activations,
'BatchNorm' : _batch_norm, 'BatchNorm' : _batch_norm,
'BatchNorm_v1' : _batch_norm, 'BatchNorm_v1' : _batch_norm,
......
...@@ -272,15 +272,13 @@ NNVM_REGISTER_BASE_REDUCE_OP(collapse_sum) ...@@ -272,15 +272,13 @@ NNVM_REGISTER_BASE_REDUCE_OP(collapse_sum)
return Array<Tensor>{ topi::collapse_sum(inputs[0], inputs[1]->shape) }; return Array<Tensor>{ topi::collapse_sum(inputs[0], inputs[1]->shape) };
}); });
template<int Type>
inline bool InferFixedType(const NodeAttrs& attrs, inline bool InferFixedType(const NodeAttrs& attrs,
std::vector<int>* in_attrs, std::vector<int>* in_attrs,
std::vector<int>* out_attrs) { std::vector<int>* out_attrs) {
// Static type inference for argmax operation. Argmax return indices which
// should have Int32 type as shapes do.
CHECK_EQ(in_attrs->size(), 1U); CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U); CHECK_EQ(out_attrs->size(), 1U);
NNVM_ASSIGN_OUTPUT_TYPE(attrs, *out_attrs, 0, static_cast<int>(Type)); const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed);
NNVM_ASSIGN_OUTPUT_TYPE(attrs, *out_attrs, 0, param.dtype);
return true; return true;
} }
...@@ -291,7 +289,7 @@ values over a given axis. ...@@ -291,7 +289,7 @@ values over a given axis.
)code" NNVM_ADD_FILELINE) )code" NNVM_ADD_FILELINE)
.add_argument("data", "Tensor", "The input") .add_argument("data", "Tensor", "The input")
.set_attr<FInferShape>("FInferShape", ReduceShape) .set_attr<FInferShape>("FInferShape", ReduceShape)
.set_attr<FInferType>("FInferType", InferFixedType<kInt32>) .set_attr<FInferType>("FInferType", InferFixedType)
.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseFixedLayoutUnknownOut<1, 1>) .set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseFixedLayoutUnknownOut<1, 1>)
.set_num_inputs(1) .set_num_inputs(1)
.set_attr<FTVMCompute>( .set_attr<FTVMCompute>(
...@@ -302,8 +300,9 @@ values over a given axis. ...@@ -302,8 +300,9 @@ values over a given axis.
TShape r_axes = GetReduceAxes(inputs[0]->shape.size(), TShape r_axes = GetReduceAxes(inputs[0]->shape.size(),
param.axis, param.exclude); param.axis, param.exclude);
auto axis = ShapeToArray(r_axes); auto axis = ShapeToArray(r_axes);
return Array<Tensor>{ Tensor out = topi::argmax(inputs[0], axis, param.keepdims);
topi::argmax(inputs[0], axis, param.keepdims) }; if (param.dtype == kFloat32) out = topi::cast(out, out_info[0]->dtype);
return Array<Tensor>{out};
}); });
NNVM_REGISTER_BASE_REDUCE_OP(argmin) NNVM_REGISTER_BASE_REDUCE_OP(argmin)
...@@ -313,7 +312,7 @@ values over a given axis. ...@@ -313,7 +312,7 @@ values over a given axis.
)code" NNVM_ADD_FILELINE) )code" NNVM_ADD_FILELINE)
.add_argument("data", "Tensor", "The input") .add_argument("data", "Tensor", "The input")
.set_attr<FInferShape>("FInferShape", ReduceShape) .set_attr<FInferShape>("FInferShape", ReduceShape)
.set_attr<FInferType>("FInferType", InferFixedType<kInt32>) .set_attr<FInferType>("FInferType", InferFixedType)
.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseFixedLayoutUnknownOut<1, 1>) .set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseFixedLayoutUnknownOut<1, 1>)
.set_num_inputs(1) .set_num_inputs(1)
.set_attr<FTVMCompute>( .set_attr<FTVMCompute>(
...@@ -324,8 +323,9 @@ values over a given axis. ...@@ -324,8 +323,9 @@ values over a given axis.
TShape r_axes = GetReduceAxes(inputs[0]->shape.size(), TShape r_axes = GetReduceAxes(inputs[0]->shape.size(),
param.axis, param.exclude); param.axis, param.exclude);
auto axis = ShapeToArray(r_axes); auto axis = ShapeToArray(r_axes);
return Array<Tensor>{ Tensor out = topi::argmin(inputs[0], axis, param.keepdims);
topi::argmin(inputs[0], axis, param.keepdims) }; if (param.dtype == kFloat32) out = topi::cast(out, out_info[0]->dtype);
return Array<Tensor>{out};
}); });
NNVM_REGISTER_REDUCE_OP(mean) NNVM_REGISTER_REDUCE_OP(mean)
......
...@@ -175,6 +175,16 @@ def test_forward_zeros_like(): ...@@ -175,6 +175,16 @@ def test_forward_zeros_like():
mx_sym = mx.sym.zeros_like(data, dtype='float32') mx_sym = mx.sym.zeros_like(data, dtype='float32')
verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4)) verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4))
def test_forward_argmax():
data = mx.sym.var('data')
mx_sym = mx.sym.argmax(data, axis=1)
verify_mxnet_frontend_impl(mx_sym, (5, 3), (5,))
def test_forward_argmin():
data = mx.sym.var('data')
mx_sym = mx.sym.argmin(data, axis=0)
verify_mxnet_frontend_impl(mx_sym, (5, 4), (4,))
if __name__ == '__main__': if __name__ == '__main__':
test_forward_mlp() test_forward_mlp()
test_forward_vgg() test_forward_vgg()
...@@ -194,3 +204,6 @@ if __name__ == '__main__': ...@@ -194,3 +204,6 @@ if __name__ == '__main__':
test_forward_zeros() test_forward_zeros()
test_forward_ones_like() test_forward_ones_like()
test_forward_zeros_like() test_forward_zeros_like()
test_forward_argmax()
test_forward_argmin()
...@@ -107,7 +107,10 @@ def schedule_reduce(outs): ...@@ -107,7 +107,10 @@ def schedule_reduce(outs):
def traverse_after_reduce(operator): def traverse_after_reduce(operator):
"""Internal travserse function""" """Internal travserse function"""
if tag.is_broadcast(operator.tag): if tag.is_broadcast(operator.tag):
raise RuntimeError("Not yet support ewise after reduce") if operator not in scheduled_ops:
_schedule_injective(operator, sch)
for tensor in operator.input_tensors:
traverse_after_reduce(tensor.op)
elif operator.tag == 'comm_reduce': elif operator.tag == 'comm_reduce':
_schedule_reduce(operator, sch, is_idx_reduce=False) _schedule_reduce(operator, sch, is_idx_reduce=False)
for tensor in operator.input_tensors: for tensor in operator.input_tensors:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment