Commit 1d243664 by Animesh Jain Committed by Yizhi Liu

[TOPI][AlterOpLayout][ARM] Enabling NHWC to NCHW layout transformation. (#4249)

parent d2fc0252
......@@ -916,6 +916,67 @@ def test_alter_layout_sum():
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_nhwc_nchw_arm():
""" Check NHWC to NHCW conversion for a small sequence of ops."""
# Register alter op layout. "level" is used to override the previously registered functions.
@register_alter_op_layout("nn.conv2d", level=115)
def alter_conv2d(attrs, inputs, tinfos):
from topi.arm_cpu.conv2d import _alter_conv2d_layout_arm
return _alter_conv2d_layout_arm(attrs, inputs, tinfos, tvm.relay)
# Check NHWC conversion.
def before_nhwc():
x = relay.var("x", shape=(1, 56, 56, 64))
weight1 = relay.var('weight1', shape=(3, 3, 64, 64))
weight2 = relay.var('weight2', shape=(3, 3, 64, 64))
y = relay.nn.conv2d(x, weight1,
channels=64,
kernel_size=(3, 3),
data_layout='NHWC',
kernel_layout='HWIO')
y = relay.nn.relu(y)
y = relay.nn.avg_pool2d(y,
pool_size=(1,1),
layout='NHWC')
y = relay.nn.conv2d(y, weight2,
channels=64,
kernel_size=(3, 3),
data_layout='NHWC',
kernel_layout='HWIO')
y = relay.nn.relu(y)
y = relay.Function(analysis.free_vars(y), y)
return y
def expected_nhwc():
x = relay.var("x", shape=(1, 56, 56, 64))
weight1 = relay.var('weight1', shape=(3, 3, 64, 64))
weight2 = relay.var('weight2', shape=(3, 3, 64, 64))
y = relay.layout_transform(x, "NHWC", "NCHW")
weight1 = relay.layout_transform(weight1, "HWIO", "OIHW")
weight2 = relay.layout_transform(weight2, "HWIO", "OIHW")
y = relay.nn.conv2d(y, weight1,
channels=64,
kernel_size=(3, 3))
y = relay.nn.relu(y)
y = relay.nn.avg_pool2d(y,
pool_size=(1,1))
y = relay.nn.conv2d(y, weight2,
channels=64,
kernel_size=(3, 3))
y = relay.nn.relu(y)
y = relay.layout_transform(y, "NCHW", "NHWC")
y = relay.Function(analysis.free_vars(y), y)
return y
a = before_nhwc()
a = run_opt_pass(a, transform.AlterOpLayout())
b = expected_nhwc()
b = run_opt_pass(b, transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
if __name__ == "__main__":
test_alter_op()
test_alter_return_none()
......@@ -932,3 +993,4 @@ if __name__ == "__main__":
test_alter_layout_pad()
test_alter_layout_pool()
test_alter_layout_sum()
test_alter_layout_nhwc_nchw_arm()
......@@ -171,53 +171,9 @@ def test_legalize_multi_input():
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
def test_legalize_arm_layout_functional():
"""Test if the legalized conversion yields same result as original"""
def get_output(func, data_val, parameters):
with relay.build_config(opt_level=0):
graph, lib, params = relay.build(func, target='llvm', params=parameters)
m = graph_runtime.create(graph, lib, tvm.cpu())
m.set_input("data", data_val)
m.set_input(**params)
m.run()
out = m.get_output(0, tvm.nd.empty((1, 224, 224, 32), 'float32')).asnumpy()
return out
def before():
n, ic, ih, iw, oc, kh, kw = 1, 16, 224, 224, 32, 3, 3
data = relay.var("data", relay.TensorType((n, ih, iw, ic), 'float32'))
kernel = relay.var("kernel", relay.TensorType((kh, kw, ic, oc), 'float32'))
y = relay.nn.conv2d(data, kernel,
kernel_size=(kh, kw),
channels=oc,
padding=(1, 1),
dilation=(1, 1),
data_layout='NHWC',
kernel_layout='HWIO',
out_dtype='float32')
func = relay.Function([data, kernel], y)
return func
@register_legalize("nn.conv2d", level=105)
def legalize_conv2d(attrs, inputs, types):
from topi.arm_cpu.conv2d import _conv2d_legalize
return _conv2d_legalize(attrs, inputs, types)
a = before()
b = run_opt_pass(a, transform.Legalize())
assert b.astext().count('transpose') == 3
wdata = np.random.rand(3, 3, 16, 32) * 10
parameters = {"kernel": tvm.nd.array(wdata.astype('float32'))}
data_val = np.random.rand(1, 224, 224, 16).astype('float32')
ref_out = get_output(a, data_val, parameters)
legalized_out = get_output(b, data_val, parameters)
np.testing.assert_allclose(ref_out, legalized_out, rtol=0.01)
if __name__ == "__main__":
test_legalize()
test_legalize_none()
test_legalize_multiple_ops()
test_legalize_multi_input()
test_legalize_arm_layout_functional()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment