Commit 5498e54d by Animesh Jain Committed by Yao Wang

[Relay][Legalize][ARM_CPU] Handling NHWC layout for arm_cpu. (#3754)

parent aee16d87
...@@ -204,10 +204,11 @@ def alter_op_layout_conv2d(attrs, inputs, tinfos): ...@@ -204,10 +204,11 @@ def alter_op_layout_conv2d(attrs, inputs, tinfos):
from ... import op from ... import op
return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, op) return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, op)
# A placeholder to have at least one invocation of register legalize to register FTVMLegalize.
@reg.register_legalize("nn.conv2d") @reg.register_legalize("nn.conv2d")
def legalize_conv2d(attrs, inputs, arg_dtypes): def legalize_conv2d(attrs, inputs, arg_dtypes):
return None """Legalize conv2d"""
from ... import op
return topi.nn.conv2d_legalize(attrs, inputs, arg_dtypes, op)
reg.register_pattern("nn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE) reg.register_pattern("nn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
......
...@@ -15,9 +15,11 @@ ...@@ -15,9 +15,11 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""Test legalize pass""" """Test legalize pass"""
import numpy as np
import tvm import tvm
from tvm import relay from tvm import relay
from tvm.contrib import graph_runtime
from tvm.relay.op import register_legalize from tvm.relay.op import register_legalize
from tvm.relay import transform, analysis from tvm.relay import transform, analysis
...@@ -123,8 +125,52 @@ def test_legalize_multi_input(): ...@@ -123,8 +125,52 @@ def test_legalize_multi_input():
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
def test_legalize_arm_layout_functional():
"""Test if the legalized conversion yields same result as original"""
def get_output(func, data_val, parameters):
with relay.build_config(opt_level=0):
graph, lib, params = relay.build(func, target='llvm', params=parameters)
m = graph_runtime.create(graph, lib, tvm.cpu())
m.set_input("data", data_val)
m.set_input(**params)
m.run()
out = m.get_output(0, tvm.nd.empty((1, 224, 224, 32), 'float32')).asnumpy()
return out
def before():
n, ic, ih, iw, oc, kh, kw = 1, 16, 224, 224, 32, 3, 3
data = relay.var("data", relay.TensorType((n, ih, iw, ic), 'float32'))
kernel = relay.var("kernel", relay.TensorType((kh, kw, ic, oc), 'float32'))
y = relay.nn.conv2d(data, kernel,
kernel_size=(kh, kw),
channels=oc,
padding=(1, 1),
dilation=(1, 1),
data_layout='NHWC',
kernel_layout='HWIO',
out_dtype='float32')
func = relay.Function([data, kernel], y)
return func
@register_legalize("nn.conv2d", level=101)
def legalize_conv2d(attrs, inputs, arg_types):
from topi.arm_cpu.conv2d import _conv2d_legalize
return _conv2d_legalize(attrs, inputs, arg_types, tvm.relay.op)
a = before()
b = run_opt_pass(a, transform.Legalize())
assert b.astext().count('transpose') == 3
wdata = np.random.rand(3, 3, 16, 32) * 10
parameters = {"kernel": tvm.nd.array(wdata.astype('float32'))}
data_val = np.random.rand(1, 224, 224, 16).astype('float32')
ref_out = get_output(a, data_val, parameters)
legalized_out = get_output(b, data_val, parameters)
np.testing.assert_allclose(ref_out, legalized_out, rtol=0.01)
if __name__ == "__main__": if __name__ == "__main__":
test_legalize() test_legalize()
test_legalize_none() test_legalize_none()
test_legalize_multi_input() test_legalize_multi_input()
test_legalize_arm_layout_functional()
...@@ -31,6 +31,7 @@ from ..nn import dilate, pad, conv2d, conv2d_alter_layout, \ ...@@ -31,6 +31,7 @@ from ..nn import dilate, pad, conv2d, conv2d_alter_layout, \
conv2d_winograd_without_weight_transform, \ conv2d_winograd_without_weight_transform, \
conv2d_winograd_nnpack_without_weight_transform, \ conv2d_winograd_nnpack_without_weight_transform, \
depthwise_conv2d_nchw depthwise_conv2d_nchw
from ..nn import conv2d_legalize
from ..nn.util import get_const_int, get_pad_tuple from ..nn.util import get_const_int, get_pad_tuple
from ..nn.winograd_util import winograd_transform_matrices from ..nn.winograd_util import winograd_transform_matrices
...@@ -783,3 +784,33 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F): ...@@ -783,3 +784,33 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
# currently we only have contrib_spatial_pack and direct template # currently we only have contrib_spatial_pack and direct template
# add more schedule templates. # add more schedule templates.
return None return None
@conv2d_legalize.register("arm_cpu")
def _conv2d_legalize(attrs, inputs, arg_types, F):
if F.__name__ != 'tvm.relay.op':
return None
if attrs['data_layout'] == 'NHWC':
data, kernel = inputs
if attrs['kernel_layout'] == 'HWIO':
# Handle HWIO layout. This is common in TF graph.
kernel = F.transpose(kernel, axes=(3, 2, 0, 1))
elif attrs['kernel_layout'] == 'HWOI':
# Handle HWOI layout. This is common in TF depthwise conv2d graph.
kernel = F.transpose(kernel, axes=(2, 3, 0, 1))
elif attrs['kernel_layout'] != 'OIHW':
return None
warnings.warn("Legalize arm_cpu - NHWC schedule absent. Inserting layout transforms to "
+ "fallback to NCHW. This can result in performance degradation.")
# Set new attrs for the tranposed conv.
new_attrs = {k: attrs[k] for k in attrs.keys()}
new_attrs['data_layout'] = 'NCHW'
new_attrs['kernel_layout'] = 'OIHW'
# Convert from NHWC to NCHW.
data = F.transpose(data, axes=(0, 3, 1, 2))
conv = F.nn.conv2d(data, kernel, **new_attrs)
# Convert back to original NHWC layout.
out = F.transpose(conv, axes=(0, 2, 3, 1))
return out
return None
...@@ -72,6 +72,28 @@ def conv2d(input, filter, strides, padding, dilation, layout='NCHW', out_dtype=N ...@@ -72,6 +72,28 @@ def conv2d(input, filter, strides, padding, dilation, layout='NCHW', out_dtype=N
@tvm.target.generic_func @tvm.target.generic_func
def conv2d_legalize(attrs, inputs, arg_dtypes, F):
"""Legalizes Conv2D op.
Parameters
----------
attrs : nnvm.top.AttrDict or tvm.attrs.Attrs
Attributes of current convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized.
arg_dtypes : list of types
List of types of input arguments
F: symbol
The context, can be either nnvm.sym or relay.op
Note
----
Unlike other TOPI functions, this function operates on both graph level and operator level,
so we have to pass 'F' to make it support our two versions of graph IR, NNVM and Relay.
"""
# not to change by default
return None
@tvm.target.generic_func
def conv2d_alter_layout(attrs, inputs, tinfos, F): def conv2d_alter_layout(attrs, inputs, tinfos, F):
"""Change Conv2D layout. """Change Conv2D layout.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment