Commit be8fa6ac by Zhi Committed by Thierry Moreau

[relay][frontend] clean up tf frontend (#3710)

* clean up tf frontend

* fix get_relay_op
parent 8d5de5ed
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
"""Common utilities""" """Common utilities"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import logging import logging
import tvm
from topi.util import get_const_tuple from topi.util import get_const_tuple
from .. import expr as _expr from .. import expr as _expr
from .. import module as _module from .. import module as _module
...@@ -224,6 +226,7 @@ class StrAttrsDict(object): ...@@ -224,6 +226,7 @@ class StrAttrsDict(object):
raise AttributeError("Required attribute {} not found.".format(key)) raise AttributeError("Required attribute {} not found.".format(key))
return default return default
def get_relay_op(op_name): def get_relay_op(op_name):
"""Get the callable function from Relay based on operator name. """Get the callable function from Relay based on operator name.
Parameters Parameters
...@@ -246,9 +249,10 @@ def get_relay_op(op_name): ...@@ -246,9 +249,10 @@ def get_relay_op(op_name):
if op is not None: if op is not None:
break break
if not op: if not op:
raise RuntimeError("Unable to map op_name {} to relay".format(op_name)) raise tvm.error.OpNotImplemented("Unable to map op_name {} to relay".format(op_name))
return op return op
class ExprTable(object): class ExprTable(object):
"""Table storing Relay expressions by names.""" """Table storing Relay expressions by names."""
def __init__(self): def __init__(self):
...@@ -298,21 +302,27 @@ class AttrCvt(object): ...@@ -298,21 +302,27 @@ class AttrCvt(object):
If set as str, returned operator name is the str. If set as str, returned operator name is the str.
If set as callable, returned operator is the str returned by calling: If set as callable, returned operator is the str returned by calling:
`op_name = func(attr)` `op_name = func(attr)`
transforms : dict of `new_name, or (new_name, default_value, transform function)` transforms : dict of `new_name, or (new_name, default_value, transform function)`
If only a new_name is provided, it's like renaming the attribute name. If only a new_name is provided, it's like renaming the attribute name.
If default_value if provided, then the attribute is considered as optional. If default_value if provided, then the attribute is considered as optional.
If transform function is provided, the original attribute value is handled If transform function is provided, the original attribute value is handled
by transform function. by transform function.
excludes : list excludes : list
A list of excluded attributes that should `NOT` appear. A list of excluded attributes that should `NOT` appear.
Raise NotImplementedError if occurred. Raise NotImplementedError if occurred.
disables : list disables : list
A list of attributes that is disabled in relay. Log warnings. A list of attributes that is disabled in relay. Log warnings.
ignores : list ignores : list
A list of attributes that is ignored in relay. Debug level logging. A list of attributes that is ignored in relay. Debug level logging.
extras : dict extras : dict
A series of additional attributes should be added anyway to the returned A series of additional attributes should be added anyway to the returned
attribute dict. attribute dict.
custom_check : callable custom_check : callable
A custom function takes attribute, and return True/False. A custom function takes attribute, and return True/False.
Raise RuntimeError if not bool(True) returned. Raise RuntimeError if not bool(True) returned.
...@@ -329,6 +339,14 @@ class AttrCvt(object): ...@@ -329,6 +339,14 @@ class AttrCvt(object):
self._custom_check = custom_check self._custom_check = custom_check
def __call__(self, inputs, attrs, *args): def __call__(self, inputs, attrs, *args):
self._ignores.append('_output_shapes')
self._ignores.append('_input_shapes')
self._ignores.append('T')
self._ignores.append('use_cudnn_on_gpu')
self._ignores.append('_node_name')
self._ignores.append('is_training')
self._ignores.append('_target_layout')
# apply custom check # apply custom check
if self._custom_check: if self._custom_check:
func, msg = self._custom_check func, msg = self._custom_check
...@@ -348,7 +366,8 @@ class AttrCvt(object): ...@@ -348,7 +366,8 @@ class AttrCvt(object):
new_attrs = {} new_attrs = {}
for k in attrs.keys(): for k in attrs.keys():
if k in self._excludes: if k in self._excludes:
raise NotImplementedError("Attribute {} not supported yet.".format(k)) raise NotImplementedError('Attribute %s in operator %s is not' +
' supported.', k, op_name)
elif k in self._disables: elif k in self._disables:
logging.warning("Attribute %s is disabled in relay.sym.%s", k, op_name) logging.warning("Attribute %s is disabled in relay.sym.%s", k, op_name)
elif k in self._ignores: elif k in self._ignores:
...@@ -401,6 +420,7 @@ class AttrCvt(object): ...@@ -401,6 +420,7 @@ class AttrCvt(object):
raise AttributeError("Required attribute {} not found.".format(key)) raise AttributeError("Required attribute {} not found.".format(key))
return attr[key] return attr[key]
def get_name(node): def get_name(node):
name = '' name = ''
if hasattr(node, "name_hint"): if hasattr(node, "name_hint"):
...@@ -410,17 +430,19 @@ def get_name(node): ...@@ -410,17 +430,19 @@ def get_name(node):
def infer_type(node): def infer_type(node):
"""A method to infer the type of an intermediate node in the relay graph.""" """A method to infer the type of an intermediate node in the relay graph."""
mod = _module.Module.from_expr(node) mod = node if isinstance(node, _module.Module) else _module.Module.from_expr(node)
mod = _transform.InferType()(mod) mod = _transform.InferType()(mod)
entry = mod["main"] entry = mod["main"]
return entry if isinstance(node, _expr.Function) else entry.body return entry if isinstance(node, _expr.Function) else entry.body
def infer_shape(inputs): def infer_shape(inputs):
"""A method to get the output shape of an intermediate node in the graph.""" """A method to get the output shape of an intermediate node in the graph."""
out_type = infer_type(inputs) out_type = infer_type(inputs)
out_shapes = get_const_tuple(out_type.checked_type.shape) out_shapes = get_const_tuple(out_type.checked_type.shape)
return out_shapes return out_shapes
def infer_channels(inputs, transpose=False): def infer_channels(inputs, transpose=False):
"""A hack for getting 'channels' or 'units' since caffe2 does not provide """A hack for getting 'channels' or 'units' since caffe2 does not provide
these attributes. We check the shape of weights provided to get the number. these attributes. We check the shape of weights provided to get the number.
...@@ -430,12 +452,14 @@ def infer_channels(inputs, transpose=False): ...@@ -430,12 +452,14 @@ def infer_channels(inputs, transpose=False):
channels = out_shapes[0][0] if not transpose else out_shapes[0][1] channels = out_shapes[0][0] if not transpose else out_shapes[0][1]
return channels return channels
def new_var(name_hint, def new_var(name_hint,
type_annotation=None, type_annotation=None,
shape=None, shape=None,
dtype="float32"): dtype="float32"):
return _expr.var(name_hint, type_annotation, shape, dtype) return _expr.var(name_hint, type_annotation, shape, dtype)
class Renamer(object): class Renamer(object):
"""A simply renamer for operators. """A simply renamer for operators.
......
...@@ -20,13 +20,14 @@ from __future__ import absolute_import as _abs ...@@ -20,13 +20,14 @@ from __future__ import absolute_import as _abs
import json import json
import tvm import tvm
from .. import analysis, transform from .. import analysis
from .. import expr as _expr from .. import expr as _expr
from .. import op as _op from .. import op as _op
from .. import module as _module from .. import module as _module
from ... import nd as _nd from ... import nd as _nd
from .common import StrAttrsDict from .common import StrAttrsDict
from .common import infer_type as _infer_type
from .nnvm_common import _rename, _binop_scalar, _rbinop_scalar, _reduce from .nnvm_common import _rename, _binop_scalar, _rbinop_scalar, _reduce
from .nnvm_common import _arg_reduce, _init_op, _softmax_op, _cast from .nnvm_common import _arg_reduce, _init_op, _softmax_op, _cast
from .nnvm_common import _clip, _transpose, _upsampling from .nnvm_common import _clip, _transpose, _upsampling
...@@ -41,13 +42,6 @@ _activation_map = { ...@@ -41,13 +42,6 @@ _activation_map = {
"relu" : _op.nn.relu "relu" : _op.nn.relu
} }
def _infer_type(node):
"""A method to infer the type of an intermediate node in the relay graph."""
mod = _module.Module.from_expr(node)
mod = transform.InferType()(mod)
entry = mod["main"]
return entry if isinstance(node, _expr.Function) else entry.body
def _mx_fully_connected(inputs, attrs): def _mx_fully_connected(inputs, attrs):
import mxnet as mx import mxnet as mx
units = attrs.get_int("num_hidden") units = attrs.get_int("num_hidden")
......
...@@ -19,20 +19,21 @@ ...@@ -19,20 +19,21 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from __future__ import print_function from __future__ import print_function
import logging
import warnings import warnings
from collections import defaultdict from collections import defaultdict
# Numpy support # Numpy support
import numpy as np import numpy as np
import tvm import tvm
from topi.util import get_const_tuple
from .. import analysis from .. import analysis
from .. import transform as _transform
from .. import expr as _expr from .. import expr as _expr
from .. import op as _op from .. import op as _op
from ..expr_functor import ExprMutator from ..expr_functor import ExprMutator
from .. import module as _module from .. import module as _module
from .common import AttrCvt, get_relay_op
from .common import infer_type as _infer_type
from .common import infer_shape as _infer_shape
from .common import infer_channels as _infer_channels
__all__ = ['from_tensorflow'] __all__ = ['from_tensorflow']
...@@ -50,140 +51,6 @@ def _infer_value(input_val, params): ...@@ -50,140 +51,6 @@ def _infer_value(input_val, params):
m.run() m.run()
return m.get_output(0) return m.get_output(0)
def _get_relay_op(op_name):
ops = [_op, _op.nn, _op.image, _op.vision]
for operator in ops:
try:
op = getattr(operator, op_name)
return op
except AttributeError:
continue
raise tvm.error.OpNotImplemented(
'Operator {} is not supported for frontend TensorFlow.'.format(op_name))
class AttrCvt(object):
"""Common attribute converter. An AttrConverter instance is a callable:
```
attr_converter = AttrConverter(op_name, transforms={'a':'b', 'c':('d', 1)})
new_op_name, new_attr = attr_converter(attrs)
```
Parameters
----------
op_name : str or callable
If set as str, returned operator name is the str.
If set as callable, returned operator is the str returned by calling:
`op_name = func(attr)`
transforms : dict of `new_name, or (new_name, default_value, transform function)`
If only a new_name is provided, it's like renaming the attribute name.
If default_value if provided, then the attribute is considered as optional.
If transform function is provided, the original attribute value is handled
by transform function.
excludes : list
A list of excluded attributes that should `NOT` appear.
Raise NotImplementedError if occurred.
disables : list
A list of attributes that is disabled in relay. Log warnings.
ignores : list
A list of attributes that is ignored in relay. Debug level logging.
extras : dict
A series of additional attributes should be added anyway to the returned
attribute dict.
custom_check : callable
A custom function takes attribute, and return True/False.
Raise RuntimeError if not bool(True) returned.
"""
def __init__(self, op_name, transforms=None,
excludes=None, disables=None, ignores=None,
extras=None, custom_check=None):
self._op_name = op_name
self._transforms = transforms if transforms else {}
self._excludes = excludes if excludes else []
self._disables = disables if disables else []
self._ignores = ignores if ignores else []
self._extras = extras if extras else {}
self._custom_check = custom_check
def __call__(self, inputs, attrs, *args):
self._ignores.append('_output_shapes')
self._ignores.append('_input_shapes')
self._ignores.append('T')
self._ignores.append('use_cudnn_on_gpu')
self._ignores.append('_node_name')
self._ignores.append('is_training')
self._ignores.append('_target_layout')
# apply custom check
if self._custom_check:
func, msg = self._custom_check
if not func(attrs):
raise RuntimeError("Check failed: {}".format(msg))
# get new op_name
if isinstance(self._op_name, str):
op_name = self._op_name
else:
assert callable(self._op_name), "op_name can either be string or callable"
op_name = self._op_name(attrs)
# convert attributes
new_attrs = {}
for k in attrs.keys():
if k in self._excludes:
raise tvm.error.OpAttributeUnImplemented(
'Attribute {} in operator {} is not supported.'.format(k, op_name))
elif k in self._disables:
logging.warning("Attribute %s is disabled in relay.%s", k, op_name)
elif k in self._ignores:
logging.debug("Attribute %s is ignored in relay.%s", k, op_name)
elif k in self._transforms:
new_name, defaults, transform = self._parse_default(self._transforms[k])
if defaults is None:
new_attr = self._required_attr(attrs, k)
else:
new_attr = attrs.get(k, None)
if new_attr is None:
new_attrs[new_name] = defaults
else:
new_attrs[new_name] = transform(new_attr)
else:
# copy
new_attrs[k] = attrs[k]
# add extras
new_attrs.update(self._extras)
return _get_relay_op(op_name)(*inputs, **new_attrs)
def _parse_default(self, target):
"""Helper function to parse default values."""
if not isinstance(target, (list, tuple)):
k, v, t = target, None, lambda x: x
elif len(target) == 1:
k, v, t = target[0], None, lambda x: x
elif len(target) == 2:
k, v, t = target[0], target[1], lambda x: x
elif len(target) > 2:
k, v, t = target[0], target[1], target[2]
else:
k = None # should raise
if not isinstance(k, str):
msg = "{} is not a valid target, (name, default) expected.".format(target)
raise ValueError(msg)
return k, v, t
def _parse_bool(self, value):
"""Helper function to parse default boolean values."""
if isinstance(value, str):
return value.strip().lower() in ['true', '1', 't', 'y', 'yes']
return bool(value)
def _required_attr(self, attr, key):
"""Wrapper for getting required attributes."""
assert isinstance(attr, dict)
if key not in attr:
raise tvm.error.OpAttributeRequired(
'Attribute {} not found in operator {}'.format(key, self._op_name))
return attr[key]
def _get_pad_pair(input1d, kernel1d, stride1d): def _get_pad_pair(input1d, kernel1d, stride1d):
if input1d % stride1d == 0: if input1d % stride1d == 0:
pad = max(kernel1d - stride1d, 0) pad = max(kernel1d - stride1d, 0)
...@@ -195,12 +62,6 @@ def _get_pad_pair(input1d, kernel1d, stride1d): ...@@ -195,12 +62,6 @@ def _get_pad_pair(input1d, kernel1d, stride1d):
return [pad_before, pad_after] return [pad_before, pad_after]
def _get_name_hint(node):
name = ''
if hasattr(node, "name_hint"):
name = node.name_hint
return name
def _math_name_picker(surfix): def _math_name_picker(surfix):
def _impl(attr): def _impl(attr):
return 'broadcast_' + surfix return 'broadcast_' + surfix
...@@ -222,30 +83,6 @@ def _dimension_constraint(): ...@@ -222,30 +83,6 @@ def _dimension_constraint():
return False return False
return _dim_check, "Only 2d kernel supported." return _dim_check, "Only 2d kernel supported."
def _infer_channels(node, params, transpose=False):
"""A hack for getting 'channels' or 'units' since tensorflow don't provide
these attributes. We check the shape of weights provided to get the number.
"""
out_shape = _infer_shape(node, params)
channels = out_shape[0] if not transpose else out_shape[1]
return channels
def _infer_out_shapes(inputs, params):
"""A method to get the output shape of intermediate nodes in the relay graph."""
return [_infer_shape(inputs, params)]
def _infer_type(node):
"""A method to infer the type of an intermediate node in the relay graph."""
mod = _module.Module.from_expr(node)
mod = _transform.InferType()(mod)
entry = mod["main"]
return entry if isinstance(node, _expr.Function) else entry.body
def _infer_shape(node, params=None):
"""A method to get the output shape of an intermediate node in the relay graph."""
out_type = _infer_type(node)
return get_const_tuple(out_type.checked_type.shape)
def _get_param(params, input_node): def _get_param(params, input_node):
return params.pop(input_node.name_hint).asnumpy() return params.pop(input_node.name_hint).asnumpy()
...@@ -280,7 +117,7 @@ def _argx(func, func_name): ...@@ -280,7 +117,7 @@ def _argx(func, func_name):
def _elemwise(name): def _elemwise(name):
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
assert len(inputs) == 2, "{} take 2 inputs, {} given".format(name, len(inputs)) assert len(inputs) == 2, "{} take 2 inputs, {} given".format(name, len(inputs))
return _get_relay_op(name)(*inputs) return get_relay_op(name)(*inputs)
return _impl return _impl
def _pooling(name): def _pooling(name):
...@@ -300,7 +137,7 @@ def _pooling(name): ...@@ -300,7 +137,7 @@ def _pooling(name):
else: else:
msg = 'Value {} of attribute "data_format" of operator Pooling ' \ msg = 'Value {} of attribute "data_format" of operator Pooling ' \
'is not valid.' 'is not valid.'
raise tvm.error.OpAttributeInvalid(msg.format(attrs['data_format'])) raise tvm.error.OpAttributeInvalid(msg.format(attr['data_format']))
if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC": if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC":
tmp_shape = attr['_input_shapes'][inputs[0]] tmp_shape = attr['_input_shapes'][inputs[0]]
...@@ -539,7 +376,7 @@ def _crop_and_resize(): ...@@ -539,7 +376,7 @@ def _crop_and_resize():
res_crop = _op.strided_slice(inputs[0], begin=begin, end=size) res_crop = _op.strided_slice(inputs[0], begin=begin, end=size)
# 2) Resize # 2) Resize
res_resize = _get_relay_op('resize')(res_crop, **attrs) res_resize = get_relay_op('resize')(res_crop, **attrs)
out = _op.concatenate([out, res_resize], axis=0) if out else res_resize out = _op.concatenate([out, res_resize], axis=0) if out else res_resize
return out return out
return _impl return _impl
...@@ -598,7 +435,7 @@ def _check_numerics(): ...@@ -598,7 +435,7 @@ def _check_numerics():
def _matmul(): def _matmul():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
channels = _infer_channels(inputs[1], params, not attr['transpose_b']) channels = _infer_channels(inputs[1], not attr['transpose_b'])
if attr['transpose_a']: if attr['transpose_a']:
inputs[0] = _op.transpose(inputs[0], axes=(1, 0)) inputs[0] = _op.transpose(inputs[0], axes=(1, 0))
if not attr['transpose_b']: if not attr['transpose_b']:
...@@ -615,15 +452,10 @@ def _batch_matmul(): ...@@ -615,15 +452,10 @@ def _batch_matmul():
adj_y = attr['adj_y'] adj_y = attr['adj_y']
input_x = _op.transpose(inputs[0], axes=[0, 2, 1]) if adj_x else inputs[0] input_x = _op.transpose(inputs[0], axes=[0, 2, 1]) if adj_x else inputs[0]
input_y = _op.transpose(inputs[1], axes=[0, 2, 1]) if not adj_y else inputs[1] input_y = _op.transpose(inputs[1], axes=[0, 2, 1]) if not adj_y else inputs[1]
ret = _get_relay_op('batch_matmul')(input_x, input_y) ret = get_relay_op('batch_matmul')(input_x, input_y)
return ret return ret
return _impl return _impl
def _undef():
def _impl(inputs, attr, params):
return _sym.__undef__()
return _impl
def _identity(): def _identity():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
return inputs[0] return inputs[0]
...@@ -985,7 +817,7 @@ def _stridedSlice(): ...@@ -985,7 +817,7 @@ def _stridedSlice():
if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask: if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask:
begin, end, stride, fshape_indices = _transform_mask(stride_dim, ellipsis_mask) begin, end, stride, fshape_indices = _transform_mask(stride_dim, ellipsis_mask)
out = _op.strided_slice(inputs[0], begin=begin, end=end, strides=stride) out = _op.strided_slice(inputs[0], begin=begin, end=end, strides=stride)
out_shape = _infer_shape(out, params) out_shape = _infer_shape(out)
if not fshape_indices: if not fshape_indices:
fshape_indices = range(len(out_shape)) fshape_indices = range(len(out_shape))
...@@ -1178,8 +1010,8 @@ def _softplus(): ...@@ -1178,8 +1010,8 @@ def _softplus():
exp_out = AttrCvt('exp')(inputs, attr) exp_out = AttrCvt('exp')(inputs, attr)
inputs.append(tvm.relay.const(1, attr['T'].name)) inputs.append(tvm.relay.const(1, attr['T'].name))
rh = tvm.relay.const(1, attr['T'].name) rh = tvm.relay.const(1, attr['T'].name)
add_out = _get_relay_op('add')(exp_out, rh) add_out = get_relay_op('add')(exp_out, rh)
return _get_relay_op('log')(add_out) return get_relay_op('log')(add_out)
return _impl return _impl
def _topk(): def _topk():
...@@ -1200,7 +1032,7 @@ def _floordiv(): ...@@ -1200,7 +1032,7 @@ def _floordiv():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
assert len(inputs) == 2 assert len(inputs) == 2
div = AttrCvt('divide')(inputs, attr) div = AttrCvt('divide')(inputs, attr)
return _get_relay_op('floor')(div) return get_relay_op('floor')(div)
return _impl return _impl
def _logical(name): def _logical(name):
...@@ -1234,7 +1066,7 @@ def _space_to_batch_nd(): ...@@ -1234,7 +1066,7 @@ def _space_to_batch_nd():
axes = [2 * i + 2 for i in range(M)] + [0] + [2 * i + 1 for i in range(M)] + \ axes = [2 * i + 2 for i in range(M)] + [0] + [2 * i + 1 for i in range(M)] + \
list(range(1 + 2 * M, 1 + 2 * M + remaining_shape_length)) list(range(1 + 2 * M, 1 + 2 * M + remaining_shape_length))
permuted_reshaped_padded = tvm.relay.transpose(reshaped_padded, axes=axes) permuted_reshaped_padded = tvm.relay.transpose(reshaped_padded, axes=axes)
permuted_reshaped_padded_shape = _infer_shape(permuted_reshaped_padded, params) permuted_reshaped_padded_shape = _infer_shape(permuted_reshaped_padded)
# Reshape permuted_reshaped_padded to flatten block_shape into the batch dimension, # Reshape permuted_reshaped_padded to flatten block_shape into the batch dimension,
# producing an output tensor of shape: # producing an output tensor of shape:
# [batch * prod(block_shape)] + [padded_shape[1] / block_shape[0], ..., # [batch * prod(block_shape)] + [padded_shape[1] / block_shape[0], ...,
...@@ -1277,7 +1109,7 @@ def _batch_to_space_nd(): ...@@ -1277,7 +1109,7 @@ def _batch_to_space_nd():
# [batch / prod(block_shape), input_shape[1] * block_shape[0] - crops[0,0] - crops[0,1], # [batch / prod(block_shape), input_shape[1] * block_shape[0] - crops[0,0] - crops[0,1],
# ..., input_shape[M] * block_shape[M-1] - crops[M-1,0] - crops[M-1,1], # ..., input_shape[M] * block_shape[M-1] - crops[M-1,0] - crops[M-1,1],
# input_shape[M+1], ..., input_shape[N-1]] # input_shape[M+1], ..., input_shape[N-1]]
reshaped_permuted_shape = _infer_shape(reshaped_permuted, params) reshaped_permuted_shape = _infer_shape(reshaped_permuted)
cropped = reshaped_permuted cropped = reshaped_permuted
for axis in range(1, M+1): for axis in range(1, M+1):
crop = crops[axis - 1] crop = crops[axis - 1]
...@@ -1305,8 +1137,8 @@ def _log1p(): ...@@ -1305,8 +1137,8 @@ def _log1p():
# op description: https://www.tensorflow.org/api_docs/python/tf/math/log1p # op description: https://www.tensorflow.org/api_docs/python/tf/math/log1p
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
one = tvm.relay.const(1, attr['T'].name) one = tvm.relay.const(1, attr['T'].name)
add_out = _get_relay_op('add')(inputs[0], one) add_out = get_relay_op('add')(inputs[0], one)
return _get_relay_op('log')(add_out) return get_relay_op('log')(add_out)
return _impl return _impl
# compatible operators that do NOT require any conversion. # compatible operators that do NOT require any conversion.
...@@ -2399,7 +2231,7 @@ class GraphProto(object): ...@@ -2399,7 +2231,7 @@ class GraphProto(object):
convert_map = convert_map if convert_map else _convert_map convert_map = convert_map if convert_map else _convert_map
convert_map_rnn = _convert_map_rnn convert_map_rnn = _convert_map_rnn
if op_name in identity_list: if op_name in identity_list:
sym = _get_relay_op(op_name)(*inputs, **attrs) sym = get_relay_op(op_name)(*inputs, **attrs)
elif op_name in convert_map: elif op_name in convert_map:
sym = convert_map[op_name](inputs, attrs, self._params) sym = convert_map[op_name](inputs, attrs, self._params)
elif op_name in convert_map_rnn: elif op_name in convert_map_rnn:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment