Skip to content
Snippets Groups Projects
Commit 58fa0531 authored by Steven S. Lyubomirsky's avatar Steven S. Lyubomirsky Committed by Tianqi Chen
Browse files

Reverse shape dims of weight type (#2155)

parent f3ae3f20
No related branches found
No related tags found
No related merge requests found
......@@ -49,7 +49,7 @@ def lstm_cell(num_hidden, batch_size=1, dtype="float32", name=""):
builder = relay.ScopeBuilder()
input_type = relay.TensorType((batch_size, num_hidden), dtype)
weight_type = relay.TensorType((num_hidden, 4*num_hidden), dtype)
weight_type = relay.TensorType((4*num_hidden, num_hidden), dtype)
bias_type = relay.TensorType((4*num_hidden,), dtype)
dense_type = relay.TensorType((batch_size, 4*num_hidden), dtype)
......@@ -116,7 +116,7 @@ def lstm_cell(num_hidden, batch_size=1, dtype="float32", name=""):
def get_net(iterations, num_hidden, batch_size=1, dtype="float32"):
'''Constructs an unrolled RNN with LSTM cells'''
input_type = relay.TensorType((batch_size, num_hidden), dtype)
weight_type = relay.TensorType((num_hidden, 4*num_hidden), dtype)
weight_type = relay.TensorType((4*num_hidden, num_hidden), dtype)
bias_type = relay.TensorType((4*num_hidden,), dtype)
state_type = relay.TupleType([input_type, input_type])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment