Skip to content

Commit

Permalink
Update deepspeed requirement support window (#16813)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: Jirka <jirka.borovec@seznam.cz>
  • Loading branch information
3 people authored Apr 25, 2023
1 parent 843a167 commit b792c90
Show file tree
Hide file tree
Showing 6 changed files with 7 additions and 33 deletions.
2 changes: 1 addition & 1 deletion requirements/fabric/strategies.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment
deepspeed >=0.6.3, !=0.7.0, <=0.8.0; platform_system != "Windows"
deepspeed >=0.8.2, <=0.9.1; platform_system != "Windows"
2 changes: 1 addition & 1 deletion requirements/pytorch/strategies.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment
deepspeed >=0.6.3, !=0.7.0, <0.8.0; platform_system != "Windows" # TODO: Include 0.8.x after https://github.com/microsoft/DeepSpeed/commit/b587c7e85470329ac25df7c7c2521ff9b2833db7 gets released
deepspeed >=0.8.2, <=0.9.1; platform_system != "Windows"
13 changes: 1 addition & 12 deletions src/lightning/fabric/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,22 +32,11 @@
from lightning.fabric.strategies.ddp import DDPStrategy
from lightning.fabric.strategies.strategy import _Sharded
from lightning.fabric.utilities.distributed import log
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
from lightning.fabric.utilities.rank_zero import rank_zero_info, rank_zero_warn
from lightning.fabric.utilities.seed import reset_seed
from lightning.fabric.utilities.types import _PATH

_DEEPSPEED_AVAILABLE = (
# DeepSpeed fails under 0.8.2 with torch 2.0: https://github.com/microsoft/DeepSpeed/pull/2863
RequirementCache("deepspeed>=0.8.2")
or (
not _TORCH_GREATER_EQUAL_2_0
and RequirementCache("deepspeed")
# check packaging because of https://github.com/microsoft/DeepSpeed/pull/2771
# remove the packaging check when min version is >=0.8.1
and RequirementCache("packaging>=20.0")
)
)
_DEEPSPEED_AVAILABLE = RequirementCache("deepspeed")
if TYPE_CHECKING and _DEEPSPEED_AVAILABLE:
import deepspeed

Expand Down
15 changes: 1 addition & 14 deletions src/lightning/fabric/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,7 @@ def __init__(self, optimizer: Optimizer, strategy: Strategy) -> None:
"""
# `__del__` is skipped in case the optimizer has implemented custom destructor logic which we would
# not want to call on destruction of the `_FabricOptimizer
self.__dict__ = {
k: v for k, v in optimizer.__dict__.items() if k not in ("state_dict", "step", "zero_grad", "__del__")
}
self.__dict__ = {k: v for k, v in optimizer.__dict__.items() if k not in ("state_dict", "step", "__del__")}
self.__class__ = type("Fabric" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {})
self._optimizer = optimizer
self._strategy = strategy
Expand All @@ -75,10 +73,6 @@ def step(self, closure: Optional[Callable] = None) -> Any:
**kwargs,
)

def zero_grad(self, **kwargs: Any) -> None:
kwargs = _process_optimizer_zero_grad_kwargs(self.optimizer, kwargs)
self.optimizer.zero_grad(**kwargs)


class _FabricModule(_DeviceDtypeModuleMixin):
def __init__(
Expand Down Expand Up @@ -220,13 +214,6 @@ def __iter__(self) -> Union[Iterator[Any], Generator[Any, None, None]]:
yield move_data_to_device(item, self._device)


def _process_optimizer_zero_grad_kwargs(optimizer: Optimizer, kwargs: Dict[str, Any]) -> Dict[str, Any]:
if "set_to_none" in kwargs and "set_grads_to_None" in inspect.signature(optimizer.zero_grad).parameters:
# Some optimizers out there, for example DeepSpeedZeroOptimizer, use a different name than PyTorch
kwargs["set_grads_to_None"] = kwargs.pop("set_to_none")
return kwargs


def _unwrap_objects(collection: Any) -> Any:
def _unwrap(
obj: Union[_FabricModule, _FabricOptimizer, _FabricDataLoader]
Expand Down
4 changes: 0 additions & 4 deletions tests/tests_fabric/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,10 +358,6 @@ def zero_grad(self, set_grads_to_None=False):
fabric_optimizer = _FabricOptimizer(optimizer=optimizer, strategy=Mock())
fabric_optimizer.zero_grad()
custom_zero_grad.assert_called_with(set_grads_to_None=False)
fabric_optimizer.zero_grad(set_to_none=False)
custom_zero_grad.assert_called_with(set_grads_to_None=False)
fabric_optimizer.zero_grad(set_to_none=True)
custom_zero_grad.assert_called_with(set_grads_to_None=True)


def test_is_wrapped():
Expand Down
4 changes: 3 additions & 1 deletion tests/tests_pytorch/strategies/test_deepspeed_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,13 +841,15 @@ def on_train_batch_start(self, trainer, pl_module: LightningModule, batch: Any,
model = ModelParallelClassificationModel()
dm = ClassifDataModule()
verification_callback = VerificationCallback()
strategy = DeepSpeedStrategy(stage=2, offload_optimizer=offload_optimizer)
strategy.config["zero_force_ds_cpu_optimizer"] = False
trainer = Trainer(
default_root_dir=tmpdir,
# TODO: this test fails with max_epochs >1 as there are leftover batches per epoch.
# there's divergence in how Lightning handles the last batch of the epoch with how DeepSpeed does it.
# we step the optimizers on the last batch but DeepSpeed keeps the accumulation for the next epoch
max_epochs=1,
strategy=DeepSpeedStrategy(stage=2, offload_optimizer=offload_optimizer),
strategy=strategy,
accelerator="gpu",
devices=2,
limit_train_batches=5,
Expand Down

0 comments on commit b792c90

Please sign in to comment.