Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Impossible to load a DataModule on CPU when checkpointed with GPU available #17945

Closed
rfruit17 opened this issue Jun 28, 2023 · 2 comments · Fixed by #17950
Closed

Impossible to load a DataModule on CPU when checkpointed with GPU available #17945

rfruit17 opened this issue Jun 28, 2023 · 2 comments · Fixed by #17950
Labels
bug Something isn't working checkpointing Related to checkpointing ver: 2.0.x
Milestone

Comments

@rfruit17
Copy link

rfruit17 commented Jun 28, 2023

Bug description

It seems impossible to properly load a DataModule from a CPU-only machine when it has been checkpointed from a GPU-enabled machine.

What version are you seeing the problem on?

v2.0

How to reproduce the bug

Here is a minimal reproducible example of the bug I am facing (based on the "hello world" examples of pytorch lightning webpage):

  1. In the script below I am training a simple module using a data module. The training is done on a GPU machine (torch.cuda.is_available() returns True):
from torch import nn
import torch
from torch.utils.data import random_split, DataLoader
from torch.nn import functional as F
import pytorch_lightning as pl
from torchvision.datasets import MNIST
from torchvision import transforms

class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str = "./", batch_size: int = 32):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size

    def setup(self, stage: str):
        self.mnist_test = MNIST(self.data_dir, train=False, download=True, transform=transforms.ToTensor())
        self.mnist_predict = MNIST(self.data_dir, train=False, download=True, transform=transforms.ToTensor())
        mnist_full = MNIST(self.data_dir, train=True, download=True, transform=transforms.ToTensor())
        self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size)

    def predict_dataloader(self):
        return DataLoader(self.mnist_predict, batch_size=self.batch_size)
        
class MNISTModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.l1 = torch.nn.Linear(28 * 28, 10)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_nb):
        x, y = batch
        loss = F.cross_entropy(self(x), y)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)

mnist = MNISTDataModule("./")
model = MNISTModel()

trainer = pl.Trainer(
    default_root_dir="./",
    max_epochs=3,
)
trainer.fit(model, mnist)
  1. Then I am loading the checkpoint from a CPU-only machine (torch.cuda.is_available() returns False):
checkpoint_path = "./lightning_logs/version_0/checkpoints/epoch=2-step=5157.ckpt"

neural_net = MNISTModel.load_from_checkpoint(
    checkpoint_path,
    map_location=torch.device('cpu')
)

data_module = MNISTDataModule.load_from_checkpoint(
    checkpoint_path,
    map_location=torch.device('cpu')
)

When I remove the map_location=torch.device('cpu'), I get the error RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU. which is expected.

When I add map_location=torch.device('cpu'), the module is loaded correctly, but the data module returns the error TypeError: pytorch_lightning.core.saving._load_from_checkpoint() got multiple values for keyword argument 'map_location'.

After checking the source code of method LightningDataModule.load_from_checkpoint (see: https://lightning.ai/docs/pytorch/stable/_modules/lightning/pytorch/core/datamodule.html#LightningDataModule.load_from_checkpoint), I see that the problem is coming from the fact that the map_location argument is already set to None (which causes the error):

loaded = _load_from_checkpoint(
    cls,
    checkpoint_path,
    map_location=None,
    hparams_file=hparams_file,
    strict=None,
    **kwargs,
)

Would it be possible to add map_locationas an argument to LightningDataModule.load_from_checkpoint to solve this bug ?

Error messages and logs

TypeError: pytorch_lightning.core.saving._load_from_checkpoint() got multiple values for keyword argument 'map_location'

Environment

Current environment
-Torch version: 2.0.1
- Pytorch lightning version: 2.0.4

More info

No response

cc @awaelchli

@rfruit17 rfruit17 added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Jun 28, 2023
@rjarun8
Copy link
Contributor

rjarun8 commented Jun 29, 2023

The source code of the DataModule confirms that the map_location argument is indeed set to None in the _load_from_checkpoint function call within the load_from_checkpoint method. This is the root cause of the issue.

The proposed solution would involve modifying the load_from_checkpoint method of the DataModule to accept map_location as an argument and pass it to the _load_from_checkpoint function. This would allow the user to specify the device to which the storage should be mapped when loading the checkpoint.

@classmethod
def load_from_checkpoint(
    cls,
    checkpoint_path: Union[_PATH, IO],
    hparams_file: Optional[_PATH] = None,
    map_location=None,  # Added this line
    **kwargs: Any,
) -> Self:
    r"""
    ...
    """
    loaded = _load_from_checkpoint(
        cls,
        checkpoint_path,
        map_location=map_location,  # Modified this line
        hparams_file=hparams_file,
        strict=None,
        **kwargs,
    )
    return cast(Self, loaded)

@rfruit17
Copy link
Author

That would be awesome!

@awaelchli awaelchli added checkpointing Related to checkpointing and removed needs triage Waiting to be triaged by maintainers labels Jul 1, 2023
@awaelchli awaelchli added this to the 2.0.x milestone Jul 1, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working checkpointing Related to checkpointing ver: 2.0.x
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants