-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fabric wrappers for optimizers do not load state dict #18482
Comments
Hey @nikvaessen Btw, in case you didn't know the recommended way to load in Fabric is fabric.load("first_session.ckpt", {"network": network, "opt": opt}) because this generalizes across all strategies and accelerators + offers a convenient way to make scripts stateful in general. Thanks for reporting! |
I've tried the following modification to the reloading part of the reproduction code sample: def second_session():
fabric = lightning.Fabric(accelerator="gpu", devices=1)
fabric.launch()
network = Network()
opt = torch.optim.Adam(network.parameters())
if USE_FABRIC_SETUP:
network, opt = fabric.setup(network, opt)
else:
network.cuda()
state = {"network": network, "opt": opt}
remainder = fabric.load("first_session.ckpt", state)
print("remainder:", remainder)
print("wrapper", opt.state_dict())
print("optimizer", opt.optimizer.state_dict())
print("checkpoint\n", torch.load("first_session.ckpt"))
assert len(opt.state_dict()["state"]) >= 0 This still results in
If I run my code with |
@nikvaessen Thanks for confirming! |
Bug description
When
load_state_dict()
is called on an optimizer object returned byfabric.setup(...)
, the resultingstate_dict
will be empty.What version are you seeing the problem on?
v2.0
How to reproduce the bug
Error messages and logs
The checkpoint (
ckpt['opt']
:The result of calling
state_dict()
after callingopt.load_state_dict()
:Environment
Current environment
More info
No response
cc @carmocca @justusschock @awaelchli
The text was updated successfully, but these errors were encountered: