Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sharded distributed sampler for cached dataloading in DDP #195

Merged
merged 54 commits into from
Jan 2, 2025
Merged

Conversation

ziw-liu
Copy link
Collaborator

@ziw-liu ziw-liu commented Oct 21, 2024

Add a distributed sampler that only permutes index within ranks, improving cache hit rate in DDP.

See viscy/scripts/shared_dict.py for usage.

Also includes changes from #196.

@ziw-liu ziw-liu marked this pull request as ready for review October 21, 2024 21:17
@ziw-liu ziw-liu requested a review from edyoshikun October 21, 2024 21:17
@ziw-liu ziw-liu added enhancement New feature or request translation Image translation (VS) labels Oct 21, 2024
@ziw-liu
Copy link
Collaborator Author

ziw-liu commented Oct 21, 2024

Example output: GPU available: True (cuda), used: False TPU available: False, using: 0 TPU cores HPU available: False, using: 0 HPUs /hpc/mydata/ziwen.liu/anaconda/2022.05/x86_64/envs/viscy/lib/python3.11/site-packages/lightning/pytorch/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`. Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/3 Initializing distributed: GLOBAL_RANK: 2, MEMBER: 3/3 Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/3 ---------------------------------------------------------------------------------------------------- distributed_backend=gloo All distributed processes registered. Starting with 3 processes ----------------------------------------------------------------------------------------------------

=== Initializing cache pool for rank 0 ===
=== Initializing cache pool for rank 1 ===
=== Initializing cache pool for rank 2 ===

| Name | Type | Params | Mode

0 | layer | Linear | 2 | train

2 Trainable params
0 Non-trainable params
2 Total params
0.000 Total estimated model params size (MB)
1 Modules in train mode
0 Modules in eval mode

  • Adding 31 to cache dict on rank 1
  • Adding 32 to cache dict on rank 2
  • Adding 38 to cache dict on rank 2
  • Adding 42 to cache dict on rank 0
  • Adding 30 to cache dict on rank 0
  • Adding 36 to cache dict on rank 0
  • Adding 37 to cache dict on rank 1
  • Adding 43 to cache dict on rank 1
  • Adding 48 to cache dict on rank 0
  • Adding 34 to cache dict on rank 1
  • Adding 49 to cache dict on rank 1
  • Adding 30 to cache dict on rank 2
  • Adding 44 to cache dict on rank 2
  • Adding 41 to cache dict on rank 2
  • Adding 35 to cache dict on rank 2
  • Adding 40 to cache dict on rank 1
  • Adding 46 to cache dict on rank 1
  • Adding 39 to cache dict on rank 0
  • Adding 33 to cache dict on rank 0
  • Adding 47 to cache dict on rank 2
  • Adding 45 to cache dict on rank 0
  • Adding 24 to cache dict on rank 2
  • Adding 13 to cache dict on rank 1
  • Adding 0 to cache dict on rank 0
  • Adding 20 to cache dict on rank 2
  • Adding 4 to cache dict on rank 0
  • Adding 29 to cache dict on rank 2
  • Adding 19 to cache dict on rank 1
  • Adding 26 to cache dict on rank 2
  • Adding 28 to cache dict on rank 2
    === Starting training ===
    === Starting training epoch 0 ===
  • Adding 8 to cache dict on rank 0
  • Adding 15 to cache dict on rank 1
  • Adding 3 to cache dict on rank 0
  • Adding 21 to cache dict on rank 2
  • Adding 11 to cache dict on rank 1
  • Adding 7 to cache dict on rank 0
  • Adding 23 to cache dict on rank 2
  • Adding 27 to cache dict on rank 2
  • Adding 22 to cache dict on rank 2
  • Adding 1 to cache dict on rank 0
  • Adding 9 to cache dict on rank 0
  • Adding 5 to cache dict on rank 0
  • Adding 17 to cache dict on rank 1
  • Adding 6 to cache dict on rank 0
  • Adding 18 to cache dict on rank 1
  • Adding 16 to cache dict on rank 1
  • Adding 14 to cache dict on rank 1
  • Adding 10 to cache dict on rank 1
  • Adding 25 to cache dict on rank 2
  • Adding 2 to cache dict on rank 0
  • Adding 12 to cache dict on rank 1
    === Starting training epoch 1 ===
    === Starting training epoch 2 ===
    === Starting training epoch 3 ===
    === Starting training epoch 4 ===
    Trainer.fit stopped: max_epochs=5 reached.

@ziw-liu ziw-liu changed the base branch from ram_dataloader to main October 21, 2024 23:28
@ziw-liu ziw-liu added this to the v0.4.0 milestone Nov 12, 2024
@ziw-liu
Copy link
Collaborator Author

ziw-liu commented Nov 14, 2024

The FcmaeUNet class can now do both pre-training and fine-tuning. This is controlled via the pretraining flag in model_config. This allows keeping the old VSUNet with its integration with the sliding window dataset, but also enable GPU-accelerated augmentations for virtual staining training.

@ziw-liu
Copy link
Collaborator Author

ziw-liu commented Nov 15, 2024

@edyoshikun do you want to keep improving the prototype in hcs_ram.py or we can remove it?

@edyoshikun
Copy link
Contributor

remove the hcs_ram.py

ziw-liu and others added 3 commits December 2, 2024 15:33
* fix spelling in docstring and comment

* add batched zoom transform for tta

* add standalone lightning module for arbitrary TTA

* fix composition of different zoom factors
Copy link
Contributor

@edyoshikun edyoshikun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This LGTM. The pending thing would be to properly write the augmentations now. Happy to merge this first and then add the augmentations.

@ziw-liu
Copy link
Collaborator Author

ziw-liu commented Jan 2, 2025

properly write the augmentations now

Can you elaborate?

@edyoshikun
Copy link
Contributor

edyoshikun commented Jan 2, 2025

I mean the tiling in 3D. I was thinking of the neuromast dataset or mantis datasets where if the XY dimension is less than 1 patch, then we only crop/tile the top left corners. The equivalent thing would happen in the Z dimension.

@ziw-liu
Copy link
Collaborator Author

ziw-liu commented Jan 2, 2025

I mean the tiling in 3D. I was thinking of the neuromast dataset or mantis datasets where if the XY dimension is less than 1 patch, then we only crop/tile the top left corners. The equivalent thing would happen in the Z dimension.

IIRC this only affects validation? Let's open an issue and fix the tiling transform separately.

Copy link
Contributor

@edyoshikun edyoshikun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have super minor comments. Otherwise I've tested this with the phase pertaining and works well.

Let's fix the tiling in a separate issue+PR.

batch_size: int = 16,
num_workers: int = 8,
val_subsample_ratio: int = 30,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor preference but when I see ratios I typically think of values 0-1. Like the masking ratio, the train/val ratio, etc.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would be a better name?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep this for now as I am the only user at the moment.

@ziw-liu ziw-liu merged commit 05af2e6 into main Jan 2, 2025
4 checks passed
@ziw-liu ziw-liu deleted the simple-cache branch January 2, 2025 22:42
edyoshikun added a commit that referenced this pull request Feb 13, 2025
* caching dataloader

* caching data module

* black

* ruff

* Bump torch to 2.4.1 (#174)

* update torch >2.4.1

* black

* ruff

* adding timeout to ram_dataloader

* bandaid to cached dataloader

* fixing the dataloader using torch collate_fn

* replacing dictionary with single array

* loading prior to epoch 0

* Revert "replacing dictionary with single array"

This reverts commit 8c13f49.

* using multiprocessing manager

* add sharded distributed sampler

* add example script for ddp caching

* format and lint

* addding the custom distrb sampler to hcs_ram.py

* adding sampler to val train dataloader

* fix divisibility of the last shard

* hcs_ram format and lint

* data module that only crops and does not collate

* wip: execute transforms on the GPU

* path for if not ddp

* fix randomness in inversion transform

* add option to pop the normalization metadata

* move gpu transform definition back to data module

* add tiled crop transform for validation

* add stack channel transform for gpu augmentation

* fix typing

* collate before sending to gpu

* inherit gpu transforms for livecell dataset

* update fcmae engine to apply per-dataset augmentations

* format and lint hcs_ram

* fix abc type hint

* update docstring style

* disable grad for validation transforms

* improve sample image logging in fcmae

* fix dataset length when batch size is larger than the dataset

* fix docstring

* add option to disable normalization metadata

* inherit gpu transform for ctmc

* remove duplicate method overrride

* update docstring for ctmc

* allow skipping caching for large datasets

* make the fcmae module compatible with image translation

* remove prototype implementation

* fix import path

* Arbitrary prediction time transforms (#209)

* fix spelling in docstring and comment

* add batched zoom transform for tta

* add standalone lightning module for arbitrary TTA

* fix composition of different zoom factors

* add docstrings

* fix typo in docstring

---------

Co-authored-by: Eduardo Hirata-Miyasaki <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request translation Image translation (VS)
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants