Commit 1d243664 by Animesh Jain Committed by Yizhi Liu

[TOPI][AlterOpLayout][ARM] Enabling NHWC to NCHW layout transformation. (#4249)

parent d2fc0252
......@@ -916,6 +916,67 @@ def test_alter_layout_sum():
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_nhwc_nchw_arm():
""" Check NHWC to NHCW conversion for a small sequence of ops."""
# Register alter op layout. "level" is used to override the previously registered functions.
@register_alter_op_layout("nn.conv2d", level=115)
def alter_conv2d(attrs, inputs, tinfos):
from topi.arm_cpu.conv2d import _alter_conv2d_layout_arm
return _alter_conv2d_layout_arm(attrs, inputs, tinfos, tvm.relay)
# Check NHWC conversion.
def before_nhwc():
x = relay.var("x", shape=(1, 56, 56, 64))
weight1 = relay.var('weight1', shape=(3, 3, 64, 64))
weight2 = relay.var('weight2', shape=(3, 3, 64, 64))
y = relay.nn.conv2d(x, weight1,
channels=64,
kernel_size=(3, 3),
data_layout='NHWC',
kernel_layout='HWIO')
y = relay.nn.relu(y)
y = relay.nn.avg_pool2d(y,
pool_size=(1,1),
layout='NHWC')
y = relay.nn.conv2d(y, weight2,
channels=64,
kernel_size=(3, 3),
data_layout='NHWC',
kernel_layout='HWIO')
y = relay.nn.relu(y)
y = relay.Function(analysis.free_vars(y), y)
return y
def expected_nhwc():
x = relay.var("x", shape=(1, 56, 56, 64))
weight1 = relay.var('weight1', shape=(3, 3, 64, 64))
weight2 = relay.var('weight2', shape=(3, 3, 64, 64))
y = relay.layout_transform(x, "NHWC", "NCHW")
weight1 = relay.layout_transform(weight1, "HWIO", "OIHW")
weight2 = relay.layout_transform(weight2, "HWIO", "OIHW")
y = relay.nn.conv2d(y, weight1,
channels=64,
kernel_size=(3, 3))
y = relay.nn.relu(y)
y = relay.nn.avg_pool2d(y,
pool_size=(1,1))
y = relay.nn.conv2d(y, weight2,
channels=64,
kernel_size=(3, 3))
y = relay.nn.relu(y)
y = relay.layout_transform(y, "NCHW", "NHWC")
y = relay.Function(analysis.free_vars(y), y)
return y
a = before_nhwc()
a = run_opt_pass(a, transform.AlterOpLayout())
b = expected_nhwc()
b = run_opt_pass(b, transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
if __name__ == "__main__":
test_alter_op()
test_alter_return_none()
......@@ -932,3 +993,4 @@ if __name__ == "__main__":
test_alter_layout_pad()
test_alter_layout_pool()
test_alter_layout_sum()
test_alter_layout_nhwc_nchw_arm()
......@@ -171,53 +171,9 @@ def test_legalize_multi_input():
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
def test_legalize_arm_layout_functional():
"""Test if the legalized conversion yields same result as original"""
def get_output(func, data_val, parameters):
with relay.build_config(opt_level=0):
graph, lib, params = relay.build(func, target='llvm', params=parameters)
m = graph_runtime.create(graph, lib, tvm.cpu())
m.set_input("data", data_val)
m.set_input(**params)
m.run()
out = m.get_output(0, tvm.nd.empty((1, 224, 224, 32), 'float32')).asnumpy()
return out
def before():
n, ic, ih, iw, oc, kh, kw = 1, 16, 224, 224, 32, 3, 3
data = relay.var("data", relay.TensorType((n, ih, iw, ic), 'float32'))
kernel = relay.var("kernel", relay.TensorType((kh, kw, ic, oc), 'float32'))
y = relay.nn.conv2d(data, kernel,
kernel_size=(kh, kw),
channels=oc,
padding=(1, 1),
dilation=(1, 1),
data_layout='NHWC',
kernel_layout='HWIO',
out_dtype='float32')
func = relay.Function([data, kernel], y)
return func
@register_legalize("nn.conv2d", level=105)
def legalize_conv2d(attrs, inputs, types):
from topi.arm_cpu.conv2d import _conv2d_legalize
return _conv2d_legalize(attrs, inputs, types)
a = before()
b = run_opt_pass(a, transform.Legalize())
assert b.astext().count('transpose') == 3
wdata = np.random.rand(3, 3, 16, 32) * 10
parameters = {"kernel": tvm.nd.array(wdata.astype('float32'))}
data_val = np.random.rand(1, 224, 224, 16).astype('float32')
ref_out = get_output(a, data_val, parameters)
legalized_out = get_output(b, data_val, parameters)
np.testing.assert_allclose(ref_out, legalized_out, rtol=0.01)
if __name__ == "__main__":
test_legalize()
test_legalize_none()
test_legalize_multiple_ops()
test_legalize_multi_input()
test_legalize_arm_layout_functional()
......@@ -22,7 +22,6 @@ import logging
import tvm
from tvm import autotvm
from tvm import relay
import tvm.contrib.nnpack
from ..generic import schedule_conv2d_nchw, schedule_conv2d_winograd_without_weight_transform, \
......@@ -32,7 +31,6 @@ from ..nn import dilate, pad, conv2d, conv2d_alter_layout, \
conv2d_winograd_without_weight_transform, \
conv2d_winograd_nnpack_without_weight_transform, \
depthwise_conv2d_nchw
from ..nn import conv2d_legalize
from ..nn.util import get_const_int, get_pad_tuple
from ..nn.winograd_util import winograd_transform_matrices
from .conv2d_spatial_pack import conv2d_spatial_pack_nchw, \
......@@ -508,32 +506,63 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
groups = attrs.get_int('groups')
data_layout_key = "data_layout" if "data_layout" in new_attrs else "layout"
layout = attrs[data_layout_key]
kernel_layout = attrs['kernel_layout']
out_dtype = attrs["out_dtype"]
if out_dtype in ("same", ""):
out_dtype = tinfos[0].dtype
if layout != 'NCHW':
return None
if dilation != (1, 1):
logger.warning("Does not support weight pre-transform for dilated convolution.")
return None
# query config of this workload
data, kernel = tinfos[0:2]
if groups == 1:
workload = autotvm.task.args_to_workload(
[data, kernel, strides, padding, dilation, layout, out_dtype], conv2d)
else:
workload = autotvm.task.args_to_workload(
[data, kernel, strides, padding, dilation, out_dtype], depthwise_conv2d_nchw)
if layout == 'NCHW' and kernel_layout == 'OIHW':
N, CI, H, W = get_const_tuple(data.shape)
CO, _, KH, KW = get_const_tuple(kernel.shape)
elif layout == 'NHWC' and kernel_layout == 'HWIO':
N, H, W, CI = get_const_tuple(data.shape)
KH, KW, _, CO = get_const_tuple(kernel.shape)
# Also modify the workload to pick up because later we convert to NCHW
# layout.
new_data = tvm.placeholder((N, CI, H, W), dtype=data.dtype)
new_kernel = tvm.placeholder((CO, CI, KH, KW), dtype=kernel.dtype)
new_layout = 'NCHW'
workload = autotvm.task.args_to_workload(
[new_data, new_kernel, strides, padding, dilation, new_layout, out_dtype], conv2d)
elif layout == 'NHWC' and kernel_layout == 'HWOI':
# This is the case for depthwise convolution.
N, H, W, CI = get_const_tuple(data.shape)
KH, KW, CO, M = get_const_tuple(kernel.shape)
# Also modify the workload to pick up because later we convert to NCHW
# layout.
new_data = tvm.placeholder((N, CI, H, W), dtype=data.dtype)
new_kernel = tvm.placeholder((CO, M, KH, KW), dtype=kernel.dtype)
workload = autotvm.task.args_to_workload(
[new_data, new_kernel, strides, padding, dilation, out_dtype], depthwise_conv2d_nchw)
else:
return None
idxd = tvm.indexdiv
if groups == 1:
# query config of this workload
workload = autotvm.task.args_to_workload(
[data, kernel, strides, padding, dilation, layout, out_dtype], conv2d)
target = tvm.target.current_target()
dispatch_ctx = autotvm.DispatchContext.current
cfg = dispatch_ctx.query(target, workload)
if cfg.is_fallback: # if is fallback, clear query cache and return None
autotvm.task.clear_fallback_cache(target, workload)
if layout == 'NHWC' and kernel_layout == 'HWIO':
new_attrs['data_layout'] = 'NCHW'
new_attrs['kernel_layout'] = 'OIHW'
return F.nn.conv2d(*copy_inputs, **new_attrs)
return None
if cfg.template_key == 'direct': # pack weight tensor
......@@ -541,7 +570,8 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
new_attrs['kernel_layout'] = 'OIHW%do' % VC
# Store the same config for the altered operator (workload)
new_data = data
new_data = tvm.placeholder((N, CI, H, W), dtype=data.dtype)
new_attrs[data_layout_key] = 'NCHW'
new_kernel = tvm.placeholder((idxd(CO, VC), CI, KH, KW, VC), dtype=kernel.dtype)
new_workload = autotvm.task.args_to_workload(
[new_data, new_kernel, strides, padding, dilation, 'NCHW', out_dtype], conv2d)
......@@ -560,7 +590,10 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
tile_size = _pick_tile_size(tinfos[0], tinfos[1])
VC = cfg['tile_bna'].val
weight = F.nn.contrib_conv2d_winograd_weight_transform(copy_inputs[1],
weight = copy_inputs[1]
if kernel_layout != 'OIHW':
weight = F.transpose(weight, axes=(2, 3, 0, 1))
weight = F.nn.contrib_conv2d_winograd_weight_transform(weight,
tile_size=tile_size)
if VC > 0:
weight = F.reshape(weight,
......@@ -581,9 +614,10 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
copy_inputs[1] = weight
new_attrs['tile_size'] = tile_size
new_attrs[data_layout_key] = 'NCHW'
# Store the same config for the altered operator (workload)
new_data = data
new_data = tvm.placeholder((N, CI, H, W), dtype=data.dtype)
new_workload = autotvm.task.args_to_workload(
[new_data, new_weight, strides, padding, dilation,
new_attrs[data_layout_key], out_dtype, tile_size],
......@@ -596,14 +630,21 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
# for winograd_nnpack_fp16, the the precomputeprune pass must run on device,
# where float16 is supported
weight_dtype = 'float32'
weight = copy_inputs[1]
if kernel_layout != 'OIHW':
weight = F.transpose(weight, axes=(2, 3, 0, 1))
weight = F.nn.contrib_conv2d_winograd_weight_transform(weight,
tile_size=tile_size)
transformed_kernel = F.nn.contrib_conv2d_winograd_nnpack_weight_transform(
copy_inputs[1],
weight,
convolution_algorithm=cfg['winograd_nnpack_algorithm'].val,
out_dtype=weight_dtype)
copy_inputs[1] = transformed_kernel
new_data = data
new_data = tvm.placeholder((N, CI, H, W), dtype=data.dtype)
new_kernel = tvm.placeholder((CO, CI, 8, 8), "float32")
bias = tvm.placeholder((CO, ), "float32")
new_attrs[data_layout_key] = 'NCHW'
new_workload = autotvm.task.args_to_workload(
[new_data, new_kernel, bias, strides,
padding, dilation, new_attrs[data_layout_key], out_dtype]
......@@ -617,22 +658,30 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
else:
raise RuntimeError("Unsupported template_key '%s'" % cfg.template_key)
else:
workload = autotvm.task.args_to_workload(
[data, kernel, strides, padding, dilation, out_dtype], depthwise_conv2d_nchw)
target = tvm.target.current_target()
dispatch_ctx = autotvm.DispatchContext.current
cfg = dispatch_ctx.query(target, workload)
if cfg.is_fallback: # if is fallback, clear query cache and return None
autotvm.task.clear_fallback_cache(tvm.target.current_target(), workload)
if layout == 'NHWC' and kernel_layout == 'HWOI':
new_attrs['data_layout'] = 'NCHW'
new_attrs['kernel_layout'] = 'OIHW'
return F.nn.conv2d(*copy_inputs, **new_attrs)
return None
if cfg.template_key == 'contrib_spatial_pack':
VC = cfg['tile_co'].size[-1]
new_attrs['kernel_layout'] = 'OIHW%do' % (cfg['tile_co'].size[-1])
# Store the same config for the altered operator (workload)
new_data = data
new_data = tvm.placeholder((N, CI, H, W), dtype=data.dtype)
new_attrs[data_layout_key] = 'NCHW'
if attrs['kernel_layout'] == 'OIHW':
CO, M, KH, KW = get_const_tuple(kernel.shape)
elif attrs['kernel_layout'] == 'HWOI':
KH, KW, CO, M = get_const_tuple(kernel.shape)
else:
raise RuntimeError("Depthwise conv should either have OIHW/HWIO kernel layout")
new_kernel = tvm.placeholder((idxd(CO, VC), M, KH, KW, VC), dtype=kernel.dtype)
new_workload = autotvm.task.args_to_workload(
[new_data, new_kernel, strides, padding, dilation, out_dtype],
......@@ -644,48 +693,3 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
# currently we only have contrib_spatial_pack and direct template
# add more schedule templates.
return None
@conv2d_legalize.register("arm_cpu")
def _conv2d_legalize(attrs, inputs, arg_types):
"""Legalizes Conv2D op.
Parameters
----------
attrs : tvm.attrs.Attrs
Attributes of current convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
types : list of types
List of input and output types
Returns
-------
result : tvm.relay.Expr
The legalized expr
"""
if attrs['data_layout'] == 'NHWC':
data, kernel = inputs
if attrs['kernel_layout'] == 'HWIO':
# Handle HWIO layout. This is common in TF graph.
kernel = relay.transpose(kernel, axes=(3, 2, 0, 1))
elif attrs['kernel_layout'] == 'HWOI':
# Handle HWOI layout. This is common in TF depthwise conv2d graph.
kernel = relay.transpose(kernel, axes=(2, 3, 0, 1))
elif attrs['kernel_layout'] != 'OIHW':
return None
logger.warning("Legalize arm_cpu - NHWC schedule absent. Inserting layout transforms to "
+ "fallback to NCHW. This can result in performance degradation.")
# Set new attrs for the tranposed conv.
new_attrs = {k: attrs[k] for k in attrs.keys()}
new_attrs['data_layout'] = 'NCHW'
new_attrs['kernel_layout'] = 'OIHW'
# Convert from NHWC to NCHW.
data = relay.transpose(data, axes=(0, 3, 1, 2))
conv = relay.nn.conv2d(data, kernel, **new_attrs)
# Convert back to original NHWC layout.
out = relay.transpose(conv, axes=(0, 2, 3, 1))
return out
return None
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment