Unverified Commit 49d304fc by Animesh Jain Committed by GitHub

[TOPI-ARM] Do not alter layout if layout is NHWC (#5350)

* [TOPI-ARM] Do not alter layout if layout is NHWC

* Add test.
parent 6dfcd375
......@@ -940,11 +940,8 @@ def test_alter_layout_sum():
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
# TODO(@anijain2305, @icemelon9): We should fix this. This doesn't seem to be the
# right behavior of alter_layout
@pytest.mark.skip
def test_alter_layout_nhwc_nchw_arm():
""" Check NHWC to NHCW conversion for a small sequence of ops."""
def test_alter_layout_nhwc_arm():
""" Check that AlterOplayout does not alter NHWC data layout. """
def alter_conv2d(attrs, inputs, tinfos, out_type):
import topi
with tvm.target.create("llvm -device=arm_cpu"):
......@@ -974,25 +971,7 @@ def test_alter_layout_nhwc_nchw_arm():
return y
def expected_nhwc():
x = relay.var("x", shape=(1, 56, 56, 64))
weight1 = relay.var('weight1', shape=(3, 3, 64, 64))
weight2 = relay.var('weight2', shape=(3, 3, 64, 64))
y = relay.layout_transform(x, "NHWC", "NCHW")
weight1 = relay.layout_transform(weight1, "HWIO", "OIHW")
weight2 = relay.layout_transform(weight2, "HWIO", "OIHW")
y = relay.nn.conv2d(y, weight1,
channels=64,
kernel_size=(3, 3))
y = relay.nn.relu(y)
y = relay.nn.avg_pool2d(y,
pool_size=(1,1))
y = relay.nn.conv2d(y, weight2,
channels=64,
kernel_size=(3, 3))
y = relay.nn.relu(y)
y = relay.layout_transform(y, "NCHW", "NHWC")
y = relay.Function(analysis.free_vars(y), y)
return y
return before_nhwc()
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = before_nhwc()
......@@ -1060,5 +1039,5 @@ if __name__ == "__main__":
test_alter_layout_pad()
test_alter_layout_pool()
test_alter_layout_sum()
# test_alter_layout_nhwc_nchw_arm()
test_alter_layout_nhwc_arm()
test_alter_op_with_global_var()
......@@ -59,6 +59,10 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
data, kernel = tinfos
out_dtype = out_type.dtype
# We only perform layout alteration for NCHW data layout.
if data_layout == "NHWC":
return None
# Extract data types
data_tensor, kernel_tensor = tinfos
data_dtype = data_tensor.dtype
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment