Commit a2870fef by Siju Committed by Tianqi Chen

[NNVM][TENSORFLOW]Local Response Normalization added for tensorflow (#1522)

parent 32076df8
......@@ -468,6 +468,20 @@ def _fill():
ignores=['index_type', 'T'])(new_inputs, attr)
return _impl
def _lrn():
def _impl(inputs, attr, params):
new_inputs = []
attr_new = {}
depth_radius = attr.get('depth_radius', 5)
size = (depth_radius * 2) + 1
attr_new['axis'] = 3 # Fix axis, NHWC format
attr_new['size'] = size
attr_new['bias'] = attr.get('bias', 1)
attr_new['alpha'] = attr.get('alpha', 1) * size
attr_new['beta'] = attr.get('beta', 0.5)
return AttrCvt(op_name='lrn')(new_inputs, attr_new)
return _impl
def _gather_v2():
"Tensorflow now support only gatherv2"
def _impl(inputs, attr, params):
......@@ -680,6 +694,7 @@ _convert_map = {
'Fill' : _fill(),
'GatherV2' : _gather_v2(),
'StridedSlice' : _stridedSlice(),
'LRN' : _lrn(),
}
# _convert_map_rnn defines maps of rnn operator name to
......
......@@ -855,6 +855,40 @@ def test_forward_ptb():
assert(tvm_sample_str == tf_sample_str)
#######################################################################
# LRN (Local Response Normalization)
# ----------------------------------
def _test_lrn(ishape, size, axis, bias, alpha, beta):
""" testing local response normalization """
lrn_depth_radius = size / 2
inp_array = np.random.uniform(size=ishape).astype(np.float32)
with tf.Graph().as_default():
in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype, name="lrn0_data")
nn_ops.local_response_normalization(in1,
name="lrn",
depth_radius=lrn_depth_radius,
bias=bias,
alpha=alpha,
beta=beta)
with tf.Session() as sess:
graph_def = tf.graph_util.convert_variables_to_constants(
sess,
sess.graph.as_graph_def(add_shapes=True),
['lrn'],)
tf_output = run_tf_graph(sess, inp_array, 'lrn0_data:0', 'lrn:0')
tvm_output = run_tvm_graph(graph_def,
inp_array,
"lrn0_data", tf_output.shape, tf_output.dtype)
np.testing.assert_allclose(tf_output, tvm_output, atol=1e-3, rtol=1e-3)
sess.close()
def test_forward_lrn():
_test_lrn((1, 3, 20, 20), 3, 1, 1.0, 1.0, 0.5)
# Main
# ----
if __name__ == '__main__':
......@@ -875,3 +909,4 @@ if __name__ == '__main__':
test_forward_stridedslice()
test_forward_gather()
test_forward_ptb()
test_forward_lrn()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment