Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Saving learning rate schedules with Fabric #18493

Closed
nikvaessen opened this issue Sep 6, 2023 · 1 comment · Fixed by #18513
Closed

Saving learning rate schedules with Fabric #18493

nikvaessen opened this issue Sep 6, 2023 · 1 comment · Fixed by #18513
Labels
bug Something isn't working checkpointing Related to checkpointing fabric lightning.fabric.Fabric lr scheduler ver: 2.0.x
Milestone

Comments

@nikvaessen
Copy link
Contributor

nikvaessen commented Sep 6, 2023

Bug description

It is unclear to me how learning rate schedules should be used alongside fabric.save. I've defaulted to manually saving with fabric.save(..., {..., 'schedule': schedule.state_dict()} and manually loading with ckpt=fabric.load(...); sched.load(ckpt['schedule']). As stated in #18482, the recommended way of loading is using a state object. This assumes you pass sched to the fabric.save, not sched.state_dict(). This leads to some weird behavior, see code below.

It is also unclear to me whether a learning rate schedule object should be given the bare optimizer, or the wrapped Fabric optimizer object.

What version are you seeing the problem on?

v2.0

How to reproduce the bug

import lightning
import torch

INSTANTIATE_SCHEDULE_ON_WRAPPER = False


class Network(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.fc1 = torch.nn.Linear(1000, 10)

    def forward(self, x):
        return torch.nn.functional.relu(self.fc1(x))


def no_schedule():
    fabric = lightning.Fabric(accelerator="gpu", devices=1)
    fabric.launch()

    network = Network()
    opt = torch.optim.Adam(network.parameters())

    network, opt = fabric.setup(network, opt)

    state = {"network": network, "opt": opt}
    fabric.save("no_schedule.ckpt", state)
    print("no schedule succeeded")


def exp_decay():
    fabric = lightning.Fabric(accelerator="gpu", devices=1)
    fabric.launch()

    network = Network()
    opt = torch.optim.Adam(network.parameters())

    if INSTANTIATE_SCHEDULE_ON_WRAPPER:
        network, opt = fabric.setup(network, opt)
        lr_schedule = torch.optim.lr_scheduler.ExponentialLR(opt, 0.99)
    else:
        lr_schedule = torch.optim.lr_scheduler.ExponentialLR(opt, 0.99)
        network, opt = fabric.setup(network, opt)

    state = {"network": network, "opt": opt, "lr_schedule": lr_schedule}
    fabric.save("exp_decay.ckpt", state)
    print("exp decay succeeded")


def cyclic():
    fabric = lightning.Fabric(accelerator="gpu", devices=1)
    fabric.launch()

    network = Network()
    opt = torch.optim.Adam(network.parameters())

    if INSTANTIATE_SCHEDULE_ON_WRAPPER:
        network, opt = fabric.setup(network, opt)
        lr_schedule = torch.optim.lr_scheduler.CyclicLR(
            opt, 1e-5, 1e-6, cycle_momentum=False
        )
    else:
        lr_schedule = torch.optim.lr_scheduler.CyclicLR(
            opt, 1e-5, 1e-6, cycle_momentum=False
        )
        network, opt = fabric.setup(network, opt)

    state = {"network": network, "opt": opt, "lr_schedule": lr_schedule}
    fabric.save("cyclic.ckpt", state)
    print("cyclic LR succeeded")


def main():
    print(lightning.__version__)
    torch.manual_seed(123)

    no_schedule()
    exp_decay()
    cyclic()


if __name__ == "__main__":
    main()

Error messages and logs

If run with INSTANTIATE_SCHEDULE_ON_WRAPPER = False, the methods no_schedule() and exp_decay() succeed, but cyclic fails with:

Traceback (most recent call last):
  File "/home/nvaessen/phd/repo/nanow2v2/playground/loading_bug/buggy_saving.py", line 85, in <module>
    main()
  File "/home/nvaessen/phd/repo/nanow2v2/playground/loading_bug/buggy_saving.py", line 81, in main
    cyclic()
  File "/home/nvaessen/phd/repo/nanow2v2/playground/loading_bug/buggy_saving.py", line 71, in cyclic
    fabric.save("cyclic.ckpt", state)
  File "/home/nvaessen/phd/repo/nanow2v2/.venv/lib/python3.10/site-packages/lightning/fabric/fabric.py", line 633, in save
    return self._strategy.save_checkpoint(path=path, state=_unwrap_objects(state))
  File "/home/nvaessen/phd/repo/nanow2v2/.venv/lib/python3.10/site-packages/lightning/fabric/strategies/strategy.py", line 241, in save_checkpoint
    self.checkpoint_io.save_checkpoint(checkpoint=state, path=path, storage_options=storage_options)
  File "/home/nvaessen/phd/repo/nanow2v2/.venv/lib/python3.10/site-packages/lightning/fabric/plugins/io/torch_io.py", line 59, in save_checkpoint
    _atomic_save(checkpoint, path)
  File "/home/nvaessen/phd/repo/nanow2v2/.venv/lib/python3.10/site-packages/lightning/fabric/utilities/cloud_io.py", line 72, in _atomic_save
    torch.save(checkpoint, bytesbuffer)
  File "/home/nvaessen/phd/repo/nanow2v2/.venv/lib/python3.10/site-packages/torch/serialization.py", line 441, in save
    _save(obj, opened_zipfile, pickle_module, pickle_protocol)
  File "/home/nvaessen/phd/repo/nanow2v2/.venv/lib/python3.10/site-packages/torch/serialization.py", line 653, in _save
    pickler.dump(obj)
TypeError: cannot pickle 'WeakMethod' object

If run with INSTANTIATE_SCHEDULE_ON_WRAPPER = True, both exp_decay and cyclic fail with

```Traceback (most recent call last):
  File "/home/nvaessen/phd/repo/nanow2v2/playground/loading_bug/buggy_saving.py", line 83, in <module>
    main()
  File "/home/nvaessen/phd/repo/nanow2v2/playground/loading_bug/buggy_saving.py", line 79, in main
    cyclic()
  File "/home/nvaessen/phd/repo/nanow2v2/playground/loading_bug/buggy_saving.py", line 69, in cyclic
    fabric.save("cyclic.ckpt", state)
  File "/home/nvaessen/phd/repo/nanow2v2/.venv/lib/python3.10/site-packages/lightning/fabric/fabric.py", line 633, in save
    return self._strategy.save_checkpoint(path=path, state=_unwrap_objects(state))
  File "/home/nvaessen/phd/repo/nanow2v2/.venv/lib/python3.10/site-packages/lightning/fabric/strategies/strategy.py", line 241, in save_checkpoint
    self.checkpoint_io.save_checkpoint(checkpoint=state, path=path, storage_options=storage_options)
  File "/home/nvaessen/phd/repo/nanow2v2/.venv/lib/python3.10/site-packages/lightning/fabric/plugins/io/torch_io.py", line 59, in save_checkpoint
    _atomic_save(checkpoint, path)
  File "/home/nvaessen/phd/repo/nanow2v2/.venv/lib/python3.10/site-packages/lightning/fabric/utilities/cloud_io.py", line 72, in _atomic_save
    torch.save(checkpoint, bytesbuffer)
  File "/home/nvaessen/phd/repo/nanow2v2/.venv/lib/python3.10/site-packages/torch/serialization.py", line 441, in save
    _save(obj, opened_zipfile, pickle_module, pickle_protocol)
  File "/home/nvaessen/phd/repo/nanow2v2/.venv/lib/python3.10/site-packages/torch/serialization.py", line 653, in _save
    pickler.dump(obj)
_pickle.PicklingError: Can't pickle <class 'lightning.fabric.wrappers.FabricAdam'>: attribute lookup FabricAdam on lightning.fabric.wrappers failed 

If we change state={..., "lr_schedule": lr_schedule} to state={..., "lr_schedule": lr_schedule.state_dict()}, saving succeeds for both options of INSTANTIATE_SCHEDULE_ON_WRAPPER. Using state_dict() directly means that calls to schedule.step() will not be registered in a central state object.

Environment

Current environment
  • CUDA:
    • GPU:
      • NVIDIA GeForce GTX 1080 Ti
    • available: True
    • version: 11.7
  • Lightning:
    • lightning: 2.0.8
    • lightning-cloud: 0.5.37
    • lightning-utilities: 0.8.0
    • pytorch-lightning: 2.0.3
    • torch: 2.0.1
    • torch-tb-profiler: 0.4.1
    • torchaudio: 2.0.1
    • torchdata: 0.6.1
    • torchmetrics: 0.11.4
    • torchvision: 0.15.1
  • Packages:
    • absl-py: 1.4.0
    • aiohttp: 3.8.4
    • aiosignal: 1.3.1
    • alembic: 1.11.1
    • antlr4-python3-runtime: 4.9.3
    • anyio: 3.7.0
    • appdirs: 1.4.4
    • arrow: 1.2.3
    • async-timeout: 4.0.2
    • attrs: 23.1.0
    • autopage: 0.5.1
    • backoff: 2.2.1
    • beautifulsoup4: 4.12.2
    • blessed: 1.20.0
    • cachetools: 5.3.1
    • certifi: 2023.5.7
    • cffi: 1.15.1
    • charset-normalizer: 3.1.0
    • click: 8.1.3
    • cliff: 4.3.0
    • cloudpickle: 2.2.1
    • cmaes: 0.9.1
    • cmake: 3.26.4
    • cmd2: 2.4.3
    • colorlog: 6.7.0
    • contourpy: 1.1.0
    • croniter: 1.3.15
    • cycler: 0.11.0
    • dateutils: 0.6.12
    • deepdiff: 6.3.0
    • docker-pycreds: 0.4.0
    • exceptiongroup: 1.1.1
    • fastapi: 0.98.0
    • filelock: 3.12.2
    • fonttools: 4.40.0
    • frozenlist: 1.3.3
    • fsspec: 2023.6.0
    • gitdb: 4.0.10
    • gitpython: 3.1.31
    • google-auth: 2.20.0
    • google-auth-oauthlib: 0.4.6
    • greenlet: 2.0.2
    • grpcio: 1.54.2
    • h11: 0.14.0
    • huggingface-hub: 0.15.1
    • hydra-core: 1.3.2
    • hydra-optuna-sweeper: 1.2.0
    • hydra-submitit-launcher: 1.2.0
    • idna: 3.4
    • importlib-metadata: 6.7.0
    • iniconfig: 2.0.0
    • inquirer: 3.1.3
    • itsdangerous: 2.1.2
    • jinja2: 3.1.2
    • jiwer: 3.0.0
    • joblib: 1.2.0
    • kiwisolver: 1.4.4
    • lightning: 2.0.8
    • lightning-cloud: 0.5.37
    • lightning-utilities: 0.8.0
    • lit: 16.0.6
    • mako: 1.2.4
    • markdown: 3.4.3
    • markdown-it-py: 3.0.0
    • markupsafe: 2.1.3
    • matplotlib: 3.7.1
    • mdurl: 0.1.2
    • mpmath: 1.3.0
    • multidict: 6.0.4
    • nanow2v2: 1.0
    • networkx: 3.1
    • numpy: 1.25.0
    • nvidia-cublas-cu11: 11.10.3.66
    • nvidia-cuda-cupti-cu11: 11.7.101
    • nvidia-cuda-nvrtc-cu11: 11.7.99
    • nvidia-cuda-runtime-cu11: 11.7.99
    • nvidia-cudnn-cu11: 8.5.0.96
    • nvidia-cufft-cu11: 10.9.0.58
    • nvidia-curand-cu11: 10.2.10.91
    • nvidia-cusolver-cu11: 11.4.0.1
    • nvidia-cusparse-cu11: 11.7.4.91
    • nvidia-nccl-cu11: 2.14.3
    • nvidia-nvtx-cu11: 11.7.91
    • oauthlib: 3.2.2
    • omegaconf: 2.3.0
    • optuna: 2.10.1
    • ordered-set: 4.1.0
    • packaging: 23.1
    • pandas: 2.0.2
    • pathtools: 0.1.2
    • pbr: 5.11.1
    • pillow: 9.5.0
    • pip: 23.1.2
    • pluggy: 1.2.0
    • polars: 0.18.3
    • prettytable: 3.8.0
    • protobuf: 4.23.3
    • psutil: 5.9.5
    • pyasn1: 0.5.0
    • pyasn1-modules: 0.3.0
    • pycparser: 2.21
    • pydantic: 1.10.9
    • pygments: 2.15.1
    • pyjwt: 2.7.0
    • pyparsing: 3.1.0
    • pyperclip: 1.8.2
    • pytest: 7.4.0
    • python-dateutil: 2.8.2
    • python-dotenv: 1.0.0
    • python-editor: 1.0.4
    • python-multipart: 0.0.6
    • pytorch-lightning: 2.0.3
    • pytz: 2023.3
    • pyyaml: 6.0
    • rapidfuzz: 2.13.7
    • readchar: 4.0.5
    • regex: 2023.6.3
    • requests: 2.31.0
    • requests-oauthlib: 1.3.1
    • rich: 13.4.2
    • rsa: 4.9
    • safetensors: 0.3.1
    • scikit-learn: 1.2.2
    • scipy: 1.10.1
    • seaborn: 0.12.2
    • sentry-sdk: 1.25.1
    • setproctitle: 1.3.2
    • setuptools: 65.5.0
    • six: 1.16.0
    • smmap: 5.0.0
    • sniffio: 1.3.0
    • soundfile: 0.12.1
    • soupsieve: 2.4.1
    • sqlalchemy: 2.0.17
    • starlette: 0.27.0
    • starsessions: 1.3.0
    • stevedore: 5.1.0
    • submitit: 1.4.5
    • sympy: 1.12
    • tensorboard: 2.12.0
    • tensorboard-data-server: 0.7.1
    • tensorboard-plugin-wit: 1.8.1
    • threadpoolctl: 3.1.0
    • tokenizers: 0.13.3
    • tomli: 2.0.1
    • torch: 2.0.1
    • torch-tb-profiler: 0.4.1
    • torchaudio: 2.0.1
    • torchdata: 0.6.1
    • torchmetrics: 0.11.4
    • torchvision: 0.15.1
    • tqdm: 4.65.0
    • traitlets: 5.9.0
    • transformers: 4.30.2
    • triton: 2.0.0
    • typing-extensions: 4.6.3
    • tzdata: 2023.3
    • urllib3: 1.26.16
    • uvicorn: 0.22.0
    • wandb: 0.15.5
    • wcwidth: 0.2.6
    • websocket-client: 1.6.0
    • websockets: 11.0.3
    • werkzeug: 2.3.6
    • wheel: 0.40.0
    • yarl: 1.9.2
    • zipp: 3.15.0
  • System:
    • OS: Linux
    • architecture:
      • 64bit
      • ELF
    • processor:
    • python: 3.10.12
    • release: 6.4.12-arch1-1
    • version: Proposal for help #1 SMP PREEMPT_DYNAMIC Thu, 24 Aug 2023 00:38:14 +0000

More info

No response

cc @awaelchli @carmocca @justusschock

@nikvaessen nikvaessen added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Sep 6, 2023
@awaelchli
Copy link
Contributor

awaelchli commented Sep 6, 2023

Hey @nikvaessen
Thanks again for the feedback here! And that you provide a runnable example is absolutely fantastic!

You are right, the pickle error is due to the fact that the scheduler uses the wrapped optimizer. In general however, we shouldn't expect optimizers, schedulers or other modules to be pickleable.

My proposal is that we should support saving stateful objects other than optimizers and modules. The LRScheduler has a state_dict and a load_state_dict method. We could automatically save and load these correctly for the user.

As a workaround for you, I suggest you do this for now:
when saving:

    state = {"network": network, "opt": opt, "lr_schedule": lr_schedule.state_dict()}
    fabric.save("exp_decay.ckpt", state)

when loading:

    scheduler_state = {}
    state = {"network": network, "opt": opt, "lr_schedule": scheduler_state}
    fabric.load(path, state)
    lr_schedule.load_state_dict(scheduler_state)

This is an important use case we need to cover (and document).

@awaelchli awaelchli added fabric lightning.fabric.Fabric checkpointing Related to checkpointing lr scheduler and removed needs triage Waiting to be triaged by maintainers labels Sep 6, 2023
@awaelchli awaelchli added this to the future milestone Sep 6, 2023
@carmocca carmocca modified the milestones: future, 2.1 Sep 28, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working checkpointing Related to checkpointing fabric lightning.fabric.Fabric lr scheduler ver: 2.0.x
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants