Unverified Commit 0145cd50 by masahi Committed by GitHub

[Torch] Support Python list, more realistic recurrent networks (#5306)

* use funcs from prelude, pass around convert_map

* get relay input type from user ishape

* handle tuple unpack

* experimenting with static tensor array

* use prelude concat instead of cons + rev

* minor clean up

* fix layer norm conversion bug, unwrap tensor array

* add infer shape on tensor array

* pass around prelude for now

* compile worked but runtime error

* fix tensor array wrapping

* begin list dynamic test

* is_list_dynamic first version

* finish dynamic list test

* a few fix

* use shape_of function if Any is found

* improve size conversion

* working on adding free vars to loop block

* fixed inlined inner loop issue

* clean up free var handling

* add support for tensor array concat

* adding ta concat on last axis

* fix concat, but got runtime error

* disable concat on axis -1 for now

* add lstm tests

* revert unrelated change

* fix stacked bidir test

* minor fix to test

* relax tol a bit, revert dnnl change to avoid conflict

* simplify infer type, use input tensor shape rather than concat shape

* more shape fix
parent cd0d52da
......@@ -543,7 +543,7 @@ def test_forward_maxpool1d():
input_data)
verify_model(torch.nn.MaxPool1d(kernel_size=10).eval(),
input_data)
verify_model( torch.nn.MaxPool1d(kernel_size=4,
verify_model(torch.nn.MaxPool1d(kernel_size=4,
padding=2,
stride=2).eval(),
input_data)
......@@ -1363,3 +1363,8 @@ if __name__ == "__main__":
# Test simple conditionals and loop
test_control_flow()
test_simple_rnn()
# More complex recurrent models
from lstm_test import custom_lstm_test
custom_lstm_test()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment