Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sharded distributed sampler for cached dataloading in DDP #195

Merged
merged 54 commits into from
Jan 2, 2025
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
e19ee14
caching dataloader
edyoshikun Sep 12, 2024
d31978d
caching data module
edyoshikun Sep 26, 2024
041d738
black
edyoshikun Sep 27, 2024
7f76174
ruff
edyoshikun Sep 27, 2024
85ea791
Bump torch to 2.4.1 (#174)
edyoshikun Sep 28, 2024
1838581
adding timeout to ram_dataloader
edyoshikun Sep 28, 2024
f5c01a3
bandaid to cached dataloader
edyoshikun Oct 4, 2024
26a06b8
fixing the dataloader using torch collate_fn
edyoshikun Oct 4, 2024
f2ff43c
replacing dictionary with single array
edyoshikun Oct 17, 2024
5fb96d7
loading prior to epoch 0
edyoshikun Oct 18, 2024
848cd63
Revert "replacing dictionary with single array"
edyoshikun Oct 19, 2024
f7e57ae
using multiprocessing manager
edyoshikun Oct 19, 2024
c4797b4
add sharded distributed sampler
ziw-liu Oct 21, 2024
2c31e7d
add example script for ddp caching
ziw-liu Oct 21, 2024
5300b4a
format and lint
ziw-liu Oct 21, 2024
8a8b4b0
addding the custom distrb sampler to hcs_ram.py
edyoshikun Oct 22, 2024
49764fa
adding sampler to val train dataloader
edyoshikun Oct 22, 2024
1fe5491
fix divisibility of the last shard
ziw-liu Oct 22, 2024
0b005cf
hcs_ram format and lint
ziw-liu Oct 22, 2024
023ca88
data module that only crops and does not collate
ziw-liu Oct 23, 2024
f7ce0ba
wip: execute transforms on the GPU
ziw-liu Oct 23, 2024
daa6860
path for if not ddp
edyoshikun Oct 24, 2024
55499de
fix randomness in inversion transform
ziw-liu Oct 29, 2024
4280677
add option to pop the normalization metadata
ziw-liu Oct 29, 2024
1561802
move gpu transform definition back to data module
ziw-liu Oct 30, 2024
2e37217
add tiled crop transform for validation
ziw-liu Oct 30, 2024
7edf36e
add stack channel transform for gpu augmentation
ziw-liu Oct 30, 2024
eda5d1b
fix typing
ziw-liu Oct 30, 2024
550101d
collate before sending to gpu
ziw-liu Oct 30, 2024
92e3722
inherit gpu transforms for livecell dataset
ziw-liu Oct 30, 2024
c185377
update fcmae engine to apply per-dataset augmentations
ziw-liu Oct 30, 2024
2ca134b
format and lint hcs_ram
ziw-liu Oct 30, 2024
70fcf1c
Merge branch 'simple-cache' into gpu-transform
ziw-liu Oct 31, 2024
be0e94f
fix abc type hint
ziw-liu Oct 31, 2024
92c4b0a
update docstring style
ziw-liu Oct 31, 2024
f7b585c
disable grad for validation transforms
ziw-liu Oct 31, 2024
42c49f5
improve sample image logging in fcmae
ziw-liu Oct 31, 2024
4bf1088
fix dataset length when batch size is larger than the dataset
ziw-liu Oct 31, 2024
3276950
fix docstring
ziw-liu Oct 31, 2024
14a16ed
add option to disable normalization metadata
ziw-liu Oct 31, 2024
6719305
inherit gpu transform for ctmc
ziw-liu Oct 31, 2024
fad3d4e
remove duplicate method overrride
ziw-liu Oct 31, 2024
07c1021
update docstring for ctmc
ziw-liu Nov 1, 2024
949c445
Merge pull request #196 from mehta-lab/gpu-transform
ziw-liu Nov 8, 2024
d331c1f
Merge branch 'main' into simple-cache
ziw-liu Nov 8, 2024
7d79473
allow skipping caching for large datasets
ziw-liu Nov 13, 2024
736d4c5
Merge branch 'main' into simple-cache
ziw-liu Nov 13, 2024
e548d52
make the fcmae module compatible with image translation
ziw-liu Nov 14, 2024
084717f
remove prototype implementation
ziw-liu Nov 19, 2024
fdc377a
fix import path
ziw-liu Nov 19, 2024
96313fa
Arbitrary prediction time transforms (#209)
ziw-liu Dec 2, 2024
6e1818b
add docstrings
ziw-liu Dec 2, 2024
9126083
Merge branch 'main' into simple-cache
ziw-liu Dec 4, 2024
b864c6e
fix typo in docstring
ziw-liu Jan 2, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ license = { file = "LICENSE" }
authors = [{ name = "CZ Biohub SF", email = "[email protected]" }]
dependencies = [
"iohub==0.1.0",
"torch>=2.1.2",
"torch>=2.4.1",
"timm>=0.9.5",
"tensorboard>=2.13.0",
"lightning>=2.3.0",
Expand Down
56 changes: 56 additions & 0 deletions viscy/data/distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""Utilities for DDP training."""

from __future__ import annotations

import math
from typing import TYPE_CHECKING

import torch
import torch.distributed
from torch.utils.data.distributed import DistributedSampler

if TYPE_CHECKING:
from torch import Generator


class ShardedDistributedSampler(DistributedSampler):
def _sharded_randperm(self, max_size: int, generator: Generator) -> list[int]:
"""Generate a sharded random permutation of indices.
Overlap may occur in between the last two shards to maintain divisibility."""
sharded_randperm = [
torch.randperm(self.num_samples, generator=generator)
+ min(i * self.num_samples, max_size - self.num_samples)
for i in range(self.num_replicas)
]
indices = torch.stack(sharded_randperm, dim=1).reshape(-1)
return indices.tolist()

def __iter__(self):
"""Modified __iter__ method to shard data across distributed ranks."""
max_size = len(self.dataset) # type: ignore[arg-type]
if self.shuffle:
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
indices = self._sharded_randperm(max_size, g)
else:
indices = list(range(max_size))
if not self.drop_last:
# add extra samples to make it evenly divisible
padding_size = self.total_size - len(indices)
if padding_size <= len(indices):
indices += indices[:padding_size]
else:
indices += (indices * math.ceil(padding_size / len(indices)))[
:padding_size
]
else:
# remove tail of data to make it evenly divisible.
indices = indices[: self.total_size]
assert len(indices) == self.total_size

# subsample
indices = indices[self.rank : self.total_size : self.num_replicas]
assert len(indices) == self.num_samples

return iter(indices)
Loading
Loading