Commit 97aaadeb by masahi Committed by Tianqi Chen

fix onnx conv2d_transpose loading (#245)

parent 51a25982
...@@ -64,7 +64,8 @@ def _elemwise(name): ...@@ -64,7 +64,8 @@ def _elemwise(name):
assert len(inputs) == 2, "Math op take 2 inputs, {} given".format(len(inputs)) assert len(inputs) == 2, "Math op take 2 inputs, {} given".format(len(inputs))
op_name = _math_name_picker(name)(attr) op_name = _math_name_picker(name)(attr)
axis = int(attr.get('axis', 0)) axis = int(attr.get('axis', 0))
if op_name == 'broadcast_add' and inputs[0].attr('op_name') == 'conv2d': conv_ops = ["conv2d", "conv2d_transpose"]
if op_name == 'broadcast_add' and inputs[0].attr('op_name') in conv_ops:
# TODO(zhreshold): remove hard coded infershape # TODO(zhreshold): remove hard coded infershape
inputs[1] = _sym.expand_dims(inputs[1], axis=axis, num_newaxis=2) inputs[1] = _sym.expand_dims(inputs[1], axis=axis, num_newaxis=2)
return get_nnvm_op(op_name)(*inputs) return get_nnvm_op(op_name)(*inputs)
...@@ -101,8 +102,10 @@ def _conv(): ...@@ -101,8 +102,10 @@ def _conv():
def _conv_transpose(): def _conv_transpose():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
# get number of channels # get number of channels
channels = _infer_channels(inputs[1], params) channels = _infer_channels(inputs[1], params, True)
attr['channels'] = channels attr['channels'] = channels
groups = attr.pop('group')
attr['groups'] = groups
return AttrCvt( return AttrCvt(
op_name=_dimension_picker('conv', '_transpose'), op_name=_dimension_picker('conv', '_transpose'),
transforms={ transforms={
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment