Unverified Commit 0145cd50 by masahi Committed by GitHub

[Torch] Support Python list, more realistic recurrent networks (#5306)

* use funcs from prelude, pass around convert_map

* get relay input type from user ishape

* handle tuple unpack

* experimenting with static tensor array

* use prelude concat instead of cons + rev

* minor clean up

* fix layer norm conversion bug, unwrap tensor array

* add infer shape on tensor array

* pass around prelude for now

* compile worked but runtime error

* fix tensor array wrapping

* begin list dynamic test

* is_list_dynamic first version

* finish dynamic list test

* a few fix

* use shape_of function if Any is found

* improve size conversion

* working on adding free vars to loop block

* fixed inlined inner loop issue

* clean up free var handling

* add support for tensor array concat

* adding ta concat on last axis

* fix concat, but got runtime error

* disable concat on axis -1 for now

* add lstm tests

* revert unrelated change

* fix stacked bidir test

* minor fix to test

* relax tol a bit, revert dnnl change to avoid conflict

* simplify infer type, use input tensor shape rather than concat shape

* more shape fix
parent cd0d52da
...@@ -526,13 +526,13 @@ def test_forward_maxpool2d(): ...@@ -526,13 +526,13 @@ def test_forward_maxpool2d():
input_data = torch.rand(input_shape).float() input_data = torch.rand(input_shape).float()
verify_model(torch.nn.MaxPool2d(kernel_size=[1, 1]).eval(), verify_model(torch.nn.MaxPool2d(kernel_size=[1, 1]).eval(),
input_data) input_data)
verify_model(torch.nn.MaxPool2d(kernel_size=[10, 10]).eval(), verify_model(torch.nn.MaxPool2d(kernel_size=[10, 10]).eval(),
input_data) input_data)
verify_model(torch.nn.MaxPool2d(kernel_size=[4, 4], verify_model(torch.nn.MaxPool2d(kernel_size=[4, 4],
padding=2, padding=2,
stride=2).eval(), stride=2).eval(),
input_data) input_data)
def test_forward_maxpool1d(): def test_forward_maxpool1d():
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
...@@ -540,13 +540,13 @@ def test_forward_maxpool1d(): ...@@ -540,13 +540,13 @@ def test_forward_maxpool1d():
input_data = torch.rand(input_shape).float() input_data = torch.rand(input_shape).float()
verify_model(torch.nn.MaxPool1d(kernel_size=1).eval(), verify_model(torch.nn.MaxPool1d(kernel_size=1).eval(),
input_data) input_data)
verify_model(torch.nn.MaxPool1d(kernel_size=10).eval(), verify_model(torch.nn.MaxPool1d(kernel_size=10).eval(),
input_data) input_data)
verify_model( torch.nn.MaxPool1d(kernel_size=4, verify_model(torch.nn.MaxPool1d(kernel_size=4,
padding=2, padding=2,
stride=2).eval(), stride=2).eval(),
input_data) input_data)
def test_forward_maxpool3d(): def test_forward_maxpool3d():
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
...@@ -554,13 +554,13 @@ def test_forward_maxpool3d(): ...@@ -554,13 +554,13 @@ def test_forward_maxpool3d():
input_data = torch.rand(input_shape).float() input_data = torch.rand(input_shape).float()
verify_model(torch.nn.MaxPool3d(kernel_size=[1, 1, 1]).eval(), verify_model(torch.nn.MaxPool3d(kernel_size=[1, 1, 1]).eval(),
input_data) input_data)
verify_model(torch.nn.MaxPool3d(kernel_size=[10, 10, 10]).eval(), verify_model(torch.nn.MaxPool3d(kernel_size=[10, 10, 10]).eval(),
input_data) input_data)
verify_model(torch.nn.MaxPool3d(kernel_size=[4, 4, 4], verify_model(torch.nn.MaxPool3d(kernel_size=[4, 4, 4],
padding=2, padding=2,
stride=2).eval(), stride=2).eval(),
input_data) input_data)
def test_forward_split(): def test_forward_split():
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
...@@ -577,13 +577,13 @@ def test_forward_split(): ...@@ -577,13 +577,13 @@ def test_forward_split():
input_data = torch.rand(input_shape).float() input_data = torch.rand(input_shape).float()
verify_model(Split(2, 0).float().eval(), verify_model(Split(2, 0).float().eval(),
input_data=input_data) input_data=input_data)
verify_model(Split(3, 1).float().eval(), verify_model(Split(3, 1).float().eval(),
input_data=input_data) input_data=input_data)
verify_model(Split(4, 1).float().eval(), verify_model(Split(4, 1).float().eval(),
input_data=input_data) input_data=input_data)
verify_model(Split([2, 3, 5], 1).float().eval(), verify_model(Split([2, 3, 5], 1).float().eval(),
input_data=input_data) input_data=input_data)
def test_forward_avgpool(): def test_forward_avgpool():
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
...@@ -1363,3 +1363,8 @@ if __name__ == "__main__": ...@@ -1363,3 +1363,8 @@ if __name__ == "__main__":
# Test simple conditionals and loop # Test simple conditionals and loop
test_control_flow() test_control_flow()
test_simple_rnn() test_simple_rnn()
# More complex recurrent models
from lstm_test import custom_lstm_test
custom_lstm_test()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment