Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refresh _FabricOptimizer.__dict__ when loading a state dict #18488

Merged
merged 6 commits into from
Sep 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed model parameters getting shared between processes when running with `strategy="ddp_spawn"` and `accelerator="cpu"`; this has a necessary memory impact, as parameters are replicated for each process now ([#18238](https://github.com/Lightning-AI/lightning/pull/18238))


- Fixed an issue causing the `_FabricOptimizer.state` to remain outdated after loading with `load_state_dict` ([#18488](https://github.com/Lightning-AI/lightning/pull/18488))


## [2.0.7] - 2023-08-14

### Changed
Expand Down
25 changes: 22 additions & 3 deletions src/lightning/fabric/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,11 @@ def __init__(self, optimizer: Optimizer, strategy: Strategy, callbacks: Optional


"""
# `__del__` is skipped in case the optimizer has implemented custom destructor logic which we would
# not want to call on destruction of the `_FabricOptimizer
self.__dict__ = {k: v for k, v in optimizer.__dict__.items() if k not in ("state_dict", "step", "__del__")}
self.__class__ = type("Fabric" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {})
self._optimizer = optimizer
self._strategy = strategy
self._callbacks = callbacks or []
self._refresh()

@property
def optimizer(self) -> Optimizer:
Expand All @@ -65,6 +63,12 @@ def optimizer(self) -> Optimizer:
def state_dict(self) -> Dict[str, Tensor]:
return self._strategy.get_optimizer_state(self.optimizer)

def load_state_dict(self, state_dict: Dict[str, Tensor]) -> None:
self.optimizer.load_state_dict(state_dict)
# `Optimizer.load_state_dict` modifies `optimizer.__dict__`, so we need to update the `__dict__` on
# this wrapper
self._refresh()

def step(self, closure: Optional[Callable] = None) -> Any:
kwargs = {"closure": closure} if closure is not None else {}
if hasattr(self._strategy, "model") and isinstance(self._strategy.model, Optimizable):
Expand All @@ -82,6 +86,21 @@ def step(self, closure: Optional[Callable] = None) -> Any:
hook(strategy=self._strategy, optimizer=optimizer)
return output

def _refresh(self) -> None:
"""Refreshes the ``__dict__`` so that it matches the internal states in the wrapped optimizer.

This is only needed to present the user with an updated view in case they inspect the state of this wrapper.
"""
# `__del__` is skipped in case the optimizer has implemented custom destructor logic which we would
# not want to call on destruction of the `_FabricOptimizer
self.__dict__.update(
{
k: v
for k, v in self.optimizer.__dict__.items()
if k not in ("load_state_dict", "state_dict", "step", "__del__")
}
)


class _FabricModule(_DeviceDtypeModuleMixin):
def __init__(
Expand Down
23 changes: 21 additions & 2 deletions tests/tests_fabric/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,16 +362,35 @@ def test_fabric_optimizer_wraps():

def test_fabric_optimizer_state_dict():
"""Test that the FabricOptimizer calls into the strategy to collect the state."""
optimizer = Mock()
optimizer = Mock(spec=torch.optim.Adam)
strategy = Mock()
fabric_optimizer = _FabricOptimizer(optimizer=optimizer, strategy=strategy)
fabric_optimizer.state_dict()
strategy.get_optimizer_state.assert_called_with(optimizer)


def test_fabric_optimizer_load_state_dict():
"""Test that the FabricOptimizer can load the state dict on the wrapped optimizer and update its
internal `__dict__`."""
model = torch.nn.Linear(1, 1)
optimizer = torch.optim.Adam(model.parameters())
assert not optimizer.state # a fresh optimizer has no state
model(torch.rand(1)).backward()
optimizer.step()
assert optimizer.state
state_dict = optimizer.state_dict()

optimizer = torch.optim.Adam(model.parameters()) # fresh optimizer
fabric_optimizer = _FabricOptimizer(optimizer=optimizer, strategy=Mock())
assert not fabric_optimizer.state # a fresh optimizer has no state
fabric_optimizer.load_state_dict(state_dict)
assert fabric_optimizer.state
assert fabric_optimizer.optimizer.state_dict()["state"] == state_dict["state"]


def test_fabric_optimizer_steps():
"""Test that the FabricOptimizer forwards the step() and zero_grad() calls to the wrapped optimizer."""
optimizer = Mock()
optimizer = Mock(spec=torch.optim.Adam)
strategy = Mock(spec=["optimizer_step"])
strategy.optimizer_step.return_value = 123
fabric_optimizer = _FabricOptimizer(optimizer=optimizer, strategy=strategy)
Expand Down
Loading