Commit 420ec786 by Haichen Shen Committed by Tianqi Chen

[NNVM][OP] Allow two input tensors with different type in reshape_like op (#2052)

parent c8245e9a
...@@ -290,7 +290,8 @@ _identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__', ...@@ -290,7 +290,8 @@ _identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__',
'elemwise_div', 'elemwise_mul', 'elemwise_sub', 'exp', 'elemwise_div', 'elemwise_mul', 'elemwise_sub', 'exp',
'flatten', 'log', 'log_softmax', 'max', 'min', 'negative', 'flatten', 'log', 'log_softmax', 'max', 'min', 'negative',
'ones_like', 'relu', 'sigmoid', 'slice_like', 'softmax', 'ones_like', 'relu', 'sigmoid', 'slice_like', 'softmax',
'sum', 'tanh', 'transpose', 'zeros_like', 'gather_nd'] 'sum', 'tanh', 'transpose', 'zeros_like', 'gather_nd',
'reshape_like']
_convert_map = { _convert_map = {
'_copy' : _rename('copy'), '_copy' : _rename('copy'),
......
...@@ -631,6 +631,15 @@ The significance of each is explained below: ...@@ -631,6 +631,15 @@ The significance of each is explained below:
}) })
.set_support_level(3); .set_support_level(3);
inline bool ReshapeLikeInferType(const NodeAttrs &attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);
NNVM_ASSIGN_OUTPUT_TYPE(attrs, *out_attrs, 0, (*in_attrs)[0]);
return true;
}
NNVM_REGISTER_OP(reshape_like) NNVM_REGISTER_OP(reshape_like)
.describe(R"code(Reshapes the input array by the size of another array. .describe(R"code(Reshapes the input array by the size of another array.
For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` operation reshapes For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` operation reshapes
...@@ -651,7 +660,7 @@ the input array into an output array with the same shape as the second input arr ...@@ -651,7 +660,7 @@ the input array into an output array with the same shape as the second input arr
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, in_attrs->at(1)); NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, in_attrs->at(1));
return true; return true;
}) })
.set_attr<FInferType>("FInferType", ElemwiseType<2, 1>) .set_attr<FInferType>("FInferType", ReshapeLikeInferType)
// never transform layout of the second input array. // never transform layout of the second input array.
.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseFixedLayoutUnknownOut<1, 1>) .set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseFixedLayoutUnknownOut<1, 1>)
.set_attr<FGradient>( .set_attr<FGradient>(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment