diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index f22873a5cf58e..1da509630bc22 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -194,6 +194,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed model parameters getting shared between processes when running with `strategy="ddp_spawn"` and `accelerator="cpu"`; this has a necessary memory impact, as parameters are replicated for each process now ([#18238](https://github.com/Lightning-AI/lightning/pull/18238)) +- Fixed an issue causing the `_FabricOptimizer.state` to remain outdated after loading with `load_state_dict` ([#18488](https://github.com/Lightning-AI/lightning/pull/18488)) + + ## [2.0.7] - 2023-08-14 ### Changed diff --git a/src/lightning/fabric/wrappers.py b/src/lightning/fabric/wrappers.py index bea5b6ea3f8d3..77f679e588208 100644 --- a/src/lightning/fabric/wrappers.py +++ b/src/lightning/fabric/wrappers.py @@ -50,13 +50,11 @@ def __init__(self, optimizer: Optimizer, strategy: Strategy, callbacks: Optional """ - # `__del__` is skipped in case the optimizer has implemented custom destructor logic which we would - # not want to call on destruction of the `_FabricOptimizer - self.__dict__ = {k: v for k, v in optimizer.__dict__.items() if k not in ("state_dict", "step", "__del__")} self.__class__ = type("Fabric" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {}) self._optimizer = optimizer self._strategy = strategy self._callbacks = callbacks or [] + self._refresh() @property def optimizer(self) -> Optimizer: @@ -65,6 +63,12 @@ def optimizer(self) -> Optimizer: def state_dict(self) -> Dict[str, Tensor]: return self._strategy.get_optimizer_state(self.optimizer) + def load_state_dict(self, state_dict: Dict[str, Tensor]) -> None: + self.optimizer.load_state_dict(state_dict) + # `Optimizer.load_state_dict` modifies `optimizer.__dict__`, so we need to update the `__dict__` on + # this wrapper + self._refresh() + def step(self, closure: Optional[Callable] = None) -> Any: kwargs = {"closure": closure} if closure is not None else {} if hasattr(self._strategy, "model") and isinstance(self._strategy.model, Optimizable): @@ -82,6 +86,21 @@ def step(self, closure: Optional[Callable] = None) -> Any: hook(strategy=self._strategy, optimizer=optimizer) return output + def _refresh(self) -> None: + """Refreshes the ``__dict__`` so that it matches the internal states in the wrapped optimizer. + + This is only needed to present the user with an updated view in case they inspect the state of this wrapper. + """ + # `__del__` is skipped in case the optimizer has implemented custom destructor logic which we would + # not want to call on destruction of the `_FabricOptimizer + self.__dict__.update( + { + k: v + for k, v in self.optimizer.__dict__.items() + if k not in ("load_state_dict", "state_dict", "step", "__del__") + } + ) + class _FabricModule(_DeviceDtypeModuleMixin): def __init__( diff --git a/tests/tests_fabric/test_wrappers.py b/tests/tests_fabric/test_wrappers.py index 1b0f7333db177..b9c0f89bb97be 100644 --- a/tests/tests_fabric/test_wrappers.py +++ b/tests/tests_fabric/test_wrappers.py @@ -362,16 +362,35 @@ def test_fabric_optimizer_wraps(): def test_fabric_optimizer_state_dict(): """Test that the FabricOptimizer calls into the strategy to collect the state.""" - optimizer = Mock() + optimizer = Mock(spec=torch.optim.Adam) strategy = Mock() fabric_optimizer = _FabricOptimizer(optimizer=optimizer, strategy=strategy) fabric_optimizer.state_dict() strategy.get_optimizer_state.assert_called_with(optimizer) +def test_fabric_optimizer_load_state_dict(): + """Test that the FabricOptimizer can load the state dict on the wrapped optimizer and update its + internal `__dict__`.""" + model = torch.nn.Linear(1, 1) + optimizer = torch.optim.Adam(model.parameters()) + assert not optimizer.state # a fresh optimizer has no state + model(torch.rand(1)).backward() + optimizer.step() + assert optimizer.state + state_dict = optimizer.state_dict() + + optimizer = torch.optim.Adam(model.parameters()) # fresh optimizer + fabric_optimizer = _FabricOptimizer(optimizer=optimizer, strategy=Mock()) + assert not fabric_optimizer.state # a fresh optimizer has no state + fabric_optimizer.load_state_dict(state_dict) + assert fabric_optimizer.state + assert fabric_optimizer.optimizer.state_dict()["state"] == state_dict["state"] + + def test_fabric_optimizer_steps(): """Test that the FabricOptimizer forwards the step() and zero_grad() calls to the wrapped optimizer.""" - optimizer = Mock() + optimizer = Mock(spec=torch.optim.Adam) strategy = Mock(spec=["optimizer_step"]) strategy.optimizer_step.return_value = 123 fabric_optimizer = _FabricOptimizer(optimizer=optimizer, strategy=strategy)