diff --git a/applications/contrastive_phenotyping/evaluation/cosine_similarity.py b/applications/contrastive_phenotyping/evaluation/cosine_similarity.py new file mode 100644 index 000000000..78a4906c9 --- /dev/null +++ b/applications/contrastive_phenotyping/evaluation/cosine_similarity.py @@ -0,0 +1,546 @@ +# %% +# Import necessary libraries, try euclidean distance for both features and +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +from sklearn.preprocessing import StandardScaler +from umap import UMAP + +from viscy.representation.embedding_writer import read_embedding_dataset +from viscy.representation.evaluation import ( + calculate_cosine_similarity_cell, + compute_displacement, + compute_displacement_mean_std, +) + +# %% Paths and parameters. + + +features_path_30_min = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr" +) + + +feature_path_no_track = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_random_sampling2/feb_fixed_test_predict.zarr" +) + + +features_path_any_time = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_difcell_randomtime_sampling/Ver2_updateTracking_refineModel/predictions/Feb_2chan_128patch_32projDim/2chan_128patch_56ckpt_FebTest.zarr" +) + + +data_path = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr" +) + + +tracks_path = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr" +) + + +# %% Load embedding datasets for all three sampling +fov_name = "/B/4/6" +track_id = 52 + +embedding_dataset_30_min = read_embedding_dataset(features_path_30_min) +embedding_dataset_no_track = read_embedding_dataset(feature_path_no_track) +embedding_dataset_any_time = read_embedding_dataset(features_path_any_time) + +# Calculate cosine similarities for each sampling +time_points_30_min, cosine_similarities_30_min = calculate_cosine_similarity_cell( + embedding_dataset_30_min, fov_name, track_id +) +time_points_no_track, cosine_similarities_no_track = calculate_cosine_similarity_cell( + embedding_dataset_no_track, fov_name, track_id +) +time_points_any_time, cosine_similarities_any_time = calculate_cosine_similarity_cell( + embedding_dataset_any_time, fov_name, track_id +) + +# %% Plot cosine similarities over time for all three conditions + +plt.figure(figsize=(10, 6)) + +plt.plot( + time_points_no_track, + cosine_similarities_no_track, + marker="o", + label="classical contrastive (no tracking)", +) +plt.plot( + time_points_any_time, cosine_similarities_any_time, marker="o", label="cell aware" +) +plt.plot( + time_points_30_min, + cosine_similarities_30_min, + marker="o", + label="cell & time aware (interval 30 min)", +) + +plt.xlabel("Time Delay (t)") +plt.ylabel("Cosine Similarity with First Time Point") +plt.title("Cosine Similarity Over Time for Infected Cell") + +# plt.savefig('infected_cell_example.pdf', format='pdf') + + +plt.grid(True) + +plt.legend() + +plt.savefig("new_example_cell.svg", format="svg") + + +plt.show() +# %% + + +# %% import statements + + +# %% Paths to datasets +features_path_30_min = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr" +) +feature_path_no_track = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_random_sampling2/feb_fixed_test_predict.zarr" +) +# features_path_any_time = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_difcell_randomtime_sampling/Ver2_updateTracking_refineModel/predictions/Feb_1chan_128patch_32projDim/1chan_128patch_63ckpt_FebTest.zarr") + + +# %% Read embedding datasets +embedding_dataset_30_min = read_embedding_dataset(features_path_30_min) +embedding_dataset_no_track = read_embedding_dataset(feature_path_no_track) +# embedding_dataset_any_time = read_embedding_dataset(features_path_any_time) + + +# %% Compute displacements for both datasets (using Euclidean distance and Cosine similarity) +max_tau = 10 # Maximum time shift (tau) to compute displacements + + +# mean_displacement_30_min, std_displacement_30_min = compute_displacement_mean_std(embedding_dataset_30_min, max_tau, use_cosine=False, use_dissimilarity=False) +# mean_displacement_no_track, std_displacement_no_track = compute_displacement_mean_std(embedding_dataset_no_track, max_tau, use_cosine=False, use_dissimilarity=False) +# mean_displacement_any_time, std_displacement_any_time = compute_displacement_mean_std(embedding_dataset_any_time, max_tau, use_cosine=False) + + +mean_displacement_30_min_cosine, std_displacement_30_min_cosine = ( + compute_displacement_mean_std( + embedding_dataset_30_min, max_tau, use_cosine=True, use_dissimilarity=False + ) +) +mean_displacement_no_track_cosine, std_displacement_no_track_cosine = ( + compute_displacement_mean_std( + embedding_dataset_no_track, max_tau, use_cosine=True, use_dissimilarity=False + ) +) +# mean_displacement_any_time_cosine, std_displacement_any_time_cosine = compute_displacement_mean_std(embedding_dataset_any_time, max_tau, use_cosine=True) +# %% Plot 1: Euclidean Displacements +plt.figure(figsize=(10, 6)) + + +taus = list(mean_displacement_30_min_cosine.keys()) +mean_values_30_min = list(mean_displacement_30_min_cosine.values()) +std_values_30_min = list(std_displacement_30_min_cosine.values()) + + +mean_values_no_track = list(mean_displacement_no_track_cosine.values()) +std_values_no_track = list(std_displacement_no_track_cosine.values()) + + +# mean_values_any_time = list(mean_displacement_any_time.values()) +# std_values_any_time = list(std_displacement_any_time.values()) + + +# Plotting Euclidean displacements +plt.plot( + taus, mean_values_30_min, marker="o", label="Cell & Time Aware (30 min interval)" +) +plt.fill_between( + taus, + np.array(mean_values_30_min) - np.array(std_values_30_min), + np.array(mean_values_30_min) + np.array(std_values_30_min), + color="gray", + alpha=0.3, + label="Std Dev (30 min interval)", +) + + +plt.plot( + taus, mean_values_no_track, marker="o", label="Classical Contrastive (No Tracking)" +) +plt.fill_between( + taus, + np.array(mean_values_no_track) - np.array(std_values_no_track), + np.array(mean_values_no_track) + np.array(std_values_no_track), + color="blue", + alpha=0.3, + label="Std Dev (No Tracking)", +) + + +plt.xlabel("Time Shift (τ)") +plt.ylabel("Displacement") +plt.title("Embedding Displacement Over Time") +plt.grid(True) +plt.legend() + + +# plt.savefig('embedding_displacement_euclidean.svg', format='svg') +# plt.savefig('embedding_displacement_euclidean.pdf', format='pdf') + + +# Show the Euclidean plot +plt.show() + + +# %% Plot 2: Cosine Displacements +plt.figure(figsize=(10, 6)) + +taus = list(mean_displacement_30_min_cosine.keys()) + +# Plotting Cosine displacements +mean_values_30_min_cosine = list(mean_displacement_30_min_cosine.values()) +std_values_30_min_cosine = list(std_displacement_30_min_cosine.values()) + + +mean_values_no_track_cosine = list(mean_displacement_no_track_cosine.values()) +std_values_no_track_cosine = list(std_displacement_no_track_cosine.values()) + + +plt.plot( + taus, + mean_values_30_min_cosine, + marker="o", + label="Cell & Time Aware (30 min interval)", +) +plt.fill_between( + taus, + np.array(mean_values_30_min_cosine) - np.array(std_values_30_min_cosine), + np.array(mean_values_30_min_cosine) + np.array(std_values_30_min_cosine), + color="gray", + alpha=0.3, + label="Std Dev (30 min interval)", +) + + +plt.plot( + taus, + mean_values_no_track_cosine, + marker="o", + label="Classical Contrastive (No Tracking)", +) +plt.fill_between( + taus, + np.array(mean_values_no_track_cosine) - np.array(std_values_no_track_cosine), + np.array(mean_values_no_track_cosine) + np.array(std_values_no_track_cosine), + color="blue", + alpha=0.3, + label="Std Dev (No Tracking)", +) + + +plt.xlabel("Time Shift (τ)") +plt.ylabel("Cosine Similarity") +plt.title("Embedding Displacement Over Time") + + +plt.grid(True) +plt.legend() +plt.savefig("1_std_cosine_plot.svg", format="svg") + +# Show the Cosine plot +plt.show() +# %% + + +# %% Paths to datasets +features_path_30_min = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr" +) +feature_path_no_track = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_random_sampling2/feb_fixed_test_predict.zarr" +) + + +# %% Read embedding datasets +embedding_dataset_30_min = read_embedding_dataset(features_path_30_min) +embedding_dataset_no_track = read_embedding_dataset(feature_path_no_track) + + +# %% Compute displacements for both datasets (using Cosine similarity) +max_tau = 10 # Maximum time shift (tau) to compute displacements + + +# Compute displacements for Cell & Time Aware (30 min interval) using Cosine similarity +displacement_per_tau_aware_cosine = compute_displacement( + embedding_dataset_30_min, + max_tau, + use_cosine=True, + use_dissimilarity=False, + use_umap=False, +) + + +# Compute displacements for Classical Contrastive (No Tracking) using Cosine similarity +displacement_per_tau_contrastive_cosine = compute_displacement( + embedding_dataset_no_track, + max_tau, + use_cosine=True, + use_dissimilarity=False, + use_umap=False, +) + + +# %% Prepare data for violin plot +def prepare_violin_data(taus, displacement_aware, displacement_contrastive): + # Create a list to hold the data + data = [] + + # Populate the data for Cell & Time Aware + for tau in taus: + displacements_aware = displacement_aware.get(tau, []) + for displacement in displacements_aware: + data.append( + { + "Time Shift (τ)": tau, + "Displacement": displacement, + "Sampling": "Cell & Time Aware (30 min interval)", + } + ) + + # Populate the data for Classical Contrastive + for tau in taus: + displacements_contrastive = displacement_contrastive.get(tau, []) + for displacement in displacements_contrastive: + data.append( + { + "Time Shift (τ)": tau, + "Displacement": displacement, + "Sampling": "Classical Contrastive (No Tracking)", + } + ) + + # Convert to a DataFrame + df = pd.DataFrame(data) + return df + + +taus = list(displacement_per_tau_aware_cosine.keys()) + + +# Prepare the violin plot data +df = prepare_violin_data( + taus, displacement_per_tau_aware_cosine, displacement_per_tau_contrastive_cosine +) + + +# Create a violin plot using seaborn +plt.figure(figsize=(12, 8)) +sns.violinplot( + x="Time Shift (τ)", + y="Displacement", + hue="Sampling", + data=df, + palette="Set2", + scale="width", + bw=0.2, + inner=None, + split=True, + cut=0, +) + + +# Add labels and title +plt.xlabel("Time Shift (τ)", fontsize=14) +plt.ylabel("Cosine Similarity", fontsize=14) +plt.title("Cosine Similarity Distribution on Features", fontsize=16) +plt.grid(True, linestyle="--", alpha=0.6) # Lighter grid lines for less distraction +plt.legend(title="Sampling", fontsize=12, title_fontsize=14) + + +# plt.ylim(0.5, 1.0) + + +# Save the violin plot as SVG and PDF +plt.savefig("1fixed_violin_plot_cosine_similarity.svg", format="svg") +# plt.savefig('violin_plot_cosine_similarity.pdf', format='pdf') + + +# Show the plot +plt.show() +# %% using umap violin plot + +# %% Paths to datasets +features_path_30_min = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr" +) +feature_path_no_track = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_random_sampling2/feb_fixed_test_predict.zarr" +) + +# %% Read embedding datasets +embedding_dataset_30_min = read_embedding_dataset(features_path_30_min) +embedding_dataset_no_track = read_embedding_dataset(feature_path_no_track) + + +# %% Compute UMAP on features +def compute_umap(dataset): + features = dataset["features"] + scaled_features = StandardScaler().fit_transform(features.values) + umap = UMAP(n_components=2) # Reduce to 2 dimensions + embedding = umap.fit_transform(scaled_features) + + # Add UMAP coordinates using xarray functionality + umap_features = features.assign_coords( + UMAP1=("sample", embedding[:, 0]), UMAP2=("sample", embedding[:, 1]) + ) + return umap_features + + +# Apply UMAP to both datasets +umap_features_30_min = compute_umap(embedding_dataset_30_min) +umap_features_no_track = compute_umap(embedding_dataset_no_track) + +# %% +print(umap_features_30_min) +# %% Visualize UMAP embeddings +# # Visualize UMAP embeddings for the 30 min interval +# plt.figure(figsize=(8, 6)) +# plt.scatter(umap_features_30_min[:, 0], umap_features_30_min[:, 1], c=embedding_dataset_30_min["t"].values, cmap='viridis') +# plt.colorbar(label='Timepoints') +# plt.title('UMAP Projection of Features (30 min Interval)') +# plt.xlabel('UMAP1') +# plt.ylabel('UMAP2') +# plt.show() + +# # Visualize UMAP embeddings for the No Tracking dataset +# plt.figure(figsize=(8, 6)) +# plt.scatter(umap_features_no_track[:, 0], umap_features_no_track[:, 1], c=embedding_dataset_no_track["t"].values, cmap='viridis') +# plt.colorbar(label='Timepoints') +# plt.title('UMAP Projection of Features (No Tracking)') +# plt.xlabel('UMAP1') +# plt.ylabel('UMAP2') +# plt.show() +# %% Compute displacements using UMAP coordinates (using Cosine similarity) +max_tau = 10 # Maximum time shift (tau) to compute displacements + +# Compute displacements for UMAP-processed Cell & Time Aware (30 min interval) +displacement_per_tau_aware_umap_cosine = compute_displacement( + umap_features_30_min, + max_tau, + use_cosine=True, + use_dissimilarity=False, + use_umap=True, +) + +# Compute displacements for UMAP-processed Classical Contrastive (No Tracking) +displacement_per_tau_contrastive_umap_cosine = compute_displacement( + umap_features_no_track, + max_tau, + use_cosine=True, + use_dissimilarity=False, + use_umap=True, +) + + +# %% Prepare data for violin plot +def prepare_violin_data(taus, displacement_aware, displacement_contrastive): + # Create a list to hold the data + data = [] + + # Populate the data for Cell & Time Aware + for tau in taus: + displacements_aware = displacement_aware.get(tau, []) + for displacement in displacements_aware: + data.append( + { + "Time Shift (τ)": tau, + "Displacement": displacement, + "Sampling": "Cell & Time Aware (30 min interval)", + } + ) + + # Populate the data for Classical Contrastive + for tau in taus: + displacements_contrastive = displacement_contrastive.get(tau, []) + for displacement in displacements_contrastive: + data.append( + { + "Time Shift (τ)": tau, + "Displacement": displacement, + "Sampling": "Classical Contrastive (No Tracking)", + } + ) + + # Convert to a DataFrame + df = pd.DataFrame(data) + return df + + +taus = list(displacement_per_tau_aware_umap_cosine.keys()) + +# Prepare the violin plot data +df = prepare_violin_data( + taus, + displacement_per_tau_aware_umap_cosine, + displacement_per_tau_contrastive_umap_cosine, +) + +# %% Create a violin plot using seaborn +plt.figure(figsize=(12, 8)) +sns.violinplot( + x="Time Shift (τ)", + y="Displacement", + hue="Sampling", + data=df, + palette="Set2", + scale="width", + bw=0.2, + inner=None, + split=True, + cut=0, +) + +# Add labels and title +plt.xlabel("Time Shift (τ)", fontsize=14) +plt.ylabel("Cosine Similarity", fontsize=14) +plt.title("Cosine Similarity Distribution using UMAP Features", fontsize=16) +plt.grid(True, linestyle="--", alpha=0.6) # Lighter grid lines for less distraction +plt.legend(title="Sampling", fontsize=12, title_fontsize=14) + +# plt.ylim(0, 1) + +# Save the violin plot as SVG and PDF +plt.savefig("fixed_plot_cosine_similarity.svg", format="svg") +# plt.savefig('violin_plot_cosine_similarity_umap.pdf', format='pdf') + +# Show the plot +plt.show() + + +# %% +# %% Visualize Displacement Distributions (Example Code) +# Compare displacement distributions for τ = 1 +# plt.figure(figsize=(10, 6)) +# sns.histplot(displacement_per_tau_aware_umap_cosine[1], kde=True, label='UMAP - 30 min Interval', color='blue') +# sns.histplot(displacement_per_tau_contrastive_umap_cosine[1], kde=True, label='UMAP - No Tracking', color='green') +# plt.legend() +# plt.title('Comparison of Displacement Distributions for τ = 1 (UMAP)') +# plt.xlabel('Displacement') +# plt.show() + +# # Compare displacement distributions for the full feature set (same τ = 1) +# plt.figure(figsize=(10, 6)) +# sns.histplot(displacement_per_tau_aware_cosine[1], kde=True, label='Full Features - 30 min Interval', color='red') +# sns.histplot(displacement_per_tau_contrastive_cosine[1], kde=True, label='Full Features - No Tracking', color='orange') +# plt.legend() +# plt.title('Comparison of Displacement Distributions for τ = 1 (Full Features)') +# plt.xlabel('Displacement') +# plt.show() +# # %% diff --git a/applications/contrastive_phenotyping/evaluation/displacement.py b/applications/contrastive_phenotyping/evaluation/displacement.py new file mode 100644 index 000000000..a0d46c28a --- /dev/null +++ b/applications/contrastive_phenotyping/evaluation/displacement.py @@ -0,0 +1,118 @@ +# %% +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import plotly.express as px +import seaborn as sns +from sklearn.decomposition import PCA +from sklearn.preprocessing import StandardScaler +from umap import UMAP +from sklearn.decomposition import PCA +from matplotlib.font_manager import FontProperties + +from viscy.representation.embedding_writer import read_embedding_dataset +from viscy.representation.evaluation import dataset_of_tracks, load_annotation +from viscy.representation.evaluation import calculate_normalized_euclidean_distance_cell +from viscy.representation.evaluation import compute_displacement_mean_std_full +from sklearn.metrics.pairwise import cosine_similarity +from collections import defaultdict +from scipy.ndimage import gaussian_filter1d + +# %% paths + +features_path_30_min = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr" +) + +feature_path_no_track = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_random_sampling2/feb_fixed_test_predict.zarr") + +features_path_any_time = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_difcell_randomtime_sampling/Ver2_updateTracking_refineModel/predictions/Feb_2chan_128patch_32projDim/2chan_128patch_56ckpt_FebTest.zarr") + +data_path = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr" +) + +tracks_path = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr" +) + +# %% Load embedding datasets for all three sampling +fov_name = '/B/4/6' +track_id = 52 + +embedding_dataset_30_min = read_embedding_dataset(features_path_30_min) +embedding_dataset_no_track = read_embedding_dataset(feature_path_no_track) +embedding_dataset_any_time = read_embedding_dataset(features_path_any_time) + +#%% +# Calculate displacement for each sampling +time_points_30_min, cosine_similarities_30_min = calculate_normalized_euclidean_distance_cell(embedding_dataset_30_min, fov_name, track_id) +time_points_no_track, cosine_similarities_no_track = calculate_normalized_euclidean_distance_cell(embedding_dataset_no_track, fov_name, track_id) +time_points_any_time, cosine_similarities_any_time = calculate_normalized_euclidean_distance_cell(embedding_dataset_any_time, fov_name, track_id) + +# %% Plot displacement over time for all three conditions + +plt.figure(figsize=(10, 6)) + +plt.plot(time_points_no_track, cosine_similarities_no_track, marker='o', label='classical contrastive (no tracking)') +plt.plot(time_points_any_time, cosine_similarities_any_time, marker='o', label='cell aware') +plt.plot(time_points_30_min, cosine_similarities_30_min, marker='o', label='cell & time aware (interval 30 min)') + +plt.xlabel("Time Delay (t)", fontsize=10) +plt.ylabel("Normalized Euclidean Distance with First Time Point", fontsize=10) +plt.title("Normalized Euclidean Distance (Features) Over Time for Infected Cell", fontsize=12) + +plt.grid(True) +plt.legend(fontsize=10) + +#plt.savefig('4_euc_dist_full.svg', format='svg') +plt.show() + + +# %% Paths to datasets +features_path_30_min = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr") +feature_path_no_track = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_random_sampling2/feb_fixed_test_predict.zarr") + +embedding_dataset_30_min = read_embedding_dataset(features_path_30_min) +embedding_dataset_no_track = read_embedding_dataset(feature_path_no_track) + + +# %% +max_tau = 10 + +mean_displacement_30_min_euc, std_displacement_30_min_euc = compute_displacement_mean_std_full(embedding_dataset_30_min, max_tau) +mean_displacement_no_track_euc, std_displacement_no_track_euc = compute_displacement_mean_std_full(embedding_dataset_no_track, max_tau) + +# %% Plot 2: Cosine Displacements +plt.figure(figsize=(10, 6)) + +taus = list(mean_displacement_30_min_euc.keys()) + +mean_values_30_min_euc = list(mean_displacement_30_min_euc.values()) +std_values_30_min_euc = list(std_displacement_30_min_euc.values()) + +plt.plot(taus, mean_values_30_min_euc, marker='o', label='Cell & Time Aware (30 min interval)', color='green') +plt.fill_between(taus, + np.array(mean_values_30_min_euc) - np.array(std_values_30_min_euc), + np.array(mean_values_30_min_euc) + np.array(std_values_30_min_euc), + color='green', alpha=0.3, label='Std Dev (30 min interval)') + +mean_values_no_track_euc = list(mean_displacement_no_track_euc.values()) +std_values_no_track_euc = list(std_displacement_no_track_euc.values()) + +plt.plot(taus, mean_values_no_track_euc, marker='o', label='Classical Contrastive (No Tracking)', color='blue') +plt.fill_between(taus, + np.array(mean_values_no_track_euc) - np.array(std_values_no_track_euc), + np.array(mean_values_no_track_euc) + np.array(std_values_no_track_euc), + color='blue', alpha=0.3, label='Std Dev (No Tracking)') + +plt.xlabel('Time Shift (τ)') +plt.ylabel('Euclidean Distance') +plt.title('Embedding Displacement Over Time (Features)') + +plt.grid(True) +plt.legend() + +plt.show() diff --git a/applications/contrastive_phenotyping/evaluation/log_regresssion_training.py b/applications/contrastive_phenotyping/evaluation/log_regresssion_training.py new file mode 100644 index 000000000..0bb7a4b32 --- /dev/null +++ b/applications/contrastive_phenotyping/evaluation/log_regresssion_training.py @@ -0,0 +1,111 @@ + +# %% +from pathlib import Path + + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import plotly.express as px +import seaborn as sns +from sklearn.decomposition import PCA +from sklearn.preprocessing import StandardScaler +from umap import UMAP +from sklearn.decomposition import PCA + + +from viscy.representation.embedding_writer import read_embedding_dataset +from viscy.representation.evaluation import dataset_of_tracks, load_annotation + + +# %% Paths and parameters. + + +features_path = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr" +) +data_path = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr" +) +tracks_path = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr" +) + + +# %% +embedding_dataset = read_embedding_dataset(features_path) +embedding_dataset + +# %% +# Compute UMAP over all features +features = embedding_dataset["features"] +# or select a well: +# features = features[features["fov_name"].str.contains("B/4")] + +# %% OVERLAY INFECTION ANNOTATION +ann_root = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/supervised_inf_pred" +) + + +infection = load_annotation( + features, + ann_root / "extracted_inf_state.csv", + "infection_state", + {0.0: "background", 1.0: "uninfected", 2.0: "infected"}, +) + +# %% plot the umap + +infection_npy = infection.cat.codes.values + +# Filter out the background class +infection_npy_filtered = infection_npy[infection_npy != 0] + +feature_npy = features.values +feature_npy_filtered = feature_npy[infection_npy != 0] + +# %% combine the umap, pca and infection annotation in one dataframe + +data = pd.DataFrame({"infection": infection_npy_filtered}) + +# add time and well info into dataframe +time_npy = features["t"].values +time_npy_filtered = time_npy[infection_npy != 0] +data["time"] = time_npy_filtered + +fov_name_list = features["fov_name"].values +fov_name_list_filtered = fov_name_list[infection_npy != 0] +data["fov_name"] = fov_name_list_filtered + +# Add all 768 features to the dataframe +for i in range(768): + data[f"feature_{i+1}"] = feature_npy_filtered[:, i] + +# %% manually split the dataset into training and testing set by well name + +# dataframe for training set, fov names starts with "/B/4/6" or "/B/4/7" or "/A/3/" +data_train_val = data[data["fov_name"].str.contains("/B/4/6") | data["fov_name"].str.contains("/B/4/7") | data["fov_name"].str.contains("/A/3/")] + +# dataframe for testing set, fov names starts with "/B/4/8" or "/B/4/9" or "/A/4/" +data_test = data[data["fov_name"].str.contains("/B/4/8") | data["fov_name"].str.contains("/B/4/9") | data["fov_name"].str.contains("/B/3/")] + +# %% train a linear classifier to predict infection state from PCA components + +from sklearn.linear_model import LogisticRegression +from sklearn.model_selection import train_test_split +from sklearn.metrics import classification_report + +x_train = data_train_val.drop(columns=["infection", "fov_name", "time"]) +y_train = data_train_val["infection"] + +# train a logistic regression model +clf = LogisticRegression(random_state=0).fit(x_train, y_train) + +x_test = data_test.drop(columns=["infection", "fov_name", "time"]) +y_test = data_test["infection"] + +# predict the infection state for the testing set +y_pred = clf.predict(x_test) + +# %% diff --git a/applications/contrastive_phenotyping/evaluation/pca_umap_embeddings_time.py b/applications/contrastive_phenotyping/evaluation/pca_umap_embeddings_time.py new file mode 100644 index 000000000..5f59da3e0 --- /dev/null +++ b/applications/contrastive_phenotyping/evaluation/pca_umap_embeddings_time.py @@ -0,0 +1,220 @@ +# %% +from pathlib import Path + +import matplotlib.pyplot as plt +import seaborn as sns +from sklearn.decomposition import PCA +from sklearn.preprocessing import StandardScaler +from umap import UMAP + +from viscy.representation.embedding_writer import read_embedding_dataset +from viscy.representation.evaluation import load_annotation + +# %% Paths and parameters. + + +features_path = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr" +) +data_path = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr" +) +tracks_path = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr" +) + + +# %% +embedding_dataset = read_embedding_dataset(features_path) +embedding_dataset + + +# %% +# Compute UMAP over all features +features = embedding_dataset["features"] +# or select a well: +# features = features[features["fov_name"].str.contains("B/4")] + + +scaled_features = StandardScaler().fit_transform(features.values) +umap = UMAP() +# Fit UMAP on all features +embedding = umap.fit_transform(scaled_features) + + +# %% +# Add UMAP coordinates to the dataset and plot w/ time + + +features = ( + features.assign_coords(UMAP1=("sample", embedding[:, 0])) + .assign_coords(UMAP2=("sample", embedding[:, 1])) + .set_index(sample=["UMAP1", "UMAP2"], append=True) +) +features + + +sns.scatterplot( + x=features["UMAP1"], y=features["UMAP2"], hue=features["t"], s=7, alpha=0.8 +) + + +# Add the title to the plot +plt.title("Cell & Time Aware Sampling (30 min interval)") +plt.xlim(-10, 20) +plt.ylim(-10, 20) +# plt.savefig('umap_cell_time_aware_time.svg', format='svg') +plt.savefig("updated_cell_time_aware_time.png", format="png") +# Show the plot +plt.show() + + +# %% + + +any_features_path = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_difcell_randomtime_sampling/Ver2_updateTracking_refineModel/predictions/Feb_2chan_128patch_32projDim/2chan_128patch_56ckpt_FebTest.zarr" +) +embedding_dataset = read_embedding_dataset(any_features_path) +embedding_dataset + + +# %% +# Compute UMAP over all features +features = embedding_dataset["features"] +# or select a well: +# features = features[features["fov_name"].str.contains("B/4")] + + +scaled_features = StandardScaler().fit_transform(features.values) +umap = UMAP() +# Fit UMAP on all features +embedding = umap.fit_transform(scaled_features) + + +# %% Any time sampling plot + + +features = ( + features.assign_coords(UMAP1=("sample", embedding[:, 0])) + .assign_coords(UMAP2=("sample", embedding[:, 1])) + .set_index(sample=["UMAP1", "UMAP2"], append=True) +) +features + + +sns.scatterplot( + x=features["UMAP1"], y=features["UMAP2"], hue=features["t"], s=7, alpha=0.8 +) + + +# Add the title to the plot +plt.title("Cell Aware Sampling") + +plt.xlim(-10, 20) +plt.ylim(-10, 20) + +plt.savefig("1_updated_cell_aware_time.png", format="png") +# plt.savefig('umap_cell_aware_time.pdf', format='pdf') +# Show the plot +plt.show() + + +# %% + + +contrastive_learning_path = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_random_sampling2/feb_fixed_test_predict.zarr" +) +embedding_dataset = read_embedding_dataset(contrastive_learning_path) +embedding_dataset + + +# %% +# Compute UMAP over all features +features = embedding_dataset["features"] +# or select a well: +# features = features[features["fov_name"].str.contains("B/4")] + + +scaled_features = StandardScaler().fit_transform(features.values) +umap = UMAP() +# Fit UMAP on all features +embedding = umap.fit_transform(scaled_features) + + +# %% Any time sampling plot + + +features = ( + features.assign_coords(UMAP1=("sample", embedding[:, 0])) + .assign_coords(UMAP2=("sample", embedding[:, 1])) + .set_index(sample=["UMAP1", "UMAP2"], append=True) +) +features + +sns.scatterplot( + x=features["UMAP1"], y=features["UMAP2"], hue=features["t"], s=7, alpha=0.8 +) + +# Add the title to the plot +plt.title("Classical Contrastive Learning Sampling") +plt.xlim(-10, 20) +plt.ylim(-10, 20) +plt.savefig("updated_classical_time.png", format="png") +# plt.savefig('classical_time.pdf', format='pdf') + +# Show the plot +plt.show() + + +# %% PCA + + +pca = PCA(n_components=4) +# scaled_features = StandardScaler().fit_transform(features.values) +# pca_features = pca.fit_transform(scaled_features) +pca_features = pca.fit_transform(features.values) + + +features = ( + features.assign_coords(PCA1=("sample", pca_features[:, 0])) + .assign_coords(PCA2=("sample", pca_features[:, 1])) + .assign_coords(PCA3=("sample", pca_features[:, 2])) + .assign_coords(PCA4=("sample", pca_features[:, 3])) + .set_index(sample=["PCA1", "PCA2", "PCA3", "PCA4"], append=True) +) + + +# %% plot PCA components w/ time + + +plt.figure(figsize=(10, 10)) +sns.scatterplot( + x=features["PCA1"], y=features["PCA2"], hue=features["t"], s=7, alpha=0.8 +) + + +# %% OVERLAY INFECTION ANNOTATION +ann_root = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/supervised_inf_pred" +) + + +infection = load_annotation( + features, + ann_root / "extracted_inf_state.csv", + "infection_state", + {0.0: "background", 1.0: "uninfected", 2.0: "infected"}, +) + + +# %% +sns.scatterplot(x=features["UMAP1"], y=features["UMAP2"], hue=infection, s=7, alpha=0.8) + + +# %% plot PCA components with infection hue +sns.scatterplot(x=features["PCA1"], y=features["PCA2"], hue=infection, s=7, alpha=0.8) + + +# %% diff --git a/applications/contrastive_phenotyping/figures/figure_cell_infection.py b/applications/contrastive_phenotyping/figures/figure_cell_infection.py new file mode 100644 index 000000000..b14ae3aeb --- /dev/null +++ b/applications/contrastive_phenotyping/figures/figure_cell_infection.py @@ -0,0 +1,652 @@ +# %% +from pathlib import Path +import sys + +sys.path.append("/hpc/mydata/soorya.pradeep/scratch/viscy_infection_phenotyping/VisCy") + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import plotly.express as px +import seaborn as sns +from sklearn.decomposition import PCA +from sklearn.preprocessing import StandardScaler +from umap import UMAP +from sklearn.decomposition import PCA + + +from viscy.representation.embedding_writer import read_embedding_dataset +from viscy.representation.evaluation import load_annotation + + +# %% Paths and parameters. + + +features_path = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr" +) +data_path = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr" +) +tracks_path = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr" +) + + +# %% +embedding_dataset = read_embedding_dataset(features_path) +embedding_dataset + +# %% +# Compute UMAP over all features +features = embedding_dataset["features"] +# or select a well: +# features = features[features["fov_name"].str.contains("B/4")] + + +scaled_features = StandardScaler().fit_transform(features.values) +umap = UMAP() +# Fit UMAP on all features +embedding = umap.fit_transform(scaled_features) + +features = ( + features.assign_coords(UMAP1=("sample", embedding[:, 0])) + .assign_coords(UMAP2=("sample", embedding[:, 1])) + .set_index(sample=["UMAP1", "UMAP2"], append=True) +) +features + +pca = PCA(n_components=4) +# scaled_features = StandardScaler().fit_transform(features.values) +# pca_features = pca.fit_transform(scaled_features) +pca_features = pca.fit_transform(features.values) + + +features = ( + features.assign_coords(PCA1=("sample", pca_features[:, 0])) + .assign_coords(PCA2=("sample", pca_features[:, 1])) + .assign_coords(PCA3=("sample", pca_features[:, 2])) + .assign_coords(PCA4=("sample", pca_features[:, 3])) + .set_index(sample=["PCA1", "PCA2", "PCA3", "PCA4"], append=True) +) + +# %% OVERLAY INFECTION ANNOTATION +ann_root = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/supervised_inf_pred" +) + +infection = load_annotation( + features, + ann_root / "extracted_inf_state.csv", + "infection_state", + {0.0: "background", 1.0: "uninfected", 2.0: "infected"}, +) + +# %% plot the umap + +# remove the rows in umap and annotation for background class +# Convert UMAP coordinates to a DataFrame +umap_npy = embedding.copy() +infection_npy = infection.cat.codes.values + +# Filter out the background class +umap_npy_filtered = umap_npy[infection_npy != 0] +infection_npy_filtered = infection_npy[infection_npy != 0] + +feature_npy = features.values +feature_npy_filtered = feature_npy[infection_npy != 0] + +sns.scatterplot( + x=umap_npy_filtered[:, 0], + y=umap_npy_filtered[:, 1], + hue=infection_npy_filtered, + palette={1: "steelblue", 2: "orangered"}, + hue_order=[1, 2], + s=7, + alpha=0.8, +) +plt.legend([], [], frameon=False) +plt.savefig( + "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/infection/umap_infection.png", + format="png", + dpi=300, +) + +# %% plot PCA components with infection hue + +pca_npy = pca_features.copy() +pca_npy_filtered = pca_npy[infection_npy != 0] + +sns.scatterplot( + x=pca_npy_filtered[:, 0], + y=pca_npy_filtered[:, 1], + hue=infection_npy_filtered, + palette={1: "steelblue", 2: "orangered"}, + hue_order=[1, 2], + s=7, + alpha=0.8, +) +plt.legend([], [], frameon=False) +plt.savefig( + "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/infection/pca_infection.png", + format="png", + dpi=300, +) + +# %% combine the umap, pca and infection annotation in one dataframe + +data = pd.DataFrame( + { + "UMAP1": umap_npy_filtered[:, 0], + "UMAP2": umap_npy_filtered[:, 1], + "PCA1": pca_npy_filtered[:, 0], + "PCA2": pca_npy_filtered[:, 1], + "PCA3": pca_npy_filtered[:, 2], + "PCA4": pca_npy_filtered[:, 3], + "infection": infection_npy_filtered, + } +) + +# add time and well info into dataframe +time_npy = features["t"].values +time_npy_filtered = time_npy[infection_npy != 0] +data["time"] = time_npy_filtered + +fov_name_list = features["fov_name"].values +fov_name_list_filtered = fov_name_list[infection_npy != 0] +data["fov_name"] = fov_name_list_filtered + +# Add all 768 features to the dataframe +for i in range(768): + data[f"feature_{i+1}"] = feature_npy_filtered[:, i] + +# %% manually split the dataset into training and testing set by well name + +# dataframe for training set, fov names starts with "/B/4/6" or "/B/4/7" or "/A/3/" +data_train_val = data[ + data["fov_name"].str.contains("/B/4/6") + | data["fov_name"].str.contains("/B/4/7") + | data["fov_name"].str.contains("/A/3/") +] + +# dataframe for testing set, fov names starts with "/B/4/8" or "/B/4/9" or "/A/4/" +data_test = data[ + data["fov_name"].str.contains("/B/4/8") + | data["fov_name"].str.contains("/B/4/9") + | data["fov_name"].str.contains("/B/3/") +] + +# %% train a linear classifier to predict infection state from PCA components + +from sklearn.linear_model import LogisticRegression +from sklearn.model_selection import train_test_split +from sklearn.metrics import classification_report + +x_train = data_train_val.drop( + columns=[ + "infection", + "fov_name", + "time", + "UMAP1", + "UMAP2", + "PCA1", + "PCA2", + "PCA3", + "PCA4", + ] +) +y_train = data_train_val["infection"] + +# train a logistic regression model +clf = LogisticRegression(random_state=0).fit(x_train, y_train) + +x_test = data_test.drop( + columns=[ + "infection", + "fov_name", + "time", + "UMAP1", + "UMAP2", + "PCA1", + "PCA2", + "PCA3", + "PCA4", + ] +) +y_test = data_test["infection"] + +# predict the infection state for the testing set +y_pred = clf.predict(x_test) + +# %% construct confusion matrix to compare the true and predicted infection state + +from sklearn.metrics import confusion_matrix +import seaborn as sns + +cm = confusion_matrix(y_test, y_pred) +cm_percentage = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis] * 100 +sns.heatmap(cm_percentage, annot=True, fmt=".2f", cmap="viridis") +plt.xlabel("Predicted") +plt.ylabel("True") +plt.title("Confusion Matrix (Percentage)") +plt.xticks(ticks=[0.5, 1.5], labels=["uninfected", "infected"]) +plt.yticks(ticks=[0.5, 1.5], labels=["uninfected", "infected"]) +plt.savefig( + "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/infection/confusion_matrix.svg", + format="svg", +) + +# %% use the trained classifier to perform prediction on the entire dataset + +data_test["predicted_infection"] = y_pred + +# plot the predicted infection state over time for /B/3 well and /B/4 well +time_points_test = np.unique(data_test["time"]) + +infected_test_cntrl = [] +infected_test_infected = [] + +for time in time_points_test: + infected_cell = data_test[ + (data_test["fov_name"].str.startswith("/B/3")) + & (data_test["time"] == time) + & (data_test["predicted_infection"] == 2) + ].shape[0] + total_cell = data_test[ + (data_test["fov_name"].str.startswith("/B/3")) & (data_test["time"] == time) + ].shape[0] + infected_test_cntrl.append(infected_cell * 100 / total_cell) + infected_cell = data_test[ + (data_test["fov_name"].str.startswith("/B/4")) + & (data_test["time"] == time) + & (data_test["predicted_infection"] == 2) + ].shape[0] + total_cell = data_test[ + (data_test["fov_name"].str.startswith("/B/4")) & (data_test["time"] == time) + ].shape[0] + infected_test_infected.append(infected_cell * 100 / total_cell) + + +infected_true_cntrl = [] +infected_true_infected = [] + +for time in time_points_test: + infected_cell = data_test[ + (data_test["fov_name"].str.startswith("/B/3")) + & (data_test["time"] == time) + & (data_test["infection"] == 2) + ].shape[0] + total_cell = data_test[ + (data_test["fov_name"].str.startswith("/B/3")) & (data_test["time"] == time) + ].shape[0] + infected_true_cntrl.append(infected_cell * 100 / total_cell) + infected_cell = data_test[ + (data_test["fov_name"].str.startswith("/B/4")) + & (data_test["time"] == time) + & (data_test["infection"] == 2) + ].shape[0] + total_cell = data_test[ + (data_test["fov_name"].str.startswith("/B/4")) & (data_test["time"] == time) + ].shape[0] + infected_true_infected.append(infected_cell * 100 / total_cell) + + +# %% perform prediction on the june dataset + +# Paths and parameters. +features_path = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/jun_time_interval_1_epoch_178.zarr" +) +data_path = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_06_13_SEC61_TOMM20_ZIKV_DENGUE_1/2-register/registered_chunked.zarr" +) +tracks_path = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_06_13_SEC61_TOMM20_ZIKV_DENGUE_1/4.2-tracking/track.zarr" +) + +# %% +embedding_dataset = read_embedding_dataset(features_path) +embedding_dataset + +# %% +june_features = embedding_dataset["features"] + +scaled_features = StandardScaler().fit_transform(june_features.values) +umap = UMAP() +# Fit UMAP on all features +embedding = umap.fit_transform(scaled_features) + +june_features = ( + june_features.assign_coords(UMAP1=("sample", embedding[:, 0])) + .assign_coords(UMAP2=("sample", embedding[:, 1])) + .set_index(sample=["UMAP1", "UMAP2"], append=True) +) +june_features + +pca = PCA(n_components=4) +pca_features = pca.fit_transform(june_features.values) + +# %% + +# sns.scatterplot( +# x=june_features["UMAP1"], +# y=june_features["UMAP2"], +# hue=june_pred, +# palette={1: 'blue', 2: 'red'}, +# hue_order=[1, 2], +# s=7, +# alpha=0.8, +# ) +# plt.legend([], [], frameon=False) +# plt.xlim(0, 15) +# plt.savefig('/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/infection/june_umap_infection.png', format='png', dpi=300) + +# %% plot June and Feb test combined UMAP + +june_umap_npy = embedding.copy() +june_pca_npy = pca_features.copy() +june_data = pd.DataFrame( + { + "UMAP1": june_umap_npy[:, 0], + "UMAP2": june_umap_npy[:, 1], + "PCA1": june_pca_npy[:, 0], + "PCA2": june_pca_npy[:, 1], + "PCA3": june_pca_npy[:, 2], + "PCA4": june_pca_npy[:, 3], + "infection": np.nan, + } +) + +# add time and well info into dataframe +june_data["time"] = june_features["t"].values + +june_data["fov_name"] = june_features["fov_name"].values + +# Add all 768 features to the dataframe +june_features_npy = june_features.values +for i in range(768): + june_data[f"feature_{i+1}"] = june_features_npy[:, i] + +# use one mock and one dengue infecected well only +june_data = june_data[ + june_data["fov_name"].str.contains("/0/6") + | june_data["fov_name"].str.contains("/0/2") +] + +# add the predicted infection state +june_pred = clf.predict( + june_data.drop( + columns=[ + "infection", + "fov_name", + "time", + "UMAP1", + "UMAP2", + "PCA1", + "PCA2", + "PCA3", + "PCA4", + ] + ) +) +june_data["predicted_infection"] = june_pred + +# %% combine the june and feb data + +combined_data = pd.concat([data_test, june_data]) + +# perform the umap analysis again with the 768 features +features = combined_data.drop( + columns=[ + "infection", + "predicted_infection", + "fov_name", + "time", + "UMAP1", + "UMAP2", + "PCA1", + "PCA2", + "PCA3", + "PCA4", + ] +) +scaled_features = StandardScaler().fit_transform(features.values) +umap = UMAP() +# Fit UMAP on all features +embedding = umap.fit_transform(scaled_features) + +# overwrite the umap coordinates on combined data +combined_data["UMAP1"] = embedding[:, 0] +combined_data["UMAP2"] = embedding[:, 1] + +# plot the combined data with 'fov_name' starting with '/A and '/B' hue 'infection' and '/0' hue 'predicted_infection' +Feb_split = combined_data[ + combined_data["fov_name"].str.contains("/A") + | combined_data["fov_name"].str.contains("/B") +] +June_split = combined_data[combined_data["fov_name"].str.contains("/0")] + +sns.scatterplot( + x=June_split["UMAP1"], + y=June_split["UMAP2"], + hue=June_split["predicted_infection"], + palette={1: "blue", 2: "red"}, + hue_order=[1, 2], + s=7, + alpha=0.8, +) +sns.scatterplot( + x=Feb_split["UMAP1"], + y=Feb_split["UMAP2"], + hue=Feb_split["infection"], + palette={1: "steelblue", 2: "orange"}, + hue_order=[1, 2], + s=7, + alpha=0.8, +) +plt.legend([], [], frameon=False) +# plt.savefig('/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/infection/combined_umap_infection.png', format='png', dpi=300) + +# plot the scatterplot hue well name '/A' and '/B' are blue and '/0' are red +combined_data["color"] = combined_data["fov_name"].apply( + lambda x: "brown" if x.startswith("/0") else "green" +) + +sns.scatterplot( + x=combined_data["UMAP1"], + y=combined_data["UMAP2"], + hue="color", + palette={"green": "green", "brown": "brown"}, + data=combined_data, + s=7, + alpha=0.2, # Increased transparency +) +plt.xlim(-5, 5) +plt.ylim(-2, 20) +plt.legend([], [], frameon=False) +plt.savefig( + "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/infection/combined_umap_well.png", + format="png", + dpi=300, +) + +# plot the predicted infection state with combined data +sns.scatterplot( + x=combined_data["UMAP1"], + y=combined_data["UMAP2"], + hue=combined_data["predicted_infection"], + palette={1: "blue", 2: "red"}, + hue_order=[1, 2], + s=7, + alpha=0.8, +) +plt.xlim(-5, 5) +plt.ylim(-2, 20) +plt.legend([], [], frameon=False) +plt.savefig( + "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/infection/combined_umap_predicted_infection.png", + format="png", + dpi=300, +) + +# %% plot % infected over time + +time_points_june = np.unique(June_split["time"]) + +infected_june_cntrl = [] +infected_june_infected = [] + +for time in time_points_june: + infected_june = June_split[ + (June_split["fov_name"].str.startswith("/0/2")) + & (June_split["time"] == time) + & (June_split["predicted_infection"] == 2) + ].shape[0] + total_june = June_split[ + (June_split["fov_name"].str.startswith("/0/2")) & (June_split["time"] == time) + ].shape[0] + infected_june_cntrl.append(infected_june * 100 / total_june) + infected_june = June_split[ + (June_split["fov_name"].str.startswith("/0/6")) + & (June_split["time"] == time) + & (June_split["predicted_infection"] == 2) + ].shape[0] + total_june = June_split[ + (June_split["fov_name"].str.startswith("/0/6")) & (June_split["time"] == time) + ].shape[0] + infected_june_infected.append(infected_june * 100 / total_june) + + +# plot infected percentage over time for both wells +plt.plot( + time_points_test * 0.5 + 3, + infected_true_cntrl, + label="mock true", + color="steelblue", + linestyle="--", +) +plt.plot( + time_points_test * 0.5 + 3, + infected_test_cntrl, + label="mock predicted", + color="blue", + marker="+", +) +plt.plot( + time_points_test * 0.5 + 3, + infected_true_infected, + label="MOI true", + color="orange", + linestyle="--", +) +plt.plot( + time_points_test * 0.5 + 3, + infected_test_infected, + label="MOI predicted", + color="red", + marker="+", +) +plt.plot( + time_points_june * 2 + 3, + infected_june_cntrl, + label="mock new predicted", + color="blue", + marker="o", +) +plt.plot( + time_points_june * 2 + 3, + infected_june_infected, + label="MOI new predicted", + color="red", + marker="o", +) +plt.xlabel("HPI") +plt.ylabel("Infected percentage") +plt.legend() +plt.savefig( + "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/infection/infected_percentage_withJune.svg", + format="svg", +) + +# %% appendix video for infection dynamics umap, Feb test data, colored by human revised annotation + +for time in range(48): + plt.clf() + sns.scatterplot( + data=data_test[(data_test["time"] == time)], + x="UMAP1", + y="UMAP2", + hue="infection", + palette={1: "steelblue", 2: "orangered"}, + hue_order=[1, 2], + s=20, + alpha=0.8, + ) + handles, _ = plt.gca().get_legend_handles_labels() + plt.legend(handles=handles, labels=["uninfected", "infected"]) + plt.suptitle(f"Time: {time*0.5+3} HPI") + plt.ylim(-10, 20) + plt.xlim(2, 18) + plt.savefig( + f"/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/infection/video_umap/umap_feb_true_infection_" + + str(time).zfill(3) + + ".png", + format="png", + dpi=300, + ) + +# %% appendix video for infection dynamics umap, Feb test data, colored by predicted infection + +for time in range(48): + plt.clf() + sns.scatterplot( + data=data_test[(data_test["time"] == time)], + x="UMAP1", + y="UMAP2", + hue="predicted_infection", + palette={1: "blue", 2: "red"}, + hue_order=[1, 2], + s=20, + alpha=0.8, + ) + handles, _ = plt.gca().get_legend_handles_labels() + plt.legend(handles=handles, labels=["uninfected", "infected"]) + plt.suptitle(f"Time: {time*0.5+3} HPI") + plt.ylim(-10, 18) + plt.xlim(2, 18) + plt.savefig( + f"/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/infection/video_umap/umap_feb_predicted_infection_" + + str(time).zfill(3) + + ".png", + format="png", + dpi=300, + ) + +# %% appendix video for infection dynamics umap, June data, colored by predicted infection + +for time in range(12): + plt.clf() + sns.scatterplot( + data=June_split[(June_split["time"] == time)], + x="UMAP1", + y="UMAP2", + hue="predicted_infection", + palette={1: "blue", 2: "red"}, + hue_order=[1, 2], + s=20, + alpha=0.8, + ) + handles, _ = plt.gca().get_legend_handles_labels() + plt.legend(handles=handles, labels=["uninfected", "infected"]) + plt.suptitle(f"Time: {time*2+3} HPI") + plt.ylim(-8, 10) + plt.xlim(-5, 5) + plt.savefig( + f"/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/infection/video_umap/umap_june_predicted_infection_" + + str(time).zfill(3) + + ".png", + format="png", + dpi=300, + ) + +# %% diff --git a/applications/infection_classification/Infection_classifier_accuracy.py b/applications/infection_classification/Infection_classifier_accuracy.py new file mode 100644 index 000000000..97958b018 --- /dev/null +++ b/applications/infection_classification/Infection_classifier_accuracy.py @@ -0,0 +1,71 @@ +# %% script to compare the output from the supervised model and human revised annotations to get the accuracy of the model + +import numpy as np +from iohub import open_ome_zarr +from scipy.ndimage import label + +# %% datapaths + +# Path to model output +data_out_path = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/supervised_inf_pred/supervised_test.zarr" + +# Path to the human revised annotations +human_corrected_path = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/supervised_inf_pred/supervised_test_corrected.zarr" + +# %% Load data and compute the number of objects in each class + +data_out = open_ome_zarr(data_out_path, layout="hcs", mode="r+") +human_corrected = open_ome_zarr(human_corrected_path, layout="hcs", mode="r+") + +out_medians = [] +HC_medians = [] +for well_id, well_data in data_out.wells(): + well_name, well_no = well_id.split("/") + + for pos_name, pos_data in well_data.positions(): + + out_data = pos_data.data.numpy() + T, C, Z, Y, X = out_data.shape + + HC_data = human_corrected[well_id + "/" + pos_name + "/0"] + HC_data = HC_data.numpy() + + # Compute the number of objects in the model output + for t in range(T): + out_img = out_data[t, 0, 0] + + # Compute the number of objects in the model output + out_labeled, num_out_objects = label(out_img > 0) + + # Compute the median of pixel values in each object in the model output + for obj_id in range(1, num_out_objects + 1): + obj_pixels = out_img[out_labeled == obj_id] + out_medians.append(np.median(obj_pixels)) + + # repeat for human acorrected annotations + HC_img = HC_data[t, 0, 0] + HC_labeled, num_HC_objects = label(HC_img > 0) + + for obj_id in range(1, num_HC_objects + 1): + obj_pixels = HC_img[HC_labeled == obj_id] + HC_medians.append(np.median(obj_pixels)) + +# %% Compute the accuracy + +num_twos_in_out_medians = out_medians.count(2) +num_twos_in_HC_medians = HC_medians.count(2) +error_inf = ( + (num_twos_in_HC_medians - num_twos_in_out_medians) / num_twos_in_HC_medians +) * 100 + +num_ones_in_out_medians = out_medians.count(1) +num_ones_in_HC_medians = HC_medians.count(1) +error_uninf = ( + (num_ones_in_HC_medians - num_ones_in_out_medians) / num_ones_in_HC_medians +) * 100 + +avg_error = (np.abs(error_inf) + np.abs(error_uninf)) / 2 + +accuracy = 100 - avg_error + +# %% diff --git a/viscy/representation/evaluation.py b/viscy/representation/evaluation.py index f459e8b46..343519d75 100644 --- a/viscy/representation/evaluation.py +++ b/viscy/representation/evaluation.py @@ -1,3 +1,5 @@ +from collections import defaultdict + import numpy as np import pandas as pd import umap @@ -12,6 +14,7 @@ normalized_mutual_info_score, silhouette_score, ) +from sklearn.metrics.pairwise import cosine_similarity from sklearn.neighbors import KNeighborsClassifier from sklearn.preprocessing import StandardScaler @@ -439,3 +442,248 @@ def compute_radial_intensity_gradient(image): ) return radial_intensity_gradient[0] + + +def calculate_cosine_similarity_cell(embedding_dataset, fov_name, track_id): + """Extract embeddings and calculate cosine similarities for a specific cell""" + # Filter the dataset for the specific infected cell + filtered_data = embedding_dataset.where( + (embedding_dataset["fov_name"] == fov_name) + & (embedding_dataset["track_id"] == track_id), + drop=True, + ) + + # Extract the feature embeddings and time points + features = filtered_data["features"].values # (sample, features) + time_points = filtered_data["t"].values # (sample,) + + # Get the first time point's embedding + first_time_point_embedding = features[0].reshape(1, -1) + + # Calculate cosine similarity between each time point and the first time point + cosine_similarities = [] + for i in range(len(time_points)): + similarity = cosine_similarity( + first_time_point_embedding, features[i].reshape(1, -1) + ) + cosine_similarities.append(similarity[0][0]) + + return time_points, cosine_similarities + + +def compute_displacement_mean_std( + embedding_dataset, max_tau=10, use_cosine=False, use_dissimilarity=False +): + """Compute the norm of differences between embeddings at t and t + tau""" + # Get the arrays of (fov_name, track_id, t, and embeddings) + fov_names = embedding_dataset["fov_name"].values + track_ids = embedding_dataset["track_id"].values + timepoints = embedding_dataset["t"].values + embeddings = embedding_dataset["features"].values + + # Dictionary to store displacements for each tau + displacement_per_tau = defaultdict(list) + + # Iterate over all entries in the dataset + for i in range(len(fov_names)): + fov_name = fov_names[i] + track_id = track_ids[i] + current_time = timepoints[i] + current_embedding = embeddings[i] + + # For each time point t, compute displacements for t + tau + for tau in range(1, max_tau + 1): + future_time = current_time + tau + + # Find if future_time exists for the same (fov_name, track_id) + matching_indices = np.where( + (fov_names == fov_name) + & (track_ids == track_id) + & (timepoints == future_time) + )[0] + + if len(matching_indices) == 1: + # Get the embedding at t + tau + future_embedding = embeddings[matching_indices[0]] + + if use_cosine: + # Compute cosine similarity + similarity = cosine_similarity( + current_embedding.reshape(1, -1), + future_embedding.reshape(1, -1), + )[0][0] + # Choose whether to use similarity or dissimilarity + if use_dissimilarity: + displacement = 1 - similarity # Cosine dissimilarity + else: + displacement = similarity # Cosine similarity + else: + # Compute the Euclidean distance, elementwise square on difference + displacement = np.sum((current_embedding - future_embedding) ** 2) + + # Store the displacement for the given tau + displacement_per_tau[tau].append(displacement) + + # Compute mean and std displacement for each tau by averaging the displacements + mean_displacement_per_tau = { + tau: np.mean(displacements) + for tau, displacements in displacement_per_tau.items() + } + std_displacement_per_tau = { + tau: np.std(displacements) + for tau, displacements in displacement_per_tau.items() + } + + return mean_displacement_per_tau, std_displacement_per_tau + + +def compute_displacement( + embedding_dataset, + max_tau=10, + use_cosine=False, + use_dissimilarity=False, + use_umap=False, +): + """Compute the norm of differences between embeddings at t and t + tau""" + # Get the arrays of (fov_name, track_id, t, and embeddings) + fov_names = embedding_dataset["fov_name"].values + track_ids = embedding_dataset["track_id"].values + timepoints = embedding_dataset["t"].values + + if use_umap: + umap1 = embedding_dataset["UMAP1"].values + umap2 = embedding_dataset["UMAP2"].values + embeddings = np.vstack((umap1, umap2)).T + else: + embeddings = embedding_dataset["features"].values + + # Dictionary to store displacements for each tau + displacement_per_tau = defaultdict(list) + + # Iterate over all entries in the dataset + for i in range(len(fov_names)): + fov_name = fov_names[i] + track_id = track_ids[i] + current_time = timepoints[i] + current_embedding = embeddings[i] + + # For each time point t, compute displacements for t + tau + for tau in range(1, max_tau + 1): + future_time = current_time + tau + + # Find if future_time exists for the same (fov_name, track_id) + matching_indices = np.where( + (fov_names == fov_name) + & (track_ids == track_id) + & (timepoints == future_time) + )[0] + + if len(matching_indices) == 1: + # Get the embedding at t + tau + future_embedding = embeddings[matching_indices[0]] + + if use_cosine: + # Compute cosine similarity + similarity = cosine_similarity( + current_embedding.reshape(1, -1), + future_embedding.reshape(1, -1), + )[0][0] + # Choose whether to use similarity or dissimilarity + if use_dissimilarity: + displacement = 1 - similarity # Cosine dissimilarity + else: + displacement = similarity # Cosine similarity + else: + # Compute the Euclidean distance, elementwise square on difference + displacement = np.sum((current_embedding - future_embedding) ** 2) + + # Store the displacement for the given tau + displacement_per_tau[tau].append(displacement) + + return displacement_per_tau + + +def calculate_normalized_euclidean_distance_cell(embedding_dataset, fov_name, track_id): + filtered_data = embedding_dataset.where( + (embedding_dataset["fov_name"] == fov_name) + & (embedding_dataset["track_id"] == track_id), + drop=True, + ) + + features = filtered_data["features"].values # (sample, features) + time_points = filtered_data["t"].values # (sample,) + + normalized_features = features / np.linalg.norm(features, axis=1, keepdims=True) + + # Get the first time point's normalized embedding + first_time_point_embedding = normalized_features[0].reshape(1, -1) + + euclidean_distances = [] + for i in range(len(time_points)): + distance = np.linalg.norm( + first_time_point_embedding - normalized_features[i].reshape(1, -1) + ) + euclidean_distances.append(distance) + + return time_points, euclidean_distances + + +def compute_displacement_mean_std_full(embedding_dataset, max_tau=10): + fov_names = embedding_dataset["fov_name"].values + track_ids = embedding_dataset["track_id"].values + timepoints = embedding_dataset["t"].values + embeddings = embedding_dataset["features"].values + + cell_identifiers = np.array( + list(zip(fov_names, track_ids)), + dtype=[("fov_name", "O"), ("track_id", "int64")], + ) + + unique_cells = np.unique(cell_identifiers) + + displacement_per_tau = defaultdict(list) + + for cell in unique_cells: + fov_name = cell["fov_name"] + track_id = cell["track_id"] + + indices = np.where((fov_names == fov_name) & (track_ids == track_id))[0] + + cell_timepoints = timepoints[indices] + cell_embeddings = embeddings[indices] + + sorted_indices = np.argsort(cell_timepoints) + cell_timepoints = cell_timepoints[sorted_indices] + cell_embeddings = cell_embeddings[sorted_indices] + + for i in range(len(cell_timepoints)): + current_time = cell_timepoints[i] + current_embedding = cell_embeddings[i] + + current_embedding = current_embedding / np.linalg.norm(current_embedding) + + for tau in range(0, max_tau + 1): + future_time = current_time + tau + + future_index = np.where(cell_timepoints == future_time)[0] + + if len(future_index) >= 1: + future_embedding = cell_embeddings[future_index[0]] + future_embedding = future_embedding / np.linalg.norm( + future_embedding + ) + + distance = np.linalg.norm(current_embedding - future_embedding) + + displacement_per_tau[tau].append(distance) + + mean_displacement_per_tau = { + tau: np.mean(displacements) + for tau, displacements in displacement_per_tau.items() + } + std_displacement_per_tau = { + tau: np.std(displacements) + for tau, displacements in displacement_per_tau.items() + } + + return mean_displacement_per_tau, std_displacement_per_tau