diff --git a/src/lightning/fabric/fabric.py b/src/lightning/fabric/fabric.py index 6e76a92f7cd13..5ed0fceb3ba1c 100644 --- a/src/lightning/fabric/fabric.py +++ b/src/lightning/fabric/fabric.py @@ -48,7 +48,13 @@ from lightning.fabric.utilities.seed import seed_everything from lightning.fabric.utilities.types import ReduceOp from lightning.fabric.utilities.warnings import PossibleUserWarning -from lightning.fabric.wrappers import _FabricDataLoader, _FabricModule, _FabricOptimizer, _unwrap_objects +from lightning.fabric.wrappers import ( + _FabricDataLoader, + _FabricModule, + _FabricOptimizer, + _unwrap_compiled, + _unwrap_objects, +) def _do_nothing(*_: Any) -> None: @@ -547,6 +553,7 @@ def no_backward_sync(self, module: _FabricModule, enabled: bool = True) -> Gener enabled: Whether the context manager is enabled or not. ``True`` means skip the sync, ``False`` means do not skip. """ + module = _unwrap_compiled(module) if not isinstance(module, _FabricModule): raise TypeError( "You need to set up the model first before you can call `self.no_backward_sync()`:" @@ -638,7 +645,8 @@ def load( # We need to unwrap objects (see above) but this creates a new dictionary. In-place updates # (for user metadata) wouldn't show up in the original dict, so we need to copy the data back. for k in list(unwrapped_state.keys()): - if isinstance(state[k], (_FabricModule, _FabricOptimizer, _FabricDataLoader)): + obj = _unwrap_compiled(state[k]) + if isinstance(obj, (_FabricModule, _FabricOptimizer, _FabricDataLoader)): continue state[k] = unwrapped_state[k] return remainder diff --git a/src/lightning/fabric/wrappers.py b/src/lightning/fabric/wrappers.py index 96ee7aedea205..947510244d0df 100644 --- a/src/lightning/fabric/wrappers.py +++ b/src/lightning/fabric/wrappers.py @@ -28,6 +28,7 @@ from lightning.fabric.utilities import move_data_to_device from lightning.fabric.utilities.data import _set_sampler_epoch from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0 from lightning.fabric.utilities.types import Optimizable from lightning.fabric.utilities.warnings import PossibleUserWarning @@ -218,15 +219,35 @@ def _unwrap_objects(collection: Any) -> Any: def _unwrap( obj: Union[_FabricModule, _FabricOptimizer, _FabricDataLoader] ) -> Union[nn.Module, Optimizer, DataLoader]: - if isinstance(obj, _FabricModule): - return obj._forward_module + if isinstance(unwrapped := _unwrap_compiled(obj), _FabricModule): + return unwrapped._forward_module if isinstance(obj, _FabricOptimizer): return obj.optimizer if isinstance(obj, _FabricDataLoader): return obj._dataloader return obj - return apply_to_collection(collection, dtype=(_FabricModule, _FabricOptimizer, _FabricDataLoader), function=_unwrap) + types = [_FabricModule, _FabricOptimizer, _FabricDataLoader] + if _TORCH_GREATER_EQUAL_2_0: + from torch._dynamo import OptimizedModule + + types.append(OptimizedModule) + + return apply_to_collection(collection, dtype=tuple(types), function=_unwrap) + + +def _unwrap_compiled(obj: Any) -> Any: + """Removes the :class:`torch._dynamo.OptimizedModule` around the object if it is wrapped. + + Use this function before instance checks against e.g. :class:`_FabricModule`. + """ + if not _TORCH_GREATER_EQUAL_2_0: + return obj + from torch._dynamo import OptimizedModule + + if isinstance(obj, OptimizedModule): + return obj._orig_mod + return obj def is_wrapped(obj: object) -> bool: @@ -239,4 +260,5 @@ def is_wrapped(obj: object) -> bool: Args: obj: The object to test. """ + obj = _unwrap_compiled(obj) return isinstance(obj, (_FabricModule, _FabricOptimizer, _FabricDataLoader)) diff --git a/tests/tests_fabric/test_wrappers.py b/tests/tests_fabric/test_wrappers.py index 0996627e57a15..72ce305506304 100644 --- a/tests/tests_fabric/test_wrappers.py +++ b/tests/tests_fabric/test_wrappers.py @@ -23,7 +23,7 @@ from lightning.fabric.fabric import Fabric from lightning.fabric.plugins import Precision from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin -from lightning.fabric.wrappers import _FabricDataLoader, _FabricModule, _FabricOptimizer, is_wrapped +from lightning.fabric.wrappers import _FabricDataLoader, _FabricModule, _FabricOptimizer, _unwrap_objects, is_wrapped from tests_fabric.helpers.runif import RunIf @@ -358,7 +358,8 @@ def zero_grad(self, set_grads_to_None=False): custom_zero_grad.assert_called_with(set_grads_to_None=False) -def test_is_wrapped(): +@pytest.mark.parametrize("compile", [False, pytest.param(True, marks=RunIf(dynamo=True))]) +def test_is_wrapped(compile): """Test that the `is_wrapped` utility recognizes when an object was wrapped by Fabric.""" assert not is_wrapped(None) @@ -368,6 +369,15 @@ def test_is_wrapped(): wrapped = _FabricModule(module, Mock()) assert is_wrapped(wrapped) + # _FabricModule inside an OptimizedModule + if compile: + from torch._dynamo import OptimizedModule + + module = torch.nn.Linear(2, 2) + wrapped = torch.compile(_FabricModule(module, Mock())) + assert isinstance(wrapped, OptimizedModule) + assert is_wrapped(wrapped) + # _FabricOptimizer optimizer = torch.optim.Adam(module.parameters()) assert not is_wrapped(optimizer) @@ -381,6 +391,43 @@ def test_is_wrapped(): assert is_wrapped(wrapped) +@pytest.mark.parametrize("compile", [False, pytest.param(True, marks=RunIf(dynamo=True))]) +def test_unwrap_objects(compile): + # empty container + assert _unwrap_objects({}) == {} + + # container with pure objects and wrapped objects + module = torch.nn.Linear(1, 1) + wrapped_module = _FabricModule(module, Mock()) + if compile: + wrapped_module = torch.compile(wrapped_module) + optimizer = torch.optim.Adam(module.parameters()) + wrapped_optimizer = _FabricOptimizer(optimizer, Mock()) + dataloader = DataLoader([1, 2, 3]) + wrapped_dataloader = _FabricDataLoader(dataloader) + container = { + "int": 1, + "module": module, + "wrapped_module": wrapped_module, + "optimizer": optimizer, + "wrapped_optimizer": wrapped_optimizer, + "dataloader": dataloader, + "wrapped_dataloader": wrapped_dataloader, + "nested": [module, wrapped_module, optimizer, wrapped_optimizer, dataloader, wrapped_dataloader], + } + expected = { + "int": 1, + "module": module, + "wrapped_module": wrapped_module._forward_module, + "optimizer": optimizer, + "wrapped_optimizer": optimizer, + "dataloader": dataloader, + "wrapped_dataloader": dataloader, + "nested": [module, wrapped_module._forward_module, optimizer, optimizer, dataloader, dataloader], + } + assert _unwrap_objects(container) == expected + + def test_step_method_redirection(): """Test that the FabricModule redirects the special `LightningModule.*_step` methods through the forward- module."""