Commit 68c03944 by Siva Committed by Tianqi Chen

[NHWC] InferShape Layout conversion fix. (#372)

parent 50c20b76
......@@ -48,7 +48,7 @@ inline bool Conv2DInferShape(const nnvm::NodeAttrs& attrs,
param.kernel_size[0],
param.kernel_size[1]});
wshape = ConvertLayout(wshape, kNCHW, param.layout);
wshape = ConvertLayout(wshape, kNCHW, param.layout, true);
wshape[0] *= param.groups;
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, Conv2DParam::kWeight, wshape);
......@@ -189,7 +189,7 @@ inline bool Conv2DTransposeInferShape(const nnvm::NodeAttrs& attrs,
param.channels / param.groups,
param.kernel_size[0],
param.kernel_size[1]});
wshape = ConvertLayout(wshape, kNCHW, param.layout);
wshape = ConvertLayout(wshape, kNCHW, param.layout, true);
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, Conv2DTransposeParam::kWeight, wshape);
if (param.use_bias) {
......
......@@ -40,7 +40,7 @@ inline std::vector<std::string> UseBiasListInputNames(const NodeAttrs& attrs) {
* \param dst_layout target layout
* \return shape in target layout
*/
inline TShape ConvertLayout(TShape src, int src_layout, int dst_layout) {
inline TShape ConvertLayout(TShape src, int src_layout, int dst_layout, bool is_weight = false) {
if (src_layout == dst_layout) return src;
TShape dst = src;
if (src.ndim() == 3) {
......@@ -68,9 +68,16 @@ inline TShape ConvertLayout(TShape src, int src_layout, int dst_layout) {
switch (src_layout) {
case kNCHW: break;
case kNHWC: {
dst[2] = src[1];
dst[3] = src[2];
dst[1] = src[3];
if (is_weight) {
dst[2] = src[0];
dst[3] = src[1];
dst[1] = src[2];
dst[0] = src[3];
} else {
dst[2] = src[1];
dst[3] = src[2];
dst[1] = src[3];
}
break;
}
default: {
......@@ -81,9 +88,16 @@ inline TShape ConvertLayout(TShape src, int src_layout, int dst_layout) {
switch (dst_layout) {
case kNCHW: break;
case kNHWC: {
dst[1] = src[2];
dst[2] = src[3];
dst[3] = src[1];
if (is_weight) {
dst[0] = src[2];
dst[1] = src[3];
dst[2] = src[1];
dst[3] = src[0];
} else {
dst[1] = src[2];
dst[2] = src[3];
dst[3] = src[1];
}
break;
}
default: {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment