validation_step method with the signature breaks with decorators #17505

yaoyu-33 opened this issue Apr 27, 2023 · 2 comments · Fixed by #17507

bug Something isn't working hooks Related to the hooks API ver: 2.0.x


yaoyu-33 commented Apr 27, 2023

Bug description

    def validation_step(self, dataloader_iter, batch_idx):

will breaks the method signature and its first argument will be batch not a iterator.

What version are you seeing the problem on?


How to reproduce the bug

import torch
import pytorch_lightning as pl
from torch import nn
from import DataLoader
from torchvision import datasets, transforms

class MNISTClassifier(pl.LightningModule):
    def __init__(self):
        super(MNISTClassifier, self).__init__()
        self.layer = nn.Sequential(
            nn.Linear(28 * 28, 64),
            nn.Linear(64, 10)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.cuda()
        y = y.cuda()
        logits = self(x)
        loss = nn.functional.cross_entropy(logits, y)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, dataloader_iter, batch_idx):
        x, y = next(dataloader_iter)
        x = x.cuda()
        y = y.cuda()
        logits = self(x)
        loss = nn.functional.cross_entropy(logits, y)
        acc = (torch.argmax(logits, dim=1) == y).float().mean()
        metrics = {'val_loss': loss, 'val_acc': acc}
        return metrics

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

def mnist_data_loader():
    transform = transforms.Compose([transforms.ToTensor()])
    train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    val_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
    return train_loader, val_loader

def main():
    model = MNISTClassifier()
    train_loader, val_loader = mnist_data_loader()
    trainer = pl.Trainer(max_epochs=5, accelerator="gpu", devices=1), train_loader, val_loader)

if __name__ == '__main__':

Error messages and logs

[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/](https://localhost:8080/#) in fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    518         model = _maybe_unwrap_optimized(model)
    519         self.strategy._lightning_module = model
--> 520         call._call_and_handle_interrupt(
    521             self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    522         )

[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/](https://localhost:8080/#) in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     42             return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
     43         else:
---> 44             return trainer_fn(*args, **kwargs)
     46     except _TunerExitException:

[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/](https://localhost:8080/#) in _fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    557             model_connected=self.lightning_module is not None,
    558         )
--> 559         self._run(model, ckpt_path=ckpt_path)
    561         assert self.state.stopped

[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/](https://localhost:8080/#) in _run(self, model, ckpt_path)
    933         # RUN THE TRAINER
    934         # ----------------------------
--> 935         results = self._run_stage()
    937         # ----------------------------

[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/](https://localhost:8080/#) in _run_stage(self)
    974         if
    975             with isolate_rng():
--> 976                 self._run_sanity_check()
    977             with torch.autograd.set_detect_anomaly(self._detect_anomaly):

[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/](https://localhost:8080/#) in _run_sanity_check(self)
   1004             # run eval step
-> 1005   
   1007             call._call_callback_hooks(self, "on_sanity_check_end")

[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/](https://localhost:8080/#) in _decorator(self, *args, **kwargs)
    175             context_manager = torch.no_grad
    176         with context_manager():
--> 177             return loop_run(self, *args, **kwargs)
    179     return _decorator

[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/](https://localhost:8080/#) in run(self)
    113                 previous_dataloader_idx = dataloader_idx
    114                 # run step hooks
--> 115                 self._evaluation_step(batch, batch_idx, dataloader_idx)
    116             except StopIteration:
    117                 # this needs to wrap the `*_step` call too (not just `next`) for `dataloader_iter` support

[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/](https://localhost:8080/#) in _evaluation_step(self, batch, batch_idx, dataloader_idx)
    374         hook_name = "test_step" if trainer.testing else "validation_step"
--> 375         output = call._call_strategy_hook(trainer, hook_name, *step_kwargs.values())
    377         self.batch_progress.increment_processed()

[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/](https://localhost:8080/#) in _call_strategy_hook(trainer, hook_name, *args, **kwargs)
    287     with trainer.profiler.profile(f"[Strategy]{trainer.strategy.__class__.__name__}.{hook_name}"):
--> 288         output = fn(*args, **kwargs)
    290     # restore current_fx when nested context

[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/strategies/](https://localhost:8080/#) in validation_step(self, *args, **kwargs)
    376         with self.precision_plugin.val_step_context():
    377             assert isinstance(self.model, ValidationStep)
--> 378             return self.model.validation_step(*args, **kwargs)
    380     def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:

[/usr/local/lib/python3.10/dist-packages/torch/utils/](https://localhost:8080/#) in decorate_context(*args, **kwargs)
    113     def decorate_context(*args, **kwargs):
    114         with ctx_factory():
--> 115             return func(*args, **kwargs)
    117     return decorate_context

[<ipython-input-8-5ec322be3273>](https://localhost:8080/#) in validation_step(self, dataloader_iter, batch_idx)
     29     @torch.no_grad()
     30     def validation_step(self, dataloader_iter, batch_idx):
---> 31         x, y = next(dataloader_iter)
     32         x = x.cuda()
     33         y = y.cuda()

TypeError: 'list' object is not an iterator


Current environment
More info

No response

Opened #17507 with a fix.

However, note that wrapping validation_step with this context manager is not required. The Trainer disables the training mode automatically when validation starts.

Hi, yeah, it's not required to add the context. I modified the code from open-source project, where they had this context. Didn't aware it will cause the issue.

