-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
2.1D upscale decoder #37
Conversation
* sample multiple patches from one stack * do not use type annotations from future it breaks jsonargparse * fix channel stacking for non-training samples * remove batch size from model the metrics will be automatically reduced by lightning
That's because it is using non-trainable bilinear interpolation as the upsampling method. This introduces the least artifact due to random weight initialization (there is no learnable weight), but with the disadvantage that the convolutions have to happen in higher resolution to refine the upsampled features, which incurs |
Does it need to be a Python script or will a config file suffice? |
* updated intro and paths * updated figures, tested data loader * setup.sh fetches correct dataset * finalized the exercise outline * semi-final exercise * parts 1 and 2 tested, part 3 outline ready * clearer variables, train with larger patch size * fix typo * clarify variable names * trying to log graph * match example size with training * reuse globals * fix reference * log sample images from the first batch * wider model * low LR solution * fix path * seed everything * fix test dataset without masks * metrics solution this needs a new test dataset * fetch test data, compute metrics * byass cellpose import error due to numpy version conflicts * final exercise * moved files * fixed formatting - ready for review * viscy -> VisCy (#34) (#39) Introducing capitalization to highlight vision and single-cell aspects of the pipeline. * trying to log graph * log graph * black --------- Co-authored-by: Shalin Mehta <[email protected]> Co-authored-by: Shalin Mehta <[email protected]>
sliding windows are blended with uniform average
@@ -5,15 +5,15 @@ | |||
|
|||
# %% | |||
model = VSUNet( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for this script @ziw-liu. This is very helpful to just explore specs and diagrams of different architectures.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The svg outputs from torchview look great!
norm_layer_cl=timm.layers.LayerNorm, | ||
) | ||
|
||
|
||
class Conv21dStem(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the benefit of downsampling the input with stride = kernel_size
? Why not learn the filters with stride=1
, and then average pool?
I am worried this may affect the shift invariance of the model by worsening the aliasing effects described in this paper
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is borrowed from ConvNext. The idea is to patchify the image tokens and learn embeddings with large kernels in deep layers.
In practice modern DCNNs rarely use stride 1 for the first convolution layer, because the high resolution restricts the features that can be learned at each stage. For example ResNet (2015) uses k=7, s=2 followed by 2x2 pooling as its stem. This is especially true for high-resolution 3D input.
Another distinction is that the natural images used in the paper you linked are low resolution, meaning that very small windows (e.g. 3x3) can contain critical texture information in both luminance and chrominance. In microscopy images we train on, voxel sizes are co-designed with diffraction limit and 4x4 patches should not contain more than one distinguishable feature.
Also classification networks suffer from feature space collapse and thus are much more sensitive to this problem than models trained for dense tasks. Since our model generalizes even on a different microscope where the optical transfer is completely different, I don't think there is empirical evidence supporting using stride 1 at the stem layer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very interesting paper! Pasting the relevant bit of text:
The improvement in image classification accuracy with patchification is reported to be 0.1%. That may not be reproducible.
Another distinction is that the natural images used in the paper you linked are low resolution, meaning that very small windows (e.g. 3x3) can contain critical texture information in both luminance and chrominance.
Our data is typically critically sampled (Nyquist criterion) and we deconvolve to boost high spatial frequencies. Subsampling can lead to aliasing.
Did the checkerboard artifacts coincide with the use of patchify?
You may have already done experiments with different strides (levels of patchification). If yes, please point me to them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The improvement in image classification accuracy with patchification is reported to be 0.1%. That may not be reproducible.
Note that they compared to the ResNet stem, i.e. 7x7 stride 2 convolution followed by max pooling for a total 4x4 downsampling. Both methods use a stem containing one convolution layer to achieve 4x4 down sampling. The difference being that in the later design the downsampling is completely learned. This should actually help avoid aliasing.
Stride 1 at the first conv followed by pooling will result in a stem that does not learn the downsampling at all. It is also very expensive given the large 3D input. With the current width, that single layer will take up more than 10% of the FLOPs of the entire network. This is one of the main reasons why the 2.5D network can only learn 16 filters in the first layer with a 48 GB GPU on typical patch sizes. In other words, without the initial aggressive learned projection, the 2.1D network will be more similar to 2.5D in terms of capacity.
Did the checkerboard artifacts coincide with the use of patchify?
The checkerboard artifacts are due to the upscaling method, and they do not appear when linear upsampling is used. This is a well-studied problem, and I haven't seen any literature linking it to the stride in the first encoder layer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You may have already done experiments with different strides (levels of patchification). If yes, please point me to them.
I haven't done any experiments since it will require significantly shrinking down the rest of the network. This is not necessarily bad, but will likely require a different overall design. See HR-Net for an example.
At this stage I have been avoiding hyper-parameter tuning unless there is a clear issue it can fix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a third option: since the pipeline now supports distributed training, we can scale the model beyond the computational complexity of old 2.5D models and learn both high-resolution feature maps and very deep representations.
model, | ||
model.example_input_array, | ||
graph_name="2D UNet", | ||
roll=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"deconv" value for this key leads to an error during graph construction.
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
File ~/conda/envs/04_image_translation/lib/python3.10/site-packages/torchview/torchview.py:256, in forward_prop(model, x, device, model_graph, mode, **kwargs)
255 if isinstance(x, (list, tuple)):
--> 256 _ = model.to(device)(*x, **kwargs)
257 elif isinstance(x, Mapping):
File ~/conda/envs/04_image_translation/lib/python3.10/site-packages/torchview/recorder_tensor.py:146, in module_forward_wrapper.._module_forward_wrapper(mod, *args, **kwargs)
144 # TODO: check if output contains RecorderTensor
145 # this seems not to be necessary so far
--> 146 out = _orig_module_forward(mod, *args, **kwargs)
148 model_graph.context_tracker['current_depth'] = cur_depth
File ~/conda/envs/04_image_translation/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
File ~/code/viscy/viscy/light/engine.py:215, in VSUNet.forward(self, x)
214 def forward(self, x) -> torch.Tensor:
--> 215 return self.model(x)
File ~/conda/envs/04_image_translation/lib/python3.10/site-packages/torchview/recorder_tensor.py:146, in module_forward_wrapper.._module_forward_wrapper(mod, *args, **kwargs)
...
266 ) from e
267 finally:
268 model.train(saved_model_mode)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was planning to deprecate that option. Do you think there is a value in keeping (and fixing) it?
Adds a pixelshuffle mode to the 2.1D network. This avoids the checkerboard artifact introduced by transposed convolution layers.
Initialization (samples from first training epoch) of deconvolution (left) and pixelshuffle (right):

After fitting:
