Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature] : u-nets: add ability to remove regions that don't touch original grain #1107

Open
SylviaWhittle opened this issue Mar 14, 2025 · 0 comments · May be fixed by #1108
Open

[feature] : u-nets: add ability to remove regions that don't touch original grain #1107

SylviaWhittle opened this issue Mar 14, 2025 · 0 comments · May be fixed by #1108
Labels
enhancement New feature or request

Comments

@SylviaWhittle
Copy link
Collaborator

Is your feature request related to a problem?

During a project with @llwiggins, we found that u-net masks would often include dna that was not part of the molecule of interest. This causes the processing pipeline to incorrectly calculate stats for other grains as if they were part of the grain of interest.

Describe the solution you would like.

A simple solution to this is to:

  • Take the u-net mask
  • Flatten the classes into one combined class
  • Via connected-component analysis, delete any component not touching the original mask

This is far from a perfect solution but it worked for what I was doing. Here is an image showing the result

Image

And here is the code I wrote for it in the llwiggins/grain_restructure_multiclass branch:

grains.py > improve_grain_segmentation_unet

            if unet_config["remove_disconnected_grains"]:
                # Remove grains that are not connected to the original grain
                original_grain_mask = graincrop.mask
                predicted_mask = Grains.remove_disconnected_grains(
                    original_grain_tensor=original_grain_mask,
                    predicted_grain_tensor=predicted_mask,
                )

grains.py > Grains.remove_disconnected_grains

    @staticmethod
    def remove_disconnected_grains(
        original_grain_tensor: npt.NDArray,
        predicted_grain_tensor: npt.NDArray,
    ):
        """
        Remove grains that are not connected to the original grains.

        Parameters
        ----------
        original_grain_tensor : npt.NDArray
            3-D Numpy array of the original grain tensor.
        predicted_grain_tensor : npt.NDArray
            3-D Numpy array of the predicted grain tensor.

        Returns
        -------
        npt.NDArray
            3-D Numpy array of the predicted grain tensor with grains not connected to the original grains removed.
        """
        # flatten the masks and compare connected components
        original_mask_flattened = Grains.flatten_multi_class_tensor(original_grain_tensor)
        predicted_mask_flattened = Grains.flatten_multi_class_tensor(predicted_grain_tensor)
        # Get the connected components of the original grain mask
        original_mask_flattened_labelled = label(original_mask_flattened)
        predicted_mask_flattened_labelled = label(predicted_mask_flattened)
        # for each region of the predicted mask, check if it overlaps with any of the original mask regions
        # (the original mask is expected to only have one region, but just in case future edits don't follow
        # this assumption, I check all regions)
        predicted_mask_regions = regionprops(predicted_mask_flattened_labelled)
        original_mask_regions = regionprops(original_mask_flattened_labelled)
        # if the predicted mask region doesn't overlap with any of the original mask regions, set it to 0
        for predicted_mask_region in predicted_mask_regions:
            predicted_mask_region_mask = predicted_mask_flattened_labelled == predicted_mask_region.label
            overlap = False
            for original_mask_region in original_mask_regions:
                original_mask_region_mask = original_mask_flattened_labelled == original_mask_region.label
                if np.any(predicted_mask_region_mask & original_mask_region_mask):
                    # a region in the flattened original mask shares a pixel with the flattened predicted mask
                    overlap = True
                    break
            if not overlap:
                # zero the region in all channels of the predicted mask
                for channel in range(1, predicted_grain_tensor.shape[-1]):
                    predicted_grain_tensor[predicted_mask_region_mask, channel] = 0

        return predicted_grain_tensor

and a test:

test_grains.py > test_remove_disconnected_grains

@pytest.mark.parametrize(
    ("original_grain_tensor", "predicted_grain_tensor", "expected_result_grain_tensor"),
    [
        pytest.param(
            np.stack(
                [
                    np.array(
                        [
                            [1, 1, 1, 1, 1, 1, 1],
                            [1, 1, 1, 1, 1, 1, 1],
                            [1, 1, 1, 1, 1, 1, 1],
                            [1, 1, 1, 1, 1, 1, 1],
                            [1, 1, 1, 1, 1, 1, 1],
                            [1, 1, 1, 1, 1, 1, 1],
                            [1, 1, 1, 1, 1, 1, 1],
                        ]
                    ),
                    np.array(
                        [
                            [0, 0, 0, 0, 0, 0, 0],
                            [0, 1, 0, 0, 0, 0, 0],
                            [0, 1, 0, 0, 0, 0, 0],
                            [0, 0, 0, 0, 0, 0, 0],
                            [0, 0, 0, 0, 0, 0, 0],
                            [0, 0, 0, 0, 0, 0, 0],
                            [0, 0, 0, 0, 0, 0, 0],
                        ]
                    ),
                    np.array(
                        [
                            [0, 0, 0, 0, 0, 0, 0],
                            [0, 0, 1, 0, 0, 0, 0],
                            [0, 0, 1, 0, 0, 0, 0],
                            [0, 0, 0, 0, 0, 0, 0],
                            [0, 0, 0, 0, 0, 0, 0],
                            [0, 0, 0, 0, 0, 0, 0],
                            [0, 0, 0, 0, 0, 0, 0],
                        ]
                    ),
                ],
                axis=-1,
            ),
            np.stack(
                [
                    np.array(
                        [
                            [1, 1, 1, 1, 1, 1, 1],
                            [1, 1, 1, 1, 1, 1, 1],
                            [1, 1, 1, 1, 1, 1, 1],
                            [1, 1, 1, 1, 1, 1, 1],
                            [1, 1, 1, 1, 1, 1, 1],
                            [1, 1, 1, 1, 1, 1, 1],
                            [1, 1, 1, 1, 1, 1, 1],
                        ]
                    ),
                    np.array(
                        [
                            [0, 0, 0, 0, 0, 0, 0],
                            [0, 1, 1, 0, 0, 0, 0],
                            [0, 0, 1, 0, 0, 0, 0],
                            [0, 0, 0, 0, 0, 0, 0],
                            [0, 0, 0, 0, 0, 1, 0],
                            [0, 0, 0, 0, 0, 1, 0],
                            [0, 0, 0, 0, 0, 0, 0],
                        ]
                    ),
                    np.array(
                        [
                            [0, 0, 0, 0, 0, 0, 0],
                            [0, 0, 0, 0, 0, 0, 0],
                            [0, 1, 0, 0, 0, 0, 0],
                            [0, 0, 0, 0, 0, 0, 0],
                            [0, 0, 0, 0, 1, 0, 0],
                            [0, 0, 0, 0, 1, 0, 0],
                            [0, 0, 0, 0, 0, 0, 0],
                        ]
                    ),
                ],
                axis=-1,
            ),
            np.stack(
                [
                    np.array(
                        [
                            [1, 1, 1, 1, 1, 1, 1],
                            [1, 1, 1, 1, 1, 1, 1],
                            [1, 1, 1, 1, 1, 1, 1],
                            [1, 1, 1, 1, 1, 1, 1],
                            [1, 1, 1, 1, 1, 1, 1],
                            [1, 1, 1, 1, 1, 1, 1],
                            [1, 1, 1, 1, 1, 1, 1],
                        ]
                    ),
                    np.array(
                        [
                            [0, 0, 0, 0, 0, 0, 0],
                            [0, 1, 1, 0, 0, 0, 0],
                            [0, 0, 1, 0, 0, 0, 0],
                            [0, 0, 0, 0, 0, 0, 0],
                            [0, 0, 0, 0, 0, 0, 0],
                            [0, 0, 0, 0, 0, 0, 0],
                            [0, 0, 0, 0, 0, 0, 0],
                        ]
                    ),
                    np.array(
                        [
                            [0, 0, 0, 0, 0, 0, 0],
                            [0, 0, 0, 0, 0, 0, 0],
                            [0, 1, 0, 0, 0, 0, 0],
                            [0, 0, 0, 0, 0, 0, 0],
                            [0, 0, 0, 0, 0, 0, 0],
                            [0, 0, 0, 0, 0, 0, 0],
                            [0, 0, 0, 0, 0, 0, 0],
                        ]
                    ),
                ],
                axis=-1,
            ),
        )
    ],
)
def test_remove_disconnected_grains(
    original_grain_tensor: npt.NDArray[np.int32],
    predicted_grain_tensor: npt.NDArray[np.int32],
    expected_result_grain_tensor: npt.NDArray[np.int32],
) -> None:
    """Test the remove_disconnected_grains method of the Grains class."""
    result_grain_tensor = Grains.remove_disconnected_grains(
        original_grain_tensor, predicted_grain_tensor
    )

    np.testing.assert_array_equal(result_grain_tensor, expected_result_grain_tensor)

Describe the alternatives you have considered.

N/A

Additional context

No response

@SylviaWhittle SylviaWhittle added the enhancement New feature or request label Mar 14, 2025
@SylviaWhittle SylviaWhittle linked a pull request Mar 14, 2025 that will close this issue
9 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant