Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fabric wrappers for optimizers do not load state dict #18482

Closed
nikvaessen opened this issue Sep 5, 2023 · 3 comments · Fixed by #18488
Closed

Fabric wrappers for optimizers do not load state dict #18482

nikvaessen opened this issue Sep 5, 2023 · 3 comments · Fixed by #18488
Assignees
Labels
bug Something isn't working fabric lightning.fabric.Fabric ver: 2.0.x
Milestone

Comments

@nikvaessen
Copy link
Contributor

nikvaessen commented Sep 5, 2023

Bug description

When load_state_dict() is called on an optimizer object returned by fabric.setup(...), the resulting state_dict will be empty.

What version are you seeing the problem on?

v2.0

How to reproduce the bug

import lightning
import torch

USE_FABRIC_SETUP = True


class Network(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.fc1 = torch.nn.Linear(1000, 10)

    def forward(self, x):
        return torch.nn.functional.relu(self.fc1(x))


def first_session():
    fabric = lightning.Fabric(accelerator="gpu", devices=1)
    fabric.launch()

    network = Network()
    opt = torch.optim.Adam(network.parameters())

    if USE_FABRIC_SETUP:
        network, opt = fabric.setup(network, opt)
    else:
        network = network.cuda()

    for i in range(5):
        opt.zero_grad()

        inp = torch.rand((8, 1000)).cuda()
        pred = network(inp)

        target = torch.randint(low=0, high=9, size=(8,)).cuda()
        loss = torch.nn.functional.cross_entropy(pred, target)

        fabric.backward(loss)
        opt.step()

    fabric.save("first_session.ckpt", {"network": network, "opt": opt})


def second_session():
    fabric = lightning.Fabric(accelerator="gpu", devices=1)
    fabric.launch()

    network = Network()
    opt = torch.optim.Adam(network.parameters())

    if USE_FABRIC_SETUP:
        network, opt = fabric.setup(network, opt)

    ckpt = fabric.load("first_session.ckpt")

    network.load_state_dict(ckpt["network"])
    opt.load_state_dict(ckpt["opt"])

    assert len(opt.state_dict()["state"]) == len(ckpt["opt"]["state"])


def main():
    print(lightning.__version__)
    torch.manual_seed(123)

    first_session()
    second_session()


if __name__ == "__main__":
    main()

Error messages and logs

Traceback (most recent call last):
  File "/home/nvaessen/phd/repo/nanow2v2/playground/loading_bug/buggy_loading.py", line 71, in <module>
    main()
  File "/home/nvaessen/phd/repo/nanow2v2/playground/loading_bug/buggy_loading.py", line 67, in main
    second_session()
  File "/home/nvaessen/phd/repo/nanow2v2/playground/loading_bug/buggy_loading.py", line 59, in second_session
    assert len(opt.state_dict()["state"]) == len(ckpt["opt"]["state"])
AssertionError

The checkpoint (ckpt['opt']:

{'state': {0: {'step': tensor(5.), 'exp_avg': tensor([[ 0.0008,  0.0006,  0.0015,  ...,  0.0010,  0.0012,  0.0011],
        [ 0.0013,  0.0017,  0.0026,  ...,  0.0018,  0.0029,  0.0013],
        [-0.0038, -0.0014, -0.0019,  ..., -0.0006, -0.0010, -0.0031],
        ...,
        [ 0.0196,  0.0150,  0.0168,  ...,  0.0152,  0.0192,  0.0213],
        [ 0.0039,  0.0045,  0.0067,  ...,  0.0053,  0.0063,  0.0055],
        [ 0.0006,  0.0004,  0.0009,  ...,  0.0004,  0.0010,  0.0003]]), 'exp_avg_sq': tensor([[1.5035e-07, 8.4814e-08, 5.0681e-07,  ..., 2.2370e-07, 3.5894e-07,
         2.7821e-07],
        [4.1096e-07, 6.4151e-07, 1.5764e-06,  ..., 7.8267e-07, 1.9212e-06,
         3.8671e-07],
        [8.4621e-06, 6.0896e-06, 1.3199e-05,  ..., 6.2065e-06, 9.0778e-06,
         1.1578e-05],
        ...,
        [2.4836e-05, 2.4092e-05, 2.5735e-05,  ..., 2.5762e-05, 3.9960e-05,
         3.7594e-05],
        [1.9858e-06, 2.4212e-06, 6.8515e-06,  ..., 3.5609e-06, 6.0807e-06,
         4.0056e-06],
        [9.6174e-08, 3.3862e-08, 1.6834e-07,  ..., 3.2796e-08, 2.2325e-07,
         1.7551e-08]])}, 1: {'step': tensor(5.), 'exp_avg': tensor([0.0032, 0.0051, 0.0019, 0.0000, 0.0007, 0.0042, 0.0058, 0.0310, 0.0122,
        0.0016]), 'exp_avg_sq': tensor([2.3620e-06, 5.9441e-06, 2.4069e-05, 0.0000e+00, 1.1375e-07, 4.0922e-06,
        7.9036e-06, 1.0822e-04, 2.1972e-05, 5.7628e-07])}}, 'param_groups': [{'lr': 0.001, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'maximize': False, 'foreach': None, 'capturable': False, 'differentiable': False, 'fused': None, 'params': [0, 1]}]}

The result of calling state_dict() after calling opt.load_state_dict():

{'state': {}, 'param_groups': [{'lr': 0.001, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'maximize': False, 'foreach': None, 'capturable': False, 'differentiable': False, 'fused': None, 'params': [0, 1]}]}

Environment

Current environment
  • CUDA:
    • GPU:
      • NVIDIA GeForce GTX 1080 Ti
    • available: True
    • version: 11.7
  • Lightning:
    • lightning: 2.0.8
    • lightning-cloud: 0.5.37
    • lightning-utilities: 0.8.0
    • pytorch-lightning: 2.0.3
    • torch: 2.0.1
    • torch-tb-profiler: 0.4.1
    • torchaudio: 2.0.1
    • torchdata: 0.6.1
    • torchmetrics: 0.11.4
    • torchvision: 0.15.1
  • Packages:
    • absl-py: 1.4.0
    • aiohttp: 3.8.4
    • aiosignal: 1.3.1
    • alembic: 1.11.1
    • antlr4-python3-runtime: 4.9.3
    • anyio: 3.7.0
    • appdirs: 1.4.4
    • arrow: 1.2.3
    • async-timeout: 4.0.2
    • attrs: 23.1.0
    • autopage: 0.5.1
    • backoff: 2.2.1
    • beautifulsoup4: 4.12.2
    • blessed: 1.20.0
    • cachetools: 5.3.1
    • certifi: 2023.5.7
    • cffi: 1.15.1
    • charset-normalizer: 3.1.0
    • click: 8.1.3
    • cliff: 4.3.0
    • cloudpickle: 2.2.1
    • cmaes: 0.9.1
    • cmake: 3.26.4
    • cmd2: 2.4.3
    • colorlog: 6.7.0
    • contourpy: 1.1.0
    • croniter: 1.3.15
    • cycler: 0.11.0
    • dateutils: 0.6.12
    • deepdiff: 6.3.0
    • docker-pycreds: 0.4.0
    • exceptiongroup: 1.1.1
    • fastapi: 0.98.0
    • filelock: 3.12.2
    • fonttools: 4.40.0
    • frozenlist: 1.3.3
    • fsspec: 2023.6.0
    • gitdb: 4.0.10
    • gitpython: 3.1.31
    • google-auth: 2.20.0
    • google-auth-oauthlib: 0.4.6
    • greenlet: 2.0.2
    • grpcio: 1.54.2
    • h11: 0.14.0
    • huggingface-hub: 0.15.1
    • hydra-core: 1.3.2
    • hydra-optuna-sweeper: 1.2.0
    • hydra-submitit-launcher: 1.2.0
    • idna: 3.4
    • importlib-metadata: 6.7.0
    • iniconfig: 2.0.0
    • inquirer: 3.1.3
    • itsdangerous: 2.1.2
    • jinja2: 3.1.2
    • jiwer: 3.0.0
    • joblib: 1.2.0
    • kiwisolver: 1.4.4
    • lightning: 2.0.8
    • lightning-cloud: 0.5.37
    • lightning-utilities: 0.8.0
    • lit: 16.0.6
    • mako: 1.2.4
    • markdown: 3.4.3
    • markdown-it-py: 3.0.0
    • markupsafe: 2.1.3
    • matplotlib: 3.7.1
    • mdurl: 0.1.2
    • mpmath: 1.3.0
    • multidict: 6.0.4
    • nanow2v2: 1.0
    • networkx: 3.1
    • numpy: 1.25.0
    • nvidia-cublas-cu11: 11.10.3.66
    • nvidia-cuda-cupti-cu11: 11.7.101
    • nvidia-cuda-nvrtc-cu11: 11.7.99
    • nvidia-cuda-runtime-cu11: 11.7.99
    • nvidia-cudnn-cu11: 8.5.0.96
    • nvidia-cufft-cu11: 10.9.0.58
    • nvidia-curand-cu11: 10.2.10.91
    • nvidia-cusolver-cu11: 11.4.0.1
    • nvidia-cusparse-cu11: 11.7.4.91
    • nvidia-nccl-cu11: 2.14.3
    • nvidia-nvtx-cu11: 11.7.91
    • oauthlib: 3.2.2
    • omegaconf: 2.3.0
    • optuna: 2.10.1
    • ordered-set: 4.1.0
    • packaging: 23.1
    • pandas: 2.0.2
    • pathtools: 0.1.2
    • pbr: 5.11.1
    • pillow: 9.5.0
    • pip: 23.1.2
    • pluggy: 1.2.0
    • polars: 0.18.3
    • prettytable: 3.8.0
    • protobuf: 4.23.3
    • psutil: 5.9.5
    • pyasn1: 0.5.0
    • pyasn1-modules: 0.3.0
    • pycparser: 2.21
    • pydantic: 1.10.9
    • pygments: 2.15.1
    • pyjwt: 2.7.0
    • pyparsing: 3.1.0
    • pyperclip: 1.8.2
    • pytest: 7.4.0
    • python-dateutil: 2.8.2
    • python-dotenv: 1.0.0
    • python-editor: 1.0.4
    • python-multipart: 0.0.6
    • pytorch-lightning: 2.0.3
    • pytz: 2023.3
    • pyyaml: 6.0
    • rapidfuzz: 2.13.7
    • readchar: 4.0.5
    • regex: 2023.6.3
    • requests: 2.31.0
    • requests-oauthlib: 1.3.1
    • rich: 13.4.2
    • rsa: 4.9
    • safetensors: 0.3.1
    • scikit-learn: 1.2.2
    • scipy: 1.10.1
    • seaborn: 0.12.2
    • sentry-sdk: 1.25.1
    • setproctitle: 1.3.2
    • setuptools: 65.5.0
    • six: 1.16.0
    • smmap: 5.0.0
    • sniffio: 1.3.0
    • soundfile: 0.12.1
    • soupsieve: 2.4.1
    • sqlalchemy: 2.0.17
    • starlette: 0.27.0
    • starsessions: 1.3.0
    • stevedore: 5.1.0
    • submitit: 1.4.5
    • sympy: 1.12
    • tensorboard: 2.12.0
    • tensorboard-data-server: 0.7.1
    • tensorboard-plugin-wit: 1.8.1
    • threadpoolctl: 3.1.0
    • tokenizers: 0.13.3
    • tomli: 2.0.1
    • torch: 2.0.1
    • torch-tb-profiler: 0.4.1
    • torchaudio: 2.0.1
    • torchdata: 0.6.1
    • torchmetrics: 0.11.4
    • torchvision: 0.15.1
    • tqdm: 4.65.0
    • traitlets: 5.9.0
    • transformers: 4.30.2
    • triton: 2.0.0
    • typing-extensions: 4.6.3
    • tzdata: 2023.3
    • urllib3: 1.26.16
    • uvicorn: 0.22.0
    • wandb: 0.15.5
    • wcwidth: 0.2.6
    • websocket-client: 1.6.0
    • websockets: 11.0.3
    • werkzeug: 2.3.6
    • wheel: 0.40.0
    • yarl: 1.9.2
    • zipp: 3.15.0
  • System:
    • OS: Linux
    • architecture:
      • 64bit
      • ELF
    • processor:
    • python: 3.10.12
    • release: 6.4.12-arch1-1
    • version: Proposal for help #1 SMP PREEMPT_DYNAMIC Thu, 24 Aug 2023 00:38:14 +0000

More info

No response

cc @carmocca @justusschock @awaelchli

@nikvaessen nikvaessen added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Sep 5, 2023
@awaelchli awaelchli removed the needs triage Waiting to be triaged by maintainers label Sep 5, 2023
@awaelchli awaelchli self-assigned this Sep 5, 2023
@awaelchli awaelchli added the fabric lightning.fabric.Fabric label Sep 5, 2023
@awaelchli
Copy link
Contributor

awaelchli commented Sep 5, 2023

Hey @nikvaessen
This is very odd, I'll look into it.

Btw, in case you didn't know the recommended way to load in Fabric is

fabric.load("first_session.ckpt", {"network": network, "opt": opt})

because this generalizes across all strategies and accelerators + offers a convenient way to make scripts stateful in general.
And this will pass your assertion. So you can use this way as a workaround until I make the bugfix.

Thanks for reporting!

@nikvaessen
Copy link
Contributor Author

nikvaessen commented Sep 6, 2023

And this will pass your assertion. So you can use this way as a workaround until I make the bugfix.

I've tried the following modification to the reloading part of the reproduction code sample:

def second_session():
    fabric = lightning.Fabric(accelerator="gpu", devices=1)
    fabric.launch()

    network = Network()
    opt = torch.optim.Adam(network.parameters())

    if USE_FABRIC_SETUP:
        network, opt = fabric.setup(network, opt)
    else:
        network.cuda()

    state = {"network": network, "opt": opt}
    remainder = fabric.load("first_session.ckpt", state)

    print("remainder:", remainder)
    print("wrapper", opt.state_dict())
    print("optimizer", opt.optimizer.state_dict())
    print("checkpoint\n", torch.load("first_session.ckpt"))

    assert len(opt.state_dict()["state"]) >= 0

This still results in

remainder: {}
wrapper {'state': {}, 'param_groups': [{'lr': 0.001, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'maximize': False, 'foreach': None, 'capturable': False, 'differentiable': False, 'fused': None, 'params': [0, 1]}]}
optimizer {'state': {}, 'param_groups': [{'lr': 0.001, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'maximize': False, 'foreach': None, 'capturable': False, 'differentiable': False, 'fused': None, 'params': [0, 1]}]}

If I run my code with pip install git+https://github.com/Lightning-AI/lightning@bugfix/fabric-optimizer-load-state, the state dictionary is correctly loaded. So thanks for the bugfix :)

@awaelchli
Copy link
Contributor

@nikvaessen Thanks for confirming!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working fabric lightning.fabric.Fabric ver: 2.0.x
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants