test_nhwc_layout.py 2.18 KB
Newer Older
1 2 3 4 5 6 7
import numpy as np
import tvm
from tvm.contrib import graph_runtime as runtime
import nnvm.symbol as sym
import nnvm.compiler
from nnvm.testing.config import ctx_list

8
def get_sym(layout, kernel_layout, channels):
9
    data = sym.Variable(name="data")
10 11
    data = sym.conv2d(data=data, kernel_size=(3,3), channels=channels, padding=(1, 1),
                      layout=layout, kernel_layout=kernel_layout, use_bias=True)
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
    data = sym.max_pool2d(data=data, pool_size=(2, 2), strides=(2, 2), layout=layout)
    data = sym.upsampling(data=data, scale=2, layout=layout)
    softmax_axis = 1
    if layout == "NHWC":
        softmax_axis = 3
    data = sym.softmax(data=data, axis=softmax_axis)
    return data


def build_and_run(sym, params, data, out_shape):
    ctx = tvm.cpu(0)
    graph, lib, params = nnvm.compiler.build(sym, "llvm", shape={"data":data.shape}, params=params)
    module = runtime.create(graph, lib, ctx)
    module.set_input(**params)
    module.set_input("data", data)
    module.run()
    out =  module.get_output(0, tvm.nd.empty(out_shape))
    return out.asnumpy()


def test_nhwc():
    data_shape = (1, 3, 224, 224)
    out_channel = 8
35 36
    nchw_sym = get_sym("NCHW", "OIHW", out_channel)
    nhwc_sym = get_sym("NHWC", "HWIO", out_channel)
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
    conv_weight = np.random.uniform(-1, 1, (out_channel, 3, 3, 3)).astype(np.float32)
    conv_bias = np.random.uniform(-1, 1, (out_channel)).astype(np.float32)
    nchw_params = {
        "conv2d0_weight" : tvm.nd.array(conv_weight, ctx=tvm.cpu(0)),
        "conv2d0_bias" : tvm.nd.array(conv_bias, ctx=tvm.cpu(0))
    }
    nhwc_params = {
        "conv2d1_weight" : tvm.nd.array(conv_weight.transpose(2, 3, 1, 0), ctx=tvm.cpu(0)),
        "conv2d1_bias" : tvm.nd.array(conv_bias, ctx=tvm.cpu(0))
    }

    data = np.random.uniform(-1, 1, data_shape).astype(np.float32)
    oshape = (1, out_channel, 224, 224)
    oshape_nhwc = (1, 224, 224, out_channel)
    nchw_output = build_and_run(nchw_sym, nchw_params, data, oshape)
    nhwc_output = build_and_run(nhwc_sym, nhwc_params, data.transpose(0, 2, 3, 1), oshape_nhwc)
53
    tvm.testing.assert_allclose(nchw_output, nhwc_output.transpose(0, 3, 1, 2), rtol=1e-5, atol=1e-5)
54 55 56 57


if __name__ == "__main__":
    test_nhwc()