Commit 9ec0e5ce by masahi Committed by Zhi

remove unnecessary cast to int32 (#4573)

parent dfc4009c
......@@ -1414,8 +1414,6 @@ class GraphProto(object):
self._num_param += 1
# We should convert scalar integers to int32, to normalize.
array = self._parse_array(t_proto)
if len(array.shape) == 0 and array.dtype == 'int64':
array = _nd.array(array.asnumpy().astype('int32'))
self._params[node.output[0]] = array
self._nodes[node.output[0]] = new_var(
node.output[0],
......
......@@ -1826,6 +1826,24 @@ def test_convtranspose():
verify_convtranspose((1, 1, 3, 3), (1, 2, 3, 3), (1, 2, 7, 3), [1, 2, 1, 2])
def test_unsqueeze_constant():
from torch.nn import Linear, Sequential, Module
class Flatten(Module):
def forward(self, input):
return input.view(input.size(0), -1)
import tempfile
with tempfile.NamedTemporaryFile() as fp:
file_name = fp.name
input_size = (1, 16, 32, 32)
dummy_input = torch.randn(*input_size)
layer = Sequential(Flatten(), Linear(16 * 32 * 32, 64))
torch.onnx.export(layer, dummy_input, file_name, export_params=True)
onnx_model = onnx.load(file_name)
relay.frontend.from_onnx(onnx_model, {'0': input_size})
if __name__ == '__main__':
test_flatten()
test_reshape()
......@@ -1882,3 +1900,4 @@ if __name__ == '__main__':
test_space_to_depth()
test_conv()
test_convtranspose()
test_unsqueeze_constant()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment