Commit 58fa0531 by Steven S. Lyubomirsky Committed by Tianqi Chen

Reverse shape dims of weight type (#2155)

parent f3ae3f20
......@@ -49,7 +49,7 @@ def lstm_cell(num_hidden, batch_size=1, dtype="float32", name=""):
builder = relay.ScopeBuilder()
input_type = relay.TensorType((batch_size, num_hidden), dtype)
weight_type = relay.TensorType((num_hidden, 4*num_hidden), dtype)
weight_type = relay.TensorType((4*num_hidden, num_hidden), dtype)
bias_type = relay.TensorType((4*num_hidden,), dtype)
dense_type = relay.TensorType((batch_size, 4*num_hidden), dtype)
......@@ -116,7 +116,7 @@ def lstm_cell(num_hidden, batch_size=1, dtype="float32", name=""):
def get_net(iterations, num_hidden, batch_size=1, dtype="float32"):
'''Constructs an unrolled RNN with LSTM cells'''
input_type = relay.TensorType((batch_size, num_hidden), dtype)
weight_type = relay.TensorType((num_hidden, 4*num_hidden), dtype)
weight_type = relay.TensorType((4*num_hidden, num_hidden), dtype)
bias_type = relay.TensorType((4*num_hidden,), dtype)
state_type = relay.TupleType([input_type, input_type])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment