-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
SSL: fix MLP head and remove L2 normalization (#145)
* draft projection head per Update the projection head (normalization and size). #139 * reorganize comments in example fit config * configurable stem stride and projection dimensions * update type hint and docstring for ContrastiveEncoder * clarify embedding_dim * use the forward method directly for projected * normalize projections only when fitting the projected features saved during prediction is now *not* normalized * remove unused logger * refactor training code into translation and representation modules * extract image logging functions * use AdamW instead of Adam for contrastive learning * inline single-use argument * fix normalization * fix MLP layer order * fix output dimensions * remove L2 normalization before computing loss * compute rank of features and projections * documentation --------- Co-authored-by: Shalin Mehta <[email protected]>
- Loading branch information
1 parent
6e7d61f
commit 1f269c7
Showing
28 changed files
with
335 additions
and
336 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from typing import Sequence | ||
|
||
import numpy as np | ||
from matplotlib.pyplot import get_cmap | ||
from skimage.exposure import rescale_intensity | ||
from torch import Tensor | ||
|
||
|
||
def detach_sample(imgs: Sequence[Tensor], log_samples_per_batch: int): | ||
num_samples = min(imgs[0].shape[0], log_samples_per_batch) | ||
samples = [] | ||
for i in range(num_samples): | ||
patches = [] | ||
for img in imgs: | ||
patch = img[i].detach().cpu().numpy() | ||
patch = np.squeeze(patch[:, patch.shape[1] // 2]) | ||
patches.append(patch) | ||
samples.append(patches) | ||
return samples | ||
|
||
|
||
def render_images(imgs: Sequence[Sequence[np.ndarray]], cmaps: list[str] = []): | ||
images_grid = [] | ||
for sample_images in imgs: | ||
images_row = [] | ||
for i, image in enumerate(sample_images): | ||
if cmaps: | ||
cm_name = cmaps[i] | ||
else: | ||
cm_name = "gray" if i == 0 else "inferno" | ||
if image.ndim == 2: | ||
image = image[np.newaxis] | ||
for channel in image: | ||
channel = rescale_intensity(channel, out_range=(0, 1)) | ||
render = get_cmap(cm_name)(channel, bytes=True)[..., :3] | ||
images_row.append(render) | ||
images_grid.append(np.concatenate(images_row, axis=1)) | ||
return np.concatenate(images_grid, axis=0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.