Commit 891c4117 by Siva Committed by Tianqi Chen

[FRONTEND][ONNX] Some bug fixes and Shape operator fixed for relay. (#2850)

* [FRONTEND][ONNX] Some bug fixes and Shape operator fixed for relay.

* 	* test cases

* 	* ci error
parent b590c4f2
...@@ -321,6 +321,10 @@ class AttrCvt(object): ...@@ -321,6 +321,10 @@ class AttrCvt(object):
else: else:
assert callable(self._op_name), "op_name can either be string or callable" assert callable(self._op_name), "op_name can either be string or callable"
op_name = self._op_name(attrs) op_name = self._op_name(attrs)
# ignore 'tvm_custom' always
self._ignores.append('tvm_custom')
# convert attributes # convert attributes
new_attrs = {} new_attrs = {}
for k in attrs.keys(): for k in attrs.keys():
...@@ -329,7 +333,8 @@ class AttrCvt(object): ...@@ -329,7 +333,8 @@ class AttrCvt(object):
elif k in self._disables: elif k in self._disables:
logging.warning("Attribute %s is disabled in relay.sym.%s", k, op_name) logging.warning("Attribute %s is disabled in relay.sym.%s", k, op_name)
elif k in self._ignores: elif k in self._ignores:
logging.debug("Attribute %s is ignored in relay.sym.%s", k, op_name) if k != 'tvm_custom':
logging.warning("Attribute %s is ignored in relay.sym.%s", k, op_name)
elif k in self._transforms: elif k in self._transforms:
new_name, defaults, transform = self._parse_default(self._transforms[k]) new_name, defaults, transform = self._parse_default(self._transforms[k])
if defaults is None: if defaults is None:
...@@ -416,4 +421,6 @@ class Renamer(object): ...@@ -416,4 +421,6 @@ class Renamer(object):
self._new_name = new_name self._new_name = new_name
def __call__(self, inputs, attrs, *args): def __call__(self, inputs, attrs, *args):
if 'tvm_custom' in attrs:
attrs.pop('tvm_custom')
return get_relay_op(self._new_name)(*inputs, **attrs) return get_relay_op(self._new_name)(*inputs, **attrs)
...@@ -106,7 +106,7 @@ class Pool(OnnxOpConverter): ...@@ -106,7 +106,7 @@ class Pool(OnnxOpConverter):
'pads': ('padding', (0, 0), revert_caffe2_pad) 'pads': ('padding', (0, 0), revert_caffe2_pad)
}, },
# very weird attributes here in onnx, force check # very weird attributes here in onnx, force check
ignores=['dilations'], ignores=['dilations', 'auto_pad'],
# TODO(zhreshold): make sure ceil_mode in onnx, and layout? # TODO(zhreshold): make sure ceil_mode in onnx, and layout?
extras={'ceil_mode': False}, extras={'ceil_mode': False},
custom_check=dimension_constraint())(inputs, attr, params) custom_check=dimension_constraint())(inputs, attr, params)
...@@ -160,6 +160,7 @@ class Conv(OnnxOpConverter): ...@@ -160,6 +160,7 @@ class Conv(OnnxOpConverter):
'dilations': ('dilation', (0, 0)), 'dilations': ('dilation', (0, 0)),
'pads': ('padding', (0, 0), revert_caffe2_pad), 'pads': ('padding', (0, 0), revert_caffe2_pad),
'group': ('groups', 1)}, 'group': ('groups', 1)},
ignores=['auto_pad'],
custom_check=dimension_constraint())(inputs[:2], attr, params) custom_check=dimension_constraint())(inputs[:2], attr, params)
use_bias = len(inputs) == 3 use_bias = len(inputs) == 3
if use_bias: if use_bias:
...@@ -332,7 +333,21 @@ class Reshape(OnnxOpConverter): ...@@ -332,7 +333,21 @@ class Reshape(OnnxOpConverter):
shape = tuple(params[inputs[1].name_hint].asnumpy()) shape = tuple(params[inputs[1].name_hint].asnumpy())
out = _op.reshape(inputs[0], shape) out = _op.reshape(inputs[0], shape)
else: else:
out = _op.reshape_like(inputs[0], inputs[1]) # Try to infer shape by precompute prune if possible.
# TODO: good to check inputs to be in params.
# to be enhanced when relay support list_input_names API of NNVM
logging.warning("Infering Reshape argument by precompute")
func = _expr.Function(ir_pass.free_vars(inputs[1]), inputs[1])
with tvm.relay.build_config(opt_level=0):
graph, lib, params = tvm.relay.build(func, target="llvm", params=params)
ctx = tvm.context("llvm", 0)
from tvm.contrib import graph_runtime
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**params)
m.run()
params_new = m.get_output(0)
inputs.pop(1)
out = _op.reshape(inputs[0], tuple(params_new.asnumpy().astype('int32').flatten()))
return out return out
...@@ -477,10 +492,7 @@ class Shape(OnnxOpConverter): ...@@ -477,10 +492,7 @@ class Shape(OnnxOpConverter):
@classmethod @classmethod
def _impl_v1(cls, inputs, attr, params): def _impl_v1(cls, inputs, attr, params):
# Result of this operator is prominently used by reshape operator. return _op.shape_of(inputs[0])
# Just pass the input as it is so that reshape_like can be used there.
logging.warning("Shape: Differently implemented in relay as a bypass (dummy operator)")
return inputs[0]
class Cast(OnnxOpConverter): class Cast(OnnxOpConverter):
""" Operator converter for Cast. """ Operator converter for Cast.
...@@ -494,7 +506,7 @@ class Cast(OnnxOpConverter): ...@@ -494,7 +506,7 @@ class Cast(OnnxOpConverter):
def _impl_v5(cls, inputs, attr, params): def _impl_v5(cls, inputs, attr, params):
try: try:
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
attr['to'] = TENSOR_TYPE_TO_NP_TYPE[attr['to']] attr['to'] = str(TENSOR_TYPE_TO_NP_TYPE[attr['to']])
except ImportError as e: except ImportError as e:
raise ImportError( raise ImportError(
"Unable to import onnx.mapping which is required {}".format(e)) "Unable to import onnx.mapping which is required {}".format(e))
...@@ -674,6 +686,11 @@ class ReduceMean(Reduce): ...@@ -674,6 +686,11 @@ class ReduceMean(Reduce):
""" """
name = 'mean' name = 'mean'
class ReduceProd(Reduce):
""" Operator converter for ArgMax.
"""
name = 'prod'
class ArgMax(OnnxOpConverter): class ArgMax(OnnxOpConverter):
""" Operator converter for ArgMax. """ Operator converter for ArgMax.
""" """
...@@ -826,6 +843,7 @@ def _get_convert_map(opset): ...@@ -826,6 +843,7 @@ def _get_convert_map(opset):
'ReduceMin': ReduceMin.get_converter(opset), 'ReduceMin': ReduceMin.get_converter(opset),
'ReduceSum': ReduceSum.get_converter(opset), 'ReduceSum': ReduceSum.get_converter(opset),
'ReduceMean': ReduceMean.get_converter(opset), 'ReduceMean': ReduceMean.get_converter(opset),
'ReduceProd': ReduceProd.get_converter(opset),
# 'ReduceProd' # 'ReduceProd'
# 'ReduceLogSumExp' # 'ReduceLogSumExp'
'ArgMax': ArgMax.get_converter(opset), 'ArgMax': ArgMax.get_converter(opset),
...@@ -842,8 +860,7 @@ def _get_convert_map(opset): ...@@ -842,8 +860,7 @@ def _get_convert_map(opset):
'Squeeze': AttrCvt('squeeze', {'axes': 'axis'}), 'Squeeze': AttrCvt('squeeze', {'axes': 'axis'}),
'Unsqueeze': Unsqueeze.get_converter(opset), 'Unsqueeze': Unsqueeze.get_converter(opset),
'Pad': Pad.get_converter(opset), 'Pad': Pad.get_converter(opset),
# TODO(zhreshold) Shape op is implemented as bypass op in relay 'Shape': Shape.get_converter(opset),
# 'Shape': Shape.get_converter(opset),
} }
...@@ -883,6 +900,7 @@ class GraphProto(object): ...@@ -883,6 +900,7 @@ class GraphProto(object):
---------- ----------
graph : onnx protobuf object graph : onnx protobuf object
The loaded onnx graph The loaded onnx graph
opset : opset version opset : opset version
Returns Returns
...@@ -911,12 +929,12 @@ class GraphProto(object): ...@@ -911,12 +929,12 @@ class GraphProto(object):
dtype=self._params[i_name].dtype) dtype=self._params[i_name].dtype)
else: else:
self._num_input += 1 self._num_input += 1
shape = self._shape[i_name] if i_name in self._shape else () tshape = self._shape[i_name] if i_name in self._shape else ()
if isinstance(self._dtype, dict): if isinstance(self._dtype, dict):
dtype = self._dtype[i_name] if i_name in self._dtype else d_type dtype = self._dtype[i_name] if i_name in self._dtype else d_type
else: else:
dtype = d_type dtype = d_type
self._nodes[i_name] = new_var(i_name, shape=shape, dtype=dtype) self._nodes[i_name] = new_var(i_name, shape=tshape, dtype=dtype)
# construct nodes, nodes are stored as directed acyclic graph # construct nodes, nodes are stored as directed acyclic graph
for node in graph.node: for node in graph.node:
op_name = node.op_type op_name = node.op_type
...@@ -936,6 +954,10 @@ class GraphProto(object): ...@@ -936,6 +954,10 @@ class GraphProto(object):
self._nodes[i_name] = new_var(node.output[0], shape=(), dtype=dtype) self._nodes[i_name] = new_var(node.output[0], shape=(), dtype=dtype)
inputs.append(self._nodes[i_name]) inputs.append(self._nodes[i_name])
i_name = self._parse_value_proto(node)
attr['tvm_custom'] = {}
attr['tvm_custom']['name'] = i_name
op = self._convert_operator(op_name, inputs, attr, opset) op = self._convert_operator(op_name, inputs, attr, opset)
node_output = self._fix_outputs(op_name, node.output) node_output = self._fix_outputs(op_name, node.output)
if not isinstance(op, _expr.TupleWrapper): if not isinstance(op, _expr.TupleWrapper):
......
...@@ -113,35 +113,36 @@ def test_reshape(): ...@@ -113,35 +113,36 @@ def test_reshape():
tvm.testing.assert_allclose(ref_shape, tvm_out.shape) tvm.testing.assert_allclose(ref_shape, tvm_out.shape)
def test_reshape_like(): def test_shape():
in_shape = (4, 3, 3, 4) in_shape = (4, 3, 3, 4)
ref_shape = (3, 4, 4, 3) ref_shape = (6, 2, 4, 3)
ref_array = np.random.uniform(size=ref_shape).astype('float32') ref_array = np.array(ref_shape)
ref_node = onnx.helper.make_node('Constant', ref_node = onnx.helper.make_node('Constant',
inputs=[], inputs=[],
outputs=['ref_in'], outputs=['ref_in'],
value=onnx.helper.make_tensor(name = 'const_tensor', value=onnx.helper.make_tensor(name = 'const_tensor',
data_type = onnx.TensorProto.FLOAT, data_type = onnx.TensorProto.INT32,
dims = ref_array.shape, dims = ref_array.shape,
vals = ref_array.flatten().astype(float))) vals = ref_array.flatten().astype(int)))
copy_node = helper.make_node("Identity", ["ref_in"], ["copy_in"]) reshape_node = helper.make_node("Reshape", ["in", "ref_in"], ["out"])
reshape_node = helper.make_node("Reshape", ["in", "copy_in"], ["out"])
shape_node = helper.make_node("Shape", ['out'], ['final_out'])
graph = helper.make_graph([ref_node, copy_node, reshape_node], graph = helper.make_graph([ref_node, reshape_node, shape_node],
"reshape_like_test", "shape_test",
inputs = [helper.make_tensor_value_info("in", inputs = [helper.make_tensor_value_info("in",
TensorProto.FLOAT, list(in_shape))], TensorProto.FLOAT, list(in_shape))],
outputs = [helper.make_tensor_value_info("out", outputs = [helper.make_tensor_value_info("final_out",
TensorProto.FLOAT, list(ref_shape))]) TensorProto.FLOAT, list(ref_shape))])
model = helper.make_model(graph, producer_name='reshape_like_test') model = helper.make_model(graph, producer_name='shape_test')
for target, ctx in ctx_list(): for target, ctx in ctx_list():
x = np.random.uniform(size=in_shape).astype('float32') x = np.random.uniform(size=in_shape).astype('int32')
tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, 'float32') tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, 'int32')
tvm.testing.assert_allclose(ref_shape, tvm_out.shape) tvm.testing.assert_allclose(ref_shape, tvm_out)
def _test_power_iteration(x_shape, y_shape): def _test_power_iteration(x_shape, y_shape):
if isinstance(y_shape, int): if isinstance(y_shape, int):
...@@ -995,7 +996,7 @@ def test_LogSoftmax(): ...@@ -995,7 +996,7 @@ def test_LogSoftmax():
if __name__ == '__main__': if __name__ == '__main__':
test_reshape() test_reshape()
test_reshape_like() test_shape()
test_power() test_power()
test_squeeze() test_squeeze()
test_unsqueeze() test_unsqueeze()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment