Skip to content

Commit

Permalink
SSL: fix MLP head and remove L2 normalization (#145)
Browse files Browse the repository at this point in the history
* draft projection head per Update the projection head (normalization and size). #139

* reorganize comments in example fit config

* configurable stem stride and projection dimensions

* update type hint and docstring for ContrastiveEncoder

* clarify embedding_dim

* use the forward method directly for projected

* normalize projections only when fitting
the projected features saved during prediction is now *not* normalized

* remove unused logger

* refactor training code into translation and representation modules

* extract image logging functions

* use AdamW instead of Adam for contrastive learning

* inline single-use argument

* fix normalization

* fix MLP layer order

* fix output dimensions

* remove L2 normalization before computing loss

* compute rank of features and projections

* documentation

---------

Co-authored-by: Shalin Mehta <[email protected]>
  • Loading branch information
ziw-liu and mattersoflight authored Aug 31, 2024
1 parent 6e7d61f commit 1f269c7
Show file tree
Hide file tree
Showing 28 changed files with 335 additions and 336 deletions.
43 changes: 32 additions & 11 deletions applications/contrastive_phenotyping/contrastive_cli/fit.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# See help here on how to configure hyper-parameters with config files: https://lightning.ai/docs/pytorch/stable/cli/lightning_cli_advanced.html
# See help here on how to configure hyper-parameters with config files:
# https://lightning.ai/docs/pytorch/stable/cli/lightning_cli_advanced.html
seed_everything: 42
trainer:
accelerator: gpu
Expand All @@ -8,16 +9,19 @@ trainer:
precision: 32-true
logger:
class_path: lightning.pytorch.loggers.TensorBoardLogger
# Nesting the logger config like this is equivalent to
# supplying the following argument to `lightning.pytorch.Trainer`:
# logger=TensorBoardLogger(
# "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/contrastive_tune_augmentations",
# log_graph=True,
# version="vanilla",
# )
init_args:
save_dir: /hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/contrastive_tune_augmentations
version: chocolate # this is the name of the experiment. The logs will be saved in save_dir/lightning_logs/version
# this is the name of the experiment.
# The logs will be saved in `save_dir/lightning_logs/version`
version: l2_projection_batchnorm
log_graph: True
# Nesting the logger config like this is equivalent to supplying the following argument to lightning.pytorch.Trainer
# logger=TensorBoardLogger(
# "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/contrastive_tune_augmentations",
# log_graph=True,
# version="vanilla",
# )
callbacks:
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
init_args:
Expand All @@ -34,12 +38,29 @@ trainer:
enable_checkpointing: true
inference_mode: true
use_distributed_sampler: true
# synchronize batchnorm parameters across multiple GPUs.
# important for contrastive learning to normalize the tensors across the whole batch.
sync_batchnorm: true
model:
backbone: convnext_tiny
in_channels: 2
encoder:
class_path: viscy.representation.contrastive.ContrastiveEncoder
init_args:
backbone: convnext_tiny
in_channels: 2
in_stack_depth: 15
stem_kernel_size: [5, 4, 4]
stem_stride: [5, 4, 4]
embedding_dim: 768
projection_dim: 128
drop_path_rate: 0.0
loss_function:
class_path: torch.nn.TripletMarginLoss
init_args:
margin: 0.5
lr: 0.0002
log_batches_per_epoch: 3
log_samples_per_batch: 3
lr: 0.0002
example_input_array_shape: [1, 2, 15, 256, 256]
data:
data_path: /hpc/projects/virtual_staining/2024_02_04_A549_DENV_ZIKV_timelapse/registered_chunked.zarr
tracks_path: /hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track/tracking_v1.zarr
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,21 @@
from umap import UMAP


from viscy.light.embedding_writer import read_embedding_dataset
from viscy.representation.embedding_writer import read_embedding_dataset
from viscy.data.triplet import TripletDataset, TripletDataModule
from iohub import open_ome_zarr
import monai.transforms as transforms

# %% Paths and parameters.

features_path = Path(
"/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/contrastive_tune_augmentations/predict/2024_02_04/tokenized-drop_path_0_0-2024-06-13.zarr"
"/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/contrastive_tune_augmentations/predict/2024_06_13/l2_projection_batchnorm-128p.zarr"
)
data_path = Path(
"/hpc/projects/virtual_staining/2024_02_04_A549_DENV_ZIKV_timelapse/registered_chunked.zarr"
"/hpc/projects/intracellular_dashboard/viral-sensor/2024_06_13_SEC61_TOMM20_ZIKV_DENGUE_1/2-register/registered_chunked.zarr"
)
tracks_path = Path(
"/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track/tracking_v1.zarr"
"/hpc/projects/intracellular_dashboard/viral-sensor/2024_06_13_SEC61_TOMM20_ZIKV_DENGUE_1/4.1-tracking/test_tracking_4.zarr"
)

# %%
Expand All @@ -34,8 +34,8 @@

# %%
# Compute PCA of the features and projections to estimate the number of components to keep.
PCA_features = PCA().fit(embedding_dataset["features"].values)
PCA_projection = PCA().fit(embedding_dataset["projections"].values)
PCA_features = PCA(n_components=100).fit(embedding_dataset["features"].values)
PCA_projection = PCA(n_components=100).fit(embedding_dataset["projections"].values)

plt.plot(PCA_features.explained_variance_ratio_, label="features")
plt.plot(PCA_projection.explained_variance_ratio_, label="projections")
Expand All @@ -50,11 +50,15 @@
# * Heatmaps of annotations over UMAPs.


# %%
print(np.linalg.matrix_rank(embedding_dataset["features"].values))
print(np.linalg.matrix_rank(embedding_dataset["projections"].values))

# %%
# Extract a track from the dataset and visualize its features.

fov_name = "/B/4/4"
track_id = 71
fov_name = "/0/1/000000" # "/B/4/4" FOV names can change between datasets.
track_id = 21
all_tracks_FOV = embedding_dataset.sel(fov_name=fov_name)
a_track_in_FOV = all_tracks_FOV.sel(track_id=track_id)
# Why is sample dimension ~22000 long after the dataset is sliced by FOV and by track_id?
Expand Down Expand Up @@ -253,7 +257,7 @@ def load_annotation(da, path, name, categories: dict | None = None):

# %%
ann_root = Path(
"/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track"
"/hpc/projects/intracellular_dashboard/viral-sensor/2024_06_13_SEC61_TOMM20_ZIKV_DENGUE_1/4.1-tracking"
)

infection = load_annotation(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ trainer:
num_nodes: 1
precision: 32-true
callbacks:
- class_path: viscy.light.embedding_writer.EmbeddingWriter
- class_path: viscy.representation.embedding_writer.EmbeddingWriter
init_args:
output_path: "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/contrastive_tune_augmentations/predict/test_prediction_code.zarr"
# edit the following lines to specify logging path
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from lightning.pytorch.loggers import TensorBoardLogger

from viscy.data.triplet import TripletDataModule
from viscy.light.engine import ContrastiveModule
from viscy.representation.engine import ContrastiveModule


def main():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
import torchview

from viscy.light.engine import ContrastiveModule
from viscy.representation.engine import ContrastiveModule
from viscy.representation.contrastive import ContrastiveEncoder, UNeXt2Stem

# %load_ext autoreload
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torch import Tensor

from viscy.data.hcs import Sample
from viscy.light.engine import VSUNet
from viscy.translation.engine import VSUNet

#
# %% Methods to compute confusion matrix per cell using torchmetrics
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
)

from viscy.data.hcs import HCSDataModule
from viscy.light.predict_writer import HCSPredictionWriter
from viscy.translation.predict_writer import HCSPredictionWriter
from viscy.transforms import NormalizeSampled

# %% # %% write the predictions to a zarr file
Expand Down
2 changes: 1 addition & 1 deletion examples/configs/predict_example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ predict:
num_nodes: 1
precision: 32-true
callbacks:
- class_path: viscy.light.predict_writer.HCSPredictionWriter
- class_path: viscy.translation.predict_writer.HCSPredictionWriter
init_args:
output_store: null
write_input: false
Expand Down
7 changes: 4 additions & 3 deletions examples/virtual_staining/VS_model_inference/demo_vscyto2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@

from iohub import open_ome_zarr
from plot import plot_vs_n_fluor

# Viscy classes for the trainer and model
from viscy.data.hcs import HCSDataModule
from viscy.light.engine import FcmaeUNet
from viscy.light.predict_writer import HCSPredictionWriter
from viscy.light.trainer import VSTrainer
from viscy.translation.engine import FcmaeUNet
from viscy.translation.predict_writer import HCSPredictionWriter
from viscy.translation.trainer import VSTrainer
from viscy.transforms import NormalizeSampled

# %% [markdown]
Expand Down
6 changes: 3 additions & 3 deletions examples/virtual_staining/VS_model_inference/demo_vscyto3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
from plot import plot_vs_n_fluor
from viscy.data.hcs import HCSDataModule
# Viscy classes for the trainer and model
from viscy.light.engine import VSUNet
from viscy.light.predict_writer import HCSPredictionWriter
from viscy.light.trainer import VSTrainer
from viscy.translation.engine import VSUNet
from viscy.translation.predict_writer import HCSPredictionWriter
from viscy.translation.trainer import VSTrainer
from viscy.transforms import NormalizeSampled

# %% [markdown]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
from plot import plot_vs_n_fluor
from viscy.data.hcs import HCSDataModule
# Viscy classes for the trainer and model
from viscy.light.engine import VSUNet
from viscy.light.predict_writer import HCSPredictionWriter
from viscy.light.trainer import VSTrainer
from viscy.translation.engine import VSUNet
from viscy.translation.predict_writer import HCSPredictionWriter
from viscy.translation.trainer import VSTrainer
from viscy.transforms import NormalizeSampled

# %% [markdown]
Expand Down
4 changes: 2 additions & 2 deletions examples/virtual_staining/dlmbl_exercise/solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@
from viscy.data.hcs import HCSDataModule
from viscy.evaluation.evaluation_metrics import mean_average_precision
# Trainer class and UNet.
from viscy.light.engine import MixedLoss, VSUNet
from viscy.light.trainer import VSTrainer
from viscy.translation.engine import MixedLoss, VSUNet
from viscy.translation.trainer import VSTrainer
# training augmentations
from viscy.transforms import (NormalizeSampled, RandAdjustContrastd,
RandAffined, RandGaussianNoised,
Expand Down
4 changes: 2 additions & 2 deletions examples/virtual_staining/img2img_translation/solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@
# HCSDataModule makes it easy to load data during training.
from viscy.data.hcs import HCSDataModule
# Trainer class and UNet.
from viscy.light.engine import MixedLoss, VSUNet
from viscy.light.trainer import VSTrainer
from viscy.translation.engine import MixedLoss, VSUNet
from viscy.translation.trainer import VSTrainer
# training augmentations
from viscy.transforms import (NormalizeSampled, RandAdjustContrastd,
RandAffined, RandGaussianNoised,
Expand Down
2 changes: 1 addition & 1 deletion tests/data/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pytest import mark

from viscy.data.hcs import HCSDataModule
from viscy.light.trainer import VSTrainer
from viscy.translation.trainer import VSTrainer


@mark.parametrize("default_channels", [True, False])
Expand Down
2 changes: 1 addition & 1 deletion tests/light/test_engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from viscy.light.engine import FcmaeUNet
from viscy.translation.engine import FcmaeUNet


def test_fcmae_vsunet() -> None:
Expand Down
38 changes: 38 additions & 0 deletions viscy/_log_images.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from typing import Sequence

import numpy as np
from matplotlib.pyplot import get_cmap
from skimage.exposure import rescale_intensity
from torch import Tensor


def detach_sample(imgs: Sequence[Tensor], log_samples_per_batch: int):
num_samples = min(imgs[0].shape[0], log_samples_per_batch)
samples = []
for i in range(num_samples):
patches = []
for img in imgs:
patch = img[i].detach().cpu().numpy()
patch = np.squeeze(patch[:, patch.shape[1] // 2])
patches.append(patch)
samples.append(patches)
return samples


def render_images(imgs: Sequence[Sequence[np.ndarray]], cmaps: list[str] = []):
images_grid = []
for sample_images in imgs:
images_row = []
for i, image in enumerate(sample_images):
if cmaps:
cm_name = cmaps[i]
else:
cm_name = "gray" if i == 0 else "inferno"
if image.ndim == 2:
image = image[np.newaxis]
for channel in image:
channel = rescale_intensity(channel, out_range=(0, 1))
render = get_cmap(cm_name)(channel, bytes=True)[..., :3]
images_row.append(render)
images_grid.append(np.concatenate(images_row, axis=1))
return np.concatenate(images_grid, axis=0)
4 changes: 2 additions & 2 deletions viscy/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from lightning.pytorch.loggers import TensorBoardLogger

from viscy.data.hcs import HCSDataModule
from viscy.light.engine import VSUNet
from viscy.light.trainer import VSTrainer
from viscy.translation.engine import VSUNet
from viscy.translation.trainer import VSTrainer


class VSLightningCLI(LightningCLI):
Expand Down
2 changes: 1 addition & 1 deletion viscy/cli/contrastive_triplet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from lightning.pytorch.loggers import TensorBoardLogger

from viscy.data.triplet import TripletDataModule
from viscy.light.engine import ContrastiveModule
from viscy.representation.engine import ContrastiveModule


class ContrastiveLightningCLI(LightningCLI):
Expand Down
2 changes: 1 addition & 1 deletion viscy/data/hcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def __getitem__(self, index: int) -> Sample:
class MaskTestDataset(SlidingWindowDataset):
"""Torch dataset where each element is a window of
(C, Z, Y, X) where C=2 (source and target) and Z is ``z_window_size``.
This a testing stage version of :py:class:`viscy.light.data.SlidingWindowDataset`,
This a testing stage version of :py:class:`viscy.data.hcs.SlidingWindowDataset`,
and can only be used with batch size 1 for efficiency (no padding for collation),
since the mask is not available for each stack.
Expand Down
Loading

0 comments on commit 1f269c7

Please sign in to comment.