From 8e5036c912aa93f4c01dee577c2930e39c06e09e Mon Sep 17 00:00:00 2001 From: Can Balioglu Date: Wed, 13 Sep 2023 08:06:39 -0700 Subject: [PATCH] Improve layer output hook API --- src/fairseq2/models/w2vbert/model.py | 4 +++- src/fairseq2/nn/transformer/decoder.py | 13 ++++++++++--- src/fairseq2/nn/transformer/encoder.py | 13 ++++++++++--- 3 files changed, 23 insertions(+), 7 deletions(-) diff --git a/src/fairseq2/models/w2vbert/model.py b/src/fairseq2/models/w2vbert/model.py index 5db89fd92..a5e579f7e 100644 --- a/src/fairseq2/models/w2vbert/model.py +++ b/src/fairseq2/models/w2vbert/model.py @@ -96,12 +96,14 @@ def layer_output_hook( layer_output: Tensor, layer_padding_mask: Optional[Tensor], num_layers: int, - ) -> None: + ) -> bool: nonlocal w2v2_layer_output if layer_idx == num_layers - self.num_bert_encoder_layers - 1: w2v2_layer_output = layer_output + return True + # TODO: Should we pad for fp16? encoder_output, _ = self.w2v2_model.encoder( seqs, padding_mask, layer_output_hook=layer_output_hook diff --git a/src/fairseq2/nn/transformer/decoder.py b/src/fairseq2/nn/transformer/decoder.py index c5668e02d..8a342bf56 100644 --- a/src/fairseq2/nn/transformer/decoder.py +++ b/src/fairseq2/nn/transformer/decoder.py @@ -96,16 +96,22 @@ def __call__( layer_output: Tensor, layer_padding_mask: Optional[Tensor], num_layers: int, - ) -> None: + ) -> bool: """ :param layer_idx: The index of the layer in the decoder stack. :param layer_output: The decoded output of the layer. :param layer_padding_mask: - The padding mask of `layer_output`. + The padding mask of ``layer_output``. :param num_layers: The number of layers in the decoder stack. + + :returns: + ``True`` if the decoder should continue executing the remaining + layers in the stack; ``False`` if the decoder should stop executing + the remaining layers and treat this layer as the final layer in the + stack. """ @@ -202,7 +208,8 @@ def forward( ) if layer_output_hook is not None: - layer_output_hook(layer_idx, seqs, padding_mask, num_layers) + if not layer_output_hook(layer_idx, seqs, padding_mask, num_layers): + break if self.layer_norm is not None: seqs = self.layer_norm(seqs) diff --git a/src/fairseq2/nn/transformer/encoder.py b/src/fairseq2/nn/transformer/encoder.py index 88864b062..e7d447a71 100644 --- a/src/fairseq2/nn/transformer/encoder.py +++ b/src/fairseq2/nn/transformer/encoder.py @@ -77,16 +77,22 @@ def __call__( layer_output: Tensor, layer_padding_mask: Optional[Tensor], num_layers: int, - ) -> None: + ) -> bool: """ :param layer_idx: The index of the layer in the encoder stack. :param layer_output: The encoded output of the layer. :param layer_padding_mask: - The padding mask of `layer_output`. + The padding mask of ``layer_output``. :param num_layers: The number of layers in the encoder stack. + + :returns: + ``True`` if the encoder should continue executing the remaining + layers in the stack; ``False`` if the encoder should stop executing + the remaining layers and treat this layer as the final layer in the + stack. """ @@ -158,7 +164,8 @@ def forward( seqs, padding_mask = layer(seqs, padding_mask) if layer_output_hook is not None: - layer_output_hook(layer_idx, seqs, padding_mask, num_layers) + if not layer_output_hook(layer_idx, seqs, padding_mask, num_layers): + break if self.layer_norm is not None: seqs = self.layer_norm(seqs)