Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resume training from last.ckpt #19251

Closed
Karesto opened this issue Jan 9, 2024 · 3 comments
Closed

Resume training from last.ckpt #19251

Karesto opened this issue Jan 9, 2024 · 3 comments
Labels
bug Something isn't working checkpointing Related to checkpointing ver: 2.1.x

Comments

@Karesto
Copy link

Karesto commented Jan 9, 2024

Bug description

I was wondering what is the use of the save_last parameter in the checkpoint model.
I assume it is to have a "last.ckpt" that you can always refer to, this file being a symlink, it is linked to the last saved checkpoint.

Now that that is the case, i cannot load the last.ckpt :

Traceback (most recent call last): File "/home/****/****/main.py", line 110, in <module> trainer.fit(model(model_params), dataloader, ckpt_path = path) File "/home/****/miniconda3/envs/training/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 544, in fit call._call_and_handle_interrupt( File "/home/****/miniconda3/envs/training/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 44, in _call_and_handle_interrupt return trainer_fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/****/miniconda3/envs/training/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 580, in _fit_impl self._run(model, ckpt_path=ckpt_path) File "/home/****/miniconda3/envs/training/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 955, in _run self._checkpoint_connector._restore_modules_and_callbacks(ckpt_path) File "/home/****/miniconda3/envs/training/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py", line 395, in _restore_modules_and_callbacks self.resume_start(checkpoint_path) File "/home/****/miniconda3/envs/training/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py", line 79, in resume_start loaded_checkpoint = self.trainer.strategy.load_checkpoint(checkpoint_path) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/****/miniconda3/envs/training/lib/python3.11/site-packages/lightning/pytorch/strategies/strategy.py", line 359, in load_checkpoint return self.checkpoint_io.load_checkpoint(checkpoint_path) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/****/miniconda3/envs/training/lib/python3.11/site-packages/lightning/fabric/plugins/io/torch_io.py", line 77, in load_checkpoint raise FileNotFoundError(f"Checkpoint file not found: {path}") FileNotFoundError: Checkpoint file not found: /home/****/****/lightning_logs/2.0.0/checkpoints/last.ckpt

However, the last.ckpt file does exist, and so does the checkpoint it points to.

What version are you seeing the problem on?

v2.1

How to reproduce the bug

path = "path/to/last.ckpt"
    trainer = pl.Trainer(**Training_args)
    trainer.fit(model(model_params), dataloader, ckpt_path = path)

Error messages and logs

No response

Environment

Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):

More info

No response

cc @awaelchli

@Karesto Karesto added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Jan 9, 2024
@awaelchli
Copy link
Contributor

awaelchli commented Jan 18, 2024

@Karesto Thanks for reporting this. I don't know how to reproduce this. Would you be able to modify this example so that it breaks like described in your case?

import torch
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader, Dataset


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
model = BoringModel()
trainer = Trainer(max_epochs=1, callbacks=ModelCheckpoint(dirpath="checkpionts", save_last="link"))
trainer.fit(model, train_dataloaders=train_data)
trainer = Trainer(max_epochs=2)
trainer.fit(model, train_dataloaders=train_data, ckpt_path="checkpionts/last.ckpt")

That would be very helpful, thank you.

@awaelchli awaelchli added repro needed The issue is missing a reproducible example checkpointing Related to checkpointing and removed needs triage Waiting to be triaged by maintainers labels Jan 18, 2024
@Karesto
Copy link
Author

Karesto commented Jan 22, 2024

Hi,
While i have not been able to get a reproduction of the said bug.
I am using this code :

import lightning.pytorch as pl 
import torch
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
from torch.utils.data import DataLoader, Dataset
from lightning.pytorch.loggers.tensorboard import TensorBoardLogger


import json
import torchaudio 
from torch.utils.data import DataLoader
from lightning.pytorch.loggers.tensorboard import TensorBoardLogger
logger = TensorBoardLogger(".", version = "3.3.0")




class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)

cb = ModelCheckpoint(monitor = "loss", save_last = True ,mode ="min", save_top_k=5)
lr_monitor = LearningRateMonitor(logging_interval='step')



Training_args = \
{
    "accelerator": "auto",
    "strategy": "auto",
    "devices": [1],
    "num_nodes": 1,
    # "precision": null,
    "logger": logger,
    "callbacks": [ModelCheckpoint(monitor = "train_loss", save_last = "link" ,mode ="min", save_top_k=5), lr_monitor],
    "fast_dev_run": False,
    "max_epochs": 10,
    "min_epochs": 5,
    "detect_anomaly" :False,

    }

train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
model = BoringModel()
trainer = Trainer(**Training_args)
trainer.fit(model, train_data)
trainer = Trainer(**Training_args)
trainer.fit(model, train_data, ckpt_path="lightning_logs/3.3.0/checkpoints/last.ckpt")

I cannot get the last.ckpt to be symbolic. (whether i use True, or "link" in the model checkpoint argument).

@awaelchli
Copy link
Contributor

@Karesto Ok I was able to reproduce on Lightning version 2.1.3. Thanks for providing the code.
The good news is that this has been fixed already by: #19303
I verified this by taking your code example and running it against the commit in #19303 and the prior one.

The fix will come in the next patch release (2.1.4).

@awaelchli awaelchli removed the repro needed The issue is missing a reproducible example label Jan 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working checkpointing Related to checkpointing ver: 2.1.x
Projects
None yet
Development

No branches or pull requests

2 participants