Skip to content

Commit

Permalink
Support inspecting the signature of decorated hooks (#17507)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Apr 28, 2023
1 parent 82aa7b4 commit 3867045
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue that caused `num_nodes` not to be set correctly for `FSDPStrategy` ([#17438](https://github.com/Lightning-AI/lightning/pull/17438))


- Fixed signature inspection of decorated hooks ([#17507](https://github.com/Lightning-AI/lightning/pull/17507))


## [2.0.1.post0] - 2023-04-11

### Fixed
Expand Down
3 changes: 3 additions & 0 deletions src/lightning/pytorch/utilities/signature_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ def is_param_in_hook_signature(
explicit: whether the parameter has to be explicitly declared
min_args: whether the `signature` has at least `min_args` parameters
"""
if hasattr(hook_fx, "__wrapped__"):
# in case the hook has a decorator
hook_fx = hook_fx.__wrapped__
parameters = inspect.getfullargspec(hook_fx)
args = parameters.args[1:] # ignore `self`
return (
Expand Down
36 changes: 36 additions & 0 deletions tests/tests_pytorch/utilities/test_signature_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch

from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature


def test_param_in_hook_signature():
class LightningModule:
def validation_step(self, dataloader_iter, batch_idx):
...

model = LightningModule()
assert is_param_in_hook_signature(model.validation_step, "dataloader_iter", explicit=True)

class LightningModule:
@torch.no_grad()
def validation_step(self, dataloader_iter, batch_idx):
...

model = LightningModule()
assert is_param_in_hook_signature(model.validation_step, "dataloader_iter", explicit=True)

class LightningModule:
def validation_step(self, *args):
...

model = LightningModule()
assert not is_param_in_hook_signature(model.validation_step, "dataloader_iter", explicit=True)
assert is_param_in_hook_signature(model.validation_step, "dataloader_iter", explicit=False)

class LightningModule:
def validation_step(self, a, b):
...

model = LightningModule()
assert not is_param_in_hook_signature(model.validation_step, "dataloader_iter", min_args=3)
assert is_param_in_hook_signature(model.validation_step, "dataloader_iter", min_args=2)

0 comments on commit 3867045

Please sign in to comment.