Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support compiling a module after it was set up by Fabric #17529

Merged
merged 10 commits into from
May 3, 2023
12 changes: 10 additions & 2 deletions src/lightning/fabric/fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,13 @@
from lightning.fabric.utilities.seed import seed_everything
from lightning.fabric.utilities.types import ReduceOp
from lightning.fabric.utilities.warnings import PossibleUserWarning
from lightning.fabric.wrappers import _FabricDataLoader, _FabricModule, _FabricOptimizer, _unwrap_objects
from lightning.fabric.wrappers import (
_FabricDataLoader,
_FabricModule,
_FabricOptimizer,
_unwrap_compiled,
_unwrap_objects,
)


def _do_nothing(*_: Any) -> None:
Expand Down Expand Up @@ -547,6 +553,7 @@ def no_backward_sync(self, module: _FabricModule, enabled: bool = True) -> Gener
enabled: Whether the context manager is enabled or not. ``True`` means skip the sync, ``False`` means do not
skip.
"""
module = _unwrap_compiled(module)
if not isinstance(module, _FabricModule):
raise TypeError(
"You need to set up the model first before you can call `self.no_backward_sync()`:"
Expand Down Expand Up @@ -638,7 +645,8 @@ def load(
# We need to unwrap objects (see above) but this creates a new dictionary. In-place updates
# (for user metadata) wouldn't show up in the original dict, so we need to copy the data back.
for k in list(unwrapped_state.keys()):
if isinstance(state[k], (_FabricModule, _FabricOptimizer, _FabricDataLoader)):
obj = _unwrap_compiled(state[k])
if isinstance(obj, (_FabricModule, _FabricOptimizer, _FabricDataLoader)):
continue
state[k] = unwrapped_state[k]
return remainder
Expand Down
28 changes: 25 additions & 3 deletions src/lightning/fabric/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from lightning.fabric.utilities import move_data_to_device
from lightning.fabric.utilities.data import _set_sampler_epoch
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
from lightning.fabric.utilities.types import Optimizable
from lightning.fabric.utilities.warnings import PossibleUserWarning

Expand Down Expand Up @@ -218,15 +219,35 @@ def _unwrap_objects(collection: Any) -> Any:
def _unwrap(
obj: Union[_FabricModule, _FabricOptimizer, _FabricDataLoader]
) -> Union[nn.Module, Optimizer, DataLoader]:
if isinstance(obj, _FabricModule):
return obj._forward_module
if isinstance(unwrapped := _unwrap_compiled(obj), _FabricModule):
return unwrapped._forward_module
if isinstance(obj, _FabricOptimizer):
return obj.optimizer
if isinstance(obj, _FabricDataLoader):
return obj._dataloader
return obj

return apply_to_collection(collection, dtype=(_FabricModule, _FabricOptimizer, _FabricDataLoader), function=_unwrap)
types = [_FabricModule, _FabricOptimizer, _FabricDataLoader]
if _TORCH_GREATER_EQUAL_2_0:
from torch._dynamo import OptimizedModule

types.append(OptimizedModule)

return apply_to_collection(collection, dtype=tuple(types), function=_unwrap)


def _unwrap_compiled(obj: Any) -> Any:
"""Removes the :class:`torch._dynamo.OptimizedModule` around the object if it is wrapped.

Use this function before instance checks against e.g. :class:`_FabricModule`.
"""
if not _TORCH_GREATER_EQUAL_2_0:
return obj
from torch._dynamo import OptimizedModule

if isinstance(obj, OptimizedModule):
return obj._orig_mod
return obj


def is_wrapped(obj: object) -> bool:
Expand All @@ -239,4 +260,5 @@ def is_wrapped(obj: object) -> bool:
Args:
obj: The object to test.
"""
obj = _unwrap_compiled(obj)
return isinstance(obj, (_FabricModule, _FabricOptimizer, _FabricDataLoader))
51 changes: 49 additions & 2 deletions tests/tests_fabric/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from lightning.fabric.fabric import Fabric
from lightning.fabric.plugins import Precision
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
from lightning.fabric.wrappers import _FabricDataLoader, _FabricModule, _FabricOptimizer, is_wrapped
from lightning.fabric.wrappers import _FabricDataLoader, _FabricModule, _FabricOptimizer, _unwrap_objects, is_wrapped
from tests_fabric.helpers.runif import RunIf


Expand Down Expand Up @@ -358,7 +358,8 @@ def zero_grad(self, set_grads_to_None=False):
custom_zero_grad.assert_called_with(set_grads_to_None=False)


def test_is_wrapped():
@pytest.mark.parametrize("compile", [False, pytest.param(True, marks=RunIf(dynamo=True))])
def test_is_wrapped(compile):
"""Test that the `is_wrapped` utility recognizes when an object was wrapped by Fabric."""
assert not is_wrapped(None)

Expand All @@ -368,6 +369,15 @@ def test_is_wrapped():
wrapped = _FabricModule(module, Mock())
assert is_wrapped(wrapped)

# _FabricModule inside an OptimizedModule
if compile:
from torch._dynamo import OptimizedModule

module = torch.nn.Linear(2, 2)
wrapped = torch.compile(_FabricModule(module, Mock()))
assert isinstance(wrapped, OptimizedModule)
assert is_wrapped(wrapped)

# _FabricOptimizer
optimizer = torch.optim.Adam(module.parameters())
assert not is_wrapped(optimizer)
Expand All @@ -381,6 +391,43 @@ def test_is_wrapped():
assert is_wrapped(wrapped)


@pytest.mark.parametrize("compile", [False, pytest.param(True, marks=RunIf(dynamo=True))])
def test_unwrap_objects(compile):
# empty container
assert _unwrap_objects({}) == {}

# container with pure objects and wrapped objects
module = torch.nn.Linear(1, 1)
wrapped_module = _FabricModule(module, Mock())
if compile:
wrapped_module = torch.compile(wrapped_module)
optimizer = torch.optim.Adam(module.parameters())
wrapped_optimizer = _FabricOptimizer(optimizer, Mock())
dataloader = DataLoader([1, 2, 3])
wrapped_dataloader = _FabricDataLoader(dataloader)
container = {
"int": 1,
"module": module,
"wrapped_module": wrapped_module,
"optimizer": optimizer,
"wrapped_optimizer": wrapped_optimizer,
"dataloader": dataloader,
"wrapped_dataloader": wrapped_dataloader,
"nested": [module, wrapped_module, optimizer, wrapped_optimizer, dataloader, wrapped_dataloader],
}
expected = {
"int": 1,
"module": module,
"wrapped_module": wrapped_module._forward_module,
"optimizer": optimizer,
"wrapped_optimizer": optimizer,
"dataloader": dataloader,
"wrapped_dataloader": dataloader,
"nested": [module, wrapped_module._forward_module, optimizer, optimizer, dataloader, dataloader],
}
assert _unwrap_objects(container) == expected


def test_step_method_redirection():
"""Test that the FabricModule redirects the special `LightningModule.*_step` methods through the forward-
module."""
Expand Down