Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add testing for PyTorch 2.4 (Fabric) #20028

Merged
merged 19 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions .azure/gpu-tests-fabric.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ jobs:
"Lightning | latest":
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.11-torch2.3-cuda12.1.0"
PACKAGE_NAME: "lightning"
"Lightning | future":
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.11-torch2.4-cuda12.1.0"
PACKAGE_NAME: "lightning"
workspace:
clean: all
steps:
Expand All @@ -72,9 +75,12 @@ jobs:
echo "##vso[task.setvariable variable=TORCH_URL]https://download.pytorch.org/whl/cu${cuda_ver}/torch_stable.html"
scope=$(python -c 'n = "$(PACKAGE_NAME)" ; print(dict(fabric="lightning_fabric").get(n, n))')
echo "##vso[task.setvariable variable=COVERAGE_SOURCE]$scope"
python_ver=$(python -c "import sys; print(f'{sys.version_info.major}{sys.version_info.minor}')")
echo "##vso[task.setvariable variable=PYTHON_VERSION_MM]$python_ver"
displayName: "set env. vars"
- bash: |
echo "##vso[task.setvariable variable=TORCH_URL]https://download.pytorch.org/whl/test/cu${CUDA_VERSION_MM}/torch_test.html"
echo "##vso[task.setvariable variable=TORCH_URL]https://download.pytorch.org/whl/test/cu${CUDA_VERSION_MM}"
echo "##vso[task.setvariable variable=TORCHVISION_URL]https://download.pytorch.org/whl/test/cu124/torchvision-0.19.0%2Bcu124-cp${PYTHON_VERSION_MM}-cp${PYTHON_VERSION_MM}-linux_x86_64.whl"
condition: endsWith(variables['Agent.JobName'], 'future')
displayName: "set env. vars 4 future"

Expand Down Expand Up @@ -103,7 +109,7 @@ jobs:

- bash: |
extra=$(python -c "print({'lightning': 'fabric-'}.get('$(PACKAGE_NAME)', ''))")
pip install -e ".[${extra}dev]" pytest-timeout -U --find-links="${TORCH_URL}"
pip install -e ".[${extra}dev]" pytest-timeout -U --find-links="${TORCH_URL}" --find-links="${TORCHVISION_URL}"
displayName: "Install package & dependencies"

- bash: |
Expand Down
8 changes: 6 additions & 2 deletions .github/workflows/ci-tests-fabric.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ jobs:
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.3" }
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.4" }
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.4" }
# TODO: PyTorch 2.4 on Windows not yet working with `torch.distributed` (not compiled with libuv support)
# - { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.4" }
# only run PyTorch latest with Python latest, use Fabric scope to limit dependency issues
- { os: "macOS-12", pkg-name: "fabric", python-version: "3.11", pytorch-version: "2.1" }
- { os: "ubuntu-22.04", pkg-name: "fabric", python-version: "3.11", pytorch-version: "2.1" }
Expand Down Expand Up @@ -79,7 +83,7 @@ jobs:
FREEZE_REQUIREMENTS: ${{ ! (github.ref == 'refs/heads/master' || startsWith(github.ref, 'refs/heads/release/')) }}
PYPI_CACHE_DIR: "_pip-wheels"
TORCH_URL_STABLE: "https://download.pytorch.org/whl/cpu/torch_stable.html"
TORCH_URL_TEST: "https://download.pytorch.org/whl/test/cpu/torch_test.html"
TORCH_URL_TEST: "https://download.pytorch.org/whl/test/cpu/torch"
# TODO: Remove this - Enable running MPS tests on this platform
DISABLE_MPS: ${{ matrix.os == 'macOS-14' && '1' || '0' }}
steps:
Expand Down Expand Up @@ -118,7 +122,7 @@ jobs:
- name: Env. variables
run: |
# Switch PyTorch URL
python -c "print('TORCH_URL=' + str('${{env.TORCH_URL_TEST}}' if '${{ matrix.pytorch-version }}' == '2.3' else '${{env.TORCH_URL_STABLE}}'))" >> $GITHUB_ENV
python -c "print('TORCH_URL=' + str('${{env.TORCH_URL_TEST}}' if '${{ matrix.pytorch-version }}' == '2.4' else '${{env.TORCH_URL_STABLE}}'))" >> $GITHUB_ENV
# Switch coverage scope
python -c "print('COVERAGE_SCOPE=' + str('lightning' if '${{matrix.pkg-name}}' == 'lightning' else 'lightning_fabric'))" >> $GITHUB_ENV
# if you install mono-package set dependency only for this subpackage
Expand Down
2 changes: 1 addition & 1 deletion requirements/fabric/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

numpy >=1.21.0, <1.27.0
torch >=2.1.0, <2.4.0
torch >=2.1.0, <2.5.0
fsspec[http] >=2022.5.0, <2024.4.0
packaging >=20.0, <=23.1
typing-extensions >=4.4.0, <4.10.0
Expand Down
2 changes: 1 addition & 1 deletion requirements/fabric/examples.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

torchvision >=0.16.0, <0.19.0
torchvision >=0.16.0, <0.20.0
torchmetrics >=0.10.0, <1.3.0
lightning-utilities >=0.8.0, <0.12.0
5 changes: 3 additions & 2 deletions src/lightning/fabric/plugins/precision/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from lightning.fabric.plugins.precision.precision import Precision
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
from lightning.fabric.utilities.types import Optimizable


Expand All @@ -39,7 +40,7 @@ def __init__(
self,
precision: Literal["16-mixed", "bf16-mixed"],
device: str,
scaler: Optional[torch.cuda.amp.GradScaler] = None,
scaler: Optional["torch.cuda.amp.GradScaler"] = None,
) -> None:
if precision not in ("16-mixed", "bf16-mixed"):
raise ValueError(
Expand All @@ -49,7 +50,7 @@ def __init__(

self.precision = precision
if scaler is None and self.precision == "16-mixed":
scaler = torch.cuda.amp.GradScaler()
scaler = torch.amp.GradScaler(device=device) if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler()
if scaler is not None and self.precision == "bf16-mixed":
raise ValueError(f"`precision='bf16-mixed'` does not use a scaler, found {scaler}.")
self.device = device
Expand Down
4 changes: 4 additions & 0 deletions src/lightning/fabric/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
import warnings
from contextlib import ExitStack, nullcontext
from datetime import timedelta
from functools import partial
Expand Down Expand Up @@ -83,6 +84,9 @@

_FSDP_ALIASES = ("fsdp", "fsdp_cpu_offload")

# TODO: Switch to new state-dict APIs
warnings.filterwarnings("ignore", category=FutureWarning, message=".*FSDP.state_dict_type.*") # from torch >= 2.4


class FSDPStrategy(ParallelStrategy, _Sharded):
r"""Strategy for Fully Sharded Data Parallel provided by torch.distributed.
Expand Down
18 changes: 6 additions & 12 deletions src/lightning/fabric/strategies/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class ModelParallelStrategy(ParallelStrategy):

Currently supports up to 2D parallelism. Specifically, it supports the combination of
Fully Sharded Data-Parallel 2 (FSDP2) with Tensor Parallelism (DTensor). These PyTorch APIs are currently still
experimental in PyTorch. Requires PyTorch 2.3 or newer.
experimental in PyTorch. Requires PyTorch 2.4 or newer.

Arguments:
parallelize_fn: A function that applies parallelisms to a module. The strategy will provide the
Expand All @@ -95,8 +95,8 @@ def __init__(
timeout: Optional[timedelta] = default_pg_timeout,
) -> None:
super().__init__()
if not _TORCH_GREATER_EQUAL_2_3:
raise ImportError(f"{type(self).__name__} requires PyTorch 2.3 or higher.")
if not _TORCH_GREATER_EQUAL_2_4:
raise ImportError(f"{type(self).__name__} requires PyTorch 2.4 or higher.")
self._parallelize_fn = parallelize_fn
self._data_parallel_size = data_parallel_size
self._tensor_parallel_size = tensor_parallel_size
Expand Down Expand Up @@ -178,7 +178,7 @@ def setup_module(self, module: TModel) -> TModel:
if any(isinstance(mod, FullyShardedDataParallel) for mod in module.modules()):
raise TypeError(
"Found modules that are wrapped with `torch.distributed.fsdp.FullyShardedDataParallel`."
f" The `{self.__class__.__name__}` only supports the new FSDP2 APIs in PyTorch >= 2.3."
f" The `{self.__class__.__name__}` only supports the new FSDP2 APIs in PyTorch >= 2.4."
)

module = self._parallelize_fn(module, self.device_mesh)
Expand Down Expand Up @@ -329,10 +329,10 @@ def __init__(self, module: Module, enabled: bool) -> None:
self._enabled = enabled

def _set_requires_grad_sync(self, requires_grad_sync: bool) -> None:
from torch.distributed._composable.fsdp import FSDP
from torch.distributed._composable.fsdp import FSDPModule

for mod in self._module.modules():
if isinstance(mod, FSDP):
if isinstance(mod, FSDPModule):
mod.set_requires_gradient_sync(requires_grad_sync, recurse=False)

def __enter__(self) -> None:
Expand Down Expand Up @@ -458,9 +458,6 @@ def _load_checkpoint(
return metadata

if _is_full_checkpoint(path):
if not _TORCH_GREATER_EQUAL_2_4:
raise ImportError("Loading a non-distributed checkpoint into a distributed model requires PyTorch >= 2.4.")

checkpoint = torch.load(path, mmap=True, map_location="cpu")
_load_raw_module_state(checkpoint.pop(module_key), module, strict=strict)

Expand Down Expand Up @@ -546,9 +543,6 @@ def _load_raw_module_state(
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

if _has_dtensor_modules(module):
if not _TORCH_GREATER_EQUAL_2_4:
raise ImportError("Loading a non-distributed checkpoint into a distributed model requires PyTorch >= 2.4.")

from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict

state_dict_options = StateDictOptions(
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/fabric/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

_TORCH_GREATER_EQUAL_2_2 = compare_version("torch", operator.ge, "2.2.0")
_TORCH_GREATER_EQUAL_2_3 = compare_version("torch", operator.ge, "2.3.0")
_TORCH_GREATER_EQUAL_2_4 = compare_version("torch", operator.ge, "2.4.0", use_base_version=True)
_TORCH_GREATER_EQUAL_2_4 = compare_version("torch", operator.ge, "2.4.0")

_PYTHON_GREATER_EQUAL_3_8_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 8)
_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)
Expand Down
5 changes: 4 additions & 1 deletion src/lightning/fabric/utilities/testing/_runif.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from lightning.fabric.accelerators.cuda import num_cuda_devices
from lightning.fabric.accelerators.mps import MPSAccelerator
from lightning.fabric.strategies.deepspeed import _DEEPSPEED_AVAILABLE
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4


def _runif_reasons(
Expand Down Expand Up @@ -111,7 +112,9 @@ def _runif_reasons(
reasons.append("Standalone execution")
kwargs["standalone"] = True

if deepspeed and not (_DEEPSPEED_AVAILABLE and RequirementCache(module="deepspeed.utils")):
if deepspeed and not (
_DEEPSPEED_AVAILABLE and not _TORCH_GREATER_EQUAL_2_4 and RequirementCache(module="deepspeed.utils")
):
reasons.append("Deepspeed")

if dynamo:
Expand Down
1 change: 1 addition & 0 deletions tests/tests_fabric/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def thread_police_duuu_daaa_duuu_daaa():
elif (
thread.name == "QueueFeederThread" # tensorboardX
or thread.name == "QueueManagerThread" # torch.compile
or "(_read_thread)" in thread.name # torch.compile
):
thread.join(timeout=20)
elif (
Expand Down
7 changes: 5 additions & 2 deletions tests/tests_fabric/plugins/precision/test_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@
import pytest
import torch
from lightning.fabric.plugins.precision.amp import MixedPrecision
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4


def test_amp_precision_default_scaler():
precision = MixedPrecision(precision="16-mixed", device=Mock())
assert isinstance(precision.scaler, torch.cuda.amp.GradScaler)
scaler_cls = torch.amp.GradScaler if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler
assert isinstance(precision.scaler, scaler_cls)


def test_amp_precision_scaler_with_bf16():
Expand All @@ -36,7 +38,8 @@ def test_amp_precision_forward_context():
"""Test to ensure that the context manager correctly is set to bfloat16 on CPU and CUDA."""
precision = MixedPrecision(precision="16-mixed", device="cuda")
assert precision.device == "cuda"
assert isinstance(precision.scaler, torch.cuda.amp.GradScaler)
scaler_cls = torch.amp.GradScaler if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler
assert isinstance(precision.scaler, scaler_cls)
assert torch.get_default_dtype() == torch.float32
with precision.forward_context():
assert torch.get_autocast_gpu_dtype() == torch.float16
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch
import torch.nn as nn
from lightning.fabric import Fabric, seed_everything
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4

from tests_fabric.helpers.runif import RunIf

Expand Down Expand Up @@ -82,7 +83,8 @@ def run(fused=False):
optimizer = torch.optim.Adam(model.parameters(), lr=1.0, fused=fused)

model, optimizer = fabric.setup(model, optimizer)
assert isinstance(fabric._precision.scaler, torch.cuda.amp.GradScaler)
scaler_cls = torch.amp.GradScaler if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler
assert isinstance(fabric._precision.scaler, scaler_cls)

data = torch.randn(10, 10, device="cuda")
target = torch.randn(10, 10, device="cuda")
Expand Down
4 changes: 2 additions & 2 deletions tests/tests_fabric/plugins/precision/test_bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(self):
precision.convert_module(model)


@RunIf(min_cuda_gpus=1)
@RunIf(min_cuda_gpus=1, max_torch="2.4")
@pytest.mark.skipif(not _BITSANDBYTES_AVAILABLE, reason="bitsandbytes unavailable")
@pytest.mark.parametrize(
("args", "expected"),
Expand Down Expand Up @@ -232,7 +232,7 @@ def __init__(self):
assert model.l.weight.dtype == expected


@RunIf(min_cuda_gpus=1)
@RunIf(min_cuda_gpus=1, max_torch="2.4")
@pytest.mark.skipif(not _BITSANDBYTES_AVAILABLE, reason="bitsandbytes unavailable")
def test_load_quantized_checkpoint(tmp_path):
"""Test that a checkpoint saved from a quantized model can be loaded back into a quantized model."""
Expand Down
1 change: 1 addition & 0 deletions tests/tests_fabric/strategies/test_dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def __instancecheck__(self, instance):
assert strategy.get_module_state_dict(wrapped_module).keys() == original_module.state_dict().keys()


@pytest.mark.filterwarnings("ignore::FutureWarning")
@pytest.mark.parametrize(
"precision",
[
Expand Down
6 changes: 5 additions & 1 deletion tests/tests_fabric/strategies/test_fsdp_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def get_model(self):
return model


@RunIf(min_cuda_gpus=2, standalone=True)
@RunIf(min_cuda_gpus=2, standalone=True, max_torch="2.4")
@pytest.mark.parametrize("precision", ["16-mixed", pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True))])
@pytest.mark.parametrize("manual_wrapping", [True, False])
def test_train_save_load(tmp_path, manual_wrapping, precision):
Expand Down Expand Up @@ -173,6 +173,7 @@ def test_train_save_load(tmp_path, manual_wrapping, precision):
assert state["coconut"] == 11


@pytest.mark.filterwarnings("ignore::FutureWarning")
@RunIf(min_cuda_gpus=2, standalone=True)
def test_save_full_state_dict(tmp_path):
"""Test that FSDP saves the full state into a single file with `state_dict_type="full"`."""
Expand Down Expand Up @@ -287,6 +288,7 @@ def test_save_full_state_dict(tmp_path):
trainer.run()


@pytest.mark.filterwarnings("ignore::FutureWarning")
@RunIf(min_cuda_gpus=2, standalone=True)
def test_load_full_state_dict_into_sharded_model(tmp_path):
"""Test that the strategy can load a full-state checkpoint into a FSDP sharded model."""
Expand Down Expand Up @@ -469,6 +471,7 @@ def _run_setup_assertions(empty_init, expected_device):
_run_setup_assertions(empty_init=True, expected_device=torch.device("meta"))


@pytest.mark.filterwarnings("ignore::FutureWarning")
@RunIf(min_cuda_gpus=2, standalone=True)
def test_save_filter(tmp_path):
fabric = Fabric(accelerator="cuda", strategy=FSDPStrategy(state_dict_type="full"), devices=2)
Expand Down Expand Up @@ -602,6 +605,7 @@ def test_clip_gradients(clip_type, precision):
optimizer.zero_grad()


@pytest.mark.filterwarnings("ignore::FutureWarning")
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.3.0")
def test_save_sharded_and_consolidate_and_load(tmp_path):
"""Test the consolidation of a FSDP-sharded checkpoint into a single file."""
Expand Down
Loading
Loading