Commit 2a7f7548 by Balint Cristian Committed by Tianqi Chen

Additional fix for PR#2972 (#3044)

parent 4ab97dfa
......@@ -700,7 +700,7 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
new_attrs = {k: attrs[k] for k in attrs.keys()}
if F == tvm.relay.op:
if F.__name__ == 'tvm.relay.op':
# Derive channels for frontends (e.g ONNX) that miss "channel" field.
new_attrs["channels"] = inputs[1].checked_type.shape[attrs['kernel_layout'].index('O')]
......
......@@ -371,7 +371,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, F):
copy_inputs = [s for s in inputs]
new_attrs = {k: attrs[k] for k in attrs.keys()}
if F == tvm.relay.op:
if F.__name__ == 'tvm.relay.op':
# Derive channels for frontends (e.g ONNX) that miss "channel" field.
new_attrs["channels"] = inputs[1].checked_type.shape[attrs['kernel_layout'].index('O')]
......
......@@ -54,7 +54,6 @@ def tile_and_bind3d(s, tensor, z, y, x, z_factor=2, y_factor=None, x_factor=None
@conv2d_alter_layout.register(["intel_graphics"])
def _alter_conv2d_layout(attrs, inputs, tinfos, F):
import nnvm.symbol as sym
copy_inputs = [s for s in inputs]
......@@ -75,11 +74,11 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, F):
new_attrs = {k: attrs[k] for k in attrs.keys()}
new_attrs["kernel_layout"] = 'OIHW%do' % (oc_bn)
if F == tvm.relay.op:
if F.__name__ == 'tvm.relay.op':
# Derive channels for frontends (e.g ONNX) that miss "channel" field.
new_attrs["channels"] = inputs[1].checked_type.shape[attrs['kernel_layout'].index('O')]
if F == sym:
if F.__name__ == 'nnvm.symbol':
out = F.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs)
else:
out = F.nn.contrib_conv2d_nchwc(*copy_inputs, **new_attrs)
......
......@@ -323,12 +323,11 @@ def _topi_nn_conv2d_NCHWc(*args, **kwargs):
@conv2d_alter_layout.register("cpu")
def _alter_conv2d_layout(attrs, inputs, tinfo, F):
import nnvm.symbol as sym
copy_inputs = [s for s in inputs]
new_attrs = {k : attrs[k] for k in attrs.keys()}
if F == tvm.relay.op:
if F.__name__ == 'tvm.relay.op':
# Derive channels for frontends (e.g ONNX) that miss "channel" field.
new_attrs["channels"] = inputs[1].checked_type.shape[attrs['kernel_layout'].index('O')]
......@@ -336,13 +335,14 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F):
batch_size, in_channel, height, width = get_const_tuple(data.shape)
groups = attrs.get_int("groups")
out_channel = attrs.get_int("channels") if F == sym else new_attrs["channels"]
out_channel = attrs.get_int("channels") \
if F.__name__ == 'nnvm.symbol' else new_attrs["channels"]
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
dilation = attrs.get_int_tuple("dilation")
out_dtype = attrs["out_dtype"]
layout_name = 'layout' if F == sym else 'data_layout'
layout_name = 'layout' if F.__name__ == 'nnvm.symbol' else 'data_layout'
layout = attrs[layout_name]
kh, kw = attrs.get_int_tuple("kernel_size")
......@@ -399,12 +399,12 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F):
dispatch_ctx.update(target, new_workload, cfg)
if is_depthwise:
if F == sym:
if F.__name__ == 'nnvm.symbol':
logging.warning("Use native layout for depthwise convolution on NNVM.")
return None
return F.nn.contrib_depthwise_conv2d_nchwc(*copy_inputs, **new_attrs)
else:
if F == sym:
if F.__name__ == 'nnvm.symbol':
return F.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs)
return F.nn.contrib_conv2d_nchwc(*copy_inputs, **new_attrs)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment