Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix the bug of BidirectionalCell #13575

Merged
merged 8 commits into from
Dec 13, 2018
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ List of Contributors
* [Rahul Padmanabhan](https://github.com/rahul3)
* [Yuxi Hu](https://github.com/yuxihu)
* [Harsh Patel](https://github.com/harshp8l)
* [Xiao Wang](https://github.com/BeyonderXX)

Label Bot
---------
Expand Down
28 changes: 28 additions & 0 deletions tests/python/unittest/test_gluon_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,34 @@ def test_bidirectional():
assert outs == [(10, 200), (10, 200), (10, 200)]


def test_bidirectional_unroll_valid_length():
# Test BidirectionalCell.
# In 1.3.1 version, after hybridize( ), BidirectionalCell would failed when pass valid_length to unroll( ).
class BiLSTM(gluon.nn.HybridBlock):
def __init__(self, rnn_size, time_step, **kwargs):
super(BiLSTM, self).__init__(**kwargs)
self.time_step = time_step
with self.name_scope():
self.bi_lstm = gluon.rnn.BidirectionalCell(
gluon.rnn.LSTMCell(rnn_size, prefix='rnn_l0_'),
gluon.rnn.LSTMCell(rnn_size, prefix='rnn_r0_'),
output_prefix='lstm_bi_')

def hybrid_forward(self, F, inputs, valid_len):
outputs, states = self.bi_lstm.unroll(self.time_step, inputs, valid_length=valid_len,
layout='TNC', merge_outputs='True')
BeyonderXX marked this conversation as resolved.
Show resolved Hide resolved
return outputs, states

rnn_size, time_step = 100, 3
net = BiLSTM(rnn_size, time_step)
net.initialize()
net.hybridize()
inputs_data = mx.nd.random.uniform(shape=(3, 10, 50))
valid_len = mx.nd.array(range(1, 11))
outputs, _ = net(inputs_data, valid_len)
assert outputs.shape == (3, 10, 200)


@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
def test_layer_bidirectional():
class RefBiLSTM(gluon.Block):
Expand Down