import json import nnvm.symbol as sym import nnvm.graph as graph def infer_shape(sym): g = graph.create(sym) g._set_json_attr("shape_attr_key", "shape") g = g.apply("InferShape") sdict = {} vshape = g.json_attr("shape") entry_ptr = g.index.entry_ptr for i, n in enumerate(g.index.nodes): begin, end = entry_ptr[i], entry_ptr[i + 1] sdict[n["name"]] = vshape[begin:end] return sdict # Level 1 def test_dense(): x = sym.Variable("x", shape=(10, 20)) y = sym.dense(x, units=30, name="fc") sdict = infer_shape(y) assert(sdict["fc"][0] == [10, 30]) assert(sdict["fc_bias"][0] == [30]) def test_concatenate(): x1 = sym.Variable("x", shape=(10, 20)) x2 = sym.Variable("y", shape=(10, 30)) z = sym.concatenate(x1, x2, name="concat") sdict = infer_shape(z) assert(sdict["concat"][0] == [10, 50]) z = sym.concatenate(x1, x1, axis=0, name="concat") sdict = infer_shape(z) assert(sdict["concat"][0] == [20, 20]) def test_expand_dims(): x = sym.Variable("x", shape=(10, 20)) y = sym.expand_dims(x, axis=1, name="y") sdict = infer_shape(y) assert(sdict["y"][0] == [10, 1, 20]) y = sym.expand_dims(x, axis=-1, name="y", num_newaxis=2) sdict = infer_shape(y) assert(sdict["y"][0] == [10, 20, 1, 1]) def test_split(): x1 = sym.Variable("x", shape=(10, 20)) z = sym.split(x1, indices_or_sections=[11], name="y") sdict = infer_shape(z) assert(sdict["y"][0] == [10, 11]) assert(sdict["y"][1] == [10, 9]) z = sym.split(x1, indices_or_sections=2, name="y") sdict = infer_shape(z) assert(sdict["y"][0] == [10, 10]) assert(sdict["y"][1] == [10, 10]) def test_batchnorm(): x = sym.Variable("x", shape=(10, 20)) y = sym.batch_norm(1 / x, name="bn") sdict = infer_shape(y) assert(sdict["bn_gamma"][0] == [20]) x = sym.Variable("x", shape=(10, 20, 30, 40)) y = sym.batch_norm(data=x, axis=0, epsilon=2e-5, name='bn') sdict = infer_shape(y) assert(sdict['bn_moving_var'][0] == [10]) y = sym.batch_norm(data=x, axis=1, epsilon=2e-5, name='bn') sdict = infer_shape(y) assert(sdict['bn_gamma'][0] == [20]) y = sym.batch_norm(data=x, axis=2, epsilon=2e-5, name='bn') sdict = infer_shape(y) assert(sdict['bn_beta'][0] == [30]) y = sym.batch_norm(data=x, axis=3, epsilon=2e-5, name='bn') sdict = infer_shape(y) assert(sdict['bn_moving_mean'][0] == [40]) def test_flatten(): x = sym.Variable("x", shape=(10, 20, 10)) y = sym.flatten(x) * 2 y = sym.exp(y, name="y") sdict = infer_shape(y) assert(sdict["y"][0] == [10, 200]) # Level 2 def test_conv2d(): def check(in_shape, out_shape, **kwargs): x = sym.Variable("x", shape=in_shape) y = sym.conv2d(x, name="y", **kwargs) sdict = infer_shape(y) assert(tuple(sdict["y"][0]) == tuple(out_shape)) check((4, 10, 10, 12), (4, 12, 10, 12), channels=12, kernel_size=(3,3), padding=(1,1)) check((4, 10, 12, 4), (4, 8, 8, 5), channels=5, kernel_size=(3, 5), layout="NHWC") check((4, 10, 12, 4), (4, 6, 8, 5), channels=5, dilation=(2, 2), kernel_size=(3, 3), layout="NHWC") check((4, 10, 12, 4), (4, 5, 6, 5), channels=5, strides=(2, 2), kernel_size=(3, 3), padding=(1, 1), layout="NHWC") def test_conv2d_transpose(): def check(in_shape, out_shape, **kwargs): x = sym.Variable("x", shape=in_shape) y = sym.conv2d_transpose(x, name="y", **kwargs) sdict = infer_shape(y) assert(tuple(sdict["y"][0]) == tuple(out_shape)) check((4, 10, 10, 12), (4, 15, 10, 12), channels=15, kernel_size=(3,3), padding=(1,1)) check((4, 10, 10, 12), (4, 15, 10, 14), channels=15, kernel_size=(3, 5), padding=(1, 1)) check((4, 10, 10, 12), (4, 15, 11, 15), channels=15, kernel_size=(3, 5), padding=(1, 1), output_padding=(1, 1)) check((4, 10, 10, 12), (4, 15, 15, 11), channels=11, kernel_size=(5, 5), output_padding=(1, 1), layout="NHWC") def test_max_pool2d(): def check(in_shape, out_shape, **kwargs): x = sym.Variable("x", shape=in_shape) y = sym.max_pool2d(x, name="y", **kwargs) sdict = infer_shape(y) assert(tuple(sdict["y"][0]) == tuple(out_shape)) check((4, 10, 12, 12), (4, 10, 12, 12), pool_size=(3,3), padding=(1,1)) check((4, 10, 12, 12), (4, 10, 6, 6), pool_size=(3, 3), padding=(1, 1), strides=(2, 2)) check((4, 10, 12, 12), (4, 10, 7, 7), pool_size=(3, 3), padding=(1, 1), strides=(2, 2), ceil_mode=True) check((4, 12, 14, 10), (4, 6, 7, 10), pool_size=(3, 3), padding=(1, 1), strides=(2, 2), layout="NHWC") def test_global_pool2d(): def check(in_shape, out_shape, **kwargs): x = sym.Variable("x", shape=in_shape) y = sym.global_max_pool2d(x, name="y", **kwargs) sdict = infer_shape(y) assert(tuple(sdict["y"][0]) == tuple(out_shape)) check((4, 10, 12, 12), (4, 10, 1, 1)) check((4, 10, 12, 12), (4, 1, 1, 12), layout="NHWC") # Level 3 def test_reshape(): def check(in_shape, tshape, out_shape): x = sym.Variable("x", shape=in_shape) y = sym.reshape(x, shape=tshape, name="y") sdict = infer_shape(y) assert(tuple(sdict["y"][0]) == tuple(out_shape)) check((4,), (2, 2), (2, 2)) check((2, 3, 4), (4, 0, 2), (4, 3, 2)) check((2, 3, 4), (2, 0, 0), (2, 3, 4)) check((2, 3, 4), (6, 1, -1), (6, 1, 4)) check((2, 3, 4), (3, -1, 8), (3, 1, 8)) check((2, 3, 4), (-1,), (24,)) check((2, 3, 4), (-2,), (2, 3, 4)) check((2, 3, 4), (2, -2), (2, 3, 4)) check((2, 3, 4), (-2, 1, 1), (2, 3, 4, 1, 1)) check((2, 3, 4), (-3, 4), (6, 4)) check((2, 3, 4, 5), (-3, -3), (6, 20)) check((2, 3, 4), (0, -3), (2, 12)) check((2, 3, 4), (-3, -2), (6, 4)) check((2, 3, 4), (-4, 1, 2, -2), (1, 2, 3, 4)) check((2, 3, 4), (2, -4, -1, 3, -2), (2, 1, 3, 4)) # Level 4 def test_transpose(): def check(in_shape, out_shape, **kwargs): x = sym.Variable("x", shape=in_shape) y = sym.transpose(x, name="y", **kwargs) sdict = infer_shape(y) assert(tuple(sdict["y"][0]) == tuple(out_shape)) check((4, 1), (1, 4)) check((0, 1, 2, 3), (1, 2, 3, 0), axes=(1, 2, 3, 0)) def test_broadcast_to(): def check(in_shape, tshape, out_shape): x = sym.Variable("x", shape=in_shape) y = sym.broadcast_to(x, shape=tshape, name="y") sdict = infer_shape(y) assert(tuple(sdict["y"][0]) == tuple(out_shape)) check((4, 1), (0, 4), (4, 4)) check((4, 1, 5), (0, 4, 5), (4, 4, 5)) def test_broadcast_binary(): def check(lhs_shape, rhs_shape, out_shape): x = sym.Variable("x", shape=lhs_shape) y = sym.Variable("y", shape=rhs_shape) z = sym.broadcast_add(x, y, name="y") sdict = infer_shape(z) assert(tuple(sdict["y"][0]) == tuple(out_shape)) check((4, 1), (4), (4, 4)) check((5, 1, 1), (1, 4, 4), (5, 4, 4)) check((6, 1, 4), (5, 4), (6, 5, 4)) def test_reduce(): def check(in_shape, out_shape, **kwargs): x = sym.Variable("x", shape=in_shape) y = sym.sum(x, name="y", **kwargs) sdict = infer_shape(y) assert(tuple(sdict["y"][0]) == tuple(out_shape)) check((4, 5), (4,), axis=1) check((4, 5), (4, 1), axis=1, keepdims=True) check((4, 5), (1, 5), axis=0, keepdims=True) check((4, 5), (1, 1), axis=(), keepdims=True) check((4, 5), (1,), axis=()) check((4, 5, 10), (5,), axis=(0, 2)) check((4, 5, 10), (1, 5, 1), axis=(0, 2), keepdims=True) if __name__ == "__main__": test_expand_dims() test_dense() test_concatenate() test_split() test_batchnorm() test_flatten() test_conv2d() test_conv2d_transpose() test_max_pool2d() test_global_pool2d() test_reshape() test_broadcast_to() test_broadcast_binary() test_reduce() test_transpose()