Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FSDP full-precision param_dtype training with PyTorch < 2.0 triggers FSDP assertion error #18277

Closed
speediedan opened this issue Aug 10, 2023 · 0 comments · Fixed by #18278
Closed
Assignees
Labels
bug Something isn't working strategy: fsdp Fully Sharded Data Parallel ver: 2.1.x
Milestone

Comments

@speediedan
Copy link
Contributor

speediedan commented Aug 10, 2023

Bug description

When FSDP training with full precision param_dtypes (16-mixed, bf16-mixed and 32-true configurations) and PyTorch < 2.0, FSDP training will encounter this assertion error.

This is because FSDP uses the noneness of param_dtype as a proxy for the _uses_param_mixed_precision property and FSDPPrecisionPlugin currently sets the default param_dtype to torch.float32 when training in full precision.

I'll be submitting a PR shortly that sets MixedPrecision param_dtype to None when FSDP training with full precision param_dtypes and PyTorch < 2.0. Because there is substantial overlap with #18230, I'll be including a fix to that including the lightning_module_state_dict patch as well.

What version are you seeing the problem on?

master

How to reproduce the bug

To reproduce an example of the issue, run ./tests/tests_pytorch/strategies/test_fsdp.py::test_configure_model[32-true-expected_dtype0], without fast_dev_run enabled and after patching lightning_module_state_dict to allow the FSDP 1.x test to proceed:
https://github.com/Lightning-AI/lightning/blob/c83774a1093fab53fef02ae2b824dd85ee21af0a/src/lightning/pytorch/strategies/fsdp.py#L171-L179

Patch the above with:

	def lightning_module_state_dict(self) -> Dict[str, Any]:
	"""Gathers the full state dict by unsharding all the parameters.

	To avoid OOM, the returned parameters will only be returned on rank 0 and on CPU. All other ranks get an empty
	dict.
	"""
	from torch.distributed.fsdp import FullyShardedDataParallel
	if _TORCH_GREATER_EQUAL_2_0:
	  from torch.distributed.fsdp.api import FullStateDictConfig, StateDictType
	else:
	  from torch.distributed.fsdp import FullStateDictConfig, StateDictType

Error messages and logs

An example of the produced errors:

./tests/tests_pytorch/strategies/test_fsdp.py::test_configure_model[32-true-expected_dtype0] Failed: [undefined]AssertionError
precision = '32-true', expected_dtype = torch.float32

    @RunIf(min_cuda_gpus=2, skip_windows=True, standalone=False)
    @pytest.mark.parametrize(
        ("precision", "expected_dtype"),
        [
            ("32-true", torch.float32),
        ],
    )
    def test_configure_model(precision, expected_dtype):
        """Test that the module under configure_model gets moved to the right device and dtype."""
        trainer = Trainer(
            accelerator="cuda",
            devices=2,
            strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy),
            precision=precision,
            #fast_dev_run=1,
        )
    
        class MyModel(BoringModel):
            def configure_model(self):
                self.layer = torch.nn.Linear(32, 2)
                # The model is on the CPU until after `.setup()``
                # TODO: Support initialization on meta device
                expected_device = torch.device("cpu")
                assert self.layer.weight.device == expected_device
                assert self.layer.weight.dtype == expected_dtype
    
            def on_fit_start(self):
                # Parameters get sharded in `.setup()` and moved to the target device
                assert self.layer.weight.device == torch.device("cuda", self.local_rank)
                assert self.layer.weight.dtype == expected_dtype
    
        model = MyModel()
>       trainer.fit(model)

tests/tests_pytorch/strategies/test_fsdp.py:675: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
src/lightning/pytorch/trainer/trainer.py:542: in fit
    call._call_and_handle_interrupt(
src/lightning/pytorch/trainer/call.py:43: in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
src/lightning/pytorch/strategies/launchers/subprocess_script.py:98: in launch
    return function(*args, **kwargs)
src/lightning/pytorch/trainer/trainer.py:581: in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
src/lightning/pytorch/trainer/trainer.py:990: in _run
    results = self._run_stage()
src/lightning/pytorch/trainer/trainer.py:1033: in _run_stage
    self.fit_loop.run()
src/lightning/pytorch/loops/fit_loop.py:203: in run
    self.on_advance_end()
src/lightning/pytorch/loops/fit_loop.py:367: in on_advance_end
    call._call_callback_hooks(trainer, "on_train_epoch_end", monitoring_callbacks=True)
src/lightning/pytorch/trainer/call.py:208: in _call_callback_hooks
    fn(trainer, trainer.lightning_module, *args, **kwargs)
src/lightning/pytorch/callbacks/model_checkpoint.py:309: in on_train_epoch_end
    self._save_topk_checkpoint(trainer, monitor_candidates)
src/lightning/pytorch/callbacks/model_checkpoint.py:368: in _save_topk_checkpoint
    self._save_none_monitor_checkpoint(trainer, monitor_candidates)
src/lightning/pytorch/callbacks/model_checkpoint.py:684: in _save_none_monitor_checkpoint
    self._save_checkpoint(trainer, filepath)
src/lightning/pytorch/callbacks/model_checkpoint.py:371: in _save_checkpoint
    trainer.save_checkpoint(filepath, self.save_weights_only)
src/lightning/pytorch/trainer/trainer.py:1373: in save_checkpoint
    checkpoint = self._checkpoint_connector.dump_checkpoint(weights_only)
src/lightning/pytorch/trainer/connectors/checkpoint_connector.py:438: in dump_checkpoint
    "state_dict": self._get_lightning_module_state_dict(),
src/lightning/pytorch/trainer/connectors/checkpoint_connector.py:496: in _get_lightning_module_state_dict
    return self.trainer.strategy.lightning_module_state_dict()
src/lightning/pytorch/strategies/fsdp.py:191: in lightning_module_state_dict
    return self.model.state_dict()
/opt/miniconda/envs/lightning_dev_pt1x/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:2416: in state_dict
    state_dict = super().state_dict(*args, **kwargs)
/opt/miniconda/envs/lightning_dev_pt1x/lib/python3.10/site-packages/torch/nn/modules/module.py:1448: in state_dict
    module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
/opt/miniconda/envs/lightning_dev_pt1x/lib/python3.10/site-packages/torch/nn/modules/module.py:1448: in state_dict
    module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
/opt/miniconda/envs/lightning_dev_pt1x/lib/python3.10/site-packages/torch/nn/modules/module.py:1448: in state_dict
    module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
/opt/miniconda/envs/lightning_dev_pt1x/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:2402: in state_dict
    with summon_ctx:
/opt/miniconda/envs/lightning_dev_pt1x/lib/python3.10/contextlib.py:135: in __enter__
    return next(self.gen)
/opt/miniconda/envs/lightning_dev_pt1x/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:2981: in _summon_full_params
    free_unsharded_flat_params = [handle.needs_unshard() for handle in self._handles]
/opt/miniconda/envs/lightning_dev_pt1x/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:2981: in <listcomp>
    free_unsharded_flat_params = [handle.needs_unshard() for handle in self._handles]
/opt/miniconda/envs/lightning_dev_pt1x/lib/python3.10/site-packages/torch/distributed/fsdp/flat_param.py:681: in needs_unshard
    unsharded_flat_param = self._get_padded_unsharded_flat_param()
/opt/miniconda/envs/lightning_dev_pt1x/lib/python3.10/site-packages/torch/distributed/fsdp/flat_param.py:714: in _get_padded_unsharded_flat_param
    p_assert(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

cond = False, s = 'Expects full precision but got torch.float32'
raise_assertion_error = True

    def p_assert(cond: Any, s: Any, raise_assertion_error: bool = True) -> None:
        """This is used as an alternate to ``assert`` when in the backward context
        to print the error message ``s`` since otherwise, it is swallowed."""
        if not cond:
            print(s)
            traceback.print_stack()
            if raise_assertion_error:
>               raise AssertionError
E               AssertionError

/opt/miniconda/envs/lightning_dev_pt1x/lib/python3.10/site-packages/torch/distributed/fsdp/_utils.py:149: AssertionError

Environment

Current environment
  • CUDA:
    • GPU:
      • NVIDIA GeForce RTX 4090
      • NVIDIA GeForce RTX 2070 SUPER
    • available: True
    • version: 11.7
  • Lightning:
    • lightning: 2.1.0.dev0
    • lightning-cloud: 0.5.37
    • lightning-utilities: 0.9.0
    • pytorch-lightning: 2.0.6
    • torch: 1.13.1
    • torchmetrics: 0.11.4
    • torchvision: 0.14.1
  • Packages:
    • absl-py: 1.4.0
    • aiohttp: 3.8.5
    • aiosignal: 1.3.1
    • alabaster: 0.7.13
    • antlr4-python3-runtime: 4.9.3
    • anyio: 3.7.1
    • argon2-cffi: 21.3.0
    • argon2-cffi-bindings: 21.2.0
    • arrow: 1.2.3
    • asttokens: 2.2.1
    • async-lru: 2.0.4
    • async-timeout: 4.0.3
    • attrs: 23.1.0
    • babel: 2.12.1
    • backcall: 0.2.0
    • backoff: 2.2.1
    • beautifulsoup4: 4.12.2
    • black: 23.7.0
    • bleach: 6.0.0
    • blessed: 1.20.0
    • brotlipy: 0.7.0
    • cachetools: 5.3.1
    • certifi: 2023.7.22
    • cffi: 1.15.1
    • charset-normalizer: 2.0.4
    • click: 8.1.6
    • cloudpickle: 2.2.1
    • coloredlogs: 15.0.1
    • comm: 0.1.4
    • contourpy: 1.1.0
    • coverage: 7.2.7
    • croniter: 1.4.1
    • cryptography: 41.0.2
    • curio: 1.6
    • cycler: 0.11.0
    • dateutils: 0.6.12
    • debugpy: 1.6.7.post1
    • decorator: 5.1.1
    • deepdiff: 6.3.1
    • deepspeed: 0.10.0
    • defusedxml: 0.7.1
    • docstring-parser: 0.15
    • docutils: 0.20.1
    • entrypoints: 0.4
    • exceptiongroup: 1.1.2
    • executing: 1.2.0
    • fastapi: 0.101.0
    • fastjsonschema: 2.18.0
    • flatbuffers: 23.5.26
    • fonttools: 4.42.0
    • fqdn: 1.5.1
    • frozenlist: 1.4.0
    • fsspec: 2023.6.0
    • google-auth: 2.22.0
    • google-auth-oauthlib: 1.0.0
    • grpcio: 1.57.0
    • gym: 0.26.2
    • gym-notices: 0.0.8
    • h11: 0.14.0
    • hjson: 3.1.0
    • humanfriendly: 10.0
    • hydra-core: 1.3.2
    • idna: 3.4
    • imagesize: 1.4.1
    • importlib-resources: 6.0.1
    • iniconfig: 2.0.0
    • inquirer: 3.1.3
    • ipykernel: 6.25.1
    • ipyparallel: 8.6.1
    • ipython: 8.1.1
    • ipython-genutils: 0.2.0
    • ipywidgets: 8.1.0
    • isoduration: 20.11.0
    • itsdangerous: 2.1.2
    • jedi: 0.19.0
    • jinja2: 3.1.2
    • joblib: 1.3.2
    • json5: 0.9.14
    • jsonargparse: 4.22.1
    • jsonpointer: 2.4
    • jsonschema: 4.19.0
    • jsonschema-specifications: 2023.7.1
    • jupyter-client: 8.3.0
    • jupyter-core: 5.3.1
    • jupyter-events: 0.7.0
    • jupyter-lsp: 2.2.0
    • jupyter-server: 2.7.0
    • jupyter-server-terminals: 0.4.4
    • jupyterlab: 4.0.4
    • jupyterlab-pygments: 0.2.2
    • jupyterlab-server: 2.24.0
    • jupyterlab-widgets: 3.0.8
    • kiwisolver: 1.4.4
    • lightning: 2.1.0.dev0
    • lightning-cloud: 0.5.37
    • lightning-utilities: 0.9.0
    • markdown: 3.4.4
    • markdown-it-py: 3.0.0
    • markupsafe: 2.1.3
    • matplotlib: 3.7.2
    • matplotlib-inline: 0.1.6
    • mdurl: 0.1.2
    • mistune: 3.0.1
    • mkl-fft: 1.3.6
    • mkl-random: 1.2.2
    • mkl-service: 2.4.0
    • mpmath: 1.3.0
    • multidict: 6.0.4
    • mypy-extensions: 1.0.0
    • nbclient: 0.8.0
    • nbconvert: 7.7.3
    • nbformat: 5.9.2
    • nest-asyncio: 1.5.7
    • ninja: 1.11.1
    • notebook: 7.0.2
    • notebook-shim: 0.2.3
    • numpy: 1.25.2
    • oauthlib: 3.2.2
    • omegaconf: 2.3.0
    • onnx: 1.12.0
    • onnxruntime: 1.15.1
    • ordered-set: 4.1.0
    • outcome: 1.2.0
    • overrides: 7.4.0
    • packaging: 23.1
    • pandas: 2.0.3
    • pandocfilters: 1.5.0
    • parso: 0.8.3
    • pathspec: 0.11.2
    • pexpect: 4.8.0
    • pickleshare: 0.7.5
    • pillow: 9.4.0
    • pip: 23.2.1
    • platformdirs: 3.10.0
    • pluggy: 1.2.0
    • prometheus-client: 0.17.1
    • prompt-toolkit: 3.0.39
    • protobuf: 3.20.1
    • psutil: 5.9.5
    • ptyprocess: 0.7.0
    • pure-eval: 0.2.2
    • py: 1.11.0
    • py-cpuinfo: 9.0.0
    • pyasn1: 0.5.0
    • pyasn1-modules: 0.3.0
    • pycparser: 2.21
    • pydantic: 1.10.12
    • pygame: 2.1.0
    • pygments: 2.16.1
    • pyjwt: 2.8.0
    • pyopenssl: 23.2.0
    • pyparsing: 3.0.9
    • pysocks: 1.7.1
    • pytest: 7.4.0
    • pytest-asyncio: 0.21.1
    • pytest-cov: 4.1.0
    • pytest-forked: 1.4.0
    • pytest-random-order: 1.1.0
    • pytest-rerunfailures: 10.3
    • python-dateutil: 2.8.2
    • python-editor: 1.0.4
    • python-json-logger: 2.0.7
    • python-multipart: 0.0.6
    • pytorch-lightning: 2.0.6
    • pytz: 2023.3
    • pyyaml: 6.0.1
    • pyzmq: 25.1.1
    • qtconsole: 5.4.3
    • qtpy: 2.3.1
    • readchar: 4.0.5
    • referencing: 0.30.2
    • requests: 2.31.0
    • requests-oauthlib: 1.3.1
    • rfc3339-validator: 0.1.4
    • rfc3986-validator: 0.1.1
    • rich: 13.5.2
    • rpds-py: 0.9.2
    • rsa: 4.9
    • scikit-learn: 1.3.0
    • scipy: 1.11.1
    • send2trash: 1.8.2
    • setuptools: 68.0.0
    • six: 1.16.0
    • sniffio: 1.3.0
    • snowballstemmer: 2.2.0
    • sortedcontainers: 2.4.0
    • soupsieve: 2.4.1
    • sphinx: 7.1.2
    • sphinxcontrib-applehelp: 1.0.6
    • sphinxcontrib-devhelp: 1.0.4
    • sphinxcontrib-htmlhelp: 2.0.3
    • sphinxcontrib-jsmath: 1.0.1
    • sphinxcontrib-qthelp: 1.0.5
    • sphinxcontrib-serializinghtml: 1.1.7
    • stack-data: 0.6.2
    • starlette: 0.27.0
    • starsessions: 1.3.0
    • sympy: 1.12
    • tensorboard: 2.14.0
    • tensorboard-data-server: 0.7.1
    • tensorboardx: 2.6.2
    • terminado: 0.17.1
    • testpath: 0.6.0
    • threadpoolctl: 3.2.0
    • tinycss2: 1.2.1
    • tomli: 2.0.1
    • torch: 1.13.1
    • torchmetrics: 0.11.4
    • torchvision: 0.14.1
    • tornado: 6.3.2
    • tqdm: 4.66.1
    • traitlets: 5.9.0
    • trio: 0.22.2
    • typeshed-client: 2.3.0
    • typing-extensions: 4.7.1
    • tzdata: 2023.3
    • uri-template: 1.3.0
    • urllib3: 1.26.16
    • uvicorn: 0.23.2
    • wcwidth: 0.2.6
    • webcolors: 1.13
    • webencodings: 0.5.1
    • websocket-client: 1.6.1
    • websockets: 11.0.3
    • werkzeug: 2.3.6
    • wheel: 0.38.4
    • widgetsnbextension: 4.0.8
    • yarl: 1.9.2
  • System:

More info

No response

cc @awaelchli @carmocca

@speediedan speediedan added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Aug 10, 2023
@awaelchli awaelchli added strategy: fsdp Fully Sharded Data Parallel and removed needs triage Waiting to be triaged by maintainers labels Aug 13, 2023
@awaelchli awaelchli added this to the 2.0.x milestone Aug 13, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working strategy: fsdp Fully Sharded Data Parallel ver: 2.1.x
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants