Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SSL: fix MLP head and remove L2 normalization #141

Closed
wants to merge 17 commits into from

Conversation

mattersoflight
Copy link
Member

@mattersoflight mattersoflight commented Aug 17, 2024

Fix the sequence of batchnorm and linear in the last MLP layer of the contrastive encoder model.

Remove L2 normalization before computing triplet loss. This works best when also reducing the dimension of projections.

Refactor light module into representation and translation to separate pipelining code for different tasks.

Fix #139, fix #138.

@mattersoflight mattersoflight requested a review from ziw-liu August 17, 2024 02:17
@mattersoflight
Copy link
Member Author

@ziw-liu started the draft. I'll let you complete it. Feel free to make other improvements in architecture as you go.

@mattersoflight mattersoflight changed the title draft projection head projection head (and other architectural changes. Aug 17, 2024
@mattersoflight mattersoflight changed the title projection head (and other architectural changes. projection head (and other architectural changes) Aug 17, 2024
@ziw-liu ziw-liu added bug Something isn't working breaking Breaking changes labels Aug 21, 2024
the projected features saved during prediction is now *not* normalized
@ziw-liu ziw-liu added the enhancement New feature or request label Aug 21, 2024
@Soorya19Pradeep
Copy link
Contributor

@ziw-liu , I get the following error:

[rank3]: else self._run_ddp_forward(*inputs, **kwargs)
[rank3]: File "/hpc/mydata/soorya.pradeep/scratch/viscy_test/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1454, in _run_ddp_forward
[rank3]: return self.module(*inputs, **kwargs) # type: ignore[index]
[rank3]: File "/hpc/mydata/soorya.pradeep/scratch/viscy_test/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank3]: return self._call_impl(*args, **kwargs)
[rank3]: File "/hpc/mydata/soorya.pradeep/scratch/viscy_test/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank3]: return forward_call(args, **kwargs)
[rank3]: File "/hpc/mydata/soorya.pradeep/scratch/viscy_test/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py", line 633, in wrapped_forward
[rank3]: out = method(
_args, **_kwargs)
[rank3]: File "/hpc/mydata/soorya.pradeep/scratch/viscy_infection_phenotyping/VisCy/viscy/representation/engine.py", line 144, in validation_step
[rank3]: self._log_metrics(
[rank3]: File "/hpc/mydata/soorya.pradeep/scratch/viscy_infection_phenotyping/VisCy/viscy/representation/engine.py", line 74, in _log_metrics
[rank3]: cosine_sim_pos = F.cosine_similarity(anchor, positive, dim=1).mean()
[rank3]: IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

Sanity Checking: | | 0/? [00:00<?, ?it/s]
Sanity Checking: 0%| | 0/2 [00:00<?, ?it/s]
Sanity Checking DataLoader 0: 0%| | 0/2 [00:00<?, ?it/s]srun: error: gpu-f-2: tasks 0,2: Exited with exit code 1
srun: error: gpu-f-2: task 3: Exited with exit code 1

@ziw-liu
Copy link
Collaborator

ziw-liu commented Aug 21, 2024

@Soorya19Pradeep Fixed in 11fa65a. This is still very much work in progress. I have not trained a full model yet.

@ziw-liu
Copy link
Collaborator

ziw-liu commented Aug 23, 2024

@mattersoflight With batch normalization in the MLP head, the features still have much lower rank (11) than projection (96).

Edit: see full analysis here.

@mattersoflight
Copy link
Member Author

@ziw-liu our next steps:

  • Reorder batchnorm layer and MLP layer at the network's end (MLP->batchnorm).
  • Do not L2-normalize projections during fitting, let the optimizer enforce normalization.
  • Perform two computational experiments: with the above changes with projection dim = 128 and projection dim ~ 16.

@mattersoflight mattersoflight marked this pull request as ready for review August 27, 2024 21:11
@ziw-liu ziw-liu changed the base branch from main to representation-learning August 28, 2024 00:29
@ziw-liu ziw-liu changed the title projection head (and other architectural changes) SSL: fix MLP head and remove L2 normalization Aug 28, 2024
@ziw-liu ziw-liu deleted the branch representation-learning August 28, 2024 00:40
@ziw-liu ziw-liu closed this Aug 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
breaking Breaking changes bug Something isn't working enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants