Skip to content

Commit 9c76a14

Browse files
awaelchlishenmishajing
authored andcommitted
Fix saving relative symlink for ModelCheckpoint callback (#19303)
Co-authored-by: shenmishajing <[email protected]>
1 parent f3f23d3 commit 9c76a14

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

src/lightning/pytorch/callbacks/model_checkpoint.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ def _link_checkpoint(trainer: "pl.Trainer", filepath: str, linkpath: str) -> Non
390390
elif os.path.isdir(linkpath):
391391
shutil.rmtree(linkpath)
392392
try:
393-
os.symlink(filepath, linkpath)
393+
os.symlink(os.path.relpath(filepath, os.path.dirname(linkpath)), linkpath)
394394
except OSError:
395395
# on Windows, special permissions are required to create symbolic links as a regular user
396396
# fall back to copying the file

tests/tests_pytorch/checkpointing/test_model_checkpoint.py

+21
Original file line numberDiff line numberDiff line change
@@ -524,20 +524,23 @@ def test_model_checkpoint_link_checkpoint(tmp_path):
524524
ModelCheckpoint._link_checkpoint(trainer, filepath=str(file), linkpath=str(link))
525525
assert os.path.islink(link)
526526
assert os.path.realpath(link) == str(file)
527+
assert not os.path.isabs(os.readlink(link))
527528

528529
# link exists (is a file)
529530
new_file1 = tmp_path / "new_file1"
530531
new_file1.touch()
531532
ModelCheckpoint._link_checkpoint(trainer, filepath=str(new_file1), linkpath=str(link))
532533
assert os.path.islink(link)
533534
assert os.path.realpath(link) == str(new_file1)
535+
assert not os.path.isabs(os.readlink(link))
534536

535537
# link exists (is a link)
536538
new_file2 = tmp_path / "new_file2"
537539
new_file2.touch()
538540
ModelCheckpoint._link_checkpoint(trainer, filepath=str(new_file2), linkpath=str(link))
539541
assert os.path.islink(link)
540542
assert os.path.realpath(link) == str(new_file2)
543+
assert not os.path.isabs(os.readlink(link))
541544

542545
# link exists (is a folder)
543546
folder = tmp_path / "folder"
@@ -547,13 +550,15 @@ def test_model_checkpoint_link_checkpoint(tmp_path):
547550
ModelCheckpoint._link_checkpoint(trainer, filepath=str(folder), linkpath=str(folder_link))
548551
assert os.path.islink(folder_link)
549552
assert os.path.realpath(folder_link) == str(folder)
553+
assert not os.path.isabs(os.readlink(folder_link))
550554

551555
# link exists (is a link to a folder)
552556
new_folder = tmp_path / "new_folder"
553557
new_folder.mkdir()
554558
ModelCheckpoint._link_checkpoint(trainer, filepath=str(new_folder), linkpath=str(folder_link))
555559
assert os.path.islink(folder_link)
556560
assert os.path.realpath(folder_link) == str(new_folder)
561+
assert not os.path.isabs(os.readlink(folder_link))
557562

558563
# simulate permission error on Windows (creation of symbolic links requires privileges)
559564
file = tmp_path / "win_file"
@@ -565,6 +570,22 @@ def test_model_checkpoint_link_checkpoint(tmp_path):
565570
assert os.path.isfile(link) # fall back to copying instead of linking
566571

567572

573+
def test_model_checkpoint_link_checkpoint_relative_path(tmp_path, monkeypatch):
574+
"""Test that linking a checkpoint works with relative paths."""
575+
trainer = Mock()
576+
monkeypatch.chdir(tmp_path)
577+
578+
folder = Path("x/z/z")
579+
folder.mkdir(parents=True)
580+
file = folder / "file"
581+
file.touch()
582+
link = folder / "link"
583+
ModelCheckpoint._link_checkpoint(trainer, filepath=str(file.absolute()), linkpath=str(link.absolute()))
584+
assert os.path.islink(link)
585+
assert Path(os.readlink(link)) == file.relative_to(folder)
586+
assert not os.path.isabs(os.readlink(link))
587+
588+
568589
def test_invalid_top_k(tmpdir):
569590
"""Make sure that a MisconfigurationException is raised for a negative save_top_k argument."""
570591
with pytest.raises(MisconfigurationException, match=r".*Must be >= -1"):

0 commit comments

Comments
 (0)