Skip to content

Commit

Permalink
Fixes PyTorch 2.1 compatibility issues (#132)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalioglu authored Nov 4, 2023
1 parent 7d99d80 commit 4ec9edf
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 25 deletions.
2 changes: 1 addition & 1 deletion fairseq2n/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ if(FAIRSEQ2N_THREAD_LIB STREQUAL "tbb")
find_package(TBB 2021.8 REQUIRED)
endif()

find_package(Torch 1.12 REQUIRED)
find_package(Torch 1.13 REQUIRED)

if(FAIRSEQ2N_BUILD_PYTHON_BINDINGS)
find_package(Python3 REQUIRED COMPONENTS Interpreter Development.Module)
Expand Down
7 changes: 3 additions & 4 deletions src/fairseq2/optim/optimizer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,13 @@
import torch
from torch.optim import Optimizer

from fairseq2.typing import finaloverride


class OptimizerBase(ABC, Optimizer):
"""Represents the base class for all optimizers."""

@finaloverride
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
def step( # type: ignore[override]
self, closure: Optional[Callable[[], float]] = None
) -> Optional[float]:
loss = None

prev_grad = torch.is_grad_enabled()
Expand Down
33 changes: 13 additions & 20 deletions tests/unit/nn/transformer/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,46 +13,39 @@
from fairseq2.nn.padding import PaddingMask
from fairseq2.nn.transformer import CustomAttentionMask, NaiveSDPA, TorchSDPA
from fairseq2.utils.version import is_pt2_or_greater
from tests.common import assert_close, device, tmp_rng_seed
from tests.common import assert_close, device


class TestScaledDotProductAttention:
@pytest.mark.skipif(
not is_pt2_or_greater(), reason="requires PyTorch 2.0.0 or greater"
)
# fmt: off
@pytest.mark.parametrize("use_key_padding_mask,use_attn_mask,attn_dropout_p,training",
@pytest.mark.parametrize("use_key_padding_mask,use_attn_mask,training",
[
(False, False, 0.0, True),
(True, True, 0.0, True),
(False, True, 0.5, True),
(True, False, 0.5, True),
(False, False, 0.5, False),
(False, True, 0.9, False),
(False, False, True),
(True, True, True),
(False, True, True),
(True, False, True),
(False, False, False),
(False, True, False),
],
)
# fmt: on
def test_torch_sdpa(
self,
use_key_padding_mask: bool,
use_attn_mask: bool,
attn_dropout_p: float,
training: bool,
self, use_key_padding_mask: bool, use_attn_mask: bool, training: bool
) -> None:
torch_sdpa = TorchSDPA(attn_dropout_p=attn_dropout_p)
naive_sdpa = NaiveSDPA(attn_dropout_p=attn_dropout_p)
torch_sdpa = TorchSDPA()
naive_sdpa = NaiveSDPA()

if training:
torch_sdpa.eval()
naive_sdpa.eval()

kwargs = self._get_sdpa_args(use_key_padding_mask, use_attn_mask)

with tmp_rng_seed(device):
attn1, _ = torch_sdpa(**kwargs)

with tmp_rng_seed(device):
attn2, _ = naive_sdpa(**kwargs)
attn1, _ = torch_sdpa(**kwargs)
attn2, _ = naive_sdpa(**kwargs)

assert_close(attn1, attn2)

Expand Down

0 comments on commit 4ec9edf

Please sign in to comment.