Commit c9f9a3f9 by Siju Committed by Tianqi Chen

l2normalization operator support for tensorflow (#1528)

parent 7ea06e6e
......@@ -434,6 +434,21 @@ def _lrn():
return AttrCvt(op_name='lrn')(new_inputs, attr_new)
return _impl
def _sum():
def _impl(inputs, attr, params):
axis = params.pop(inputs[1].list_output_names()[0]).asnumpy()
return AttrCvt(
op_name='sum',
extras={'axis': axis},
transforms={'keep_dims':'keepdims'},
ignores=['name', 'Tidx'])(inputs[0], attr)
return _impl
def _square():
def _impl(inputs, attr, params):
return _sym.elemwise_mul(inputs[0], inputs[0])
return _impl
def _gather_v2():
"Tensorflow now support only gatherv2"
def _impl(inputs, attr, params):
......@@ -651,13 +666,17 @@ _convert_map = {
'Identity' : _identity(),
'MatMul' : _matmul(),
'MaxPool' : _pooling('max_pool'),
'Add' : _elemwise('add'),
'Sub' : _elemwise('sub'),
'Mul' : _elemwise('mul'),
'Maximum' : _elemwise('max'),
'Minimum' : _elemwise('min'),
'Sum' : _sum(),
'Square' : _square(),
'Relu' : AttrCvt('relu'),
'Reshape' : _reshape(),
'ResizeBilinear' : _resize_bilinear(),
'Softmax' : AttrCvt('softmax', {'axis': ('axis', 1)}),
'Sub' : _elemwise('sub'),
'Add' : _elemwise('add'),
'Rsqrt' : _rsqrt(),
'Squeeze' : _squeeze(),
'FusedBatchNorm' : _fused_batch_norm(),
......
......@@ -12,6 +12,7 @@ import tensorflow as tf
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import graph_util
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import math_ops
......@@ -948,7 +949,6 @@ def _test_lrn(ishape, size, axis, bias, alpha, beta):
sess,
sess.graph.as_graph_def(add_shapes=True),
['lrn'],)
tf_output = run_tf_graph(sess, inp_array, 'lrn0_data:0', 'lrn:0')
tvm_output = run_tvm_graph(graph_def,
inp_array,
......@@ -959,6 +959,42 @@ def _test_lrn(ishape, size, axis, bias, alpha, beta):
def test_forward_lrn():
_test_lrn((1, 3, 20, 20), 3, 1, 1.0, 1.0, 0.5)
#######################################################################
# l2_normalize
# ------------
def _test_l2_normalize(ishape, eps, axis):
""" testing l2 normalize (uses max, sum, square, sqrt frontend operators)"""
inp_array = np.random.uniform(size=ishape).astype(np.float32)
inp_array.fill(1)
with tf.Graph().as_default():
in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype, name="Placeholder")
nn.l2_normalize(in1,
axis=axis,
epsilon=eps,
name=None,
dim=None)
with tf.Session() as sess:
graph_def = tf.graph_util.convert_variables_to_constants(
sess,
sess.graph.as_graph_def(add_shapes=True),
['l2_normalize'],
)
tf_output = run_tf_graph(sess, inp_array, 'Placeholder:0', 'Placeholder:0')
tvm_output = run_tvm_graph(graph_def,
inp_array,
"Placeholder",
tf_output.shape,
tf_output.dtype)
np.testing.assert_allclose(tf_output, tvm_output, atol=1e-3, rtol=1e-3)
sess.close()
def test_forward_l2_normalize():
_test_l2_normalize((1, 3, 20, 20), 0.001, (0,))
#######################################################################
# Main
# ----
if __name__ == '__main__':
......@@ -981,3 +1017,4 @@ if __name__ == '__main__':
test_forward_gather()
test_forward_ptb()
test_forward_lrn()
test_forward_l2_normalize()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment