Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fixed symbols naming in RNNCell, LSTMCell, GRUCell (v1.3.x) (#13158)
Browse files Browse the repository at this point in the history
* fixed symbols naming in RNNCell and LSTMCell

* fixed GRUCell as well

* added test

* fixed tests?
  • Loading branch information
lebeg authored and szha committed Nov 7, 2018
1 parent 23c09c7 commit dff0431
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 8 deletions.
25 changes: 17 additions & 8 deletions python/mxnet/gluon/rnn/rnn_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,8 @@ def hybrid_forward(self, F, inputs, states, i2h_weight,
h2h = F.FullyConnected(data=states[0], weight=h2h_weight, bias=h2h_bias,
num_hidden=self._hidden_size,
name=prefix+'h2h')
output = self._get_activation(F, i2h + h2h, self._activation,
i2h_plus_h2h = F.elemwise_add(i2h, h2h, name=prefix+'plus0')
output = self._get_activation(F, i2h_plus_h2h, self._activation,
name=prefix+'out')

return output, [output]
Expand Down Expand Up @@ -513,7 +514,7 @@ def hybrid_forward(self, F, inputs, states, i2h_weight,
num_hidden=self._hidden_size*4, name=prefix+'i2h')
h2h = F.FullyConnected(data=states[0], weight=h2h_weight, bias=h2h_bias,
num_hidden=self._hidden_size*4, name=prefix+'h2h')
gates = i2h + h2h
gates = F.elemwise_add(i2h, h2h, name=prefix+'plus0')
slice_gates = F.SliceChannel(gates, num_outputs=4, name=prefix+'slice')
in_gate = self._get_activation(
F, slice_gates[0], self._recurrent_activation, name=prefix+'i')
Expand All @@ -523,9 +524,10 @@ def hybrid_forward(self, F, inputs, states, i2h_weight,
F, slice_gates[2], self._activation, name=prefix+'c')
out_gate = self._get_activation(
F, slice_gates[3], self._recurrent_activation, name=prefix+'o')
next_c = F._internal._plus(forget_gate * states[1], in_gate * in_transform,
next_c = F._internal._plus(F.elemwise_mul(forget_gate, states[1], name=prefix+'mul0'),
F.elemwise_mul(in_gate, in_transform, name=prefix+'mul1'),
name=prefix+'state')
next_h = F._internal._mul(out_gate, F.Activation(next_c, act_type=self._activation),
next_h = F._internal._mul(out_gate, F.Activation(next_c, act_type=self._activation, name=prefix+'activation0'),
name=prefix+'out')

return next_h, [next_h, next_c]
Expand Down Expand Up @@ -637,15 +639,22 @@ def hybrid_forward(self, F, inputs, states, i2h_weight,
h2h_r, h2h_z, h2h = F.SliceChannel(h2h, num_outputs=3,
name=prefix+'h2h_slice')

reset_gate = F.Activation(i2h_r + h2h_r, act_type="sigmoid",
reset_gate = F.Activation(F.elemwise_add(i2h_r, h2h_r, name=prefix+'plus0'), act_type="sigmoid",
name=prefix+'r_act')
update_gate = F.Activation(i2h_z + h2h_z, act_type="sigmoid",
update_gate = F.Activation(F.elemwise_add(i2h_z, h2h_z, name=prefix+'plus1'), act_type="sigmoid",
name=prefix+'z_act')

next_h_tmp = F.Activation(i2h + reset_gate * h2h, act_type="tanh",
next_h_tmp = F.Activation(F.elemwise_add(i2h,
F.elemwise_mul(reset_gate, h2h, name=prefix+'mul0'),
name=prefix+'plus2'),
act_type="tanh",
name=prefix+'h_act')

next_h = F._internal._plus((1. - update_gate) * next_h_tmp, update_gate * prev_state_h,
ones = F.ones_like(update_gate, name=prefix+"ones_like0")
next_h = F._internal._plus(F.elemwise_mul(F.elemwise_sub(ones, update_gate, name=prefix+'minus0'),
next_h_tmp,
name=prefix+'mul1'),
F.elemwise_mul(update_gate, prev_state_h, name=prefix+'mul20'),
name=prefix+'out')

return next_h, [next_h]
Expand Down
48 changes: 48 additions & 0 deletions tests/python/unittest/test_gluon_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,54 @@ def test_rnn_cells():
net.add(gluon.rnn.GRUCell(100, input_size=100))
check_rnn_forward(net, mx.nd.ones((8, 3, 200)))


def test_rnn_cells_export_import():
class RNNLayer(gluon.HybridBlock):
def __init__(self):
super(RNNLayer, self).__init__()
with self.name_scope():
self.cell = gluon.rnn.RNNCell(hidden_size=1)

def hybrid_forward(self, F, seq):
outputs, state = self.cell.unroll(inputs=seq, length=2, merge_outputs=True)
return outputs

class LSTMLayer(gluon.HybridBlock):
def __init__(self):
super(LSTMLayer, self).__init__()
with self.name_scope():
self.cell = gluon.rnn.LSTMCell(hidden_size=1)

def hybrid_forward(self, F, seq):
outputs, state = self.cell.unroll(inputs=seq, length=2, merge_outputs=True)
return outputs

class GRULayer(gluon.HybridBlock):
def __init__(self):
super(GRULayer, self).__init__()
with self.name_scope():
self.cell = gluon.rnn.GRUCell(hidden_size=1)

def hybrid_forward(self, F, seq):
outputs, state = self.cell.unroll(inputs=seq, length=2, merge_outputs=True)
return outputs

for hybrid in [RNNLayer(), LSTMLayer(), GRULayer()]:
hybrid.initialize()
hybrid.hybridize()
input = mx.nd.ones(shape=(1, 2, 1))
output1 = hybrid(input)
hybrid.export(path="./model", epoch=0)
symbol = mx.gluon.SymbolBlock.imports(
symbol_file="./model-symbol.json",
input_names=["data"],
param_file="./model-0000.params",
ctx=mx.Context.default_ctx
)
output2 = symbol(input)
assert_almost_equal(output1.asnumpy(), output2.asnumpy())


def check_rnn_layer_forward(layer, inputs, states=None, run_only=False):
layer.collect_params().initialize()
inputs.attach_grad()
Expand Down

0 comments on commit dff0431

Please sign in to comment.