Skip to content

Commit d02009a

Browse files
Fix saving relative symlink for ModelCheckpoint callback (#19303)
Co-authored-by: awaelchli <[email protected]>
1 parent e89f46a commit d02009a

File tree

3 files changed

+25
-1
lines changed

3 files changed

+25
-1
lines changed

src/lightning/pytorch/CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
7474
- Fixed warning for Dataloader if `num_workers=1` and CPU count is 1 ([#19224](https://github.com/Lightning-AI/lightning/pull/19224))
7575

7676

77+
- Fixed an issue with the ModelCheckpoint callback not saving relative symlinks with `ModelCheckpoint(save_last="link")` ([#19303](https://github.com/Lightning-AI/lightning/pull/19303))
78+
79+
7780
## [2.1.3] - 2023-12-21
7881

7982
### Changed

src/lightning/pytorch/callbacks/model_checkpoint.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ def _link_checkpoint(trainer: "pl.Trainer", filepath: str, linkpath: str) -> Non
403403
elif os.path.isdir(linkpath):
404404
shutil.rmtree(linkpath)
405405
try:
406-
os.symlink(filepath, linkpath)
406+
os.symlink(os.path.relpath(filepath, os.path.dirname(linkpath)), linkpath)
407407
except OSError:
408408
# on Windows, special permissions are required to create symbolic links as a regular user
409409
# fall back to copying the file

tests/tests_pytorch/checkpointing/test_model_checkpoint.py

+21
Original file line numberDiff line numberDiff line change
@@ -534,20 +534,23 @@ def test_model_checkpoint_link_checkpoint(tmp_path):
534534
ModelCheckpoint._link_checkpoint(trainer, filepath=str(file), linkpath=str(link))
535535
assert os.path.islink(link)
536536
assert os.path.realpath(link) == str(file)
537+
assert not os.path.isabs(os.readlink(link))
537538

538539
# link exists (is a file)
539540
new_file1 = tmp_path / "new_file1"
540541
new_file1.touch()
541542
ModelCheckpoint._link_checkpoint(trainer, filepath=str(new_file1), linkpath=str(link))
542543
assert os.path.islink(link)
543544
assert os.path.realpath(link) == str(new_file1)
545+
assert not os.path.isabs(os.readlink(link))
544546

545547
# link exists (is a link)
546548
new_file2 = tmp_path / "new_file2"
547549
new_file2.touch()
548550
ModelCheckpoint._link_checkpoint(trainer, filepath=str(new_file2), linkpath=str(link))
549551
assert os.path.islink(link)
550552
assert os.path.realpath(link) == str(new_file2)
553+
assert not os.path.isabs(os.readlink(link))
551554

552555
# link exists (is a folder)
553556
folder = tmp_path / "folder"
@@ -557,13 +560,15 @@ def test_model_checkpoint_link_checkpoint(tmp_path):
557560
ModelCheckpoint._link_checkpoint(trainer, filepath=str(folder), linkpath=str(folder_link))
558561
assert os.path.islink(folder_link)
559562
assert os.path.realpath(folder_link) == str(folder)
563+
assert not os.path.isabs(os.readlink(folder_link))
560564

561565
# link exists (is a link to a folder)
562566
new_folder = tmp_path / "new_folder"
563567
new_folder.mkdir()
564568
ModelCheckpoint._link_checkpoint(trainer, filepath=str(new_folder), linkpath=str(folder_link))
565569
assert os.path.islink(folder_link)
566570
assert os.path.realpath(folder_link) == str(new_folder)
571+
assert not os.path.isabs(os.readlink(folder_link))
567572

568573
# simulate permission error on Windows (creation of symbolic links requires privileges)
569574
file = tmp_path / "win_file"
@@ -575,6 +580,22 @@ def test_model_checkpoint_link_checkpoint(tmp_path):
575580
assert os.path.isfile(link) # fall back to copying instead of linking
576581

577582

583+
def test_model_checkpoint_link_checkpoint_relative_path(tmp_path, monkeypatch):
584+
"""Test that linking a checkpoint works with relative paths."""
585+
trainer = Mock()
586+
monkeypatch.chdir(tmp_path)
587+
588+
folder = Path("x/z/z")
589+
folder.mkdir(parents=True)
590+
file = folder / "file"
591+
file.touch()
592+
link = folder / "link"
593+
ModelCheckpoint._link_checkpoint(trainer, filepath=str(file.absolute()), linkpath=str(link.absolute()))
594+
assert os.path.islink(link)
595+
assert Path(os.readlink(link)) == file.relative_to(folder)
596+
assert not os.path.isabs(os.readlink(link))
597+
598+
578599
def test_invalid_top_k(tmpdir):
579600
"""Make sure that a MisconfigurationException is raised for a negative save_top_k argument."""
580601
with pytest.raises(MisconfigurationException, match=r".*Must be >= -1"):

0 commit comments

Comments
 (0)