Skip to content

Commit

Permalink
fix: Fixed TransformerDecoder for PyTorch 1.10 (#539)
Browse files Browse the repository at this point in the history
  • Loading branch information
fg-mindee committed Oct 22, 2021
1 parent 49dc156 commit 323d484
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions doctr/models/recognition/transformer/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(
dim_feedforward=dff,
dropout=dropout,
activation='relu',
batch_first=True,
) for _ in range(num_layers)
])

Expand All @@ -79,13 +80,11 @@ def forward(
x += self.pos_encoding[:, :seq_len, :]
x = self.dropout(x)

# Batch first = False in decoder
x = x.permute(1, 0, 2)
# Batch first = True in decoder
for i in range(self.num_layers):
x = self.dec_layers[i](
tgt=x, memory=enc_output, tgt_mask=look_ahead_mask, memory_mask=padding_mask
)

# shape (batch_size, target_seq_len, d_model)
x = x.permute(1, 0, 2)
return x

0 comments on commit 323d484

Please sign in to comment.