diff --git a/pyproject.toml b/pyproject.toml index 55be6805..addbadc3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ classifiers = [ ] dependencies = [ "iohub==0.1.0", - "torch>=2.1.2", + "torch>=2.4.1", "timm>=0.9.5", "tensorboard>=2.13.0", "lightning>=2.3.0", diff --git a/viscy/data/combined.py b/viscy/data/combined.py index 31ea9f6c..87036723 100644 --- a/viscy/data/combined.py +++ b/viscy/data/combined.py @@ -1,10 +1,12 @@ from enum import Enum from typing import Literal, Sequence +import torch from lightning.pytorch import LightningDataModule from lightning.pytorch.utilities.combined_loader import CombinedLoader -from torch.utils.data import ConcatDataset, DataLoader +from torch.utils.data import ConcatDataset, DataLoader, Dataset +from viscy.data.distributed import ShardedDistributedSampler from viscy.data.hcs import _collate_samples @@ -19,11 +21,18 @@ class CombinedDataModule(LightningDataModule): """Wrapper for combining multiple data modules. For supported modes, see ``lightning.pytorch.utilities.combined_loader``. - :param Sequence[LightningDataModule] data_modules: data modules to combine - :param str train_mode: mode in training stage, defaults to "max_size_cycle" - :param str val_mode: mode in validation stage, defaults to "sequential" - :param str test_mode: mode in testing stage, defaults to "sequential" - :param str predict_mode: mode in prediction stage, defaults to "sequential" + Parameters + ---------- + data_modules : Sequence[LightningDataModule] + data modules to combine + train_mode : CombineMode, optional + mode in training stage, by default CombineMode.MAX_SIZE_CYCLE + val_mode : CombineMode, optional + mode in validation stage, by default CombineMode.SEQUENTIAL + test_mode : CombineMode, optional + mode in testing stage, by default CombineMode.SEQUENTIAL + predict_mode : CombineMode, optional + mode in prediction stage, by default CombineMode.SEQUENTIAL """ def __init__( @@ -76,10 +85,15 @@ def predict_dataloader(self): class ConcatDataModule(LightningDataModule): """ Concatenate multiple data modules. - The concatenated data module will have the same - batch size and number of workers as the first data module. - Each element will be sampled uniformly regardless of their original data module. - :param Sequence[LightningDataModule] data_modules: data modules to concatenate + + The concatenated data module will have the same batch size and number of workers + as the first data module. Each element will be sampled uniformly regardless of + their original data module. + + Parameters + ---------- + data_modules : Sequence[LightningDataModule] + Data modules to concatenate. """ def __init__(self, data_modules: Sequence[LightningDataModule]): @@ -133,3 +147,73 @@ def val_dataloader(self): shuffle=False, persistent_workers=bool(self.num_workers), ) + + +class CachedConcatDataModule(LightningDataModule): + def __init__(self, data_modules: Sequence[LightningDataModule]): + super().__init__() + self.data_modules = data_modules + self.num_workers = data_modules[0].num_workers + self.batch_size = data_modules[0].batch_size + for dm in data_modules: + if dm.num_workers != self.num_workers: + raise ValueError("Inconsistent number of workers") + if dm.batch_size != self.batch_size: + raise ValueError("Inconsistent batch size") + self.prepare_data_per_node = True + + def prepare_data(self): + for dm in self.data_modules: + dm.trainer = self.trainer + dm.prepare_data() + + def setup(self, stage: Literal["fit", "validate", "test", "predict"]): + self.train_patches_per_stack = 0 + for dm in self.data_modules: + dm.setup(stage) + if patches := getattr(dm, "train_patches_per_stack", 1): + if self.train_patches_per_stack == 0: + self.train_patches_per_stack = patches + elif self.train_patches_per_stack != patches: + raise ValueError("Inconsistent patches per stack") + if stage != "fit": + raise NotImplementedError("Only fit stage is supported") + self.train_dataset = ConcatDataset( + [dm.train_dataset for dm in self.data_modules] + ) + self.val_dataset = ConcatDataset([dm.val_dataset for dm in self.data_modules]) + + def _maybe_sampler( + self, dataset: Dataset, shuffle: bool + ) -> ShardedDistributedSampler | None: + return ( + ShardedDistributedSampler(dataset, shuffle=shuffle) + if torch.distributed.is_initialized() + else None + ) + + def train_dataloader(self) -> DataLoader: + sampler = self._maybe_sampler(self.train_dataset, shuffle=True) + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + shuffle=False if sampler else True, + sampler=sampler, + persistent_workers=True if self.num_workers > 0 else False, + num_workers=self.num_workers, + drop_last=True, + collate_fn=lambda x: x, + ) + + def val_dataloader(self) -> DataLoader: + sampler = self._maybe_sampler(self.val_dataset, shuffle=False) + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + shuffle=False, + sampler=sampler, + persistent_workers=True if self.num_workers > 0 else False, + num_workers=self.num_workers, + drop_last=False, + collate_fn=lambda x: x, + ) diff --git a/viscy/data/ctmc_v1.py b/viscy/data/ctmc_v1.py index d71859b8..3c888175 100644 --- a/viscy/data/ctmc_v1.py +++ b/viscy/data/ctmc_v1.py @@ -1,56 +1,86 @@ from pathlib import Path +import torch from iohub.ngff import open_ome_zarr -from lightning.pytorch import LightningDataModule from monai.transforms import Compose, MapTransform -from torch.utils.data import DataLoader -from viscy.data.hcs import ChannelMap, SlidingWindowDataset -from viscy.data.typing import Sample +from viscy.data.gpu_aug import CachedOmeZarrDataset, GPUTransformDataModule -class CTMCv1ValidationDataset(SlidingWindowDataset): - def __len__(self, subsample_rate: int = 30) -> int: - # sample every 30th frame in the videos - return super().__len__() // self.subsample_rate - - def __getitem__(self, index: int) -> Sample: - index = index * self.subsample_rate - return super().__getitem__(index) - - -class CTMCv1DataModule(LightningDataModule): +class CTMCv1DataModule(GPUTransformDataModule): """ Autoregression data module for the CTMCv1 dataset. Training and validation datasets are stored in separate HCS OME-Zarr stores. - :param str | Path train_data_path: Path to the training dataset - :param str | Path val_data_path: Path to the validation dataset - :param list[MapTransform] train_transforms: List of transforms for training - :param list[MapTransform] val_transforms: List of transforms for validation - :param int batch_size: Batch size, defaults to 16 - :param int num_workers: Number of workers, defaults to 8 - :param str channel_name: Name of the DIC channel, defaults to "DIC" + Parameters + ---------- + train_data_path : str or Path + Path to the training dataset. + val_data_path : str or Path + Path to the validation dataset. + train_cpu_transforms : list of MapTransform + List of CPU transforms for training. + val_cpu_transforms : list of MapTransform + List of CPU transforms for validation. + train_gpu_transforms : list of MapTransform + List of GPU transforms for training. + val_gpu_transforms : list of MapTransform + List of GPU transforms for validation. + batch_size : int, optional + Batch size, by default 16. + num_workers : int, optional + Number of dataloading workers, by default 8. + val_subsample_ratio : int, optional + Skip every N frames for validation to reduce redundancy in video, + by default 30. + channel_name : str, optional + Name of the DIC channel, by default "DIC". + pin_memory : bool, optional + Pin memory for dataloaders, by default True. """ def __init__( self, train_data_path: str | Path, val_data_path: str | Path, - train_transforms: list[MapTransform], - val_transforms: list[MapTransform], + train_cpu_transforms: list[MapTransform], + val_cpu_transforms: list[MapTransform], + train_gpu_transforms: list[MapTransform], + val_gpu_transforms: list[MapTransform], batch_size: int = 16, num_workers: int = 8, + val_subsample_ratio: int = 30, channel_name: str = "DIC", + pin_memory: bool = True, ) -> None: super().__init__() self.train_data_path = train_data_path self.val_data_path = val_data_path - self.train_transforms = train_transforms - self.val_transforms = val_transforms - self.channel_map = ChannelMap(source=[channel_name], target=[channel_name]) + self._train_cpu_transforms = Compose(train_cpu_transforms) + self._val_cpu_transforms = Compose(val_cpu_transforms) + self._train_gpu_transforms = Compose(train_gpu_transforms) + self._val_gpu_transforms = Compose(val_gpu_transforms) + self.channel_names = [channel_name] self.batch_size = batch_size self.num_workers = num_workers + self.val_subsample_ratio = val_subsample_ratio + self.pin_memory = pin_memory + + @property + def train_cpu_transforms(self) -> Compose: + return self._train_cpu_transforms + + @property + def val_cpu_transforms(self) -> Compose: + return self._val_cpu_transforms + + @property + def train_gpu_transforms(self) -> Compose: + return self._train_gpu_transforms + + @property + def val_gpu_transforms(self) -> Compose: + return self._val_gpu_transforms def setup(self, stage: str) -> None: if stage != "fit": @@ -58,37 +88,26 @@ def setup(self, stage: str) -> None: self._setup_fit() def _setup_fit(self) -> None: + cache_map = torch.multiprocessing.Manager().dict() train_plate = open_ome_zarr(self.train_data_path) val_plate = open_ome_zarr(self.val_data_path) train_positions = [p for _, p in train_plate.positions()] val_positions = [p for _, p in val_plate.positions()] - self.train_dataset = SlidingWindowDataset( - train_positions, - channels=self.channel_map, - z_window_size=1, - transform=Compose(self.train_transforms), + self.train_dataset = CachedOmeZarrDataset( + positions=train_positions, + channel_names=self.channel_names, + cache_map=cache_map, + transform=self.train_cpu_transforms, + load_normalization_metadata=False, ) - self.val_dataset = CTMCv1ValidationDataset( - val_positions, - channels=self.channel_map, - z_window_size=1, - transform=Compose(self.val_transforms), + full_val_dataset = CachedOmeZarrDataset( + positions=val_positions, + channel_names=self.channel_names, + cache_map=cache_map, + transform=self.val_cpu_transforms, + load_normalization_metadata=False, ) - - def train_dataloader(self) -> DataLoader: - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - persistent_workers=bool(self.num_workers), - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader: - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - persistent_workers=bool(self.num_workers), - shuffle=False, + subsample_indices = list( + range(0, len(full_val_dataset), self.val_subsample_ratio) ) + self.val_dataset = torch.utils.data.Subset(full_val_dataset, subsample_indices) diff --git a/viscy/data/distributed.py b/viscy/data/distributed.py new file mode 100644 index 00000000..68e6d39e --- /dev/null +++ b/viscy/data/distributed.py @@ -0,0 +1,56 @@ +"""Utilities for DDP training.""" + +from __future__ import annotations + +import math +from typing import TYPE_CHECKING + +import torch +import torch.distributed +from torch.utils.data.distributed import DistributedSampler + +if TYPE_CHECKING: + from torch import Generator + + +class ShardedDistributedSampler(DistributedSampler): + def _sharded_randperm(self, max_size: int, generator: Generator) -> list[int]: + """Generate a sharded random permutation of indices. + Overlap may occur in between the last two shards to maintain divisibility.""" + sharded_randperm = [ + torch.randperm(self.num_samples, generator=generator) + + min(i * self.num_samples, max_size - self.num_samples) + for i in range(self.num_replicas) + ] + indices = torch.stack(sharded_randperm, dim=1).reshape(-1) + return indices.tolist() + + def __iter__(self): + """Modified __iter__ method to shard data across distributed ranks.""" + max_size = len(self.dataset) # type: ignore[arg-type] + if self.shuffle: + # deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + indices = self._sharded_randperm(max_size, g) + else: + indices = list(range(max_size)) + if not self.drop_last: + # add extra samples to make it evenly divisible + padding_size = self.total_size - len(indices) + if padding_size <= len(indices): + indices += indices[:padding_size] + else: + indices += (indices * math.ceil(padding_size / len(indices)))[ + :padding_size + ] + else: + # remove tail of data to make it evenly divisible. + indices = indices[: self.total_size] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank : self.total_size : self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) diff --git a/viscy/data/gpu_aug.py b/viscy/data/gpu_aug.py new file mode 100644 index 00000000..a445fff5 --- /dev/null +++ b/viscy/data/gpu_aug.py @@ -0,0 +1,276 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from logging import getLogger +from pathlib import Path +from typing import TYPE_CHECKING, Literal + +import numpy as np +import torch +from iohub.ngff import Plate, Position, open_ome_zarr +from lightning.pytorch import LightningDataModule +from monai.data.meta_obj import set_track_meta +from monai.data.utils import list_data_collate +from monai.transforms.compose import Compose +from torch import Tensor +from torch.multiprocessing import Manager +from torch.utils.data import DataLoader, Dataset + +from viscy.data.distributed import ShardedDistributedSampler +from viscy.data.hcs import _ensure_channel_list, _read_norm_meta +from viscy.data.typing import DictTransform, NormMeta + +if TYPE_CHECKING: + from multiprocessing.managers import DictProxy + +_logger = getLogger("lightning.pytorch") + +_CacheMetadata = tuple[Position, int, NormMeta | None] + + +class GPUTransformDataModule(ABC, LightningDataModule): + """Abstract data module with GPU transforms.""" + + train_dataset: Dataset + val_dataset: Dataset + batch_size: int + num_workers: int + pin_memory: bool + + def _maybe_sampler( + self, dataset: Dataset, shuffle: bool + ) -> ShardedDistributedSampler | None: + return ( + ShardedDistributedSampler(dataset, shuffle=shuffle) + if torch.distributed.is_initialized() + else None + ) + + def train_dataloader(self) -> DataLoader: + sampler = self._maybe_sampler(self.train_dataset, shuffle=True) + _logger.debug(f"Using training sampler {sampler}") + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + shuffle=False if sampler else True, + sampler=sampler, + persistent_workers=True if self.num_workers > 0 else False, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + drop_last=False, + collate_fn=list_data_collate, + ) + + def val_dataloader(self) -> DataLoader: + sampler = self._maybe_sampler(self.val_dataset, shuffle=False) + _logger.debug(f"Using validation sampler {sampler}") + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + shuffle=False, + sampler=sampler, + persistent_workers=True if self.num_workers > 0 else False, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + drop_last=False, + collate_fn=list_data_collate, + ) + + @property + @abstractmethod + def train_cpu_transforms(self) -> Compose: ... + + @property + @abstractmethod + def train_gpu_transforms(self) -> Compose: ... + + @property + @abstractmethod + def val_cpu_transforms(self) -> Compose: ... + + @property + @abstractmethod + def val_gpu_transforms(self) -> Compose: ... + + +class CachedOmeZarrDataset(Dataset): + """Dataset for cached OME-Zarr arrays. + + Parameters + ---------- + positions : list[Position] + List of FOVs to load images from. + channel_names : list[str] + List of channel names to load. + cache_map : DictProxy + Shared dictionary for caching loaded volumes. + transform : Compose | None, optional + Composed transforms to be applied on the CPU, by default None + array_key : str, optional + The image array key name (multi-scale level), by default "0" + load_normalization_metadata : bool, optional + Load normalization metadata in the sample dictionary, by default True + skip_cache : bool, optional + Skip caching to save RAM, by default False + """ + + def __init__( + self, + positions: list[Position], + channel_names: list[str], + cache_map: DictProxy, + transform: Compose | None = None, + array_key: str = "0", + load_normalization_metadata: bool = True, + skip_cache: bool = False, + ): + key = 0 + self._metadata_map: dict[int, _CacheMetadata] = {} + for position in positions: + img = position[array_key] + norm_meta = _read_norm_meta(position) + for time_idx in range(img.frames): + cache_map[key] = None + self._metadata_map[key] = (position, time_idx, norm_meta) + key += 1 + self.channels = {ch: position.get_channel_index(ch) for ch in channel_names} + self.array_key = array_key + self._cache_map = cache_map + self.transform = transform + self.load_normalization_metadata = load_normalization_metadata + self.skip_cache = skip_cache + + def __len__(self) -> int: + return len(self._metadata_map) + + def __getitem__(self, idx: int) -> dict[str, Tensor]: + position, time_idx, norm_meta = self._metadata_map[idx] + cache = self._cache_map[idx] + if cache is None: + _logger.debug(f"Loading volume for index {idx}") + volume = torch.from_numpy( + position[self.array_key] + .oindex[time_idx, list(self.channels.values())] + .astype(np.float32) + ) + if not self.skip_cache: + _logger.debug(f"Caching for index {idx}") + self._cache_map[idx] = volume + else: + _logger.debug(f"Using cached volume for index {idx}") + volume = cache + sample = {name: img[None] for name, img in zip(self.channels.keys(), volume)} + if self.load_normalization_metadata: + sample["norm_meta"] = norm_meta + if self.transform: + sample = self.transform(sample) + if not isinstance(sample, list): + sample = [sample] + return sample + + +class CachedOmeZarrDataModule(GPUTransformDataModule): + """Data module for cached OME-Zarr arrays. + + Parameters + ---------- + data_path : Path + Path to the HCS OME-Zarr dataset. + channels : str | list[str] + Channel names to load. + batch_size : int + Batch size for training and validation. + num_workers : int + Number of workers for data-loaders. + split_ratio : float + Fraction of the FOVs used for the training split. + The rest will be used for validation. + train_cpu_transforms : list[DictTransform] + Transforms to be applied on the CPU during training. + val_cpu_transforms : list[DictTransform] + Transforms to be applied on the CPU during validation. + train_gpu_transforms : list[DictTransform] + Transforms to be applied on the GPU during training. + val_gpu_transforms : list[DictTransform] + Transforms to be applied on the GPU during validation. + pin_memory : bool, optional + Use page-locked memory in data-loaders, by default True + skip_cache : bool, optional + Skip caching for this dataset, by default False + """ + + def __init__( + self, + data_path: Path, + channels: str | list[str], + batch_size: int, + num_workers: int, + split_ratio: float, + train_cpu_transforms: list[DictTransform], + val_cpu_transforms: list[DictTransform], + train_gpu_transforms: list[DictTransform], + val_gpu_transforms: list[DictTransform], + pin_memory: bool = True, + skip_cache: bool = False, + ): + super().__init__() + self.data_path = data_path + self.channels = _ensure_channel_list(channels) + self.batch_size = batch_size + self.num_workers = num_workers + self.split_ratio = split_ratio + self._train_cpu_transforms = Compose(train_cpu_transforms) + self._val_cpu_transforms = Compose(val_cpu_transforms) + self._train_gpu_transforms = Compose(train_gpu_transforms) + self._val_gpu_transforms = Compose(val_gpu_transforms) + self.pin_memory = pin_memory + self.skip_cache = skip_cache + + @property + def train_cpu_transforms(self) -> Compose: + return self._train_cpu_transforms + + @property + def train_gpu_transforms(self) -> Compose: + return self._train_gpu_transforms + + @property + def val_cpu_transforms(self) -> Compose: + return self._val_cpu_transforms + + @property + def val_gpu_transforms(self) -> Compose: + return self._val_gpu_transforms + + def _set_fit_global_state(self, num_positions: int) -> list[int]: + # disable metadata tracking in MONAI for performance + set_track_meta(False) + # shuffle positions, randomness is handled globally + return torch.randperm(num_positions).tolist() + + def setup(self, stage: Literal["fit", "validate"]) -> None: + if stage not in ("fit", "validate"): + raise NotImplementedError("Only fit and validate stages are supported.") + cache_map = Manager().dict() + plate: Plate = open_ome_zarr(self.data_path, mode="r", layout="hcs") + positions = [p for _, p in plate.positions()] + shuffled_indices = self._set_fit_global_state(len(positions)) + num_train_fovs = int(len(positions) * self.split_ratio) + train_fovs = [positions[i] for i in shuffled_indices[:num_train_fovs]] + val_fovs = [positions[i] for i in shuffled_indices[num_train_fovs:]] + _logger.debug(f"Training FOVs: {[p.zgroup.name for p in train_fovs]}") + _logger.debug(f"Validation FOVs: {[p.zgroup.name for p in val_fovs]}") + self.train_dataset = CachedOmeZarrDataset( + train_fovs, + self.channels, + cache_map, + transform=self.train_cpu_transforms, + skip_cache=self.skip_cache, + ) + self.val_dataset = CachedOmeZarrDataset( + val_fovs, + self.channels, + cache_map, + transform=self.val_cpu_transforms, + skip_cache=self.skip_cache, + ) diff --git a/viscy/data/livecell.py b/viscy/data/livecell.py index bb8bb56c..e8da1eb4 100644 --- a/viscy/data/livecell.py +++ b/viscy/data/livecell.py @@ -1,45 +1,81 @@ +from __future__ import annotations + import json from pathlib import Path +from typing import TYPE_CHECKING import torch -from lightning.pytorch import LightningDataModule -from monai.transforms import Compose, MapTransform +from monai.transforms import Compose, MapTransform, Transform from pycocotools.coco import COCO from tifffile import imread from torch.utils.data import DataLoader, Dataset from torchvision.ops import box_convert +from viscy.data.gpu_aug import GPUTransformDataModule from viscy.data.typing import Sample +if TYPE_CHECKING: + from multiprocessing.managers import DictProxy + class LiveCellDataset(Dataset): """ LiveCell dataset. - :param list[Path] images: List of paths to single-page, single-channel TIFF files. - :param MapTransform | Compose transform: Transform to apply to the dataset + Parameters + ---------- + images : list of Path + List of paths to single-page, single-channel TIFF files. + transform : Transform or Compose + Transform to apply to the dataset. + cache_map : DictProxy + Shared dictionary for caching images. """ - def __init__(self, images: list[Path], transform: MapTransform | Compose) -> None: + def __init__( + self, + images: list[Path], + transform: Transform | Compose, + cache_map: DictProxy, + ) -> None: self.images = images self.transform = transform + self._cache_map = cache_map def __len__(self) -> int: return len(self.images) def __getitem__(self, idx: int) -> Sample: - image = imread(self.images[idx])[None, None] - image = torch.from_numpy(image).to(torch.float32) - image = self.transform(image) - return {"source": image, "target": image} + name = self.images[idx] + if name not in self._cache_map: + image = imread(name)[None, None] + image = torch.from_numpy(image).to(torch.float32) + self._cache_map[name] = image + else: + image = self._cache_map[name] + sample = Sample(source=image) + sample = self.transform(sample) + if not isinstance(sample, list): + sample = [sample] + return sample class LiveCellTestDataset(Dataset): """ LiveCell dataset. - :param list[Path] images: List of paths to single-page, single-channel TIFF files. - :param MapTransform | Compose transform: Transform to apply to the dataset + Parameters + ---------- + image_dir : Path + Directory containing the images. + transform : MapTransform | Compose + Transform to apply to the dataset. + annotations : Path + Path to the COCO annotations file. + load_target : bool, optional + Whether to load the target images (default is False). + load_labels : bool, optional + Whether to load the labels (default is False). """ def __init__( @@ -87,7 +123,7 @@ def __getitem__(self, idx: int) -> Sample: return sample -class LiveCellDataModule(LightningDataModule): +class LiveCellDataModule(GPUTransformDataModule): def __init__( self, train_val_images: Path | None = None, @@ -95,33 +131,60 @@ def __init__( train_annotations: Path | None = None, val_annotations: Path | None = None, test_annotations: Path | None = None, - train_transforms: list[MapTransform] = [], - val_transforms: list[MapTransform] = [], + train_cpu_transforms: list[MapTransform] = [], + val_cpu_transforms: list[MapTransform] = [], + train_gpu_transforms: list[MapTransform] = [], + val_gpu_transforms: list[MapTransform] = [], test_transforms: list[MapTransform] = [], batch_size: int = 16, num_workers: int = 8, + pin_memory: bool = True, ) -> None: super().__init__() - self.train_val_images = Path(train_val_images) - if not self.train_val_images.is_dir(): - raise NotADirectoryError(str(train_val_images)) - self.test_images = Path(test_images) - if not self.test_images.is_dir(): - raise NotADirectoryError(str(test_images)) - self.train_annotations = Path(train_annotations) - if not self.train_annotations.is_file(): - raise FileNotFoundError(str(train_annotations)) - self.val_annotations = Path(val_annotations) - if not self.val_annotations.is_file(): - raise FileNotFoundError(str(val_annotations)) - self.test_annotations = Path(test_annotations) - if not self.test_annotations.is_file(): - raise FileNotFoundError(str(test_annotations)) - self.train_transforms = Compose(train_transforms) - self.val_transforms = Compose(val_transforms) + if train_val_images is not None: + self.train_val_images = Path(train_val_images) + if not self.train_val_images.is_dir(): + raise NotADirectoryError(str(train_val_images)) + if test_images is not None: + self.test_images = Path(test_images) + if not self.test_images.is_dir(): + raise NotADirectoryError(str(test_images)) + if train_annotations is not None: + self.train_annotations = Path(train_annotations) + if not self.train_annotations.is_file(): + raise FileNotFoundError(str(train_annotations)) + if val_annotations is not None: + self.val_annotations = Path(val_annotations) + if not self.val_annotations.is_file(): + raise FileNotFoundError(str(val_annotations)) + if test_annotations is not None: + self.test_annotations = Path(test_annotations) + if not self.test_annotations.is_file(): + raise FileNotFoundError(str(test_annotations)) + self._train_cpu_transforms = Compose(train_cpu_transforms) + self._val_cpu_transforms = Compose(val_cpu_transforms) + self._train_gpu_transforms = Compose(train_gpu_transforms) + self._val_gpu_transforms = Compose(val_gpu_transforms) self.test_transforms = Compose(test_transforms) self.batch_size = batch_size self.num_workers = num_workers + self.pin_memory = pin_memory + + @property + def train_cpu_transforms(self) -> Compose: + return self._train_cpu_transforms + + @property + def val_cpu_transforms(self) -> Compose: + return self._val_cpu_transforms + + @property + def train_gpu_transforms(self) -> Compose: + return self._train_gpu_transforms + + @property + def val_gpu_transforms(self) -> Compose: + return self._val_gpu_transforms def setup(self, stage: str) -> None: if stage == "fit": @@ -135,15 +198,18 @@ def _parse_image_names(self, annotations: Path) -> list[Path]: return sorted(images) def _setup_fit(self) -> None: + cache_map = torch.multiprocessing.Manager().dict() train_images = self._parse_image_names(self.train_annotations) val_images = self._parse_image_names(self.val_annotations) self.train_dataset = LiveCellDataset( [self.train_val_images / f for f in train_images], - transform=self.train_transforms, + transform=self.train_cpu_transforms, + cache_map=cache_map, ) self.val_dataset = LiveCellDataset( [self.train_val_images / f for f in val_images], - transform=self.val_transforms, + transform=self.val_cpu_transforms, + cache_map=cache_map, ) def _setup_test(self) -> None: @@ -154,23 +220,6 @@ def _setup_test(self) -> None: load_labels=True, ) - def train_dataloader(self) -> DataLoader: - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - persistent_workers=bool(self.num_workers), - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader: - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - persistent_workers=bool(self.num_workers), - ) - def test_dataloader(self) -> DataLoader: return DataLoader( self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers diff --git a/viscy/data/typing.py b/viscy/data/typing.py index fb7b6b73..f5e66caa 100644 --- a/viscy/data/typing.py +++ b/viscy/data/typing.py @@ -50,7 +50,7 @@ class Sample(TypedDict, total=False): # Instance segmentation masks labels: OneOrSeq[Tensor] # None: not available - norm_meta: NormMeta + norm_meta: NormMeta | None class ChannelMap(TypedDict): diff --git a/viscy/scripts/shared_dict.py b/viscy/scripts/shared_dict.py new file mode 100644 index 00000000..b29d4d86 --- /dev/null +++ b/viscy/scripts/shared_dict.py @@ -0,0 +1,121 @@ +from multiprocessing.managers import DictProxy + +import torch +from lightning.pytorch import LightningDataModule, LightningModule, Trainer +from lightning.pytorch.utilities import rank_zero_info +from torch.distributed import get_rank +from torch.multiprocessing import Manager +from torch.utils.data import DataLoader, Dataset, Subset + +from viscy.data.distributed import ShardedDistributedSampler + + +class CachedDataset(Dataset): + def __init__(self, shared_dict: DictProxy, length: int): + self.rank = get_rank() + print(f"=== Initializing cache pool for rank {self.rank} ===") + self.shared_dict = shared_dict + self.length = length + + def __getitem__(self, index): + if index not in self.shared_dict: + print(f"* Adding {index} to cache dict on rank {self.rank}") + self.shared_dict[index] = torch.tensor(index).float()[None] + return self.shared_dict[index] + + def __len__(self): + return self.length + + +class CachedDataModule(LightningDataModule): + def __init__( + self, + length: int, + split_ratio: float, + batch_size: int, + num_workers: int, + persistent_workers: bool, + ): + super().__init__() + self.length = length + self.split_ratio = split_ratio + self.batch_size = batch_size + self.num_workers = num_workers + self.persistent_workers = persistent_workers + + def setup(self, stage): + if stage != "fit": + raise NotImplementedError("Only fit stage is supported.") + shared_dict = Manager().dict() + dataset = CachedDataset(shared_dict, self.length) + split_idx = int(self.length * self.split_ratio) + self.train_dataset = Subset(dataset, range(0, split_idx)) + self.val_dataset = Subset(dataset, range(split_idx, self.length)) + + def train_dataloader(self): + sampler = ShardedDistributedSampler(self.train_dataset, shuffle=True) + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + persistent_workers=self.persistent_workers, + drop_last=False, + sampler=sampler, + ) + + def val_dataloader(self): + sampler = ShardedDistributedSampler(self.val_dataset, shuffle=False) + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + persistent_workers=self.persistent_workers, + drop_last=False, + sampler=sampler, + ) + + +class DummyModel(LightningModule): + def __init__(self): + super().__init__() + self.layer = torch.nn.Linear(1, 1) + + def forward(self, x): + return self.layer(x) + + def on_train_start(self): + rank_zero_info("=== Starting training ===") + + def on_train_epoch_start(self): + rank_zero_info(f"=== Starting training epoch {self.current_epoch} ===") + + def training_step(self, batch, batch_idx): + loss = torch.nn.functional.mse_loss(self.layer(batch), torch.zeros_like(batch)) + return loss + + def validation_step(self, batch, batch_idx): + loss = torch.nn.functional.mse_loss(self.layer(batch), torch.zeros_like(batch)) + return loss + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=1e-3) + + +trainer = Trainer( + max_epochs=5, + strategy="ddp", + accelerator="cpu", + devices=3, + use_distributed_sampler=False, + enable_progress_bar=False, + logger=False, + enable_checkpointing=False, +) + +data_module = CachedDataModule( + length=50, batch_size=2, split_ratio=0.6, num_workers=4, persistent_workers=False +) +model = DummyModel() +trainer.fit(model, data_module) diff --git a/viscy/transforms.py b/viscy/transforms.py index 5eca0538..f4e8103b 100644 --- a/viscy/transforms.py +++ b/viscy/transforms.py @@ -2,8 +2,11 @@ from typing import Sequence, Union +import numpy as np +import torch from monai.transforms import ( MapTransform, + MultiSampleTrait, RandAdjustContrastd, RandAffined, RandGaussianNoised, @@ -12,12 +15,12 @@ RandScaleIntensityd, RandWeightedCropd, ScaleIntensityRangePercentilesd, + Transform, ) -from monai.transforms.transform import Randomizable -from numpy.random.mtrand import RandomState as RandomState +from torch import Tensor from typing_extensions import Iterable, Literal -from viscy.data.typing import Sample +from viscy.data.typing import ChannelMap, Sample class RandWeightedCropd(RandWeightedCropd): @@ -158,11 +161,20 @@ def __init__( class NormalizeSampled(MapTransform): """ - Normalize the sample - :param Union[str, Iterable[str]] keys: keys to normalize - :param str fov: fov path with respect to Plate - :param str subtrahend: subtrahend for normalization, defaults to "mean" - :param str divisor: divisor for normalization, defaults to "std" + Normalize the sample. + + Parameters + ---------- + keys : Union[str, Iterable[str]] + Keys to normalize. + level : {'fov_statistics', 'dataset_statistics'} + Level of normalization. + subtrahend : str, optional + Subtrahend for normalization, defaults to "mean". + divisor : str, optional + Divisor for normalization, defaults to "std". + remove_meta : bool, optional + Whether to remove metadata after normalization, defaults to False. """ def __init__( @@ -171,11 +183,13 @@ def __init__( level: Literal["fov_statistics", "dataset_statistics"], subtrahend="mean", divisor="std", + remove_meta: bool = False, ) -> None: super().__init__(keys, allow_missing_keys=False) self.subtrahend = subtrahend self.divisor = divisor self.level = level + self.remove_meta = remove_meta # TODO: need to implement the case where the preprocessing already exists def __call__(self, sample: Sample) -> Sample: @@ -184,6 +198,8 @@ def __call__(self, sample: Sample) -> Sample: subtrahend_val = level_meta[self.subtrahend] divisor_val = level_meta[self.divisor] + 1e-8 # avoid div by zero sample[key] = (sample[key] - subtrahend_val) / divisor_val + if self.remove_meta: + sample.pop("norm_meta") return sample def _normalize(): @@ -206,13 +222,112 @@ def __init__( def __call__(self, sample: Sample) -> Sample: self.randomize(None) + if not self._do_transform: + return sample for key in self.keys: if key in sample: sample[key] = -sample[key] return sample - def set_random_state( - self, seed: int | None = None, state: RandomState | None = None - ) -> Randomizable: - super().set_random_state(seed, state) - return self + +class TiledSpatialCropSamplesd(MapTransform, MultiSampleTrait): + """ + Crop multiple tiled ROIs from an image. + Used for deterministic cropping in validation. + """ + + def __init__( + self, + keys: Union[str, Iterable[str]], + roi_size: tuple[int, int, int], + num_samples: int, + ) -> None: + super().__init__(keys, allow_missing_keys=False) + self.roi_size = roi_size + self.num_samples = num_samples + + def _check_num_samples(self, spatial_size: np.ndarray, offset: int) -> np.ndarray: + max_grid_shape = spatial_size // self.roi_size + max_num_samples = max_grid_shape.prod() + if offset >= max_num_samples: + raise ValueError( + f"Number of samples {self.num_samples} should be " + f"smaller than {max_num_samples}." + ) + grid_idx = np.asarray(np.unravel_index(offset, max_grid_shape)) + return grid_idx * self.roi_size + + def _crop(self, img: Tensor, offset: int) -> Tensor: + spatial_size = np.array(img.shape[-3:]) + crop_start = self._check_num_samples(spatial_size, offset) + crop_end = crop_start + np.array(self.roi_size) + return img[ + ..., + crop_start[0] : crop_end[0], + crop_start[1] : crop_end[1], + crop_start[2] : crop_end[2], + ] + + def __call__(self, sample: Sample) -> Sample: + results = [] + for i in range(self.num_samples): + result = {} + for key in self.keys: + result[key] = self._crop(sample[key], i) + if "norm_meta" in sample: + result["norm_meta"] = sample["norm_meta"] + results.append(result) + return results + + +class StackChannelsd(MapTransform): + """Stack source and target channels.""" + + def __init__(self, channel_map: ChannelMap) -> None: + channel_names = [] + for channels in channel_map.values(): + channel_names.extend(channels) + super().__init__(channel_names, allow_missing_keys=False) + self.channel_map = channel_map + + def __call__(self, sample: Sample) -> Sample: + results = {} + for key, channels in self.channel_map.items(): + results[key] = torch.cat([sample[ch] for ch in channels], dim=0) + return results + + +class BatchedZoom(Transform): + "Batched zoom transform using ``torch.nn.functional.interpolate``." + + def __init__( + self, + scale_factor: float | tuple[float, float, float], + mode: Literal[ + "nearest", + "nearest-exact", + "linear", + "bilinear", + "bicubic", + "trilinear", + "area", + ], + align_corners: bool | None = None, + recompute_scale_factor: bool | None = None, + antialias: bool = False, + ) -> None: + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + self.recompute_scale_factor = recompute_scale_factor + self.antialias = antialias + + def __call__(self, sample: Tensor) -> Tensor: + return torch.nn.functional.interpolate( + sample, + scale_factor=self.scale_factor, + mode=self.mode, + align_corners=self.align_corners, + recompute_scale_factor=self.recompute_scale_factor, + antialias=self.antialias, + ) diff --git a/viscy/translation/engine.py b/viscy/translation/engine.py index 698ff412..831217fc 100644 --- a/viscy/translation/engine.py +++ b/viscy/translation/engine.py @@ -1,12 +1,14 @@ import logging import os -from typing import Literal, Sequence, Union +import random +from typing import Callable, Literal, Sequence, Union import numpy as np import torch import torch.nn.functional as F from imageio import imwrite from lightning.pytorch import LightningModule +from monai.data.utils import collate_meta_tensor from monai.optimizers import WarmupCosineSchedule from monai.transforms import DivisiblePad, Rotate90 from torch import Tensor, nn @@ -23,6 +25,8 @@ structural_similarity_index_measure, ) +from viscy.data.combined import CombinedDataModule +from viscy.data.gpu_aug import GPUTransformDataModule from viscy.data.typing import Sample from viscy.translation.evaluation_metrics import mean_average_precision, ms_ssim_25d from viscy.unet.networks.fcmae import FullyConvolutionalMAE @@ -85,6 +89,13 @@ def forward(self, preds, target): return loss +class MaskedMSELoss(nn.Module): + def forward(self, preds, original, mask): + loss = F.mse_loss(preds, original, reduction="none") + loss = (loss.mean(2) * mask).sum() / mask.sum() + return loss + + class VSUNet(LightningModule): """Regression U-Net module for virtual staining. @@ -98,7 +109,7 @@ class VSUNet(LightningModule): :param float lr: learning rate in training, defaults to 1e-3 :param Literal['WarmupCosine', 'Constant'] schedule: learning rate scheduler, defaults to "Constant" - :param str chkpt_path: path to the checkpoint to load weights, defaults to None + :param str ckpt_path: path to the checkpoint to load weights, defaults to None :param int log_batches_per_epoch: number of batches to log each training/validation epoch, has to be smaller than steps per epoch, defaults to 8 @@ -365,7 +376,8 @@ def perform_test_time_augmentations(self, source: Tensor) -> Tensor: elif self.tta_type == "median": prediction = torch.stack(predictions).median(dim=0).values elif self.tta_type == "product": - # Perform multiplication of predictions in logarithmic space for numerical stability adding epsion to avoid log(0) case + # Perform multiplication of predictions in logarithmic space + # for numerical stability adding epsilon to avoid log(0) case log_predictions = torch.stack([torch.log(p + 1e-9) for p in predictions]) log_prediction_sum = log_predictions.sum(dim=0) prediction = torch.exp(log_prediction_sum) @@ -466,64 +478,175 @@ def _crop_to_original(self, tensor: Tensor) -> Tensor: return cropped_tensor +class AugmentedPredictionVSUNet(LightningModule): + def __init__( + self, + model: nn.Module, + forward_transforms: list[Callable[[Tensor], Tensor]], + inverse_transforms: list[Callable[[Tensor], Tensor]], + reduction: Literal["mean", "median"] = "mean", + ) -> None: + super().__init__() + down_factor = 2**model.num_blocks + self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) + self.model = model + self._forward_transforms = forward_transforms + self._inverse_transforms = inverse_transforms + self._reduction = reduction + + def forward(self, x: Tensor) -> Tensor: + return self.model(x) + + def setup(self, stage: str) -> None: + if stage != "predict": + raise NotImplementedError( + f"Only the 'predict' stage is supported by {type(self)}" + ) + + def _reduce_predictions(self, preds: list[Tensor]) -> Tensor: + prediction = torch.stack(preds, dim=0) + if self._reduction == "mean": + prediction = prediction.mean(dim=0) + elif self._reduction == "median": + prediction = prediction.median(dim=0).values + return prediction + + def predict_step( + self, batch: Sample, batch_idx: int, dataloader_idx: int = 0 + ) -> Tensor: + source = batch["source"] + preds = [] + for forward_t, inverse_t in zip( + self._forward_transforms, self._inverse_transforms + ): + source = forward_t(source) + source = self._predict_pad(source) + pred = self.forward(source) + pred = self._predict_pad.inverse(pred) + pred = inverse_t(pred) + preds.append(pred) + if len(preds) == 1: + prediction = preds[0] + else: + prediction = self._reduce_predictions(preds) + return prediction + + class FcmaeUNet(VSUNet): - def __init__(self, fit_mask_ratio: float = 0.0, **kwargs): + def __init__( + self, + fit_mask_ratio: float = 0.0, + **kwargs, + ): super().__init__(architecture="fcmae", **kwargs) self.fit_mask_ratio = fit_mask_ratio + self.save_hyperparameters(ignore=["loss_function"]) + + def on_fit_start(self): + dm = self.trainer.datamodule + if not isinstance(dm, CombinedDataModule): + raise ValueError( + f"Container data module type {type(dm)} " + "is not supported for FCMAE training" + ) + for subdm in dm.data_modules: + if not isinstance(subdm, GPUTransformDataModule): + raise ValueError( + f"Member data module type {type(subdm)} " + "is not supported for FCMAE training" + ) + self.datamodules = dm.data_modules + if self.model.pretraining and not isinstance(self.loss_function, MaskedMSELoss): + raise ValueError( + "MaskedMSELoss is required for FCMAE pre-training, " + f"got {type(self.loss_function)}" + ) def forward(self, x: Tensor, mask_ratio: float = 0.0): return self.model(x, mask_ratio) - def forward_fit(self, batch: Sample) -> tuple[Tensor]: - source = batch["source"] + def forward_fit_fcmae( + self, batch: Sample, return_target: bool = False + ) -> tuple[Tensor, Tensor | None, Tensor]: + x = batch["source"] + pred, mask = self.forward(x, mask_ratio=self.fit_mask_ratio) + loss = self.loss_function(pred, x, mask) + if return_target: + target = x * mask.unsqueeze(2) + else: + target = None + return pred, target, loss + + def forward_fit_supervised(self, batch: Sample) -> tuple[Tensor, Tensor, Tensor]: + x = batch["source"] target = batch["target"] - pred, mask = self.forward(source, mask_ratio=self.fit_mask_ratio) - loss = F.mse_loss(pred, target, reduction="none") - loss = (loss.mean(2) * mask).sum() / mask.sum() - return source, target, pred, mask, loss + pred = self.forward(x) + loss = self.loss_function(pred, target) + return pred, target, loss - def training_step(self, batch: Sequence[Sample], batch_idx: int): - losses = [] - batch_size = 0 - for b in batch: - source, target, pred, mask, loss = self.forward_fit(b) - losses.append(loss) - batch_size += source.shape[0] + def forward_fit_task( + self, batch: Sample, batch_idx: int + ) -> tuple[Tensor, Tensor | None, Tensor]: + if self.model.pretraining: if batch_idx < self.log_batches_per_epoch: - self.training_step_outputs.extend( - detach_sample( - (source, target * mask.unsqueeze(2), pred), - self.log_samples_per_batch, - ) + return_target = True + pred, target, loss = self.forward_fit_fcmae(batch, return_target) + else: + pred, target, loss = self.forward_fit_supervised(batch) + return pred, target, loss + + @torch.no_grad() + def train_transform_and_collate(self, batch: list[dict[str, Tensor]]) -> Sample: + transformed = [] + for dataset_batch, dm in zip(batch, self.datamodules): + dataset_batch = dm.train_gpu_transforms(dataset_batch) + transformed.extend(dataset_batch) + # shuffle references in place for better logging + random.shuffle(transformed) + return collate_meta_tensor(transformed) + + @torch.no_grad() + def val_transform_and_collate( + self, batch: list[Sample], dataloader_idx: int + ) -> Tensor: + batch = self.datamodules[dataloader_idx].val_gpu_transforms(batch) + return collate_meta_tensor(batch) + + def training_step(self, batch: list[list[Sample]], batch_idx: int) -> Tensor: + batch = self.train_transform_and_collate(batch) + pred, target, loss = self.forward_fit_task(batch, batch_idx) + if batch_idx < self.log_batches_per_epoch: + self.training_step_outputs.extend( + detach_sample( + (batch["source"], target, pred), self.log_samples_per_batch ) - loss_step = torch.stack(losses).mean() + ) self.log( "loss/train", - loss_step.to(self.device), + loss.to(self.device), on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, - batch_size=batch_size, + batch_size=pred.shape[0], ) - return loss_step + return loss - def validation_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): - source, target, pred, mask, loss = self.forward_fit(batch) + def validation_step( + self, batch: list[Sample], batch_idx: int, dataloader_idx: int = 0 + ) -> None: + batch = self.val_transform_and_collate(batch, dataloader_idx) + pred, target, loss = self.forward_fit_task(batch, batch_idx) if dataloader_idx + 1 > len(self.validation_losses): self.validation_losses.append([]) self.validation_losses[dataloader_idx].append(loss.detach()) self.log( - f"loss/val/{dataloader_idx}", - loss.to(self.device), - sync_dist=True, - batch_size=source.shape[0], + "loss/val", loss.to(self.device), sync_dist=True, batch_size=pred.shape[0] ) if batch_idx < self.log_batches_per_epoch: self.validation_step_outputs.extend( detach_sample( - (source, target * mask.unsqueeze(2), pred), - self.log_samples_per_batch, + (batch["source"], target, pred), self.log_samples_per_batch ) )