Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

validation_step method with the signature breaks with decorators #17505

Closed
yaoyu-33 opened this issue Apr 27, 2023 · 2 comments · Fixed by #17507
Closed

validation_step method with the signature breaks with decorators #17505

yaoyu-33 opened this issue Apr 27, 2023 · 2 comments · Fixed by #17507
Labels
bug Something isn't working hooks Related to the hooks API ver: 2.0.x
Milestone

Comments

@yaoyu-33
Copy link

yaoyu-33 commented Apr 27, 2023

Bug description

    @torch.no_grad()
    def validation_step(self, dataloader_iter, batch_idx):

will breaks the method signature and its first argument will be batch not a iterator.

What version are you seeing the problem on?

v2_0

How to reproduce the bug

import torch
import pytorch_lightning as pl
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

class MNISTClassifier(pl.LightningModule):
    def __init__(self):
        super(MNISTClassifier, self).__init__()
        self.layer = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, 64),
            nn.ReLU(),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.cuda()
        y = y.cuda()
        logits = self(x)
        loss = nn.functional.cross_entropy(logits, y)
        self.log('train_loss', loss)
        return loss

    @torch.no_grad()
    def validation_step(self, dataloader_iter, batch_idx):
        x, y = next(dataloader_iter)
        x = x.cuda()
        y = y.cuda()
        logits = self(x)
        loss = nn.functional.cross_entropy(logits, y)
        acc = (torch.argmax(logits, dim=1) == y).float().mean()
        metrics = {'val_loss': loss, 'val_acc': acc}
        self.log_dict(metrics)
        return metrics

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

def mnist_data_loader():
    transform = transforms.Compose([transforms.ToTensor()])
    train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    val_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
    return train_loader, val_loader

def main():
    pl.seed_everything(42)
    model = MNISTClassifier()
    train_loader, val_loader = mnist_data_loader()
    trainer = pl.Trainer(max_epochs=5, accelerator="gpu", devices=1)
    trainer.fit(model, train_loader, val_loader)

if __name__ == '__main__':
    main()

Error messages and logs

[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py](https://localhost:8080/#) in fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    518         model = _maybe_unwrap_optimized(model)
    519         self.strategy._lightning_module = model
--> 520         call._call_and_handle_interrupt(
    521             self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    522         )

[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/call.py](https://localhost:8080/#) in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     42             return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
     43         else:
---> 44             return trainer_fn(*args, **kwargs)
     45 
     46     except _TunerExitException:

[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py](https://localhost:8080/#) in _fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    557             model_connected=self.lightning_module is not None,
    558         )
--> 559         self._run(model, ckpt_path=ckpt_path)
    560 
    561         assert self.state.stopped

[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py](https://localhost:8080/#) in _run(self, model, ckpt_path)
    933         # RUN THE TRAINER
    934         # ----------------------------
--> 935         results = self._run_stage()
    936 
    937         # ----------------------------

[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py](https://localhost:8080/#) in _run_stage(self)
    974         if self.training:
    975             with isolate_rng():
--> 976                 self._run_sanity_check()
    977             with torch.autograd.set_detect_anomaly(self._detect_anomaly):
    978                 self.fit_loop.run()

[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py](https://localhost:8080/#) in _run_sanity_check(self)
   1003 
   1004             # run eval step
-> 1005             val_loop.run()
   1006 
   1007             call._call_callback_hooks(self, "on_sanity_check_end")

[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/utilities.py](https://localhost:8080/#) in _decorator(self, *args, **kwargs)
    175             context_manager = torch.no_grad
    176         with context_manager():
--> 177             return loop_run(self, *args, **kwargs)
    178 
    179     return _decorator

[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/evaluation_loop.py](https://localhost:8080/#) in run(self)
    113                 previous_dataloader_idx = dataloader_idx
    114                 # run step hooks
--> 115                 self._evaluation_step(batch, batch_idx, dataloader_idx)
    116             except StopIteration:
    117                 # this needs to wrap the `*_step` call too (not just `next`) for `dataloader_iter` support

[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/evaluation_loop.py](https://localhost:8080/#) in _evaluation_step(self, batch, batch_idx, dataloader_idx)
    373 
    374         hook_name = "test_step" if trainer.testing else "validation_step"
--> 375         output = call._call_strategy_hook(trainer, hook_name, *step_kwargs.values())
    376 
    377         self.batch_progress.increment_processed()

[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/call.py](https://localhost:8080/#) in _call_strategy_hook(trainer, hook_name, *args, **kwargs)
    286 
    287     with trainer.profiler.profile(f"[Strategy]{trainer.strategy.__class__.__name__}.{hook_name}"):
--> 288         output = fn(*args, **kwargs)
    289 
    290     # restore current_fx when nested context

[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/strategies/strategy.py](https://localhost:8080/#) in validation_step(self, *args, **kwargs)
    376         with self.precision_plugin.val_step_context():
    377             assert isinstance(self.model, ValidationStep)
--> 378             return self.model.validation_step(*args, **kwargs)
    379 
    380     def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:

[/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py](https://localhost:8080/#) in decorate_context(*args, **kwargs)
    113     def decorate_context(*args, **kwargs):
    114         with ctx_factory():
--> 115             return func(*args, **kwargs)
    116 
    117     return decorate_context

[<ipython-input-8-5ec322be3273>](https://localhost:8080/#) in validation_step(self, dataloader_iter, batch_idx)
     29     @torch.no_grad()
     30     def validation_step(self, dataloader_iter, batch_idx):
---> 31         x, y = next(dataloader_iter)
     32         x = x.cuda()
     33         y = y.cuda()

TypeError: 'list' object is not an iterator

Environment

Current environment
  • CUDA:
    • GPU:
      • Tesla T4
    • available: True
    • version: 11.8
  • Lightning:
    • lightning-utilities: 0.8.0
    • pytorch-lightning: 2.0.2
    • torch: 2.0.0+cu118
    • torchaudio: 2.0.1+cu118
    • torchdata: 0.6.0
    • torchmetrics: 0.11.4
    • torchsummary: 1.5.1
    • torchtext: 0.15.1
    • torchvision: 0.15.1+cu118
  • Packages:
    • absl-py: 1.4.0
    • aiohttp: 3.8.4
    • aiosignal: 1.3.1
    • alabaster: 0.7.13
    • albumentations: 1.2.1
    • altair: 4.2.2
    • anyio: 3.6.2
    • appdirs: 1.4.4
    • argon2-cffi: 21.3.0
    • argon2-cffi-bindings: 21.2.0
    • arviz: 0.15.1
    • astropy: 5.2.2
    • astunparse: 1.6.3
    • async-timeout: 4.0.2
    • attrs: 23.1.0
    • audioread: 3.0.0
    • autograd: 1.5
    • babel: 2.12.1
    • backcall: 0.2.0
    • beautifulsoup4: 4.11.2
    • bleach: 6.0.0
    • blis: 0.7.9
    • blosc2: 2.0.0
    • bokeh: 2.4.3
    • branca: 0.6.0
    • cachecontrol: 0.12.11
    • cached-property: 1.5.2
    • cachetools: 5.3.0
    • catalogue: 2.0.8
    • certifi: 2022.12.7
    • cffi: 1.15.1
    • chardet: 4.0.0
    • charset-normalizer: 2.0.12
    • chex: 0.1.7
    • click: 8.1.3
    • cloudpickle: 2.2.1
    • cmake: 3.25.2
    • cmdstanpy: 1.1.0
    • colorcet: 3.0.1
    • colorlover: 0.3.0
    • community: 1.0.0b1
    • confection: 0.0.4
    • cons: 0.4.5
    • contextlib2: 0.6.0.post1
    • contourpy: 1.0.7
    • convertdate: 2.4.0
    • cryptography: 40.0.2
    • cufflinks: 0.17.3
    • cupy-cuda11x: 11.0.0
    • cvxopt: 1.3.0
    • cvxpy: 1.3.1
    • cycler: 0.11.0
    • cymem: 2.0.7
    • cython: 0.29.34
    • dask: 2022.12.1
    • datascience: 0.17.6
    • db-dtypes: 1.1.1
    • dbus-python: 1.2.16
    • debugpy: 1.6.6
    • decorator: 4.4.2
    • defusedxml: 0.7.1
    • distributed: 2022.12.1
    • dlib: 19.24.1
    • dm-tree: 0.1.8
    • docutils: 0.16
    • dopamine-rl: 4.0.6
    • duckdb: 0.7.1
    • earthengine-api: 0.1.350
    • easydict: 1.10
    • ecos: 2.0.12
    • editdistance: 0.6.2
    • en-core-web-sm: 3.5.0
    • entrypoints: 0.4
    • ephem: 4.1.4
    • et-xmlfile: 1.1.0
    • etils: 1.2.0
    • etuples: 0.3.8
    • exceptiongroup: 1.1.1
    • fastai: 2.7.12
    • fastcore: 1.5.29
    • fastdownload: 0.0.7
    • fastjsonschema: 2.16.3
    • fastprogress: 1.0.3
    • fastrlock: 0.8.1
    • filelock: 3.12.0
    • firebase-admin: 5.3.0
    • flask: 2.2.4
    • flatbuffers: 23.3.3
    • flax: 0.6.9
    • folium: 0.14.0
    • fonttools: 4.39.3
    • frozendict: 2.3.7
    • frozenlist: 1.3.3
    • fsspec: 2023.4.0
    • future: 0.18.3
    • gast: 0.4.0
    • gdal: 3.3.2
    • gdown: 4.6.6
    • gensim: 4.3.1
    • geographiclib: 2.0
    • geopy: 2.3.0
    • gin-config: 0.5.0
    • glob2: 0.7
    • google: 2.0.3
    • google-api-core: 2.11.0
    • google-api-python-client: 2.84.0
    • google-auth: 2.17.3
    • google-auth-httplib2: 0.1.0
    • google-auth-oauthlib: 1.0.0
    • google-cloud-bigquery: 3.9.0
    • google-cloud-bigquery-storage: 2.19.1
    • google-cloud-core: 2.3.2
    • google-cloud-datastore: 2.15.1
    • google-cloud-firestore: 2.11.0
    • google-cloud-language: 2.9.1
    • google-cloud-storage: 2.8.0
    • google-cloud-translate: 3.11.1
    • google-colab: 1.0.0
    • google-crc32c: 1.5.0
    • google-pasta: 0.2.0
    • google-resumable-media: 2.5.0
    • googleapis-common-protos: 1.59.0
    • googledrivedownloader: 0.4
    • graphviz: 0.20.1
    • greenlet: 2.0.2
    • grpcio: 1.54.0
    • grpcio-status: 1.48.2
    • gspread: 3.4.2
    • gspread-dataframe: 3.0.8
    • gym: 0.25.2
    • gym-notices: 0.0.8
    • h5netcdf: 1.1.0
    • h5py: 3.8.0
    • hijri-converter: 2.3.1
    • holidays: 0.23
    • holoviews: 1.15.4
    • html5lib: 1.1
    • httpimport: 1.3.0
    • httplib2: 0.21.0
    • humanize: 4.6.0
    • hyperopt: 0.2.7
    • idna: 3.4
    • imageio: 2.25.1
    • imageio-ffmpeg: 0.4.8
    • imagesize: 1.4.1
    • imbalanced-learn: 0.10.1
    • imgaug: 0.4.0
    • importlib-metadata: 6.6.0
    • importlib-resources: 5.12.0
    • imutils: 0.5.4
    • inflect: 6.0.4
    • iniconfig: 2.0.0
    • intel-openmp: 2023.1.0
    • ipykernel: 5.5.6
    • ipython: 7.34.0
    • ipython-genutils: 0.2.0
    • ipython-sql: 0.4.1
    • ipywidgets: 7.7.1
    • itsdangerous: 2.1.2
    • jax: 0.4.8
    • jaxlib: 0.4.7+cuda11.cudnn86
    • jieba: 0.42.1
    • jinja2: 3.1.2
    • joblib: 1.2.0
    • jsonpickle: 3.0.1
    • jsonschema: 4.3.3
    • jupyter-client: 6.1.12
    • jupyter-console: 6.1.0
    • jupyter-core: 5.3.0
    • jupyter-server: 1.24.0
    • jupyterlab-pygments: 0.2.2
    • jupyterlab-widgets: 3.0.7
    • kaggle: 1.5.13
    • keras: 2.12.0
    • keras-vis: 0.4.1
    • kiwisolver: 1.4.4
    • korean-lunar-calendar: 0.3.1
    • langcodes: 3.3.0
    • lazy-loader: 0.2
    • libclang: 16.0.0
    • librosa: 0.10.0.post2
    • lightgbm: 3.3.5
    • lightning-utilities: 0.8.0
    • lit: 16.0.2
    • llvmlite: 0.39.1
    • locket: 1.0.0
    • logical-unification: 0.4.5
    • lunarcalendar: 0.0.9
    • lxml: 4.9.2
    • markdown: 3.4.3
    • markdown-it-py: 2.2.0
    • markupsafe: 2.1.2
    • matplotlib: 3.7.1
    • matplotlib-inline: 0.1.6
    • matplotlib-venn: 0.11.9
    • mdurl: 0.1.2
    • minikanren: 1.0.3
    • missingno: 0.5.2
    • mistune: 0.8.4
    • mizani: 0.8.1
    • mkl: 2019.0
    • ml-dtypes: 0.1.0
    • mlxtend: 0.14.0
    • more-itertools: 9.1.0
    • moviepy: 1.0.3
    • mpmath: 1.3.0
    • msgpack: 1.0.5
    • multidict: 6.0.4
    • multipledispatch: 0.6.0
    • multitasking: 0.0.11
    • murmurhash: 1.0.9
    • music21: 8.1.0
    • natsort: 8.3.1
    • nbclient: 0.7.4
    • nbconvert: 6.5.4
    • nbformat: 5.8.0
    • nest-asyncio: 1.5.6
    • networkx: 3.1
    • nibabel: 3.0.2
    • nltk: 3.8.1
    • notebook: 6.4.8
    • numba: 0.56.4
    • numexpr: 2.8.4
    • numpy: 1.22.4
    • oauth2client: 4.1.3
    • oauthlib: 3.2.2
    • opencv-contrib-python: 4.7.0.72
    • opencv-python: 4.7.0.72
    • opencv-python-headless: 4.7.0.72
    • openpyxl: 3.0.10
    • opt-einsum: 3.3.0
    • optax: 0.1.5
    • orbax-checkpoint: 0.2.1
    • osqp: 0.6.2.post8
    • packaging: 23.1
    • palettable: 3.3.3
    • pandas: 1.5.3
    • pandas-datareader: 0.10.0
    • pandas-gbq: 0.17.9
    • pandocfilters: 1.5.0
    • panel: 0.14.4
    • param: 1.13.0
    • parso: 0.8.3
    • partd: 1.4.0
    • pathlib: 1.0.1
    • pathy: 0.10.1
    • patsy: 0.5.3
    • pep517: 0.13.0
    • pexpect: 4.8.0
    • pickleshare: 0.7.5
    • pillow: 8.4.0
    • pip: 23.0.1
    • pip-tools: 6.6.2
    • platformdirs: 3.3.0
    • plotly: 5.13.1
    • plotnine: 0.10.1
    • pluggy: 1.0.0
    • polars: 0.17.3
    • pooch: 1.6.0
    • portpicker: 1.3.9
    • prefetch-generator: 1.0.3
    • preshed: 3.0.8
    • prettytable: 0.7.2
    • proglog: 0.1.10
    • progressbar2: 4.2.0
    • prometheus-client: 0.16.0
    • promise: 2.3
    • prompt-toolkit: 3.0.38
    • prophet: 1.1.2
    • proto-plus: 1.22.2
    • protobuf: 3.20.3
    • psutil: 5.9.5
    • psycopg2: 2.9.6
    • ptyprocess: 0.7.0
    • py-cpuinfo: 9.0.0
    • py4j: 0.10.9.7
    • pyarrow: 9.0.0
    • pyasn1: 0.5.0
    • pyasn1-modules: 0.3.0
    • pycocotools: 2.0.6
    • pycparser: 2.21
    • pyct: 0.5.0
    • pydantic: 1.10.7
    • pydata-google-auth: 1.7.0
    • pydot: 1.4.2
    • pydot-ng: 2.0.0
    • pydotplus: 2.0.2
    • pydrive: 1.3.1
    • pyerfa: 2.0.0.3
    • pygame: 2.3.0
    • pygments: 2.14.0
    • pygobject: 3.36.0
    • pymc: 5.1.2
    • pymeeus: 0.5.12
    • pymystem3: 0.2.0
    • pyopengl: 3.1.6
    • pyparsing: 3.0.9
    • pyrsistent: 0.19.3
    • pysocks: 1.7.1
    • pytensor: 2.10.1
    • pytest: 7.2.2
    • python-apt: 0.0.0
    • python-dateutil: 2.8.2
    • python-louvain: 0.16
    • python-slugify: 8.0.1
    • python-utils: 3.5.2
    • pytorch-lightning: 2.0.2
    • pytz: 2022.7.1
    • pytz-deprecation-shim: 0.1.0.post0
    • pyviz-comms: 2.2.1
    • pywavelets: 1.4.1
    • pyyaml: 6.0
    • pyzmq: 23.2.1
    • qdldl: 0.1.7
    • qudida: 0.0.4
    • regex: 2022.10.31
    • requests: 2.27.1
    • requests-oauthlib: 1.3.1
    • requests-unixsocket: 0.2.0
    • rich: 13.3.4
    • rpy2: 3.5.5
    • rsa: 4.9
    • scikit-image: 0.19.3
    • scikit-learn: 1.2.2
    • scipy: 1.10.1
    • scs: 3.2.3
    • seaborn: 0.12.2
    • send2trash: 1.8.0
    • setuptools: 67.7.2
    • shapely: 2.0.1
    • six: 1.16.0
    • sklearn-pandas: 2.2.0
    • smart-open: 6.3.0
    • sniffio: 1.3.0
    • snowballstemmer: 2.2.0
    • sortedcontainers: 2.4.0
    • soundfile: 0.12.1
    • soupsieve: 2.4.1
    • soxr: 0.3.5
    • spacy: 3.5.2
    • spacy-legacy: 3.0.12
    • spacy-loggers: 1.0.4
    • sphinx: 3.5.4
    • sphinxcontrib-applehelp: 1.0.4
    • sphinxcontrib-devhelp: 1.0.2
    • sphinxcontrib-htmlhelp: 2.0.1
    • sphinxcontrib-jsmath: 1.0.1
    • sphinxcontrib-qthelp: 1.0.3
    • sphinxcontrib-serializinghtml: 1.1.5
    • sqlalchemy: 2.0.10
    • sqlparse: 0.4.4
    • srsly: 2.4.6
    • statsmodels: 0.13.5
    • sympy: 1.11.1
    • tables: 3.8.0
    • tabulate: 0.8.10
    • tblib: 1.7.0
    • tenacity: 8.2.2
    • tensorboard: 2.12.2
    • tensorboard-data-server: 0.7.0
    • tensorboard-plugin-wit: 1.8.1
    • tensorflow: 2.12.0
    • tensorflow-datasets: 4.8.3
    • tensorflow-estimator: 2.12.0
    • tensorflow-gcs-config: 2.12.0
    • tensorflow-hub: 0.13.0
    • tensorflow-io-gcs-filesystem: 0.32.0
    • tensorflow-metadata: 1.13.1
    • tensorflow-probability: 0.19.0
    • tensorstore: 0.1.36
    • termcolor: 2.3.0
    • terminado: 0.17.1
    • text-unidecode: 1.3
    • textblob: 0.17.1
    • tf-slim: 1.1.0
    • thinc: 8.1.9
    • threadpoolctl: 3.1.0
    • tifffile: 2023.4.12
    • tinycss2: 1.2.1
    • toml: 0.10.2
    • tomli: 2.0.1
    • toolz: 0.12.0
    • torch: 2.0.0+cu118
    • torchaudio: 2.0.1+cu118
    • torchdata: 0.6.0
    • torchmetrics: 0.11.4
    • torchsummary: 1.5.1
    • torchtext: 0.15.1
    • torchvision: 0.15.1+cu118
    • tornado: 6.2
    • tqdm: 4.65.0
    • traitlets: 5.7.1
    • triton: 2.0.0
    • tweepy: 4.13.0
    • typer: 0.7.0
    • typing-extensions: 4.5.0
    • tzdata: 2023.3
    • tzlocal: 4.3
    • uritemplate: 4.1.1
    • urllib3: 1.26.15
    • vega-datasets: 0.9.0
    • wasabi: 1.1.1
    • wcwidth: 0.2.6
    • webcolors: 1.13
    • webencodings: 0.5.1
    • websocket-client: 1.5.1
    • werkzeug: 2.3.0
    • wheel: 0.40.0
    • widgetsnbextension: 3.6.4
    • wordcloud: 1.8.2.2
    • wrapt: 1.14.1
    • xarray: 2022.12.0
    • xarray-einstats: 0.5.1
    • xgboost: 1.7.5
    • xlrd: 2.0.1
    • yarl: 1.9.2
    • yellowbrick: 1.5
    • yfinance: 0.2.18
    • zict: 3.0.0
    • zipp: 3.15.0
  • System:
    • OS: Linux
    • architecture:
      • 64bit
      • ELF
    • processor: x86_64
    • python: 3.10.11
    • version: Proposal for help #1 SMP Sat Dec 10 16:00:40 UTC 2022

More info

No response

cc @carmocca @awaelchli @Borda @justusschock

@yaoyu-33 yaoyu-33 added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Apr 27, 2023
@carmocca carmocca added hooks Related to the hooks API and removed needs triage Waiting to be triaged by maintainers labels Apr 28, 2023
@carmocca carmocca added this to the v1.9.x milestone Apr 28, 2023
@carmocca
Copy link
Contributor

Opened #17507 with a fix.

However, note that wrapping validation_step with this context manager is not required. The Trainer disables the training mode automatically when validation starts.

@yaoyu-33
Copy link
Author

Hi, yeah, it's not required to add the context. I modified the code from open-source project, where they had this context. Didn't aware it will cause the issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working hooks Related to the hooks API ver: 2.0.x
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants