From 1bc680c99816c35d2a174d4807aebad9f622171f Mon Sep 17 00:00:00 2001 From: Alishba Imran Date: Tue, 17 Sep 2024 11:43:54 -0700 Subject: [PATCH 01/18] updated files --- .../evaluation/cosine_similarity.py | 313 ++++++++++++++++++ .../evaluation/pca_umap_embeddings_time.py | 222 +++++++++++++ viscy/representation/evaluation.py | 130 ++++++++ 3 files changed, 665 insertions(+) create mode 100644 applications/contrastive_phenotyping/evaluation/cosine_similarity.py create mode 100644 applications/contrastive_phenotyping/evaluation/pca_umap_embeddings_time.py diff --git a/applications/contrastive_phenotyping/evaluation/cosine_similarity.py b/applications/contrastive_phenotyping/evaluation/cosine_similarity.py new file mode 100644 index 000000000..79b25741b --- /dev/null +++ b/applications/contrastive_phenotyping/evaluation/cosine_similarity.py @@ -0,0 +1,313 @@ +# %% +from pathlib import Path + + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import plotly.express as px +import seaborn as sns +from sklearn.decomposition import PCA +from sklearn.preprocessing import StandardScaler +from umap import UMAP +from sklearn.decomposition import PCA + + +from viscy.representation.embedding_writer import read_embedding_dataset +from viscy.representation.evaluation import dataset_of_tracks, load_annotation +from sklearn.metrics.pairwise import cosine_similarity +from collections import defaultdict +from viscy.representation.evaluation import calculate_cosine_similarity_cell +from viscy.representation.evaluation import compute_displacement_mean_std +from viscy.representation.evaluation import compute_displacement + + +# %% Paths and parameters. + + +features_path_30_min = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr" +) + + +feature_path_no_track = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_random_sampling2/feb_fixed_test_predict.zarr") + + +features_path_any_time = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_difcell_randomtime_sampling/Ver2_updateTracking_refineModel/predictions/Feb_1chan_128patch_32projDim/1chan_128patch_63ckpt_FebTest.zarr") + + +data_path = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr" +) + + +tracks_path = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr" +) + + +# %% Load embedding datasets for all three sampling +fov_name = '/B/4/6' +track_id = 4 + + +embedding_dataset_30_min = read_embedding_dataset(features_path_30_min) +embedding_dataset_no_track = read_embedding_dataset(feature_path_no_track) +embedding_dataset_any_time = read_embedding_dataset(features_path_any_time) + + +# Calculate cosine similarities for each sampling +time_points_30_min, cosine_similarities_30_min = calculate_cosine_similarity_cell(embedding_dataset_30_min, fov_name, track_id) +time_points_no_track, cosine_similarities_no_track = calculate_cosine_similarity_cell(embedding_dataset_no_track, fov_name, track_id) +time_points_any_time, cosine_similarities_any_time = calculate_cosine_similarity_cell(embedding_dataset_any_time, fov_name, track_id) + + +# %% Plot cosine similarities over time for all three conditions + + +plt.figure(figsize=(10, 6)) + + +plt.plot(time_points_no_track, cosine_similarities_no_track, marker='o', label='classical contrastive (no tracking)') +plt.plot(time_points_any_time, cosine_similarities_any_time, marker='o', label='cell aware') +plt.plot(time_points_30_min, cosine_similarities_30_min, marker='o', label='cell & time aware (interval 30 min)') + + + + +plt.xlabel("Time Delay (t)") +plt.ylabel("Cosine Similarity with First Time Point") +plt.title("Cosine Similarity Over Time for Infected Cell") + + +#plt.savefig('infected_cell_example.svg', format='svg') +#plt.savefig('infected_cell_example.pdf', format='pdf') + + +plt.grid(True) +plt.legend() + + + + +plt.show() +# %% + + +# %% import statements + + +from pathlib import Path +import numpy as np +import matplotlib.pyplot as plt +from sklearn.metrics.pairwise import euclidean_distances + + +from viscy.representation.embedding_writer import read_embedding_dataset +from viscy.representation.evaluation import dataset_of_tracks, load_annotation +from sklearn.metrics.pairwise import cosine_similarity + + +# %% Paths to datasets +features_path_30_min = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr") +feature_path_no_track = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_random_sampling2/feb_fixed_test_predict.zarr") +#features_path_any_time = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_difcell_randomtime_sampling/Ver2_updateTracking_refineModel/predictions/Feb_1chan_128patch_32projDim/1chan_128patch_63ckpt_FebTest.zarr") + + +# %% Read embedding datasets +embedding_dataset_30_min = read_embedding_dataset(features_path_30_min) +embedding_dataset_no_track = read_embedding_dataset(feature_path_no_track) +#embedding_dataset_any_time = read_embedding_dataset(features_path_any_time) + + +# %% Compute displacements for both datasets (using Euclidean distance and Cosine similarity) +max_tau = 10 # Maximum time shift (tau) to compute displacements + + +mean_displacement_30_min, std_displacement_30_min = compute_displacement_mean_std(embedding_dataset_30_min, max_tau, use_cosine=False, use_dissimilarity=True) +mean_displacement_no_track, std_displacement_no_track = compute_displacement_mean_std(embedding_dataset_no_track, max_tau, use_cosine=False, use_dissimilarity=True) +#mean_displacement_any_time, std_displacement_any_time = compute_displacement_mean_std(embedding_dataset_any_time, max_tau, use_cosine=False) + + +mean_displacement_30_min_cosine, std_displacement_30_min_cosine = compute_displacement_mean_std(embedding_dataset_30_min, max_tau, use_cosine=True, use_dissimilarity=True) +mean_displacement_no_track_cosine, std_displacement_no_track_cosine = compute_displacement_mean_std(embedding_dataset_no_track, max_tau, use_cosine=True, use_dissimilarity=True) +#mean_displacement_any_time_cosine, std_displacement_any_time_cosine = compute_displacement_mean_std(embedding_dataset_any_time, max_tau, use_cosine=True) +# %% Plot 1: Euclidean Displacements +plt.figure(figsize=(10, 6)) + + +taus = list(mean_displacement_30_min.keys()) +mean_values_30_min = list(mean_displacement_30_min.values()) +std_values_30_min = list(std_displacement_30_min.values()) + + +mean_values_no_track = list(mean_displacement_no_track.values()) +std_values_no_track = list(std_displacement_no_track.values()) + + +# mean_values_any_time = list(mean_displacement_any_time.values()) +# std_values_any_time = list(std_displacement_any_time.values()) + + +# Plotting Euclidean displacements +plt.plot(taus, mean_values_30_min, marker='o', label='Cell & Time Aware (30 min interval)') +plt.fill_between(taus, np.array(mean_values_30_min) - np.array(std_values_30_min), np.array(mean_values_30_min) + np.array(std_values_30_min), + color='gray', alpha=0.3, label='Std Dev (30 min interval)') + + +plt.plot(taus, mean_values_no_track, marker='o', label='Classical Contrastive (No Tracking)') +plt.fill_between(taus, np.array(mean_values_no_track) - np.array(std_values_no_track), np.array(mean_values_no_track) + np.array(std_values_no_track), + color='blue', alpha=0.3, label='Std Dev (No Tracking)') + + +plt.xlabel('Time Shift (τ)') +plt.ylabel('Displacement') +plt.title('Embedding Displacement Over Time') +plt.grid(True) +plt.legend() + + +# plt.savefig('embedding_displacement_euclidean.svg', format='svg') +# plt.savefig('embedding_displacement_euclidean.pdf', format='pdf') + + +# Show the Euclidean plot +plt.show() + + +# %% Plot 2: Cosine Displacements +plt.figure(figsize=(10, 6)) + + +# Plotting Cosine displacements +mean_values_30_min_cosine = list(mean_displacement_30_min_cosine.values()) +std_values_30_min_cosine = list(std_displacement_30_min_cosine.values()) + + +mean_values_no_track_cosine = list(mean_displacement_no_track_cosine.values()) +std_values_no_track_cosine = list(std_displacement_no_track_cosine.values()) + + +plt.plot(taus, mean_values_30_min_cosine, marker='o', label='Cell & Time Aware (30 min interval)') +plt.fill_between(taus, np.array(mean_values_30_min_cosine) - np.array(std_values_30_min_cosine), np.array(mean_values_30_min_cosine) + np.array(std_values_30_min_cosine), + color='gray', alpha=0.3, label='Std Dev (30 min interval, Cosine)') + + +plt.plot(taus, mean_values_no_track_cosine, marker='o', label='Classical Contrastive (No Tracking)') +plt.fill_between(taus, np.array(mean_values_no_track_cosine) - np.array(std_values_no_track_cosine), np.array(mean_values_no_track_cosine) + np.array(std_values_no_track_cosine), + color='blue', alpha=0.3, label='Std Dev (No Tracking)') + + +plt.xlabel('Time Shift (τ)') +plt.ylabel('Cosine Similarity') +plt.title('Embedding Displacement Over Time') +plt.grid(True) +plt.legend() + + +# Show the Cosine plot +plt.show() +# %% + + +import seaborn as sns +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +from pathlib import Path +from collections import defaultdict +from sklearn.metrics.pairwise import cosine_similarity +from viscy.representation.embedding_writer import read_embedding_dataset + + +# %% Paths to datasets +features_path_30_min = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr") +feature_path_no_track = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_random_sampling2/feb_fixed_test_predict.zarr") + + +# %% Read embedding datasets +embedding_dataset_30_min = read_embedding_dataset(features_path_30_min) +embedding_dataset_no_track = read_embedding_dataset(feature_path_no_track) + + +# %% Compute displacements for both datasets (using Cosine similarity) +max_tau = 10 # Maximum time shift (tau) to compute displacements + + +# Compute displacements for Cell & Time Aware (30 min interval) using Cosine similarity +displacement_per_tau_aware_cosine = compute_displacement(embedding_dataset_30_min, max_tau, use_cosine=True, use_dissimilarity=True) + + +# Compute displacements for Classical Contrastive (No Tracking) using Cosine similarity +displacement_per_tau_contrastive_cosine = compute_displacement(embedding_dataset_no_track, max_tau, use_cosine=True, use_dissimilarity=True) + + +# %% Prepare data for violin plot +# Prepare the data in a long-form DataFrame for the violin plot +def prepare_violin_data(taus, displacement_aware, displacement_contrastive): + # Create a list to hold the data + data = [] + + + # Populate the data for Cell & Time Aware + for tau in taus: + displacements_aware = displacement_aware.get(tau, []) + for displacement in displacements_aware: + data.append({'Time Shift (τ)': tau, 'Displacement': displacement, 'Sampling': 'Cell & Time Aware (30 min interval)'}) + + + # Populate the data for Classical Contrastive + for tau in taus: + displacements_contrastive = displacement_contrastive.get(tau, []) + for displacement in displacements_contrastive: + data.append({'Time Shift (τ)': tau, 'Displacement': displacement, 'Sampling': 'Classical Contrastive (No Tracking)'}) + + + # Convert to a DataFrame + df = pd.DataFrame(data) + return df + + +# Assuming 'displacement_per_tau_aware_cosine' and 'displacement_per_tau_contrastive_cosine' hold the displacements as dictionaries +taus = list(displacement_per_tau_aware_cosine.keys()) + + +# Prepare the violin plot data +df = prepare_violin_data(taus, displacement_per_tau_aware_cosine, displacement_per_tau_contrastive_cosine) + + +# Create a violin plot using seaborn +plt.figure(figsize=(12, 8)) +sns.violinplot( + x='Time Shift (τ)', + y='Displacement', + hue='Sampling', + data=df, + palette='Set2', + scale='width', + bw=0.2, + inner=None, + split=True, + cut=0 +) + + +# Add labels and title +plt.xlabel('Time Shift (τ)', fontsize=14) +plt.ylabel('Cosine Dissimilarity', fontsize=14) +plt.title('Cosine Dissimilarity Distribution', fontsize=16) +plt.grid(True, linestyle='--', alpha=0.6) # Lighter grid lines for less distraction +plt.legend(title='Sampling', fontsize=12, title_fontsize=14) + + +plt.ylim(0, 0.5) + + +# Save the violin plot as SVG and PDF +plt.savefig('violin_plot_cosine_similarity.svg', format='svg') +plt.savefig('violin_plot_cosine_similarity.pdf', format='pdf') + + +# Show the plot +plt.show() +# %% diff --git a/applications/contrastive_phenotyping/evaluation/pca_umap_embeddings_time.py b/applications/contrastive_phenotyping/evaluation/pca_umap_embeddings_time.py new file mode 100644 index 000000000..435dafba5 --- /dev/null +++ b/applications/contrastive_phenotyping/evaluation/pca_umap_embeddings_time.py @@ -0,0 +1,222 @@ +# %% +from pathlib import Path + + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import plotly.express as px +import seaborn as sns +from sklearn.decomposition import PCA +from sklearn.preprocessing import StandardScaler +from umap import UMAP +from sklearn.decomposition import PCA + + +from viscy.representation.embedding_writer import read_embedding_dataset +from viscy.representation.evaluation import dataset_of_tracks, load_annotation + + +# %% Paths and parameters. + + +features_path = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr" +) +data_path = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr" +) +tracks_path = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr" +) + + +# %% +embedding_dataset = read_embedding_dataset(features_path) +embedding_dataset + + +# %% +# Compute UMAP over all features +features = embedding_dataset["features"] +# or select a well: +# features = features[features["fov_name"].str.contains("B/4")] + + +scaled_features = StandardScaler().fit_transform(features.values) +umap = UMAP() +# Fit UMAP on all features +embedding = umap.fit_transform(scaled_features) + + +# %% +# Add UMAP coordinates to the dataset and plot w/ time + + +features = ( + features.assign_coords(UMAP1=("sample", embedding[:, 0])) + .assign_coords(UMAP2=("sample", embedding[:, 1])) + .set_index(sample=["UMAP1", "UMAP2"], append=True) +) +features + + + + +sns.scatterplot( + x=features["UMAP1"], y=features["UMAP2"], hue=features["t"], s=7, alpha=0.8 +) + + +# Add the title to the plot +plt.title("Cell & Time Aware (30 min interval)") +plt.savefig('umap_cell_time_aware_time.svg', format='svg') +plt.savefig('umap_cell_time_aware_time.pdf', format='pdf') +# Show the plot +plt.show() + + +# %% + + +any_features_path = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_difcell_randomtime_sampling/Ver2_updateTracking_refineModel/predictions/Feb_2chan_128patch_32projDim/2chan_128patch_56ckpt_FebTest.zarr") +embedding_dataset = read_embedding_dataset(any_features_path) +embedding_dataset + + +# %% +# Compute UMAP over all features +features = embedding_dataset["features"] +# or select a well: +# features = features[features["fov_name"].str.contains("B/4")] + + +scaled_features = StandardScaler().fit_transform(features.values) +umap = UMAP() +# Fit UMAP on all features +embedding = umap.fit_transform(scaled_features) + + +# %% Any time sampling plot + + +features = ( + features.assign_coords(UMAP1=("sample", embedding[:, 0])) + .assign_coords(UMAP2=("sample", embedding[:, 1])) + .set_index(sample=["UMAP1", "UMAP2"], append=True) +) +features + + + + +sns.scatterplot( + x=features["UMAP1"], y=features["UMAP2"], hue=features["t"], s=7, alpha=0.8 +) + + +# Add the title to the plot +plt.title("Cell Aware Any Time") +plt.savefig('umap_cell_aware_time.png', format='png') +#plt.savefig('umap_cell_aware_time.pdf', format='pdf') +# Show the plot +plt.show() + + +# %% + + +contrastive_learning_path = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_random_sampling2/feb_fixed_test_predict.zarr") +embedding_dataset = read_embedding_dataset(contrastive_learning_path) +embedding_dataset + + +# %% +# Compute UMAP over all features +features = embedding_dataset["features"] +# or select a well: +# features = features[features["fov_name"].str.contains("B/4")] + + +scaled_features = StandardScaler().fit_transform(features.values) +umap = UMAP() +# Fit UMAP on all features +embedding = umap.fit_transform(scaled_features) + + +# %% Any time sampling plot + + +features = ( + features.assign_coords(UMAP1=("sample", embedding[:, 0])) + .assign_coords(UMAP2=("sample", embedding[:, 1])) + .set_index(sample=["UMAP1", "UMAP2"], append=True) +) +features + + + + +sns.scatterplot( + x=features["UMAP1"], y=features["UMAP2"], hue=features["t"], s=7, alpha=0.8 +) + + +# Add the title to the plot +plt.title("Classical Contrastive Learning") +plt.savefig('classical_time.svg', format='svg') +plt.savefig('classical_time.pdf', format='pdf') + + +# Show the plot +plt.show() + + +# %% PCA + + +pca = PCA(n_components=4) +# scaled_features = StandardScaler().fit_transform(features.values) +# pca_features = pca.fit_transform(scaled_features) +pca_features = pca.fit_transform(features.values) + + +features = ( + features.assign_coords(PCA1=("sample", pca_features[:, 0])) + .assign_coords(PCA2=("sample", pca_features[:, 1])) + .assign_coords(PCA3=("sample", pca_features[:, 2])) + .assign_coords(PCA4=("sample", pca_features[:, 3])) + .set_index(sample=["PCA1", "PCA2", "PCA3", "PCA4"], append=True) +) + + +# %% plot PCA components w/ time + + +plt.figure(figsize=(10, 10)) +sns.scatterplot(x=features["PCA1"], y=features["PCA2"], hue=features["t"], s=7, alpha=0.8) + + +# %% OVERLAY INFECTION ANNOTATION +ann_root = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/supervised_inf_pred" +) + + +infection = load_annotation( + features, + ann_root / "extracted_inf_state.csv", + "infection_state", + {0.0: "background", 1.0: "uninfected", 2.0: "infected"}, +) + + +# %% +sns.scatterplot(x=features["UMAP1"], y=features["UMAP2"], hue=infection, s=7, alpha=0.8) + + +# %% plot PCA components with infection hue +sns.scatterplot(x=features["PCA1"], y=features["PCA2"], hue=infection, s=7, alpha=0.8) + + +# %% diff --git a/viscy/representation/evaluation.py b/viscy/representation/evaluation.py index cbf8ead06..0d9e505af 100644 --- a/viscy/representation/evaluation.py +++ b/viscy/representation/evaluation.py @@ -17,6 +17,8 @@ from sklearn.preprocessing import StandardScaler from viscy.data.triplet import TripletDataModule +from sklearn.metrics.pairwise import cosine_similarity +from collections import defaultdict """ This module enables evaluation of learned representations using annotations, such as @@ -379,3 +381,131 @@ def compute_std_dev(image): std_dev = np.std(image) return std_dev + +# Function to extract embeddings and calculate cosine similarities for a specific cell +def calculate_cosine_similarity_cell(embedding_dataset, fov_name, track_id): + # Filter the dataset for the specific infected cell + filtered_data = embedding_dataset.where( + (embedding_dataset['fov_name'] == fov_name) & + (embedding_dataset['track_id'] == track_id), + drop=True + ) + + # Extract the feature embeddings and time points + features = filtered_data['features'].values # (sample, features) + time_points = filtered_data['t'].values # (sample,) + + # Get the first time point's embedding + first_time_point_embedding = features[0].reshape(1, -1) + + # Calculate cosine similarity between each time point and the first time point + cosine_similarities = [] + for i in range(len(time_points)): + similarity = cosine_similarity( + first_time_point_embedding, features[i].reshape(1, -1) + ) + cosine_similarities.append(similarity[0][0]) + + return time_points, cosine_similarities + + +# Function to compute the norm of differences between embeddings at t and t + tau +def compute_displacement_mean_std(embedding_dataset, max_tau=10, use_cosine=False, use_dissimilarity=False): + # Get the arrays of (fov_name, track_id, t, and embeddings) + fov_names = embedding_dataset['fov_name'].values + track_ids = embedding_dataset['track_id'].values + timepoints = embedding_dataset['t'].values + embeddings = embedding_dataset['features'].values + + # Dictionary to store displacements for each tau + displacement_per_tau = defaultdict(list) + + # Iterate over all entries in the dataset + for i in range(len(fov_names)): + fov_name = fov_names[i] + track_id = track_ids[i] + current_time = timepoints[i] + current_embedding = embeddings[i] + + # For each time point t, compute displacements for t + tau + for tau in range(1, max_tau + 1): + future_time = current_time + tau + + # Find if future_time exists for the same (fov_name, track_id) + matching_indices = np.where( + (fov_names == fov_name) & (track_ids == track_id) & (timepoints == future_time) + )[0] + + if len(matching_indices) == 1: + # Get the embedding at t + tau + future_embedding = embeddings[matching_indices[0]] + + if use_cosine: + # Compute cosine similarity + similarity = cosine_similarity(current_embedding.reshape(1, -1), future_embedding.reshape(1, -1))[0][0] + # Choose whether to use similarity or dissimilarity + if use_dissimilarity: + displacement = 1 - similarity # Cosine dissimilarity + else: + displacement = similarity # Cosine similarity + else: + # Compute the Euclidean distance, elementwise square on difference + displacement = np.sum((current_embedding - future_embedding) ** 2) + + # Store the displacement for the given tau + displacement_per_tau[tau].append(displacement) + + # Compute mean and std displacement for each tau by averaging the displacements + mean_displacement_per_tau = {tau: np.mean(displacements) for tau, displacements in displacement_per_tau.items()} + std_displacement_per_tau = {tau: np.std(displacements) for tau, displacements in displacement_per_tau.items()} + + return mean_displacement_per_tau, std_displacement_per_tau + + +# Function to compute the norm of differences between embeddings at t and t + tau +def compute_displacement(embedding_dataset, max_tau=10, use_cosine=False, use_dissimilarity=False): + # Get the arrays of (fov_name, track_id, t, and embeddings) + fov_names = embedding_dataset['fov_name'].values + track_ids = embedding_dataset['track_id'].values + timepoints = embedding_dataset['t'].values + embeddings = embedding_dataset['features'].values + + # Dictionary to store displacements for each tau + displacement_per_tau = defaultdict(list) + + # Iterate over all entries in the dataset + for i in range(len(fov_names)): + fov_name = fov_names[i] + track_id = track_ids[i] + current_time = timepoints[i] + current_embedding = embeddings[i] + + # For each time point t, compute displacements for t + tau + for tau in range(1, max_tau + 1): + future_time = current_time + tau + + # Find if future_time exists for the same (fov_name, track_id) + matching_indices = np.where( + (fov_names == fov_name) & (track_ids == track_id) & (timepoints == future_time) + )[0] + + if len(matching_indices) == 1: + # Get the embedding at t + tau + future_embedding = embeddings[matching_indices[0]] + + if use_cosine: + # Compute cosine similarity + similarity = cosine_similarity(current_embedding.reshape(1, -1), future_embedding.reshape(1, -1))[0][0] + # Choose whether to use similarity or dissimilarity + if use_dissimilarity: + displacement = 1 - similarity # Cosine dissimilarity + else: + displacement = similarity # Cosine similarity + else: + # Compute the Euclidean distance, elementwise square on difference + displacement = np.sum((current_embedding - future_embedding) ** 2) + + # Store the displacement for the given tau + displacement_per_tau[tau].append(displacement) + + return displacement_per_tau \ No newline at end of file From cb263ade3dc9a14770293d756a36b4c28a53db5c Mon Sep 17 00:00:00 2001 From: Alishba Imran Date: Tue, 17 Sep 2024 11:47:38 -0700 Subject: [PATCH 02/18] format fixed for tests --- viscy/representation/evaluation.py | 264 ++++++++++++++++------------- 1 file changed, 143 insertions(+), 121 deletions(-) diff --git a/viscy/representation/evaluation.py b/viscy/representation/evaluation.py index 0d9e505af..e4830dbc5 100644 --- a/viscy/representation/evaluation.py +++ b/viscy/representation/evaluation.py @@ -1,3 +1,5 @@ +from collections import defaultdict + import numpy as np import pandas as pd import umap @@ -13,12 +15,11 @@ normalized_mutual_info_score, silhouette_score, ) +from sklearn.metrics.pairwise import cosine_similarity from sklearn.neighbors import KNeighborsClassifier from sklearn.preprocessing import StandardScaler from viscy.data.triplet import TripletDataModule -from sklearn.metrics.pairwise import cosine_similarity -from collections import defaultdict """ This module enables evaluation of learned representations using annotations, such as @@ -382,130 +383,151 @@ def compute_std_dev(image): return std_dev + # Function to extract embeddings and calculate cosine similarities for a specific cell def calculate_cosine_similarity_cell(embedding_dataset, fov_name, track_id): - # Filter the dataset for the specific infected cell - filtered_data = embedding_dataset.where( - (embedding_dataset['fov_name'] == fov_name) & - (embedding_dataset['track_id'] == track_id), - drop=True - ) - - # Extract the feature embeddings and time points - features = filtered_data['features'].values # (sample, features) - time_points = filtered_data['t'].values # (sample,) - - # Get the first time point's embedding - first_time_point_embedding = features[0].reshape(1, -1) - - # Calculate cosine similarity between each time point and the first time point - cosine_similarities = [] - for i in range(len(time_points)): - similarity = cosine_similarity( - first_time_point_embedding, features[i].reshape(1, -1) - ) - cosine_similarities.append(similarity[0][0]) - - return time_points, cosine_similarities + # Filter the dataset for the specific infected cell + filtered_data = embedding_dataset.where( + (embedding_dataset["fov_name"] == fov_name) + & (embedding_dataset["track_id"] == track_id), + drop=True, + ) + + # Extract the feature embeddings and time points + features = filtered_data["features"].values # (sample, features) + time_points = filtered_data["t"].values # (sample,) + + # Get the first time point's embedding + first_time_point_embedding = features[0].reshape(1, -1) + + # Calculate cosine similarity between each time point and the first time point + cosine_similarities = [] + for i in range(len(time_points)): + similarity = cosine_similarity( + first_time_point_embedding, features[i].reshape(1, -1) + ) + cosine_similarities.append(similarity[0][0]) + + return time_points, cosine_similarities # Function to compute the norm of differences between embeddings at t and t + tau -def compute_displacement_mean_std(embedding_dataset, max_tau=10, use_cosine=False, use_dissimilarity=False): - # Get the arrays of (fov_name, track_id, t, and embeddings) - fov_names = embedding_dataset['fov_name'].values - track_ids = embedding_dataset['track_id'].values - timepoints = embedding_dataset['t'].values - embeddings = embedding_dataset['features'].values - - # Dictionary to store displacements for each tau - displacement_per_tau = defaultdict(list) - - # Iterate over all entries in the dataset - for i in range(len(fov_names)): - fov_name = fov_names[i] - track_id = track_ids[i] - current_time = timepoints[i] - current_embedding = embeddings[i] - - # For each time point t, compute displacements for t + tau - for tau in range(1, max_tau + 1): - future_time = current_time + tau - - # Find if future_time exists for the same (fov_name, track_id) - matching_indices = np.where( - (fov_names == fov_name) & (track_ids == track_id) & (timepoints == future_time) - )[0] - - if len(matching_indices) == 1: - # Get the embedding at t + tau - future_embedding = embeddings[matching_indices[0]] - - if use_cosine: - # Compute cosine similarity - similarity = cosine_similarity(current_embedding.reshape(1, -1), future_embedding.reshape(1, -1))[0][0] - # Choose whether to use similarity or dissimilarity - if use_dissimilarity: - displacement = 1 - similarity # Cosine dissimilarity - else: - displacement = similarity # Cosine similarity - else: - # Compute the Euclidean distance, elementwise square on difference - displacement = np.sum((current_embedding - future_embedding) ** 2) - - # Store the displacement for the given tau - displacement_per_tau[tau].append(displacement) - - # Compute mean and std displacement for each tau by averaging the displacements - mean_displacement_per_tau = {tau: np.mean(displacements) for tau, displacements in displacement_per_tau.items()} - std_displacement_per_tau = {tau: np.std(displacements) for tau, displacements in displacement_per_tau.items()} - - return mean_displacement_per_tau, std_displacement_per_tau +def compute_displacement_mean_std( + embedding_dataset, max_tau=10, use_cosine=False, use_dissimilarity=False +): + # Get the arrays of (fov_name, track_id, t, and embeddings) + fov_names = embedding_dataset["fov_name"].values + track_ids = embedding_dataset["track_id"].values + timepoints = embedding_dataset["t"].values + embeddings = embedding_dataset["features"].values + + # Dictionary to store displacements for each tau + displacement_per_tau = defaultdict(list) + + # Iterate over all entries in the dataset + for i in range(len(fov_names)): + fov_name = fov_names[i] + track_id = track_ids[i] + current_time = timepoints[i] + current_embedding = embeddings[i] + + # For each time point t, compute displacements for t + tau + for tau in range(1, max_tau + 1): + future_time = current_time + tau + + # Find if future_time exists for the same (fov_name, track_id) + matching_indices = np.where( + (fov_names == fov_name) + & (track_ids == track_id) + & (timepoints == future_time) + )[0] + + if len(matching_indices) == 1: + # Get the embedding at t + tau + future_embedding = embeddings[matching_indices[0]] + + if use_cosine: + # Compute cosine similarity + similarity = cosine_similarity( + current_embedding.reshape(1, -1), + future_embedding.reshape(1, -1), + )[0][0] + # Choose whether to use similarity or dissimilarity + if use_dissimilarity: + displacement = 1 - similarity # Cosine dissimilarity + else: + displacement = similarity # Cosine similarity + else: + # Compute the Euclidean distance, elementwise square on difference + displacement = np.sum((current_embedding - future_embedding) ** 2) + + # Store the displacement for the given tau + displacement_per_tau[tau].append(displacement) + + # Compute mean and std displacement for each tau by averaging the displacements + mean_displacement_per_tau = { + tau: np.mean(displacements) + for tau, displacements in displacement_per_tau.items() + } + std_displacement_per_tau = { + tau: np.std(displacements) + for tau, displacements in displacement_per_tau.items() + } + + return mean_displacement_per_tau, std_displacement_per_tau # Function to compute the norm of differences between embeddings at t and t + tau -def compute_displacement(embedding_dataset, max_tau=10, use_cosine=False, use_dissimilarity=False): - # Get the arrays of (fov_name, track_id, t, and embeddings) - fov_names = embedding_dataset['fov_name'].values - track_ids = embedding_dataset['track_id'].values - timepoints = embedding_dataset['t'].values - embeddings = embedding_dataset['features'].values - - # Dictionary to store displacements for each tau - displacement_per_tau = defaultdict(list) - - # Iterate over all entries in the dataset - for i in range(len(fov_names)): - fov_name = fov_names[i] - track_id = track_ids[i] - current_time = timepoints[i] - current_embedding = embeddings[i] - - # For each time point t, compute displacements for t + tau - for tau in range(1, max_tau + 1): - future_time = current_time + tau - - # Find if future_time exists for the same (fov_name, track_id) - matching_indices = np.where( - (fov_names == fov_name) & (track_ids == track_id) & (timepoints == future_time) - )[0] - - if len(matching_indices) == 1: - # Get the embedding at t + tau - future_embedding = embeddings[matching_indices[0]] - - if use_cosine: - # Compute cosine similarity - similarity = cosine_similarity(current_embedding.reshape(1, -1), future_embedding.reshape(1, -1))[0][0] - # Choose whether to use similarity or dissimilarity - if use_dissimilarity: - displacement = 1 - similarity # Cosine dissimilarity - else: - displacement = similarity # Cosine similarity - else: - # Compute the Euclidean distance, elementwise square on difference - displacement = np.sum((current_embedding - future_embedding) ** 2) - - # Store the displacement for the given tau - displacement_per_tau[tau].append(displacement) - - return displacement_per_tau \ No newline at end of file +def compute_displacement( + embedding_dataset, max_tau=10, use_cosine=False, use_dissimilarity=False +): + # Get the arrays of (fov_name, track_id, t, and embeddings) + fov_names = embedding_dataset["fov_name"].values + track_ids = embedding_dataset["track_id"].values + timepoints = embedding_dataset["t"].values + embeddings = embedding_dataset["features"].values + + # Dictionary to store displacements for each tau + displacement_per_tau = defaultdict(list) + + # Iterate over all entries in the dataset + for i in range(len(fov_names)): + fov_name = fov_names[i] + track_id = track_ids[i] + current_time = timepoints[i] + current_embedding = embeddings[i] + + # For each time point t, compute displacements for t + tau + for tau in range(1, max_tau + 1): + future_time = current_time + tau + + # Find if future_time exists for the same (fov_name, track_id) + matching_indices = np.where( + (fov_names == fov_name) + & (track_ids == track_id) + & (timepoints == future_time) + )[0] + + if len(matching_indices) == 1: + # Get the embedding at t + tau + future_embedding = embeddings[matching_indices[0]] + + if use_cosine: + # Compute cosine similarity + similarity = cosine_similarity( + current_embedding.reshape(1, -1), + future_embedding.reshape(1, -1), + )[0][0] + # Choose whether to use similarity or dissimilarity + if use_dissimilarity: + displacement = 1 - similarity # Cosine dissimilarity + else: + displacement = similarity # Cosine similarity + else: + # Compute the Euclidean distance, elementwise square on difference + displacement = np.sum((current_embedding - future_embedding) ** 2) + + # Store the displacement for the given tau + displacement_per_tau[tau].append(displacement) + + return displacement_per_tau From 398ed4408615dd03bfc6eeecb829dbb92c2fe4e6 Mon Sep 17 00:00:00 2001 From: Alishba Imran Date: Tue, 17 Sep 2024 15:09:09 -0700 Subject: [PATCH 03/18] updated scripts --- .../evaluation/cosine_similarity.py | 139 ++++++++++++++++-- .../evaluation/pca_umap_embeddings_time.py | 33 +++-- 2 files changed, 141 insertions(+), 31 deletions(-) diff --git a/applications/contrastive_phenotyping/evaluation/cosine_similarity.py b/applications/contrastive_phenotyping/evaluation/cosine_similarity.py index 79b25741b..9962d2c3b 100644 --- a/applications/contrastive_phenotyping/evaluation/cosine_similarity.py +++ b/applications/contrastive_phenotyping/evaluation/cosine_similarity.py @@ -48,7 +48,7 @@ # %% Load embedding datasets for all three sampling fov_name = '/B/4/6' -track_id = 4 +track_id = 38 embedding_dataset_30_min = read_embedding_dataset(features_path_30_min) @@ -77,10 +77,10 @@ plt.xlabel("Time Delay (t)") plt.ylabel("Cosine Similarity with First Time Point") -plt.title("Cosine Similarity Over Time for Infected Cell") +plt.title("Cosine Similarity Over Time for Infected Cell (FOV: /B/4/6, Track ID: 38)") -#plt.savefig('infected_cell_example.svg', format='svg') +plt.savefig('example_cell_inf.svg', format='svg') #plt.savefig('infected_cell_example.pdf', format='pdf') @@ -124,13 +124,13 @@ max_tau = 10 # Maximum time shift (tau) to compute displacements -mean_displacement_30_min, std_displacement_30_min = compute_displacement_mean_std(embedding_dataset_30_min, max_tau, use_cosine=False, use_dissimilarity=True) -mean_displacement_no_track, std_displacement_no_track = compute_displacement_mean_std(embedding_dataset_no_track, max_tau, use_cosine=False, use_dissimilarity=True) +# mean_displacement_30_min, std_displacement_30_min = compute_displacement_mean_std(embedding_dataset_30_min, max_tau, use_cosine=False, use_dissimilarity=False) +# mean_displacement_no_track, std_displacement_no_track = compute_displacement_mean_std(embedding_dataset_no_track, max_tau, use_cosine=False, use_dissimilarity=False) #mean_displacement_any_time, std_displacement_any_time = compute_displacement_mean_std(embedding_dataset_any_time, max_tau, use_cosine=False) -mean_displacement_30_min_cosine, std_displacement_30_min_cosine = compute_displacement_mean_std(embedding_dataset_30_min, max_tau, use_cosine=True, use_dissimilarity=True) -mean_displacement_no_track_cosine, std_displacement_no_track_cosine = compute_displacement_mean_std(embedding_dataset_no_track, max_tau, use_cosine=True, use_dissimilarity=True) +mean_displacement_30_min_cosine, std_displacement_30_min_cosine = compute_displacement_mean_std(embedding_dataset_30_min, max_tau, use_cosine=True, use_dissimilarity=False) +mean_displacement_no_track_cosine, std_displacement_no_track_cosine = compute_displacement_mean_std(embedding_dataset_no_track, max_tau, use_cosine=True, use_dissimilarity=False) #mean_displacement_any_time_cosine, std_displacement_any_time_cosine = compute_displacement_mean_std(embedding_dataset_any_time, max_tau, use_cosine=True) # %% Plot 1: Euclidean Displacements plt.figure(figsize=(10, 6)) @@ -178,6 +178,7 @@ # %% Plot 2: Cosine Displacements plt.figure(figsize=(10, 6)) +taus = list(mean_displacement_30_min_cosine.keys()) # Plotting Cosine displacements mean_values_30_min_cosine = list(mean_displacement_30_min_cosine.values()) @@ -190,7 +191,7 @@ plt.plot(taus, mean_values_30_min_cosine, marker='o', label='Cell & Time Aware (30 min interval)') plt.fill_between(taus, np.array(mean_values_30_min_cosine) - np.array(std_values_30_min_cosine), np.array(mean_values_30_min_cosine) + np.array(std_values_30_min_cosine), - color='gray', alpha=0.3, label='Std Dev (30 min interval, Cosine)') + color='gray', alpha=0.3, label='Std Dev (30 min interval)') plt.plot(taus, mean_values_no_track_cosine, marker='o', label='Classical Contrastive (No Tracking)') @@ -201,6 +202,10 @@ plt.xlabel('Time Shift (τ)') plt.ylabel('Cosine Similarity') plt.title('Embedding Displacement Over Time') + +plt.savefig('std_cosine_plot.svg', format='svg') + + plt.grid(True) plt.legend() @@ -235,11 +240,11 @@ # Compute displacements for Cell & Time Aware (30 min interval) using Cosine similarity -displacement_per_tau_aware_cosine = compute_displacement(embedding_dataset_30_min, max_tau, use_cosine=True, use_dissimilarity=True) +displacement_per_tau_aware_cosine = compute_displacement(embedding_dataset_30_min, max_tau, use_cosine=True, use_dissimilarity=False) # Compute displacements for Classical Contrastive (No Tracking) using Cosine similarity -displacement_per_tau_contrastive_cosine = compute_displacement(embedding_dataset_no_track, max_tau, use_cosine=True, use_dissimilarity=True) +displacement_per_tau_contrastive_cosine = compute_displacement(embedding_dataset_no_track, max_tau, use_cosine=True, use_dissimilarity=False) # %% Prepare data for violin plot @@ -268,7 +273,6 @@ def prepare_violin_data(taus, displacement_aware, displacement_contrastive): return df -# Assuming 'displacement_per_tau_aware_cosine' and 'displacement_per_tau_contrastive_cosine' hold the displacements as dictionaries taus = list(displacement_per_tau_aware_cosine.keys()) @@ -294,20 +298,125 @@ def prepare_violin_data(taus, displacement_aware, displacement_contrastive): # Add labels and title plt.xlabel('Time Shift (τ)', fontsize=14) -plt.ylabel('Cosine Dissimilarity', fontsize=14) -plt.title('Cosine Dissimilarity Distribution', fontsize=16) +plt.ylabel('Cosine Similarity', fontsize=14) +plt.title('Cosine Similarity Distribution on Features', fontsize=16) plt.grid(True, linestyle='--', alpha=0.6) # Lighter grid lines for less distraction plt.legend(title='Sampling', fontsize=12, title_fontsize=14) -plt.ylim(0, 0.5) +plt.ylim(0.5, 1.0) # Save the violin plot as SVG and PDF -plt.savefig('violin_plot_cosine_similarity.svg', format='svg') +plt.savefig('updated_violin_plot_cosine_similarity.svg', format='svg') plt.savefig('violin_plot_cosine_similarity.pdf', format='pdf') # Show the plot plt.show() +# %% using umap violin plot + +# Import necessary libraries +import seaborn as sns +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +from pathlib import Path +from collections import defaultdict +from sklearn.metrics.pairwise import cosine_similarity +from viscy.representation.embedding_writer import read_embedding_dataset +from sklearn.preprocessing import StandardScaler +from umap import UMAP +import xarray as xr + +# %% Paths to datasets +features_path_30_min = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr") +feature_path_no_track = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_random_sampling2/feb_fixed_test_predict.zarr") + +# %% Read embedding datasets +embedding_dataset_30_min = read_embedding_dataset(features_path_30_min) +embedding_dataset_no_track = read_embedding_dataset(feature_path_no_track) + +# %% Compute UMAP on features +# %% Compute UMAP on features +def compute_umap(dataset): + features = dataset["features"] + scaled_features = StandardScaler().fit_transform(features.values) + umap = UMAP(n_components=2) # Reduce to 2 dimensions + embedding = umap.fit_transform(scaled_features) + + # Add UMAP coordinates using xarray functionality + umap_features = features.assign_coords(UMAP1=("sample", embedding[:, 0]), UMAP2=("sample", embedding[:, 1])) + return umap_features + +# Apply UMAP to both datasets +umap_features_30_min = compute_umap(embedding_dataset_30_min) +umap_features_no_track = compute_umap(embedding_dataset_no_track) + +# %% Compute displacements using UMAP coordinates (using Cosine similarity) +max_tau = 10 # Maximum time shift (tau) to compute displacements + +# Compute displacements for UMAP-processed Cell & Time Aware (30 min interval) +displacement_per_tau_aware_umap_cosine = compute_displacement(umap_features_30_min, max_tau, use_cosine=True, use_dissimilarity=False) + +# Compute displacements for UMAP-processed Classical Contrastive (No Tracking) +displacement_per_tau_contrastive_umap_cosine = compute_displacement(umap_features_no_track, max_tau, use_cosine=True, use_dissimilarity=False) + +# %% Prepare data for violin plot +def prepare_violin_data(taus, displacement_aware, displacement_contrastive): + # Create a list to hold the data + data = [] + + # Populate the data for Cell & Time Aware + for tau in taus: + displacements_aware = displacement_aware.get(tau, []) + for displacement in displacements_aware: + data.append({'Time Shift (τ)': tau, 'Displacement': displacement, 'Sampling': 'Cell & Time Aware (30 min interval)'}) + + # Populate the data for Classical Contrastive + for tau in taus: + displacements_contrastive = displacement_contrastive.get(tau, []) + for displacement in displacements_contrastive: + data.append({'Time Shift (τ)': tau, 'Displacement': displacement, 'Sampling': 'Classical Contrastive (No Tracking)'}) + + # Convert to a DataFrame + df = pd.DataFrame(data) + return df + +taus = list(displacement_per_tau_aware_umap_cosine.keys()) + +# Prepare the violin plot data +df = prepare_violin_data(taus, displacement_per_tau_aware_umap_cosine, displacement_per_tau_contrastive_umap_cosine) + +# %% Create a violin plot using seaborn +plt.figure(figsize=(12, 8)) +sns.violinplot( + x='Time Shift (τ)', + y='Displacement', + hue='Sampling', + data=df, + palette='Set2', + scale='width', + bw=0.2, + inner=None, + split=True, + cut=0 +) + +# Add labels and title +plt.xlabel('Time Shift (τ)', fontsize=14) +plt.ylabel('Cosine Similarity', fontsize=14) +plt.title('Cosine Similarity Distribution using UMAP Features', fontsize=16) +plt.grid(True, linestyle='--', alpha=0.6) # Lighter grid lines for less distraction +plt.legend(title='Sampling', fontsize=12, title_fontsize=14) + +plt.ylim(0.5, 1.0) + +# Save the violin plot as SVG and PDF +plt.savefig('umap_violin_plot_cosine_similarity.svg', format='svg') +# plt.savefig('violin_plot_cosine_similarity_umap.pdf', format='pdf') + +# Show the plot +plt.show() + # %% diff --git a/applications/contrastive_phenotyping/evaluation/pca_umap_embeddings_time.py b/applications/contrastive_phenotyping/evaluation/pca_umap_embeddings_time.py index 435dafba5..bafe98eb9 100644 --- a/applications/contrastive_phenotyping/evaluation/pca_umap_embeddings_time.py +++ b/applications/contrastive_phenotyping/evaluation/pca_umap_embeddings_time.py @@ -69,9 +69,11 @@ # Add the title to the plot -plt.title("Cell & Time Aware (30 min interval)") -plt.savefig('umap_cell_time_aware_time.svg', format='svg') -plt.savefig('umap_cell_time_aware_time.pdf', format='pdf') +plt.title("Cell & Time Aware Sampling (30 min interval)") +plt.xlim(-10, 20) +plt.ylim(-10, 20) +#plt.savefig('umap_cell_time_aware_time.svg', format='svg') +plt.savefig('updated_cell_time_aware_time.png', format='png') # Show the plot plt.show() @@ -79,7 +81,7 @@ # %% -any_features_path = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_difcell_randomtime_sampling/Ver2_updateTracking_refineModel/predictions/Feb_2chan_128patch_32projDim/2chan_128patch_56ckpt_FebTest.zarr") +any_features_path = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_difcell_randomtime_sampling/Ver2_updateTracking_refineModel/predictions/Feb_1chan_128patch_32projDim/1chan_128patch_63ckpt_FebTest.zarr") embedding_dataset = read_embedding_dataset(any_features_path) embedding_dataset @@ -108,16 +110,18 @@ features - - sns.scatterplot( x=features["UMAP1"], y=features["UMAP2"], hue=features["t"], s=7, alpha=0.8 ) # Add the title to the plot -plt.title("Cell Aware Any Time") -plt.savefig('umap_cell_aware_time.png', format='png') +plt.title("Cell Aware Sampling") + +plt.xlim(-10, 20) +plt.ylim(-10, 20) + +plt.savefig('updated_cell_aware_time.png', format='png') #plt.savefig('umap_cell_aware_time.pdf', format='pdf') # Show the plot plt.show() @@ -154,19 +158,16 @@ ) features - - - sns.scatterplot( x=features["UMAP1"], y=features["UMAP2"], hue=features["t"], s=7, alpha=0.8 ) - # Add the title to the plot -plt.title("Classical Contrastive Learning") -plt.savefig('classical_time.svg', format='svg') -plt.savefig('classical_time.pdf', format='pdf') - +plt.title("Classical Contrastive Learning Sampling") +plt.xlim(-10, 20) +plt.ylim(-10, 20) +plt.savefig('updated_classical_time.png', format='png') +#plt.savefig('classical_time.pdf', format='pdf') # Show the plot plt.show() From 9068b677b82a7fc82f0509267af34c5ac547fbde Mon Sep 17 00:00:00 2001 From: Alishba Imran Date: Tue, 17 Sep 2024 18:37:01 -0700 Subject: [PATCH 04/18] umap dist code --- .../evaluation/cosine_similarity.py | 89 ++++++++++++------- .../evaluation/pca_umap_embeddings_time.py | 4 +- viscy/representation/evaluation.py | 11 ++- 3 files changed, 68 insertions(+), 36 deletions(-) diff --git a/applications/contrastive_phenotyping/evaluation/cosine_similarity.py b/applications/contrastive_phenotyping/evaluation/cosine_similarity.py index 9962d2c3b..2d9fc38af 100644 --- a/applications/contrastive_phenotyping/evaluation/cosine_similarity.py +++ b/applications/contrastive_phenotyping/evaluation/cosine_similarity.py @@ -1,7 +1,6 @@ # %% from pathlib import Path - import matplotlib.pyplot as plt import numpy as np import pandas as pd @@ -12,7 +11,6 @@ from umap import UMAP from sklearn.decomposition import PCA - from viscy.representation.embedding_writer import read_embedding_dataset from viscy.representation.evaluation import dataset_of_tracks, load_annotation from sklearn.metrics.pairwise import cosine_similarity @@ -21,7 +19,6 @@ from viscy.representation.evaluation import compute_displacement_mean_std from viscy.representation.evaluation import compute_displacement - # %% Paths and parameters. @@ -33,7 +30,7 @@ feature_path_no_track = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_random_sampling2/feb_fixed_test_predict.zarr") -features_path_any_time = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_difcell_randomtime_sampling/Ver2_updateTracking_refineModel/predictions/Feb_1chan_128patch_32projDim/1chan_128patch_63ckpt_FebTest.zarr") +features_path_any_time = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_difcell_randomtime_sampling/Ver2_updateTracking_refineModel/predictions/Feb_2chan_128patch_32projDim/2chan_128patch_56ckpt_FebTest.zarr") data_path = Path( @@ -48,46 +45,37 @@ # %% Load embedding datasets for all three sampling fov_name = '/B/4/6' -track_id = 38 - +track_id = 52 embedding_dataset_30_min = read_embedding_dataset(features_path_30_min) embedding_dataset_no_track = read_embedding_dataset(feature_path_no_track) embedding_dataset_any_time = read_embedding_dataset(features_path_any_time) - # Calculate cosine similarities for each sampling time_points_30_min, cosine_similarities_30_min = calculate_cosine_similarity_cell(embedding_dataset_30_min, fov_name, track_id) time_points_no_track, cosine_similarities_no_track = calculate_cosine_similarity_cell(embedding_dataset_no_track, fov_name, track_id) time_points_any_time, cosine_similarities_any_time = calculate_cosine_similarity_cell(embedding_dataset_any_time, fov_name, track_id) - # %% Plot cosine similarities over time for all three conditions - plt.figure(figsize=(10, 6)) - plt.plot(time_points_no_track, cosine_similarities_no_track, marker='o', label='classical contrastive (no tracking)') plt.plot(time_points_any_time, cosine_similarities_any_time, marker='o', label='cell aware') plt.plot(time_points_30_min, cosine_similarities_30_min, marker='o', label='cell & time aware (interval 30 min)') - - - plt.xlabel("Time Delay (t)") plt.ylabel("Cosine Similarity with First Time Point") -plt.title("Cosine Similarity Over Time for Infected Cell (FOV: /B/4/6, Track ID: 38)") - +plt.title("Cosine Similarity Over Time for Infected Cell") -plt.savefig('example_cell_inf.svg', format='svg') #plt.savefig('infected_cell_example.pdf', format='pdf') plt.grid(True) -plt.legend() +plt.legend() +plt.savefig('new_example_cell.svg', format='svg') plt.show() @@ -203,12 +191,12 @@ plt.ylabel('Cosine Similarity') plt.title('Embedding Displacement Over Time') -plt.savefig('std_cosine_plot.svg', format='svg') + plt.grid(True) plt.legend() - +plt.savefig('1_std_cosine_plot.svg', format='svg') # Show the Cosine plot plt.show() @@ -240,15 +228,14 @@ # Compute displacements for Cell & Time Aware (30 min interval) using Cosine similarity -displacement_per_tau_aware_cosine = compute_displacement(embedding_dataset_30_min, max_tau, use_cosine=True, use_dissimilarity=False) +displacement_per_tau_aware_cosine = compute_displacement(embedding_dataset_30_min, max_tau, use_cosine=True, use_dissimilarity=False, use_umap=False) # Compute displacements for Classical Contrastive (No Tracking) using Cosine similarity -displacement_per_tau_contrastive_cosine = compute_displacement(embedding_dataset_no_track, max_tau, use_cosine=True, use_dissimilarity=False) +displacement_per_tau_contrastive_cosine = compute_displacement(embedding_dataset_no_track, max_tau, use_cosine=True, use_dissimilarity=False, use_umap=False) # %% Prepare data for violin plot -# Prepare the data in a long-form DataFrame for the violin plot def prepare_violin_data(taus, displacement_aware, displacement_contrastive): # Create a list to hold the data data = [] @@ -304,19 +291,19 @@ def prepare_violin_data(taus, displacement_aware, displacement_contrastive): plt.legend(title='Sampling', fontsize=12, title_fontsize=14) -plt.ylim(0.5, 1.0) +#plt.ylim(0.5, 1.0) # Save the violin plot as SVG and PDF -plt.savefig('updated_violin_plot_cosine_similarity.svg', format='svg') -plt.savefig('violin_plot_cosine_similarity.pdf', format='pdf') +plt.savefig('1fixed_violin_plot_cosine_similarity.svg', format='svg') +# plt.savefig('violin_plot_cosine_similarity.pdf', format='pdf') # Show the plot plt.show() # %% using umap violin plot -# Import necessary libraries +# Import necessary libraries, try euclidean distance for both features and import seaborn as sns import pandas as pd import numpy as np @@ -337,7 +324,6 @@ def prepare_violin_data(taus, displacement_aware, displacement_contrastive): embedding_dataset_30_min = read_embedding_dataset(features_path_30_min) embedding_dataset_no_track = read_embedding_dataset(feature_path_no_track) -# %% Compute UMAP on features # %% Compute UMAP on features def compute_umap(dataset): features = dataset["features"] @@ -353,14 +339,34 @@ def compute_umap(dataset): umap_features_30_min = compute_umap(embedding_dataset_30_min) umap_features_no_track = compute_umap(embedding_dataset_no_track) +# %% +print(umap_features_30_min) +# %% Visualize UMAP embeddings +# # Visualize UMAP embeddings for the 30 min interval +# plt.figure(figsize=(8, 6)) +# plt.scatter(umap_features_30_min[:, 0], umap_features_30_min[:, 1], c=embedding_dataset_30_min["t"].values, cmap='viridis') +# plt.colorbar(label='Timepoints') +# plt.title('UMAP Projection of Features (30 min Interval)') +# plt.xlabel('UMAP1') +# plt.ylabel('UMAP2') +# plt.show() + +# # Visualize UMAP embeddings for the No Tracking dataset +# plt.figure(figsize=(8, 6)) +# plt.scatter(umap_features_no_track[:, 0], umap_features_no_track[:, 1], c=embedding_dataset_no_track["t"].values, cmap='viridis') +# plt.colorbar(label='Timepoints') +# plt.title('UMAP Projection of Features (No Tracking)') +# plt.xlabel('UMAP1') +# plt.ylabel('UMAP2') +# plt.show() # %% Compute displacements using UMAP coordinates (using Cosine similarity) max_tau = 10 # Maximum time shift (tau) to compute displacements # Compute displacements for UMAP-processed Cell & Time Aware (30 min interval) -displacement_per_tau_aware_umap_cosine = compute_displacement(umap_features_30_min, max_tau, use_cosine=True, use_dissimilarity=False) +displacement_per_tau_aware_umap_cosine = compute_displacement(umap_features_30_min, max_tau, use_cosine=True, use_dissimilarity=False, use_umap=True) # Compute displacements for UMAP-processed Classical Contrastive (No Tracking) -displacement_per_tau_contrastive_umap_cosine = compute_displacement(umap_features_no_track, max_tau, use_cosine=True, use_dissimilarity=False) +displacement_per_tau_contrastive_umap_cosine = compute_displacement(umap_features_no_track, max_tau, use_cosine=True, use_dissimilarity=False, use_umap=True) # %% Prepare data for violin plot def prepare_violin_data(taus, displacement_aware, displacement_contrastive): @@ -410,13 +416,34 @@ def prepare_violin_data(taus, displacement_aware, displacement_contrastive): plt.grid(True, linestyle='--', alpha=0.6) # Lighter grid lines for less distraction plt.legend(title='Sampling', fontsize=12, title_fontsize=14) -plt.ylim(0.5, 1.0) +#plt.ylim(0, 1) # Save the violin plot as SVG and PDF -plt.savefig('umap_violin_plot_cosine_similarity.svg', format='svg') +plt.savefig('fixed_plot_cosine_similarity.svg', format='svg') # plt.savefig('violin_plot_cosine_similarity_umap.pdf', format='pdf') # Show the plot plt.show() + + # %% +# %% Visualize Displacement Distributions (Example Code) +# Compare displacement distributions for τ = 1 +# plt.figure(figsize=(10, 6)) +# sns.histplot(displacement_per_tau_aware_umap_cosine[1], kde=True, label='UMAP - 30 min Interval', color='blue') +# sns.histplot(displacement_per_tau_contrastive_umap_cosine[1], kde=True, label='UMAP - No Tracking', color='green') +# plt.legend() +# plt.title('Comparison of Displacement Distributions for τ = 1 (UMAP)') +# plt.xlabel('Displacement') +# plt.show() + +# # Compare displacement distributions for the full feature set (same τ = 1) +# plt.figure(figsize=(10, 6)) +# sns.histplot(displacement_per_tau_aware_cosine[1], kde=True, label='Full Features - 30 min Interval', color='red') +# sns.histplot(displacement_per_tau_contrastive_cosine[1], kde=True, label='Full Features - No Tracking', color='orange') +# plt.legend() +# plt.title('Comparison of Displacement Distributions for τ = 1 (Full Features)') +# plt.xlabel('Displacement') +# plt.show() +# # %% diff --git a/applications/contrastive_phenotyping/evaluation/pca_umap_embeddings_time.py b/applications/contrastive_phenotyping/evaluation/pca_umap_embeddings_time.py index bafe98eb9..69adb5fba 100644 --- a/applications/contrastive_phenotyping/evaluation/pca_umap_embeddings_time.py +++ b/applications/contrastive_phenotyping/evaluation/pca_umap_embeddings_time.py @@ -81,7 +81,7 @@ # %% -any_features_path = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_difcell_randomtime_sampling/Ver2_updateTracking_refineModel/predictions/Feb_1chan_128patch_32projDim/1chan_128patch_63ckpt_FebTest.zarr") +any_features_path = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_difcell_randomtime_sampling/Ver2_updateTracking_refineModel/predictions/Feb_2chan_128patch_32projDim/2chan_128patch_56ckpt_FebTest.zarr") embedding_dataset = read_embedding_dataset(any_features_path) embedding_dataset @@ -121,7 +121,7 @@ plt.xlim(-10, 20) plt.ylim(-10, 20) -plt.savefig('updated_cell_aware_time.png', format='png') +plt.savefig('1_updated_cell_aware_time.png', format='png') #plt.savefig('umap_cell_aware_time.pdf', format='pdf') # Show the plot plt.show() diff --git a/viscy/representation/evaluation.py b/viscy/representation/evaluation.py index e4830dbc5..e1b2dc7ed 100644 --- a/viscy/representation/evaluation.py +++ b/viscy/representation/evaluation.py @@ -479,13 +479,18 @@ def compute_displacement_mean_std( # Function to compute the norm of differences between embeddings at t and t + tau def compute_displacement( - embedding_dataset, max_tau=10, use_cosine=False, use_dissimilarity=False -): + embedding_dataset, max_tau=10, use_cosine=False, use_dissimilarity=False, use_umap=False): # Get the arrays of (fov_name, track_id, t, and embeddings) fov_names = embedding_dataset["fov_name"].values track_ids = embedding_dataset["track_id"].values timepoints = embedding_dataset["t"].values - embeddings = embedding_dataset["features"].values + + if use_umap: + umap1 = embedding_dataset["UMAP1"].values + umap2 = embedding_dataset["UMAP2"].values + embeddings = np.vstack((umap1, umap2)).T + else: + embeddings = embedding_dataset["features"].values # Dictionary to store displacements for each tau displacement_per_tau = defaultdict(list) From 90d6cf1aec6d0bd49a50cb752e3773894b9ee524 Mon Sep 17 00:00:00 2001 From: Shalin Mehta Date: Wed, 18 Sep 2024 10:24:52 -0700 Subject: [PATCH 05/18] bug fixes and linting --- .../evaluation/cosine_similarity.py | 407 +++++++++++------- .../evaluation/pca_umap_embeddings_time.py | 85 ++-- 2 files changed, 293 insertions(+), 199 deletions(-) diff --git a/applications/contrastive_phenotyping/evaluation/cosine_similarity.py b/applications/contrastive_phenotyping/evaluation/cosine_similarity.py index 2d9fc38af..78a4906c9 100644 --- a/applications/contrastive_phenotyping/evaluation/cosine_similarity.py +++ b/applications/contrastive_phenotyping/evaluation/cosine_similarity.py @@ -1,50 +1,51 @@ # %% +# Import necessary libraries, try euclidean distance for both features and from pathlib import Path import matplotlib.pyplot as plt import numpy as np import pandas as pd -import plotly.express as px import seaborn as sns -from sklearn.decomposition import PCA from sklearn.preprocessing import StandardScaler from umap import UMAP -from sklearn.decomposition import PCA from viscy.representation.embedding_writer import read_embedding_dataset -from viscy.representation.evaluation import dataset_of_tracks, load_annotation -from sklearn.metrics.pairwise import cosine_similarity -from collections import defaultdict -from viscy.representation.evaluation import calculate_cosine_similarity_cell -from viscy.representation.evaluation import compute_displacement_mean_std -from viscy.representation.evaluation import compute_displacement +from viscy.representation.evaluation import ( + calculate_cosine_similarity_cell, + compute_displacement, + compute_displacement_mean_std, +) # %% Paths and parameters. features_path_30_min = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr" + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr" ) -feature_path_no_track = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_random_sampling2/feb_fixed_test_predict.zarr") +feature_path_no_track = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_random_sampling2/feb_fixed_test_predict.zarr" +) -features_path_any_time = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_difcell_randomtime_sampling/Ver2_updateTracking_refineModel/predictions/Feb_2chan_128patch_32projDim/2chan_128patch_56ckpt_FebTest.zarr") +features_path_any_time = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_difcell_randomtime_sampling/Ver2_updateTracking_refineModel/predictions/Feb_2chan_128patch_32projDim/2chan_128patch_56ckpt_FebTest.zarr" +) data_path = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr" + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr" ) tracks_path = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr" + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr" ) # %% Load embedding datasets for all three sampling -fov_name = '/B/4/6' +fov_name = "/B/4/6" track_id = 52 embedding_dataset_30_min = read_embedding_dataset(features_path_30_min) @@ -52,30 +53,48 @@ embedding_dataset_any_time = read_embedding_dataset(features_path_any_time) # Calculate cosine similarities for each sampling -time_points_30_min, cosine_similarities_30_min = calculate_cosine_similarity_cell(embedding_dataset_30_min, fov_name, track_id) -time_points_no_track, cosine_similarities_no_track = calculate_cosine_similarity_cell(embedding_dataset_no_track, fov_name, track_id) -time_points_any_time, cosine_similarities_any_time = calculate_cosine_similarity_cell(embedding_dataset_any_time, fov_name, track_id) +time_points_30_min, cosine_similarities_30_min = calculate_cosine_similarity_cell( + embedding_dataset_30_min, fov_name, track_id +) +time_points_no_track, cosine_similarities_no_track = calculate_cosine_similarity_cell( + embedding_dataset_no_track, fov_name, track_id +) +time_points_any_time, cosine_similarities_any_time = calculate_cosine_similarity_cell( + embedding_dataset_any_time, fov_name, track_id +) # %% Plot cosine similarities over time for all three conditions plt.figure(figsize=(10, 6)) -plt.plot(time_points_no_track, cosine_similarities_no_track, marker='o', label='classical contrastive (no tracking)') -plt.plot(time_points_any_time, cosine_similarities_any_time, marker='o', label='cell aware') -plt.plot(time_points_30_min, cosine_similarities_30_min, marker='o', label='cell & time aware (interval 30 min)') +plt.plot( + time_points_no_track, + cosine_similarities_no_track, + marker="o", + label="classical contrastive (no tracking)", +) +plt.plot( + time_points_any_time, cosine_similarities_any_time, marker="o", label="cell aware" +) +plt.plot( + time_points_30_min, + cosine_similarities_30_min, + marker="o", + label="cell & time aware (interval 30 min)", +) plt.xlabel("Time Delay (t)") plt.ylabel("Cosine Similarity with First Time Point") plt.title("Cosine Similarity Over Time for Infected Cell") -#plt.savefig('infected_cell_example.pdf', format='pdf') +# plt.savefig('infected_cell_example.pdf', format='pdf') plt.grid(True) plt.legend() -plt.savefig('new_example_cell.svg', format='svg') +plt.savefig("new_example_cell.svg", format="svg") plt.show() @@ -85,27 +104,20 @@ # %% import statements -from pathlib import Path -import numpy as np -import matplotlib.pyplot as plt -from sklearn.metrics.pairwise import euclidean_distances - - -from viscy.representation.embedding_writer import read_embedding_dataset -from viscy.representation.evaluation import dataset_of_tracks, load_annotation -from sklearn.metrics.pairwise import cosine_similarity - - # %% Paths to datasets -features_path_30_min = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr") -feature_path_no_track = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_random_sampling2/feb_fixed_test_predict.zarr") -#features_path_any_time = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_difcell_randomtime_sampling/Ver2_updateTracking_refineModel/predictions/Feb_1chan_128patch_32projDim/1chan_128patch_63ckpt_FebTest.zarr") +features_path_30_min = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr" +) +feature_path_no_track = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_random_sampling2/feb_fixed_test_predict.zarr" +) +# features_path_any_time = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_difcell_randomtime_sampling/Ver2_updateTracking_refineModel/predictions/Feb_1chan_128patch_32projDim/1chan_128patch_63ckpt_FebTest.zarr") # %% Read embedding datasets embedding_dataset_30_min = read_embedding_dataset(features_path_30_min) embedding_dataset_no_track = read_embedding_dataset(feature_path_no_track) -#embedding_dataset_any_time = read_embedding_dataset(features_path_any_time) +# embedding_dataset_any_time = read_embedding_dataset(features_path_any_time) # %% Compute displacements for both datasets (using Euclidean distance and Cosine similarity) @@ -114,23 +126,31 @@ # mean_displacement_30_min, std_displacement_30_min = compute_displacement_mean_std(embedding_dataset_30_min, max_tau, use_cosine=False, use_dissimilarity=False) # mean_displacement_no_track, std_displacement_no_track = compute_displacement_mean_std(embedding_dataset_no_track, max_tau, use_cosine=False, use_dissimilarity=False) -#mean_displacement_any_time, std_displacement_any_time = compute_displacement_mean_std(embedding_dataset_any_time, max_tau, use_cosine=False) +# mean_displacement_any_time, std_displacement_any_time = compute_displacement_mean_std(embedding_dataset_any_time, max_tau, use_cosine=False) -mean_displacement_30_min_cosine, std_displacement_30_min_cosine = compute_displacement_mean_std(embedding_dataset_30_min, max_tau, use_cosine=True, use_dissimilarity=False) -mean_displacement_no_track_cosine, std_displacement_no_track_cosine = compute_displacement_mean_std(embedding_dataset_no_track, max_tau, use_cosine=True, use_dissimilarity=False) -#mean_displacement_any_time_cosine, std_displacement_any_time_cosine = compute_displacement_mean_std(embedding_dataset_any_time, max_tau, use_cosine=True) +mean_displacement_30_min_cosine, std_displacement_30_min_cosine = ( + compute_displacement_mean_std( + embedding_dataset_30_min, max_tau, use_cosine=True, use_dissimilarity=False + ) +) +mean_displacement_no_track_cosine, std_displacement_no_track_cosine = ( + compute_displacement_mean_std( + embedding_dataset_no_track, max_tau, use_cosine=True, use_dissimilarity=False + ) +) +# mean_displacement_any_time_cosine, std_displacement_any_time_cosine = compute_displacement_mean_std(embedding_dataset_any_time, max_tau, use_cosine=True) # %% Plot 1: Euclidean Displacements plt.figure(figsize=(10, 6)) -taus = list(mean_displacement_30_min.keys()) -mean_values_30_min = list(mean_displacement_30_min.values()) -std_values_30_min = list(std_displacement_30_min.values()) +taus = list(mean_displacement_30_min_cosine.keys()) +mean_values_30_min = list(mean_displacement_30_min_cosine.values()) +std_values_30_min = list(std_displacement_30_min_cosine.values()) -mean_values_no_track = list(mean_displacement_no_track.values()) -std_values_no_track = list(std_displacement_no_track.values()) +mean_values_no_track = list(mean_displacement_no_track_cosine.values()) +std_values_no_track = list(std_displacement_no_track_cosine.values()) # mean_values_any_time = list(mean_displacement_any_time.values()) @@ -138,19 +158,35 @@ # Plotting Euclidean displacements -plt.plot(taus, mean_values_30_min, marker='o', label='Cell & Time Aware (30 min interval)') -plt.fill_between(taus, np.array(mean_values_30_min) - np.array(std_values_30_min), np.array(mean_values_30_min) + np.array(std_values_30_min), - color='gray', alpha=0.3, label='Std Dev (30 min interval)') +plt.plot( + taus, mean_values_30_min, marker="o", label="Cell & Time Aware (30 min interval)" +) +plt.fill_between( + taus, + np.array(mean_values_30_min) - np.array(std_values_30_min), + np.array(mean_values_30_min) + np.array(std_values_30_min), + color="gray", + alpha=0.3, + label="Std Dev (30 min interval)", +) -plt.plot(taus, mean_values_no_track, marker='o', label='Classical Contrastive (No Tracking)') -plt.fill_between(taus, np.array(mean_values_no_track) - np.array(std_values_no_track), np.array(mean_values_no_track) + np.array(std_values_no_track), - color='blue', alpha=0.3, label='Std Dev (No Tracking)') +plt.plot( + taus, mean_values_no_track, marker="o", label="Classical Contrastive (No Tracking)" +) +plt.fill_between( + taus, + np.array(mean_values_no_track) - np.array(std_values_no_track), + np.array(mean_values_no_track) + np.array(std_values_no_track), + color="blue", + alpha=0.3, + label="Std Dev (No Tracking)", +) -plt.xlabel('Time Shift (τ)') -plt.ylabel('Displacement') -plt.title('Embedding Displacement Over Time') +plt.xlabel("Time Shift (τ)") +plt.ylabel("Displacement") +plt.title("Embedding Displacement Over Time") plt.grid(True) plt.legend() @@ -177,45 +213,59 @@ std_values_no_track_cosine = list(std_displacement_no_track_cosine.values()) -plt.plot(taus, mean_values_30_min_cosine, marker='o', label='Cell & Time Aware (30 min interval)') -plt.fill_between(taus, np.array(mean_values_30_min_cosine) - np.array(std_values_30_min_cosine), np.array(mean_values_30_min_cosine) + np.array(std_values_30_min_cosine), - color='gray', alpha=0.3, label='Std Dev (30 min interval)') - - -plt.plot(taus, mean_values_no_track_cosine, marker='o', label='Classical Contrastive (No Tracking)') -plt.fill_between(taus, np.array(mean_values_no_track_cosine) - np.array(std_values_no_track_cosine), np.array(mean_values_no_track_cosine) + np.array(std_values_no_track_cosine), - color='blue', alpha=0.3, label='Std Dev (No Tracking)') +plt.plot( + taus, + mean_values_30_min_cosine, + marker="o", + label="Cell & Time Aware (30 min interval)", +) +plt.fill_between( + taus, + np.array(mean_values_30_min_cosine) - np.array(std_values_30_min_cosine), + np.array(mean_values_30_min_cosine) + np.array(std_values_30_min_cosine), + color="gray", + alpha=0.3, + label="Std Dev (30 min interval)", +) -plt.xlabel('Time Shift (τ)') -plt.ylabel('Cosine Similarity') -plt.title('Embedding Displacement Over Time') +plt.plot( + taus, + mean_values_no_track_cosine, + marker="o", + label="Classical Contrastive (No Tracking)", +) +plt.fill_between( + taus, + np.array(mean_values_no_track_cosine) - np.array(std_values_no_track_cosine), + np.array(mean_values_no_track_cosine) + np.array(std_values_no_track_cosine), + color="blue", + alpha=0.3, + label="Std Dev (No Tracking)", +) +plt.xlabel("Time Shift (τ)") +plt.ylabel("Cosine Similarity") +plt.title("Embedding Displacement Over Time") plt.grid(True) plt.legend() -plt.savefig('1_std_cosine_plot.svg', format='svg') +plt.savefig("1_std_cosine_plot.svg", format="svg") # Show the Cosine plot plt.show() # %% -import seaborn as sns -import pandas as pd -import numpy as np -import matplotlib.pyplot as plt -from pathlib import Path -from collections import defaultdict -from sklearn.metrics.pairwise import cosine_similarity -from viscy.representation.embedding_writer import read_embedding_dataset - - # %% Paths to datasets -features_path_30_min = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr") -feature_path_no_track = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_random_sampling2/feb_fixed_test_predict.zarr") +features_path_30_min = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr" +) +feature_path_no_track = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_random_sampling2/feb_fixed_test_predict.zarr" +) # %% Read embedding datasets @@ -228,74 +278,97 @@ # Compute displacements for Cell & Time Aware (30 min interval) using Cosine similarity -displacement_per_tau_aware_cosine = compute_displacement(embedding_dataset_30_min, max_tau, use_cosine=True, use_dissimilarity=False, use_umap=False) +displacement_per_tau_aware_cosine = compute_displacement( + embedding_dataset_30_min, + max_tau, + use_cosine=True, + use_dissimilarity=False, + use_umap=False, +) # Compute displacements for Classical Contrastive (No Tracking) using Cosine similarity -displacement_per_tau_contrastive_cosine = compute_displacement(embedding_dataset_no_track, max_tau, use_cosine=True, use_dissimilarity=False, use_umap=False) +displacement_per_tau_contrastive_cosine = compute_displacement( + embedding_dataset_no_track, + max_tau, + use_cosine=True, + use_dissimilarity=False, + use_umap=False, +) # %% Prepare data for violin plot def prepare_violin_data(taus, displacement_aware, displacement_contrastive): - # Create a list to hold the data - data = [] - - - # Populate the data for Cell & Time Aware - for tau in taus: - displacements_aware = displacement_aware.get(tau, []) - for displacement in displacements_aware: - data.append({'Time Shift (τ)': tau, 'Displacement': displacement, 'Sampling': 'Cell & Time Aware (30 min interval)'}) - + # Create a list to hold the data + data = [] - # Populate the data for Classical Contrastive - for tau in taus: - displacements_contrastive = displacement_contrastive.get(tau, []) - for displacement in displacements_contrastive: - data.append({'Time Shift (τ)': tau, 'Displacement': displacement, 'Sampling': 'Classical Contrastive (No Tracking)'}) + # Populate the data for Cell & Time Aware + for tau in taus: + displacements_aware = displacement_aware.get(tau, []) + for displacement in displacements_aware: + data.append( + { + "Time Shift (τ)": tau, + "Displacement": displacement, + "Sampling": "Cell & Time Aware (30 min interval)", + } + ) + # Populate the data for Classical Contrastive + for tau in taus: + displacements_contrastive = displacement_contrastive.get(tau, []) + for displacement in displacements_contrastive: + data.append( + { + "Time Shift (τ)": tau, + "Displacement": displacement, + "Sampling": "Classical Contrastive (No Tracking)", + } + ) - # Convert to a DataFrame - df = pd.DataFrame(data) - return df + # Convert to a DataFrame + df = pd.DataFrame(data) + return df taus = list(displacement_per_tau_aware_cosine.keys()) # Prepare the violin plot data -df = prepare_violin_data(taus, displacement_per_tau_aware_cosine, displacement_per_tau_contrastive_cosine) +df = prepare_violin_data( + taus, displacement_per_tau_aware_cosine, displacement_per_tau_contrastive_cosine +) # Create a violin plot using seaborn plt.figure(figsize=(12, 8)) sns.violinplot( - x='Time Shift (τ)', - y='Displacement', - hue='Sampling', - data=df, - palette='Set2', - scale='width', - bw=0.2, - inner=None, - split=True, - cut=0 + x="Time Shift (τ)", + y="Displacement", + hue="Sampling", + data=df, + palette="Set2", + scale="width", + bw=0.2, + inner=None, + split=True, + cut=0, ) # Add labels and title -plt.xlabel('Time Shift (τ)', fontsize=14) -plt.ylabel('Cosine Similarity', fontsize=14) -plt.title('Cosine Similarity Distribution on Features', fontsize=16) -plt.grid(True, linestyle='--', alpha=0.6) # Lighter grid lines for less distraction -plt.legend(title='Sampling', fontsize=12, title_fontsize=14) +plt.xlabel("Time Shift (τ)", fontsize=14) +plt.ylabel("Cosine Similarity", fontsize=14) +plt.title("Cosine Similarity Distribution on Features", fontsize=16) +plt.grid(True, linestyle="--", alpha=0.6) # Lighter grid lines for less distraction +plt.legend(title="Sampling", fontsize=12, title_fontsize=14) -#plt.ylim(0.5, 1.0) +# plt.ylim(0.5, 1.0) # Save the violin plot as SVG and PDF -plt.savefig('1fixed_violin_plot_cosine_similarity.svg', format='svg') +plt.savefig("1fixed_violin_plot_cosine_similarity.svg", format="svg") # plt.savefig('violin_plot_cosine_similarity.pdf', format='pdf') @@ -303,45 +376,40 @@ def prepare_violin_data(taus, displacement_aware, displacement_contrastive): plt.show() # %% using umap violin plot -# Import necessary libraries, try euclidean distance for both features and -import seaborn as sns -import pandas as pd -import numpy as np -import matplotlib.pyplot as plt -from pathlib import Path -from collections import defaultdict -from sklearn.metrics.pairwise import cosine_similarity -from viscy.representation.embedding_writer import read_embedding_dataset -from sklearn.preprocessing import StandardScaler -from umap import UMAP -import xarray as xr - # %% Paths to datasets -features_path_30_min = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr") -feature_path_no_track = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_random_sampling2/feb_fixed_test_predict.zarr") +features_path_30_min = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr" +) +feature_path_no_track = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_random_sampling2/feb_fixed_test_predict.zarr" +) # %% Read embedding datasets embedding_dataset_30_min = read_embedding_dataset(features_path_30_min) embedding_dataset_no_track = read_embedding_dataset(feature_path_no_track) + # %% Compute UMAP on features def compute_umap(dataset): features = dataset["features"] scaled_features = StandardScaler().fit_transform(features.values) umap = UMAP(n_components=2) # Reduce to 2 dimensions embedding = umap.fit_transform(scaled_features) - + # Add UMAP coordinates using xarray functionality - umap_features = features.assign_coords(UMAP1=("sample", embedding[:, 0]), UMAP2=("sample", embedding[:, 1])) + umap_features = features.assign_coords( + UMAP1=("sample", embedding[:, 0]), UMAP2=("sample", embedding[:, 1]) + ) return umap_features + # Apply UMAP to both datasets umap_features_30_min = compute_umap(embedding_dataset_30_min) umap_features_no_track = compute_umap(embedding_dataset_no_track) # %% print(umap_features_30_min) -# %% Visualize UMAP embeddings +# %% Visualize UMAP embeddings # # Visualize UMAP embeddings for the 30 min interval # plt.figure(figsize=(8, 6)) # plt.scatter(umap_features_30_min[:, 0], umap_features_30_min[:, 1], c=embedding_dataset_30_min["t"].values, cmap='viridis') @@ -363,10 +431,23 @@ def compute_umap(dataset): max_tau = 10 # Maximum time shift (tau) to compute displacements # Compute displacements for UMAP-processed Cell & Time Aware (30 min interval) -displacement_per_tau_aware_umap_cosine = compute_displacement(umap_features_30_min, max_tau, use_cosine=True, use_dissimilarity=False, use_umap=True) +displacement_per_tau_aware_umap_cosine = compute_displacement( + umap_features_30_min, + max_tau, + use_cosine=True, + use_dissimilarity=False, + use_umap=True, +) # Compute displacements for UMAP-processed Classical Contrastive (No Tracking) -displacement_per_tau_contrastive_umap_cosine = compute_displacement(umap_features_no_track, max_tau, use_cosine=True, use_dissimilarity=False, use_umap=True) +displacement_per_tau_contrastive_umap_cosine = compute_displacement( + umap_features_no_track, + max_tau, + use_cosine=True, + use_dissimilarity=False, + use_umap=True, +) + # %% Prepare data for violin plot def prepare_violin_data(taus, displacement_aware, displacement_contrastive): @@ -377,56 +458,72 @@ def prepare_violin_data(taus, displacement_aware, displacement_contrastive): for tau in taus: displacements_aware = displacement_aware.get(tau, []) for displacement in displacements_aware: - data.append({'Time Shift (τ)': tau, 'Displacement': displacement, 'Sampling': 'Cell & Time Aware (30 min interval)'}) + data.append( + { + "Time Shift (τ)": tau, + "Displacement": displacement, + "Sampling": "Cell & Time Aware (30 min interval)", + } + ) # Populate the data for Classical Contrastive for tau in taus: displacements_contrastive = displacement_contrastive.get(tau, []) for displacement in displacements_contrastive: - data.append({'Time Shift (τ)': tau, 'Displacement': displacement, 'Sampling': 'Classical Contrastive (No Tracking)'}) + data.append( + { + "Time Shift (τ)": tau, + "Displacement": displacement, + "Sampling": "Classical Contrastive (No Tracking)", + } + ) # Convert to a DataFrame df = pd.DataFrame(data) return df + taus = list(displacement_per_tau_aware_umap_cosine.keys()) # Prepare the violin plot data -df = prepare_violin_data(taus, displacement_per_tau_aware_umap_cosine, displacement_per_tau_contrastive_umap_cosine) +df = prepare_violin_data( + taus, + displacement_per_tau_aware_umap_cosine, + displacement_per_tau_contrastive_umap_cosine, +) # %% Create a violin plot using seaborn plt.figure(figsize=(12, 8)) sns.violinplot( - x='Time Shift (τ)', - y='Displacement', - hue='Sampling', + x="Time Shift (τ)", + y="Displacement", + hue="Sampling", data=df, - palette='Set2', - scale='width', + palette="Set2", + scale="width", bw=0.2, inner=None, split=True, - cut=0 + cut=0, ) # Add labels and title -plt.xlabel('Time Shift (τ)', fontsize=14) -plt.ylabel('Cosine Similarity', fontsize=14) -plt.title('Cosine Similarity Distribution using UMAP Features', fontsize=16) -plt.grid(True, linestyle='--', alpha=0.6) # Lighter grid lines for less distraction -plt.legend(title='Sampling', fontsize=12, title_fontsize=14) +plt.xlabel("Time Shift (τ)", fontsize=14) +plt.ylabel("Cosine Similarity", fontsize=14) +plt.title("Cosine Similarity Distribution using UMAP Features", fontsize=16) +plt.grid(True, linestyle="--", alpha=0.6) # Lighter grid lines for less distraction +plt.legend(title="Sampling", fontsize=12, title_fontsize=14) -#plt.ylim(0, 1) +# plt.ylim(0, 1) # Save the violin plot as SVG and PDF -plt.savefig('fixed_plot_cosine_similarity.svg', format='svg') +plt.savefig("fixed_plot_cosine_similarity.svg", format="svg") # plt.savefig('violin_plot_cosine_similarity_umap.pdf', format='pdf') # Show the plot plt.show() - # %% # %% Visualize Displacement Distributions (Example Code) # Compare displacement distributions for τ = 1 diff --git a/applications/contrastive_phenotyping/evaluation/pca_umap_embeddings_time.py b/applications/contrastive_phenotyping/evaluation/pca_umap_embeddings_time.py index 69adb5fba..5f59da3e0 100644 --- a/applications/contrastive_phenotyping/evaluation/pca_umap_embeddings_time.py +++ b/applications/contrastive_phenotyping/evaluation/pca_umap_embeddings_time.py @@ -1,33 +1,26 @@ # %% from pathlib import Path - import matplotlib.pyplot as plt -import numpy as np -import pandas as pd -import plotly.express as px import seaborn as sns from sklearn.decomposition import PCA from sklearn.preprocessing import StandardScaler from umap import UMAP -from sklearn.decomposition import PCA - from viscy.representation.embedding_writer import read_embedding_dataset -from viscy.representation.evaluation import dataset_of_tracks, load_annotation - +from viscy.representation.evaluation import load_annotation # %% Paths and parameters. features_path = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr" + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr" ) data_path = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr" + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr" ) tracks_path = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr" + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr" ) @@ -54,17 +47,15 @@ features = ( - features.assign_coords(UMAP1=("sample", embedding[:, 0])) - .assign_coords(UMAP2=("sample", embedding[:, 1])) - .set_index(sample=["UMAP1", "UMAP2"], append=True) + features.assign_coords(UMAP1=("sample", embedding[:, 0])) + .assign_coords(UMAP2=("sample", embedding[:, 1])) + .set_index(sample=["UMAP1", "UMAP2"], append=True) ) features - - sns.scatterplot( - x=features["UMAP1"], y=features["UMAP2"], hue=features["t"], s=7, alpha=0.8 + x=features["UMAP1"], y=features["UMAP2"], hue=features["t"], s=7, alpha=0.8 ) @@ -72,8 +63,8 @@ plt.title("Cell & Time Aware Sampling (30 min interval)") plt.xlim(-10, 20) plt.ylim(-10, 20) -#plt.savefig('umap_cell_time_aware_time.svg', format='svg') -plt.savefig('updated_cell_time_aware_time.png', format='png') +# plt.savefig('umap_cell_time_aware_time.svg', format='svg') +plt.savefig("updated_cell_time_aware_time.png", format="png") # Show the plot plt.show() @@ -81,7 +72,9 @@ # %% -any_features_path = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_difcell_randomtime_sampling/Ver2_updateTracking_refineModel/predictions/Feb_2chan_128patch_32projDim/2chan_128patch_56ckpt_FebTest.zarr") +any_features_path = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_difcell_randomtime_sampling/Ver2_updateTracking_refineModel/predictions/Feb_2chan_128patch_32projDim/2chan_128patch_56ckpt_FebTest.zarr" +) embedding_dataset = read_embedding_dataset(any_features_path) embedding_dataset @@ -103,15 +96,15 @@ features = ( - features.assign_coords(UMAP1=("sample", embedding[:, 0])) - .assign_coords(UMAP2=("sample", embedding[:, 1])) - .set_index(sample=["UMAP1", "UMAP2"], append=True) + features.assign_coords(UMAP1=("sample", embedding[:, 0])) + .assign_coords(UMAP2=("sample", embedding[:, 1])) + .set_index(sample=["UMAP1", "UMAP2"], append=True) ) features sns.scatterplot( - x=features["UMAP1"], y=features["UMAP2"], hue=features["t"], s=7, alpha=0.8 + x=features["UMAP1"], y=features["UMAP2"], hue=features["t"], s=7, alpha=0.8 ) @@ -121,8 +114,8 @@ plt.xlim(-10, 20) plt.ylim(-10, 20) -plt.savefig('1_updated_cell_aware_time.png', format='png') -#plt.savefig('umap_cell_aware_time.pdf', format='pdf') +plt.savefig("1_updated_cell_aware_time.png", format="png") +# plt.savefig('umap_cell_aware_time.pdf', format='pdf') # Show the plot plt.show() @@ -130,7 +123,9 @@ # %% -contrastive_learning_path = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_random_sampling2/feb_fixed_test_predict.zarr") +contrastive_learning_path = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_random_sampling2/feb_fixed_test_predict.zarr" +) embedding_dataset = read_embedding_dataset(contrastive_learning_path) embedding_dataset @@ -152,22 +147,22 @@ features = ( - features.assign_coords(UMAP1=("sample", embedding[:, 0])) - .assign_coords(UMAP2=("sample", embedding[:, 1])) - .set_index(sample=["UMAP1", "UMAP2"], append=True) + features.assign_coords(UMAP1=("sample", embedding[:, 0])) + .assign_coords(UMAP2=("sample", embedding[:, 1])) + .set_index(sample=["UMAP1", "UMAP2"], append=True) ) features sns.scatterplot( - x=features["UMAP1"], y=features["UMAP2"], hue=features["t"], s=7, alpha=0.8 + x=features["UMAP1"], y=features["UMAP2"], hue=features["t"], s=7, alpha=0.8 ) # Add the title to the plot plt.title("Classical Contrastive Learning Sampling") plt.xlim(-10, 20) plt.ylim(-10, 20) -plt.savefig('updated_classical_time.png', format='png') -#plt.savefig('classical_time.pdf', format='pdf') +plt.savefig("updated_classical_time.png", format="png") +# plt.savefig('classical_time.pdf', format='pdf') # Show the plot plt.show() @@ -183,11 +178,11 @@ features = ( - features.assign_coords(PCA1=("sample", pca_features[:, 0])) - .assign_coords(PCA2=("sample", pca_features[:, 1])) - .assign_coords(PCA3=("sample", pca_features[:, 2])) - .assign_coords(PCA4=("sample", pca_features[:, 3])) - .set_index(sample=["PCA1", "PCA2", "PCA3", "PCA4"], append=True) + features.assign_coords(PCA1=("sample", pca_features[:, 0])) + .assign_coords(PCA2=("sample", pca_features[:, 1])) + .assign_coords(PCA3=("sample", pca_features[:, 2])) + .assign_coords(PCA4=("sample", pca_features[:, 3])) + .set_index(sample=["PCA1", "PCA2", "PCA3", "PCA4"], append=True) ) @@ -195,20 +190,22 @@ plt.figure(figsize=(10, 10)) -sns.scatterplot(x=features["PCA1"], y=features["PCA2"], hue=features["t"], s=7, alpha=0.8) +sns.scatterplot( + x=features["PCA1"], y=features["PCA2"], hue=features["t"], s=7, alpha=0.8 +) # %% OVERLAY INFECTION ANNOTATION ann_root = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/supervised_inf_pred" + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/supervised_inf_pred" ) infection = load_annotation( - features, - ann_root / "extracted_inf_state.csv", - "infection_state", - {0.0: "background", 1.0: "uninfected", 2.0: "infected"}, + features, + ann_root / "extracted_inf_state.csv", + "infection_state", + {0.0: "background", 1.0: "uninfected", 2.0: "infected"}, ) From a875bbb9e199edb9e939cbf12a75c739dc03e977 Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Thu, 19 Sep 2024 06:52:04 -0700 Subject: [PATCH 06/18] logistic regression script --- .../evaluation/log_regresssion_training.py | 111 ++++++++++++++++++ 1 file changed, 111 insertions(+) create mode 100644 applications/contrastive_phenotyping/evaluation/log_regresssion_training.py diff --git a/applications/contrastive_phenotyping/evaluation/log_regresssion_training.py b/applications/contrastive_phenotyping/evaluation/log_regresssion_training.py new file mode 100644 index 000000000..0bb7a4b32 --- /dev/null +++ b/applications/contrastive_phenotyping/evaluation/log_regresssion_training.py @@ -0,0 +1,111 @@ + +# %% +from pathlib import Path + + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import plotly.express as px +import seaborn as sns +from sklearn.decomposition import PCA +from sklearn.preprocessing import StandardScaler +from umap import UMAP +from sklearn.decomposition import PCA + + +from viscy.representation.embedding_writer import read_embedding_dataset +from viscy.representation.evaluation import dataset_of_tracks, load_annotation + + +# %% Paths and parameters. + + +features_path = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr" +) +data_path = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr" +) +tracks_path = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr" +) + + +# %% +embedding_dataset = read_embedding_dataset(features_path) +embedding_dataset + +# %% +# Compute UMAP over all features +features = embedding_dataset["features"] +# or select a well: +# features = features[features["fov_name"].str.contains("B/4")] + +# %% OVERLAY INFECTION ANNOTATION +ann_root = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/supervised_inf_pred" +) + + +infection = load_annotation( + features, + ann_root / "extracted_inf_state.csv", + "infection_state", + {0.0: "background", 1.0: "uninfected", 2.0: "infected"}, +) + +# %% plot the umap + +infection_npy = infection.cat.codes.values + +# Filter out the background class +infection_npy_filtered = infection_npy[infection_npy != 0] + +feature_npy = features.values +feature_npy_filtered = feature_npy[infection_npy != 0] + +# %% combine the umap, pca and infection annotation in one dataframe + +data = pd.DataFrame({"infection": infection_npy_filtered}) + +# add time and well info into dataframe +time_npy = features["t"].values +time_npy_filtered = time_npy[infection_npy != 0] +data["time"] = time_npy_filtered + +fov_name_list = features["fov_name"].values +fov_name_list_filtered = fov_name_list[infection_npy != 0] +data["fov_name"] = fov_name_list_filtered + +# Add all 768 features to the dataframe +for i in range(768): + data[f"feature_{i+1}"] = feature_npy_filtered[:, i] + +# %% manually split the dataset into training and testing set by well name + +# dataframe for training set, fov names starts with "/B/4/6" or "/B/4/7" or "/A/3/" +data_train_val = data[data["fov_name"].str.contains("/B/4/6") | data["fov_name"].str.contains("/B/4/7") | data["fov_name"].str.contains("/A/3/")] + +# dataframe for testing set, fov names starts with "/B/4/8" or "/B/4/9" or "/A/4/" +data_test = data[data["fov_name"].str.contains("/B/4/8") | data["fov_name"].str.contains("/B/4/9") | data["fov_name"].str.contains("/B/3/")] + +# %% train a linear classifier to predict infection state from PCA components + +from sklearn.linear_model import LogisticRegression +from sklearn.model_selection import train_test_split +from sklearn.metrics import classification_report + +x_train = data_train_val.drop(columns=["infection", "fov_name", "time"]) +y_train = data_train_val["infection"] + +# train a logistic regression model +clf = LogisticRegression(random_state=0).fit(x_train, y_train) + +x_test = data_test.drop(columns=["infection", "fov_name", "time"]) +y_test = data_test["infection"] + +# predict the infection state for the testing set +y_pred = clf.predict(x_test) + +# %% From fd73c11e8d48e5a5a976b447e0dd18da8064944d Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Thu, 19 Sep 2024 06:52:47 -0700 Subject: [PATCH 07/18] add infection figure script --- .../figures/figure_cell_infection.py | 237 ++++++++++++++++++ 1 file changed, 237 insertions(+) create mode 100644 applications/contrastive_phenotyping/figures/figure_cell_infection.py diff --git a/applications/contrastive_phenotyping/figures/figure_cell_infection.py b/applications/contrastive_phenotyping/figures/figure_cell_infection.py new file mode 100644 index 000000000..18add8718 --- /dev/null +++ b/applications/contrastive_phenotyping/figures/figure_cell_infection.py @@ -0,0 +1,237 @@ + +# %% +from pathlib import Path +import sys + +sys.path.append("/hpc/mydata/soorya.pradeep/scratch/viscy_infection_phenotyping/VisCy") + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import plotly.express as px +import seaborn as sns +from sklearn.decomposition import PCA +from sklearn.preprocessing import StandardScaler +from umap import UMAP +from sklearn.decomposition import PCA + + +from viscy.representation.embedding_writer import read_embedding_dataset +from viscy.representation.evaluation import dataset_of_tracks, load_annotation + + +# %% Paths and parameters. + + +features_path = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr" +) +data_path = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr" +) +tracks_path = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr" +) + + +# %% +embedding_dataset = read_embedding_dataset(features_path) +embedding_dataset + +# %% +# Compute UMAP over all features +features = embedding_dataset["features"] +# or select a well: +# features = features[features["fov_name"].str.contains("B/4")] + + +scaled_features = StandardScaler().fit_transform(features.values) +umap = UMAP() +# Fit UMAP on all features +embedding = umap.fit_transform(scaled_features) + +features = ( + features.assign_coords(UMAP1=("sample", embedding[:, 0])) + .assign_coords(UMAP2=("sample", embedding[:, 1])) + .set_index(sample=["UMAP1", "UMAP2"], append=True) +) +features + +pca = PCA(n_components=4) +# scaled_features = StandardScaler().fit_transform(features.values) +# pca_features = pca.fit_transform(scaled_features) +pca_features = pca.fit_transform(features.values) + + +features = ( + features.assign_coords(PCA1=("sample", pca_features[:, 0])) + .assign_coords(PCA2=("sample", pca_features[:, 1])) + .assign_coords(PCA3=("sample", pca_features[:, 2])) + .assign_coords(PCA4=("sample", pca_features[:, 3])) + .set_index(sample=["PCA1", "PCA2", "PCA3", "PCA4"], append=True) +) + +# %% OVERLAY INFECTION ANNOTATION +ann_root = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/supervised_inf_pred" +) + +infection = load_annotation( + features, + ann_root / "extracted_inf_state.csv", + "infection_state", + {0.0: "background", 1.0: "uninfected", 2.0: "infected"}, +) + +# %% plot the umap + +# remove the rows in umap and annotation for background class +# Convert UMAP coordinates to a DataFrame +umap_npy = embedding.copy() +infection_npy = infection.cat.codes.values + +# Filter out the background class +umap_npy_filtered = umap_npy[infection_npy != 0] +infection_npy_filtered = infection_npy[infection_npy != 0] + +feature_npy = features.values +feature_npy_filtered = feature_npy[infection_npy != 0] + +sns.scatterplot( + x=umap_npy_filtered[:, 0], + y=umap_npy_filtered[:, 1], + hue=infection_npy_filtered, + palette={1: 'steelblue', 2: 'orangered'}, + hue_order=[1, 2], + s=7, + alpha=0.8, +) +plt.legend([], [], frameon=False) +plt.savefig('/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/infection/umap_infection.png', format='png', dpi=300) + +# %% plot PCA components with infection hue + +pca_npy = pca_features.copy() +pca_npy_filtered = pca_npy[infection_npy != 0] + +sns.scatterplot( + x=pca_npy_filtered[:, 0], + y=pca_npy_filtered[:, 1], + hue=infection_npy_filtered, + palette={1: 'steelblue', 2: 'orangered'}, + hue_order=[1, 2], + s=7, + alpha=0.8, +) +plt.legend([], [], frameon=False) +plt.savefig('/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/infection/pca_infection.png', format='png', dpi=300) + +# %% combine the umap, pca and infection annotation in one dataframe + +data = pd.DataFrame( + { + "UMAP1": umap_npy_filtered[:, 0], + "UMAP2": umap_npy_filtered[:, 1], + "PCA1": pca_npy_filtered[:, 0], + "PCA2": pca_npy_filtered[:, 1], + "PCA3": pca_npy_filtered[:, 2], + "PCA4": pca_npy_filtered[:, 3], + "infection": infection_npy_filtered, + } +) + +# add time and well info into dataframe +time_npy = features["t"].values +time_npy_filtered = time_npy[infection_npy != 0] +data["time"] = time_npy_filtered + +fov_name_list = features["fov_name"].values +fov_name_list_filtered = fov_name_list[infection_npy != 0] +data["fov_name"] = fov_name_list_filtered + +# Add all 768 features to the dataframe +for i in range(768): + data[f"feature_{i+1}"] = feature_npy_filtered[:, i] + +# %% manually split the dataset into training and testing set by well name + +# dataframe for training set, fov names starts with "/B/4/6" or "/B/4/7" or "/A/3/" +data_train_val = data[data["fov_name"].str.contains("/B/4/6") | data["fov_name"].str.contains("/B/4/7") | data["fov_name"].str.contains("/A/3/")] + +# dataframe for testing set, fov names starts with "/B/4/8" or "/B/4/9" or "/A/4/" +data_test = data[data["fov_name"].str.contains("/B/4/8") | data["fov_name"].str.contains("/B/4/9") | data["fov_name"].str.contains("/B/3/")] + +# %% train a linear classifier to predict infection state from PCA components + +from sklearn.linear_model import LogisticRegression +from sklearn.model_selection import train_test_split +from sklearn.metrics import classification_report + +x_train = data_train_val.drop(columns=["infection", "fov_name", "time", "UMAP1", "UMAP2", "PCA1", "PCA2", "PCA3", "PCA4"]) +y_train = data_train_val["infection"] + +# train a logistic regression model +clf = LogisticRegression(random_state=0).fit(x_train, y_train) + +x_test = data_test.drop(columns=["infection", "fov_name", "time", "UMAP1", "UMAP2", "PCA1", "PCA2", "PCA3", "PCA4"]) +y_test = data_test["infection"] + +# predict the infection state for the testing set +y_pred = clf.predict(x_test) + +# %% construct confusion matrix to compare the true and predicted infection state + +from sklearn.metrics import confusion_matrix +import seaborn as sns + +cm = confusion_matrix(y_test, y_pred) +cm_percentage = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100 +sns.heatmap(cm_percentage, annot=True, fmt=".2f", cmap="viridis") +plt.xlabel("Predicted") +plt.ylabel("True") +plt.title("Confusion Matrix (Percentage)") +plt.xticks(ticks=[0.5, 1.5], labels=['uninfected', 'infected']) +plt.yticks(ticks=[0.5, 1.5], labels=['uninfected', 'infected']) +plt.savefig('/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/infection/confusion_matrix.svg', format='svg') + +# %% use the trained classifier to perform prediction on the entire dataset + +data_test["predicted_infection"] = y_pred + +# plot the predicted infection state over time for /B/3 well and /B/4 well +time_points_test = np.unique(data_test["time"]) + +infected_test_cntrl = [] +infected_test_infected = [] + +for time in time_points_test: + infected_cell = data_test[(data_test['fov_name'].str.startswith('/B/3')) & (data_test['time'] == time) & (data_test['predicted_infection'] == 2)].shape[0] + total_cell = data_test[(data_test['fov_name'].str.startswith('/B/3')) & (data_test['time'] == time)].shape[0] + infected_test_cntrl.append(infected_cell*100 / total_cell) + infected_cell = data_test[(data_test['fov_name'].str.startswith('/B/4')) & (data_test['time'] == time) & (data_test['predicted_infection'] == 2)].shape[0] + total_cell = data_test[(data_test['fov_name'].str.startswith('/B/4')) & (data_test['time'] == time)].shape[0] + infected_test_infected.append(infected_cell*100 /total_cell) + + +infected_true_cntrl = [] +infected_true_infected = [] + +for time in time_points_test: + infected_cell = data_test[(data_test['fov_name'].str.startswith('/B/3')) & (data_test['time'] == time) & (data_test['infection'] == 2)].shape[0] + total_cell = data_test[(data_test['fov_name'].str.startswith('/B/3')) & (data_test['time'] == time)].shape[0] + infected_true_cntrl.append(infected_cell*100 / total_cell) + infected_cell = data_test[(data_test['fov_name'].str.startswith('/B/4')) & (data_test['time'] == time) & (data_test['infection'] == 2)].shape[0] + total_cell = data_test[(data_test['fov_name'].str.startswith('/B/4')) & (data_test['time'] == time)].shape[0] + infected_true_infected.append(infected_cell*100 /total_cell) + +# plot infected percentage over time for both wells +plt.plot(time_points_test*0.5 + 3, infected_test_cntrl, label='mock predicted') +plt.plot(time_points_test*0.5 + 3, infected_test_infected, label='infected predicted') +plt.plot(time_points_test*0.5 + 3, infected_true_cntrl, label='mock true') +plt.plot(time_points_test*0.5 + 3, infected_true_infected, label='infected true') +plt.xlabel('Time (hours)') +plt.ylabel('Infected percentage') +plt.legend() +plt.savefig('/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/infection/infected_percentage.svg', format='svg') + +# %% From 76fae6396406da3a14270cd750ac7a6e7f66dcaa Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Sun, 22 Sep 2024 21:04:16 -0700 Subject: [PATCH 08/18] Add script for generating infection figure and perform prediction on the June dataset --- .../figures/figure_cell_infection.py | 182 +++++++++++++++++- 1 file changed, 175 insertions(+), 7 deletions(-) diff --git a/applications/contrastive_phenotyping/figures/figure_cell_infection.py b/applications/contrastive_phenotyping/figures/figure_cell_infection.py index 18add8718..fc3cb2a6e 100644 --- a/applications/contrastive_phenotyping/figures/figure_cell_infection.py +++ b/applications/contrastive_phenotyping/figures/figure_cell_infection.py @@ -224,14 +224,182 @@ total_cell = data_test[(data_test['fov_name'].str.startswith('/B/4')) & (data_test['time'] == time)].shape[0] infected_true_infected.append(infected_cell*100 /total_cell) + +# %% perform prediction on the june dataset + +# Paths and parameters. +features_path = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/jun_time_interval_1_epoch_178.zarr" +) +data_path = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_06_13_SEC61_TOMM20_ZIKV_DENGUE_1/2-register/registered_chunked.zarr" +) +tracks_path = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_06_13_SEC61_TOMM20_ZIKV_DENGUE_1/4.2-tracking/track.zarr" +) + +# %% +embedding_dataset = read_embedding_dataset(features_path) +embedding_dataset + +# %% +june_features = embedding_dataset["features"] + +scaled_features = StandardScaler().fit_transform(june_features.values) +umap = UMAP() +# Fit UMAP on all features +embedding = umap.fit_transform(scaled_features) + +june_features = ( + june_features.assign_coords(UMAP1=("sample", embedding[:, 0])) + .assign_coords(UMAP2=("sample", embedding[:, 1])) + .set_index(sample=["UMAP1", "UMAP2"], append=True) +) +june_features + +pca = PCA(n_components=4) +pca_features = pca.fit_transform(june_features.values) + +# %% + +# sns.scatterplot( +# x=june_features["UMAP1"], +# y=june_features["UMAP2"], +# hue=june_pred, +# palette={1: 'blue', 2: 'red'}, +# hue_order=[1, 2], +# s=7, +# alpha=0.8, +# ) +# plt.legend([], [], frameon=False) +# plt.xlim(0, 15) +# plt.savefig('/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/infection/june_umap_infection.png', format='png', dpi=300) + +# %% plot June and Feb test combined UMAP + +june_umap_npy = embedding.copy() +june_pca_npy = pca_features.copy() +june_data = pd.DataFrame( + { + "UMAP1": june_umap_npy[:, 0], + "UMAP2": june_umap_npy[:, 1], + "PCA1": june_pca_npy[:, 0], + "PCA2": june_pca_npy[:, 1], + "PCA3": june_pca_npy[:, 2], + "PCA4": june_pca_npy[:, 3], + "infection": np.nan, + } +) + +# add time and well info into dataframe +june_data["time"] = june_features["t"].values + +june_data["fov_name"] = june_features["fov_name"].values + +# Add all 768 features to the dataframe +june_features_npy = june_features.values +for i in range(768): + june_data[f"feature_{i+1}"] = june_features_npy[:, i] + +# use one mock and one dengue infecected well only +june_data = june_data[june_data["fov_name"].str.contains("/0/6") | june_data["fov_name"].str.contains("/0/2")] + +# add the predicted infection state +june_pred = clf.predict(june_data.drop(columns=["infection", "fov_name", "time", "UMAP1", "UMAP2", "PCA1", "PCA2", "PCA3", "PCA4"])) +june_data["predicted_infection"] = june_pred + +# %% combine the june and feb data + +combined_data = pd.concat([data_test, june_data]) + +# perform the umap analysis again with the 768 features +features = combined_data.drop(columns=["infection", "predicted_infection", "fov_name", "time", "UMAP1", "UMAP2", "PCA1", "PCA2", "PCA3", "PCA4"]) +scaled_features = StandardScaler().fit_transform(features.values) +umap = UMAP() +# Fit UMAP on all features +embedding = umap.fit_transform(scaled_features) + +# overwrite the umap coordinates on combined data +combined_data["UMAP1"] = embedding[:, 0] +combined_data["UMAP2"] = embedding[:, 1] + +# plot the combined data with 'fov_name' starting with '/A and '/B' hue 'infection' and '/0' hue 'predicted_infection' +Feb_split = combined_data[combined_data["fov_name"].str.contains("/A") | combined_data["fov_name"].str.contains("/B")] +June_split = combined_data[combined_data["fov_name"].str.contains("/0")] + +sns.scatterplot( + x=Feb_split["UMAP1"], + y=Feb_split["UMAP2"], + hue=Feb_split["infection"], + palette={1: 'steelblue', 2: 'orangered'}, + hue_order=[1, 2], + s=7, + alpha=0.8, +) +sns.scatterplot( + x=June_split["UMAP1"], + y=June_split["UMAP2"], + hue=June_split["predicted_infection"], + palette={1: 'blue', 2: 'red'}, + hue_order=[1, 2], + s=7, + alpha=0.8, +) +plt.legend([], [], frameon=False) +# plt.savefig('/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/infection/combined_umap_infection.png', format='png', dpi=300) + +# plot the scatterplot hue well name '/A' and '/B' are blue and '/0' are red +sns.scatterplot( + x=combined_data["UMAP1"], + y=combined_data["UMAP2"], + hue=combined_data["fov_name"].apply(lambda x: 'blue' if x.startswith('/0') else 'red'), + s=7, + alpha=0.8, +) +plt.xlim(5, 15) +plt.ylim(-10, 10) +plt.legend([], [], frameon=False) +plt.savefig('/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/infection/combined_umap_well.png', format='png', dpi=300) + +# plot the predicted infection state with combined data +sns.scatterplot( + x=combined_data["UMAP1"], + y=combined_data["UMAP2"], + hue=combined_data["predicted_infection"], + palette={1: 'blue', 2: 'red'}, + hue_order=[1, 2], + s=7, + alpha=0.8, +) +plt.xlim(5, 15) +plt.ylim(-10, 10) +plt.legend([], [], frameon=False) +plt.savefig('/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/infection/combined_umap_predicted_infection.png', format='png', dpi=300) + +# %% plot % infected over time + +time_points_june = np.unique(June_split["time"]) + +infected_june_cntrl = [] +infected_june_infected = [] + +for time in time_points_june: + infected_june = June_split[(June_split['fov_name'].str.startswith('/0/2')) & (June_split['time'] == time) & (June_split['predicted_infection'] == 2)].shape[0] + total_june = June_split[(June_split['fov_name'].str.startswith('/0/2')) & (June_split['time'] == time)].shape[0] + infected_june_cntrl.append(infected_june*100 / total_june) + infected_june = June_split[(June_split['fov_name'].str.startswith('/0/6')) & (June_split['time'] == time) & (June_split['predicted_infection'] == 2)].shape[0] + total_june = June_split[(June_split['fov_name'].str.startswith('/0/6')) & (June_split['time'] == time)].shape[0] + infected_june_infected.append(infected_june*100 /total_june) + + # plot infected percentage over time for both wells -plt.plot(time_points_test*0.5 + 3, infected_test_cntrl, label='mock predicted') -plt.plot(time_points_test*0.5 + 3, infected_test_infected, label='infected predicted') -plt.plot(time_points_test*0.5 + 3, infected_true_cntrl, label='mock true') -plt.plot(time_points_test*0.5 + 3, infected_true_infected, label='infected true') +plt.plot(time_points_test*0.5 + 3, infected_true_cntrl, label='mock true', color='steelblue') +plt.plot(time_points_test*0.5 + 3, infected_test_cntrl, label='mock predicted', color='blue') +plt.plot(time_points_test*0.5 + 3, infected_true_infected, label='MOI true', color='orangered') +plt.plot(time_points_test*0.5 + 3, infected_test_infected, label='MOI predicted', color='red') +plt.plot(time_points_june*2 + 3, infected_june_cntrl, label='mock new predicted', color='green') +plt.plot(time_points_june*2 + 3, infected_june_infected, label='MOI new predicted', color='brown') plt.xlabel('Time (hours)') plt.ylabel('Infected percentage') plt.legend() -plt.savefig('/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/infection/infected_percentage.svg', format='svg') - -# %% +plt.savefig('/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/infection/infected_percentage_withJune.svg', format='svg') From e835f9b7c93e4ef8146c8577fc3c2e65bfd13954 Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Mon, 23 Sep 2024 10:54:59 -0700 Subject: [PATCH 09/18] Format code --- .../figures/figure_cell_infection.py | 406 ++++++++++++------ 1 file changed, 280 insertions(+), 126 deletions(-) diff --git a/applications/contrastive_phenotyping/figures/figure_cell_infection.py b/applications/contrastive_phenotyping/figures/figure_cell_infection.py index fc3cb2a6e..b47bcc968 100644 --- a/applications/contrastive_phenotyping/figures/figure_cell_infection.py +++ b/applications/contrastive_phenotyping/figures/figure_cell_infection.py @@ -1,4 +1,3 @@ - # %% from pathlib import Path import sys @@ -24,13 +23,13 @@ features_path = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr" + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr" ) data_path = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr" + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr" ) tracks_path = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr" + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr" ) @@ -51,9 +50,9 @@ embedding = umap.fit_transform(scaled_features) features = ( - features.assign_coords(UMAP1=("sample", embedding[:, 0])) - .assign_coords(UMAP2=("sample", embedding[:, 1])) - .set_index(sample=["UMAP1", "UMAP2"], append=True) + features.assign_coords(UMAP1=("sample", embedding[:, 0])) + .assign_coords(UMAP2=("sample", embedding[:, 1])) + .set_index(sample=["UMAP1", "UMAP2"], append=True) ) features @@ -64,23 +63,23 @@ features = ( - features.assign_coords(PCA1=("sample", pca_features[:, 0])) - .assign_coords(PCA2=("sample", pca_features[:, 1])) - .assign_coords(PCA3=("sample", pca_features[:, 2])) - .assign_coords(PCA4=("sample", pca_features[:, 3])) - .set_index(sample=["PCA1", "PCA2", "PCA3", "PCA4"], append=True) + features.assign_coords(PCA1=("sample", pca_features[:, 0])) + .assign_coords(PCA2=("sample", pca_features[:, 1])) + .assign_coords(PCA3=("sample", pca_features[:, 2])) + .assign_coords(PCA4=("sample", pca_features[:, 3])) + .set_index(sample=["PCA1", "PCA2", "PCA3", "PCA4"], append=True) ) # %% OVERLAY INFECTION ANNOTATION ann_root = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/supervised_inf_pred" + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/supervised_inf_pred" ) infection = load_annotation( - features, - ann_root / "extracted_inf_state.csv", - "infection_state", - {0.0: "background", 1.0: "uninfected", 2.0: "infected"}, + features, + ann_root / "extracted_inf_state.csv", + "infection_state", + {0.0: "background", 1.0: "uninfected", 2.0: "infected"}, ) # %% plot the umap @@ -98,16 +97,20 @@ feature_npy_filtered = feature_npy[infection_npy != 0] sns.scatterplot( - x=umap_npy_filtered[:, 0], - y=umap_npy_filtered[:, 1], - hue=infection_npy_filtered, - palette={1: 'steelblue', 2: 'orangered'}, - hue_order=[1, 2], - s=7, - alpha=0.8, + x=umap_npy_filtered[:, 0], + y=umap_npy_filtered[:, 1], + hue=infection_npy_filtered, + palette={1: "steelblue", 2: "orangered"}, + hue_order=[1, 2], + s=7, + alpha=0.8, ) plt.legend([], [], frameon=False) -plt.savefig('/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/infection/umap_infection.png', format='png', dpi=300) +plt.savefig( + "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/infection/umap_infection.png", + format="png", + dpi=300, +) # %% plot PCA components with infection hue @@ -115,29 +118,33 @@ pca_npy_filtered = pca_npy[infection_npy != 0] sns.scatterplot( - x=pca_npy_filtered[:, 0], - y=pca_npy_filtered[:, 1], - hue=infection_npy_filtered, - palette={1: 'steelblue', 2: 'orangered'}, - hue_order=[1, 2], - s=7, - alpha=0.8, + x=pca_npy_filtered[:, 0], + y=pca_npy_filtered[:, 1], + hue=infection_npy_filtered, + palette={1: "steelblue", 2: "orangered"}, + hue_order=[1, 2], + s=7, + alpha=0.8, ) plt.legend([], [], frameon=False) -plt.savefig('/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/infection/pca_infection.png', format='png', dpi=300) +plt.savefig( + "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/infection/pca_infection.png", + format="png", + dpi=300, +) # %% combine the umap, pca and infection annotation in one dataframe data = pd.DataFrame( - { - "UMAP1": umap_npy_filtered[:, 0], - "UMAP2": umap_npy_filtered[:, 1], - "PCA1": pca_npy_filtered[:, 0], - "PCA2": pca_npy_filtered[:, 1], - "PCA3": pca_npy_filtered[:, 2], - "PCA4": pca_npy_filtered[:, 3], - "infection": infection_npy_filtered, - } + { + "UMAP1": umap_npy_filtered[:, 0], + "UMAP2": umap_npy_filtered[:, 1], + "PCA1": pca_npy_filtered[:, 0], + "PCA2": pca_npy_filtered[:, 1], + "PCA3": pca_npy_filtered[:, 2], + "PCA4": pca_npy_filtered[:, 3], + "infection": infection_npy_filtered, + } ) # add time and well info into dataframe @@ -156,10 +163,18 @@ # %% manually split the dataset into training and testing set by well name # dataframe for training set, fov names starts with "/B/4/6" or "/B/4/7" or "/A/3/" -data_train_val = data[data["fov_name"].str.contains("/B/4/6") | data["fov_name"].str.contains("/B/4/7") | data["fov_name"].str.contains("/A/3/")] +data_train_val = data[ + data["fov_name"].str.contains("/B/4/6") + | data["fov_name"].str.contains("/B/4/7") + | data["fov_name"].str.contains("/A/3/") +] # dataframe for testing set, fov names starts with "/B/4/8" or "/B/4/9" or "/A/4/" -data_test = data[data["fov_name"].str.contains("/B/4/8") | data["fov_name"].str.contains("/B/4/9") | data["fov_name"].str.contains("/B/3/")] +data_test = data[ + data["fov_name"].str.contains("/B/4/8") + | data["fov_name"].str.contains("/B/4/9") + | data["fov_name"].str.contains("/B/3/") +] # %% train a linear classifier to predict infection state from PCA components @@ -167,13 +182,37 @@ from sklearn.model_selection import train_test_split from sklearn.metrics import classification_report -x_train = data_train_val.drop(columns=["infection", "fov_name", "time", "UMAP1", "UMAP2", "PCA1", "PCA2", "PCA3", "PCA4"]) +x_train = data_train_val.drop( + columns=[ + "infection", + "fov_name", + "time", + "UMAP1", + "UMAP2", + "PCA1", + "PCA2", + "PCA3", + "PCA4", + ] +) y_train = data_train_val["infection"] # train a logistic regression model clf = LogisticRegression(random_state=0).fit(x_train, y_train) -x_test = data_test.drop(columns=["infection", "fov_name", "time", "UMAP1", "UMAP2", "PCA1", "PCA2", "PCA3", "PCA4"]) +x_test = data_test.drop( + columns=[ + "infection", + "fov_name", + "time", + "UMAP1", + "UMAP2", + "PCA1", + "PCA2", + "PCA3", + "PCA4", + ] +) y_test = data_test["infection"] # predict the infection state for the testing set @@ -185,14 +224,17 @@ import seaborn as sns cm = confusion_matrix(y_test, y_pred) -cm_percentage = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100 +cm_percentage = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis] * 100 sns.heatmap(cm_percentage, annot=True, fmt=".2f", cmap="viridis") plt.xlabel("Predicted") plt.ylabel("True") plt.title("Confusion Matrix (Percentage)") -plt.xticks(ticks=[0.5, 1.5], labels=['uninfected', 'infected']) -plt.yticks(ticks=[0.5, 1.5], labels=['uninfected', 'infected']) -plt.savefig('/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/infection/confusion_matrix.svg', format='svg') +plt.xticks(ticks=[0.5, 1.5], labels=["uninfected", "infected"]) +plt.yticks(ticks=[0.5, 1.5], labels=["uninfected", "infected"]) +plt.savefig( + "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/infection/confusion_matrix.svg", + format="svg", +) # %% use the trained classifier to perform prediction on the entire dataset @@ -205,44 +247,68 @@ infected_test_infected = [] for time in time_points_test: - infected_cell = data_test[(data_test['fov_name'].str.startswith('/B/3')) & (data_test['time'] == time) & (data_test['predicted_infection'] == 2)].shape[0] - total_cell = data_test[(data_test['fov_name'].str.startswith('/B/3')) & (data_test['time'] == time)].shape[0] - infected_test_cntrl.append(infected_cell*100 / total_cell) - infected_cell = data_test[(data_test['fov_name'].str.startswith('/B/4')) & (data_test['time'] == time) & (data_test['predicted_infection'] == 2)].shape[0] - total_cell = data_test[(data_test['fov_name'].str.startswith('/B/4')) & (data_test['time'] == time)].shape[0] - infected_test_infected.append(infected_cell*100 /total_cell) + infected_cell = data_test[ + (data_test["fov_name"].str.startswith("/B/3")) + & (data_test["time"] == time) + & (data_test["predicted_infection"] == 2) + ].shape[0] + total_cell = data_test[ + (data_test["fov_name"].str.startswith("/B/3")) & (data_test["time"] == time) + ].shape[0] + infected_test_cntrl.append(infected_cell * 100 / total_cell) + infected_cell = data_test[ + (data_test["fov_name"].str.startswith("/B/4")) + & (data_test["time"] == time) + & (data_test["predicted_infection"] == 2) + ].shape[0] + total_cell = data_test[ + (data_test["fov_name"].str.startswith("/B/4")) & (data_test["time"] == time) + ].shape[0] + infected_test_infected.append(infected_cell * 100 / total_cell) infected_true_cntrl = [] infected_true_infected = [] for time in time_points_test: - infected_cell = data_test[(data_test['fov_name'].str.startswith('/B/3')) & (data_test['time'] == time) & (data_test['infection'] == 2)].shape[0] - total_cell = data_test[(data_test['fov_name'].str.startswith('/B/3')) & (data_test['time'] == time)].shape[0] - infected_true_cntrl.append(infected_cell*100 / total_cell) - infected_cell = data_test[(data_test['fov_name'].str.startswith('/B/4')) & (data_test['time'] == time) & (data_test['infection'] == 2)].shape[0] - total_cell = data_test[(data_test['fov_name'].str.startswith('/B/4')) & (data_test['time'] == time)].shape[0] - infected_true_infected.append(infected_cell*100 /total_cell) + infected_cell = data_test[ + (data_test["fov_name"].str.startswith("/B/3")) + & (data_test["time"] == time) + & (data_test["infection"] == 2) + ].shape[0] + total_cell = data_test[ + (data_test["fov_name"].str.startswith("/B/3")) & (data_test["time"] == time) + ].shape[0] + infected_true_cntrl.append(infected_cell * 100 / total_cell) + infected_cell = data_test[ + (data_test["fov_name"].str.startswith("/B/4")) + & (data_test["time"] == time) + & (data_test["infection"] == 2) + ].shape[0] + total_cell = data_test[ + (data_test["fov_name"].str.startswith("/B/4")) & (data_test["time"] == time) + ].shape[0] + infected_true_infected.append(infected_cell * 100 / total_cell) # %% perform prediction on the june dataset # Paths and parameters. features_path = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/jun_time_interval_1_epoch_178.zarr" + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/jun_time_interval_1_epoch_178.zarr" ) data_path = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/2024_06_13_SEC61_TOMM20_ZIKV_DENGUE_1/2-register/registered_chunked.zarr" + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_06_13_SEC61_TOMM20_ZIKV_DENGUE_1/2-register/registered_chunked.zarr" ) tracks_path = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/2024_06_13_SEC61_TOMM20_ZIKV_DENGUE_1/4.2-tracking/track.zarr" + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_06_13_SEC61_TOMM20_ZIKV_DENGUE_1/4.2-tracking/track.zarr" ) # %% embedding_dataset = read_embedding_dataset(features_path) embedding_dataset -# %% +# %% june_features = embedding_dataset["features"] scaled_features = StandardScaler().fit_transform(june_features.values) @@ -251,9 +317,9 @@ embedding = umap.fit_transform(scaled_features) june_features = ( - june_features.assign_coords(UMAP1=("sample", embedding[:, 0])) - .assign_coords(UMAP2=("sample", embedding[:, 1])) - .set_index(sample=["UMAP1", "UMAP2"], append=True) + june_features.assign_coords(UMAP1=("sample", embedding[:, 0])) + .assign_coords(UMAP2=("sample", embedding[:, 1])) + .set_index(sample=["UMAP1", "UMAP2"], append=True) ) june_features @@ -280,15 +346,15 @@ june_umap_npy = embedding.copy() june_pca_npy = pca_features.copy() june_data = pd.DataFrame( - { - "UMAP1": june_umap_npy[:, 0], - "UMAP2": june_umap_npy[:, 1], - "PCA1": june_pca_npy[:, 0], - "PCA2": june_pca_npy[:, 1], - "PCA3": june_pca_npy[:, 2], - "PCA4": june_pca_npy[:, 3], - "infection": np.nan, - } + { + "UMAP1": june_umap_npy[:, 0], + "UMAP2": june_umap_npy[:, 1], + "PCA1": june_pca_npy[:, 0], + "PCA2": june_pca_npy[:, 1], + "PCA3": june_pca_npy[:, 2], + "PCA4": june_pca_npy[:, 3], + "infection": np.nan, + } ) # add time and well info into dataframe @@ -302,10 +368,27 @@ june_data[f"feature_{i+1}"] = june_features_npy[:, i] # use one mock and one dengue infecected well only -june_data = june_data[june_data["fov_name"].str.contains("/0/6") | june_data["fov_name"].str.contains("/0/2")] +june_data = june_data[ + june_data["fov_name"].str.contains("/0/6") + | june_data["fov_name"].str.contains("/0/2") +] # add the predicted infection state -june_pred = clf.predict(june_data.drop(columns=["infection", "fov_name", "time", "UMAP1", "UMAP2", "PCA1", "PCA2", "PCA3", "PCA4"])) +june_pred = clf.predict( + june_data.drop( + columns=[ + "infection", + "fov_name", + "time", + "UMAP1", + "UMAP2", + "PCA1", + "PCA2", + "PCA3", + "PCA4", + ] + ) +) june_data["predicted_infection"] = june_pred # %% combine the june and feb data @@ -313,7 +396,20 @@ combined_data = pd.concat([data_test, june_data]) # perform the umap analysis again with the 768 features -features = combined_data.drop(columns=["infection", "predicted_infection", "fov_name", "time", "UMAP1", "UMAP2", "PCA1", "PCA2", "PCA3", "PCA4"]) +features = combined_data.drop( + columns=[ + "infection", + "predicted_infection", + "fov_name", + "time", + "UMAP1", + "UMAP2", + "PCA1", + "PCA2", + "PCA3", + "PCA4", + ] +) scaled_features = StandardScaler().fit_transform(features.values) umap = UMAP() # Fit UMAP on all features @@ -324,57 +420,70 @@ combined_data["UMAP2"] = embedding[:, 1] # plot the combined data with 'fov_name' starting with '/A and '/B' hue 'infection' and '/0' hue 'predicted_infection' -Feb_split = combined_data[combined_data["fov_name"].str.contains("/A") | combined_data["fov_name"].str.contains("/B")] +Feb_split = combined_data[ + combined_data["fov_name"].str.contains("/A") + | combined_data["fov_name"].str.contains("/B") +] June_split = combined_data[combined_data["fov_name"].str.contains("/0")] sns.scatterplot( - x=Feb_split["UMAP1"], - y=Feb_split["UMAP2"], - hue=Feb_split["infection"], - palette={1: 'steelblue', 2: 'orangered'}, - hue_order=[1, 2], - s=7, - alpha=0.8, + x=Feb_split["UMAP1"], + y=Feb_split["UMAP2"], + hue=Feb_split["infection"], + palette={1: "steelblue", 2: "orangered"}, + hue_order=[1, 2], + s=7, + alpha=0.8, ) sns.scatterplot( - x=June_split["UMAP1"], - y=June_split["UMAP2"], - hue=June_split["predicted_infection"], - palette={1: 'blue', 2: 'red'}, - hue_order=[1, 2], - s=7, - alpha=0.8, + x=June_split["UMAP1"], + y=June_split["UMAP2"], + hue=June_split["predicted_infection"], + palette={1: "blue", 2: "red"}, + hue_order=[1, 2], + s=7, + alpha=0.8, ) plt.legend([], [], frameon=False) # plt.savefig('/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/infection/combined_umap_infection.png', format='png', dpi=300) # plot the scatterplot hue well name '/A' and '/B' are blue and '/0' are red sns.scatterplot( - x=combined_data["UMAP1"], - y=combined_data["UMAP2"], - hue=combined_data["fov_name"].apply(lambda x: 'blue' if x.startswith('/0') else 'red'), - s=7, - alpha=0.8, + x=combined_data["UMAP1"], + y=combined_data["UMAP2"], + hue=combined_data["fov_name"].apply( + lambda x: "blue" if x.startswith("/0") else "red" + ), + s=7, + alpha=0.8, ) plt.xlim(5, 15) plt.ylim(-10, 10) plt.legend([], [], frameon=False) -plt.savefig('/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/infection/combined_umap_well.png', format='png', dpi=300) +plt.savefig( + "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/infection/combined_umap_well.png", + format="png", + dpi=300, +) # plot the predicted infection state with combined data sns.scatterplot( - x=combined_data["UMAP1"], - y=combined_data["UMAP2"], - hue=combined_data["predicted_infection"], - palette={1: 'blue', 2: 'red'}, - hue_order=[1, 2], - s=7, - alpha=0.8, -) + x=combined_data["UMAP1"], + y=combined_data["UMAP2"], + hue=combined_data["predicted_infection"], + palette={1: "blue", 2: "red"}, + hue_order=[1, 2], + s=7, + alpha=0.8, +) plt.xlim(5, 15) plt.ylim(-10, 10) plt.legend([], [], frameon=False) -plt.savefig('/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/infection/combined_umap_predicted_infection.png', format='png', dpi=300) +plt.savefig( + "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/infection/combined_umap_predicted_infection.png", + format="png", + dpi=300, +) # %% plot % infected over time @@ -384,22 +493,67 @@ infected_june_infected = [] for time in time_points_june: - infected_june = June_split[(June_split['fov_name'].str.startswith('/0/2')) & (June_split['time'] == time) & (June_split['predicted_infection'] == 2)].shape[0] - total_june = June_split[(June_split['fov_name'].str.startswith('/0/2')) & (June_split['time'] == time)].shape[0] - infected_june_cntrl.append(infected_june*100 / total_june) - infected_june = June_split[(June_split['fov_name'].str.startswith('/0/6')) & (June_split['time'] == time) & (June_split['predicted_infection'] == 2)].shape[0] - total_june = June_split[(June_split['fov_name'].str.startswith('/0/6')) & (June_split['time'] == time)].shape[0] - infected_june_infected.append(infected_june*100 /total_june) + infected_june = June_split[ + (June_split["fov_name"].str.startswith("/0/2")) + & (June_split["time"] == time) + & (June_split["predicted_infection"] == 2) + ].shape[0] + total_june = June_split[ + (June_split["fov_name"].str.startswith("/0/2")) & (June_split["time"] == time) + ].shape[0] + infected_june_cntrl.append(infected_june * 100 / total_june) + infected_june = June_split[ + (June_split["fov_name"].str.startswith("/0/6")) + & (June_split["time"] == time) + & (June_split["predicted_infection"] == 2) + ].shape[0] + total_june = June_split[ + (June_split["fov_name"].str.startswith("/0/6")) & (June_split["time"] == time) + ].shape[0] + infected_june_infected.append(infected_june * 100 / total_june) # plot infected percentage over time for both wells -plt.plot(time_points_test*0.5 + 3, infected_true_cntrl, label='mock true', color='steelblue') -plt.plot(time_points_test*0.5 + 3, infected_test_cntrl, label='mock predicted', color='blue') -plt.plot(time_points_test*0.5 + 3, infected_true_infected, label='MOI true', color='orangered') -plt.plot(time_points_test*0.5 + 3, infected_test_infected, label='MOI predicted', color='red') -plt.plot(time_points_june*2 + 3, infected_june_cntrl, label='mock new predicted', color='green') -plt.plot(time_points_june*2 + 3, infected_june_infected, label='MOI new predicted', color='brown') -plt.xlabel('Time (hours)') -plt.ylabel('Infected percentage') +plt.plot( + time_points_test * 0.5 + 3, + infected_true_cntrl, + label="mock true", + color="steelblue", +) +plt.plot( + time_points_test * 0.5 + 3, + infected_test_cntrl, + label="mock predicted", + color="blue", +) +plt.plot( + time_points_test * 0.5 + 3, + infected_true_infected, + label="MOI true", + color="orangered", +) +plt.plot( + time_points_test * 0.5 + 3, + infected_test_infected, + label="MOI predicted", + color="red", +) +plt.plot( + time_points_june * 2 + 3, + infected_june_cntrl, + label="mock new predicted", + color="green", +) +plt.plot( + time_points_june * 2 + 3, + infected_june_infected, + label="MOI new predicted", + color="brown", +) +plt.xlabel("Time (hours)") +plt.ylabel("Infected percentage") plt.legend() -plt.savefig('/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/infection/infected_percentage_withJune.svg', format='svg') +plt.savefig( + "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/infection/infected_percentage_withJune.svg", + format="svg", +) From 727309c9a39a93e3cc52842e6c27148979936edc Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Mon, 23 Sep 2024 11:20:04 -0700 Subject: [PATCH 10/18] Black format evaluation module and fix import in figure_cell_infection script --- .../figures/figure_cell_infection.py | 2 +- viscy/representation/evaluation.py | 9 +++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/applications/contrastive_phenotyping/figures/figure_cell_infection.py b/applications/contrastive_phenotyping/figures/figure_cell_infection.py index b47bcc968..5a0ad4735 100644 --- a/applications/contrastive_phenotyping/figures/figure_cell_infection.py +++ b/applications/contrastive_phenotyping/figures/figure_cell_infection.py @@ -16,7 +16,7 @@ from viscy.representation.embedding_writer import read_embedding_dataset -from viscy.representation.evaluation import dataset_of_tracks, load_annotation +from viscy.representation.evaluation import load_annotation # %% Paths and parameters. diff --git a/viscy/representation/evaluation.py b/viscy/representation/evaluation.py index e1b2dc7ed..fdfbca0e8 100644 --- a/viscy/representation/evaluation.py +++ b/viscy/representation/evaluation.py @@ -479,7 +479,12 @@ def compute_displacement_mean_std( # Function to compute the norm of differences between embeddings at t and t + tau def compute_displacement( - embedding_dataset, max_tau=10, use_cosine=False, use_dissimilarity=False, use_umap=False): + embedding_dataset, + max_tau=10, + use_cosine=False, + use_dissimilarity=False, + use_umap=False, +): # Get the arrays of (fov_name, track_id, t, and embeddings) fov_names = embedding_dataset["fov_name"].values track_ids = embedding_dataset["track_id"].values @@ -488,7 +493,7 @@ def compute_displacement( if use_umap: umap1 = embedding_dataset["UMAP1"].values umap2 = embedding_dataset["UMAP2"].values - embeddings = np.vstack((umap1, umap2)).T + embeddings = np.vstack((umap1, umap2)).T else: embeddings = embedding_dataset["features"].values From 60f37d437ad3ac9757a969e2e79e7113139f2260 Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Tue, 24 Sep 2024 15:08:03 -0700 Subject: [PATCH 11/18] Refactor scatterplot colors and markers --- .../figures/figure_cell_infection.py | 50 +++++++++++-------- 1 file changed, 30 insertions(+), 20 deletions(-) diff --git a/applications/contrastive_phenotyping/figures/figure_cell_infection.py b/applications/contrastive_phenotyping/figures/figure_cell_infection.py index 5a0ad4735..807f04424 100644 --- a/applications/contrastive_phenotyping/figures/figure_cell_infection.py +++ b/applications/contrastive_phenotyping/figures/figure_cell_infection.py @@ -427,19 +427,19 @@ June_split = combined_data[combined_data["fov_name"].str.contains("/0")] sns.scatterplot( - x=Feb_split["UMAP1"], - y=Feb_split["UMAP2"], - hue=Feb_split["infection"], - palette={1: "steelblue", 2: "orangered"}, + x=June_split["UMAP1"], + y=June_split["UMAP2"], + hue=June_split["predicted_infection"], + palette={1: "blue", 2: "red"}, hue_order=[1, 2], s=7, alpha=0.8, ) sns.scatterplot( - x=June_split["UMAP1"], - y=June_split["UMAP2"], - hue=June_split["predicted_infection"], - palette={1: "blue", 2: "red"}, + x=Feb_split["UMAP1"], + y=Feb_split["UMAP2"], + hue=Feb_split["infection"], + palette={1: "steelblue", 2: "orange"}, hue_order=[1, 2], s=7, alpha=0.8, @@ -448,17 +448,21 @@ # plt.savefig('/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/infection/combined_umap_infection.png', format='png', dpi=300) # plot the scatterplot hue well name '/A' and '/B' are blue and '/0' are red +combined_data["color"] = combined_data["fov_name"].apply( + lambda x: "brown" if x.startswith("/0") else "green" +) + sns.scatterplot( x=combined_data["UMAP1"], y=combined_data["UMAP2"], - hue=combined_data["fov_name"].apply( - lambda x: "blue" if x.startswith("/0") else "red" - ), + hue="color", + palette={"green": "green", "brown": "brown"}, + data=combined_data, s=7, - alpha=0.8, + alpha=0.2, # Increased transparency ) -plt.xlim(5, 15) -plt.ylim(-10, 10) +plt.xlim(-5, 5) +plt.ylim(-2, 20) plt.legend([], [], frameon=False) plt.savefig( "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/infection/combined_umap_well.png", @@ -476,8 +480,8 @@ s=7, alpha=0.8, ) -plt.xlim(5, 15) -plt.ylim(-10, 10) +plt.xlim(-5, 5) +plt.ylim(-2, 20) plt.legend([], [], frameon=False) plt.savefig( "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/infection/combined_umap_predicted_infection.png", @@ -519,38 +523,44 @@ infected_true_cntrl, label="mock true", color="steelblue", + linestyle="--", ) plt.plot( time_points_test * 0.5 + 3, infected_test_cntrl, label="mock predicted", color="blue", + marker="+", ) plt.plot( time_points_test * 0.5 + 3, infected_true_infected, label="MOI true", - color="orangered", + color="orange", + linestyle="--", ) plt.plot( time_points_test * 0.5 + 3, infected_test_infected, label="MOI predicted", color="red", + marker="+", ) plt.plot( time_points_june * 2 + 3, infected_june_cntrl, label="mock new predicted", - color="green", + color="blue", + marker="o", ) plt.plot( time_points_june * 2 + 3, infected_june_infected, label="MOI new predicted", - color="brown", + color="red", + marker="o", ) -plt.xlabel("Time (hours)") +plt.xlabel("HPI") plt.ylabel("Infected percentage") plt.legend() plt.savefig( From 65e05699f0920ea7f997f1c64aadba6dab7e46f0 Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Tue, 24 Sep 2024 15:08:24 -0700 Subject: [PATCH 12/18] Calculate model accuracy --- .../Infection_classifier_accuracy.py | 71 +++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 applications/infection_classification/Infection_classifier_accuracy.py diff --git a/applications/infection_classification/Infection_classifier_accuracy.py b/applications/infection_classification/Infection_classifier_accuracy.py new file mode 100644 index 000000000..97958b018 --- /dev/null +++ b/applications/infection_classification/Infection_classifier_accuracy.py @@ -0,0 +1,71 @@ +# %% script to compare the output from the supervised model and human revised annotations to get the accuracy of the model + +import numpy as np +from iohub import open_ome_zarr +from scipy.ndimage import label + +# %% datapaths + +# Path to model output +data_out_path = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/supervised_inf_pred/supervised_test.zarr" + +# Path to the human revised annotations +human_corrected_path = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/supervised_inf_pred/supervised_test_corrected.zarr" + +# %% Load data and compute the number of objects in each class + +data_out = open_ome_zarr(data_out_path, layout="hcs", mode="r+") +human_corrected = open_ome_zarr(human_corrected_path, layout="hcs", mode="r+") + +out_medians = [] +HC_medians = [] +for well_id, well_data in data_out.wells(): + well_name, well_no = well_id.split("/") + + for pos_name, pos_data in well_data.positions(): + + out_data = pos_data.data.numpy() + T, C, Z, Y, X = out_data.shape + + HC_data = human_corrected[well_id + "/" + pos_name + "/0"] + HC_data = HC_data.numpy() + + # Compute the number of objects in the model output + for t in range(T): + out_img = out_data[t, 0, 0] + + # Compute the number of objects in the model output + out_labeled, num_out_objects = label(out_img > 0) + + # Compute the median of pixel values in each object in the model output + for obj_id in range(1, num_out_objects + 1): + obj_pixels = out_img[out_labeled == obj_id] + out_medians.append(np.median(obj_pixels)) + + # repeat for human acorrected annotations + HC_img = HC_data[t, 0, 0] + HC_labeled, num_HC_objects = label(HC_img > 0) + + for obj_id in range(1, num_HC_objects + 1): + obj_pixels = HC_img[HC_labeled == obj_id] + HC_medians.append(np.median(obj_pixels)) + +# %% Compute the accuracy + +num_twos_in_out_medians = out_medians.count(2) +num_twos_in_HC_medians = HC_medians.count(2) +error_inf = ( + (num_twos_in_HC_medians - num_twos_in_out_medians) / num_twos_in_HC_medians +) * 100 + +num_ones_in_out_medians = out_medians.count(1) +num_ones_in_HC_medians = HC_medians.count(1) +error_uninf = ( + (num_ones_in_HC_medians - num_ones_in_out_medians) / num_ones_in_HC_medians +) * 100 + +avg_error = (np.abs(error_inf) + np.abs(error_uninf)) / 2 + +accuracy = 100 - avg_error + +# %% From 05661da13643f706e772f04a25fdbb979f0a00cd Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Wed, 25 Sep 2024 13:26:51 -0700 Subject: [PATCH 13/18] Add script for appendix video --- .../figures/figure_cell_infection.py | 77 +++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/applications/contrastive_phenotyping/figures/figure_cell_infection.py b/applications/contrastive_phenotyping/figures/figure_cell_infection.py index 807f04424..7888e8f53 100644 --- a/applications/contrastive_phenotyping/figures/figure_cell_infection.py +++ b/applications/contrastive_phenotyping/figures/figure_cell_infection.py @@ -567,3 +567,80 @@ "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/infection/infected_percentage_withJune.svg", format="svg", ) + +# %% appendix video for infection dynamics umap, Feb test data, colored by human revised annotation + +for time in range(48): + plt.clf() + sns.scatterplot( + data=data_test[(data_test["time"] == time)], + x="UMAP1", + y="UMAP2", + hue="infection", + palette={1: "steelblue", 2: "orangered"}, + hue_order=[1, 2], + s=20, + alpha=0.8, + ) + handles, _ = plt.gca().get_legend_handles_labels() + plt.legend(handles=handles, labels=['uninfected', 'infected']) + plt.suptitle(f"Time: {time*0.5+3} HPI") + plt.ylim(-10, 20) + plt.xlim(2, 18) + plt.savefig( + f"/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/infection/video_umap/umap_feb_true_infection_" + str(time).zfill(3) + ".png", + format="png", + dpi=300, + ) + +# %% appendix video for infection dynamics umap, Feb test data, colored by predicted infection + +for time in range(48): + plt.clf() + sns.scatterplot( + data=data_test[(data_test["time"] == time)], + x="UMAP1", + y="UMAP2", + hue="predicted_infection", + palette={1: "blue", 2: "red"}, + hue_order=[1, 2], + s=20, + alpha=0.8, + ) + handles, _ = plt.gca().get_legend_handles_labels() + plt.legend(handles=handles, labels=['uninfected', 'infected']) + plt.suptitle(f"Time: {time*0.5+3} HPI") + plt.ylim(-10, 18) + plt.xlim(2, 18) + plt.savefig( + f"/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/infection/video_umap/umap_feb_predicted_infection_" + str(time).zfill(3) + ".png", + format="png", + dpi=300, + ) + +# %% appendix video for infection dynamics umap, June data, colored by predicted infection + +for time in range(12): + plt.clf() + sns.scatterplot( + data=June_split[(June_split["time"] == time)], + x="UMAP1", + y="UMAP2", + hue="predicted_infection", + palette={1: "blue", 2: "red"}, + hue_order=[1, 2], + s=20, + alpha=0.8, + ) + handles, _ = plt.gca().get_legend_handles_labels() + plt.legend(handles=handles, labels=['uninfected', 'infected']) + plt.suptitle(f"Time: {time*2+3} HPI") + plt.ylim(-8, 10) + plt.xlim(-5, 5) + plt.savefig( + f"/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/infection/video_umap/umap_june_predicted_infection_" + str(time).zfill(3) + ".png", + format="png", + dpi=300, + ) + +# %% From b0a869539c55aa98b8f9645ec1f1e03dcdb62fa7 Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Wed, 25 Sep 2024 13:27:08 -0700 Subject: [PATCH 14/18] formatted code --- .../figures/figure_cell_infection.py | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/applications/contrastive_phenotyping/figures/figure_cell_infection.py b/applications/contrastive_phenotyping/figures/figure_cell_infection.py index 7888e8f53..b14ae3aeb 100644 --- a/applications/contrastive_phenotyping/figures/figure_cell_infection.py +++ b/applications/contrastive_phenotyping/figures/figure_cell_infection.py @@ -579,16 +579,18 @@ hue="infection", palette={1: "steelblue", 2: "orangered"}, hue_order=[1, 2], - s=20, + s=20, alpha=0.8, ) handles, _ = plt.gca().get_legend_handles_labels() - plt.legend(handles=handles, labels=['uninfected', 'infected']) + plt.legend(handles=handles, labels=["uninfected", "infected"]) plt.suptitle(f"Time: {time*0.5+3} HPI") plt.ylim(-10, 20) plt.xlim(2, 18) plt.savefig( - f"/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/infection/video_umap/umap_feb_true_infection_" + str(time).zfill(3) + ".png", + f"/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/infection/video_umap/umap_feb_true_infection_" + + str(time).zfill(3) + + ".png", format="png", dpi=300, ) @@ -608,12 +610,14 @@ alpha=0.8, ) handles, _ = plt.gca().get_legend_handles_labels() - plt.legend(handles=handles, labels=['uninfected', 'infected']) + plt.legend(handles=handles, labels=["uninfected", "infected"]) plt.suptitle(f"Time: {time*0.5+3} HPI") plt.ylim(-10, 18) plt.xlim(2, 18) plt.savefig( - f"/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/infection/video_umap/umap_feb_predicted_infection_" + str(time).zfill(3) + ".png", + f"/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/infection/video_umap/umap_feb_predicted_infection_" + + str(time).zfill(3) + + ".png", format="png", dpi=300, ) @@ -633,14 +637,16 @@ alpha=0.8, ) handles, _ = plt.gca().get_legend_handles_labels() - plt.legend(handles=handles, labels=['uninfected', 'infected']) + plt.legend(handles=handles, labels=["uninfected", "infected"]) plt.suptitle(f"Time: {time*2+3} HPI") plt.ylim(-8, 10) plt.xlim(-5, 5) plt.savefig( - f"/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/infection/video_umap/umap_june_predicted_infection_" + str(time).zfill(3) + ".png", + f"/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/infection/video_umap/umap_june_predicted_infection_" + + str(time).zfill(3) + + ".png", format="png", dpi=300, - ) + ) # %% From f71bb3d8af1b6eedf41ebca8dfb33e3315e4f605 Mon Sep 17 00:00:00 2001 From: Alishba Imran <44557946+alishbaimran@users.noreply.github.com> Date: Thu, 26 Sep 2024 13:29:01 -0700 Subject: [PATCH 15/18] updated displacement funcs for full embeddings --- viscy/representation/evaluation.py | 80 ++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) diff --git a/viscy/representation/evaluation.py b/viscy/representation/evaluation.py index fdfbca0e8..e45c19bcf 100644 --- a/viscy/representation/evaluation.py +++ b/viscy/representation/evaluation.py @@ -541,3 +541,83 @@ def compute_displacement( displacement_per_tau[tau].append(displacement) return displacement_per_tau + +def calculate_normalized_euclidean_distance_cell(embedding_dataset, fov_name, track_id): + filtered_data = embedding_dataset.where( + (embedding_dataset["fov_name"] == fov_name) + & (embedding_dataset["track_id"] == track_id), + drop=True, + ) + + features = filtered_data["features"].values # (sample, features) + time_points = filtered_data["t"].values # (sample,) + + normalized_features = features / np.linalg.norm(features, axis=1, keepdims=True) + + # Get the first time point's normalized embedding + first_time_point_embedding = normalized_features[0].reshape(1, -1) + + euclidean_distances = [] + for i in range(len(time_points)): + distance = np.linalg.norm(first_time_point_embedding - normalized_features[i].reshape(1, -1)) + euclidean_distances.append(distance) + + return time_points, euclidean_distances + +def compute_displacement_mean_std_full( + embedding_dataset, max_tau=10 +): + fov_names = embedding_dataset["fov_name"].values + track_ids = embedding_dataset["track_id"].values + timepoints = embedding_dataset["t"].values + embeddings = embedding_dataset["features"].values + + cell_identifiers = np.array(list(zip(fov_names, track_ids)), dtype=[('fov_name', 'O'), ('track_id', 'int64')]) + + unique_cells = np.unique(cell_identifiers) + + displacement_per_tau = defaultdict(list) + + for cell in unique_cells: + fov_name = cell['fov_name'] + track_id = cell['track_id'] + + indices = np.where((fov_names == fov_name) & (track_ids == track_id))[0] + + cell_timepoints = timepoints[indices] + cell_embeddings = embeddings[indices] + + sorted_indices = np.argsort(cell_timepoints) + cell_timepoints = cell_timepoints[sorted_indices] + cell_embeddings = cell_embeddings[sorted_indices] + + for i in range(len(cell_timepoints)): + current_time = cell_timepoints[i] + current_embedding = cell_embeddings[i] + + current_embedding = current_embedding / np.linalg.norm(current_embedding) + + for tau in range(0, max_tau + 1): + future_time = current_time + tau + + future_index = np.where(cell_timepoints == future_time)[0] + + if len(future_index) >= 1: + future_embedding = cell_embeddings[future_index[0]] + future_embedding = future_embedding / np.linalg.norm(future_embedding) + + distance = np.linalg.norm(current_embedding - future_embedding) + + displacement_per_tau[tau].append(distance) + + mean_displacement_per_tau = { + tau: np.mean(displacements) + for tau, displacements in displacement_per_tau.items() + } + std_displacement_per_tau = { + tau: np.std(displacements) + for tau, displacements in displacement_per_tau.items() + } + + return mean_displacement_per_tau, std_displacement_per_tau + From 5dbee3464254c31c3d07d3395e501b41cd728a54 Mon Sep 17 00:00:00 2001 From: Alishba Imran <44557946+alishbaimran@users.noreply.github.com> Date: Thu, 26 Sep 2024 13:30:44 -0700 Subject: [PATCH 16/18] script for displacement computation --- .../evaluation/displacement.py | 118 ++++++++++++++++++ 1 file changed, 118 insertions(+) create mode 100644 applications/contrastive_phenotyping/evaluation/displacement.py diff --git a/applications/contrastive_phenotyping/evaluation/displacement.py b/applications/contrastive_phenotyping/evaluation/displacement.py new file mode 100644 index 000000000..a0d46c28a --- /dev/null +++ b/applications/contrastive_phenotyping/evaluation/displacement.py @@ -0,0 +1,118 @@ +# %% +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import plotly.express as px +import seaborn as sns +from sklearn.decomposition import PCA +from sklearn.preprocessing import StandardScaler +from umap import UMAP +from sklearn.decomposition import PCA +from matplotlib.font_manager import FontProperties + +from viscy.representation.embedding_writer import read_embedding_dataset +from viscy.representation.evaluation import dataset_of_tracks, load_annotation +from viscy.representation.evaluation import calculate_normalized_euclidean_distance_cell +from viscy.representation.evaluation import compute_displacement_mean_std_full +from sklearn.metrics.pairwise import cosine_similarity +from collections import defaultdict +from scipy.ndimage import gaussian_filter1d + +# %% paths + +features_path_30_min = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr" +) + +feature_path_no_track = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_random_sampling2/feb_fixed_test_predict.zarr") + +features_path_any_time = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_difcell_randomtime_sampling/Ver2_updateTracking_refineModel/predictions/Feb_2chan_128patch_32projDim/2chan_128patch_56ckpt_FebTest.zarr") + +data_path = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr" +) + +tracks_path = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr" +) + +# %% Load embedding datasets for all three sampling +fov_name = '/B/4/6' +track_id = 52 + +embedding_dataset_30_min = read_embedding_dataset(features_path_30_min) +embedding_dataset_no_track = read_embedding_dataset(feature_path_no_track) +embedding_dataset_any_time = read_embedding_dataset(features_path_any_time) + +#%% +# Calculate displacement for each sampling +time_points_30_min, cosine_similarities_30_min = calculate_normalized_euclidean_distance_cell(embedding_dataset_30_min, fov_name, track_id) +time_points_no_track, cosine_similarities_no_track = calculate_normalized_euclidean_distance_cell(embedding_dataset_no_track, fov_name, track_id) +time_points_any_time, cosine_similarities_any_time = calculate_normalized_euclidean_distance_cell(embedding_dataset_any_time, fov_name, track_id) + +# %% Plot displacement over time for all three conditions + +plt.figure(figsize=(10, 6)) + +plt.plot(time_points_no_track, cosine_similarities_no_track, marker='o', label='classical contrastive (no tracking)') +plt.plot(time_points_any_time, cosine_similarities_any_time, marker='o', label='cell aware') +plt.plot(time_points_30_min, cosine_similarities_30_min, marker='o', label='cell & time aware (interval 30 min)') + +plt.xlabel("Time Delay (t)", fontsize=10) +plt.ylabel("Normalized Euclidean Distance with First Time Point", fontsize=10) +plt.title("Normalized Euclidean Distance (Features) Over Time for Infected Cell", fontsize=12) + +plt.grid(True) +plt.legend(fontsize=10) + +#plt.savefig('4_euc_dist_full.svg', format='svg') +plt.show() + + +# %% Paths to datasets +features_path_30_min = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr") +feature_path_no_track = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_random_sampling2/feb_fixed_test_predict.zarr") + +embedding_dataset_30_min = read_embedding_dataset(features_path_30_min) +embedding_dataset_no_track = read_embedding_dataset(feature_path_no_track) + + +# %% +max_tau = 10 + +mean_displacement_30_min_euc, std_displacement_30_min_euc = compute_displacement_mean_std_full(embedding_dataset_30_min, max_tau) +mean_displacement_no_track_euc, std_displacement_no_track_euc = compute_displacement_mean_std_full(embedding_dataset_no_track, max_tau) + +# %% Plot 2: Cosine Displacements +plt.figure(figsize=(10, 6)) + +taus = list(mean_displacement_30_min_euc.keys()) + +mean_values_30_min_euc = list(mean_displacement_30_min_euc.values()) +std_values_30_min_euc = list(std_displacement_30_min_euc.values()) + +plt.plot(taus, mean_values_30_min_euc, marker='o', label='Cell & Time Aware (30 min interval)', color='green') +plt.fill_between(taus, + np.array(mean_values_30_min_euc) - np.array(std_values_30_min_euc), + np.array(mean_values_30_min_euc) + np.array(std_values_30_min_euc), + color='green', alpha=0.3, label='Std Dev (30 min interval)') + +mean_values_no_track_euc = list(mean_displacement_no_track_euc.values()) +std_values_no_track_euc = list(std_displacement_no_track_euc.values()) + +plt.plot(taus, mean_values_no_track_euc, marker='o', label='Classical Contrastive (No Tracking)', color='blue') +plt.fill_between(taus, + np.array(mean_values_no_track_euc) - np.array(std_values_no_track_euc), + np.array(mean_values_no_track_euc) + np.array(std_values_no_track_euc), + color='blue', alpha=0.3, label='Std Dev (No Tracking)') + +plt.xlabel('Time Shift (τ)') +plt.ylabel('Euclidean Distance') +plt.title('Embedding Displacement Over Time (Features)') + +plt.grid(True) +plt.legend() + +plt.show() From e78cff9e4b3a315cd147a46a03044db80e01c0db Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Fri, 27 Sep 2024 14:44:20 -0700 Subject: [PATCH 17/18] fix style --- viscy/representation/evaluation.py | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/viscy/representation/evaluation.py b/viscy/representation/evaluation.py index 812e186e5..622f532ee 100644 --- a/viscy/representation/evaluation.py +++ b/viscy/representation/evaluation.py @@ -602,6 +602,7 @@ def compute_displacement( return displacement_per_tau + def calculate_normalized_euclidean_distance_cell(embedding_dataset, fov_name, track_id): filtered_data = embedding_dataset.where( (embedding_dataset["fov_name"] == fov_name) @@ -619,31 +620,35 @@ def calculate_normalized_euclidean_distance_cell(embedding_dataset, fov_name, tr euclidean_distances = [] for i in range(len(time_points)): - distance = np.linalg.norm(first_time_point_embedding - normalized_features[i].reshape(1, -1)) + distance = np.linalg.norm( + first_time_point_embedding - normalized_features[i].reshape(1, -1) + ) euclidean_distances.append(distance) return time_points, euclidean_distances -def compute_displacement_mean_std_full( - embedding_dataset, max_tau=10 -): + +def compute_displacement_mean_std_full(embedding_dataset, max_tau=10): fov_names = embedding_dataset["fov_name"].values track_ids = embedding_dataset["track_id"].values timepoints = embedding_dataset["t"].values embeddings = embedding_dataset["features"].values - cell_identifiers = np.array(list(zip(fov_names, track_ids)), dtype=[('fov_name', 'O'), ('track_id', 'int64')]) + cell_identifiers = np.array( + list(zip(fov_names, track_ids)), + dtype=[("fov_name", "O"), ("track_id", "int64")], + ) unique_cells = np.unique(cell_identifiers) displacement_per_tau = defaultdict(list) for cell in unique_cells: - fov_name = cell['fov_name'] - track_id = cell['track_id'] + fov_name = cell["fov_name"] + track_id = cell["track_id"] indices = np.where((fov_names == fov_name) & (track_ids == track_id))[0] - + cell_timepoints = timepoints[indices] cell_embeddings = embeddings[indices] @@ -664,10 +669,12 @@ def compute_displacement_mean_std_full( if len(future_index) >= 1: future_embedding = cell_embeddings[future_index[0]] - future_embedding = future_embedding / np.linalg.norm(future_embedding) + future_embedding = future_embedding / np.linalg.norm( + future_embedding + ) distance = np.linalg.norm(current_embedding - future_embedding) - + displacement_per_tau[tau].append(distance) mean_displacement_per_tau = { @@ -680,4 +687,3 @@ def compute_displacement_mean_std_full( } return mean_displacement_per_tau, std_displacement_per_tau - From 4c48802605ab06ee023c22cb846f86c30d3c1a6d Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Fri, 27 Sep 2024 14:45:33 -0700 Subject: [PATCH 18/18] fix docstring format --- viscy/representation/evaluation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/viscy/representation/evaluation.py b/viscy/representation/evaluation.py index 622f532ee..343519d75 100644 --- a/viscy/representation/evaluation.py +++ b/viscy/representation/evaluation.py @@ -444,8 +444,8 @@ def compute_radial_intensity_gradient(image): return radial_intensity_gradient[0] -# Function to extract embeddings and calculate cosine similarities for a specific cell def calculate_cosine_similarity_cell(embedding_dataset, fov_name, track_id): + """Extract embeddings and calculate cosine similarities for a specific cell""" # Filter the dataset for the specific infected cell filtered_data = embedding_dataset.where( (embedding_dataset["fov_name"] == fov_name) @@ -471,10 +471,10 @@ def calculate_cosine_similarity_cell(embedding_dataset, fov_name, track_id): return time_points, cosine_similarities -# Function to compute the norm of differences between embeddings at t and t + tau def compute_displacement_mean_std( embedding_dataset, max_tau=10, use_cosine=False, use_dissimilarity=False ): + """Compute the norm of differences between embeddings at t and t + tau""" # Get the arrays of (fov_name, track_id, t, and embeddings) fov_names = embedding_dataset["fov_name"].values track_ids = embedding_dataset["track_id"].values @@ -537,7 +537,6 @@ def compute_displacement_mean_std( return mean_displacement_per_tau, std_displacement_per_tau -# Function to compute the norm of differences between embeddings at t and t + tau def compute_displacement( embedding_dataset, max_tau=10, @@ -545,6 +544,7 @@ def compute_displacement( use_dissimilarity=False, use_umap=False, ): + """Compute the norm of differences between embeddings at t and t + tau""" # Get the arrays of (fov_name, track_id, t, and embeddings) fov_names = embedding_dataset["fov_name"].values track_ids = embedding_dataset["track_id"].values