From e19ee14dec4b0cf61938d391ad81625ce4811ef5 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Thu, 12 Sep 2024 15:47:38 -0700 Subject: [PATCH 01/49] caching dataloader --- viscy/data/hcs_ram.py | 207 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 207 insertions(+) create mode 100644 viscy/data/hcs_ram.py diff --git a/viscy/data/hcs_ram.py b/viscy/data/hcs_ram.py new file mode 100644 index 00000000..f9e0fa82 --- /dev/null +++ b/viscy/data/hcs_ram.py @@ -0,0 +1,207 @@ +import logging +import math +import os +import re +import tempfile +from pathlib import Path +from typing import Callable, Literal, Sequence + +import numpy as np +import torch +import zarr +from imageio import imread +from iohub.ngff import ImageArray, Plate, Position, open_ome_zarr +from lightning.pytorch import LightningDataModule +from monai.data import set_track_meta +from monai.data.utils import collate_meta_tensor +from monai.transforms import ( + CenterSpatialCropd, + Compose, + MapTransform, + MultiSampleTrait, + RandAffined, +) +from torch import Tensor +from torch.utils.data import DataLoader, Dataset + +from viscy.data.typing import ChannelMap, DictTransform, HCSStackIndex, NormMeta, Sample +from viscy.data.hcs import _read_norm_meta +from tqdm import tqdm + +_logger = logging.getLogger("lightning.pytorch") + +# TODO: cache the norm metadata when caching the dataset + + +def _stack_channels( + sample_images: list[dict[str, Tensor]] | dict[str, Tensor], + channels: ChannelMap, + key: str, +) -> Tensor | list[Tensor]: + """Stack single-channel images into a multi-channel tensor.""" + if not isinstance(sample_images, list): + return torch.stack([sample_images[ch][0] for ch in channels[key]]) + # training time + return [torch.stack([im[ch][0] for ch in channels[key]]) for im in sample_images] + + +class CachedDataset(Dataset): + """ + A dataset that caches the data in RAM. + It relies on the `__getitem__` method to load the data on the 1st epoch. + """ + + def __init__( + self, + positions: list[Position], + channels: ChannelMap, + transform: DictTransform | None = None, + ): + super().__init__() + self.positions = positions + self.channels = channels + self.transform = transform + + self.source_ch_idx = [ + positions[0].get_channel_index(c) for c in channels["source"] + ] + self.target_ch_idx = ( + [positions[0].get_channel_index(c) for c in channels["target"]] + if "target" in channels + else None + ) + self._position_mapping() + self.cache_dict = {} + + def _position_mapping(self) -> None: + self.position_keys = [] + self.norm_meta_dict = {} + + for pos in self.positions: + self.position_keys.append(pos.data.name) + self.norm_meta_dict[str(pos.data.name)] = _read_norm_meta(pos) + + def _cache_dataset(self, index: int, channel_index: list[int], t: int = 0) -> None: + # Add the position to the cached_dict + # TODO: hardcoding to t=0 + self.cache_dict[str(self.position_keys[index])] = torch.from_numpy( + self.positions[index] + .data.oindex[slice(t, t + 1), channel_index, :] + .astype(np.float32) + ) + + def _get_weight_map(self, position: Position) -> Tensor: + # Get the weight map from the position for the MONAI weightedcrop transform + raise NotImplementedError + + def __len__(self) -> int: + return len(self.positions) + + def __getitem__(self, index: int) -> Sample: + + ch_names = self.channels["source"].copy() + ch_idx = self.source_ch_idx.copy() + if self.target_ch_idx is not None: + ch_names.extend(self.channels["target"]) + ch_idx.extend(self.target_ch_idx) + + # Check if the sample is in the cache else add it + # Split the tensor into the channels + sample_id = self.position_keys[index] + if sample_id not in self.cache_dict: + logging.debug(f"Adding {sample_id} to cache") + self._cache_dataset(index, channel_index=ch_idx) + + # Get the sample from the cache + images = self.cache_dict[sample_id].unbind(dim=1) + norm_meta = self.norm_meta_dict[str(sample_id)] + + sample_images = {k: v for k, v in zip(ch_names, images)} + + if self.target_ch_idx is not None: + # FIXME: this uses the first target channel as weight for performance + # since adding a reference to a tensor does not copy + # maybe write a weight map in preprocessing to use more information? + sample_images["weight"] = sample_images[self.channels["target"][0]] + if norm_meta is not None: + sample_images["norm_meta"] = norm_meta + if self.transform: + sample_images = self.transform(sample_images) + if "weight" in sample_images: + del sample_images["weight"] + sample = { + "index": sample_id, + "source": _stack_channels(sample_images, self.channels, "source"), + "norm_meta": norm_meta, + } + if self.target_ch_idx is not None: + sample["target"] = _stack_channels(sample_images, self.channels, "target") + return sample + + def _load_sample(self, position: Position) -> Sample: + source, target = self.channel_map.source, self.channel_map.target + source_data = self._load_channel_data(position, source) + target_data = self._load_channel_data(position, target) + sample = {"source": source_data, "target": target_data} + return sample + + +class CachedDataloader(LightningDataModule): + def __init__( + self, + data_path: str, + source_channel: str | Sequence[str], + target_channel: str | Sequence[str], + split_ratio: float = 0.8, + batch_size: int = 16, + num_workers: int = 8, + architecture: Literal["2D", "UNeXt2", "2.5D", "3D", "fcmae"] = "UNeXt2", + yx_patch_size: tuple[int, int] = (256, 256), + normalizations: list[MapTransform] = [], + augmentations: list[MapTransform] = [], + ): + super().__init__() + self.data_path = data_path + self.source_channel = source_channel + self.target_channel = target_channel + self.batch_size = batch_size + self.num_workers = num_workers + self.target_2d = False if architecture in ["UNeXt2", "3D", "fcmae"] else True + self.split_ratio = split_ratio + self.yx_patch_size = yx_patch_size + self.normalizations = normalizations + self.augmentations = augmentations + + @property + def _base_dataset_settings(self) -> dict[str, dict[str, list[str]] | int]: + return { + "channels": {"source": self.source_channel}, + } + + def setup(self, stage: Literal["fit", "validate", "test", "predict"]) -> None: + dataset_settings = self._base_dataset_settings + if stage in ("fit", "validate"): + self._setup_fit(dataset_settings) + elif stage == "test": + self._setup_test(dataset_settings) + elif stage == "predict": + self._setup_predict(dataset_settings) + else: + raise NotImplementedError(f"Stage {stage} is not supported") + + def _setup_fit(self, dataset_settings: dict) -> None: + """ + Setup the train and validation datasets. + """ + train_transform, val_transform = self._fit_transform() + dataset_settings["channels"]["target"] = self.target_channel + # Load the plate + plate = open_ome_zarr(self.data_path) + + pass + + def _setup_test(self) -> None: + pass + + def _setup_val(self) -> None: + pass From d31978d928af8820f3e7b3db2a8f0aa4ac3a9fc4 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 25 Sep 2024 17:27:04 -0700 Subject: [PATCH 02/49] caching data module --- viscy/data/hcs_ram.py | 88 ++++++++++++++++++++++++++++++++++++++----- 1 file changed, 78 insertions(+), 10 deletions(-) diff --git a/viscy/data/hcs_ram.py b/viscy/data/hcs_ram.py index f9e0fa82..5e72f6f6 100644 --- a/viscy/data/hcs_ram.py +++ b/viscy/data/hcs_ram.py @@ -98,7 +98,6 @@ def __len__(self) -> int: return len(self.positions) def __getitem__(self, index: int) -> Sample: - ch_names = self.channels["source"].copy() ch_idx = self.source_ch_idx.copy() if self.target_ch_idx is not None: @@ -109,7 +108,7 @@ def __getitem__(self, index: int) -> Sample: # Split the tensor into the channels sample_id = self.position_keys[index] if sample_id not in self.cache_dict: - logging.debug(f"Adding {sample_id} to cache") + logging.info(f"Adding {sample_id} to cache") self._cache_dataset(index, channel_index=ch_idx) # Get the sample from the cache @@ -146,7 +145,7 @@ def _load_sample(self, position: Position) -> Sample: return sample -class CachedDataloader(LightningDataModule): +class CachedDataModule(LightningDataModule): def __init__( self, data_path: str, @@ -159,6 +158,7 @@ def __init__( yx_patch_size: tuple[int, int] = (256, 256), normalizations: list[MapTransform] = [], augmentations: list[MapTransform] = [], + z_window_size: int = 1, ): super().__init__() self.data_path = data_path @@ -171,6 +171,7 @@ def __init__( self.yx_patch_size = yx_patch_size self.normalizations = normalizations self.augmentations = augmentations + self.z_window_size = z_window_size @property def _base_dataset_settings(self) -> dict[str, dict[str, list[str]] | int]: @@ -183,12 +184,53 @@ def setup(self, stage: Literal["fit", "validate", "test", "predict"]) -> None: if stage in ("fit", "validate"): self._setup_fit(dataset_settings) elif stage == "test": - self._setup_test(dataset_settings) + raise NotImplementedError("Test stage is not supported") elif stage == "predict": - self._setup_predict(dataset_settings) + raise NotImplementedError("Predict stage is not supported") else: raise NotImplementedError(f"Stage {stage} is not supported") + def _train_transform(self) -> list[Callable]: + if self.augmentations: + for aug in self.augmentations: + if isinstance(aug, MultiSampleTrait): + num_samples = aug.cropper.num_samples + if self.batch_size % num_samples != 0: + raise ValueError( + "Batch size must be divisible by `num_samples` per stack. " + f"Got batch size {self.batch_size} and " + f"number of samples {num_samples} for " + f"transform type {type(aug)}." + ) + self.train_patches_per_stack = num_samples + return list(self.augmentations) + + def _fit_transform(self) -> tuple[Compose, Compose]: + """(normalization -> maybe augmentation -> center crop) + Deterministic center crop as the last step of training and validation.""" + # TODO: These have a fixed order for now... () + final_crop = [ + CenterSpatialCropd( + keys=self.source_channel + self.target_channel, + roi_size=( + self.z_window_size, + self.yx_patch_size[0], + self.yx_patch_size[1], + ), + ) + ] + train_transform = Compose( + self.normalizations + self._train_transform() + final_crop + ) + val_transform = Compose(self.normalizations + final_crop) + return train_transform, val_transform + + def _set_fit_global_state(self, num_positions: int) -> torch.Tensor: + # disable metadata tracking in MONAI for performance + set_track_meta(False) + # shuffle positions, randomness is handled globally + return torch.randperm(num_positions) + def _setup_fit(self, dataset_settings: dict) -> None: """ Setup the train and validation datasets. @@ -197,11 +239,37 @@ def _setup_fit(self, dataset_settings: dict) -> None: dataset_settings["channels"]["target"] = self.target_channel # Load the plate plate = open_ome_zarr(self.data_path) + # shuffle positions, randomness is handled globally + positions = [pos for _, pos in plate.positions()] + shuffled_indices = self._set_fit_global_state(len(positions)) + positions = list(positions[i] for i in shuffled_indices) + num_train_fovs = int(len(positions) * self.split_ratio) - pass + self.train_dataset = CachedDataset( + positions[:num_train_fovs], + transform=train_transform, + **dataset_settings, + ) + self.val_dataset = CachedDataset( + positions[num_train_fovs:], + transform=val_transform, + **dataset_settings, + ) - def _setup_test(self) -> None: - pass + def train_dataloader(self) -> DataLoader: + return DataLoader( + self.train_dataset, + batch_size=self.batch_size // self.train_patches_per_stack, + num_workers=self.num_workers, + persistent_workers=bool(self.num_workers), + shuffle=True, + ) - def _setup_val(self) -> None: - pass + def val_dataloader(self) -> DataLoader: + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + persistent_workers=bool(self.num_workers), + shuffle=False, + ) From 041d73837665bb3d465b48d29d63dc0621931cf4 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Thu, 26 Sep 2024 21:06:49 -0700 Subject: [PATCH 03/49] black --- viscy/data/hcs_ram.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/viscy/data/hcs_ram.py b/viscy/data/hcs_ram.py index 5e72f6f6..8a813dce 100644 --- a/viscy/data/hcs_ram.py +++ b/viscy/data/hcs_ram.py @@ -224,7 +224,7 @@ def _fit_transform(self) -> tuple[Compose, Compose]: ) val_transform = Compose(self.normalizations + final_crop) return train_transform, val_transform - + def _set_fit_global_state(self, num_positions: int) -> torch.Tensor: # disable metadata tracking in MONAI for performance set_track_meta(False) From 7f76174917513a3cee4e565008cc7d7b60527cfe Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Thu, 26 Sep 2024 21:13:25 -0700 Subject: [PATCH 04/49] ruff --- viscy/data/hcs_ram.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/viscy/data/hcs_ram.py b/viscy/data/hcs_ram.py index 8a813dce..74240deb 100644 --- a/viscy/data/hcs_ram.py +++ b/viscy/data/hcs_ram.py @@ -1,32 +1,22 @@ import logging -import math -import os -import re -import tempfile -from pathlib import Path from typing import Callable, Literal, Sequence import numpy as np import torch -import zarr -from imageio import imread -from iohub.ngff import ImageArray, Plate, Position, open_ome_zarr +from iohub.ngff import Position, open_ome_zarr from lightning.pytorch import LightningDataModule from monai.data import set_track_meta -from monai.data.utils import collate_meta_tensor from monai.transforms import ( CenterSpatialCropd, Compose, MapTransform, MultiSampleTrait, - RandAffined, ) from torch import Tensor from torch.utils.data import DataLoader, Dataset -from viscy.data.typing import ChannelMap, DictTransform, HCSStackIndex, NormMeta, Sample from viscy.data.hcs import _read_norm_meta -from tqdm import tqdm +from viscy.data.typing import ChannelMap, DictTransform, Sample _logger = logging.getLogger("lightning.pytorch") From 85ea7915a82dba31c733d01b4842bc8ff1e7f9aa Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Fri, 27 Sep 2024 18:18:20 -0700 Subject: [PATCH 05/49] Bump torch to 2.4.1 (#174) * update torch >2.4.1 * black * ruff --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index dc263580..d07187fc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ license = { file = "LICENSE" } authors = [{ name = "CZ Biohub SF", email = "compmicro@czbiohub.org" }] dependencies = [ "iohub==0.1.0", - "torch>=2.1.2", + "torch>=2.4.1", "timm>=0.9.5", "tensorboard>=2.13.0", "lightning>=2.3.0", From 18385813b1097ddf736d09cce8bb9e4d99745ef6 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Sat, 28 Sep 2024 11:05:41 -0700 Subject: [PATCH 06/49] adding timeout to ram_dataloader --- viscy/data/hcs_ram.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/viscy/data/hcs_ram.py b/viscy/data/hcs_ram.py index 74240deb..fb1e1de8 100644 --- a/viscy/data/hcs_ram.py +++ b/viscy/data/hcs_ram.py @@ -32,9 +32,11 @@ def _stack_channels( if not isinstance(sample_images, list): return torch.stack([sample_images[ch][0] for ch in channels[key]]) # training time + # sample_images is a list['Phase3D'].shape = (1,3,256,256) return [torch.stack([im[ch][0] for ch in channels[key]]) for im in sample_images] + class CachedDataset(Dataset): """ A dataset that caches the data in RAM. @@ -149,6 +151,7 @@ def __init__( normalizations: list[MapTransform] = [], augmentations: list[MapTransform] = [], z_window_size: int = 1, + timeout: int = 600, ): super().__init__() self.data_path = data_path @@ -162,6 +165,7 @@ def __init__( self.normalizations = normalizations self.augmentations = augmentations self.z_window_size = z_window_size + self.timeout = timeout @property def _base_dataset_settings(self) -> dict[str, dict[str, list[str]] | int]: @@ -253,6 +257,7 @@ def train_dataloader(self) -> DataLoader: num_workers=self.num_workers, persistent_workers=bool(self.num_workers), shuffle=True, + timeout=self.timeout ) def val_dataloader(self) -> DataLoader: @@ -262,4 +267,5 @@ def val_dataloader(self) -> DataLoader: num_workers=self.num_workers, persistent_workers=bool(self.num_workers), shuffle=False, + timeout=self.timeout ) From f5c01a31ce4cff5f79f65d46b072fba9594af55d Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Fri, 4 Oct 2024 13:07:14 +0200 Subject: [PATCH 07/49] bandaid to cached dataloader --- viscy/data/hcs_ram.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/viscy/data/hcs_ram.py b/viscy/data/hcs_ram.py index fb1e1de8..bec0b0da 100644 --- a/viscy/data/hcs_ram.py +++ b/viscy/data/hcs_ram.py @@ -117,7 +117,8 @@ def __getitem__(self, index: int) -> Sample: if norm_meta is not None: sample_images["norm_meta"] = norm_meta if self.transform: - sample_images = self.transform(sample_images) + # FIX ME: check why the transforms return a list? + sample_images = self.transform(sample_images)[0] if "weight" in sample_images: del sample_images["weight"] sample = { @@ -185,6 +186,11 @@ def setup(self, stage: Literal["fit", "validate", "test", "predict"]) -> None: raise NotImplementedError(f"Stage {stage} is not supported") def _train_transform(self) -> list[Callable]: + """ Set the train augmentations + + + """ + if self.augmentations: for aug in self.augmentations: if isinstance(aug, MultiSampleTrait): @@ -197,6 +203,10 @@ def _train_transform(self) -> list[Callable]: f"transform type {type(aug)}." ) self.train_patches_per_stack = num_samples + else: + self.augmentations=[] + + _logger.info(f'Training augmentations: {self.augmentations}') return list(self.augmentations) def _fit_transform(self) -> tuple[Compose, Compose]: From 26a06b86491d7fea807afee3b170cd04d61aa5c4 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Fri, 4 Oct 2024 09:02:20 -0700 Subject: [PATCH 08/49] fixing the dataloader using torch collate_fn --- viscy/data/hcs_ram.py | 31 +++++++++++++++++++++++++++---- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/viscy/data/hcs_ram.py b/viscy/data/hcs_ram.py index bec0b0da..2e4f4f2c 100644 --- a/viscy/data/hcs_ram.py +++ b/viscy/data/hcs_ram.py @@ -14,6 +14,7 @@ ) from torch import Tensor from torch.utils.data import DataLoader, Dataset +from monai.data.utils import collate_meta_tensor from viscy.data.hcs import _read_norm_meta from viscy.data.typing import ChannelMap, DictTransform, Sample @@ -35,6 +36,25 @@ def _stack_channels( # sample_images is a list['Phase3D'].shape = (1,3,256,256) return [torch.stack([im[ch][0] for ch in channels[key]]) for im in sample_images] +def _collate_samples(batch: Sequence[Sample]) -> Sample: + """Collate samples into a batch sample. + + :param Sequence[Sample] batch: a sequence of dictionaries, + where each key may point to a value of a single tensor or a list of tensors, + as is the case with ``train_patches_per_stack > 1``. + :return Sample: Batch sample (dictionary of tensors) + """ + collated: Sample = {} + for key in batch[0].keys(): + data = [] + for sample in batch: + if isinstance(sample[key], Sequence): + data.extend(sample[key]) + else: + data.append(sample[key]) + collated[key] = collate_meta_tensor(data) + return collated + class CachedDataset(Dataset): @@ -118,7 +138,7 @@ def __getitem__(self, index: int) -> Sample: sample_images["norm_meta"] = norm_meta if self.transform: # FIX ME: check why the transforms return a list? - sample_images = self.transform(sample_images)[0] + sample_images = self.transform(sample_images) if "weight" in sample_images: del sample_images["weight"] sample = { @@ -206,7 +226,7 @@ def _train_transform(self) -> list[Callable]: else: self.augmentations=[] - _logger.info(f'Training augmentations: {self.augmentations}') + _logger.debug(f'Training augmentations: {self.augmentations}') return list(self.augmentations) def _fit_transform(self) -> tuple[Compose, Compose]: @@ -267,7 +287,9 @@ def train_dataloader(self) -> DataLoader: num_workers=self.num_workers, persistent_workers=bool(self.num_workers), shuffle=True, - timeout=self.timeout + timeout=self.timeout, + collate_fn=_collate_samples, + drop_last=True ) def val_dataloader(self) -> DataLoader: @@ -277,5 +299,6 @@ def val_dataloader(self) -> DataLoader: num_workers=self.num_workers, persistent_workers=bool(self.num_workers), shuffle=False, - timeout=self.timeout + timeout=self.timeout, + ) From f2ff43c0ac8b1370c3f93401b881170c930e725c Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 16 Oct 2024 17:52:28 -0700 Subject: [PATCH 09/49] replacing dictionary with single array --- viscy/data/hcs_ram.py | 62 ++++++++++++++++++++++++++++++++----------- 1 file changed, 46 insertions(+), 16 deletions(-) diff --git a/viscy/data/hcs_ram.py b/viscy/data/hcs_ram.py index 2e4f4f2c..c8f0227c 100644 --- a/viscy/data/hcs_ram.py +++ b/viscy/data/hcs_ram.py @@ -83,24 +83,39 @@ def __init__( else None ) self._position_mapping() - self.cache_dict = {} + + self.cache_order = [] + self.cache_record = torch.zeros(len(self.positions)) + # Caching the dataset as two separate arrays + # self._init_cache_dataset() def _position_mapping(self) -> None: self.position_keys = [] + self.position_shape_tczyx= (1,1,1,1,1) self.norm_meta_dict = {} for pos in self.positions: self.position_keys.append(pos.data.name) self.norm_meta_dict[str(pos.data.name)] = _read_norm_meta(pos) + # FIX: Use the position shape + self.position_shape_zyx = pos.data.shape[-3:] - def _cache_dataset(self, index: int, channel_index: list[int], t: int = 0) -> None: - # Add the position to the cached_dict - # TODO: hardcoding to t=0 - self.cache_dict[str(self.position_keys[index])] = torch.from_numpy( - self.positions[index] - .data.oindex[slice(t, t + 1), channel_index, :] - .astype(np.float32) - ) + def _init_cache_dataset(self, t_idx=1, ch_idx=1) -> None: + _logger.info('Initializing cache array') + # FIXME assumes t=1 + self.cache = torch.zeros(((len(self.positions),t_idx,len(ch_idx),)+ self.position_shape_zyx)) + + + # def _cache_dataset(self, index: int, channel_index: list[int], t: int = 0) -> None: + # # Add the position to the cached_dict + # # TODO: hardcoding to t=0 + # _logger.info(f'Adding {self.position_keys[index]} to cache') + + # # self.cache_dict[str(self.position_keys[index])] = torch.from_numpy( + # # self.positions[index] + # # .data.oindex[slice(t, t + 1), channel_index, :] + # # .astype(np.float32) + # # ) def _get_weight_map(self, position: Position) -> Tensor: # Get the weight map from the position for the MONAI weightedcrop transform @@ -117,14 +132,30 @@ def __getitem__(self, index: int) -> Sample: ch_idx.extend(self.target_ch_idx) # Check if the sample is in the cache else add it - # Split the tensor into the channels - sample_id = self.position_keys[index] - if sample_id not in self.cache_dict: - logging.info(f"Adding {sample_id} to cache") - self._cache_dataset(index, channel_index=ch_idx) + if self.cache_record[index]==0: + #if all entries of self.cache_record are zero + if self.cache_record.sum()==0: + #FIXME hardcoding t_idx=1 + self._init_cache_dataset(ch_idx=ch_idx,t_idx=1) + + # Flip the bit + self.cache_record[index]=1 + self.cache_order.append(index) + # Stack the data + _logger.info(f'Adding {self.position_keys[index]} to cache') + _logger.info(f'Cache_order: {self.cache_order}') + _logger.info(f'caching index: {index}') + #FIX ME: hardcoding t=0 and make this part of function + t=0 + # Insert the data into the cache + self.cache[index]=torch.from_numpy(self.positions[index] + .data.oindex[slice(t, t + 1), ch_idx, :] + .astype(np.float32)) # Get the sample from the cache - images = self.cache_dict[sample_id].unbind(dim=1) + # images = self.cache_dict[sample_id].unbind(dim=1) + sample_id = self.position_keys[index] + images = self.cache[index].unbind(dim=1) norm_meta = self.norm_meta_dict[str(sample_id)] sample_images = {k: v for k, v in zip(ch_names, images)} @@ -207,7 +238,6 @@ def setup(self, stage: Literal["fit", "validate", "test", "predict"]) -> None: def _train_transform(self) -> list[Callable]: """ Set the train augmentations - """ From 5fb96d75ace8775d8a152ebee6b900f9ed1be59a Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Fri, 18 Oct 2024 10:04:34 -0700 Subject: [PATCH 10/49] loading prior to epoch 0 --- viscy/data/hcs_ram.py | 104 ++++++++++++++++++++++++------------------ 1 file changed, 59 insertions(+), 45 deletions(-) diff --git a/viscy/data/hcs_ram.py b/viscy/data/hcs_ram.py index c8f0227c..98bfaba0 100644 --- a/viscy/data/hcs_ram.py +++ b/viscy/data/hcs_ram.py @@ -23,6 +23,15 @@ # TODO: cache the norm metadata when caching the dataset +# Map the NumPy dtype to the corresponding PyTorch dtype +numpy_to_torch_dtype = { + np.dtype('float32'): torch.float32, + np.dtype('float64'): torch.float64, + np.dtype('int32'): torch.int32, + np.dtype('int64'): torch.int64, + np.dtype('uint8'): torch.int8, + np.dtype('uint16'): torch.int16, +} def _stack_channels( sample_images: list[dict[str, Tensor]] | dict[str, Tensor], @@ -54,9 +63,7 @@ def _collate_samples(batch: Sequence[Sample]) -> Sample: data.append(sample[key]) collated[key] = collate_meta_tensor(data) return collated - - - + class CachedDataset(Dataset): """ A dataset that caches the data in RAM. @@ -82,10 +89,18 @@ def __init__( if "target" in channels else None ) + # Get total num channels + self.total_ch_names = self.channels["source"].copy() + self.total_ch_idx = self.source_ch_idx.copy() + if self.target_ch_idx is not None: + self.total_ch_names.extend(self.channels["target"]) + self.total_ch_idx.extend(self.target_ch_idx) self._position_mapping() self.cache_order = [] self.cache_record = torch.zeros(len(self.positions)) + self._init_cache_dataset() + # Caching the dataset as two separate arrays # self._init_cache_dataset() @@ -97,25 +112,27 @@ def _position_mapping(self) -> None: for pos in self.positions: self.position_keys.append(pos.data.name) self.norm_meta_dict[str(pos.data.name)] = _read_norm_meta(pos) - # FIX: Use the position shape self.position_shape_zyx = pos.data.shape[-3:] + self._cache_dtype = numpy_to_torch_dtype.get(pos.data.dtype, torch.float32) # Default to torch.float32 if not found - def _init_cache_dataset(self, t_idx=1, ch_idx=1) -> None: + def _init_cache_dataset(self) -> None: _logger.info('Initializing cache array') # FIXME assumes t=1 - self.cache = torch.zeros(((len(self.positions),t_idx,len(ch_idx),)+ self.position_shape_zyx)) - - - # def _cache_dataset(self, index: int, channel_index: list[int], t: int = 0) -> None: - # # Add the position to the cached_dict - # # TODO: hardcoding to t=0 - # _logger.info(f'Adding {self.position_keys[index]} to cache') - - # # self.cache_dict[str(self.position_keys[index])] = torch.from_numpy( - # # self.positions[index] - # # .data.oindex[slice(t, t + 1), channel_index, :] - # # .astype(np.float32) - # # ) + t_idx = 1 + self.cache = torch.zeros(((len(self.positions),t_idx,len(self.total_ch_idx))+ self.position_shape_zyx),dtype=self._cache_dtype) + _logger.info(f'Cache shape: {self.cache.shape}') + + #TODO Caching here to see if multiprocessing is faster + t=0 + + for i, pos in enumerate(self.positions): + _logger.info(f'Caching position {i}/{len(self.positions)}') + ## Insert the data into the cache + data = pos.data.oindex[slice(t, t + 1), self.total_ch_idx, :] + if data.dtype != np.float32: + data = data.astype(np.float32) + self.cache[i]= torch.from_numpy(data) + del data def _get_weight_map(self, position: Position) -> Tensor: # Get the weight map from the position for the MONAI weightedcrop transform @@ -125,35 +142,33 @@ def __len__(self) -> int: return len(self.positions) def __getitem__(self, index: int) -> Sample: - ch_names = self.channels["source"].copy() - ch_idx = self.source_ch_idx.copy() - if self.target_ch_idx is not None: - ch_names.extend(self.channels["target"]) - ch_idx.extend(self.target_ch_idx) - + #FIXME replace this after debugging + ch_idx = self.total_ch_idx + ch_names = self.total_ch_names + # Check if the sample is in the cache else add it - if self.cache_record[index]==0: - #if all entries of self.cache_record are zero - if self.cache_record.sum()==0: - #FIXME hardcoding t_idx=1 - self._init_cache_dataset(ch_idx=ch_idx,t_idx=1) - - # Flip the bit - self.cache_record[index]=1 - self.cache_order.append(index) - # Stack the data - _logger.info(f'Adding {self.position_keys[index]} to cache') - _logger.info(f'Cache_order: {self.cache_order}') - _logger.info(f'caching index: {index}') - #FIX ME: hardcoding t=0 and make this part of function - t=0 - # Insert the data into the cache - self.cache[index]=torch.from_numpy(self.positions[index] - .data.oindex[slice(t, t + 1), ch_idx, :] - .astype(np.float32)) + # if self.cache_record[index]== 0: + # # Flip the bit + # self.cache_record[index]=1 + # self.cache_order.append(index) + + # # Stack the data + # _logger.info(f'Adding {self.position_keys[index]} to cache') + # _logger.info(f'Cache_order: {self.cache_order}') + # _logger.info(f'caching index: {index}') + + # #FIX ME: hardcoding t=0 and make this part of function + # t=0 + + # # Insert the data into the cache + # data = self.positions[index].data.oindex[slice(t, t + 1), ch_idx, :] + # if data.dtype != np.float32: + # data = data.astype(np.float32) + # self.cache[index]= torch.from_numpy(data) + # del data # Get the sample from the cache - # images = self.cache_dict[sample_id].unbind(dim=1) + _logger.info(f'Getting sample {index} from cache') sample_id = self.position_keys[index] images = self.cache[index].unbind(dim=1) norm_meta = self.norm_meta_dict[str(sample_id)] @@ -168,7 +183,6 @@ def __getitem__(self, index: int) -> Sample: if norm_meta is not None: sample_images["norm_meta"] = norm_meta if self.transform: - # FIX ME: check why the transforms return a list? sample_images = self.transform(sample_images) if "weight" in sample_images: del sample_images["weight"] From 848cd63ce06b4dd5c36548bd109ef87f6b232176 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Fri, 18 Oct 2024 17:32:51 -0700 Subject: [PATCH 11/49] Revert "replacing dictionary with single array" This reverts commit 8c13f49498eb862e9f94518be727f47682d2cdcf. --- viscy/data/hcs_ram.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/viscy/data/hcs_ram.py b/viscy/data/hcs_ram.py index 98bfaba0..4ecf2f6f 100644 --- a/viscy/data/hcs_ram.py +++ b/viscy/data/hcs_ram.py @@ -96,6 +96,7 @@ def __init__( self.total_ch_names.extend(self.channels["target"]) self.total_ch_idx.extend(self.target_ch_idx) self._position_mapping() +<<<<<<< HEAD self.cache_order = [] self.cache_record = torch.zeros(len(self.positions)) @@ -103,15 +104,18 @@ def __init__( # Caching the dataset as two separate arrays # self._init_cache_dataset() +======= + self.cache_dict = {} +>>>>>>> parent of 8c13f49 (replacing dictionary with single array) def _position_mapping(self) -> None: self.position_keys = [] - self.position_shape_tczyx= (1,1,1,1,1) self.norm_meta_dict = {} for pos in self.positions: self.position_keys.append(pos.data.name) self.norm_meta_dict[str(pos.data.name)] = _read_norm_meta(pos) +<<<<<<< HEAD self.position_shape_zyx = pos.data.shape[-3:] self._cache_dtype = numpy_to_torch_dtype.get(pos.data.dtype, torch.float32) # Default to torch.float32 if not found @@ -133,6 +137,17 @@ def _init_cache_dataset(self) -> None: data = data.astype(np.float32) self.cache[i]= torch.from_numpy(data) del data +======= + + def _cache_dataset(self, index: int, channel_index: list[int], t: int = 0) -> None: + # Add the position to the cached_dict + # TODO: hardcoding to t=0 + self.cache_dict[str(self.position_keys[index])] = torch.from_numpy( + self.positions[index] + .data.oindex[slice(t, t + 1), channel_index, :] + .astype(np.float32) + ) +>>>>>>> parent of 8c13f49 (replacing dictionary with single array) def _get_weight_map(self, position: Position) -> Tensor: # Get the weight map from the position for the MONAI weightedcrop transform @@ -147,6 +162,7 @@ def __getitem__(self, index: int) -> Sample: ch_names = self.total_ch_names # Check if the sample is in the cache else add it +<<<<<<< HEAD # if self.cache_record[index]== 0: # # Flip the bit # self.cache_record[index]=1 @@ -171,6 +187,16 @@ def __getitem__(self, index: int) -> Sample: _logger.info(f'Getting sample {index} from cache') sample_id = self.position_keys[index] images = self.cache[index].unbind(dim=1) +======= + # Split the tensor into the channels + sample_id = self.position_keys[index] + if sample_id not in self.cache_dict: + logging.info(f"Adding {sample_id} to cache") + self._cache_dataset(index, channel_index=ch_idx) + + # Get the sample from the cache + images = self.cache_dict[sample_id].unbind(dim=1) +>>>>>>> parent of 8c13f49 (replacing dictionary with single array) norm_meta = self.norm_meta_dict[str(sample_id)] sample_images = {k: v for k, v in zip(ch_names, images)} @@ -252,6 +278,7 @@ def setup(self, stage: Literal["fit", "validate", "test", "predict"]) -> None: def _train_transform(self) -> list[Callable]: """ Set the train augmentations + """ From f7e57ae03a2ceb4ffb5347bb6b1d5d9107a8a50c Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Fri, 18 Oct 2024 18:11:01 -0700 Subject: [PATCH 12/49] using multiprocessing manager --- viscy/data/hcs_ram.py | 95 ++++++++++++------------------------------- 1 file changed, 27 insertions(+), 68 deletions(-) diff --git a/viscy/data/hcs_ram.py b/viscy/data/hcs_ram.py index 4ecf2f6f..d8b1b698 100644 --- a/viscy/data/hcs_ram.py +++ b/viscy/data/hcs_ram.py @@ -18,6 +18,8 @@ from viscy.data.hcs import _read_norm_meta from viscy.data.typing import ChannelMap, DictTransform, Sample +from multiprocessing import Manager +from datetime import datetime _logger = logging.getLogger("lightning.pytorch") @@ -96,17 +98,12 @@ def __init__( self.total_ch_names.extend(self.channels["target"]) self.total_ch_idx.extend(self.target_ch_idx) self._position_mapping() -<<<<<<< HEAD - - self.cache_order = [] - self.cache_record = torch.zeros(len(self.positions)) - self._init_cache_dataset() - - # Caching the dataset as two separate arrays - # self._init_cache_dataset() -======= + + # Cached dictionary with tensors self.cache_dict = {} ->>>>>>> parent of 8c13f49 (replacing dictionary with single array) + manager = Manager() + self.cache_dict = manager.dict() + self._cached_pos=[] def _position_mapping(self) -> None: self.position_keys = [] @@ -115,39 +112,15 @@ def _position_mapping(self) -> None: for pos in self.positions: self.position_keys.append(pos.data.name) self.norm_meta_dict[str(pos.data.name)] = _read_norm_meta(pos) -<<<<<<< HEAD - self.position_shape_zyx = pos.data.shape[-3:] - self._cache_dtype = numpy_to_torch_dtype.get(pos.data.dtype, torch.float32) # Default to torch.float32 if not found - - def _init_cache_dataset(self) -> None: - _logger.info('Initializing cache array') - # FIXME assumes t=1 - t_idx = 1 - self.cache = torch.zeros(((len(self.positions),t_idx,len(self.total_ch_idx))+ self.position_shape_zyx),dtype=self._cache_dtype) - _logger.info(f'Cache shape: {self.cache.shape}') - - #TODO Caching here to see if multiprocessing is faster - t=0 - - for i, pos in enumerate(self.positions): - _logger.info(f'Caching position {i}/{len(self.positions)}') - ## Insert the data into the cache - data = pos.data.oindex[slice(t, t + 1), self.total_ch_idx, :] - if data.dtype != np.float32: - data = data.astype(np.float32) - self.cache[i]= torch.from_numpy(data) - del data -======= def _cache_dataset(self, index: int, channel_index: list[int], t: int = 0) -> None: # Add the position to the cached_dict # TODO: hardcoding to t=0 - self.cache_dict[str(self.position_keys[index])] = torch.from_numpy( - self.positions[index] - .data.oindex[slice(t, t + 1), channel_index, :] - .astype(np.float32) - ) ->>>>>>> parent of 8c13f49 (replacing dictionary with single array) + data =self.positions[index].data.oindex[slice(t, t + 1), channel_index, :] + if data.dtype != np.float32: + data = data.astype(np.float32) + self.cache_dict[str(self.position_keys[index])] = torch.from_numpy(data) + def _get_weight_map(self, position: Position) -> Tensor: # Get the weight map from the position for the MONAI weightedcrop transform @@ -162,43 +135,20 @@ def __getitem__(self, index: int) -> Sample: ch_names = self.total_ch_names # Check if the sample is in the cache else add it -<<<<<<< HEAD - # if self.cache_record[index]== 0: - # # Flip the bit - # self.cache_record[index]=1 - # self.cache_order.append(index) - - # # Stack the data - # _logger.info(f'Adding {self.position_keys[index]} to cache') - # _logger.info(f'Cache_order: {self.cache_order}') - # _logger.info(f'caching index: {index}') - - # #FIX ME: hardcoding t=0 and make this part of function - # t=0 - - # # Insert the data into the cache - # data = self.positions[index].data.oindex[slice(t, t + 1), ch_idx, :] - # if data.dtype != np.float32: - # data = data.astype(np.float32) - # self.cache[index]= torch.from_numpy(data) - # del data - - # Get the sample from the cache - _logger.info(f'Getting sample {index} from cache') - sample_id = self.position_keys[index] - images = self.cache[index].unbind(dim=1) -======= # Split the tensor into the channels sample_id = self.position_keys[index] if sample_id not in self.cache_dict: - logging.info(f"Adding {sample_id} to cache") + _logger.info(f"Adding {sample_id} to cache") + self._cached_pos.append(index) + _logger.info(f"Cached positions: {self._cached_pos}") self._cache_dataset(index, channel_index=ch_idx) # Get the sample from the cache + _logger.info('Getting sample from cache') + start_time = datetime.now() images = self.cache_dict[sample_id].unbind(dim=1) ->>>>>>> parent of 8c13f49 (replacing dictionary with single array) norm_meta = self.norm_meta_dict[str(sample_id)] - + after_cache = datetime.now() - start_time sample_images = {k: v for k, v in zip(ch_names, images)} if self.target_ch_idx is not None: @@ -209,7 +159,9 @@ def __getitem__(self, index: int) -> Sample: if norm_meta is not None: sample_images["norm_meta"] = norm_meta if self.transform: + before_transform = datetime.now() sample_images = self.transform(sample_images) + after_transform = datetime.now() - before_transform if "weight" in sample_images: del sample_images["weight"] sample = { @@ -219,6 +171,11 @@ def __getitem__(self, index: int) -> Sample: } if self.target_ch_idx is not None: sample["target"] = _stack_channels(sample_images, self.channels, "target") + + _logger.info(f"\nTime taken to cache: {after_cache}") + _logger.info(f"Time taken to transform: {after_transform}") + _logger.info(f"Time taken to get sample: {datetime.now() - start_time}\n") + return sample def _load_sample(self, position: Position) -> Sample: @@ -357,6 +314,7 @@ def train_dataloader(self) -> DataLoader: batch_size=self.batch_size // self.train_patches_per_stack, num_workers=self.num_workers, persistent_workers=bool(self.num_workers), + pin_memory=True, shuffle=True, timeout=self.timeout, collate_fn=_collate_samples, @@ -369,6 +327,7 @@ def val_dataloader(self) -> DataLoader: batch_size=self.batch_size, num_workers=self.num_workers, persistent_workers=bool(self.num_workers), + pin_memory=True, shuffle=False, timeout=self.timeout, From c4797b4529dfd846127dab7bac58231fa19a1b7f Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Mon, 21 Oct 2024 14:13:44 -0700 Subject: [PATCH 13/49] add sharded distributed sampler --- viscy/data/distributed.py | 51 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 viscy/data/distributed.py diff --git a/viscy/data/distributed.py b/viscy/data/distributed.py new file mode 100644 index 00000000..bd3ab618 --- /dev/null +++ b/viscy/data/distributed.py @@ -0,0 +1,51 @@ +"""Utilities for DDP training.""" + +import math + +import torch +from torch.utils.data.distributed import DistributedSampler + + +class ShardedDistributedSampler(DistributedSampler): + def _sharded_randperm(self, generator): + """Generate a sharded random permutation of indices.""" + indices = torch.tensor(range(len(self.dataset))) + permuted = torch.stack( + [ + torch.randperm(self.num_samples, generator=generator) + + i * self.num_samples + for i in range(self.num_replicas) + ], + dim=1, + ).reshape(-1) + return indices[permuted].tolist() + + def __iter__(self): + """Modified __iter__ method to shard data across distributed ranks.""" + if self.shuffle: + # deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + indices = self._sharded_randperm(g) + else: + indices = list(range(len(self.dataset))) # type: ignore[arg-type] + + if not self.drop_last: + # add extra samples to make it evenly divisible + padding_size = self.total_size - len(indices) + if padding_size <= len(indices): + indices += indices[:padding_size] + else: + indices += (indices * math.ceil(padding_size / len(indices)))[ + :padding_size + ] + else: + # remove tail of data to make it evenly divisible. + indices = indices[: self.total_size] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank : self.total_size : self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) From 2c31e7d09653bb831c3cdd022ba782988bd93893 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Mon, 21 Oct 2024 14:14:05 -0700 Subject: [PATCH 14/49] add example script for ddp caching --- viscy/scripts/shared_dict.py | 121 +++++++++++++++++++++++++++++++++++ 1 file changed, 121 insertions(+) create mode 100644 viscy/scripts/shared_dict.py diff --git a/viscy/scripts/shared_dict.py b/viscy/scripts/shared_dict.py new file mode 100644 index 00000000..b29d4d86 --- /dev/null +++ b/viscy/scripts/shared_dict.py @@ -0,0 +1,121 @@ +from multiprocessing.managers import DictProxy + +import torch +from lightning.pytorch import LightningDataModule, LightningModule, Trainer +from lightning.pytorch.utilities import rank_zero_info +from torch.distributed import get_rank +from torch.multiprocessing import Manager +from torch.utils.data import DataLoader, Dataset, Subset + +from viscy.data.distributed import ShardedDistributedSampler + + +class CachedDataset(Dataset): + def __init__(self, shared_dict: DictProxy, length: int): + self.rank = get_rank() + print(f"=== Initializing cache pool for rank {self.rank} ===") + self.shared_dict = shared_dict + self.length = length + + def __getitem__(self, index): + if index not in self.shared_dict: + print(f"* Adding {index} to cache dict on rank {self.rank}") + self.shared_dict[index] = torch.tensor(index).float()[None] + return self.shared_dict[index] + + def __len__(self): + return self.length + + +class CachedDataModule(LightningDataModule): + def __init__( + self, + length: int, + split_ratio: float, + batch_size: int, + num_workers: int, + persistent_workers: bool, + ): + super().__init__() + self.length = length + self.split_ratio = split_ratio + self.batch_size = batch_size + self.num_workers = num_workers + self.persistent_workers = persistent_workers + + def setup(self, stage): + if stage != "fit": + raise NotImplementedError("Only fit stage is supported.") + shared_dict = Manager().dict() + dataset = CachedDataset(shared_dict, self.length) + split_idx = int(self.length * self.split_ratio) + self.train_dataset = Subset(dataset, range(0, split_idx)) + self.val_dataset = Subset(dataset, range(split_idx, self.length)) + + def train_dataloader(self): + sampler = ShardedDistributedSampler(self.train_dataset, shuffle=True) + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + persistent_workers=self.persistent_workers, + drop_last=False, + sampler=sampler, + ) + + def val_dataloader(self): + sampler = ShardedDistributedSampler(self.val_dataset, shuffle=False) + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + persistent_workers=self.persistent_workers, + drop_last=False, + sampler=sampler, + ) + + +class DummyModel(LightningModule): + def __init__(self): + super().__init__() + self.layer = torch.nn.Linear(1, 1) + + def forward(self, x): + return self.layer(x) + + def on_train_start(self): + rank_zero_info("=== Starting training ===") + + def on_train_epoch_start(self): + rank_zero_info(f"=== Starting training epoch {self.current_epoch} ===") + + def training_step(self, batch, batch_idx): + loss = torch.nn.functional.mse_loss(self.layer(batch), torch.zeros_like(batch)) + return loss + + def validation_step(self, batch, batch_idx): + loss = torch.nn.functional.mse_loss(self.layer(batch), torch.zeros_like(batch)) + return loss + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=1e-3) + + +trainer = Trainer( + max_epochs=5, + strategy="ddp", + accelerator="cpu", + devices=3, + use_distributed_sampler=False, + enable_progress_bar=False, + logger=False, + enable_checkpointing=False, +) + +data_module = CachedDataModule( + length=50, batch_size=2, split_ratio=0.6, num_workers=4, persistent_workers=False +) +model = DummyModel() +trainer.fit(model, data_module) From 5300b4a932c6bf334a7943182c5d65c464019985 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Mon, 21 Oct 2024 16:32:34 -0700 Subject: [PATCH 15/49] format and lint --- viscy/data/hcs_ram.py | 50 +++++++++++++++++++++---------------------- 1 file changed, 24 insertions(+), 26 deletions(-) diff --git a/viscy/data/hcs_ram.py b/viscy/data/hcs_ram.py index d8b1b698..e891beda 100644 --- a/viscy/data/hcs_ram.py +++ b/viscy/data/hcs_ram.py @@ -1,4 +1,6 @@ import logging +from datetime import datetime +from multiprocessing import Manager from typing import Callable, Literal, Sequence import numpy as np @@ -6,6 +8,7 @@ from iohub.ngff import Position, open_ome_zarr from lightning.pytorch import LightningDataModule from monai.data import set_track_meta +from monai.data.utils import collate_meta_tensor from monai.transforms import ( CenterSpatialCropd, Compose, @@ -14,12 +17,9 @@ ) from torch import Tensor from torch.utils.data import DataLoader, Dataset -from monai.data.utils import collate_meta_tensor from viscy.data.hcs import _read_norm_meta from viscy.data.typing import ChannelMap, DictTransform, Sample -from multiprocessing import Manager -from datetime import datetime _logger = logging.getLogger("lightning.pytorch") @@ -27,14 +27,15 @@ # Map the NumPy dtype to the corresponding PyTorch dtype numpy_to_torch_dtype = { - np.dtype('float32'): torch.float32, - np.dtype('float64'): torch.float64, - np.dtype('int32'): torch.int32, - np.dtype('int64'): torch.int64, - np.dtype('uint8'): torch.int8, - np.dtype('uint16'): torch.int16, + np.dtype("float32"): torch.float32, + np.dtype("float64"): torch.float64, + np.dtype("int32"): torch.int32, + np.dtype("int64"): torch.int64, + np.dtype("uint8"): torch.int8, + np.dtype("uint16"): torch.int16, } + def _stack_channels( sample_images: list[dict[str, Tensor]] | dict[str, Tensor], channels: ChannelMap, @@ -47,6 +48,7 @@ def _stack_channels( # sample_images is a list['Phase3D'].shape = (1,3,256,256) return [torch.stack([im[ch][0] for ch in channels[key]]) for im in sample_images] + def _collate_samples(batch: Sequence[Sample]) -> Sample: """Collate samples into a batch sample. @@ -65,7 +67,8 @@ def _collate_samples(batch: Sequence[Sample]) -> Sample: data.append(sample[key]) collated[key] = collate_meta_tensor(data) return collated - + + class CachedDataset(Dataset): """ A dataset that caches the data in RAM. @@ -98,12 +101,12 @@ def __init__( self.total_ch_names.extend(self.channels["target"]) self.total_ch_idx.extend(self.target_ch_idx) self._position_mapping() - + # Cached dictionary with tensors self.cache_dict = {} manager = Manager() self.cache_dict = manager.dict() - self._cached_pos=[] + self._cached_pos = [] def _position_mapping(self) -> None: self.position_keys = [] @@ -116,12 +119,11 @@ def _position_mapping(self) -> None: def _cache_dataset(self, index: int, channel_index: list[int], t: int = 0) -> None: # Add the position to the cached_dict # TODO: hardcoding to t=0 - data =self.positions[index].data.oindex[slice(t, t + 1), channel_index, :] + data = self.positions[index].data.oindex[slice(t, t + 1), channel_index, :] if data.dtype != np.float32: data = data.astype(np.float32) self.cache_dict[str(self.position_keys[index])] = torch.from_numpy(data) - def _get_weight_map(self, position: Position) -> Tensor: # Get the weight map from the position for the MONAI weightedcrop transform raise NotImplementedError @@ -130,10 +132,10 @@ def __len__(self) -> int: return len(self.positions) def __getitem__(self, index: int) -> Sample: - #FIXME replace this after debugging + # FIXME replace this after debugging ch_idx = self.total_ch_idx ch_names = self.total_ch_names - + # Check if the sample is in the cache else add it # Split the tensor into the channels sample_id = self.position_keys[index] @@ -144,7 +146,7 @@ def __getitem__(self, index: int) -> Sample: self._cache_dataset(index, channel_index=ch_idx) # Get the sample from the cache - _logger.info('Getting sample from cache') + _logger.info("Getting sample from cache") start_time = datetime.now() images = self.cache_dict[sample_id].unbind(dim=1) norm_meta = self.norm_meta_dict[str(sample_id)] @@ -234,10 +236,7 @@ def setup(self, stage: Literal["fit", "validate", "test", "predict"]) -> None: raise NotImplementedError(f"Stage {stage} is not supported") def _train_transform(self) -> list[Callable]: - """ Set the train augmentations - - - """ + """Set the train augmentations""" if self.augmentations: for aug in self.augmentations: @@ -252,9 +251,9 @@ def _train_transform(self) -> list[Callable]: ) self.train_patches_per_stack = num_samples else: - self.augmentations=[] - - _logger.debug(f'Training augmentations: {self.augmentations}') + self.augmentations = [] + + _logger.debug(f"Training augmentations: {self.augmentations}") return list(self.augmentations) def _fit_transform(self) -> tuple[Compose, Compose]: @@ -318,7 +317,7 @@ def train_dataloader(self) -> DataLoader: shuffle=True, timeout=self.timeout, collate_fn=_collate_samples, - drop_last=True + drop_last=True, ) def val_dataloader(self) -> DataLoader: @@ -330,5 +329,4 @@ def val_dataloader(self) -> DataLoader: pin_memory=True, shuffle=False, timeout=self.timeout, - ) From 8a8b4b017b42d0dc680e6c523be9138748985b50 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Tue, 22 Oct 2024 09:40:26 -0700 Subject: [PATCH 16/49] addding the custom distrb sampler to hcs_ram.py --- viscy/data/hcs_ram.py | 35 +++++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/viscy/data/hcs_ram.py b/viscy/data/hcs_ram.py index e891beda..7e406179 100644 --- a/viscy/data/hcs_ram.py +++ b/viscy/data/hcs_ram.py @@ -1,6 +1,7 @@ import logging from datetime import datetime from multiprocessing import Manager +from multiprocessing.managers import DictProxy from typing import Callable, Literal, Sequence import numpy as np @@ -20,6 +21,9 @@ from viscy.data.hcs import _read_norm_meta from viscy.data.typing import ChannelMap, DictTransform, Sample +from viscy.data.distributed import ShardedDistributedSampler +from torch.distributed import get_rank +import torch.distributed as dist _logger = logging.getLogger("lightning.pytorch") @@ -68,7 +72,10 @@ def _collate_samples(batch: Sequence[Sample]) -> Sample: collated[key] = collate_meta_tensor(data) return collated - +def is_ddp_enabled() -> bool: + """Check if distributed data parallel (DDP) is initialized.""" + return dist.is_available() and dist.is_initialized() + class CachedDataset(Dataset): """ A dataset that caches the data in RAM. @@ -77,11 +84,17 @@ class CachedDataset(Dataset): def __init__( self, + shared_dict: DictProxy, positions: list[Position], channels: ChannelMap, transform: DictTransform | None = None, ): super().__init__() + if is_ddp_enabled(): + self.rank = dist.get_rank() + _logger.info(f"=== Initializing cache pool for rank {self.rank} ===") + + self.cache_dict = shared_dict self.positions = positions self.channels = channels self.transform = transform @@ -103,9 +116,7 @@ def __init__( self._position_mapping() # Cached dictionary with tensors - self.cache_dict = {} - manager = Manager() - self.cache_dict = manager.dict() + # TODO: Delete after testing self._cached_pos = [] def _position_mapping(self) -> None: @@ -132,18 +143,13 @@ def __len__(self) -> int: return len(self.positions) def __getitem__(self, index: int) -> Sample: - # FIXME replace this after debugging - ch_idx = self.total_ch_idx - ch_names = self.total_ch_names - # Check if the sample is in the cache else add it - # Split the tensor into the channels sample_id = self.position_keys[index] if sample_id not in self.cache_dict: _logger.info(f"Adding {sample_id} to cache") self._cached_pos.append(index) _logger.info(f"Cached positions: {self._cached_pos}") - self._cache_dataset(index, channel_index=ch_idx) + self._cache_dataset(index, channel_index=self.total_ch_idx) # Get the sample from the cache _logger.info("Getting sample from cache") @@ -151,7 +157,7 @@ def __getitem__(self, index: int) -> Sample: images = self.cache_dict[sample_id].unbind(dim=1) norm_meta = self.norm_meta_dict[str(sample_id)] after_cache = datetime.now() - start_time - sample_images = {k: v for k, v in zip(ch_names, images)} + sample_images = {k: v for k, v in zip(self.total_ch_names, images)} if self.target_ch_idx is not None: # FIXME: this uses the first target channel as weight for performance @@ -296,31 +302,36 @@ def _setup_fit(self, dataset_settings: dict) -> None: positions = list(positions[i] for i in shuffled_indices) num_train_fovs = int(len(positions) * self.split_ratio) + shared_dict = Manager().dict() self.train_dataset = CachedDataset( + shared_dict, positions[:num_train_fovs], transform=train_transform, **dataset_settings, ) self.val_dataset = CachedDataset( + shared_dict, positions[num_train_fovs:], transform=val_transform, **dataset_settings, ) def train_dataloader(self) -> DataLoader: + sampler = ShardedDistributedSampler(self.train_dataset, shuffle=True) return DataLoader( self.train_dataset, batch_size=self.batch_size // self.train_patches_per_stack, num_workers=self.num_workers, persistent_workers=bool(self.num_workers), pin_memory=True, - shuffle=True, + shuffle=False, timeout=self.timeout, collate_fn=_collate_samples, drop_last=True, ) def val_dataloader(self) -> DataLoader: + sampler = ShardedDistributedSampler(self.val_dataset, shuffle=False) return DataLoader( self.val_dataset, batch_size=self.batch_size, From 49764faa6f433d3a7bad9f70334b924a402244f9 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Tue, 22 Oct 2024 09:52:57 -0700 Subject: [PATCH 17/49] adding sampler to val train dataloader --- viscy/data/hcs_ram.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/viscy/data/hcs_ram.py b/viscy/data/hcs_ram.py index 7e406179..cedc3403 100644 --- a/viscy/data/hcs_ram.py +++ b/viscy/data/hcs_ram.py @@ -75,7 +75,7 @@ def _collate_samples(batch: Sequence[Sample]) -> Sample: def is_ddp_enabled() -> bool: """Check if distributed data parallel (DDP) is initialized.""" return dist.is_available() and dist.is_initialized() - + class CachedDataset(Dataset): """ A dataset that caches the data in RAM. @@ -328,6 +328,7 @@ def train_dataloader(self) -> DataLoader: timeout=self.timeout, collate_fn=_collate_samples, drop_last=True, + sampler=sampler, ) def val_dataloader(self) -> DataLoader: @@ -340,4 +341,5 @@ def val_dataloader(self) -> DataLoader: pin_memory=True, shuffle=False, timeout=self.timeout, + sampler=sampler ) From 1fe54913bf045cab07b1537e735b000784f2a4a8 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Tue, 22 Oct 2024 10:49:38 -0700 Subject: [PATCH 18/49] fix divisibility of the last shard --- viscy/data/distributed.py | 35 ++++++++++++++++++++--------------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/viscy/data/distributed.py b/viscy/data/distributed.py index bd3ab618..68e6d39e 100644 --- a/viscy/data/distributed.py +++ b/viscy/data/distributed.py @@ -1,35 +1,40 @@ """Utilities for DDP training.""" +from __future__ import annotations + import math +from typing import TYPE_CHECKING import torch +import torch.distributed from torch.utils.data.distributed import DistributedSampler +if TYPE_CHECKING: + from torch import Generator + class ShardedDistributedSampler(DistributedSampler): - def _sharded_randperm(self, generator): - """Generate a sharded random permutation of indices.""" - indices = torch.tensor(range(len(self.dataset))) - permuted = torch.stack( - [ - torch.randperm(self.num_samples, generator=generator) - + i * self.num_samples - for i in range(self.num_replicas) - ], - dim=1, - ).reshape(-1) - return indices[permuted].tolist() + def _sharded_randperm(self, max_size: int, generator: Generator) -> list[int]: + """Generate a sharded random permutation of indices. + Overlap may occur in between the last two shards to maintain divisibility.""" + sharded_randperm = [ + torch.randperm(self.num_samples, generator=generator) + + min(i * self.num_samples, max_size - self.num_samples) + for i in range(self.num_replicas) + ] + indices = torch.stack(sharded_randperm, dim=1).reshape(-1) + return indices.tolist() def __iter__(self): """Modified __iter__ method to shard data across distributed ranks.""" + max_size = len(self.dataset) # type: ignore[arg-type] if self.shuffle: # deterministically shuffle based on epoch and seed g = torch.Generator() g.manual_seed(self.seed + self.epoch) - indices = self._sharded_randperm(g) + indices = self._sharded_randperm(max_size, g) else: - indices = list(range(len(self.dataset))) # type: ignore[arg-type] - + indices = list(range(max_size)) if not self.drop_last: # add extra samples to make it evenly divisible padding_size = self.total_size - len(indices) From 0b005cfb6407b018adc7104cf184667a108a2990 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Tue, 22 Oct 2024 11:04:19 -0700 Subject: [PATCH 19/49] hcs_ram format and lint --- viscy/data/hcs_ram.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/viscy/data/hcs_ram.py b/viscy/data/hcs_ram.py index cedc3403..aa24f28e 100644 --- a/viscy/data/hcs_ram.py +++ b/viscy/data/hcs_ram.py @@ -6,6 +6,7 @@ import numpy as np import torch +import torch.distributed as dist from iohub.ngff import Position, open_ome_zarr from lightning.pytorch import LightningDataModule from monai.data import set_track_meta @@ -19,11 +20,9 @@ from torch import Tensor from torch.utils.data import DataLoader, Dataset +from viscy.data.distributed import ShardedDistributedSampler from viscy.data.hcs import _read_norm_meta from viscy.data.typing import ChannelMap, DictTransform, Sample -from viscy.data.distributed import ShardedDistributedSampler -from torch.distributed import get_rank -import torch.distributed as dist _logger = logging.getLogger("lightning.pytorch") @@ -72,10 +71,12 @@ def _collate_samples(batch: Sequence[Sample]) -> Sample: collated[key] = collate_meta_tensor(data) return collated + def is_ddp_enabled() -> bool: """Check if distributed data parallel (DDP) is initialized.""" return dist.is_available() and dist.is_initialized() + class CachedDataset(Dataset): """ A dataset that caches the data in RAM. @@ -341,5 +342,5 @@ def val_dataloader(self) -> DataLoader: pin_memory=True, shuffle=False, timeout=self.timeout, - sampler=sampler + sampler=sampler, ) From 023ca880131ae54657c16d4fa2205286064b78ab Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 23 Oct 2024 10:00:53 -0700 Subject: [PATCH 20/49] data module that only crops and does not collate --- viscy/data/gpu_aug.py | 148 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 148 insertions(+) create mode 100644 viscy/data/gpu_aug.py diff --git a/viscy/data/gpu_aug.py b/viscy/data/gpu_aug.py new file mode 100644 index 00000000..573bb734 --- /dev/null +++ b/viscy/data/gpu_aug.py @@ -0,0 +1,148 @@ +from __future__ import annotations + +from logging import getLogger +from pathlib import Path +from typing import TYPE_CHECKING, Literal + +import numpy as np +import torch +from iohub.ngff import Plate, Position, open_ome_zarr +from lightning.pytorch import LightningDataModule +from monai.data.meta_obj import set_track_meta +from monai.transforms import Compose +from torch import Tensor +from torch.multiprocessing import Manager +from torch.utils.data import DataLoader, Dataset, Subset + +from viscy.data.distributed import ShardedDistributedSampler +from viscy.data.hcs import _ensure_channel_list, _read_norm_meta +from viscy.data.typing import DictTransform + +if TYPE_CHECKING: + from multiprocessing.managers import DictProxy + +_logger = getLogger("lightning.pytorch") + + +class CachedOmeZarrDataset(Dataset): + def __init__( + self, + positions: list[Position], + channel_names: list[str], + cache_map: DictProxy, + transform: DictTransform | None = None, + array_key: str = "0", + ): + key = 0 + self._metadata_map = {} + for position in positions: + img = position[array_key] + norm_meta = _read_norm_meta(position) + for time_idx in range(img.frames): + cache_map[key] = None + self._metadata_map[key] = (position, time_idx, norm_meta) + key += 1 + self.channels = {ch: position.get_channel_index(ch) for ch in channel_names} + self.array_key = array_key + self._cache_map = cache_map + self.transform = transform + + def __len__(self) -> int: + return len(self._cache_map) + + def __getitem__(self, idx: int) -> Tensor: + position, time_idx, norm_meta = self._metadata_map[idx] + cache = self._cache_map[idx] + if cache is None: + _logger.debug(f"Caching for index {idx}") + volume = torch.from_numpy( + position[self.array_key] + .oindex[time_idx, list(self.channels.values())] + .astype(np.float32) + ) + self._cache_map[idx] = volume + else: + _logger.debug(f"Using cached volume for index {idx}") + volume = cache + sample = {name: img[None] for name, img in zip(self.channels.keys(), volume)} + sample["norm_meta"] = norm_meta + if self.transform: + sample = self.transform(sample) + if not isinstance(sample, list): + sample = [sample] + out_tensors = [] + for s in sample: + s.pop("norm_meta") + s_out = torch.cat(list(s.values())) + out_tensors.append(s_out) + return out_tensors + + +class CachedOmeZarrDataModule(LightningDataModule): + def __init__( + self, + data_path: Path, + channels: str | list[str], + batch_size: int, + num_workers: int, + split_ratio: float, + transforms: list[DictTransform], + ): + super().__init__() + self.data_path = data_path + self.channels = _ensure_channel_list(channels) + self.batch_size = batch_size + self.num_workers = num_workers + self.split_ratio = split_ratio + self.transforms = Compose(transforms) + + def _set_fit_global_state(self, num_positions: int) -> list[int]: + # disable metadata tracking in MONAI for performance + set_track_meta(False) + # shuffle positions, randomness is handled globally + return torch.randperm(num_positions).tolist() + + def setup(self, stage: Literal["fit", "validate"]) -> None: + cache_map = Manager().dict() + plate: Plate = open_ome_zarr(self.data_path, mode="r", layout="hcs") + positions = [p for _, p in plate.positions()] + shuffled_indices = self._set_fit_global_state(len(positions)) + num_train_fovs = int(len(positions) * self.split_ratio) + dataset = CachedOmeZarrDataset( + positions, self.channels, cache_map, self.transforms + ) + self.train_dataset = Subset(dataset, shuffled_indices[:num_train_fovs]) + self.val_dataset = Subset(dataset, shuffled_indices[num_train_fovs:]) + + def _maybe_sampler( + self, dataset: Dataset, shuffle: bool + ) -> ShardedDistributedSampler | None: + return ( + ShardedDistributedSampler(dataset, shuffle=shuffle) + if torch.distributed.is_initialized() + else None + ) + + def train_dataloader(self) -> DataLoader: + sampler = self._maybe_sampler(self.train_dataset, shuffle=True) + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + shuffle=False if sampler else True, + sampler=sampler, + persistent_workers=True if self.num_workers > 0 else False, + num_workers=self.num_workers, + drop_last=True, + ) + + def val_dataloader(self) -> DataLoader: + sampler = self._maybe_sampler(self.val_dataset, shuffle=False) + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + shuffle=False, + sampler=sampler, + persistent_workers=True if self.num_workers > 0 else False, + num_workers=self.num_workers, + drop_last=False, + ) From f7ce0ba381d6d338d37dac3eddd8080440818dc1 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 23 Oct 2024 10:01:28 -0700 Subject: [PATCH 21/49] wip: execute transforms on the GPU --- viscy/data/combined.py | 76 ++++++++++++++++++++++++++++++++++++- viscy/translation/engine.py | 73 +++++++++++++++++++---------------- 2 files changed, 114 insertions(+), 35 deletions(-) diff --git a/viscy/data/combined.py b/viscy/data/combined.py index 31ea9f6c..9585345c 100644 --- a/viscy/data/combined.py +++ b/viscy/data/combined.py @@ -1,11 +1,13 @@ from enum import Enum from typing import Literal, Sequence +import torch from lightning.pytorch import LightningDataModule from lightning.pytorch.utilities.combined_loader import CombinedLoader -from torch.utils.data import ConcatDataset, DataLoader +from torch.utils.data import ConcatDataset, DataLoader, Dataset -from viscy.data.hcs import _collate_samples +from viscy.data.distributed import ShardedDistributedSampler +from viscy.data.hcs_ram import _collate_samples class CombineMode(Enum): @@ -133,3 +135,73 @@ def val_dataloader(self): shuffle=False, persistent_workers=bool(self.num_workers), ) + + +class CachedConcatDataModule(LightningDataModule): + def __init__(self, data_modules: Sequence[LightningDataModule]): + super().__init__() + self.data_modules = data_modules + self.num_workers = data_modules[0].num_workers + self.batch_size = data_modules[0].batch_size + for dm in data_modules: + if dm.num_workers != self.num_workers: + raise ValueError("Inconsistent number of workers") + if dm.batch_size != self.batch_size: + raise ValueError("Inconsistent batch size") + self.prepare_data_per_node = True + + def prepare_data(self): + for dm in self.data_modules: + dm.trainer = self.trainer + dm.prepare_data() + + def setup(self, stage: Literal["fit", "validate", "test", "predict"]): + self.train_patches_per_stack = 0 + for dm in self.data_modules: + dm.setup(stage) + if patches := getattr(dm, "train_patches_per_stack", 1): + if self.train_patches_per_stack == 0: + self.train_patches_per_stack = patches + elif self.train_patches_per_stack != patches: + raise ValueError("Inconsistent patches per stack") + if stage != "fit": + raise NotImplementedError("Only fit stage is supported") + self.train_dataset = ConcatDataset( + [dm.train_dataset for dm in self.data_modules] + ) + self.val_dataset = ConcatDataset([dm.val_dataset for dm in self.data_modules]) + + def _maybe_sampler( + self, dataset: Dataset, shuffle: bool + ) -> ShardedDistributedSampler | None: + return ( + ShardedDistributedSampler(dataset, shuffle=shuffle) + if torch.distributed.is_initialized() + else None + ) + + def train_dataloader(self) -> DataLoader: + sampler = self._maybe_sampler(self.train_dataset, shuffle=True) + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + shuffle=False if sampler else True, + sampler=sampler, + persistent_workers=True if self.num_workers > 0 else False, + num_workers=self.num_workers, + drop_last=True, + collate_fn=lambda x: x, + ) + + def val_dataloader(self) -> DataLoader: + sampler = self._maybe_sampler(self.val_dataset, shuffle=False) + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + shuffle=False, + sampler=sampler, + persistent_workers=True if self.num_workers > 0 else False, + num_workers=self.num_workers, + drop_last=False, + collate_fn=lambda x: x, + ) diff --git a/viscy/translation/engine.py b/viscy/translation/engine.py index aa7ac24b..4e8af365 100644 --- a/viscy/translation/engine.py +++ b/viscy/translation/engine.py @@ -8,7 +8,7 @@ from imageio import imwrite from lightning.pytorch import LightningModule from monai.optimizers import WarmupCosineSchedule -from monai.transforms import DivisiblePad, Rotate90 +from monai.transforms import Compose, DivisiblePad, Rotate90 from torch import Tensor, nn from torch.optim.lr_scheduler import ConstantLR from torchmetrics.functional import ( @@ -466,63 +466,70 @@ def _crop_to_original(self, tensor: Tensor) -> Tensor: class FcmaeUNet(VSUNet): - def __init__(self, fit_mask_ratio: float = 0.0, **kwargs): + def __init__( + self, + fit_mask_ratio: float = 0.0, + train_transforms=[], + validation_transforms=[], + **kwargs, + ): super().__init__(architecture="fcmae", **kwargs) self.fit_mask_ratio = fit_mask_ratio + self.train_transforms = Compose(train_transforms) + self.validation_transforms = Compose(validation_transforms) def forward(self, x: Tensor, mask_ratio: float = 0.0): return self.model(x, mask_ratio) - def forward_fit(self, batch: Sample) -> tuple[Tensor]: - source = batch["source"] - target = batch["target"] - pred, mask = self.forward(source, mask_ratio=self.fit_mask_ratio) - loss = F.mse_loss(pred, target, reduction="none") + def forward_fit(self, batch: Tensor) -> tuple[Tensor]: + pred, mask = self.forward(batch, mask_ratio=self.fit_mask_ratio) + loss = F.mse_loss(pred, batch, reduction="none") loss = (loss.mean(2) * mask).sum() / mask.sum() - return source, target, pred, mask, loss - - def training_step(self, batch: Sequence[Sample], batch_idx: int): - losses = [] - batch_size = 0 - for b in batch: - source, target, pred, mask, loss = self.forward_fit(b) - losses.append(loss) - batch_size += source.shape[0] - if batch_idx < self.log_batches_per_epoch: - self.training_step_outputs.extend( - detach_sample( - (source, target * mask.unsqueeze(2), pred), - self.log_samples_per_batch, - ) + return pred, mask, loss + + def transform_and_collate(self, batch: list[Tensor], transforms: Compose) -> Tensor: + transformed = [] + for sample in batch: + for element in sample: + transformed.append(transforms(element)) + return torch.stack(transformed) + + def training_step(self, batch: list[Tensor], batch_idx: int): + batch = self.transform_and_collate(batch, self.train_transforms) + pred, mask, loss = self.forward_fit(batch) + if batch_idx < self.log_batches_per_epoch: + self.training_step_outputs.extend( + detach_sample( + (batch, batch * mask.unsqueeze(2), pred), self.log_samples_per_batch ) - loss_step = torch.stack(losses).mean() + ) self.log( "loss/train", - loss_step.to(self.device), + loss.to(self.device), on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, - batch_size=batch_size, + batch_size=batch.shape[0], ) - return loss_step + return loss - def validation_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): - source, target, pred, mask, loss = self.forward_fit(batch) + def validation_step( + self, batch: list[Tensor], batch_idx: int, dataloader_idx: int = 0 + ): + batch = self.transform_and_collate(batch, self.validation_transforms) + pred, mask, loss = self.forward_fit(batch) if dataloader_idx + 1 > len(self.validation_losses): self.validation_losses.append([]) self.validation_losses[dataloader_idx].append(loss.detach()) self.log( - f"loss/val/{dataloader_idx}", - loss.to(self.device), - sync_dist=True, - batch_size=source.shape[0], + "loss/val", loss.to(self.device), sync_dist=True, batch_size=batch.shape[0] ) if batch_idx < self.log_batches_per_epoch: self.validation_step_outputs.extend( detach_sample( - (source, target * mask.unsqueeze(2), pred), + (batch, batch * mask.unsqueeze(2), pred), self.log_samples_per_batch, ) ) From daa686028575868d7f87f34955f7451d9fcfac19 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Thu, 24 Oct 2024 10:49:49 -0700 Subject: [PATCH 22/49] path for if not ddp --- viscy/data/hcs_ram.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/viscy/data/hcs_ram.py b/viscy/data/hcs_ram.py index aa24f28e..a9ff25d3 100644 --- a/viscy/data/hcs_ram.py +++ b/viscy/data/hcs_ram.py @@ -6,7 +6,6 @@ import numpy as np import torch -import torch.distributed as dist from iohub.ngff import Position, open_ome_zarr from lightning.pytorch import LightningDataModule from monai.data import set_track_meta @@ -20,9 +19,11 @@ from torch import Tensor from torch.utils.data import DataLoader, Dataset -from viscy.data.distributed import ShardedDistributedSampler from viscy.data.hcs import _read_norm_meta from viscy.data.typing import ChannelMap, DictTransform, Sample +from viscy.data.distributed import ShardedDistributedSampler +from torch.distributed import get_rank +import torch.distributed as dist _logger = logging.getLogger("lightning.pytorch") @@ -71,12 +72,10 @@ def _collate_samples(batch: Sequence[Sample]) -> Sample: collated[key] = collate_meta_tensor(data) return collated - def is_ddp_enabled() -> bool: """Check if distributed data parallel (DDP) is initialized.""" return dist.is_available() and dist.is_initialized() - class CachedDataset(Dataset): """ A dataset that caches the data in RAM. @@ -318,7 +317,11 @@ def _setup_fit(self, dataset_settings: dict) -> None: ) def train_dataloader(self) -> DataLoader: - sampler = ShardedDistributedSampler(self.train_dataset, shuffle=True) + if is_ddp_enabled(): + sampler = ShardedDistributedSampler(self.train_dataset, shuffle=True) + else: + sampler = None + _logger.info("Using standard sampler for non-distributed training") return DataLoader( self.train_dataset, batch_size=self.batch_size // self.train_patches_per_stack, @@ -333,7 +336,12 @@ def train_dataloader(self) -> DataLoader: ) def val_dataloader(self) -> DataLoader: - sampler = ShardedDistributedSampler(self.val_dataset, shuffle=False) + if is_ddp_enabled(): + sampler = ShardedDistributedSampler(self.val_dataset, shuffle=False) + else: + sampler = None + _logger.info("Using standard sampler for non-distributed validation") + return DataLoader( self.val_dataset, batch_size=self.batch_size, @@ -342,5 +350,5 @@ def val_dataloader(self) -> DataLoader: pin_memory=True, shuffle=False, timeout=self.timeout, - sampler=sampler, + sampler=sampler ) From 55499deec2babd58844df80a35334fd6d26542f3 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Mon, 28 Oct 2024 20:01:03 -0700 Subject: [PATCH 23/49] fix randomness in inversion transform --- viscy/transforms.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/viscy/transforms.py b/viscy/transforms.py index 5eca0538..eb894393 100644 --- a/viscy/transforms.py +++ b/viscy/transforms.py @@ -206,6 +206,8 @@ def __init__( def __call__(self, sample: Sample) -> Sample: self.randomize(None) + if not self._do_transform: + return sample for key in self.keys: if key in sample: sample[key] = -sample[key] From 42806775cefca873107b9b39d59f4fe7f4f3ab90 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Mon, 28 Oct 2024 20:02:12 -0700 Subject: [PATCH 24/49] add option to pop the normalization metadata --- viscy/transforms.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/viscy/transforms.py b/viscy/transforms.py index eb894393..3a2ee0e7 100644 --- a/viscy/transforms.py +++ b/viscy/transforms.py @@ -158,11 +158,20 @@ def __init__( class NormalizeSampled(MapTransform): """ - Normalize the sample - :param Union[str, Iterable[str]] keys: keys to normalize - :param str fov: fov path with respect to Plate - :param str subtrahend: subtrahend for normalization, defaults to "mean" - :param str divisor: divisor for normalization, defaults to "std" + Normalize the sample. + + Parameters + ---------- + keys : Union[str, Iterable[str]] + Keys to normalize. + level : {'fov_statistics', 'dataset_statistics'} + Level of normalization. + subtrahend : str, optional + Subtrahend for normalization, defaults to "mean". + divisor : str, optional + Divisor for normalization, defaults to "std". + remove_meta : bool, optional + Whether to remove metadata after normalization, defaults to False. """ def __init__( @@ -171,11 +180,13 @@ def __init__( level: Literal["fov_statistics", "dataset_statistics"], subtrahend="mean", divisor="std", + remove_meta: bool = False, ) -> None: super().__init__(keys, allow_missing_keys=False) self.subtrahend = subtrahend self.divisor = divisor self.level = level + self.remove_meta = remove_meta # TODO: need to implement the case where the preprocessing already exists def __call__(self, sample: Sample) -> Sample: @@ -184,6 +195,8 @@ def __call__(self, sample: Sample) -> Sample: subtrahend_val = level_meta[self.subtrahend] divisor_val = level_meta[self.divisor] + 1e-8 # avoid div by zero sample[key] = (sample[key] - subtrahend_val) / divisor_val + if self.remove_meta: + sample.pop("norm_meta") return sample def _normalize(): From 1561802ad1dcb58451093e7efe74d8769f9452f6 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 30 Oct 2024 14:25:12 -0700 Subject: [PATCH 25/49] move gpu transform definition back to data module --- viscy/data/gpu_aug.py | 153 ++++++++++++++++++++++++++++-------------- 1 file changed, 102 insertions(+), 51 deletions(-) diff --git a/viscy/data/gpu_aug.py b/viscy/data/gpu_aug.py index 573bb734..76f9a469 100644 --- a/viscy/data/gpu_aug.py +++ b/viscy/data/gpu_aug.py @@ -1,5 +1,6 @@ from __future__ import annotations +from abc import ABC, abstractmethod from logging import getLogger from pathlib import Path from typing import TYPE_CHECKING, Literal @@ -9,20 +10,77 @@ from iohub.ngff import Plate, Position, open_ome_zarr from lightning.pytorch import LightningDataModule from monai.data.meta_obj import set_track_meta -from monai.transforms import Compose +from monai.transforms.compose import Compose from torch import Tensor from torch.multiprocessing import Manager -from torch.utils.data import DataLoader, Dataset, Subset +from torch.utils.data import DataLoader, Dataset from viscy.data.distributed import ShardedDistributedSampler from viscy.data.hcs import _ensure_channel_list, _read_norm_meta -from viscy.data.typing import DictTransform +from viscy.data.typing import DictTransform, NormMeta if TYPE_CHECKING: from multiprocessing.managers import DictProxy _logger = getLogger("lightning.pytorch") +_CacheMetadata = tuple[Position, int, NormMeta | None] + + +class GPUTransformDataModule(ABC, LightningDataModule): + def _maybe_sampler( + self, dataset: Dataset, shuffle: bool + ) -> ShardedDistributedSampler | None: + return ( + ShardedDistributedSampler(dataset, shuffle=shuffle) + if torch.distributed.is_initialized() + else None + ) + + def train_dataloader(self) -> DataLoader: + sampler = self._maybe_sampler(self.train_dataset, shuffle=True) + _logger.debug(f"Using training sampler {sampler}") + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + shuffle=False if sampler else True, + sampler=sampler, + persistent_workers=True if self.num_workers > 0 else False, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + drop_last=True, + ) + + def val_dataloader(self) -> DataLoader: + sampler = self._maybe_sampler(self.val_dataset, shuffle=False) + _logger.debug(f"Using validation sampler {sampler}") + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + shuffle=False, + sampler=sampler, + persistent_workers=True if self.num_workers > 0 else False, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + drop_last=False, + ) + + @property + @abstractmethod + def train_cpu_transforms(self) -> Compose: ... + + @property + @abstractmethod + def train_gpu_transforms(self) -> Compose: ... + + @property + @abstractmethod + def val_cpu_transforms(self) -> Compose: ... + + @property + @abstractmethod + def val_gpu_transforms(self) -> Compose: ... + class CachedOmeZarrDataset(Dataset): def __init__( @@ -30,11 +88,11 @@ def __init__( positions: list[Position], channel_names: list[str], cache_map: DictProxy, - transform: DictTransform | None = None, + transform: Compose | None = None, array_key: str = "0", ): key = 0 - self._metadata_map = {} + self._metadata_map: dict[int, _CacheMetadata] = {} for position in positions: img = position[array_key] norm_meta = _read_norm_meta(position) @@ -50,7 +108,7 @@ def __init__( def __len__(self) -> int: return len(self._cache_map) - def __getitem__(self, idx: int) -> Tensor: + def __getitem__(self, idx: int) -> dict[str, Tensor]: position, time_idx, norm_meta = self._metadata_map[idx] cache = self._cache_map[idx] if cache is None: @@ -70,15 +128,10 @@ def __getitem__(self, idx: int) -> Tensor: sample = self.transform(sample) if not isinstance(sample, list): sample = [sample] - out_tensors = [] - for s in sample: - s.pop("norm_meta") - s_out = torch.cat(list(s.values())) - out_tensors.append(s_out) - return out_tensors + return sample -class CachedOmeZarrDataModule(LightningDataModule): +class CachedOmeZarrDataModule(GPUTransformDataModule): def __init__( self, data_path: Path, @@ -86,7 +139,11 @@ def __init__( batch_size: int, num_workers: int, split_ratio: float, - transforms: list[DictTransform], + train_cpu_transforms: list[DictTransform], + val_cpu_transforms: list[DictTransform], + train_gpu_transforms: list[DictTransform], + val_gpu_transforms: list[DictTransform], + pin_memory: bool = True, ): super().__init__() self.data_path = data_path @@ -94,7 +151,27 @@ def __init__( self.batch_size = batch_size self.num_workers = num_workers self.split_ratio = split_ratio - self.transforms = Compose(transforms) + self._train_cpu_transforms = Compose(train_cpu_transforms) + self._val_cpu_transforms = Compose(val_cpu_transforms) + self._train_gpu_transforms = Compose(train_gpu_transforms) + self._val_gpu_transforms = Compose(val_gpu_transforms) + self.pin_memory = pin_memory + + @property + def train_cpu_transforms(self) -> Compose: + return self._train_cpu_transforms + + @property + def train_gpu_transforms(self) -> Compose: + return self._train_gpu_transforms + + @property + def val_cpu_transforms(self) -> Compose: + return self._val_cpu_transforms + + @property + def val_gpu_transforms(self) -> Compose: + return self._val_gpu_transforms def _set_fit_global_state(self, num_positions: int) -> list[int]: # disable metadata tracking in MONAI for performance @@ -103,46 +180,20 @@ def _set_fit_global_state(self, num_positions: int) -> list[int]: return torch.randperm(num_positions).tolist() def setup(self, stage: Literal["fit", "validate"]) -> None: + if stage not in ("fit", "validate"): + raise NotImplementedError("Only fit and validate stages are supported.") cache_map = Manager().dict() plate: Plate = open_ome_zarr(self.data_path, mode="r", layout="hcs") positions = [p for _, p in plate.positions()] shuffled_indices = self._set_fit_global_state(len(positions)) num_train_fovs = int(len(positions) * self.split_ratio) - dataset = CachedOmeZarrDataset( - positions, self.channels, cache_map, self.transforms + train_fovs = [positions[i] for i in shuffled_indices[:num_train_fovs]] + val_fovs = [positions[i] for i in shuffled_indices[num_train_fovs:]] + _logger.debug(f"Training FOVs: {[p.zgroup.name for p in train_fovs]}") + _logger.debug(f"Validation FOVs: {[p.zgroup.name for p in val_fovs]}") + self.train_dataset = CachedOmeZarrDataset( + train_fovs, self.channels, cache_map, transform=self.train_cpu_transforms ) - self.train_dataset = Subset(dataset, shuffled_indices[:num_train_fovs]) - self.val_dataset = Subset(dataset, shuffled_indices[num_train_fovs:]) - - def _maybe_sampler( - self, dataset: Dataset, shuffle: bool - ) -> ShardedDistributedSampler | None: - return ( - ShardedDistributedSampler(dataset, shuffle=shuffle) - if torch.distributed.is_initialized() - else None - ) - - def train_dataloader(self) -> DataLoader: - sampler = self._maybe_sampler(self.train_dataset, shuffle=True) - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - shuffle=False if sampler else True, - sampler=sampler, - persistent_workers=True if self.num_workers > 0 else False, - num_workers=self.num_workers, - drop_last=True, - ) - - def val_dataloader(self) -> DataLoader: - sampler = self._maybe_sampler(self.val_dataset, shuffle=False) - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - shuffle=False, - sampler=sampler, - persistent_workers=True if self.num_workers > 0 else False, - num_workers=self.num_workers, - drop_last=False, + self.val_dataset = CachedOmeZarrDataset( + val_fovs, self.channels, cache_map, transform=self.val_cpu_transforms ) From 2e37217285c5329be1f934c9e6413cd7dc1abfa4 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 30 Oct 2024 16:44:54 -0700 Subject: [PATCH 26/49] add tiled crop transform for validation --- viscy/transforms.py | 56 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 55 insertions(+), 1 deletion(-) diff --git a/viscy/transforms.py b/viscy/transforms.py index 3a2ee0e7..cddd693c 100644 --- a/viscy/transforms.py +++ b/viscy/transforms.py @@ -2,8 +2,11 @@ from typing import Sequence, Union +import numpy as np +import torch from monai.transforms import ( MapTransform, + MultiSampleTrait, RandAdjustContrastd, RandAffined, RandGaussianNoised, @@ -15,9 +18,10 @@ ) from monai.transforms.transform import Randomizable from numpy.random.mtrand import RandomState as RandomState +from torch import Tensor from typing_extensions import Iterable, Literal -from viscy.data.typing import Sample +from viscy.data.typing import ChannelMap, Sample class RandWeightedCropd(RandWeightedCropd): @@ -231,3 +235,53 @@ def set_random_state( ) -> Randomizable: super().set_random_state(seed, state) return self + + +class TiledSpatialCropSamplesd(MapTransform, MultiSampleTrait): + """ + Crop multiple tiled ROIs from an image. + Used for deterministic cropping in validation. + """ + + def __init__( + self, + keys: Union[str, Iterable[str]], + roi_size: tuple[int, int, int], + num_samples: int, + ) -> None: + super().__init__(keys, allow_missing_keys=False) + self.roi_size = roi_size + self.num_samples = num_samples + + def _check_num_samples(self, spatial_size: np.ndarray, offset: int) -> np.ndarray: + max_grid_shape = spatial_size // self.roi_size + max_num_samples = max_grid_shape.prod() + if offset >= max_num_samples: + raise ValueError( + f"Number of samples {self.num_samples} should be " + f"smaller than {max_num_samples}." + ) + grid_idx = np.asarray(np.unravel_index(offset, max_grid_shape)) + return grid_idx * self.roi_size + + def _crop(self, img: Tensor, offset: int) -> Tensor: + spatial_size = np.array(img.shape[-3:]) + crop_start = self._check_num_samples(spatial_size, offset) + crop_end = crop_start + np.array(self.roi_size) + return img[ + ..., + crop_start[0] : crop_end[0], + crop_start[1] : crop_end[1], + crop_start[2] : crop_end[2], + ] + + def __call__(self, sample: Sample) -> Sample: + results = [] + for i in range(self.num_samples): + result = {} + for key in self.keys: + result[key] = self._crop(sample[key], i) + if "norm_meta" in sample: + result["norm_meta"] = sample["norm_meta"] + results.append(result) + return results From 7edf36e3f712f6c6328b0f36349826b795257c7a Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 30 Oct 2024 16:45:31 -0700 Subject: [PATCH 27/49] add stack channel transform for gpu augmentation --- viscy/transforms.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/viscy/transforms.py b/viscy/transforms.py index cddd693c..8c0b3ec3 100644 --- a/viscy/transforms.py +++ b/viscy/transforms.py @@ -285,3 +285,20 @@ def __call__(self, sample: Sample) -> Sample: result["norm_meta"] = sample["norm_meta"] results.append(result) return results + + +class StackChannelsd(MapTransform): + """Stack source and target channels.""" + + def __init__(self, channel_map: ChannelMap) -> None: + channel_names = [] + for channels in channel_map.values(): + channel_names.extend(channels) + super().__init__(channel_names, allow_missing_keys=False) + self.channel_map = channel_map + + def __call__(self, sample: Sample) -> Sample: + results = {} + for key, channels in self.channel_map.items(): + results[key] = torch.cat([sample[ch] for ch in channels], dim=0) + return results From eda5d1ba2bee44be0e41f85f3381e574e2fbb210 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 30 Oct 2024 16:46:02 -0700 Subject: [PATCH 28/49] fix typing --- viscy/data/typing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/viscy/data/typing.py b/viscy/data/typing.py index fb7b6b73..f5e66caa 100644 --- a/viscy/data/typing.py +++ b/viscy/data/typing.py @@ -50,7 +50,7 @@ class Sample(TypedDict, total=False): # Instance segmentation masks labels: OneOrSeq[Tensor] # None: not available - norm_meta: NormMeta + norm_meta: NormMeta | None class ChannelMap(TypedDict): From 550101d627412b6424c92f3b00344040a24c9272 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 30 Oct 2024 16:46:33 -0700 Subject: [PATCH 29/49] collate before sending to gpu --- viscy/data/gpu_aug.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/viscy/data/gpu_aug.py b/viscy/data/gpu_aug.py index 76f9a469..75c91d23 100644 --- a/viscy/data/gpu_aug.py +++ b/viscy/data/gpu_aug.py @@ -10,6 +10,7 @@ from iohub.ngff import Plate, Position, open_ome_zarr from lightning.pytorch import LightningDataModule from monai.data.meta_obj import set_track_meta +from monai.data.utils import list_data_collate from monai.transforms.compose import Compose from torch import Tensor from torch.multiprocessing import Manager @@ -49,6 +50,7 @@ def train_dataloader(self) -> DataLoader: num_workers=self.num_workers, pin_memory=self.pin_memory, drop_last=True, + collate_fn=list_data_collate, ) def val_dataloader(self) -> DataLoader: @@ -63,6 +65,7 @@ def val_dataloader(self) -> DataLoader: num_workers=self.num_workers, pin_memory=self.pin_memory, drop_last=False, + collate_fn=list_data_collate, ) @property From 92e3722d7b2cf9f18aa5c220310828630afaa369 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 30 Oct 2024 16:47:52 -0700 Subject: [PATCH 30/49] inherit gpu transforms for livecell dataset --- viscy/data/livecell.py | 149 +++++++++++++++++++++++++++-------------- 1 file changed, 99 insertions(+), 50 deletions(-) diff --git a/viscy/data/livecell.py b/viscy/data/livecell.py index bb8bb56c..e8da1eb4 100644 --- a/viscy/data/livecell.py +++ b/viscy/data/livecell.py @@ -1,45 +1,81 @@ +from __future__ import annotations + import json from pathlib import Path +from typing import TYPE_CHECKING import torch -from lightning.pytorch import LightningDataModule -from monai.transforms import Compose, MapTransform +from monai.transforms import Compose, MapTransform, Transform from pycocotools.coco import COCO from tifffile import imread from torch.utils.data import DataLoader, Dataset from torchvision.ops import box_convert +from viscy.data.gpu_aug import GPUTransformDataModule from viscy.data.typing import Sample +if TYPE_CHECKING: + from multiprocessing.managers import DictProxy + class LiveCellDataset(Dataset): """ LiveCell dataset. - :param list[Path] images: List of paths to single-page, single-channel TIFF files. - :param MapTransform | Compose transform: Transform to apply to the dataset + Parameters + ---------- + images : list of Path + List of paths to single-page, single-channel TIFF files. + transform : Transform or Compose + Transform to apply to the dataset. + cache_map : DictProxy + Shared dictionary for caching images. """ - def __init__(self, images: list[Path], transform: MapTransform | Compose) -> None: + def __init__( + self, + images: list[Path], + transform: Transform | Compose, + cache_map: DictProxy, + ) -> None: self.images = images self.transform = transform + self._cache_map = cache_map def __len__(self) -> int: return len(self.images) def __getitem__(self, idx: int) -> Sample: - image = imread(self.images[idx])[None, None] - image = torch.from_numpy(image).to(torch.float32) - image = self.transform(image) - return {"source": image, "target": image} + name = self.images[idx] + if name not in self._cache_map: + image = imread(name)[None, None] + image = torch.from_numpy(image).to(torch.float32) + self._cache_map[name] = image + else: + image = self._cache_map[name] + sample = Sample(source=image) + sample = self.transform(sample) + if not isinstance(sample, list): + sample = [sample] + return sample class LiveCellTestDataset(Dataset): """ LiveCell dataset. - :param list[Path] images: List of paths to single-page, single-channel TIFF files. - :param MapTransform | Compose transform: Transform to apply to the dataset + Parameters + ---------- + image_dir : Path + Directory containing the images. + transform : MapTransform | Compose + Transform to apply to the dataset. + annotations : Path + Path to the COCO annotations file. + load_target : bool, optional + Whether to load the target images (default is False). + load_labels : bool, optional + Whether to load the labels (default is False). """ def __init__( @@ -87,7 +123,7 @@ def __getitem__(self, idx: int) -> Sample: return sample -class LiveCellDataModule(LightningDataModule): +class LiveCellDataModule(GPUTransformDataModule): def __init__( self, train_val_images: Path | None = None, @@ -95,33 +131,60 @@ def __init__( train_annotations: Path | None = None, val_annotations: Path | None = None, test_annotations: Path | None = None, - train_transforms: list[MapTransform] = [], - val_transforms: list[MapTransform] = [], + train_cpu_transforms: list[MapTransform] = [], + val_cpu_transforms: list[MapTransform] = [], + train_gpu_transforms: list[MapTransform] = [], + val_gpu_transforms: list[MapTransform] = [], test_transforms: list[MapTransform] = [], batch_size: int = 16, num_workers: int = 8, + pin_memory: bool = True, ) -> None: super().__init__() - self.train_val_images = Path(train_val_images) - if not self.train_val_images.is_dir(): - raise NotADirectoryError(str(train_val_images)) - self.test_images = Path(test_images) - if not self.test_images.is_dir(): - raise NotADirectoryError(str(test_images)) - self.train_annotations = Path(train_annotations) - if not self.train_annotations.is_file(): - raise FileNotFoundError(str(train_annotations)) - self.val_annotations = Path(val_annotations) - if not self.val_annotations.is_file(): - raise FileNotFoundError(str(val_annotations)) - self.test_annotations = Path(test_annotations) - if not self.test_annotations.is_file(): - raise FileNotFoundError(str(test_annotations)) - self.train_transforms = Compose(train_transforms) - self.val_transforms = Compose(val_transforms) + if train_val_images is not None: + self.train_val_images = Path(train_val_images) + if not self.train_val_images.is_dir(): + raise NotADirectoryError(str(train_val_images)) + if test_images is not None: + self.test_images = Path(test_images) + if not self.test_images.is_dir(): + raise NotADirectoryError(str(test_images)) + if train_annotations is not None: + self.train_annotations = Path(train_annotations) + if not self.train_annotations.is_file(): + raise FileNotFoundError(str(train_annotations)) + if val_annotations is not None: + self.val_annotations = Path(val_annotations) + if not self.val_annotations.is_file(): + raise FileNotFoundError(str(val_annotations)) + if test_annotations is not None: + self.test_annotations = Path(test_annotations) + if not self.test_annotations.is_file(): + raise FileNotFoundError(str(test_annotations)) + self._train_cpu_transforms = Compose(train_cpu_transforms) + self._val_cpu_transforms = Compose(val_cpu_transforms) + self._train_gpu_transforms = Compose(train_gpu_transforms) + self._val_gpu_transforms = Compose(val_gpu_transforms) self.test_transforms = Compose(test_transforms) self.batch_size = batch_size self.num_workers = num_workers + self.pin_memory = pin_memory + + @property + def train_cpu_transforms(self) -> Compose: + return self._train_cpu_transforms + + @property + def val_cpu_transforms(self) -> Compose: + return self._val_cpu_transforms + + @property + def train_gpu_transforms(self) -> Compose: + return self._train_gpu_transforms + + @property + def val_gpu_transforms(self) -> Compose: + return self._val_gpu_transforms def setup(self, stage: str) -> None: if stage == "fit": @@ -135,15 +198,18 @@ def _parse_image_names(self, annotations: Path) -> list[Path]: return sorted(images) def _setup_fit(self) -> None: + cache_map = torch.multiprocessing.Manager().dict() train_images = self._parse_image_names(self.train_annotations) val_images = self._parse_image_names(self.val_annotations) self.train_dataset = LiveCellDataset( [self.train_val_images / f for f in train_images], - transform=self.train_transforms, + transform=self.train_cpu_transforms, + cache_map=cache_map, ) self.val_dataset = LiveCellDataset( [self.train_val_images / f for f in val_images], - transform=self.val_transforms, + transform=self.val_cpu_transforms, + cache_map=cache_map, ) def _setup_test(self) -> None: @@ -154,23 +220,6 @@ def _setup_test(self) -> None: load_labels=True, ) - def train_dataloader(self) -> DataLoader: - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - persistent_workers=bool(self.num_workers), - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader: - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - persistent_workers=bool(self.num_workers), - ) - def test_dataloader(self) -> DataLoader: return DataLoader( self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers From c185377960b2a5b4ad0afc5d745ecf74efebdc7d Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 30 Oct 2024 16:48:37 -0700 Subject: [PATCH 31/49] update fcmae engine to apply per-dataset augmentations --- viscy/translation/engine.py | 62 ++++++++++++++++++++++++++----------- 1 file changed, 44 insertions(+), 18 deletions(-) diff --git a/viscy/translation/engine.py b/viscy/translation/engine.py index 4e8af365..221dbb06 100644 --- a/viscy/translation/engine.py +++ b/viscy/translation/engine.py @@ -7,6 +7,7 @@ import torch.nn.functional as F from imageio import imwrite from lightning.pytorch import LightningModule +from monai.data.utils import collate_meta_tensor from monai.optimizers import WarmupCosineSchedule from monai.transforms import Compose, DivisiblePad, Rotate90 from torch import Tensor, nn @@ -23,6 +24,8 @@ structural_similarity_index_measure, ) +from viscy.data.combined import CombinedDataModule +from viscy.data.gpu_aug import GPUTransformDataModule from viscy.data.typing import Sample from viscy.translation.evaluation_metrics import mean_average_precision, ms_ssim_25d from viscy.unet.networks.fcmae import FullyConvolutionalMAE @@ -477,30 +480,53 @@ def __init__( self.fit_mask_ratio = fit_mask_ratio self.train_transforms = Compose(train_transforms) self.validation_transforms = Compose(validation_transforms) + self.save_hyperparameters() + + def on_fit_start(self): + dm = self.trainer.datamodule + if not isinstance(dm, CombinedDataModule): + raise ValueError( + f"Container data module type {type(dm)} " + "is not supported for FCMAE training" + ) + for subdm in dm.data_modules: + if not isinstance(subdm, GPUTransformDataModule): + raise ValueError( + f"Member data module type {type(subdm)} " + "is not supported for FCMAE training" + ) + self.datamodules = dm.data_modules def forward(self, x: Tensor, mask_ratio: float = 0.0): return self.model(x, mask_ratio) - def forward_fit(self, batch: Tensor) -> tuple[Tensor]: + def forward_fit(self, batch: Tensor) -> tuple[Tensor, Tensor, Tensor]: pred, mask = self.forward(batch, mask_ratio=self.fit_mask_ratio) loss = F.mse_loss(pred, batch, reduction="none") loss = (loss.mean(2) * mask).sum() / mask.sum() return pred, mask, loss - def transform_and_collate(self, batch: list[Tensor], transforms: Compose) -> Tensor: + @torch.no_grad() + def train_transform_and_collate(self, batch: list[dict[Sample]]) -> Tensor: transformed = [] - for sample in batch: - for element in sample: - transformed.append(transforms(element)) - return torch.stack(transformed) - - def training_step(self, batch: list[Tensor], batch_idx: int): - batch = self.transform_and_collate(batch, self.train_transforms) - pred, mask, loss = self.forward_fit(batch) + for dataset_batch, dm in zip(batch, self.datamodules): + dataset_batch = dm.train_gpu_transforms(dataset_batch) + transformed.extend(dataset_batch) + return collate_meta_tensor(transformed)["source"] + + def val_transform_and_collate( + self, batch: list[Sample], dataloader_idx: int + ) -> Tensor: + batch = self.datamodules[dataloader_idx].val_gpu_transforms(batch) + return collate_meta_tensor(batch)["source"] + + def training_step(self, batch: list[list[Sample]], batch_idx: int) -> Tensor: + x = self.train_transform_and_collate(batch) + pred, mask, loss = self.forward_fit(x) if batch_idx < self.log_batches_per_epoch: self.training_step_outputs.extend( detach_sample( - (batch, batch * mask.unsqueeze(2), pred), self.log_samples_per_batch + (x, x * mask.unsqueeze(2), pred), self.log_samples_per_batch ) ) self.log( @@ -511,25 +537,25 @@ def training_step(self, batch: list[Tensor], batch_idx: int): prog_bar=True, logger=True, sync_dist=True, - batch_size=batch.shape[0], + batch_size=x.shape[0], ) return loss def validation_step( - self, batch: list[Tensor], batch_idx: int, dataloader_idx: int = 0 - ): - batch = self.transform_and_collate(batch, self.validation_transforms) - pred, mask, loss = self.forward_fit(batch) + self, batch: list[Sample], batch_idx: int, dataloader_idx: int = 0 + ) -> None: + x = self.val_transform_and_collate(batch, dataloader_idx) + pred, mask, loss = self.forward_fit(x) if dataloader_idx + 1 > len(self.validation_losses): self.validation_losses.append([]) self.validation_losses[dataloader_idx].append(loss.detach()) self.log( - "loss/val", loss.to(self.device), sync_dist=True, batch_size=batch.shape[0] + "loss/val", loss.to(self.device), sync_dist=True, batch_size=x.shape[0] ) if batch_idx < self.log_batches_per_epoch: self.validation_step_outputs.extend( detach_sample( - (batch, batch * mask.unsqueeze(2), pred), + (x, x * mask.unsqueeze(2), pred), self.log_samples_per_batch, ) ) From 2ca134b1dffcc7e9d90b99dec8de8c486005aa68 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 30 Oct 2024 16:59:06 -0700 Subject: [PATCH 32/49] format and lint hcs_ram --- viscy/data/hcs_ram.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/viscy/data/hcs_ram.py b/viscy/data/hcs_ram.py index a9ff25d3..51eb8e90 100644 --- a/viscy/data/hcs_ram.py +++ b/viscy/data/hcs_ram.py @@ -6,6 +6,7 @@ import numpy as np import torch +import torch.distributed as dist from iohub.ngff import Position, open_ome_zarr from lightning.pytorch import LightningDataModule from monai.data import set_track_meta @@ -19,11 +20,9 @@ from torch import Tensor from torch.utils.data import DataLoader, Dataset +from viscy.data.distributed import ShardedDistributedSampler from viscy.data.hcs import _read_norm_meta from viscy.data.typing import ChannelMap, DictTransform, Sample -from viscy.data.distributed import ShardedDistributedSampler -from torch.distributed import get_rank -import torch.distributed as dist _logger = logging.getLogger("lightning.pytorch") @@ -72,10 +71,12 @@ def _collate_samples(batch: Sequence[Sample]) -> Sample: collated[key] = collate_meta_tensor(data) return collated + def is_ddp_enabled() -> bool: """Check if distributed data parallel (DDP) is initialized.""" return dist.is_available() and dist.is_initialized() + class CachedDataset(Dataset): """ A dataset that caches the data in RAM. @@ -341,7 +342,7 @@ def val_dataloader(self) -> DataLoader: else: sampler = None _logger.info("Using standard sampler for non-distributed validation") - + return DataLoader( self.val_dataset, batch_size=self.batch_size, @@ -350,5 +351,5 @@ def val_dataloader(self) -> DataLoader: pin_memory=True, shuffle=False, timeout=self.timeout, - sampler=sampler + sampler=sampler, ) From be0e94fb55bec3d36ad29222457881b3df0fec85 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 30 Oct 2024 17:04:46 -0700 Subject: [PATCH 33/49] fix abc type hint --- viscy/data/gpu_aug.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/viscy/data/gpu_aug.py b/viscy/data/gpu_aug.py index 75c91d23..43aeab34 100644 --- a/viscy/data/gpu_aug.py +++ b/viscy/data/gpu_aug.py @@ -29,6 +29,12 @@ class GPUTransformDataModule(ABC, LightningDataModule): + train_dataset: Dataset + val_dataset: Dataset + batch_size: int + num_workers: int + pin_memory: bool + def _maybe_sampler( self, dataset: Dataset, shuffle: bool ) -> ShardedDistributedSampler | None: From 92c4b0a7bc7784a626cfefa9155987db0350cc5d Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 30 Oct 2024 17:11:08 -0700 Subject: [PATCH 34/49] update docstring style --- viscy/data/combined.py | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/viscy/data/combined.py b/viscy/data/combined.py index 9585345c..e2e864d6 100644 --- a/viscy/data/combined.py +++ b/viscy/data/combined.py @@ -21,11 +21,18 @@ class CombinedDataModule(LightningDataModule): """Wrapper for combining multiple data modules. For supported modes, see ``lightning.pytorch.utilities.combined_loader``. - :param Sequence[LightningDataModule] data_modules: data modules to combine - :param str train_mode: mode in training stage, defaults to "max_size_cycle" - :param str val_mode: mode in validation stage, defaults to "sequential" - :param str test_mode: mode in testing stage, defaults to "sequential" - :param str predict_mode: mode in prediction stage, defaults to "sequential" + Parameters + ---------- + data_modules : Sequence[LightningDataModule] + data modules to combine + train_mode : CombineMode, optional + mode in training stage, by default CombineMode.MAX_SIZE_CYCLE + val_mode : CombineMode, optional + _description_, by default CombineMode.SEQUENTIAL + test_mode : CombineMode, optional + mode in testing stage, by default CombineMode.SEQUENTIAL + predict_mode : CombineMode, optional + mode in prediction stage, by default CombineMode.SEQUENTIAL """ def __init__( @@ -78,10 +85,15 @@ def predict_dataloader(self): class ConcatDataModule(LightningDataModule): """ Concatenate multiple data modules. - The concatenated data module will have the same - batch size and number of workers as the first data module. - Each element will be sampled uniformly regardless of their original data module. - :param Sequence[LightningDataModule] data_modules: data modules to concatenate + + The concatenated data module will have the same batch size and number of workers + as the first data module. Each element will be sampled uniformly regardless of + their original data module. + + Parameters + ---------- + data_modules : Sequence[LightningDataModule] + Data modules to concatenate. """ def __init__(self, data_modules: Sequence[LightningDataModule]): From f7b585c121662e2a4475184e90bea74671703271 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 30 Oct 2024 17:17:36 -0700 Subject: [PATCH 35/49] disable grad for validation transforms --- viscy/translation/engine.py | 1 + 1 file changed, 1 insertion(+) diff --git a/viscy/translation/engine.py b/viscy/translation/engine.py index 221dbb06..66e18f8b 100644 --- a/viscy/translation/engine.py +++ b/viscy/translation/engine.py @@ -514,6 +514,7 @@ def train_transform_and_collate(self, batch: list[dict[Sample]]) -> Tensor: transformed.extend(dataset_batch) return collate_meta_tensor(transformed)["source"] + @torch.no_grad() def val_transform_and_collate( self, batch: list[Sample], dataloader_idx: int ) -> Tensor: From 42c49f571c967c8c42f886433b3e8709a96407bb Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 31 Oct 2024 10:39:34 -0700 Subject: [PATCH 36/49] improve sample image logging in fcmae --- viscy/translation/engine.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/viscy/translation/engine.py b/viscy/translation/engine.py index 66e18f8b..5c49aa10 100644 --- a/viscy/translation/engine.py +++ b/viscy/translation/engine.py @@ -1,5 +1,6 @@ import logging import os +import random from typing import Literal, Sequence, Union import numpy as np @@ -9,7 +10,7 @@ from lightning.pytorch import LightningModule from monai.data.utils import collate_meta_tensor from monai.optimizers import WarmupCosineSchedule -from monai.transforms import Compose, DivisiblePad, Rotate90 +from monai.transforms import DivisiblePad, Rotate90 from torch import Tensor, nn from torch.optim.lr_scheduler import ConstantLR from torchmetrics.functional import ( @@ -472,14 +473,10 @@ class FcmaeUNet(VSUNet): def __init__( self, fit_mask_ratio: float = 0.0, - train_transforms=[], - validation_transforms=[], **kwargs, ): super().__init__(architecture="fcmae", **kwargs) self.fit_mask_ratio = fit_mask_ratio - self.train_transforms = Compose(train_transforms) - self.validation_transforms = Compose(validation_transforms) self.save_hyperparameters() def on_fit_start(self): @@ -512,6 +509,8 @@ def train_transform_and_collate(self, batch: list[dict[Sample]]) -> Tensor: for dataset_batch, dm in zip(batch, self.datamodules): dataset_batch = dm.train_gpu_transforms(dataset_batch) transformed.extend(dataset_batch) + # shuffle references in place for better logging + random.shuffle(transformed) return collate_meta_tensor(transformed)["source"] @torch.no_grad() @@ -525,10 +524,9 @@ def training_step(self, batch: list[list[Sample]], batch_idx: int) -> Tensor: x = self.train_transform_and_collate(batch) pred, mask, loss = self.forward_fit(x) if batch_idx < self.log_batches_per_epoch: + target = x * mask.unsqueeze(2) self.training_step_outputs.extend( - detach_sample( - (x, x * mask.unsqueeze(2), pred), self.log_samples_per_batch - ) + detach_sample((x, target, pred), self.log_samples_per_batch) ) self.log( "loss/train", From 4bf108881ed91f9e8706f7c74dc657b6e47eff9c Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 31 Oct 2024 10:40:03 -0700 Subject: [PATCH 37/49] fix dataset length when batch size is larger than the dataset --- viscy/data/gpu_aug.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/viscy/data/gpu_aug.py b/viscy/data/gpu_aug.py index 43aeab34..2b42d14b 100644 --- a/viscy/data/gpu_aug.py +++ b/viscy/data/gpu_aug.py @@ -55,7 +55,7 @@ def train_dataloader(self) -> DataLoader: persistent_workers=True if self.num_workers > 0 else False, num_workers=self.num_workers, pin_memory=self.pin_memory, - drop_last=True, + drop_last=False, collate_fn=list_data_collate, ) @@ -115,7 +115,7 @@ def __init__( self.transform = transform def __len__(self) -> int: - return len(self._cache_map) + return len(self._metadata_map) def __getitem__(self, idx: int) -> dict[str, Tensor]: position, time_idx, norm_meta = self._metadata_map[idx] From 32769509c2f0781a4923e58daffcb5de7dea69ab Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 31 Oct 2024 10:43:35 -0700 Subject: [PATCH 38/49] fix docstring --- viscy/data/combined.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/viscy/data/combined.py b/viscy/data/combined.py index e2e864d6..28cf9339 100644 --- a/viscy/data/combined.py +++ b/viscy/data/combined.py @@ -28,7 +28,7 @@ class CombinedDataModule(LightningDataModule): train_mode : CombineMode, optional mode in training stage, by default CombineMode.MAX_SIZE_CYCLE val_mode : CombineMode, optional - _description_, by default CombineMode.SEQUENTIAL + mode in validation stage, by default CombineMode.SEQUENTIAL test_mode : CombineMode, optional mode in testing stage, by default CombineMode.SEQUENTIAL predict_mode : CombineMode, optional @@ -86,8 +86,8 @@ class ConcatDataModule(LightningDataModule): """ Concatenate multiple data modules. - The concatenated data module will have the same batch size and number of workers - as the first data module. Each element will be sampled uniformly regardless of + The concatenated data module will have the same batch size and number of workers + as the first data module. Each element will be sampled uniformly regardless of their original data module. Parameters From 14a16ed935beb1f1546d627b26e660e5fbdc4824 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 31 Oct 2024 14:52:15 -0700 Subject: [PATCH 39/49] add option to disable normalization metadata --- viscy/data/gpu_aug.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/viscy/data/gpu_aug.py b/viscy/data/gpu_aug.py index 2b42d14b..e02ab582 100644 --- a/viscy/data/gpu_aug.py +++ b/viscy/data/gpu_aug.py @@ -99,6 +99,7 @@ def __init__( cache_map: DictProxy, transform: Compose | None = None, array_key: str = "0", + load_normalization_metadata: bool = True, ): key = 0 self._metadata_map: dict[int, _CacheMetadata] = {} @@ -113,6 +114,7 @@ def __init__( self.array_key = array_key self._cache_map = cache_map self.transform = transform + self.load_normalization_metadata = load_normalization_metadata def __len__(self) -> int: return len(self._metadata_map) @@ -132,7 +134,8 @@ def __getitem__(self, idx: int) -> dict[str, Tensor]: _logger.debug(f"Using cached volume for index {idx}") volume = cache sample = {name: img[None] for name, img in zip(self.channels.keys(), volume)} - sample["norm_meta"] = norm_meta + if self.load_normalization_metadata: + sample["norm_meta"] = norm_meta if self.transform: sample = self.transform(sample) if not isinstance(sample, list): From 6719305fb315493232a208eb6100804cba64410e Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 31 Oct 2024 14:52:30 -0700 Subject: [PATCH 40/49] inherit gpu transform for ctmc --- viscy/data/ctmc_v1.py | 95 ++++++++++++++++++++++--------------------- 1 file changed, 48 insertions(+), 47 deletions(-) diff --git a/viscy/data/ctmc_v1.py b/viscy/data/ctmc_v1.py index d71859b8..8fea3146 100644 --- a/viscy/data/ctmc_v1.py +++ b/viscy/data/ctmc_v1.py @@ -1,25 +1,13 @@ from pathlib import Path +import torch from iohub.ngff import open_ome_zarr -from lightning.pytorch import LightningDataModule from monai.transforms import Compose, MapTransform -from torch.utils.data import DataLoader -from viscy.data.hcs import ChannelMap, SlidingWindowDataset -from viscy.data.typing import Sample +from viscy.data.gpu_aug import CachedOmeZarrDataset, GPUTransformDataModule -class CTMCv1ValidationDataset(SlidingWindowDataset): - def __len__(self, subsample_rate: int = 30) -> int: - # sample every 30th frame in the videos - return super().__len__() // self.subsample_rate - - def __getitem__(self, index: int) -> Sample: - index = index * self.subsample_rate - return super().__getitem__(index) - - -class CTMCv1DataModule(LightningDataModule): +class CTMCv1DataModule(GPUTransformDataModule): """ Autoregression data module for the CTMCv1 dataset. Training and validation datasets are stored in separate HCS OME-Zarr stores. @@ -37,20 +25,44 @@ def __init__( self, train_data_path: str | Path, val_data_path: str | Path, - train_transforms: list[MapTransform], - val_transforms: list[MapTransform], + train_cpu_transforms: list[MapTransform], + val_cpu_transforms: list[MapTransform], + train_gpu_transforms: list[MapTransform], + val_gpu_transforms: list[MapTransform], batch_size: int = 16, num_workers: int = 8, + val_subsample_ratio: int = 30, channel_name: str = "DIC", + pin_memory: bool = True, ) -> None: super().__init__() self.train_data_path = train_data_path self.val_data_path = val_data_path - self.train_transforms = train_transforms - self.val_transforms = val_transforms - self.channel_map = ChannelMap(source=[channel_name], target=[channel_name]) + self._train_cpu_transforms = Compose(train_cpu_transforms) + self._val_cpu_transforms = Compose(val_cpu_transforms) + self._train_gpu_transforms = Compose(train_gpu_transforms) + self._val_gpu_transforms = Compose(val_gpu_transforms) + self.channel_names = [channel_name] self.batch_size = batch_size self.num_workers = num_workers + self.val_subsample_ratio = val_subsample_ratio + self.pin_memory = pin_memory + + @property + def train_cpu_transforms(self) -> Compose: + return self._train_cpu_transforms + + @property + def val_cpu_transforms(self) -> Compose: + return self._val_cpu_transforms + + @property + def train_gpu_transforms(self) -> Compose: + return self._train_gpu_transforms + + @property + def val_gpu_transforms(self) -> Compose: + return self._val_gpu_transforms def setup(self, stage: str) -> None: if stage != "fit": @@ -58,37 +70,26 @@ def setup(self, stage: str) -> None: self._setup_fit() def _setup_fit(self) -> None: + cache_map = torch.multiprocessing.Manager().dict() train_plate = open_ome_zarr(self.train_data_path) val_plate = open_ome_zarr(self.val_data_path) train_positions = [p for _, p in train_plate.positions()] val_positions = [p for _, p in val_plate.positions()] - self.train_dataset = SlidingWindowDataset( - train_positions, - channels=self.channel_map, - z_window_size=1, - transform=Compose(self.train_transforms), + self.train_dataset = CachedOmeZarrDataset( + positions=train_positions, + channel_names=self.channel_names, + cache_map=cache_map, + transform=self.train_cpu_transforms, + load_normalization_metadata=False, ) - self.val_dataset = CTMCv1ValidationDataset( - val_positions, - channels=self.channel_map, - z_window_size=1, - transform=Compose(self.val_transforms), + full_val_dataset = CachedOmeZarrDataset( + positions=val_positions, + channel_names=self.channel_names, + cache_map=cache_map, + transform=self.val_cpu_transforms, + load_normalization_metadata=False, ) - - def train_dataloader(self) -> DataLoader: - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - persistent_workers=bool(self.num_workers), - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader: - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - persistent_workers=bool(self.num_workers), - shuffle=False, + subsample_indices = list( + range(0, len(full_val_dataset), self.val_subsample_ratio) ) + self.val_dataset = torch.utils.data.Subset(full_val_dataset, subsample_indices) From fad3d4eb296b8feb9f7afd0cb12c770f62248e2b Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 31 Oct 2024 14:53:34 -0700 Subject: [PATCH 41/49] remove duplicate method overrride --- viscy/transforms.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/viscy/transforms.py b/viscy/transforms.py index 8c0b3ec3..22f55e2c 100644 --- a/viscy/transforms.py +++ b/viscy/transforms.py @@ -16,8 +16,6 @@ RandWeightedCropd, ScaleIntensityRangePercentilesd, ) -from monai.transforms.transform import Randomizable -from numpy.random.mtrand import RandomState as RandomState from torch import Tensor from typing_extensions import Iterable, Literal @@ -230,12 +228,6 @@ def __call__(self, sample: Sample) -> Sample: sample[key] = -sample[key] return sample - def set_random_state( - self, seed: int | None = None, state: RandomState | None = None - ) -> Randomizable: - super().set_random_state(seed, state) - return self - class TiledSpatialCropSamplesd(MapTransform, MultiSampleTrait): """ From 07c1021bbe3e8ea000974d08361b3913bc040b0f Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Fri, 1 Nov 2024 13:15:33 -0700 Subject: [PATCH 42/49] update docstring for ctmc --- viscy/data/ctmc_v1.py | 32 +++++++++++++++++++++++++------- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/viscy/data/ctmc_v1.py b/viscy/data/ctmc_v1.py index 8fea3146..3a0248cc 100644 --- a/viscy/data/ctmc_v1.py +++ b/viscy/data/ctmc_v1.py @@ -12,13 +12,31 @@ class CTMCv1DataModule(GPUTransformDataModule): Autoregression data module for the CTMCv1 dataset. Training and validation datasets are stored in separate HCS OME-Zarr stores. - :param str | Path train_data_path: Path to the training dataset - :param str | Path val_data_path: Path to the validation dataset - :param list[MapTransform] train_transforms: List of transforms for training - :param list[MapTransform] val_transforms: List of transforms for validation - :param int batch_size: Batch size, defaults to 16 - :param int num_workers: Number of workers, defaults to 8 - :param str channel_name: Name of the DIC channel, defaults to "DIC" + Parameters + ---------- + train_data_path : str or Path + Path to the training dataset. + val_data_path : str or Path + Path to the validation dataset. + train_cpu_transforms : list of MapTransform + List of CPU transforms for training. + val_cpu_transforms : list of MapTransform + List of CPU transforms for validation. + train_gpu_transforms : list of MapTransform + List of GPU transforms for training. + val_gpu_transforms : list of MapTransform + List of GPU transforms for validation. + batch_size : int, optional + Batch size, by default 16. + num_workers : int, optional + Number of dataloading workers, by default 8. + val_subsample_ratio : int, optional + Skip evert N frames for validation to reduce redundancy in video, + by default 30. + channel_name : str, optional + Name of the DIC channel, by default "DIC". + pin_memory : bool, optional + Pin memory for dataloaders, by default True. """ def __init__( From 7d79473d6c5a35c06c3643f70275f795c17c8910 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 13 Nov 2024 10:10:42 -0800 Subject: [PATCH 43/49] allow skipping caching for large datasets --- viscy/data/gpu_aug.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/viscy/data/gpu_aug.py b/viscy/data/gpu_aug.py index e02ab582..f5481ce0 100644 --- a/viscy/data/gpu_aug.py +++ b/viscy/data/gpu_aug.py @@ -100,6 +100,7 @@ def __init__( transform: Compose | None = None, array_key: str = "0", load_normalization_metadata: bool = True, + skip_cache: bool = False, ): key = 0 self._metadata_map: dict[int, _CacheMetadata] = {} @@ -115,6 +116,7 @@ def __init__( self._cache_map = cache_map self.transform = transform self.load_normalization_metadata = load_normalization_metadata + self.skip_cache = skip_cache def __len__(self) -> int: return len(self._metadata_map) @@ -123,13 +125,15 @@ def __getitem__(self, idx: int) -> dict[str, Tensor]: position, time_idx, norm_meta = self._metadata_map[idx] cache = self._cache_map[idx] if cache is None: - _logger.debug(f"Caching for index {idx}") + _logger.debug(f"Loading volume for index {idx}") volume = torch.from_numpy( position[self.array_key] .oindex[time_idx, list(self.channels.values())] .astype(np.float32) ) - self._cache_map[idx] = volume + if not self.skip_cache: + _logger.debug(f"Caching for index {idx}") + self._cache_map[idx] = volume else: _logger.debug(f"Using cached volume for index {idx}") volume = cache @@ -156,6 +160,7 @@ def __init__( train_gpu_transforms: list[DictTransform], val_gpu_transforms: list[DictTransform], pin_memory: bool = True, + skip_cache: bool = False, ): super().__init__() self.data_path = data_path @@ -168,6 +173,7 @@ def __init__( self._train_gpu_transforms = Compose(train_gpu_transforms) self._val_gpu_transforms = Compose(val_gpu_transforms) self.pin_memory = pin_memory + self.skip_cache = skip_cache @property def train_cpu_transforms(self) -> Compose: @@ -204,8 +210,16 @@ def setup(self, stage: Literal["fit", "validate"]) -> None: _logger.debug(f"Training FOVs: {[p.zgroup.name for p in train_fovs]}") _logger.debug(f"Validation FOVs: {[p.zgroup.name for p in val_fovs]}") self.train_dataset = CachedOmeZarrDataset( - train_fovs, self.channels, cache_map, transform=self.train_cpu_transforms + train_fovs, + self.channels, + cache_map, + transform=self.train_cpu_transforms, + skip_cache=self.skip_cache, ) self.val_dataset = CachedOmeZarrDataset( - val_fovs, self.channels, cache_map, transform=self.val_cpu_transforms + val_fovs, + self.channels, + cache_map, + transform=self.val_cpu_transforms, + skip_cache=self.skip_cache, ) From e548d52de93469b3b54579bed1b0090847e0e090 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 14 Nov 2024 14:10:14 -0800 Subject: [PATCH 44/49] make the fcmae module compatible with image translation --- viscy/translation/engine.py | 74 +++++++++++++++++++++++++++---------- 1 file changed, 55 insertions(+), 19 deletions(-) diff --git a/viscy/translation/engine.py b/viscy/translation/engine.py index c30f06bc..9f710b77 100644 --- a/viscy/translation/engine.py +++ b/viscy/translation/engine.py @@ -89,6 +89,13 @@ def forward(self, preds, target): return loss +class MaskedMSELoss(nn.Module): + def forward(self, preds, original, mask): + loss = F.mse_loss(preds, original, reduction="none") + loss = (loss.mean(2) * mask).sum() / mask.sum() + return loss + + class VSUNet(LightningModule): """Regression U-Net module for virtual staining. @@ -478,7 +485,7 @@ def __init__( ): super().__init__(architecture="fcmae", **kwargs) self.fit_mask_ratio = fit_mask_ratio - self.save_hyperparameters() + self.save_hyperparameters(ignore=["loss_function"]) def on_fit_start(self): dm = self.trainer.datamodule @@ -494,40 +501,70 @@ def on_fit_start(self): "is not supported for FCMAE training" ) self.datamodules = dm.data_modules + if self.model.pretraining and not isinstance(self.loss_function, MaskedMSELoss): + raise ValueError( + "MaskedMSELoss is required for FCMAE pre-training, " + f"got {type(self.loss_function)}" + ) def forward(self, x: Tensor, mask_ratio: float = 0.0): return self.model(x, mask_ratio) - def forward_fit(self, batch: Tensor) -> tuple[Tensor, Tensor, Tensor]: - pred, mask = self.forward(batch, mask_ratio=self.fit_mask_ratio) - loss = F.mse_loss(pred, batch, reduction="none") - loss = (loss.mean(2) * mask).sum() / mask.sum() - return pred, mask, loss + def forward_fit_fcmae( + self, batch: Sample, return_target: bool = False + ) -> tuple[Tensor, Tensor | None, Tensor]: + x = batch["source"] + pred, mask = self.forward(x, mask_ratio=self.fit_mask_ratio) + loss = self.loss_function(pred, x, mask) + if return_target: + target = x * mask.unsqueeze(2) + else: + target = None + return pred, target, loss + + def forward_fit_supervised(self, batch: Sample) -> tuple[Tensor, Tensor, Tensor]: + x = batch["source"] + target = batch["target"] + pred = self.forward(x) + loss = self.loss_function(pred, target) + return pred, target, loss + + def forward_fit_task( + self, batch: Sample, batch_idx: int + ) -> tuple[Tensor, Tensor | None, Tensor]: + if self.model.pretraining: + if batch_idx < self.log_batches_per_epoch: + return_target = True + pred, target, loss = self.forward_fit_fcmae(batch, return_target) + else: + pred, target, loss = self.forward_fit_supervised(batch) + return pred, target, loss @torch.no_grad() - def train_transform_and_collate(self, batch: list[dict[Sample]]) -> Tensor: + def train_transform_and_collate(self, batch: list[dict[str, Tensor]]) -> Sample: transformed = [] for dataset_batch, dm in zip(batch, self.datamodules): dataset_batch = dm.train_gpu_transforms(dataset_batch) transformed.extend(dataset_batch) # shuffle references in place for better logging random.shuffle(transformed) - return collate_meta_tensor(transformed)["source"] + return collate_meta_tensor(transformed) @torch.no_grad() def val_transform_and_collate( self, batch: list[Sample], dataloader_idx: int ) -> Tensor: batch = self.datamodules[dataloader_idx].val_gpu_transforms(batch) - return collate_meta_tensor(batch)["source"] + return collate_meta_tensor(batch) def training_step(self, batch: list[list[Sample]], batch_idx: int) -> Tensor: - x = self.train_transform_and_collate(batch) - pred, mask, loss = self.forward_fit(x) + batch = self.train_transform_and_collate(batch) + pred, target, loss = self.forward_fit_task(batch, batch_idx) if batch_idx < self.log_batches_per_epoch: - target = x * mask.unsqueeze(2) self.training_step_outputs.extend( - detach_sample((x, target, pred), self.log_samples_per_batch) + detach_sample( + (batch["source"], target, pred), self.log_samples_per_batch + ) ) self.log( "loss/train", @@ -537,25 +574,24 @@ def training_step(self, batch: list[list[Sample]], batch_idx: int) -> Tensor: prog_bar=True, logger=True, sync_dist=True, - batch_size=x.shape[0], + batch_size=pred.shape[0], ) return loss def validation_step( self, batch: list[Sample], batch_idx: int, dataloader_idx: int = 0 ) -> None: - x = self.val_transform_and_collate(batch, dataloader_idx) - pred, mask, loss = self.forward_fit(x) + batch = self.val_transform_and_collate(batch, dataloader_idx) + pred, target, loss = self.forward_fit_task(batch, batch_idx) if dataloader_idx + 1 > len(self.validation_losses): self.validation_losses.append([]) self.validation_losses[dataloader_idx].append(loss.detach()) self.log( - "loss/val", loss.to(self.device), sync_dist=True, batch_size=x.shape[0] + "loss/val", loss.to(self.device), sync_dist=True, batch_size=pred.shape[0] ) if batch_idx < self.log_batches_per_epoch: self.validation_step_outputs.extend( detach_sample( - (x, x * mask.unsqueeze(2), pred), - self.log_samples_per_batch, + (batch["source"], target, pred), self.log_samples_per_batch ) ) From 084717f6ff8bb4f855a47e9e3c3718c69b6f6e55 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Mon, 18 Nov 2024 16:23:16 -0800 Subject: [PATCH 45/49] remove prototype implementation --- viscy/data/hcs_ram.py | 355 ------------------------------------------ 1 file changed, 355 deletions(-) delete mode 100644 viscy/data/hcs_ram.py diff --git a/viscy/data/hcs_ram.py b/viscy/data/hcs_ram.py deleted file mode 100644 index 51eb8e90..00000000 --- a/viscy/data/hcs_ram.py +++ /dev/null @@ -1,355 +0,0 @@ -import logging -from datetime import datetime -from multiprocessing import Manager -from multiprocessing.managers import DictProxy -from typing import Callable, Literal, Sequence - -import numpy as np -import torch -import torch.distributed as dist -from iohub.ngff import Position, open_ome_zarr -from lightning.pytorch import LightningDataModule -from monai.data import set_track_meta -from monai.data.utils import collate_meta_tensor -from monai.transforms import ( - CenterSpatialCropd, - Compose, - MapTransform, - MultiSampleTrait, -) -from torch import Tensor -from torch.utils.data import DataLoader, Dataset - -from viscy.data.distributed import ShardedDistributedSampler -from viscy.data.hcs import _read_norm_meta -from viscy.data.typing import ChannelMap, DictTransform, Sample - -_logger = logging.getLogger("lightning.pytorch") - -# TODO: cache the norm metadata when caching the dataset - -# Map the NumPy dtype to the corresponding PyTorch dtype -numpy_to_torch_dtype = { - np.dtype("float32"): torch.float32, - np.dtype("float64"): torch.float64, - np.dtype("int32"): torch.int32, - np.dtype("int64"): torch.int64, - np.dtype("uint8"): torch.int8, - np.dtype("uint16"): torch.int16, -} - - -def _stack_channels( - sample_images: list[dict[str, Tensor]] | dict[str, Tensor], - channels: ChannelMap, - key: str, -) -> Tensor | list[Tensor]: - """Stack single-channel images into a multi-channel tensor.""" - if not isinstance(sample_images, list): - return torch.stack([sample_images[ch][0] for ch in channels[key]]) - # training time - # sample_images is a list['Phase3D'].shape = (1,3,256,256) - return [torch.stack([im[ch][0] for ch in channels[key]]) for im in sample_images] - - -def _collate_samples(batch: Sequence[Sample]) -> Sample: - """Collate samples into a batch sample. - - :param Sequence[Sample] batch: a sequence of dictionaries, - where each key may point to a value of a single tensor or a list of tensors, - as is the case with ``train_patches_per_stack > 1``. - :return Sample: Batch sample (dictionary of tensors) - """ - collated: Sample = {} - for key in batch[0].keys(): - data = [] - for sample in batch: - if isinstance(sample[key], Sequence): - data.extend(sample[key]) - else: - data.append(sample[key]) - collated[key] = collate_meta_tensor(data) - return collated - - -def is_ddp_enabled() -> bool: - """Check if distributed data parallel (DDP) is initialized.""" - return dist.is_available() and dist.is_initialized() - - -class CachedDataset(Dataset): - """ - A dataset that caches the data in RAM. - It relies on the `__getitem__` method to load the data on the 1st epoch. - """ - - def __init__( - self, - shared_dict: DictProxy, - positions: list[Position], - channels: ChannelMap, - transform: DictTransform | None = None, - ): - super().__init__() - if is_ddp_enabled(): - self.rank = dist.get_rank() - _logger.info(f"=== Initializing cache pool for rank {self.rank} ===") - - self.cache_dict = shared_dict - self.positions = positions - self.channels = channels - self.transform = transform - - self.source_ch_idx = [ - positions[0].get_channel_index(c) for c in channels["source"] - ] - self.target_ch_idx = ( - [positions[0].get_channel_index(c) for c in channels["target"]] - if "target" in channels - else None - ) - # Get total num channels - self.total_ch_names = self.channels["source"].copy() - self.total_ch_idx = self.source_ch_idx.copy() - if self.target_ch_idx is not None: - self.total_ch_names.extend(self.channels["target"]) - self.total_ch_idx.extend(self.target_ch_idx) - self._position_mapping() - - # Cached dictionary with tensors - # TODO: Delete after testing - self._cached_pos = [] - - def _position_mapping(self) -> None: - self.position_keys = [] - self.norm_meta_dict = {} - - for pos in self.positions: - self.position_keys.append(pos.data.name) - self.norm_meta_dict[str(pos.data.name)] = _read_norm_meta(pos) - - def _cache_dataset(self, index: int, channel_index: list[int], t: int = 0) -> None: - # Add the position to the cached_dict - # TODO: hardcoding to t=0 - data = self.positions[index].data.oindex[slice(t, t + 1), channel_index, :] - if data.dtype != np.float32: - data = data.astype(np.float32) - self.cache_dict[str(self.position_keys[index])] = torch.from_numpy(data) - - def _get_weight_map(self, position: Position) -> Tensor: - # Get the weight map from the position for the MONAI weightedcrop transform - raise NotImplementedError - - def __len__(self) -> int: - return len(self.positions) - - def __getitem__(self, index: int) -> Sample: - # Check if the sample is in the cache else add it - sample_id = self.position_keys[index] - if sample_id not in self.cache_dict: - _logger.info(f"Adding {sample_id} to cache") - self._cached_pos.append(index) - _logger.info(f"Cached positions: {self._cached_pos}") - self._cache_dataset(index, channel_index=self.total_ch_idx) - - # Get the sample from the cache - _logger.info("Getting sample from cache") - start_time = datetime.now() - images = self.cache_dict[sample_id].unbind(dim=1) - norm_meta = self.norm_meta_dict[str(sample_id)] - after_cache = datetime.now() - start_time - sample_images = {k: v for k, v in zip(self.total_ch_names, images)} - - if self.target_ch_idx is not None: - # FIXME: this uses the first target channel as weight for performance - # since adding a reference to a tensor does not copy - # maybe write a weight map in preprocessing to use more information? - sample_images["weight"] = sample_images[self.channels["target"][0]] - if norm_meta is not None: - sample_images["norm_meta"] = norm_meta - if self.transform: - before_transform = datetime.now() - sample_images = self.transform(sample_images) - after_transform = datetime.now() - before_transform - if "weight" in sample_images: - del sample_images["weight"] - sample = { - "index": sample_id, - "source": _stack_channels(sample_images, self.channels, "source"), - "norm_meta": norm_meta, - } - if self.target_ch_idx is not None: - sample["target"] = _stack_channels(sample_images, self.channels, "target") - - _logger.info(f"\nTime taken to cache: {after_cache}") - _logger.info(f"Time taken to transform: {after_transform}") - _logger.info(f"Time taken to get sample: {datetime.now() - start_time}\n") - - return sample - - def _load_sample(self, position: Position) -> Sample: - source, target = self.channel_map.source, self.channel_map.target - source_data = self._load_channel_data(position, source) - target_data = self._load_channel_data(position, target) - sample = {"source": source_data, "target": target_data} - return sample - - -class CachedDataModule(LightningDataModule): - def __init__( - self, - data_path: str, - source_channel: str | Sequence[str], - target_channel: str | Sequence[str], - split_ratio: float = 0.8, - batch_size: int = 16, - num_workers: int = 8, - architecture: Literal["2D", "UNeXt2", "2.5D", "3D", "fcmae"] = "UNeXt2", - yx_patch_size: tuple[int, int] = (256, 256), - normalizations: list[MapTransform] = [], - augmentations: list[MapTransform] = [], - z_window_size: int = 1, - timeout: int = 600, - ): - super().__init__() - self.data_path = data_path - self.source_channel = source_channel - self.target_channel = target_channel - self.batch_size = batch_size - self.num_workers = num_workers - self.target_2d = False if architecture in ["UNeXt2", "3D", "fcmae"] else True - self.split_ratio = split_ratio - self.yx_patch_size = yx_patch_size - self.normalizations = normalizations - self.augmentations = augmentations - self.z_window_size = z_window_size - self.timeout = timeout - - @property - def _base_dataset_settings(self) -> dict[str, dict[str, list[str]] | int]: - return { - "channels": {"source": self.source_channel}, - } - - def setup(self, stage: Literal["fit", "validate", "test", "predict"]) -> None: - dataset_settings = self._base_dataset_settings - if stage in ("fit", "validate"): - self._setup_fit(dataset_settings) - elif stage == "test": - raise NotImplementedError("Test stage is not supported") - elif stage == "predict": - raise NotImplementedError("Predict stage is not supported") - else: - raise NotImplementedError(f"Stage {stage} is not supported") - - def _train_transform(self) -> list[Callable]: - """Set the train augmentations""" - - if self.augmentations: - for aug in self.augmentations: - if isinstance(aug, MultiSampleTrait): - num_samples = aug.cropper.num_samples - if self.batch_size % num_samples != 0: - raise ValueError( - "Batch size must be divisible by `num_samples` per stack. " - f"Got batch size {self.batch_size} and " - f"number of samples {num_samples} for " - f"transform type {type(aug)}." - ) - self.train_patches_per_stack = num_samples - else: - self.augmentations = [] - - _logger.debug(f"Training augmentations: {self.augmentations}") - return list(self.augmentations) - - def _fit_transform(self) -> tuple[Compose, Compose]: - """(normalization -> maybe augmentation -> center crop) - Deterministic center crop as the last step of training and validation.""" - # TODO: These have a fixed order for now... () - final_crop = [ - CenterSpatialCropd( - keys=self.source_channel + self.target_channel, - roi_size=( - self.z_window_size, - self.yx_patch_size[0], - self.yx_patch_size[1], - ), - ) - ] - train_transform = Compose( - self.normalizations + self._train_transform() + final_crop - ) - val_transform = Compose(self.normalizations + final_crop) - return train_transform, val_transform - - def _set_fit_global_state(self, num_positions: int) -> torch.Tensor: - # disable metadata tracking in MONAI for performance - set_track_meta(False) - # shuffle positions, randomness is handled globally - return torch.randperm(num_positions) - - def _setup_fit(self, dataset_settings: dict) -> None: - """ - Setup the train and validation datasets. - """ - train_transform, val_transform = self._fit_transform() - dataset_settings["channels"]["target"] = self.target_channel - # Load the plate - plate = open_ome_zarr(self.data_path) - # shuffle positions, randomness is handled globally - positions = [pos for _, pos in plate.positions()] - shuffled_indices = self._set_fit_global_state(len(positions)) - positions = list(positions[i] for i in shuffled_indices) - num_train_fovs = int(len(positions) * self.split_ratio) - - shared_dict = Manager().dict() - self.train_dataset = CachedDataset( - shared_dict, - positions[:num_train_fovs], - transform=train_transform, - **dataset_settings, - ) - self.val_dataset = CachedDataset( - shared_dict, - positions[num_train_fovs:], - transform=val_transform, - **dataset_settings, - ) - - def train_dataloader(self) -> DataLoader: - if is_ddp_enabled(): - sampler = ShardedDistributedSampler(self.train_dataset, shuffle=True) - else: - sampler = None - _logger.info("Using standard sampler for non-distributed training") - return DataLoader( - self.train_dataset, - batch_size=self.batch_size // self.train_patches_per_stack, - num_workers=self.num_workers, - persistent_workers=bool(self.num_workers), - pin_memory=True, - shuffle=False, - timeout=self.timeout, - collate_fn=_collate_samples, - drop_last=True, - sampler=sampler, - ) - - def val_dataloader(self) -> DataLoader: - if is_ddp_enabled(): - sampler = ShardedDistributedSampler(self.val_dataset, shuffle=False) - else: - sampler = None - _logger.info("Using standard sampler for non-distributed validation") - - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - persistent_workers=bool(self.num_workers), - pin_memory=True, - shuffle=False, - timeout=self.timeout, - sampler=sampler, - ) From fdc377a40995ab84aa66bc41c054a19b5d7d473d Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Mon, 18 Nov 2024 16:34:10 -0800 Subject: [PATCH 46/49] fix import path --- viscy/data/combined.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/viscy/data/combined.py b/viscy/data/combined.py index 28cf9339..87036723 100644 --- a/viscy/data/combined.py +++ b/viscy/data/combined.py @@ -7,7 +7,7 @@ from torch.utils.data import ConcatDataset, DataLoader, Dataset from viscy.data.distributed import ShardedDistributedSampler -from viscy.data.hcs_ram import _collate_samples +from viscy.data.hcs import _collate_samples class CombineMode(Enum): From 96313fa451eb46c45680fc96b08be1bd2b4c60e2 Mon Sep 17 00:00:00 2001 From: Ziwen Liu <67518483+ziw-liu@users.noreply.github.com> Date: Mon, 2 Dec 2024 15:33:09 -0800 Subject: [PATCH 47/49] Arbitrary prediction time transforms (#209) * fix spelling in docstring and comment * add batched zoom transform for tta * add standalone lightning module for arbitrary TTA * fix composition of different zoom factors --- viscy/transforms.py | 37 ++++++++++++++++++++++ viscy/translation/engine.py | 61 +++++++++++++++++++++++++++++++++++-- 2 files changed, 95 insertions(+), 3 deletions(-) diff --git a/viscy/transforms.py b/viscy/transforms.py index 22f55e2c..f4e8103b 100644 --- a/viscy/transforms.py +++ b/viscy/transforms.py @@ -15,6 +15,7 @@ RandScaleIntensityd, RandWeightedCropd, ScaleIntensityRangePercentilesd, + Transform, ) from torch import Tensor from typing_extensions import Iterable, Literal @@ -294,3 +295,39 @@ def __call__(self, sample: Sample) -> Sample: for key, channels in self.channel_map.items(): results[key] = torch.cat([sample[ch] for ch in channels], dim=0) return results + + +class BatchedZoom(Transform): + "Batched zoom transform using ``torch.nn.functional.interpolate``." + + def __init__( + self, + scale_factor: float | tuple[float, float, float], + mode: Literal[ + "nearest", + "nearest-exact", + "linear", + "bilinear", + "bicubic", + "trilinear", + "area", + ], + align_corners: bool | None = None, + recompute_scale_factor: bool | None = None, + antialias: bool = False, + ) -> None: + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + self.recompute_scale_factor = recompute_scale_factor + self.antialias = antialias + + def __call__(self, sample: Tensor) -> Tensor: + return torch.nn.functional.interpolate( + sample, + scale_factor=self.scale_factor, + mode=self.mode, + align_corners=self.align_corners, + recompute_scale_factor=self.recompute_scale_factor, + antialias=self.antialias, + ) diff --git a/viscy/translation/engine.py b/viscy/translation/engine.py index 9f710b77..831217fc 100644 --- a/viscy/translation/engine.py +++ b/viscy/translation/engine.py @@ -1,7 +1,7 @@ import logging import os import random -from typing import Literal, Sequence, Union +from typing import Callable, Literal, Sequence, Union import numpy as np import torch @@ -109,7 +109,7 @@ class VSUNet(LightningModule): :param float lr: learning rate in training, defaults to 1e-3 :param Literal['WarmupCosine', 'Constant'] schedule: learning rate scheduler, defaults to "Constant" - :param str chkpt_path: path to the checkpoint to load weights, defaults to None + :param str ckpt_path: path to the checkpoint to load weights, defaults to None :param int log_batches_per_epoch: number of batches to log each training/validation epoch, has to be smaller than steps per epoch, defaults to 8 @@ -376,7 +376,8 @@ def perform_test_time_augmentations(self, source: Tensor) -> Tensor: elif self.tta_type == "median": prediction = torch.stack(predictions).median(dim=0).values elif self.tta_type == "product": - # Perform multiplication of predictions in logarithmic space for numerical stability adding epsion to avoid log(0) case + # Perform multiplication of predictions in logarithmic space + # for numerical stability adding epsilon to avoid log(0) case log_predictions = torch.stack([torch.log(p + 1e-9) for p in predictions]) log_prediction_sum = log_predictions.sum(dim=0) prediction = torch.exp(log_prediction_sum) @@ -477,6 +478,60 @@ def _crop_to_original(self, tensor: Tensor) -> Tensor: return cropped_tensor +class AugmentedPredictionVSUNet(LightningModule): + def __init__( + self, + model: nn.Module, + forward_transforms: list[Callable[[Tensor], Tensor]], + inverse_transforms: list[Callable[[Tensor], Tensor]], + reduction: Literal["mean", "median"] = "mean", + ) -> None: + super().__init__() + down_factor = 2**model.num_blocks + self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) + self.model = model + self._forward_transforms = forward_transforms + self._inverse_transforms = inverse_transforms + self._reduction = reduction + + def forward(self, x: Tensor) -> Tensor: + return self.model(x) + + def setup(self, stage: str) -> None: + if stage != "predict": + raise NotImplementedError( + f"Only the 'predict' stage is supported by {type(self)}" + ) + + def _reduce_predictions(self, preds: list[Tensor]) -> Tensor: + prediction = torch.stack(preds, dim=0) + if self._reduction == "mean": + prediction = prediction.mean(dim=0) + elif self._reduction == "median": + prediction = prediction.median(dim=0).values + return prediction + + def predict_step( + self, batch: Sample, batch_idx: int, dataloader_idx: int = 0 + ) -> Tensor: + source = batch["source"] + preds = [] + for forward_t, inverse_t in zip( + self._forward_transforms, self._inverse_transforms + ): + source = forward_t(source) + source = self._predict_pad(source) + pred = self.forward(source) + pred = self._predict_pad.inverse(pred) + pred = inverse_t(pred) + preds.append(pred) + if len(preds) == 1: + prediction = preds[0] + else: + prediction = self._reduce_predictions(preds) + return prediction + + class FcmaeUNet(VSUNet): def __init__( self, From 6e1818bdb50e85b344823029686aef8881e80ca8 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Mon, 2 Dec 2024 15:45:18 -0800 Subject: [PATCH 48/49] add docstrings --- viscy/data/gpu_aug.py | 51 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/viscy/data/gpu_aug.py b/viscy/data/gpu_aug.py index f5481ce0..a445fff5 100644 --- a/viscy/data/gpu_aug.py +++ b/viscy/data/gpu_aug.py @@ -29,6 +29,8 @@ class GPUTransformDataModule(ABC, LightningDataModule): + """Abstract data module with GPU transforms.""" + train_dataset: Dataset val_dataset: Dataset batch_size: int @@ -92,6 +94,26 @@ def val_gpu_transforms(self) -> Compose: ... class CachedOmeZarrDataset(Dataset): + """Dataset for cached OME-Zarr arrays. + + Parameters + ---------- + positions : list[Position] + List of FOVs to load images from. + channel_names : list[str] + List of channel names to load. + cache_map : DictProxy + Shared dictionary for caching loaded volumes. + transform : Compose | None, optional + Composed transforms to be applied on the CPU, by default None + array_key : str, optional + The image array key name (multi-scale level), by default "0" + load_normalization_metadata : bool, optional + Load normalization metadata in the sample dictionary, by default True + skip_cache : bool, optional + Skip caching to save RAM, by default False + """ + def __init__( self, positions: list[Position], @@ -148,6 +170,35 @@ def __getitem__(self, idx: int) -> dict[str, Tensor]: class CachedOmeZarrDataModule(GPUTransformDataModule): + """Data module for cached OME-Zarr arrays. + + Parameters + ---------- + data_path : Path + Path to the HCS OME-Zarr dataset. + channels : str | list[str] + Channel names to load. + batch_size : int + Batch size for training and validation. + num_workers : int + Number of workers for data-loaders. + split_ratio : float + Fraction of the FOVs used for the training split. + The rest will be used for validation. + train_cpu_transforms : list[DictTransform] + Transforms to be applied on the CPU during training. + val_cpu_transforms : list[DictTransform] + Transforms to be applied on the CPU during validation. + train_gpu_transforms : list[DictTransform] + Transforms to be applied on the GPU during training. + val_gpu_transforms : list[DictTransform] + Transforms to be applied on the GPU during validation. + pin_memory : bool, optional + Use page-locked memory in data-loaders, by default True + skip_cache : bool, optional + Skip caching for this dataset, by default False + """ + def __init__( self, data_path: Path, From b864c6e1100376618954c4b2fff17d1faab62344 Mon Sep 17 00:00:00 2001 From: Ziwen Liu <67518483+ziw-liu@users.noreply.github.com> Date: Thu, 2 Jan 2025 14:30:40 -0800 Subject: [PATCH 49/49] fix typo in docstring --- viscy/data/ctmc_v1.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/viscy/data/ctmc_v1.py b/viscy/data/ctmc_v1.py index 3a0248cc..3c888175 100644 --- a/viscy/data/ctmc_v1.py +++ b/viscy/data/ctmc_v1.py @@ -31,7 +31,7 @@ class CTMCv1DataModule(GPUTransformDataModule): num_workers : int, optional Number of dataloading workers, by default 8. val_subsample_ratio : int, optional - Skip evert N frames for validation to reduce redundancy in video, + Skip every N frames for validation to reduce redundancy in video, by default 30. channel_name : str, optional Name of the DIC channel, by default "DIC".