diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index c7d286e1ed032..50c531f75afb6 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -17,6 +17,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a potential bug with uploading model checkpoints to Neptune.ai by uploading files from stream ([#17430](https://github.com/Lightning-AI/lightning/pull/17430)) +- Fixed signature inspection of decorated hooks ([#17507](https://github.com/Lightning-AI/lightning/pull/17507)) + + ## [2.0.2] - 2023-04-24 diff --git a/src/lightning/pytorch/utilities/signature_utils.py b/src/lightning/pytorch/utilities/signature_utils.py index bed5e2f18a31b..0f41c5948fb46 100644 --- a/src/lightning/pytorch/utilities/signature_utils.py +++ b/src/lightning/pytorch/utilities/signature_utils.py @@ -25,6 +25,9 @@ def is_param_in_hook_signature( explicit: whether the parameter has to be explicitly declared min_args: whether the `signature` has at least `min_args` parameters """ + if hasattr(hook_fx, "__wrapped__"): + # in case the hook has a decorator + hook_fx = hook_fx.__wrapped__ parameters = inspect.getfullargspec(hook_fx) args = parameters.args[1:] # ignore `self` return ( diff --git a/tests/tests_pytorch/utilities/test_signature_utils.py b/tests/tests_pytorch/utilities/test_signature_utils.py new file mode 100644 index 0000000000000..7ad1c5855b50c --- /dev/null +++ b/tests/tests_pytorch/utilities/test_signature_utils.py @@ -0,0 +1,36 @@ +import torch + +from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature + + +def test_param_in_hook_signature(): + class LightningModule: + def validation_step(self, dataloader_iter, batch_idx): + ... + + model = LightningModule() + assert is_param_in_hook_signature(model.validation_step, "dataloader_iter", explicit=True) + + class LightningModule: + @torch.no_grad() + def validation_step(self, dataloader_iter, batch_idx): + ... + + model = LightningModule() + assert is_param_in_hook_signature(model.validation_step, "dataloader_iter", explicit=True) + + class LightningModule: + def validation_step(self, *args): + ... + + model = LightningModule() + assert not is_param_in_hook_signature(model.validation_step, "dataloader_iter", explicit=True) + assert is_param_in_hook_signature(model.validation_step, "dataloader_iter", explicit=False) + + class LightningModule: + def validation_step(self, a, b): + ... + + model = LightningModule() + assert not is_param_in_hook_signature(model.validation_step, "dataloader_iter", min_args=3) + assert is_param_in_hook_signature(model.validation_step, "dataloader_iter", min_args=2)