Commit 68c03944 by Siva Committed by Tianqi Chen

[NHWC] InferShape Layout conversion fix. (#372)

parent 50c20b76
...@@ -48,7 +48,7 @@ inline bool Conv2DInferShape(const nnvm::NodeAttrs& attrs, ...@@ -48,7 +48,7 @@ inline bool Conv2DInferShape(const nnvm::NodeAttrs& attrs,
param.kernel_size[0], param.kernel_size[0],
param.kernel_size[1]}); param.kernel_size[1]});
wshape = ConvertLayout(wshape, kNCHW, param.layout); wshape = ConvertLayout(wshape, kNCHW, param.layout, true);
wshape[0] *= param.groups; wshape[0] *= param.groups;
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, Conv2DParam::kWeight, wshape); NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, Conv2DParam::kWeight, wshape);
...@@ -189,7 +189,7 @@ inline bool Conv2DTransposeInferShape(const nnvm::NodeAttrs& attrs, ...@@ -189,7 +189,7 @@ inline bool Conv2DTransposeInferShape(const nnvm::NodeAttrs& attrs,
param.channels / param.groups, param.channels / param.groups,
param.kernel_size[0], param.kernel_size[0],
param.kernel_size[1]}); param.kernel_size[1]});
wshape = ConvertLayout(wshape, kNCHW, param.layout); wshape = ConvertLayout(wshape, kNCHW, param.layout, true);
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, Conv2DTransposeParam::kWeight, wshape); NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, Conv2DTransposeParam::kWeight, wshape);
if (param.use_bias) { if (param.use_bias) {
......
...@@ -40,7 +40,7 @@ inline std::vector<std::string> UseBiasListInputNames(const NodeAttrs& attrs) { ...@@ -40,7 +40,7 @@ inline std::vector<std::string> UseBiasListInputNames(const NodeAttrs& attrs) {
* \param dst_layout target layout * \param dst_layout target layout
* \return shape in target layout * \return shape in target layout
*/ */
inline TShape ConvertLayout(TShape src, int src_layout, int dst_layout) { inline TShape ConvertLayout(TShape src, int src_layout, int dst_layout, bool is_weight = false) {
if (src_layout == dst_layout) return src; if (src_layout == dst_layout) return src;
TShape dst = src; TShape dst = src;
if (src.ndim() == 3) { if (src.ndim() == 3) {
...@@ -68,9 +68,16 @@ inline TShape ConvertLayout(TShape src, int src_layout, int dst_layout) { ...@@ -68,9 +68,16 @@ inline TShape ConvertLayout(TShape src, int src_layout, int dst_layout) {
switch (src_layout) { switch (src_layout) {
case kNCHW: break; case kNCHW: break;
case kNHWC: { case kNHWC: {
dst[2] = src[1]; if (is_weight) {
dst[3] = src[2]; dst[2] = src[0];
dst[1] = src[3]; dst[3] = src[1];
dst[1] = src[2];
dst[0] = src[3];
} else {
dst[2] = src[1];
dst[3] = src[2];
dst[1] = src[3];
}
break; break;
} }
default: { default: {
...@@ -81,9 +88,16 @@ inline TShape ConvertLayout(TShape src, int src_layout, int dst_layout) { ...@@ -81,9 +88,16 @@ inline TShape ConvertLayout(TShape src, int src_layout, int dst_layout) {
switch (dst_layout) { switch (dst_layout) {
case kNCHW: break; case kNCHW: break;
case kNHWC: { case kNHWC: {
dst[1] = src[2]; if (is_weight) {
dst[2] = src[3]; dst[0] = src[2];
dst[3] = src[1]; dst[1] = src[3];
dst[2] = src[1];
dst[3] = src[0];
} else {
dst[1] = src[2];
dst[2] = src[3];
dst[3] = src[1];
}
break; break;
} }
default: { default: {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment