Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conver layer_norm_hook to a PyTorch hook #121

Merged
merged 1 commit into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions src/fairseq2/models/w2vbert/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,10 @@ def forward(self, batch: SequenceBatch) -> "W2VBertOutput":

w2v2_layer_output = None

def layer_output_hook(
def hook(
layer_idx: int,
layer_output: Tensor,
layer_padding_mask: PaddingMask,
layer_padding_mask: Optional[PaddingMask],
num_layers: int,
) -> bool:
nonlocal w2v2_layer_output
Expand All @@ -105,10 +105,8 @@ def layer_output_hook(

return True

# TODO: Should we pad for fp16?
encoder_output, _ = self.w2v2_model.encoder(
seqs, padding_mask, layer_output_hook=layer_output_hook
)
with self.w2v2_model.encoder.register_layer_output_hook(hook):
encoder_output, _ = self.w2v2_model.encoder(seqs, padding_mask)

assert w2v2_layer_output is not None

Expand Down
6 changes: 3 additions & 3 deletions src/fairseq2/nn/transformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@
from fairseq2.nn.transformer.multihead_attention import (
AttentionWeightHook as AttentionWeightHook,
)
from fairseq2.nn.transformer.multihead_attention import (
AttentionWeightStoreHook as AttentionWeightStoreHook,
)
from fairseq2.nn.transformer.multihead_attention import (
FullAttentionState as FullAttentionState,
)
Expand All @@ -86,9 +89,6 @@
from fairseq2.nn.transformer.multihead_attention import (
StaticAttentionState as StaticAttentionState,
)
from fairseq2.nn.transformer.multihead_attention import (
StoreAttentionWeights as StoreAttentionWeights,
)
from fairseq2.nn.transformer.norm_order import (
TransformerNormOrder as TransformerNormOrder,
)
Expand Down
46 changes: 36 additions & 10 deletions src/fairseq2/nn/transformer/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Iterable, Optional, Protocol, Tuple, final
from collections import OrderedDict
from typing import Dict, Iterable, Optional, Protocol, Tuple, final

from torch import Tensor
from torch.nn import Module
from torch.utils.hooks import RemovableHandle

from fairseq2.nn.incremental_state import IncrementalStateBag
from fairseq2.nn.module_list import ModuleList
Expand All @@ -33,6 +37,8 @@ class TransformerDecoder(Module, ABC):
model_dim: int
layers: ModuleList

_layer_output_hooks: Dict[int, DecoderLayerOutputHook]

def __init__(self, model_dim: int) -> None:
"""
:param model_dim:
Expand All @@ -42,6 +48,8 @@ def __init__(self, model_dim: int) -> None:

self.model_dim = model_dim

self._layer_output_hooks = OrderedDict()

@abstractmethod
def forward(
self,
Expand All @@ -50,7 +58,6 @@ def forward(
encoder_output: Optional[Tensor] = None,
encoder_padding_mask: Optional[PaddingMask] = None,
*,
layer_output_hook: Optional["DecoderLayerOutputHook"] = None,
state_bag: Optional[IncrementalStateBag] = None,
) -> Tuple[Tensor, Optional[PaddingMask]]:
"""
Expand All @@ -70,9 +77,6 @@ def forward(
The padding mask of ``encoder_output``. *Shape:* :math:`(N,S_{enc})`,
where :math:`N` is the batch size and :math:`S_{enc}` is the encoder
output sequence length.
:param layer_output_hook:
If not ``None``, it will be called with the output of each layer in
the decoder stack.
:param state_bag:
The state bag to use for incremental decoding.

Expand All @@ -82,6 +86,27 @@ def forward(
``padding_mask``.
"""

def register_layer_output_hook(
self, hook: DecoderLayerOutputHook
) -> RemovableHandle:
"""Register a layer output hook on the module.

The hook will be called every time after a layer in the decoder stack
has computed an output.

:param hook:
The hook to register.

:returns:
A handle that can be used to remove the added hook by calling
``handle.remove()``.
"""
handle = RemovableHandle(self._layer_output_hooks)

self._layer_output_hooks[handle.id] = hook

return handle

def extra_repr(self) -> str:
""":meta private:"""
return f"model_dim={self.model_dim}"
Expand Down Expand Up @@ -187,11 +212,12 @@ def forward(
encoder_output: Optional[Tensor] = None,
encoder_padding_mask: Optional[PaddingMask] = None,
*,
layer_output_hook: Optional[DecoderLayerOutputHook] = None,
state_bag: Optional[IncrementalStateBag] = None,
) -> Tuple[Tensor, Optional[PaddingMask]]:
if layer_output_hook is not None and self.layers.drop_p > 0.0:
raise ValueError("`layer_hook` must be `None` when LayerDrop is enabled.")
if self._layer_output_hooks and self.layers.drop_p > 0.0:
raise ValueError(
"The layer output hooks cannot be run when LayerDrop is enabled."
)

num_layers = len(self.layers)

Expand All @@ -212,8 +238,8 @@ def forward(
state_bag=state_bag,
)

if layer_output_hook is not None:
if not layer_output_hook(layer_idx, seqs, padding_mask, num_layers):
for hook in self._layer_output_hooks.values():
if not hook(layer_idx, seqs, padding_mask, num_layers):
break

if self.layer_norm is not None:
Expand Down
59 changes: 40 additions & 19 deletions src/fairseq2/nn/transformer/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Iterable, Optional, Protocol, Tuple, final
from collections import OrderedDict
from typing import Dict, Iterable, Optional, Protocol, Tuple, final

from torch import Tensor
from torch.nn import Module
from torch.utils.hooks import RemovableHandle

from fairseq2.nn.module_list import ModuleList
from fairseq2.nn.normalization import LayerNorm
Expand All @@ -29,6 +33,8 @@ class TransformerEncoder(Module, ABC):
model_dim: int
layers: ModuleList

_layer_output_hooks: Dict[int, EncoderLayerOutputHook]

def __init__(self, model_dim: int) -> None:
"""
:param model_dim:
Expand All @@ -38,13 +44,11 @@ def __init__(self, model_dim: int) -> None:

self.model_dim = model_dim

self._layer_output_hooks = OrderedDict()

@abstractmethod
def forward(
self,
seqs: Tensor,
padding_mask: Optional[PaddingMask],
*,
layer_output_hook: Optional["EncoderLayerOutputHook"] = None,
self, seqs: Tensor, padding_mask: Optional[PaddingMask]
) -> Tuple[Tensor, Optional[PaddingMask]]:
"""
:param seqs:
Expand All @@ -54,23 +58,42 @@ def forward(
:param padding_mask:
The padding mask of ``seqs``. *Shape:* :math:`(N,S)`, where :math:`N`
is the batch size and :math:`S` is the sequence length.
:param layer_output_hook:
If not ``None``, it will be called with the output of each layer in
the encoder stack.

:returns:
- The encoder output. *Shape:* Same as ``seqs``.
- The padding mask of the encoder output. *Shape:* Same as
``padding_mask``.
"""

def register_layer_output_hook(
self, hook: EncoderLayerOutputHook
) -> RemovableHandle:
"""Register a layer output hook on the module.

The hook will be called every time after a layer in the encoder stack
has computed an output.

:param hook:
The hook to register.

:returns:
A handle that can be used to remove the added hook by calling
``handle.remove()``.
"""
handle = RemovableHandle(self._layer_output_hooks)

self._layer_output_hooks[handle.id] = hook

return handle

def extra_repr(self) -> str:
""":meta private:"""
return f"model_dim={self.model_dim}"


class EncoderLayerOutputHook(Protocol):
"""Represents a hook to pass to :meth:`~TransformerEncoder.forward`."""
"""Represents a hook to pass to
:meth:`~TransformerEncoder.register_layer_output_hook`."""

def __call__(
self,
Expand Down Expand Up @@ -153,14 +176,12 @@ def __init__(

@finaloverride
def forward(
self,
seqs: Tensor,
padding_mask: Optional[PaddingMask],
*,
layer_output_hook: Optional[EncoderLayerOutputHook] = None,
self, seqs: Tensor, padding_mask: Optional[PaddingMask]
) -> Tuple[Tensor, Optional[PaddingMask]]:
if layer_output_hook is not None and self.layers.drop_p > 0.0:
raise ValueError("`layer_hook` must be `None` when LayerDrop is enabled.")
if self._layer_output_hooks and self.layers.drop_p > 0.0:
raise ValueError(
"The layer output hooks cannot be run when LayerDrop is enabled."
)

num_layers = len(self.layers)

Expand All @@ -174,8 +195,8 @@ def forward(
for layer_idx, layer in enumerate(self.layers.drop_iter()):
seqs, padding_mask = layer(seqs, padding_mask, self_attn_mask)

if layer_output_hook is not None:
if not layer_output_hook(layer_idx, seqs, padding_mask, num_layers):
for hook in self._layer_output_hooks.values():
if not hook(layer_idx, seqs, padding_mask, num_layers):
break

if self.layer_norm is not None:
Expand Down
36 changes: 11 additions & 25 deletions src/fairseq2/nn/transformer/multihead_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import Dict, MutableSequence, Optional, Protocol, Tuple, final
Expand Down Expand Up @@ -31,7 +33,7 @@ class MultiheadAttention(Module, ABC):
num_heads: int
model_dim: int

_attn_weight_hooks: Dict[int, "AttentionWeightHook"]
_attn_weight_hooks: Dict[int, AttentionWeightHook]

def __init__(self, model_dim: int, num_heads: int) -> None:
"""
Expand Down Expand Up @@ -94,11 +96,11 @@ def forward(
:math:`M` is the dimensionality of the model.
"""

def register_attn_weight_hook(self, hook: "AttentionWeightHook") -> RemovableHandle:
def register_attn_weight_hook(self, hook: AttentionWeightHook) -> RemovableHandle:
"""Register an attention weight hook on the module.

The hook will be called every time after the module computes attention
weights.
The hook will be called every time after the module has computed
attention weights.

:param hook:
The hook to register.
Expand All @@ -113,23 +115,6 @@ def register_attn_weight_hook(self, hook: "AttentionWeightHook") -> RemovableHan

return handle

def _run_attn_weight_hooks(self, attn: Tensor, attn_weights: Tensor) -> None:
"""Run registered attention weight hooks.

:param attn:
The computed attention values. *Shape:* :math:`(N,S,V)`, where
:math:`N` is the batch size, :math:`S` is the sequence length, and
:math:`V` is the value size.
:param attn_weights:
The computed attention weights. *Shape:* :math:`(N,S,S_{kv})`, where
:math:`N` is the batch size, :math:`S` is the sequence length, and
:math:`S_{kv}` is the key/value sequence length.

:meta public:
"""
for hook in self._attn_weight_hooks.values():
hook(self, attn, attn_weights)

def extra_repr(self) -> str:
""":meta private:"""
return f"num_heads={self.num_heads}, model_dim={self.model_dim}"
Expand All @@ -156,7 +141,7 @@ def __call__(
"""


class StoreAttentionWeights:
class AttentionWeightStoreHook:
"""Stores attention weights in a provided storage.

.. note::
Expand Down Expand Up @@ -195,7 +180,7 @@ class StandardMultiheadAttention(MultiheadAttention):
sdpa: SDPA
head_scale_weight: Optional[Parameter]
output_proj: Projection
state_factory: Optional["AttentionStateFactory"]
state_factory: Optional[AttentionStateFactory]

def __init__(
self,
Expand All @@ -212,7 +197,7 @@ def __init__(
scale_heads: bool = False,
output_proj: Optional[Projection] = None,
bias: bool = True,
state_factory: Optional["AttentionStateFactory"] = None,
state_factory: Optional[AttentionStateFactory] = None,
device: Optional[Device] = None,
dtype: Optional[DataType] = None,
) -> None:
Expand Down Expand Up @@ -475,7 +460,8 @@ def forward(
)

if attn_weights is not None:
self._run_attn_weight_hooks(attn, attn_weights)
for hook in self._attn_weight_hooks.values():
hook(self, attn, attn_weights)

# (N, H, S, V_h) -> (N, S, H, V_h)
attn = attn.transpose(1, 2)
Expand Down