Unverified Commit 575d5369 by Huacong Yang Committed by GitHub

[RELAY][FRONTEND][CAFFE2] add Mul and ConvTranspose operator (#5302)

parent f506c8b1
......@@ -172,6 +172,12 @@ class Add(Elemwise):
name = 'add'
class Mul(Elemwise):
""" Operator converter for Mul.
"""
name = 'multiply'
class Pool(Caffe2OpConverter):
""" A helper class for pool op converters.
"""
......@@ -233,6 +239,33 @@ class Conv(Caffe2OpConverter):
return out
class ConvTranspose(Caffe2OpConverter):
""" Operator converter for ConvTranspose.
"""
@classmethod
def _impl(cls, inputs, args, params):
# get number of channels
channels = infer_channels(inputs[1], True)
args['channels'] = channels
_clean_up_pool_args(args)
out = AttrCvt(
op_name=dimension_picker('conv', '_transpose'),
transforms={
'kernel_shape': 'kernel_size',
'pads': ('padding', (0, 0), revert_caffe2_pad),
'dilations': ('dilation', (1, 1)),
'order': ('data_layout', ("NCHW"), lambda x: x if isinstance(x, str) else x.decode('UTF-8')),
},
excludes=[],
ignores=_caffe2_internal_args,
custom_check=dimension_constraint())(inputs[:2], args, params)
use_bias = len(inputs) == 3
if use_bias:
out = _op.nn.bias_add(out, inputs[2])
return out
class Concat(Caffe2OpConverter):
""" Operator converter for Concat.
"""
......@@ -353,12 +386,14 @@ def _get_convert_map():
# caffe2 common operators
'Add': Add.get_converter(),
'Sum': Sum.get_converter(),
'Mul': Mul.get_converter(),
'Softmax': Softmax.get_converter(),
# nn
'AveragePool': AveragePool.get_converter(),
'MaxPool': MaxPool.get_converter(),
'Conv': Conv.get_converter(),
'ConvTranspose': ConvTranspose.get_converter(),
'Concat': Concat.get_converter(),
'FC': FC.get_converter(),
'SpatialBN': SpatialBN.get_converter(),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment