Commit 6d1dc4ae by Sergey Mironov Committed by Tianqi Chen

[NNVM] Support argmax/argmin in tensorflow frontend (#1514)

parent 71cff3e8
...@@ -91,6 +91,20 @@ def _rsqrt(): ...@@ -91,6 +91,20 @@ def _rsqrt():
return AttrCvt(op_name="__pow_scalar__", extras={'scalar': -0.5})(inputs, attr) return AttrCvt(op_name="__pow_scalar__", extras={'scalar': -0.5})(inputs, attr)
return _impl return _impl
def _argx(func, func_name):
""" A common wrapper for argmin and argmax operations """
def _impl(inputs, attr, params):
try:
# In Tensorflow, `axis` argument is a Tensor, not attribute. We
# support the case where it inputs from a scalar constant.
axis_input_name = inputs[1].list_output_names()[0]
axis_input_vlaue = params[axis_input_name].asnumpy()[0]
except (IndexError, KeyError):
raise TypeError( \
"Unsupported argument for `{}` : `axis` should be a constant".format(func_name))
return func(inputs[0], axis=axis_input_vlaue, keepdims=False)
return _impl
def _elemwise(name): def _elemwise(name):
def _impl(inputs, attr, *args): def _impl(inputs, attr, *args):
assert len(inputs) == 2, "Math op take 2 inputs, {} given".format(len(inputs)) assert len(inputs) == 2, "Math op take 2 inputs, {} given".format(len(inputs))
...@@ -664,6 +678,8 @@ _identity_list = [] ...@@ -664,6 +678,8 @@ _identity_list = []
# for 1 to N mapping(composed), use custom callable functions # for 1 to N mapping(composed), use custom callable functions
# for N to 1 mapping, currently not supported(?) # for N to 1 mapping, currently not supported(?)
_convert_map = { _convert_map = {
'ArgMax' : _argx(_sym.argmax, 'argmax'),
'ArgMin' : _argx(_sym.argmin, 'argmin'),
'AvgPool' : _pooling('avg_pool'), 'AvgPool' : _pooling('avg_pool'),
'BatchNormWithGlobalNormalization' : _batch_norm(), 'BatchNormWithGlobalNormalization' : _batch_norm(),
'BiasAdd' : _bias_add(), 'BiasAdd' : _bias_add(),
...@@ -879,6 +895,28 @@ class RecurrentNetworks(object): ...@@ -879,6 +895,28 @@ class RecurrentNetworks(object):
params, num_layers) params, num_layers)
return sym return sym
def _parse_import_prerequisites(graph):
""" Calculate the named preconditions from TensorFlow `graph`.
Return prerequisites for parsing:
a. Set of operator names which don't have their mapping in TVM, i.e.
which are not supported
"""
missing_operators = set()
for node in graph.node:
if node.op == "Placeholder":
pass
elif node.op == "Const":
pass
else:
if any([node.op in t for t in [_identity_list, _convert_map, _convert_map_rnn]]):
pass
else:
missing_operators.add(node.op)
return missing_operators
class GraphProto(object): class GraphProto(object):
""" A helper class for handling nnvm graph copying from Tensorflow GraphDef. """ A helper class for handling nnvm graph copying from Tensorflow GraphDef.
Definition: Definition:
...@@ -901,7 +939,7 @@ class GraphProto(object): ...@@ -901,7 +939,7 @@ class GraphProto(object):
Follow the tensorflow graph definition to parse and convert it to NNVM. Follow the tensorflow graph definition to parse and convert it to NNVM.
Some of the assumptions listed below. Some of the assumptions listed below.
-> First Const or Placeholder node will be considered as graph input. -> First Placeholder or Const node will be considered as graph input.
-> Rest all Const nodes are params. -> Rest all Const nodes are params.
-> Last node is assumed as graph output. -> Last node is assumed as graph output.
-> _output_shapes : Attribute should present in the tenserflow forzen graph. -> _output_shapes : Attribute should present in the tenserflow forzen graph.
...@@ -910,6 +948,7 @@ class GraphProto(object): ...@@ -910,6 +948,7 @@ class GraphProto(object):
-> CheckNumerics: No implementation as of now for this. -> CheckNumerics: No implementation as of now for this.
Just copies input to output. Just copies input to output.
TODO: Change algorithm to stop treating first 'Const' in a special way.
Parameters Parameters
---------- ----------
...@@ -923,10 +962,6 @@ class GraphProto(object): ...@@ -923,10 +962,6 @@ class GraphProto(object):
params : dict params : dict
A dict of name: tvm.nd.array pairs, used as pretrained weights A dict of name: tvm.nd.array pairs, used as pretrained weights
""" """
# Parse throught all nodes and start extracting
# params aka Const nodes
# input nodes : First const node
# normal nodes : other normal nodes
try: try:
from tensorflow.python.framework import tensor_util from tensorflow.python.framework import tensor_util
...@@ -934,12 +969,18 @@ class GraphProto(object): ...@@ -934,12 +969,18 @@ class GraphProto(object):
raise ImportError( raise ImportError(
"Unable to import tensorflow which is required {}".format(e)) "Unable to import tensorflow which is required {}".format(e))
missing_operators = _parse_import_prerequisites(graph)
if missing_operators:
raise NotImplementedError( \
"The following operators are not implemented: {}".format(missing_operators))
# Parse the nodes to re-create TF graph using Symbol API of NNVM
for node in graph.node: for node in graph.node:
# Tensorflow doesn't have seperate list for params extraction. # Tensorflow doesn't have seperate list for params extraction.
# Operator name 'Const' is treated as a parameter to build NNVM params dict. # Operator name 'Const' is treated as a parameter to build NNVM params dict.
input_shapes = {} input_shapes = {}
if node.op == "Placeholder": if node.op == "Placeholder":
# Assuming only one input graph with type 'Placeholder'
self._input_node = node.name self._input_node = node.name
self._num_input += 1 self._num_input += 1
...@@ -954,7 +995,6 @@ class GraphProto(object): ...@@ -954,7 +995,6 @@ class GraphProto(object):
raise NotImplementedError( \ raise NotImplementedError( \
"Please freeze the graph with add_shapes=True") "Please freeze the graph with add_shapes=True")
elif node.op == "Const": elif node.op == "Const":
# Assuming first Const node as Graph Input node
if self._input_node == '': if self._input_node == '':
self._input_node = node.name self._input_node = node.name
self._num_input += 1 self._num_input += 1
...@@ -997,7 +1037,7 @@ class GraphProto(object): ...@@ -997,7 +1037,7 @@ class GraphProto(object):
# Pass the node name too in attr # Pass the node name too in attr
attr["_node_name"] = node.name attr["_node_name"] = node.name
#ToDo: Some of the tensorflow operators maintain internaly maintain #ToDo: Some of the tensorflow operators internaly maintain
#execution layers and its output name will the layer number along with #execution layers and its output name will the layer number along with
#graph node name.eg: Node name:- 'Model/RNN/cell_0/RnnCell', but the #graph node name.eg: Node name:- 'Model/RNN/cell_0/RnnCell', but the
#output name will be 'Model/RNN/cell_0/RnnCell:0'. In this case, #output name will be 'Model/RNN/cell_0/RnnCell:0'. In this case,
......
...@@ -404,6 +404,37 @@ def test_forward_sigmoid(): ...@@ -404,6 +404,37 @@ def test_forward_sigmoid():
_test_sigmoid(np.random.uniform(size=(3, 4, 4, 3)).astype('float32')) _test_sigmoid(np.random.uniform(size=(3, 4, 4, 3)).astype('float32'))
#######################################################################
# Argmin/Argmax
# -------------
def _test_argx(func, data, **kwargs):
with tf.Graph().as_default():
inp = constant_op.constant(data, shape=data.shape, dtype=data.dtype, name="c0")
# pylint: disable=unused-variable
out = func(inp, name="argx0", **kwargs)
# pylint: enable=unused-variable
with tf.Session() as sess:
graph_def = tf.graph_util.convert_variables_to_constants(
sess=sess,
input_graph_def=sess.graph.as_graph_def(add_shapes=True),
output_node_names=["argx0"])
tf_output = run_tf_graph(sess, data, input_node="c0:0", output_node="argx0:0")
tvm_output = run_tvm_graph(graph_def, data, "c0", tf_output.shape, output_dtype='int32')
np.testing.assert_allclose(tf_output, tvm_output, atol=1e-5, rtol=1e-5)
sess.close()
def test_argmin_argmax():
for axis in [None,0,1,2]:
data = np.random.uniform(size=(8,4,9)).astype('float32')
_test_argx(tf.argmax, data=data, axis=axis)
_test_argx(tf.argmin, data=data, axis=axis)
####################################################################### #######################################################################
# Variable # Variable
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment