Commit be8fa6ac by Zhi Committed by Thierry Moreau

[relay][frontend] clean up tf frontend (#3710)

* clean up tf frontend

* fix get_relay_op
parent 8d5de5ed
......@@ -17,6 +17,8 @@
"""Common utilities"""
from __future__ import absolute_import as _abs
import logging
import tvm
from topi.util import get_const_tuple
from .. import expr as _expr
from .. import module as _module
......@@ -224,6 +226,7 @@ class StrAttrsDict(object):
raise AttributeError("Required attribute {} not found.".format(key))
return default
def get_relay_op(op_name):
"""Get the callable function from Relay based on operator name.
Parameters
......@@ -246,9 +249,10 @@ def get_relay_op(op_name):
if op is not None:
break
if not op:
raise RuntimeError("Unable to map op_name {} to relay".format(op_name))
raise tvm.error.OpNotImplemented("Unable to map op_name {} to relay".format(op_name))
return op
class ExprTable(object):
"""Table storing Relay expressions by names."""
def __init__(self):
......@@ -298,21 +302,27 @@ class AttrCvt(object):
If set as str, returned operator name is the str.
If set as callable, returned operator is the str returned by calling:
`op_name = func(attr)`
transforms : dict of `new_name, or (new_name, default_value, transform function)`
If only a new_name is provided, it's like renaming the attribute name.
If default_value if provided, then the attribute is considered as optional.
If transform function is provided, the original attribute value is handled
by transform function.
excludes : list
A list of excluded attributes that should `NOT` appear.
Raise NotImplementedError if occurred.
disables : list
A list of attributes that is disabled in relay. Log warnings.
ignores : list
A list of attributes that is ignored in relay. Debug level logging.
extras : dict
A series of additional attributes should be added anyway to the returned
attribute dict.
custom_check : callable
A custom function takes attribute, and return True/False.
Raise RuntimeError if not bool(True) returned.
......@@ -329,6 +339,14 @@ class AttrCvt(object):
self._custom_check = custom_check
def __call__(self, inputs, attrs, *args):
self._ignores.append('_output_shapes')
self._ignores.append('_input_shapes')
self._ignores.append('T')
self._ignores.append('use_cudnn_on_gpu')
self._ignores.append('_node_name')
self._ignores.append('is_training')
self._ignores.append('_target_layout')
# apply custom check
if self._custom_check:
func, msg = self._custom_check
......@@ -348,7 +366,8 @@ class AttrCvt(object):
new_attrs = {}
for k in attrs.keys():
if k in self._excludes:
raise NotImplementedError("Attribute {} not supported yet.".format(k))
raise NotImplementedError('Attribute %s in operator %s is not' +
' supported.', k, op_name)
elif k in self._disables:
logging.warning("Attribute %s is disabled in relay.sym.%s", k, op_name)
elif k in self._ignores:
......@@ -401,6 +420,7 @@ class AttrCvt(object):
raise AttributeError("Required attribute {} not found.".format(key))
return attr[key]
def get_name(node):
name = ''
if hasattr(node, "name_hint"):
......@@ -410,17 +430,19 @@ def get_name(node):
def infer_type(node):
"""A method to infer the type of an intermediate node in the relay graph."""
mod = _module.Module.from_expr(node)
mod = node if isinstance(node, _module.Module) else _module.Module.from_expr(node)
mod = _transform.InferType()(mod)
entry = mod["main"]
return entry if isinstance(node, _expr.Function) else entry.body
def infer_shape(inputs):
"""A method to get the output shape of an intermediate node in the graph."""
out_type = infer_type(inputs)
out_shapes = get_const_tuple(out_type.checked_type.shape)
return out_shapes
def infer_channels(inputs, transpose=False):
"""A hack for getting 'channels' or 'units' since caffe2 does not provide
these attributes. We check the shape of weights provided to get the number.
......@@ -430,12 +452,14 @@ def infer_channels(inputs, transpose=False):
channels = out_shapes[0][0] if not transpose else out_shapes[0][1]
return channels
def new_var(name_hint,
type_annotation=None,
shape=None,
dtype="float32"):
return _expr.var(name_hint, type_annotation, shape, dtype)
class Renamer(object):
"""A simply renamer for operators.
......
......@@ -20,13 +20,14 @@ from __future__ import absolute_import as _abs
import json
import tvm
from .. import analysis, transform
from .. import analysis
from .. import expr as _expr
from .. import op as _op
from .. import module as _module
from ... import nd as _nd
from .common import StrAttrsDict
from .common import infer_type as _infer_type
from .nnvm_common import _rename, _binop_scalar, _rbinop_scalar, _reduce
from .nnvm_common import _arg_reduce, _init_op, _softmax_op, _cast
from .nnvm_common import _clip, _transpose, _upsampling
......@@ -41,13 +42,6 @@ _activation_map = {
"relu" : _op.nn.relu
}
def _infer_type(node):
"""A method to infer the type of an intermediate node in the relay graph."""
mod = _module.Module.from_expr(node)
mod = transform.InferType()(mod)
entry = mod["main"]
return entry if isinstance(node, _expr.Function) else entry.body
def _mx_fully_connected(inputs, attrs):
import mxnet as mx
units = attrs.get_int("num_hidden")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment