Commit be8fa6ac by Zhi Committed by Thierry Moreau

[relay][frontend] clean up tf frontend (#3710)

* clean up tf frontend

* fix get_relay_op
parent 8d5de5ed
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
"""Common utilities""" """Common utilities"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import logging import logging
import tvm
from topi.util import get_const_tuple from topi.util import get_const_tuple
from .. import expr as _expr from .. import expr as _expr
from .. import module as _module from .. import module as _module
...@@ -224,6 +226,7 @@ class StrAttrsDict(object): ...@@ -224,6 +226,7 @@ class StrAttrsDict(object):
raise AttributeError("Required attribute {} not found.".format(key)) raise AttributeError("Required attribute {} not found.".format(key))
return default return default
def get_relay_op(op_name): def get_relay_op(op_name):
"""Get the callable function from Relay based on operator name. """Get the callable function from Relay based on operator name.
Parameters Parameters
...@@ -246,9 +249,10 @@ def get_relay_op(op_name): ...@@ -246,9 +249,10 @@ def get_relay_op(op_name):
if op is not None: if op is not None:
break break
if not op: if not op:
raise RuntimeError("Unable to map op_name {} to relay".format(op_name)) raise tvm.error.OpNotImplemented("Unable to map op_name {} to relay".format(op_name))
return op return op
class ExprTable(object): class ExprTable(object):
"""Table storing Relay expressions by names.""" """Table storing Relay expressions by names."""
def __init__(self): def __init__(self):
...@@ -298,21 +302,27 @@ class AttrCvt(object): ...@@ -298,21 +302,27 @@ class AttrCvt(object):
If set as str, returned operator name is the str. If set as str, returned operator name is the str.
If set as callable, returned operator is the str returned by calling: If set as callable, returned operator is the str returned by calling:
`op_name = func(attr)` `op_name = func(attr)`
transforms : dict of `new_name, or (new_name, default_value, transform function)` transforms : dict of `new_name, or (new_name, default_value, transform function)`
If only a new_name is provided, it's like renaming the attribute name. If only a new_name is provided, it's like renaming the attribute name.
If default_value if provided, then the attribute is considered as optional. If default_value if provided, then the attribute is considered as optional.
If transform function is provided, the original attribute value is handled If transform function is provided, the original attribute value is handled
by transform function. by transform function.
excludes : list excludes : list
A list of excluded attributes that should `NOT` appear. A list of excluded attributes that should `NOT` appear.
Raise NotImplementedError if occurred. Raise NotImplementedError if occurred.
disables : list disables : list
A list of attributes that is disabled in relay. Log warnings. A list of attributes that is disabled in relay. Log warnings.
ignores : list ignores : list
A list of attributes that is ignored in relay. Debug level logging. A list of attributes that is ignored in relay. Debug level logging.
extras : dict extras : dict
A series of additional attributes should be added anyway to the returned A series of additional attributes should be added anyway to the returned
attribute dict. attribute dict.
custom_check : callable custom_check : callable
A custom function takes attribute, and return True/False. A custom function takes attribute, and return True/False.
Raise RuntimeError if not bool(True) returned. Raise RuntimeError if not bool(True) returned.
...@@ -329,6 +339,14 @@ class AttrCvt(object): ...@@ -329,6 +339,14 @@ class AttrCvt(object):
self._custom_check = custom_check self._custom_check = custom_check
def __call__(self, inputs, attrs, *args): def __call__(self, inputs, attrs, *args):
self._ignores.append('_output_shapes')
self._ignores.append('_input_shapes')
self._ignores.append('T')
self._ignores.append('use_cudnn_on_gpu')
self._ignores.append('_node_name')
self._ignores.append('is_training')
self._ignores.append('_target_layout')
# apply custom check # apply custom check
if self._custom_check: if self._custom_check:
func, msg = self._custom_check func, msg = self._custom_check
...@@ -348,7 +366,8 @@ class AttrCvt(object): ...@@ -348,7 +366,8 @@ class AttrCvt(object):
new_attrs = {} new_attrs = {}
for k in attrs.keys(): for k in attrs.keys():
if k in self._excludes: if k in self._excludes:
raise NotImplementedError("Attribute {} not supported yet.".format(k)) raise NotImplementedError('Attribute %s in operator %s is not' +
' supported.', k, op_name)
elif k in self._disables: elif k in self._disables:
logging.warning("Attribute %s is disabled in relay.sym.%s", k, op_name) logging.warning("Attribute %s is disabled in relay.sym.%s", k, op_name)
elif k in self._ignores: elif k in self._ignores:
...@@ -401,6 +420,7 @@ class AttrCvt(object): ...@@ -401,6 +420,7 @@ class AttrCvt(object):
raise AttributeError("Required attribute {} not found.".format(key)) raise AttributeError("Required attribute {} not found.".format(key))
return attr[key] return attr[key]
def get_name(node): def get_name(node):
name = '' name = ''
if hasattr(node, "name_hint"): if hasattr(node, "name_hint"):
...@@ -410,17 +430,19 @@ def get_name(node): ...@@ -410,17 +430,19 @@ def get_name(node):
def infer_type(node): def infer_type(node):
"""A method to infer the type of an intermediate node in the relay graph.""" """A method to infer the type of an intermediate node in the relay graph."""
mod = _module.Module.from_expr(node) mod = node if isinstance(node, _module.Module) else _module.Module.from_expr(node)
mod = _transform.InferType()(mod) mod = _transform.InferType()(mod)
entry = mod["main"] entry = mod["main"]
return entry if isinstance(node, _expr.Function) else entry.body return entry if isinstance(node, _expr.Function) else entry.body
def infer_shape(inputs): def infer_shape(inputs):
"""A method to get the output shape of an intermediate node in the graph.""" """A method to get the output shape of an intermediate node in the graph."""
out_type = infer_type(inputs) out_type = infer_type(inputs)
out_shapes = get_const_tuple(out_type.checked_type.shape) out_shapes = get_const_tuple(out_type.checked_type.shape)
return out_shapes return out_shapes
def infer_channels(inputs, transpose=False): def infer_channels(inputs, transpose=False):
"""A hack for getting 'channels' or 'units' since caffe2 does not provide """A hack for getting 'channels' or 'units' since caffe2 does not provide
these attributes. We check the shape of weights provided to get the number. these attributes. We check the shape of weights provided to get the number.
...@@ -430,12 +452,14 @@ def infer_channels(inputs, transpose=False): ...@@ -430,12 +452,14 @@ def infer_channels(inputs, transpose=False):
channels = out_shapes[0][0] if not transpose else out_shapes[0][1] channels = out_shapes[0][0] if not transpose else out_shapes[0][1]
return channels return channels
def new_var(name_hint, def new_var(name_hint,
type_annotation=None, type_annotation=None,
shape=None, shape=None,
dtype="float32"): dtype="float32"):
return _expr.var(name_hint, type_annotation, shape, dtype) return _expr.var(name_hint, type_annotation, shape, dtype)
class Renamer(object): class Renamer(object):
"""A simply renamer for operators. """A simply renamer for operators.
......
...@@ -20,13 +20,14 @@ from __future__ import absolute_import as _abs ...@@ -20,13 +20,14 @@ from __future__ import absolute_import as _abs
import json import json
import tvm import tvm
from .. import analysis, transform from .. import analysis
from .. import expr as _expr from .. import expr as _expr
from .. import op as _op from .. import op as _op
from .. import module as _module from .. import module as _module
from ... import nd as _nd from ... import nd as _nd
from .common import StrAttrsDict from .common import StrAttrsDict
from .common import infer_type as _infer_type
from .nnvm_common import _rename, _binop_scalar, _rbinop_scalar, _reduce from .nnvm_common import _rename, _binop_scalar, _rbinop_scalar, _reduce
from .nnvm_common import _arg_reduce, _init_op, _softmax_op, _cast from .nnvm_common import _arg_reduce, _init_op, _softmax_op, _cast
from .nnvm_common import _clip, _transpose, _upsampling from .nnvm_common import _clip, _transpose, _upsampling
...@@ -41,13 +42,6 @@ _activation_map = { ...@@ -41,13 +42,6 @@ _activation_map = {
"relu" : _op.nn.relu "relu" : _op.nn.relu
} }
def _infer_type(node):
"""A method to infer the type of an intermediate node in the relay graph."""
mod = _module.Module.from_expr(node)
mod = transform.InferType()(mod)
entry = mod["main"]
return entry if isinstance(node, _expr.Function) else entry.body
def _mx_fully_connected(inputs, attrs): def _mx_fully_connected(inputs, attrs):
import mxnet as mx import mxnet as mx
units = attrs.get_int("num_hidden") units = attrs.get_int("num_hidden")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment