Skip to content

Commit

Permalink
UMAP line plot to assess temporal smoothness in features space (#176)
Browse files Browse the repository at this point in the history
* add maplotlib style sheet for figure making

* add cell division attribution

* add matplotlib style sheet

* move attribution computation to lca

* tweak contrast limits and text

* add captum to optional dependencies

* move attribution function to a method of the classifier

* add script to show organelle dynamics

* add occlusion attribution

* more generic save path

* add uninfected cell

* tweak subplot spacing

* lower case titles

* reduce UMAP components to 2 and add indices

* add script to make the bridge gaps figure
  • Loading branch information
ziw-liu authored Sep 27, 2024
1 parent 10219d3 commit 42a0cb5
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def render(img, cmaps: list[str]):
subfigs = fig.subfigures(1, 2, wspace=0.02, width_ratios=[4, 7])

umap_fig = subfigs[0]
umap_fig.suptitle("A", horizontalalignment="left", x=0, y=1)
umap_fig.suptitle("a", horizontalalignment="left", x=0, y=1)
umap_ax = umap_fig.subplots(1, 1)
umap_ax.invert_xaxis()

Expand Down Expand Up @@ -208,7 +208,7 @@ def render(img, cmaps: list[str]):
)

img_fig = subfigs[1]
img_fig.suptitle("B", horizontalalignment="left", x=-0, y=1)
img_fig.suptitle("b", horizontalalignment="left", x=-0, y=1)
img_axes = img_fig.subplots(3, 4, sharex=True, sharey=True)

for i, (ax, rend, time, track_name) in enumerate(
Expand Down
111 changes: 111 additions & 0 deletions applications/contrastive_phenotyping/figures/track_smoothness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# %%
from pathlib import Path

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from cmap import Colormap
from iohub import open_ome_zarr
from skimage.color import label2rgb
from skimage.exposure import rescale_intensity

from viscy.representation.embedding_writer import read_embedding_dataset
from viscy.representation.evaluation import compute_umap

# %%
t_slice = slice(18, 33)
y_slice = slice(16, 144)
x_slice = slice(0, 224)

phase = open_ome_zarr(
"/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr/B/4/8"
)["0"][t_slice, 3, 31, y_slice, x_slice]

segments = open_ome_zarr(
"/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr/B/4/8"
)["0"][t_slice, 0, 0, y_slice, x_slice]

# %%
features = read_embedding_dataset(
"/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr"
)

# %%
_, _, umap_df = compute_umap(features)
umap_df

# %%
track_ids = np.unique(segments)[1:]
track_ids

# %%
selected_umap = umap_df[
(umap_df["fov_name"] == "/B/4/8")
& umap_df["track_id"].isin(track_ids)
& (umap_df["t"] >= t_slice.start)
& (umap_df["t"] < t_slice.stop)
]

selected_umap["HPI"] = selected_umap["t"] * 0.5 + 3

# %%
plt.style.use("../evaluation/figure.mplstyle")
fig = plt.figure(figsize=(5.5, 4.5), layout="constrained")
subfigs = fig.subfigures(2, 1, wspace=0.02, height_ratios=[3, 2])

img_fig = subfigs[0]
img_fig.suptitle("a", horizontalalignment="left", x=0, y=1)
img_ax = img_fig.subplots(3, 5)

clim = 0.03
cmap = Colormap("tab10")

labels = label2rgb(
segments,
image=rescale_intensity(phase, in_range=(-clim, clim), out_range=(0, 1)),
colors=cmap(range(10)),
)

for t, (a, rgb) in enumerate(zip(img_ax.flatten(), labels)):
a.imshow(rgb)
a.set_title(f"{(t+t_slice.start)/2 + 3} HPI")
a.axis("off")

line_fig = subfigs[1]
line_fig.suptitle("b", horizontalalignment="left", x=0, y=1)
line_ax_1 = line_fig.subplots(1, 1)
line_ax_2 = line_ax_1.twinx()
sns.lineplot(
data=selected_umap,
x="HPI",
y="UMAP1",
hue="track_id",
palette=[c for c in cmap([2, 4, 6])],
ax=line_ax_1,
)
sns.move_legend(line_ax_1, "upper right", title="Track ID")
sns.lineplot(
data=selected_umap,
x="HPI",
y="UMAP2",
hue="track_id",
palette=[c for c in cmap([2, 4, 6])],
ax=line_ax_2,
linestyle="--",
legend=False,
)

fmt = mpl.ticker.StrMethodFormatter("{x:.1f}")
for a in [line_ax_1, line_ax_2]:
a.xaxis.set_major_formatter(fmt)
a.yaxis.set_major_formatter(fmt)

# %%
fig.savefig(
Path.home()
/ "gdrive/publications/learning_impacts_of_infection/fig_manuscript/si/appendix_track_smoothness.pdf",
dpi=300,
)

# %%
8 changes: 4 additions & 4 deletions viscy/representation/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,22 +250,22 @@ def compute_umap(embedding_dataset, normalize_features=True):

# Compute UMAP for features and projections
# Computing 3 components to enable 3D visualization.
umap_features = umap.UMAP(random_state=42, n_components=3)
umap_projection = umap.UMAP(random_state=42, n_components=3)
umap_features = umap.UMAP(random_state=42, n_components=2)
umap_projection = umap.UMAP(random_state=42, n_components=2)
umap_features_embedding = umap_features.fit_transform(scaled_features)
umap_projection_embedding = umap_projection.fit_transform(scaled_projections)

# Prepare DataFrame with id and UMAP coordinates
umap_df = pd.DataFrame(
{
"id": embedding_dataset["id"].values,
"track_id": embedding_dataset["track_id"].values,
"t": embedding_dataset["t"].values,
"fov_name": embedding_dataset["fov_name"].values,
"UMAP1": umap_features_embedding[:, 0],
"UMAP2": umap_features_embedding[:, 1],
"UMAP3": umap_features_embedding[:, 2],
"UMAP1_proj": umap_projection_embedding[:, 0],
"UMAP2_proj": umap_projection_embedding[:, 1],
"UMAP3_proj": umap_projection_embedding[:, 2],
}
)

Expand Down

0 comments on commit 42a0cb5

Please sign in to comment.