From 94cf0218d34b9d3a7ec599358e98d5f1efa2ec5a Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 5 Sep 2023 23:09:47 +0200 Subject: [PATCH 1/5] refresh dict --- src/lightning/fabric/wrappers.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/src/lightning/fabric/wrappers.py b/src/lightning/fabric/wrappers.py index bea5b6ea3f8d3..b54b5f2bb2125 100644 --- a/src/lightning/fabric/wrappers.py +++ b/src/lightning/fabric/wrappers.py @@ -50,13 +50,11 @@ def __init__(self, optimizer: Optimizer, strategy: Strategy, callbacks: Optional """ - # `__del__` is skipped in case the optimizer has implemented custom destructor logic which we would - # not want to call on destruction of the `_FabricOptimizer - self.__dict__ = {k: v for k, v in optimizer.__dict__.items() if k not in ("state_dict", "step", "__del__")} self.__class__ = type("Fabric" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {}) self._optimizer = optimizer self._strategy = strategy self._callbacks = callbacks or [] + self._refresh() @property def optimizer(self) -> Optimizer: @@ -65,6 +63,12 @@ def optimizer(self) -> Optimizer: def state_dict(self) -> Dict[str, Tensor]: return self._strategy.get_optimizer_state(self.optimizer) + def load_state_dict(self, state_dict: Dict[str, Tensor]) -> None: + self.optimizer.load_state_dict(state_dict) + # `Optimizer.load_state_dict` modifies `optimizer.__dict__`, so we need to update the `__dict__` on + # this wrapper + self._refresh() + def step(self, closure: Optional[Callable] = None) -> Any: kwargs = {"closure": closure} if closure is not None else {} if hasattr(self._strategy, "model") and isinstance(self._strategy.model, Optimizable): @@ -82,6 +86,17 @@ def step(self, closure: Optional[Callable] = None) -> Any: hook(strategy=self._strategy, optimizer=optimizer) return output + def _refresh(self) -> None: + """Refreshes the ``__dict__`` so that it matches the internal states in the wrapped optimizer. + + This is only needed to present the user with an updated view in case they inspect the state of this wrapper. + """ + # `__del__` is skipped in case the optimizer has implemented custom destructor logic which we would + # not want to call on destruction of the `_FabricOptimizer + self.__dict__.update({ + k: v for k, v in self.optimizer.__dict__.items() if k not in ("load_state_dict", "state_dict", "step", "__del__") + }) + class _FabricModule(_DeviceDtypeModuleMixin): def __init__( From ff71e83d4da9db3c947ea38e9c736ac585426117 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 5 Sep 2023 23:18:34 +0200 Subject: [PATCH 2/5] add test --- tests/tests_fabric/test_wrappers.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/tests_fabric/test_wrappers.py b/tests/tests_fabric/test_wrappers.py index 1b0f7333db177..f7c7da5abac59 100644 --- a/tests/tests_fabric/test_wrappers.py +++ b/tests/tests_fabric/test_wrappers.py @@ -369,6 +369,25 @@ def test_fabric_optimizer_state_dict(): strategy.get_optimizer_state.assert_called_with(optimizer) +def test_fabric_optimizer_load_state_dict(): + """Test that the FabricOptimizer can load the state dict on the wrapped optimizer and update its + internal `__dict__`.""" + model = torch.nn.Linear(1, 1) + optimizer = torch.optim.Adam(model.parameters()) + assert not optimizer.state # a fresh optimizer has no state + model(torch.rand(1)).backward() + optimizer.step() + assert optimizer.state + state_dict = optimizer.state_dict() + + optimizer = torch.optim.Adam(model.parameters()) # fresh optimizer + fabric_optimizer = _FabricOptimizer(optimizer=optimizer, strategy=Mock()) + assert not fabric_optimizer.state # a fresh optimizer has no state + fabric_optimizer.load_state_dict(state_dict) + assert fabric_optimizer.state + assert fabric_optimizer.optimizer.state_dict()["state"] == state_dict["state"] + + def test_fabric_optimizer_steps(): """Test that the FabricOptimizer forwards the step() and zero_grad() calls to the wrapped optimizer.""" optimizer = Mock() From 951114f391c3c29ab4663c9be815aaf81b04f4e8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 5 Sep 2023 21:23:13 +0000 Subject: [PATCH 3/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/fabric/wrappers.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/lightning/fabric/wrappers.py b/src/lightning/fabric/wrappers.py index b54b5f2bb2125..77f679e588208 100644 --- a/src/lightning/fabric/wrappers.py +++ b/src/lightning/fabric/wrappers.py @@ -93,9 +93,13 @@ def _refresh(self) -> None: """ # `__del__` is skipped in case the optimizer has implemented custom destructor logic which we would # not want to call on destruction of the `_FabricOptimizer - self.__dict__.update({ - k: v for k, v in self.optimizer.__dict__.items() if k not in ("load_state_dict", "state_dict", "step", "__del__") - }) + self.__dict__.update( + { + k: v + for k, v in self.optimizer.__dict__.items() + if k not in ("load_state_dict", "state_dict", "step", "__del__") + } + ) class _FabricModule(_DeviceDtypeModuleMixin): From 914fc320073bfb1251d452ade55697816173e890 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 5 Sep 2023 23:27:46 +0200 Subject: [PATCH 4/5] changelog --- src/lightning/fabric/CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index f22873a5cf58e..1da509630bc22 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -194,6 +194,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed model parameters getting shared between processes when running with `strategy="ddp_spawn"` and `accelerator="cpu"`; this has a necessary memory impact, as parameters are replicated for each process now ([#18238](https://github.com/Lightning-AI/lightning/pull/18238)) +- Fixed an issue causing the `_FabricOptimizer.state` to remain outdated after loading with `load_state_dict` ([#18488](https://github.com/Lightning-AI/lightning/pull/18488)) + + ## [2.0.7] - 2023-08-14 ### Changed From 9731a4b43368bd310db4aaf201f52b0dc4f6f1f7 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 5 Sep 2023 23:36:09 +0200 Subject: [PATCH 5/5] update test --- tests/tests_fabric/test_wrappers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/tests_fabric/test_wrappers.py b/tests/tests_fabric/test_wrappers.py index f7c7da5abac59..b9c0f89bb97be 100644 --- a/tests/tests_fabric/test_wrappers.py +++ b/tests/tests_fabric/test_wrappers.py @@ -362,7 +362,7 @@ def test_fabric_optimizer_wraps(): def test_fabric_optimizer_state_dict(): """Test that the FabricOptimizer calls into the strategy to collect the state.""" - optimizer = Mock() + optimizer = Mock(spec=torch.optim.Adam) strategy = Mock() fabric_optimizer = _FabricOptimizer(optimizer=optimizer, strategy=strategy) fabric_optimizer.state_dict() @@ -390,7 +390,7 @@ def test_fabric_optimizer_load_state_dict(): def test_fabric_optimizer_steps(): """Test that the FabricOptimizer forwards the step() and zero_grad() calls to the wrapped optimizer.""" - optimizer = Mock() + optimizer = Mock(spec=torch.optim.Adam) strategy = Mock(spec=["optimizer_step"]) strategy.optimizer_step.return_value = 123 fabric_optimizer = _FabricOptimizer(optimizer=optimizer, strategy=strategy)