Commit d059dbab by Joshua Z. Zhang Committed by Tianqi Chen

[WIP] Onnx1.0 (#294)

* add more op for onnx 1.0

* fix syntax

* fix lint

* fix

* update 1.0

* fix

* update model
parent 063cd412
......@@ -150,6 +150,90 @@ def _gemm():
return _sym.dense(alpha * inputs[0], inputs[1], beta * inputs[2], units=channels)
return _impl
def _thresholded_relu():
def _impl(inputs, attr, params):
alpha = float(attr.get('alpha', 0.0))
return _sym.relu(inputs[0] - alpha)
return _impl
def _scaled_tanh():
def _impl(inputs, attr, params):
alpha = float(attr.get('alpha', 1.0))
beta = float(attr.get('beta', 1.0))
return _sym.tanh(beta * inputs[0]) * alpha
return _impl
def parametric_soft_plus():
def _impl(inputs, attr, params):
alpha = float(attr.get('alpha', 1.0))
beta = float(attr.get('beta', 1.0))
return _sym.log(_sym.exp(beta * inputs[0]) + 1) * alpha
return _impl
def _scale():
def _impl(inputs, attr, params):
scale = float(attr.get('scale', 1.0))
return inputs[0] * scale
return _impl
def _absolute():
"""This is a workaround."""
def _impl(inputs, attr, params):
return _sym.relu(inputs[0]) + _sym.relu(_sym.negative(inputs[0]))
return _impl
def _reciprocal():
def _impl(inputs, attr, params):
return 1.0 / inputs[0]
return _impl
def _selu():
def _impl(inputs, attr, params):
alpha = float(attr.get('alpha', 1.6732))
gamma = float(attr.get('gamma', 1.0507))
return gamma * (-alpha * _sym.relu(1 - _sym.exp(inputs[0]))
+ _sym.relu(inputs[0]))
return _impl
def _elu():
def _impl(inputs, attr, params):
alpha = float(attr.get('alpha', 1.0))
return -alpha * _sym.relu(1 - _sym.exp(inputs[0])) + _sym.relu(inputs[0])
return _impl
def _prelu():
def _impl(inputs, attr, params):
assert len(inputs) == 2, "Prelu need 2 inputs, {} given".format(len(inputs))
channels = _infer_channels(inputs[1], params, False)
if channels == 1:
return inputs[0] * inputs[1]
return _sym.broadcast_mul(inputs[0], inputs[1])
return _impl
def _softsign():
def _impl(inputs, attr, params):
return inputs[0] / (1 + _absolute()(inputs, attr, params))
return _impl
def _softplus():
def _impl(inputs, attr, params):
return _sym.log(_sym.exp(x) + 1)
return _impl
def _pad():
def _impl(inputs, attr, params):
# get number of channels
channels = _infer_channels(inputs[1], params, True)
attr['channels'] = channels
groups = attr.pop('group')
attr['groups'] = groups
return AttrCvt(
op_name='pad',
transforms={
'value': 'pad_value',
'pads': 'pad_width'},
custom_check=lambda attrs: attrs.get('mode') == 'constant')(inputs, attr)
return _impl
# compatible operators that do NOT require any conversion.
_identity_list = []
......@@ -161,7 +245,22 @@ _identity_list = []
# for N to 1 mapping, currently not supported(?)
_convert_map = {
# defs/experimental
'Identity' : Renamer('copy'),
# 'Affine'
'ThresholdedRelu': _thresholded_relu(),
'ScaledTanh' : _scaled_tanh(),
'ParametricSoftplus': parametric_soft_plus(),
# 'ConstantFill'
# 'GivenTensorFill'
'FC' : AttrCvt('dense', ignores=['axis', 'axis_w']),
'Scale' : _scale(),
# 'GRUUnit'
# 'ATen'
# 'ImageScaler'
# 'MeanVarianceNormalization'
# 'Crop'
# 'Embedding'
# 'Upsample'
'SpatialBN' : _batch_norm(),
# defs/generator
......@@ -179,28 +278,35 @@ _convert_map = {
'Mul' : _elemwise('mul'),
'Div' : _elemwise('div'),
'Neg' : Renamer('negative'),
# 'Abs'
# 'Reciprocal'
'Abs' : _absolute(),
'Reciprocal' : _reciprocal(),
# 'Floor'
# 'Ceil'
# 'Sqrt'
'Sqrt' : Renamer('sqrt'),
'Relu' : Renamer('relu'),
'LeakyRelu' : Renamer('leaky_relu'),
# 'Selu'
# 'Elu'
'Selu' : _selu(),
'Elu' : _elu(),
'Exp' : Renamer('exp'),
'Log' : Renamer('log'),
'Tanh' : Renamer('tanh'),
# 'Pow'
# 'Dot'
# 'PRelu'
'PRelu' : _prelu(),
'Sigmoid' : Renamer('sigmoid'),
# 'HardSigmoid'
# 'Max' : this is the elemwise maximum
# 'Min' : this is the elemwise minimum
# 'Sum' : elemwise sum
# 'Mean'
# 'Clip'
# softmax default axis is different in onnx
'Softmax' : AttrCvt('softmax', {'axis': ('axis', 1)}),
'LogSoftmax' : AttrCvt('log_softmax', {'axis': ('axis', 1)}),
# 'Hardmax'
'Softsign' : _softsign(),
'SoftPlus' : _softplus(),
'Gemm' : _gemm(),
# 'MatMul' batch stacked dot operation
# defs/nn
'AveragePool' : _pooling('avg_pool'),
......@@ -210,6 +316,8 @@ _convert_map = {
'GlobalAveragePool': Renamer('global_avg_pool2d'),
'GlobalMaxPool' : Renamer('global_max_pool2d'),
'BatchNormalization': _batch_norm(),
# 'InstanceNormalization'
# 'LpNormalization'
'Dropout' : AttrCvt('dropout', {'ratio': 'rate'}, ignores=['is_test']),
'Flatten' : Renamer('flatten'),
# 'LRN'
......@@ -233,6 +341,7 @@ _convert_map = {
'Transpose' : AttrCvt('transpose', {'perm': 'axes'}),
# 'Gather'
# 'Squeeze'
'Pad' : _pad(),
}
......
......@@ -24,7 +24,8 @@ def _as_abs_path(fname):
URLS = {
'super_resolution.onnx': 'https://gist.github.com/zhreshold/bcda4716699ac97ea44f791c24310193/raw/93672b029103648953c4e5ad3ac3aadf346a4cdc/super_resolution_0.2.onnx',
'squeezenet1_1.onnx': 'https://gist.github.com/zhreshold/bcda4716699ac97ea44f791c24310193/raw/93672b029103648953c4e5ad3ac3aadf346a4cdc/squeezenet1_1_0.2.onnx',
'lenet.onnx': 'https://gist.github.com/zhreshold/bcda4716699ac97ea44f791c24310193/raw/93672b029103648953c4e5ad3ac3aadf346a4cdc/lenet_0.2.onnx'}
'lenet.onnx': 'https://gist.github.com/zhreshold/bcda4716699ac97ea44f791c24310193/raw/93672b029103648953c4e5ad3ac3aadf346a4cdc/lenet_0.2.onnx',
'resnet18_1_0.onnx': 'https://gist.github.com/zhreshold/bcda4716699ac97ea44f791c24310193/raw/b385b1b242dc89a35dd808235b885ed8a19aedc1/resnet18_1.0.onnx'}
# download and add paths
for k, v in URLS.items():
......
......@@ -4,7 +4,7 @@ import tvm
from tvm.contrib import graph_runtime
from nnvm.testing.config import ctx_list
import onnx
from model_zoo import super_resolution, squeezenet1_1, lenet
from model_zoo import super_resolution, squeezenet1_1, lenet, resnet18_1_0
def verify_onnx_forward_impl(graph_file, data_shape, out_shape):
import onnx_caffe2.backend
......@@ -44,7 +44,11 @@ def verify_squeezenet1_1():
def verify_lenet():
verify_onnx_forward_impl(lenet, (1, 1, 28, 28), (1, 10))
def verify_resnet18():
verify_onnx_forward_impl(resnet18_1_0, (1, 3, 224, 224), (1, 1000))
if __name__ == '__main__':
verify_super_resolution_example()
verify_squeezenet1_1()
verify_lenet()
# verify_super_resolution_example()
# verify_squeezenet1_1()
# verify_lenet()
verify_resnet18()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment