Commit acc2151c by Ashok Emani Committed by Tianqi Chen

fix output_shape in conv2d_nchw (#1613)

parent e282915a
......@@ -265,7 +265,7 @@ inline tvm::Tensor conv2d_nchw(const tvm::Tensor& I,
auto pW = I->shape[3];
tvm::Array<tvm::Expr> output_shape{
I->shape[0], // B
W->shape[1], // O
W->shape[0], // O
(I->shape[2] - W->shape[2] + 2 * pad_h) / stride_h + 1, // H
(I->shape[3] - W->shape[3] + 2 * pad_w) / stride_w + 1 // W
};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment