Commit 56397826 by Zhi Committed by Tianqi Chen

hotfix for onnx (#3387)

parent 1119c40b
...@@ -23,6 +23,7 @@ import numpy as np ...@@ -23,6 +23,7 @@ import numpy as np
import tvm import tvm
from ... import nd as _nd from ... import nd as _nd
from .. import ir_pass from .. import ir_pass
from .. import transform as _transform
from .. import expr as _expr from .. import expr as _expr
from .. import module as _module from .. import module as _module
from .. import op as _op from .. import op as _op
...@@ -409,21 +410,27 @@ class Reshape(OnnxOpConverter): ...@@ -409,21 +410,27 @@ class Reshape(OnnxOpConverter):
shape = tuple(params[inputs[1].name_hint].asnumpy()) shape = tuple(params[inputs[1].name_hint].asnumpy())
out = _op.reshape(inputs[0], shape) out = _op.reshape(inputs[0], shape)
else: else:
# Try to infer shape by precompute prune if possible. data, shape = inputs
# TODO: good to check inputs to be in params. logging.warning("Constant evaluating Reshape's shape argument, may reduce performance")
# to be enhanced when relay support list_input_names API of NNVM shape_params = ir_pass.free_vars(shape)
logging.warning("Infering Reshape argument by precompute") func = _expr.Function(shape_params, shape)
func = _expr.Function(ir_pass.free_vars(inputs[1]), inputs[1]) mod = _module.Module.from_expr(func)
seq = _transform.Sequential([_transform.InferType(),
_transform.FoldConstant(),
_transform.FuseOps(0),
_transform.InferType()])
with tvm.relay.PassContext(opt_level=2):
mod = seq(mod)
with tvm.relay.build_config(opt_level=0): with tvm.relay.build_config(opt_level=0):
graph, lib, params = tvm.relay.build(func, target="llvm", params=params) ex = tvm.relay.create_executor("debug", mod=mod)
ctx = tvm.context("llvm", 0) inputs = []
from tvm.contrib import graph_runtime for sp in shape_params:
m = graph_runtime.create(graph, lib, ctx) if not sp.name_hint in params:
m.set_input(**params) sh = [int(i) for i in sp.type_annotation.shape]
m.run() inputs.append(
params_new = m.get_output(0) tvm.nd.array(np.random.rand(*sh).astype('float32')))
inputs.pop(1) static_shape = ex.evaluate()(*inputs, **params)
out = _op.reshape(inputs[0], tuple(params_new.asnumpy().astype('int32').flatten())) out = _op.reshape(data, newshape=tuple(static_shape.asnumpy()))
return out return out
...@@ -568,6 +575,7 @@ class Shape(OnnxOpConverter): ...@@ -568,6 +575,7 @@ class Shape(OnnxOpConverter):
@classmethod @classmethod
def _impl_v1(cls, inputs, attr, params): def _impl_v1(cls, inputs, attr, params):
# TODO(@jroesch): use shape_of once it has been fixed)
return _op.shape_of(inputs[0]) return _op.shape_of(inputs[0])
class Cast(OnnxOpConverter): class Cast(OnnxOpConverter):
...@@ -1058,8 +1066,15 @@ class GraphProto(object): ...@@ -1058,8 +1066,15 @@ class GraphProto(object):
if op_name == "Constant": if op_name == "Constant":
t_proto = self._parse_attr(node.attribute)["value"] t_proto = self._parse_attr(node.attribute)["value"]
self._num_param += 1 self._num_param += 1
self._params[node.output[0]] = self._parse_array(t_proto) # We should convert scalar integers to int32, to normalize.
self._nodes[node.output[0]] = new_var(node.output[0], shape=list(t_proto.dims)) array = self._parse_array(t_proto)
if len(array.shape) == 0 and array.dtype == 'int64':
array = _nd.array(array.asnumpy().astype('int32'))
self._params[node.output[0]] = array
self._nodes[node.output[0]] = new_var(
node.output[0],
shape=list(t_proto.dims),
dtype=array.dtype)
else: else:
if op_name == "ConstantFill": if op_name == "ConstantFill":
fill_value = attr.get('value', 0.0) fill_value = attr.get('value', 0.0)
......
...@@ -14,8 +14,11 @@ ...@@ -14,8 +14,11 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import attr
import numpy as np import numpy as np
import math import math
import torch
import torchvision
import topi import topi
import topi.testing import topi.testing
import tvm import tvm
...@@ -1072,6 +1075,48 @@ def test_LogSoftmax(): ...@@ -1072,6 +1075,48 @@ def test_LogSoftmax():
'LogSoftmax', 'LogSoftmax',
{'axis': 1}) {'axis': 1})
def check_torch_conversion(model, input_size):
dummy_input = torch.randn(*input_size)
file_name = '{}.onnx'.format(model.__name__)
# Set verbose=True for more output
torch.onnx.export(model(), dummy_input, file_name, export_params=True, verbose=False)
onnx_model = onnx.load(file_name)
shapes = { '0' : input_size }
expr, params = relay.frontend.from_onnx(onnx_model, shape=shapes)
def test_resnet():
check_torch_conversion(torchvision.models.resnet18, (1,3,224,224))
# check_torch_conversion(torchvision.models.resnet101, (1,3,224,224))
# def test_alexnet():
# Torch's ONNX export does not support the adaptive pooling used by AlexNet?
# check_torch_conversion(torchvision.models.alexnet, (1,3,224,224))
# Torch's ONNX export does not support the adaptive pooling used by vgg16?
# def test_vgg16():
# check_torch_conversion(torchvision.models.vgg16, (1,3,224,224))
# TODO(@jroesch): Update Torch + ONNX to support this import.
# def test_squeezenet():
# # Torch's ONNX export does not support the max pooling used by Squezenet
# check_torch_conversion(torchvision.models.squeezenet1_0, (1,3,224,224))
def test_densenet():
check_torch_conversion(torchvision.models.densenet161, (1,3,224,224))
def test_inception():
check_torch_conversion(torchvision.models.inception_v3, (1,3,224,224))
# TODO(@jroesch): Update Torch + ONNX to support this import.
# def test_googlenet():
# check_torch_conversion(torchvision.models.googlenet, (1,3,224,224))
# TODO(@jroesch): Update Torch + ONNX to support this import.
# def test_shufflenetv2():
# check_torch_conversion(torchvision.models.shufflenetv2, (1,3,224,224))
if __name__ == '__main__': if __name__ == '__main__':
test_flatten() test_flatten()
test_reshape() test_reshape()
...@@ -1111,3 +1156,6 @@ if __name__ == '__main__': ...@@ -1111,3 +1156,6 @@ if __name__ == '__main__':
test_ParametricSoftplus() test_ParametricSoftplus()
test_Scale() test_Scale()
test_LogSoftmax() test_LogSoftmax()
test_resnet()
test_inception()
test_densenet()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment