Commit 6d1dc4ae by Sergey Mironov Committed by Tianqi Chen

[NNVM] Support argmax/argmin in tensorflow frontend (#1514)

parent 71cff3e8
......@@ -91,6 +91,20 @@ def _rsqrt():
return AttrCvt(op_name="__pow_scalar__", extras={'scalar': -0.5})(inputs, attr)
return _impl
def _argx(func, func_name):
""" A common wrapper for argmin and argmax operations """
def _impl(inputs, attr, params):
try:
# In Tensorflow, `axis` argument is a Tensor, not attribute. We
# support the case where it inputs from a scalar constant.
axis_input_name = inputs[1].list_output_names()[0]
axis_input_vlaue = params[axis_input_name].asnumpy()[0]
except (IndexError, KeyError):
raise TypeError( \
"Unsupported argument for `{}` : `axis` should be a constant".format(func_name))
return func(inputs[0], axis=axis_input_vlaue, keepdims=False)
return _impl
def _elemwise(name):
def _impl(inputs, attr, *args):
assert len(inputs) == 2, "Math op take 2 inputs, {} given".format(len(inputs))
......@@ -664,6 +678,8 @@ _identity_list = []
# for 1 to N mapping(composed), use custom callable functions
# for N to 1 mapping, currently not supported(?)
_convert_map = {
'ArgMax' : _argx(_sym.argmax, 'argmax'),
'ArgMin' : _argx(_sym.argmin, 'argmin'),
'AvgPool' : _pooling('avg_pool'),
'BatchNormWithGlobalNormalization' : _batch_norm(),
'BiasAdd' : _bias_add(),
......@@ -879,6 +895,28 @@ class RecurrentNetworks(object):
params, num_layers)
return sym
def _parse_import_prerequisites(graph):
""" Calculate the named preconditions from TensorFlow `graph`.
Return prerequisites for parsing:
a. Set of operator names which don't have their mapping in TVM, i.e.
which are not supported
"""
missing_operators = set()
for node in graph.node:
if node.op == "Placeholder":
pass
elif node.op == "Const":
pass
else:
if any([node.op in t for t in [_identity_list, _convert_map, _convert_map_rnn]]):
pass
else:
missing_operators.add(node.op)
return missing_operators
class GraphProto(object):
""" A helper class for handling nnvm graph copying from Tensorflow GraphDef.
Definition:
......@@ -901,7 +939,7 @@ class GraphProto(object):
Follow the tensorflow graph definition to parse and convert it to NNVM.
Some of the assumptions listed below.
-> First Const or Placeholder node will be considered as graph input.
-> First Placeholder or Const node will be considered as graph input.
-> Rest all Const nodes are params.
-> Last node is assumed as graph output.
-> _output_shapes : Attribute should present in the tenserflow forzen graph.
......@@ -910,6 +948,7 @@ class GraphProto(object):
-> CheckNumerics: No implementation as of now for this.
Just copies input to output.
TODO: Change algorithm to stop treating first 'Const' in a special way.
Parameters
----------
......@@ -923,10 +962,6 @@ class GraphProto(object):
params : dict
A dict of name: tvm.nd.array pairs, used as pretrained weights
"""
# Parse throught all nodes and start extracting
# params aka Const nodes
# input nodes : First const node
# normal nodes : other normal nodes
try:
from tensorflow.python.framework import tensor_util
......@@ -934,12 +969,18 @@ class GraphProto(object):
raise ImportError(
"Unable to import tensorflow which is required {}".format(e))
missing_operators = _parse_import_prerequisites(graph)
if missing_operators:
raise NotImplementedError( \
"The following operators are not implemented: {}".format(missing_operators))
# Parse the nodes to re-create TF graph using Symbol API of NNVM
for node in graph.node:
# Tensorflow doesn't have seperate list for params extraction.
# Operator name 'Const' is treated as a parameter to build NNVM params dict.
input_shapes = {}
if node.op == "Placeholder":
# Assuming only one input graph with type 'Placeholder'
self._input_node = node.name
self._num_input += 1
......@@ -954,7 +995,6 @@ class GraphProto(object):
raise NotImplementedError( \
"Please freeze the graph with add_shapes=True")
elif node.op == "Const":
# Assuming first Const node as Graph Input node
if self._input_node == '':
self._input_node = node.name
self._num_input += 1
......@@ -997,7 +1037,7 @@ class GraphProto(object):
# Pass the node name too in attr
attr["_node_name"] = node.name
#ToDo: Some of the tensorflow operators maintain internaly maintain
#ToDo: Some of the tensorflow operators internaly maintain
#execution layers and its output name will the layer number along with
#graph node name.eg: Node name:- 'Model/RNN/cell_0/RnnCell', but the
#output name will be 'Model/RNN/cell_0/RnnCell:0'. In this case,
......
......@@ -404,6 +404,37 @@ def test_forward_sigmoid():
_test_sigmoid(np.random.uniform(size=(3, 4, 4, 3)).astype('float32'))
#######################################################################
# Argmin/Argmax
# -------------
def _test_argx(func, data, **kwargs):
with tf.Graph().as_default():
inp = constant_op.constant(data, shape=data.shape, dtype=data.dtype, name="c0")
# pylint: disable=unused-variable
out = func(inp, name="argx0", **kwargs)
# pylint: enable=unused-variable
with tf.Session() as sess:
graph_def = tf.graph_util.convert_variables_to_constants(
sess=sess,
input_graph_def=sess.graph.as_graph_def(add_shapes=True),
output_node_names=["argx0"])
tf_output = run_tf_graph(sess, data, input_node="c0:0", output_node="argx0:0")
tvm_output = run_tvm_graph(graph_def, data, "c0", tf_output.shape, output_dtype='int32')
np.testing.assert_allclose(tf_output, tvm_output, atol=1e-5, rtol=1e-5)
sess.close()
def test_argmin_argmax():
for axis in [None,0,1,2]:
data = np.random.uniform(size=(8,4,9)).astype('float32')
_test_argx(tf.argmax, data=data, axis=axis)
_test_argx(tf.argmin, data=data, axis=axis)
#######################################################################
# Variable
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment