diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index b3248b265eb3a..240c942c6644b 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -83,6 +83,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue that caused `num_nodes` not to be set correctly for `FSDPStrategy` ([#17438](https://github.com/Lightning-AI/lightning/pull/17438)) +- Fixed signature inspection of decorated hooks ([#17507](https://github.com/Lightning-AI/lightning/pull/17507)) + + ## [2.0.1.post0] - 2023-04-11 ### Fixed diff --git a/src/lightning/pytorch/utilities/signature_utils.py b/src/lightning/pytorch/utilities/signature_utils.py index bed5e2f18a31b..0f41c5948fb46 100644 --- a/src/lightning/pytorch/utilities/signature_utils.py +++ b/src/lightning/pytorch/utilities/signature_utils.py @@ -25,6 +25,9 @@ def is_param_in_hook_signature( explicit: whether the parameter has to be explicitly declared min_args: whether the `signature` has at least `min_args` parameters """ + if hasattr(hook_fx, "__wrapped__"): + # in case the hook has a decorator + hook_fx = hook_fx.__wrapped__ parameters = inspect.getfullargspec(hook_fx) args = parameters.args[1:] # ignore `self` return ( diff --git a/tests/tests_pytorch/utilities/test_signature_utils.py b/tests/tests_pytorch/utilities/test_signature_utils.py new file mode 100644 index 0000000000000..7ad1c5855b50c --- /dev/null +++ b/tests/tests_pytorch/utilities/test_signature_utils.py @@ -0,0 +1,36 @@ +import torch + +from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature + + +def test_param_in_hook_signature(): + class LightningModule: + def validation_step(self, dataloader_iter, batch_idx): + ... + + model = LightningModule() + assert is_param_in_hook_signature(model.validation_step, "dataloader_iter", explicit=True) + + class LightningModule: + @torch.no_grad() + def validation_step(self, dataloader_iter, batch_idx): + ... + + model = LightningModule() + assert is_param_in_hook_signature(model.validation_step, "dataloader_iter", explicit=True) + + class LightningModule: + def validation_step(self, *args): + ... + + model = LightningModule() + assert not is_param_in_hook_signature(model.validation_step, "dataloader_iter", explicit=True) + assert is_param_in_hook_signature(model.validation_step, "dataloader_iter", explicit=False) + + class LightningModule: + def validation_step(self, a, b): + ... + + model = LightningModule() + assert not is_param_in_hook_signature(model.validation_step, "dataloader_iter", min_args=3) + assert is_param_in_hook_signature(model.validation_step, "dataloader_iter", min_args=2)