Commit d1e048b7 by hlu1 Committed by Tianqi Chen

Fix Softmax in onnx frontend (#1642)

parent ec0d497c
......@@ -597,6 +597,20 @@ class ArgMin(OnnxOpConverter):
attr = {'axis':axis, 'keepdims':keepdims}
return AttrCvt(op_name='argmin')(inputs, attr)
class Softmax(OnnxOpConverter):
""" Operator converter for Softmax.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
# set default value when axis is not set in the model
if 'axis' not in attr:
attr['axis'] = 1
return AttrCvt(
op_name='softmax',
transforms={
'axis': ('axis', 1),
})(inputs, attr, params)
# compatible operators that do NOT require any conversion.
_identity_list = []
......@@ -664,7 +678,7 @@ def _get_convert_map(opset):
'Mean': Mean.get_converter(opset),
'Clip': AttrCvt('clip', transforms={'min': 'a_min', 'max': 'a_max'}),
# softmax default axis is different in onnx
'Softmax': AttrCvt('softmax', {'axis': ('axis', 1)}),
'Softmax': Softmax.get_converter(opset),
'LogSoftmax': AttrCvt('log_softmax', {'axis': ('axis', 1)}),
# 'Hardmax'
'Softsign': Softsign.get_converter(opset),
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# coding: utf-8
# pylint: disable=unused-argument
"""
Symbol of SqueezeNet
Reference:
Iandola, Forrest N., et al.
"Squeezenet: Alexnet-level accuracy with 50x fewer parameters and< 0.5 mb model size." (2016).
"""
from nnvm import symbol as sym
from nnvm.testing.utils import create_workload
# Helpers
def _make_fire(net, squeeze_channels, expand1x1_channels, expand3x3_channels):
net = _make_fire_conv(net, squeeze_channels, 1, 0)
left = _make_fire_conv(net, expand1x1_channels, 1, 0)
right = _make_fire_conv(net, expand3x3_channels, 3, 1)
# NOTE : Assume NCHW layout here
net = sym.concatenate(left, right, axis=1)
return net
def _make_fire_conv(net, channels, kernel_size, padding=0):
net = sym.conv2d(net, channels=channels, kernel_size=(kernel_size, kernel_size),
padding=(padding, padding))
net = sym.relu(net)
return net
# Net
def get_symbol(num_classes, version, **kwargs):
"""Get symbol of SqueezeNet
Parameters
----------
num_classes: int
The number of classification results
version : str, optional
"1.0" or "1.1" of SqueezeNet
"""
assert version == '1.1', ("Unsupported SqueezeNet version {version}:"
"1.1 expected".format(version=version))
net = sym.Variable("data")
net = sym.conv2d(net, channels=64, kernel_size=(3, 3), strides=(2, 2))
net = sym.relu(net)
net = sym.max_pool2d(net, pool_size=(3, 3), strides=(2, 2))
net = _make_fire(net, 16, 64, 64)
net = _make_fire(net, 16, 64, 64)
net = sym.max_pool2d(net, pool_size=(3, 3), strides=(2, 2))
net = _make_fire(net, 32, 128, 128)
net = _make_fire(net, 32, 128, 128)
net = sym.max_pool2d(net, pool_size=(3, 3), strides=(2, 2))
net = _make_fire(net, 48, 192, 192)
net = _make_fire(net, 48, 192, 192)
net = _make_fire(net, 64, 256, 256)
net = _make_fire(net, 64, 256, 256)
net = sym.dropout(net, rate=0.5)
net = sym.conv2d(net, channels=num_classes, kernel_size=(1, 1))
net = sym.relu(net)
net = sym.global_avg_pool2d(net)
return sym.softmax(net, axis=1)
def get_workload(batch_size=1, num_classes=1000, version='1.0',
image_shape=(3, 224, 224), dtype="float32", **kwargs):
"""Get benchmark workload for SqueezeNet
Parameters
----------
batch_size : int
The batch size used in the model
num_classes : int, optional
Number of classes
version : str, optional
"1.0" or "1.1" of SqueezeNet
image_shape : tuple, optional
The input image shape
dtype : str, optional
The data type
kwargs : dict
Extra arguments
Returns
-------
net : nnvm.Symbol
The computational graph
params : dict of str to NDArray
The parameters.
"""
net = get_symbol(num_classes=num_classes, version=version, **kwargs)
return create_workload(net, batch_size, image_shape, dtype)
......@@ -426,6 +426,33 @@ def test_upsample():
_test_upsample_nearest()
_test_upsample_bilinear()
def _test_softmax(inshape, axis):
opname = 'Softmax'
indata = np.random.uniform(size=inshape).astype(np.float32)
outshape = inshape
outdata = topi.testing.softmax_python(indata)
if isinstance(axis, int):
y = helper.make_node(opname, ['in'], ['out'], axis = axis)
elif axis is None:
y = helper.make_node(opname, ['in'], ['out'])
graph = helper.make_graph([y],
opname+'_test',
inputs = [helper.make_tensor_value_info("in",
TensorProto.FLOAT, list(indata.shape))],
outputs = [helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(outdata.shape))])
model = helper.make_model(graph, producer_name=opname+'_test')
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, indata, target, ctx, outshape, 'float32')
np.testing.assert_allclose(outdata, tvm_out, rtol=1e-5, atol=1e-5)
def test_softmax():
_test_softmax((1, 10), None)
_test_softmax((1, 10), 1)
def verify_min(input_dim):
dtype = 'float32'
......@@ -676,3 +703,4 @@ if __name__ == '__main__':
test_forward_mean()
test_forward_hardsigmoid()
test_forward_arg_min_max()
test_softmax()
......@@ -3,6 +3,7 @@ import nnvm
import onnx
from nnvm.compiler import graph_util, graph_attr
from model_zoo import super_resolution, super_resolution_sym
from model_zoo import squeezenet as squeezenet
def compare_graph(onnx_file, nnvm_sym, ishape):
onnx_model = onnx.load(onnx_file)
......@@ -18,8 +19,16 @@ def compare_graph(onnx_file, nnvm_sym, ishape):
graph_util.check_graph_equal(g1, g2)
def test_super_resolution_example():
fname, symbol = super_resolution, super_resolution_sym
fname, symbol = "super_resolution.onnx", super_resolution_sym
compare_graph(fname, symbol, ishape=(1, 1, 224, 224))
def test_squeeze_net():
# Only works for model downloaded from
# https://github.com/onnx/models/tree/master/squeezenet
fname = "squeezenet1_1.onnx"
symbol, params = squeezenet.get_workload(version='1.1')
compare_graph(fname, symbol, ishape=(1, 3, 224, 224))
if __name__ == '__main__':
test_super_resolution_example()
test_squeeze_net()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment