Commit 2374852e by guopinglong Committed by Tianqi Chen

make shape inference of BatchNorm layout neutral (#301)

* make shape inference of BatchNorm layout neutral

* refactor to use the axis variable to do BatchNorm shape inference

* refactor to use the axis variable to do BatchNorm shape inference

* add unittest to the axis param for batch norm shape inference
parent 7112dd78
...@@ -117,12 +117,15 @@ DMLC_REGISTER_PARAMETER(BatchNormParam); ...@@ -117,12 +117,15 @@ DMLC_REGISTER_PARAMETER(BatchNormParam);
inline bool BatchNormInferShape(const nnvm::NodeAttrs& attrs, inline bool BatchNormInferShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape>* in_shape, std::vector<TShape>* in_shape,
std::vector<TShape>* out_shape) { std::vector<TShape>* out_shape) {
const BatchNormParam& param = nnvm::get<BatchNormParam>(attrs.parsed);
CHECK_EQ(in_shape->size(), 5U) CHECK_EQ(in_shape->size(), 5U)
<< "Input:[data, gamma, beta, moving_mean, moving_var]"; << "Input:[data, gamma, beta, moving_mean, moving_var]";
CHECK_EQ(out_shape->size(), 3U); CHECK_EQ(out_shape->size(), 3U);
const TShape &dshape = in_shape->at(0); const TShape &dshape = in_shape->at(0);
if (dshape.ndim() == 0) return false; if (dshape.ndim() == 0) return false;
TShape bshape({dshape[1]}); CHECK((size_t)param.axis < dshape.Size());
TShape bshape({dshape[param.axis]});
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, 1, bshape); NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, 1, bshape);
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, 2, bshape); NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, 2, bshape);
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, 3, bshape); NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, 3, bshape);
......
...@@ -62,6 +62,22 @@ def test_batchnorm(): ...@@ -62,6 +62,22 @@ def test_batchnorm():
sdict = infer_shape(y) sdict = infer_shape(y)
assert(sdict["bn_gamma"][0] == [20]) assert(sdict["bn_gamma"][0] == [20])
x = sym.Variable("x", shape=(10, 20, 30, 40))
y = sym.batch_norm(data=x, axis=0, epsilon=2e-5, name='bn')
sdict = infer_shape(y)
assert(sdict['bn_moving_var'][0] == [10])
y = sym.batch_norm(data=x, axis=1, epsilon=2e-5, name='bn')
sdict = infer_shape(y)
assert(sdict['bn_gamma'][0] == [20])
y = sym.batch_norm(data=x, axis=2, epsilon=2e-5, name='bn')
sdict = infer_shape(y)
assert(sdict['bn_beta'][0] == [30])
y = sym.batch_norm(data=x, axis=3, epsilon=2e-5, name='bn')
sdict = infer_shape(y)
assert(sdict['bn_moving_mean'][0] == [40])
def test_flatten(): def test_flatten():
x = sym.Variable("x", shape=(10, 20, 10)) x = sym.Variable("x", shape=(10, 20, 10))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment