diff --git a/src/fairseq2/nn/transformer/multihead_attention.py b/src/fairseq2/nn/transformer/multihead_attention.py index 42eb601f8..c7268a496 100644 --- a/src/fairseq2/nn/transformer/multihead_attention.py +++ b/src/fairseq2/nn/transformer/multihead_attention.py @@ -457,7 +457,7 @@ def forward( v = v.flatten(0, 1) if self.pos_encoder is not None: - q = self.pos_encoder(q, padding_mask, state_bag) + q = self.pos_encoder(q, padding_mask, state_bag=state_bag) k = self.pos_encoder(k, key_padding_mask) mask_pad = 0