From 2cfde9c84e31d4178c7a5eada7e20354a764c76d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 12 Aug 2022 10:24:04 +0200 Subject: [PATCH] Replace unwrapping logic in strategies (#13738) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ Co-authored-by: Rohit Gupta --- src/pytorch_lightning/CHANGELOG.md | 10 +++ src/pytorch_lightning/overrides/base.py | 78 +++++++++++++++---- .../overrides/data_parallel.py | 24 ++++-- .../overrides/distributed.py | 11 ++- src/pytorch_lightning/overrides/fairscale.py | 29 +++++-- .../plugins/precision/sharded_native_amp.py | 2 +- src/pytorch_lightning/strategies/bagua.py | 27 +++---- src/pytorch_lightning/strategies/ddp.py | 8 +- src/pytorch_lightning/strategies/deepspeed.py | 18 ++--- src/pytorch_lightning/strategies/ipu.py | 8 +- src/pytorch_lightning/strategies/parallel.py | 5 -- src/pytorch_lightning/strategies/sharded.py | 25 +++--- .../strategies/sharded_spawn.py | 26 +++---- src/pytorch_lightning/strategies/strategy.py | 11 +-- src/pytorch_lightning/strategies/tpu_spawn.py | 10 +-- src/pytorch_lightning/trainer/trainer.py | 8 +- tests/tests_pytorch/accelerators/test_ipu.py | 50 ++++++------ .../deprecated_api/test_remove_1-10.py | 44 +++++++++++ tests/tests_pytorch/helpers/runif.py | 2 +- tests/tests_pytorch/models/test_amp.py | 9 --- tests/tests_pytorch/overrides/test_base.py | 3 +- .../precision/test_sharded_precision.py | 2 +- .../strategies/test_sharded_strategy.py | 10 +-- .../connectors/test_callback_connector.py | 8 +- .../trainer/flags/test_overfit_batches.py | 2 +- tests/tests_pytorch/utilities/test_imports.py | 3 +- 26 files changed, 274 insertions(+), 159 deletions(-) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 714d4340f1ba1..4f986257f33ed 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -28,6 +28,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Updated compatibility for LightningLite to run with the latest DeepSpeed 0.7.0 ([13967](https://github.com/Lightning-AI/lightning/pull/13967)) +- Replaced the unwrapping logic in strategies with direct access to unwrapped `LightningModule` ([#13738](https://github.com/Lightning-AI/lightning/pull/13738)) + + ### Deprecated - Deprecated `LightningDeepSpeedModule` ([#14000](https://github.com/Lightning-AI/lightning/pull/14000)) @@ -39,6 +42,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated the calls to `pytorch_lightning.utiltiies.meta` functions in favor of built-in https://github.com/pytorch/torchdistx support ([#13868](https://github.com/Lightning-AI/lightning/pull/13868)) +- Deprecated the `unwrap_lightning_module` and `unwrap_lightning_module_sharded` utility functions in favor of accessing the unwrapped `LightningModule` on the strategy directly ([#13738](https://github.com/Lightning-AI/lightning/pull/13738)) + + +- Deprecated the `pl_module` argument in `LightningParallelModule`, `LightningDistributedModule`, `LightningShardedDataParallel`, `LightningBaguaModule` and `LightningDeepSpeedModule` wrapper classes ([#13738](https://github.com/Lightning-AI/lightning/pull/13738)) + + + ### Removed - Removed the deprecated `Trainer.training_type_plugin` property in favor of `Trainer.strategy` ([#14011](https://github.com/Lightning-AI/lightning/pull/14011)) diff --git a/src/pytorch_lightning/overrides/base.py b/src/pytorch_lightning/overrides/base.py index 3e9fda2f966f5..07f30c271b207 100644 --- a/src/pytorch_lightning/overrides/base.py +++ b/src/pytorch_lightning/overrides/base.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Union +from typing import Any, Optional, Union import torch import torch.nn as nn @@ -20,6 +20,7 @@ import pytorch_lightning as pl from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin +from pytorch_lightning.utilities import rank_zero_deprecation class _LightningPrecisionModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module): @@ -54,30 +55,47 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: class _LightningModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module): - def __init__(self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]) -> None: + def __init__( + self, forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] + ) -> None: """Wraps the user's LightningModule and redirects the forward call to the appropriate method, either ``training_step``, ``validation_step``, ``test_step``, or ``predict_step``. Inheriting classes may also modify the inputs or outputs of forward. Args: - pl_module: the model to wrap + forward_module: The module to wrap. If it's not a LightningModule, it must have an attribute ``.module`` + pointing to a LightningModule reference. """ super().__init__() - self.module = pl_module + if not isinstance(forward_module, pl.LightningModule) and ( + not isinstance(getattr(forward_module, "module", None), pl.LightningModule) + ): + raise ValueError( + "`forward_module` must be a `LightningModule` instance or have an attribute `.module` pointing to one," + f" got: {forward_module.__class__.__qualname__}" + ) + # TODO: In v1.10.0, remove the Optional type from forward_module and remove the assertion + assert forward_module is not None + self._forward_module = forward_module # set the parameters_to_ignore from LightningModule. - _ddp_params_and_buffers_to_ignore = getattr(pl_module, "_ddp_params_and_buffers_to_ignore", []) + _ddp_params_and_buffers_to_ignore = getattr(self._forward_module, "_ddp_params_and_buffers_to_ignore", []) self._ddp_params_and_buffers_to_ignore = [f"module.{p}" for p in _ddp_params_and_buffers_to_ignore] + @property + def lightning_module(self) -> "pl.LightningModule": + if isinstance(self._forward_module, pl.LightningModule): + return self._forward_module + return self._forward_module.module + def forward(self, *inputs: Any, **kwargs: Any) -> Any: - pl_module = unwrap_lightning_module(self.module) + pl_module = self.lightning_module trainer = pl_module._trainer if trainer is not None: - assert isinstance(self.module, (pl.LightningModule, _LightningPrecisionModuleWrapperBase)) if trainer.training: - output = self.module.training_step(*inputs, **kwargs) + output = self._forward_module.training_step(*inputs, **kwargs) # In manual_optimization, we need to prevent DDP reducer as # it is done manually in `LightningModule.manual_backward` # `require_backward_grad_sync` will be reset in the @@ -86,27 +104,53 @@ def forward(self, *inputs: Any, **kwargs: Any) -> Any: trainer.model.require_backward_grad_sync = False # type: ignore[assignment] return output if trainer.testing: - return self.module.test_step(*inputs, **kwargs) + return self._forward_module.test_step(*inputs, **kwargs) if trainer.sanity_checking or trainer.validating: - return self.module.validation_step(*inputs, **kwargs) + return self._forward_module.validation_step(*inputs, **kwargs) if trainer.predicting: - return self.module.predict_step(*inputs, **kwargs) - return self.module(*inputs, **kwargs) - - -def unwrap_lightning_module(wrapped_model: nn.Module) -> "pl.LightningModule": + return self._forward_module.predict_step(*inputs, **kwargs) + return self._forward_module(*inputs, **kwargs) + + @classmethod + def _validate_init_arguments( + cls, + pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, + forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, + ) -> None: + # TODO: In v1.10.0, remove this method and mark the forward_module init argument in all subclasses as required + if pl_module is not None: + rank_zero_deprecation( + f"The argument `pl_module` in `{cls.__name__}` is deprecated in v1.8.0 and will be removed in" + " v1.10.0. Please use `forward_module` instead." + ) + elif forward_module is None: + raise ValueError("Argument `forward_module` is required.") + + +def unwrap_lightning_module(wrapped_model: nn.Module, _suppress_warning: bool = False) -> "pl.LightningModule": """Recursively unwraps a :class:`~pytorch_lightning.core.module.LightningModule` by following the ``.module`` attributes on the wrapper. + .. deprecated:: v1.8.0 + The function ``unwrap_lightning_module`` is deprecated in v1.8.0 and will be removed in v1.10.0. Access the + ``LightningModule`` directly through the strategy attribute ``Strategy.lightning_module``. + Raises: TypeError: If the unwrapping leads to a module that is not a LightningModule and that cannot be unwrapped further. """ + if not _suppress_warning: + rank_zero_deprecation( + "The function `unwrap_lightning_module` is deprecated in v1.8.0 and will be removed in v1.10.0. Access the" + " `LightningModule` directly through the strategy attribute `Strategy.lightning_module`." + ) model = wrapped_model if isinstance(model, (DistributedDataParallel, DataParallel)): model = unwrap_lightning_module(model.module) - if isinstance(model, (_LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase)): - model = unwrap_lightning_module(model.module) + if isinstance(model, _LightningModuleWrapperBase): + model = model.lightning_module + if isinstance(model, _LightningPrecisionModuleWrapperBase): + model = model.module if not isinstance(model, pl.LightningModule): raise TypeError(f"Unwrapping the module did not yield a `LightningModule`, got {type(model)} instead.") return model diff --git a/src/pytorch_lightning/overrides/data_parallel.py b/src/pytorch_lightning/overrides/data_parallel.py index 9fa253b9d8321..98d23cee391bc 100644 --- a/src/pytorch_lightning/overrides/data_parallel.py +++ b/src/pytorch_lightning/overrides/data_parallel.py @@ -13,7 +13,7 @@ # limitations under the License. import numbers import warnings -from typing import Any, cast, Union +from typing import Any, cast, Optional, Union import torch from torch import Tensor @@ -52,11 +52,23 @@ class LightningParallelModule(_LightningModuleWrapperBase): ) Args: - pl_module: the model to wrap + pl_module: The module to wrap. See description for `forward_module`. + + .. deprecated:: v1.8.0 + The argument ``pl_module`` is deprecated in v1.8.0 and will be removed in v1.10.0. Please use + ``forward_module`` instead. + + forward_module: The module to wrap. If it's not a ``LightningModule``, it must have an attribute ``.module`` + pointing to a ``LightningModule`` reference. """ - def __init__(self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]) -> None: - super().__init__(pl_module) + def __init__( + self, + forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, + pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, + ) -> None: + self._validate_init_arguments(pl_module, forward_module) + super().__init__(forward_module=(pl_module or forward_module)) _ignore_scalar_return_in_dp() def forward(self, *inputs: Any, **kwargs: Any) -> Any: @@ -65,7 +77,7 @@ def forward(self, *inputs: Any, **kwargs: Any) -> Any: output = super().forward(*inputs, **kwargs) def output_transform(data: Any) -> Any: - device = cast(torch.device, self.module.device) + device = cast(torch.device, self.lightning_module.device) data = python_scalar_to_tensor(data, device) data = unsqueeze_scalar_tensor(data) return data @@ -95,7 +107,7 @@ def find_tensor_with_device(tensor: Tensor) -> Tensor: if replica_device is not None: # by calling .to() we force the update to the self.device property - self.module.to(device=replica_device) + self._forward_module.to(device=replica_device) else: rank_zero_warn( "Could not determine on which device the inputs are." diff --git a/src/pytorch_lightning/overrides/distributed.py b/src/pytorch_lightning/overrides/distributed.py index 929d1ed486f4a..3ecac8c1eea04 100644 --- a/src/pytorch_lightning/overrides/distributed.py +++ b/src/pytorch_lightning/overrides/distributed.py @@ -19,12 +19,19 @@ from torch.nn.parallel import DistributedDataParallel from torch.utils.data import BatchSampler, Dataset, DistributedSampler, Sampler -from pytorch_lightning.overrides.base import _LightningModuleWrapperBase +import pytorch_lightning as pl +from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase from pytorch_lightning.utilities.exceptions import MisconfigurationException class LightningDistributedModule(_LightningModuleWrapperBase): - ... + def __init__( + self, + forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, + pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, + ) -> None: + self._validate_init_arguments(pl_module, forward_module) + super().__init__(forward_module=(pl_module or forward_module)) def _find_tensors( diff --git a/src/pytorch_lightning/overrides/fairscale.py b/src/pytorch_lightning/overrides/fairscale.py index f48fa8dcf9ccf..d9fd2e60aff61 100644 --- a/src/pytorch_lightning/overrides/fairscale.py +++ b/src/pytorch_lightning/overrides/fairscale.py @@ -11,27 +11,44 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional, Union + import torch.nn as nn import pytorch_lightning as pl -from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, unwrap_lightning_module -from pytorch_lightning.utilities import _IS_WINDOWS, _module_available +from pytorch_lightning.overrides.base import ( + _LightningModuleWrapperBase, + _LightningPrecisionModuleWrapperBase, + unwrap_lightning_module, +) +from pytorch_lightning.utilities import rank_zero_deprecation +from pytorch_lightning.utilities.imports import _IS_WINDOWS, _module_available _FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _module_available("fairscale.nn") -if _FAIRSCALE_AVAILABLE: + +if _FAIRSCALE_AVAILABLE: # pragma: no-cover from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel class LightningShardedDataParallel(_LightningModuleWrapperBase): - # Just do this for later docstrings - pass + def __init__( + self, + forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, + pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, + ) -> None: + self._validate_init_arguments(pl_module, forward_module) + super().__init__(forward_module=(pl_module or forward_module)) def unwrap_lightning_module_sharded(wrapped_model: nn.Module) -> "pl.LightningModule": + rank_zero_deprecation( + "The function `unwrap_lightning_module_sharded` is deprecated in v1.8.0 and will be removed in v1.10.0." + " Access the `LightningModule` directly through the strategy attribute `Strategy.lightning_module`." + ) model = wrapped_model if isinstance(model, ShardedDataParallel): model = model.module - return unwrap_lightning_module(model) + return unwrap_lightning_module(model, _suppress_warning=True) else: LightningShardedDataParallel = ... # type: ignore[assignment,misc] diff --git a/src/pytorch_lightning/plugins/precision/sharded_native_amp.py b/src/pytorch_lightning/plugins/precision/sharded_native_amp.py index f5646c2094253..570e25bd85caa 100644 --- a/src/pytorch_lightning/plugins/precision/sharded_native_amp.py +++ b/src/pytorch_lightning/plugins/precision/sharded_native_amp.py @@ -15,9 +15,9 @@ import torch -from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE if _FAIRSCALE_AVAILABLE: from fairscale.optim import OSS diff --git a/src/pytorch_lightning/strategies/bagua.py b/src/pytorch_lightning/strategies/bagua.py index d100d1aa97adc..f08d1aebf1b7c 100644 --- a/src/pytorch_lightning/strategies/bagua.py +++ b/src/pytorch_lightning/strategies/bagua.py @@ -7,11 +7,7 @@ from torch.nn import Module import pytorch_lightning as pl -from pytorch_lightning.overrides.base import ( - _LightningModuleWrapperBase, - _LightningPrecisionModuleWrapperBase, - unwrap_lightning_module, -) +from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin @@ -54,10 +50,16 @@ class LightningBaguaModule(_LightningModuleWrapperBase): - def __init__(self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]) -> None: - super().__init__(pl_module) + def __init__( + self, + forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, + pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, + ) -> None: + self._validate_init_arguments(pl_module, forward_module) + forward_module = pl_module or forward_module + super().__init__(forward_module=forward_module) # Bagua use `bagua_module_name` to distinguish different modules - self._bagua_module_name = f"{pl_module.__class__.__name__}{id(pl_module)}" + self._bagua_module_name = f"{forward_module.__class__.__name__}{id(forward_module)}" class BaguaStrategy(DDPStrategy): @@ -109,13 +111,6 @@ def __init__( self._bagua_flatten = flatten self._bagua_kwargs = bagua_kwargs - @property - def lightning_module(self) -> Optional["pl.LightningModule"]: - model = self.model - if isinstance(model, BaguaDistributedDataParallel): - model = model.module - return unwrap_lightning_module(model) if model is not None else None - def setup_distributed(self) -> None: reset_seed() @@ -190,7 +185,7 @@ def _check_qadam_optimizer(self) -> None: def _configure_bagua_model(self, trainer: "pl.Trainer") -> None: model = LightningBaguaModule(self.model) # type: ignore[arg-type] - self._model = self._setup_model(model) + self.model = self._setup_model(model) # start the background communication for async algorithm if trainer.training and self._bagua_algorithm == "async": diff --git a/src/pytorch_lightning/strategies/ddp.py b/src/pytorch_lightning/strategies/ddp.py index 57ab3a151b011..f4f5397a78bca 100644 --- a/src/pytorch_lightning/strategies/ddp.py +++ b/src/pytorch_lightning/strategies/ddp.py @@ -34,7 +34,6 @@ from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase from pytorch_lightning.overrides.distributed import prepare_for_backward -from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin @@ -55,7 +54,12 @@ sync_ddp_if_available, ) from pytorch_lightning.utilities.exceptions import DeadlockDetectedException -from pytorch_lightning.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_10, _TORCH_GREATER_EQUAL_1_11 +from pytorch_lightning.utilities.imports import ( + _FAIRSCALE_AVAILABLE, + _IS_WINDOWS, + _TORCH_GREATER_EQUAL_1_10, + _TORCH_GREATER_EQUAL_1_11, +) from pytorch_lightning.utilities.optimizer import optimizers_to_device from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.seed import reset_seed diff --git a/src/pytorch_lightning/strategies/deepspeed.py b/src/pytorch_lightning/strategies/deepspeed.py index 8acbc80257bd1..4a70eb983fd86 100644 --- a/src/pytorch_lightning/strategies/deepspeed.py +++ b/src/pytorch_lightning/strategies/deepspeed.py @@ -77,10 +77,14 @@ class LightningDeepSpeedModule(_LightningModuleWrapperBase): """ def __init__( - self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase], precision: Union[str, int] + self, + forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, + precision: Union[str, int] = 32, + pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, ) -> None: rank_zero_deprecation("`LightningDeepSpeedModule` has been deprecated in v1.7.1 and will be removed in v1.9.0") - super().__init__(pl_module) + self._validate_init_arguments(pl_module, forward_module) + super().__init__(forward_module=(pl_module or forward_module)) self.precision = precision def forward(self, *inputs: Any, **kwargs: Any) -> Any: @@ -485,7 +489,7 @@ def init_deepspeed(self) -> None: ) assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase)) - model = _LightningModuleWrapperBase(pl_module=self.model) + model = _LightningModuleWrapperBase(forward_module=self.model) if self.lightning_module.trainer and self.lightning_module.trainer.training: self._initialize_deepspeed_train(model) @@ -611,14 +615,6 @@ def _initialize_deepspeed_inference(self, model: Module) -> None: ) self.model = model - @property - def lightning_module(self) -> Optional["pl.LightningModule"]: - # the model may not be wrapped with DeepEngine & _LightningModuleWrapperBase if calling this too early - module = getattr(self.model, "module", self.model) - module = module.module if isinstance(module, _LightningModuleWrapperBase) else module - assert isinstance(module, pl.LightningModule) or module is None - return module - @property def distributed_sampler_kwargs(self) -> Dict[str, int]: distributed_sampler_kwargs = dict(num_replicas=self.world_size, rank=self.global_rank) diff --git a/src/pytorch_lightning/strategies/ipu.py b/src/pytorch_lightning/strategies/ipu.py index 4bedbfd6d70fc..f56c095dc12c1 100644 --- a/src/pytorch_lightning/strategies/ipu.py +++ b/src/pytorch_lightning/strategies/ipu.py @@ -51,10 +51,14 @@ class LightningIPUModule(_LightningModuleWrapperBase): """ def __init__( - self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase], precision: Union[str, int] + self, + forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, + precision: Union[str, int] = 32, + pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, ) -> None: rank_zero_deprecation("`LightningIPUModule` has been deprecated in v1.7.0 and will be removed in v1.8.0") - super().__init__(pl_module) + self._validate_init_arguments(pl_module, forward_module) + super().__init__(forward_module=(pl_module or forward_module)) self.precision = precision def forward(self, *inputs: Any, **kwargs: Any) -> Any: diff --git a/src/pytorch_lightning/strategies/parallel.py b/src/pytorch_lightning/strategies/parallel.py index 2517848274e3d..9d469313103a1 100644 --- a/src/pytorch_lightning/strategies/parallel.py +++ b/src/pytorch_lightning/strategies/parallel.py @@ -19,7 +19,6 @@ from torch import Tensor import pytorch_lightning as pl -from pytorch_lightning.overrides.base import unwrap_lightning_module from pytorch_lightning.plugins import LayerSync from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO @@ -55,10 +54,6 @@ def __init__( def root_device(self) -> torch.device: """Return the root device.""" - @property - def lightning_module(self) -> Optional["pl.LightningModule"]: - return unwrap_lightning_module(self.model) if self.model is not None else None - @property def global_rank(self) -> int: return self.cluster_environment.global_rank() if self.cluster_environment is not None else 0 diff --git a/src/pytorch_lightning/strategies/sharded.py b/src/pytorch_lightning/strategies/sharded.py index 01401bd53bb56..ce1e4cd96b961 100644 --- a/src/pytorch_lightning/strategies/sharded.py +++ b/src/pytorch_lightning/strategies/sharded.py @@ -20,20 +20,18 @@ import pytorch_lightning as pl from pytorch_lightning.core.optimizer import LightningOptimizer -from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE +from pytorch_lightning.overrides.base import _LightningModuleWrapperBase from pytorch_lightning.strategies.ddp import DDPStrategy from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.enums import PrecisionType from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE +from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE from pytorch_lightning.utilities.optimizer import optimizers_to_device from pytorch_lightning.utilities.rank_zero import rank_zero_only if _FAIRSCALE_AVAILABLE: from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel from fairscale.optim import OSS - - from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel, unwrap_lightning_module_sharded else: OSS = ShardedDataParallel = object @@ -44,6 +42,14 @@ class DDPShardedStrategy(DDPStrategy): strategy_name = "ddp_sharded" _REDUCE_BUFFER_SIZE_DEFAULT: int = 2**23 # 8M + def connect(self, model: "pl.LightningModule") -> None: + if not _FAIRSCALE_AVAILABLE: # pragma: no cover + raise MisconfigurationException( + "`DDPShardedStrategy` requires `fairscale` to be installed." + " Install it by running `pip install fairscale`." + ) + return super().connect(model) + def setup(self, trainer: "pl.Trainer") -> None: # share ddp pids to all processes self._rank_0_will_call_children_scripts = self.broadcast(self._rank_0_will_call_children_scripts) @@ -70,7 +76,7 @@ def configure_ddp(self) -> None: self._set_ddp_kwargs() self.setup_optimizers(self.model.trainer) self.model, self.optimizers = self._setup_model_and_optimizers( - model=LightningShardedDataParallel(self.model), + model=_LightningModuleWrapperBase(self.model), optimizers=self.optimizers, ) optimizers_to_device(self.optimizers, self.root_device) @@ -128,15 +134,6 @@ def _optim_state_dict(self, optimizer): """ return optimizer.state_dict() - @property - def lightning_module(self) -> Optional["pl.LightningModule"]: - if not _FAIRSCALE_AVAILABLE: # pragma: no cover - raise MisconfigurationException( - "`DDPShardedStrategy` requires `fairscale` to be installed." - " Install it by running `pip install fairscale`." - ) - return unwrap_lightning_module_sharded(self.model) if self.model is not None else None - def pre_backward(self, closure_loss: Tensor) -> None: pass diff --git a/src/pytorch_lightning/strategies/sharded_spawn.py b/src/pytorch_lightning/strategies/sharded_spawn.py index 882302e101cb6..f19aae7302eea 100644 --- a/src/pytorch_lightning/strategies/sharded_spawn.py +++ b/src/pytorch_lightning/strategies/sharded_spawn.py @@ -12,18 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import contextmanager -from typing import Any, Dict, Generator, List, Optional, Tuple +from typing import Any, Dict, Generator, List, Tuple from torch import Tensor from torch.nn import Module from torch.optim import Optimizer import pytorch_lightning as pl -from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase -from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE +from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE from pytorch_lightning.utilities.optimizer import optimizers_to_device from pytorch_lightning.utilities.rank_zero import rank_zero_only @@ -31,7 +31,6 @@ from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel from fairscale.optim import OSS - from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel, unwrap_lightning_module_sharded else: OSS = ShardedDataParallel = object @@ -41,13 +40,21 @@ class DDPSpawnShardedStrategy(DDPSpawnStrategy): strategy_name = "ddp_sharded_spawn" + def connect(self, model: "pl.LightningModule") -> None: + if not _FAIRSCALE_AVAILABLE: # pragma: no cover + raise MisconfigurationException( + "`DDPSpawnShardedStrategy` requires `fairscale` to be installed." + " Install it by running `pip install fairscale`." + ) + return super().connect(model) + def configure_ddp(self) -> None: # set up optimizers after the wrapped module has been moved to the device assert self.lightning_module is not None self.setup_optimizers(self.lightning_module.trainer) assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase)) self.model, self.optimizers = self._setup_model_and_optimizers( - model=LightningShardedDataParallel(self.model), optimizers=self.optimizers + model=_LightningModuleWrapperBase(self.model), optimizers=self.optimizers ) optimizers_to_device(self.optimizers, self.root_device) @@ -104,15 +111,6 @@ def _optim_state_dict(self, optimizer: Optimizer) -> Dict[str, Any]: """ return optimizer.state_dict() - @property - def lightning_module(self) -> Optional["pl.LightningModule"]: - if not _FAIRSCALE_AVAILABLE: # pragma: no cover - raise MisconfigurationException( - "`DDPSpawnShardedStrategy` requires `fairscale` to be installed." - " Install it by running `pip install fairscale`." - ) - return unwrap_lightning_module_sharded(self.model) if self.model is not None else None - def pre_backward(self, closure_loss: Tensor) -> None: pass diff --git a/src/pytorch_lightning/strategies/strategy.py b/src/pytorch_lightning/strategies/strategy.py index 59f1e37095e60..c09e7eae8c586 100644 --- a/src/pytorch_lightning/strategies/strategy.py +++ b/src/pytorch_lightning/strategies/strategy.py @@ -24,7 +24,6 @@ import pytorch_lightning as pl from pytorch_lightning.core.optimizer import _init_optimizers_and_lr_schedulers, LightningOptimizer -from pytorch_lightning.overrides.base import unwrap_lightning_module from pytorch_lightning.plugins import TorchCheckpointIO from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO @@ -62,8 +61,9 @@ def __init__( self._accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = accelerator self._checkpoint_io: Optional[CheckpointIO] = checkpoint_io self._precision_plugin: Optional[PrecisionPlugin] = precision_plugin - self._launcher: Optional[_Launcher] = None + self._lightning_module: Optional[pl.LightningModule] = None self._model: Optional[Module] = None + self._launcher: Optional[_Launcher] = None self._optimizers: List[Optimizer] = [] self._lightning_optimizers: Dict[int, LightningOptimizer] = {} self.lr_scheduler_configs: List[LRSchedulerConfig] = [] @@ -113,8 +113,9 @@ def optimizers(self, optimizers: List[Optimizer]) -> None: idx: LightningOptimizer._to_lightning_optimizer(opt, self, idx) for idx, opt in enumerate(self.optimizers) } - def connect(self, model: Module) -> None: + def connect(self, model: "pl.LightningModule") -> None: """Called by the accelerator to connect the accelerator and the model with this plugin.""" + self._lightning_module = model self.model = model def _configure_launcher(self) -> None: @@ -328,7 +329,7 @@ def post_backward(self, closure_loss: Tensor) -> None: @property def model(self) -> Optional[Module]: """Returns the potentially wrapped LightningModule.""" - return self._model + return self._model if self._model is not None else self._lightning_module @model.setter def model(self, new_model: Optional[Module]) -> None: @@ -337,7 +338,7 @@ def model(self, new_model: Optional[Module]) -> None: @property def lightning_module(self) -> Optional["pl.LightningModule"]: """Returns the pure LightningModule without potential wrappers.""" - return unwrap_lightning_module(self.model) if self.model is not None else None + return self._lightning_module def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: torch.cuda.empty_cache() diff --git a/src/pytorch_lightning/strategies/tpu_spawn.py b/src/pytorch_lightning/strategies/tpu_spawn.py index 62bb1c308480b..5ca8db74c4620 100644 --- a/src/pytorch_lightning/strategies/tpu_spawn.py +++ b/src/pytorch_lightning/strategies/tpu_spawn.py @@ -124,7 +124,7 @@ def _validate_patched_dataloaders(model: "pl.LightningModule") -> None: assert not isinstance(source.instance, (pl.LightningModule, pl.LightningDataModule)) TPUSpawnStrategy._validate_dataloader(source.instance) - def connect(self, model: "pl.LightningModule") -> None: # type: ignore + def connect(self, model: "pl.LightningModule") -> None: TPUSpawnStrategy._validate_patched_dataloaders(model) self.wrapped_model = xmp.MpModelWrapper(LightningDistributedModule(model)) return super().connect(model) @@ -139,11 +139,11 @@ def setup(self, trainer: "pl.Trainer") -> None: if self.debug: os.environ["PT_XLA_DEBUG"] = "1" - assert self.model - shared_params = find_shared_parameters(self.model) + assert self.lightning_module + shared_params = find_shared_parameters(self.lightning_module) self.model_to_device() - assert isinstance(self.model.module, Module) - set_shared_parameters(self.model.module, shared_params) + + set_shared_parameters(self.lightning_module, shared_params) self.setup_precision_plugin() if trainer.state.fn == TrainerFn.FITTING: diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index 6853c4328af46..5983324f2f62d 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -696,7 +696,7 @@ def fit( """ if not isinstance(model, pl.LightningModule): raise TypeError(f"`Trainer.fit()` requires a `LightningModule`, got: {model.__class__.__qualname__}") - self.strategy.model = model + self.strategy._lightning_module = model self._call_and_handle_interrupt( self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path ) @@ -778,7 +778,7 @@ def validate( """ if model is not None and not isinstance(model, pl.LightningModule): raise TypeError(f"`Trainer.validate()` requires a `LightningModule`, got: {model.__class__.__qualname__}") - self.strategy.model = model or self.lightning_module + self.strategy._lightning_module = model or self.lightning_module return self._call_and_handle_interrupt(self._validate_impl, model, dataloaders, ckpt_path, verbose, datamodule) def _validate_impl( @@ -868,7 +868,7 @@ def test( """ if model is not None and not isinstance(model, pl.LightningModule): raise TypeError(f"`Trainer.test()` requires a `LightningModule`, got: {model.__class__.__qualname__}") - self.strategy.model = model or self.lightning_module + self.strategy._lightning_module = model or self.lightning_module return self._call_and_handle_interrupt(self._test_impl, model, dataloaders, ckpt_path, verbose, datamodule) def _test_impl( @@ -957,7 +957,7 @@ def predict( """ if model is not None and not isinstance(model, pl.LightningModule): raise TypeError(f"`Trainer.predict()` requires a `LightningModule`, got: {model.__class__.__qualname__}") - self.strategy.model = model or self.lightning_module + self.strategy._lightning_module = model or self.lightning_module return self._call_and_handle_interrupt( self._predict_impl, model, dataloaders, datamodule, return_predictions, ckpt_path ) diff --git a/tests/tests_pytorch/accelerators/test_ipu.py b/tests/tests_pytorch/accelerators/test_ipu.py index db3b9d1f91952..470cb4a028bed 100644 --- a/tests/tests_pytorch/accelerators/test_ipu.py +++ b/tests/tests_pytorch/accelerators/test_ipu.py @@ -99,7 +99,7 @@ def test_epoch_end(self, outputs) -> None: @pytest.mark.skipif(_IPU_AVAILABLE, reason="test requires non-IPU machine") @mock.patch("pytorch_lightning.accelerators.ipu.IPUAccelerator.is_available", return_value=True) -def test_fail_if_no_ipus(mock_ipu_acc_avail, tmpdir): +def test_fail_if_no_ipus(_, tmpdir): with pytest.raises(MisconfigurationException, match="IPU Accelerator requires IPU devices to run"): Trainer(default_root_dir=tmpdir, accelerator="ipu", devices=1) @@ -118,7 +118,7 @@ def test_warning_if_ipus_not_used(): @RunIf(ipu=True) -def test_no_warning_plugin(tmpdir): +def test_no_warning_strategy(tmpdir): with pytest.warns(None) as record: Trainer(default_root_dir=tmpdir, max_epochs=1, strategy=IPUStrategy(training_opts=poptorch.Options())) assert len(record) == 0 @@ -235,7 +235,7 @@ def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: @RunIf(ipu=True) -def test_device_iterations_ipu_plugin(tmpdir): +def test_device_iterations_ipu_strategy(tmpdir): class TestCallback(Callback): def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: assert trainer.strategy.device_iterations == 2 @@ -442,10 +442,10 @@ def test_manual_poptorch_opts_custom(tmpdir): class TestCallback(Callback): def on_fit_end(self, trainer: Trainer, pl_module: LightningModule) -> None: # ensure dataloaders were correctly set up during training. - plugin = trainer.strategy - assert isinstance(plugin, IPUStrategy) - assert plugin.training_opts.replication_factor == 2 - assert plugin.inference_opts.replication_factor == 1 + strategy = trainer.strategy + assert isinstance(strategy, IPUStrategy) + assert strategy.training_opts.replication_factor == 2 + assert strategy.inference_opts.replication_factor == 1 val_dataloader = trainer.val_dataloaders[0] train_dataloader = trainer.train_dataloader @@ -456,21 +456,21 @@ def on_fit_end(self, trainer: Trainer, pl_module: LightningModule) -> None: assert train_dataloader.options.replication_factor == 2 assert val_dataloader.options.replication_factor == 1 - plugin = IPUStrategy(inference_opts=inference_opts, training_opts=training_opts) + strategy = IPUStrategy(inference_opts=inference_opts, training_opts=training_opts) # ensure we default to the training options replication factor - assert plugin.replication_factor == 2 - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, strategy=plugin, callbacks=TestCallback()) + assert strategy.replication_factor == 2 + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, strategy=strategy, callbacks=TestCallback()) trainer.fit(model) - plugin = trainer.strategy - assert isinstance(plugin, IPUStrategy) + strategy = trainer.strategy + assert isinstance(strategy, IPUStrategy) - training_opts = plugin.training_opts + training_opts = strategy.training_opts assert training_opts.device_iterations == 8 assert training_opts.replication_factor == 2 assert training_opts.Training.gradient_accumulation == 2 - inference_opts = plugin.inference_opts + inference_opts = strategy.inference_opts assert inference_opts.device_iterations == 16 assert inference_opts.replication_factor == 1 assert inference_opts.Training.gradient_accumulation == 1 @@ -481,8 +481,8 @@ def test_replication_factor(tmpdir): """Ensure if the user passes manual poptorch Options with custom parameters set, we set them correctly in the dataloaders.""" - plugin = IPUStrategy() - trainer = Trainer(accelerator="ipu", devices=2, default_root_dir=tmpdir, fast_dev_run=True, strategy=plugin) + strategy = IPUStrategy() + trainer = Trainer(accelerator="ipu", devices=2, default_root_dir=tmpdir, fast_dev_run=True, strategy=strategy) assert isinstance(trainer.accelerator, IPUAccelerator) assert trainer.num_devices == 2 assert trainer.strategy.replication_factor == 2 @@ -492,11 +492,11 @@ def test_replication_factor(tmpdir): inference_opts = poptorch.Options() training_opts.replicationFactor(8) inference_opts.replicationFactor(7) - plugin = IPUStrategy(inference_opts=inference_opts, training_opts=training_opts) + strategy = IPUStrategy(inference_opts=inference_opts, training_opts=training_opts) - trainer = Trainer(default_root_dir=tmpdir, accelerator="ipu", devices=1, strategy=plugin) + trainer = Trainer(default_root_dir=tmpdir, accelerator="ipu", devices=1, strategy=strategy) trainer.optimizers = model.configure_optimizers()[0] - plugin.model = model + strategy._lightning_module = model model.trainer = trainer trainer.state.fn = TrainerFn.FITTING trainer.strategy.setup(trainer) @@ -551,7 +551,7 @@ def configure_optimizers(self): @RunIf(ipu=True) -def test_precision_plugin(tmpdir): +def test_precision_plugin(): """Ensure precision plugin value is set correctly.""" plugin = IPUPrecisionPlugin(precision=16) @@ -606,13 +606,13 @@ def test_set_devices_if_none_ipu(): @RunIf(ipu=True) -def test_strategy_choice_ipu_plugin(tmpdir): +def test_strategy_choice_ipu_strategy(): trainer = Trainer(strategy=IPUStrategy(), accelerator="ipu", devices=8) assert isinstance(trainer.strategy, IPUStrategy) @RunIf(ipu=True) -def test_device_type_when_ipu_strategy_passed(tmpdir): +def test_device_type_when_ipu_strategy_passed(): trainer = Trainer(strategy=IPUStrategy(), accelerator="ipu", devices=8) assert isinstance(trainer.strategy, IPUStrategy) assert isinstance(trainer.accelerator, IPUAccelerator) @@ -620,11 +620,11 @@ def test_device_type_when_ipu_strategy_passed(tmpdir): @RunIf(ipu=True) def test_poptorch_models_at_different_stages(tmpdir): - plugin = IPUStrategy() - trainer = Trainer(default_root_dir=tmpdir, strategy=plugin, accelerator="ipu", devices=8) + strategy = IPUStrategy() + trainer = Trainer(default_root_dir=tmpdir, strategy=strategy, accelerator="ipu", devices=8) model = BoringModel() model.trainer = trainer - plugin.model = model + strategy._lightning_module = model trainer.optimizers = model.configure_optimizers()[0] trainer.state.fn = TrainerFn.FITTING diff --git a/tests/tests_pytorch/deprecated_api/test_remove_1-10.py b/tests/tests_pytorch/deprecated_api/test_remove_1-10.py index 6a0a458c6c041..186e526313bba 100644 --- a/tests/tests_pytorch/deprecated_api/test_remove_1-10.py +++ b/tests/tests_pytorch/deprecated_api/test_remove_1-10.py @@ -11,11 +11,55 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Test deprecated functionality which will be removed in v1.10.0.""" import pytest from pytorch_lightning import Trainer +from pytorch_lightning.demos.boring_classes import BoringModel +from pytorch_lightning.overrides import LightningDistributedModule, LightningParallelModule +from pytorch_lightning.overrides.base import unwrap_lightning_module +from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel, unwrap_lightning_module_sharded +from pytorch_lightning.strategies.bagua import LightningBaguaModule +from pytorch_lightning.strategies.deepspeed import LightningDeepSpeedModule +from pytorch_lightning.strategies.ipu import LightningIPUModule +from tests_pytorch.helpers.runif import RunIf +from tests_pytorch.helpers.utils import no_warning_call def test_deprecated_amp_level(): with pytest.deprecated_call(match="Setting `amp_level` inside the `Trainer` is deprecated in v1.8.0"): Trainer(amp_level="O3", amp_backend="apex") + + +@pytest.mark.parametrize( + "wrapper_class", + [ + LightningParallelModule, + LightningDistributedModule, + LightningBaguaModule, + LightningDeepSpeedModule, + pytest.param(LightningShardedDataParallel, marks=RunIf(fairscale=True)), + LightningIPUModule, + ], +) +def test_v1_10_deprecated_pl_module_init_parameter(wrapper_class): + with no_warning_call( + DeprecationWarning, match=rf"The argument `pl_module` in `{wrapper_class.__name__}` is deprecated in v1.8.0" + ): + wrapper_class(BoringModel()) + + with pytest.deprecated_call( + match=rf"The argument `pl_module` in `{wrapper_class.__name__}` is deprecated in v1.8.0" + ): + wrapper_class(pl_module=BoringModel()) + + +def test_v1_10_deprecated_unwrap_lightning_module(): + with pytest.deprecated_call(match=r"The function `unwrap_lightning_module` is deprecated in v1.8.0"): + unwrap_lightning_module(BoringModel()) + + +@RunIf(fairscale=True) +def test_v1_10_deprecated_unwrap_lightning_module_sharded(): + with pytest.deprecated_call(match=r"The function `unwrap_lightning_module_sharded` is deprecated in v1.8.0"): + unwrap_lightning_module_sharded(BoringModel()) diff --git a/tests/tests_pytorch/helpers/runif.py b/tests/tests_pytorch/helpers/runif.py index abbca75f626ad..4074eaf725e1f 100644 --- a/tests/tests_pytorch/helpers/runif.py +++ b/tests/tests_pytorch/helpers/runif.py @@ -22,11 +22,11 @@ from pytorch_lightning.accelerators.mps import _MPS_AVAILABLE from pytorch_lightning.callbacks.progress.rich_progress import _RICH_AVAILABLE -from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning.strategies.bagua import _BAGUA_AVAILABLE from pytorch_lightning.strategies.deepspeed import _DEEPSPEED_AVAILABLE from pytorch_lightning.utilities.imports import ( _APEX_AVAILABLE, + _FAIRSCALE_AVAILABLE, _FAIRSCALE_FULLY_SHARDED_AVAILABLE, _HIVEMIND_AVAILABLE, _HOROVOD_AVAILABLE, diff --git a/tests/tests_pytorch/models/test_amp.py b/tests/tests_pytorch/models/test_amp.py index 159a3767c1df2..786de99f59714 100644 --- a/tests/tests_pytorch/models/test_amp.py +++ b/tests/tests_pytorch/models/test_amp.py @@ -96,8 +96,6 @@ def test_amp_cpus(tmpdir, strategy, precision, devices): trainer.test(model) trainer.predict(model, DataLoader(RandomDataset(32, 64))) - assert trainer.state.finished, f"Training failed with {trainer.state}" - @RunIf(min_cuda_gpus=2, min_torch="1.10") @pytest.mark.parametrize("strategy", [None, "dp", "ddp_spawn"]) @@ -121,8 +119,6 @@ def test_amp_gpus(tmpdir, strategy, precision, devices): trainer.test(model) trainer.predict(model, DataLoader(RandomDataset(32, 64))) - assert trainer.state.finished, f"Training failed with {trainer.state}" - @RunIf(min_cuda_gpus=2) @mock.patch.dict( @@ -162,9 +158,6 @@ def test_amp_gpu_ddp_slurm_managed(tmpdir): ) trainer.fit(model) - # correct result and ok accuracy - assert trainer.state.finished, "amp + ddp model failed to complete" - # test root model address assert isinstance(trainer.strategy.cluster_environment, SLURMEnvironment) assert trainer.strategy.cluster_environment.resolve_root_node_address("abc") == "abc" @@ -185,7 +178,6 @@ def test_amp_without_apex(bwd_mock, tmpdir): trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, amp_backend="apex") assert trainer.amp_backend is None trainer.fit(model) - assert trainer.state.finished, f"Training failed with {trainer.state}" assert not bwd_mock.called @@ -213,7 +205,6 @@ def configure_optimizers(self): ) assert str(trainer.amp_backend) == "AMPType.APEX" trainer.fit(model) - assert trainer.state.finished, f"Training failed with {trainer.state}" # `max_steps` is fulfilled in the third batch first optimizer, but we don't check the loop # `done` condition until all optimizers have run, so the number of backwards is higher than `max_steps` assert bwd_mock.call_count == 6 diff --git a/tests/tests_pytorch/overrides/test_base.py b/tests/tests_pytorch/overrides/test_base.py index fa07912d0d44e..27d2db688d7ae 100644 --- a/tests/tests_pytorch/overrides/test_base.py +++ b/tests/tests_pytorch/overrides/test_base.py @@ -38,4 +38,5 @@ def test_unwrap_lightning_module(): wrapped_model = _LightningModuleWrapperBase(wrapped_model) wrapped_model = DataParallel(wrapped_model) - assert unwrap_lightning_module(wrapped_model) == model + with pytest.deprecated_call(match="The function `unwrap_lightning_module` is deprecated in v1.8.0"): + assert unwrap_lightning_module(wrapped_model) == model diff --git a/tests/tests_pytorch/plugins/precision/test_sharded_precision.py b/tests/tests_pytorch/plugins/precision/test_sharded_precision.py index 0c08c8e9540eb..ab7a4a432a2c6 100644 --- a/tests/tests_pytorch/plugins/precision/test_sharded_precision.py +++ b/tests/tests_pytorch/plugins/precision/test_sharded_precision.py @@ -15,8 +15,8 @@ import pytest import torch -from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning.plugins import ShardedNativeMixedPrecisionPlugin +from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE from tests_pytorch.helpers.runif import RunIf ShardedGradScaler = None diff --git a/tests/tests_pytorch/strategies/test_sharded_strategy.py b/tests/tests_pytorch/strategies/test_sharded_strategy.py index ad0673ed1a5fa..a0abfb3f73ec0 100644 --- a/tests/tests_pytorch/strategies/test_sharded_strategy.py +++ b/tests/tests_pytorch/strategies/test_sharded_strategy.py @@ -7,9 +7,9 @@ from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.demos.boring_classes import BoringModel -from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning.strategies import DDPShardedStrategy, DDPSpawnShardedStrategy from pytorch_lightning.trainer.states import TrainerFn +from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE from tests_pytorch.helpers.runif import RunIf if _FAIRSCALE_AVAILABLE: @@ -256,8 +256,8 @@ def test_configure_ddp(tmpdir): def test_custom_kwargs_sharded(_, cls): """Tests to ensure that if custom kwargs are passed, they are set correctly.""" strategy = cls(reduce_fp16=True) - strategy.model = Mock(spec=LightningModule) - strategy.model.trainer = Mock() + strategy._lightning_module = Mock(spec=LightningModule) + strategy._lightning_module.trainer = Mock() strategy.parallel_devices = [Mock()] class_name = "sharded" if isinstance(strategy, DDPShardedStrategy) else "sharded_spawn" @@ -276,8 +276,8 @@ def test_custom_kwargs_sharded_reduce_buffer_size(_, params, expected_buffer_siz """Tests to ensure that ``reduce_buffer_size`` is correctly set based on user kwargs.""" strategy = DDPShardedStrategy(**params) strategy.num_nodes = num_nodes - strategy.model = Mock(spec=LightningModule) - strategy.model.trainer = Mock() + strategy._lightning_module = Mock(spec=LightningModule) + strategy._lightning_module.trainer = Mock() strategy.parallel_devices = [Mock()] with mock.patch("pytorch_lightning.strategies.sharded.ShardedDataParallel", autospec=True) as mock_sharded: diff --git a/tests/tests_pytorch/trainer/connectors/test_callback_connector.py b/tests/tests_pytorch/trainer/connectors/test_callback_connector.py index 02e846425a2a0..c56f3fb4d988d 100644 --- a/tests/tests_pytorch/trainer/connectors/test_callback_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_callback_connector.py @@ -56,7 +56,7 @@ def test_checkpoint_callbacks_are_last(tmpdir): # no model callbacks model = LightningModule() model.configure_callbacks = lambda: [] - trainer.model = model + trainer.strategy._lightning_module = model cb_connector = CallbackConnector(trainer) cb_connector._attach_model_callbacks() assert trainer.callbacks == [ @@ -72,7 +72,7 @@ def test_checkpoint_callbacks_are_last(tmpdir): model = LightningModule() model.configure_callbacks = lambda: [checkpoint1, early_stopping, model_summary, checkpoint2] trainer = Trainer(callbacks=[progress_bar, lr_monitor, ModelCheckpoint(tmpdir)]) - trainer.model = model + trainer.strategy._lightning_module = model cb_connector = CallbackConnector(trainer) cb_connector._attach_model_callbacks() assert trainer.callbacks == [ @@ -154,7 +154,7 @@ def _attach_callbacks(trainer_callbacks, model_callbacks): enable_model_summary=False, callbacks=trainer_callbacks, ) - trainer.model = model + trainer.strategy._lightning_module = model cb_connector = CallbackConnector(trainer) cb_connector._attach_model_callbacks() return trainer @@ -212,7 +212,7 @@ def test_attach_model_callbacks_override_info(caplog): trainer = Trainer( enable_checkpointing=False, callbacks=[EarlyStopping(monitor="foo"), LearningRateMonitor(), TQDMProgressBar()] ) - trainer.model = model + trainer.strategy._lightning_module = model cb_connector = CallbackConnector(trainer) with caplog.at_level(logging.INFO): cb_connector._attach_model_callbacks() diff --git a/tests/tests_pytorch/trainer/flags/test_overfit_batches.py b/tests/tests_pytorch/trainer/flags/test_overfit_batches.py index 32f0b8938caf6..da3e154349e1b 100644 --- a/tests/tests_pytorch/trainer/flags/test_overfit_batches.py +++ b/tests/tests_pytorch/trainer/flags/test_overfit_batches.py @@ -142,7 +142,7 @@ def test_distributed_sampler_with_overfit_batches(): strategy="ddp_spawn", ) model.trainer = trainer - trainer.model = model + trainer.strategy._lightning_module = model trainer._data_connector.attach_dataloaders(model) trainer.reset_train_dataloader() train_sampler = trainer.train_dataloader.loaders.sampler diff --git a/tests/tests_pytorch/utilities/test_imports.py b/tests/tests_pytorch/utilities/test_imports.py index 25995bb029f3a..c673716c457f2 100644 --- a/tests/tests_pytorch/utilities/test_imports.py +++ b/tests/tests_pytorch/utilities/test_imports.py @@ -13,7 +13,6 @@ # limitations under the License. import operator -from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning.strategies.bagua import _BAGUA_AVAILABLE from pytorch_lightning.strategies.deepspeed import _DEEPSPEED_AVAILABLE from pytorch_lightning.utilities import ( @@ -23,7 +22,7 @@ _OMEGACONF_AVAILABLE, _POPTORCH_AVAILABLE, ) -from pytorch_lightning.utilities.imports import _compare_version, _RequirementAvailable, torch +from pytorch_lightning.utilities.imports import _compare_version, _FAIRSCALE_AVAILABLE, _RequirementAvailable, torch def test_module_exists():