Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2.1D upscale decoder #37

Merged
merged 30 commits into from
Aug 30, 2023
Merged

2.1D upscale decoder #37

merged 30 commits into from
Aug 30, 2023

Conversation

ziw-liu
Copy link
Collaborator

@ziw-liu ziw-liu commented Aug 21, 2023

Adds a pixelshuffle mode to the 2.1D network. This avoids the checkerboard artifact introduced by transposed convolution layers.

Initialization (samples from first training epoch) of deconvolution (left) and pixelshuffle (right):
image

After fitting:
image

ziw-liu and others added 2 commits August 17, 2023 14:52
* sample multiple patches from one stack

* do not use type annotations from future
it breaks jsonargparse

* fix channel stacking for non-training samples

* remove batch size from model
the metrics will be automatically reduced by lightning
@mattersoflight
Copy link
Member

@ziw-liu Exciting!

A question: why don't we notice the checkerboard artifact with 2D UNet?
Please add a training script for 2.1D model to examples. I will find some time to test after #36 is reviewed at the course.

@mattersoflight
Copy link
Member

I suggest we merge #36 into #37, before merging both to main.

@ziw-liu
Copy link
Collaborator Author

ziw-liu commented Aug 21, 2023

A question: why don't we notice the checkerboard artifact with 2D UNet?

That's because it is using non-trainable bilinear interpolation as the upsampling method. This introduces the least artifact due to random weight initialization (there is no learnable weight), but with the disadvantage that the convolutions have to happen in higher resolution to refine the upsampled features, which incurs $O(s^2)$ cost ($s$ is the scaling factor in 2D) compared to sub-pixel convolution, where upsampling is the last step of a given resolution level. See this paper for details.

@ziw-liu
Copy link
Collaborator Author

ziw-liu commented Aug 21, 2023

Please add a training script for 2.1D model to examples.

Does it need to be a Python script or will a config file suffice?

ziw-liu and others added 5 commits August 22, 2023 23:30
* updated intro and paths

* updated figures, tested data loader

* setup.sh fetches correct dataset

* finalized the exercise outline

* semi-final exercise

* parts 1 and 2 tested, part 3 outline ready

* clearer variables, train with larger patch size

* fix typo

* clarify variable names

* trying to log graph

* match example size with training

* reuse globals

* fix reference

* log sample images from the first batch

* wider model

* low LR solution

* fix path

* seed everything

* fix test dataset without masks

* metrics solution
this needs a new test dataset

* fetch test data, compute metrics

* byass cellpose import error due to numpy version conflicts

* final exercise

* moved files

* fixed formatting - ready for review

* viscy -> VisCy (#34) (#39)

Introducing capitalization to highlight vision and single-cell aspects of the pipeline.

* trying to log graph

* log graph

* black

---------

Co-authored-by: Shalin Mehta <[email protected]>
Co-authored-by: Shalin Mehta <[email protected]>
@mattersoflight mattersoflight self-requested a review August 23, 2023 18:19
@ziw-liu ziw-liu requested a review from edyoshikun August 29, 2023 19:51
@@ -5,15 +5,15 @@

# %%
model = VSUNet(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for this script @ziw-liu. This is very helpful to just explore specs and diagrams of different architectures.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The svg outputs from torchview look great!

norm_layer_cl=timm.layers.LayerNorm,
)


class Conv21dStem(nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the benefit of downsampling the input with stride = kernel_size? Why not learn the filters with stride=1, and then average pool?

I am worried this may affect the shift invariance of the model by worsening the aliasing effects described in this paper

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is borrowed from ConvNext. The idea is to patchify the image tokens and learn embeddings with large kernels in deep layers.

In practice modern DCNNs rarely use stride 1 for the first convolution layer, because the high resolution restricts the features that can be learned at each stage. For example ResNet (2015) uses k=7, s=2 followed by 2x2 pooling as its stem. This is especially true for high-resolution 3D input.

Another distinction is that the natural images used in the paper you linked are low resolution, meaning that very small windows (e.g. 3x3) can contain critical texture information in both luminance and chrominance. In microscopy images we train on, voxel sizes are co-designed with diffraction limit and 4x4 patches should not contain more than one distinguishable feature.

Also classification networks suffer from feature space collapse and thus are much more sensitive to this problem than models trained for dense tasks. Since our model generalizes even on a different microscope where the optical transfer is completely different, I don't think there is empirical evidence supporting using stride 1 at the stem layer.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very interesting paper! Pasting the relevant bit of text:
image

The improvement in image classification accuracy with patchification is reported to be 0.1%. That may not be reproducible.

Another distinction is that the natural images used in the paper you linked are low resolution, meaning that very small windows (e.g. 3x3) can contain critical texture information in both luminance and chrominance.

Our data is typically critically sampled (Nyquist criterion) and we deconvolve to boost high spatial frequencies. Subsampling can lead to aliasing.

Did the checkerboard artifacts coincide with the use of patchify?
You may have already done experiments with different strides (levels of patchification). If yes, please point me to them.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The improvement in image classification accuracy with patchification is reported to be 0.1%. That may not be reproducible.

Note that they compared to the ResNet stem, i.e. 7x7 stride 2 convolution followed by max pooling for a total 4x4 downsampling. Both methods use a stem containing one convolution layer to achieve 4x4 down sampling. The difference being that in the later design the downsampling is completely learned. This should actually help avoid aliasing.
Stride 1 at the first conv followed by pooling will result in a stem that does not learn the downsampling at all. It is also very expensive given the large 3D input. With the current width, that single layer will take up more than 10% of the FLOPs of the entire network. This is one of the main reasons why the 2.5D network can only learn 16 filters in the first layer with a 48 GB GPU on typical patch sizes. In other words, without the initial aggressive learned projection, the 2.1D network will be more similar to 2.5D in terms of capacity.

Did the checkerboard artifacts coincide with the use of patchify?

The checkerboard artifacts are due to the upscaling method, and they do not appear when linear upsampling is used. This is a well-studied problem, and I haven't seen any literature linking it to the stride in the first encoder layer.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may have already done experiments with different strides (levels of patchification). If yes, please point me to them.

I haven't done any experiments since it will require significantly shrinking down the rest of the network. This is not necessarily bad, but will likely require a different overall design. See HR-Net for an example.

At this stage I have been avoiding hyper-parameter tuning unless there is a clear issue it can fix.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a third option: since the pipeline now supports distributed training, we can scale the model beyond the computational complexity of old 2.5D models and learn both high-resolution feature maps and very deep representations.

model,
model.example_input_array,
graph_name="2D UNet",
roll=True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"deconv" value for this key leads to an error during graph construction.

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
File ~/conda/envs/04_image_translation/lib/python3.10/site-packages/torchview/torchview.py:256, in forward_prop(model, x, device, model_graph, mode, **kwargs)
    255 if isinstance(x, (list, tuple)):
--> 256     _ = model.to(device)(*x, **kwargs)
    257 elif isinstance(x, Mapping):

File ~/conda/envs/04_image_translation/lib/python3.10/site-packages/torchview/recorder_tensor.py:146, in module_forward_wrapper.._module_forward_wrapper(mod, *args, **kwargs)
    144 # TODO: check if output contains RecorderTensor
    145 # this seems not to be necessary so far
--> 146 out = _orig_module_forward(mod, *args, **kwargs)
    148 model_graph.context_tracker['current_depth'] = cur_depth

File ~/conda/envs/04_image_translation/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used

File ~/code/viscy/viscy/light/engine.py:215, in VSUNet.forward(self, x)
    214 def forward(self, x) -> torch.Tensor:
--> 215     return self.model(x)

File ~/conda/envs/04_image_translation/lib/python3.10/site-packages/torchview/recorder_tensor.py:146, in module_forward_wrapper.._module_forward_wrapper(mod, *args, **kwargs)
...
    266     ) from e
    267 finally:
    268     model.train(saved_model_mode)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was planning to deprecate that option. Do you think there is a value in keeping (and fixing) it?

@mattersoflight mattersoflight merged commit b4ec13c into main Aug 30, 2023
@ziw-liu ziw-liu deleted the 21d-upscale-decoder branch August 30, 2023 18:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Unable to load two channesl as inputs to do fluoresence to phase image translation
3 participants