Commit 0241fdc5 by Albin Joy Committed by Yizhi Liu

[FRONTEND][ONNX]LRN support for ONNX (#1518)

* LRN support for ONNX

* [ONNX] Updated lrn testcases
parent a8574e7b
...@@ -499,6 +499,23 @@ class Gather(OnnxOpConverter): ...@@ -499,6 +499,23 @@ class Gather(OnnxOpConverter):
params[name] = indices params[name] = indices
return _sym.take(inputs[0], gather_indices, axis=axis) return _sym.take(inputs[0], gather_indices, axis=axis)
class LRN(OnnxOpConverter):
""" Operator converter for Local Response Normalization.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
"""LRN support only NCHW format
https://github.com/onnx/onnx/blob/master/docs/Operators.md#LRN
"""
axis = 1
alpha = attr.get('alpha', 0.0001)
beta = attr.get('beta', 0.75)
bias = attr.get('bias', 1.0)
nsize = attr.get('size')
return _sym.lrn(inputs[0], size=nsize, axis=axis,
alpha=alpha, beta=beta, bias=bias)
# compatible operators that do NOT require any conversion. # compatible operators that do NOT require any conversion.
_identity_list = [] _identity_list = []
...@@ -586,7 +603,7 @@ def _get_convert_map(opset): ...@@ -586,7 +603,7 @@ def _get_convert_map(opset):
# 'LpNormalization' # 'LpNormalization'
'Dropout': AttrCvt('dropout', {'ratio': 'rate'}, ignores=['is_test']), 'Dropout': AttrCvt('dropout', {'ratio': 'rate'}, ignores=['is_test']),
'Flatten': Renamer('flatten'), 'Flatten': Renamer('flatten'),
# 'LRN' 'LRN': LRN.get_converter(opset),
# defs/reduction # defs/reduction
'ReduceMax': AttrCvt('max', {'axes', 'axis'}), 'ReduceMax': AttrCvt('max', {'axes', 'axis'}),
......
import numpy as np import numpy as np
import math
import nnvm import nnvm
import tvm import tvm
from tvm.contrib import graph_runtime from tvm.contrib import graph_runtime
...@@ -312,6 +313,58 @@ def test_matmul(): ...@@ -312,6 +313,58 @@ def test_matmul():
np.testing.assert_allclose(np.matmul(a_array, b_array), tvm_out.asnumpy(), rtol=1e-5, atol=1e-5) np.testing.assert_allclose(np.matmul(a_array, b_array), tvm_out.asnumpy(), rtol=1e-5, atol=1e-5)
def verify_lrn(shape, nsize, dtype, alpha=None, beta=None, bias=None):
in_array = np.random.uniform(size=shape).astype(dtype)
if alpha == None and beta == None and bias==None:
alpha = 0.0001
beta = 0.75
bias = 1.0
node = onnx.helper.make_node('LRN', inputs=['in'], outputs=['out'], size=nsize)
else:
node = onnx.helper.make_node('LRN', inputs=['in'], outputs=['out'], alpha=alpha,
beta=beta, bias=bias, size=nsize)
graph = helper.make_graph([node],
"lrn_test",
inputs = [helper.make_tensor_value_info("in", TensorProto.FLOAT, list(shape))],
outputs = [helper.make_tensor_value_info("out", TensorProto.FLOAT, list(shape))])
model = helper.make_model(graph, producer_name='lrn_test')
def _get_python_lrn():
square_sum = np.zeros(shape).astype(dtype)
for n, c, h, w in np.ndindex(in_array.shape):
square_sum[n, c, h, w] = sum(in_array[n,
max(0, c - int(math.floor((nsize - 1) / 2))): \
min(5, c + int(math.ceil((nsize - 1) / 2)) + 1),
h,
w] ** 2)
py_out = in_array / ((bias + (alpha / nsize) * square_sum) ** beta)
return py_out
for target, ctx in ctx_list():
new_sym, params = nnvm.frontend.from_onnx(model)
input_name = model.graph.input[0].name
shape_dict = {input_name: in_array.shape}
dtype_dict = {input_name: dtype}
graph, lib, params = nnvm.compiler.build(new_sym, target,
shape_dict, dtype_dict, params=params)
m = graph_runtime.create(graph, lib, ctx)
# set inputs
m.set_input(input_name, tvm.nd.array(in_array.astype(dtype)))
m.set_input(**params)
m.run()
# get outputs
tvm_out = m.get_output(0, tvm.nd.empty(shape, dtype))
py_out = _get_python_lrn()
np.testing.assert_allclose(py_out, tvm_out.asnumpy(), rtol=1e-5, atol=1e-5)
def test_lrn():
verify_lrn((5, 5, 5, 5), 3, 'float32')
verify_lrn((5, 5, 5, 5), 3, 'float32', alpha=0.0002, beta=0.5, bias=2.0)
if __name__ == '__main__': if __name__ == '__main__':
# verify_super_resolution_example() # verify_super_resolution_example()
# verify_squeezenet1_1() # verify_squeezenet1_1()
...@@ -328,3 +381,4 @@ if __name__ == '__main__': ...@@ -328,3 +381,4 @@ if __name__ == '__main__':
test_clip() test_clip()
test_matmul() test_matmul()
test_gather() test_gather()
test_lrn()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment