Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flexible coordinate transform #9543

Merged
merged 18 commits into from
Feb 14, 2025
Merged

Conversation

benbovy
Copy link
Member

@benbovy benbovy commented Sep 24, 2024

This PR is a first step towards adding generic support for coordinate transforms in Xarray (i.e., analytical coordinates, functional index, etc.), which has been discussed already in different issues or threads:

(I might miss other issues / discussions)

I started with a few rough experimentations but ended up with something more concrete that seems to work reasonably well, hence directly opening a (draft) PR. The design & implementation detailled below is still very much open to discussion, though! There's an usage example further below using a 2D affine transformation. It would be nice to test this with other examples.

cc @rabernat @dcherian @TomNicholas @martindurant

Design / Implementation

This PR adds three new classes that should facilitate integrating any coordinate transform into Xarray:

CoordinateTransform

Abstract (wrapper) class to handle coordinate transformation with support of dimension and coordinate names.

  • many transforms should be pluggable via this class (by subclassing it)
  • supports bulk (vectorized) transformation, both in forward and reverse direction
  • supports any arbitrary number of coordinates / dimensions
    • lon, lat = f(x, y)
    • lon, lat, time = f2(x, y, t)
    • etc.
  • one restriction is that the coordinates of a same transform must all have the same dimensions
    • lon(x,y) / lat(x,y)
    • lon(x,y,t) / lat(x,y,t) / time(x,y,t)
    • etc.
    • much simpler!
    • In some cases however, the transform parameter values are such that it can be applied independently over each dimension (e.g., rioxarray creates x(x) / y(y) dimension coordinates when the affine transform is rectilinear with no rotation). For those cases we cannot use a single CoordinateTransform instance, but it is still possible to wrap the same underlying transform object in several instances and link their respective coordinates at the xarray Index level (see below).

CoordinateTransformIndexingAdapter

Internal class for creating indexable coordinate variables from a transform (no need to change the Xarray data model!).

  • wraps a CoordinateTransform instance
  • coordinate labels are computed on-demand (lazy coordinates)
  • supports both (explicit) orthogonal and vectorized indexing
  • supports dimension re-ordering (transpose)
  • doesn't support item assignment (of course)

CoordinateTransformIndex

Helper class for creating Xarray (custom) indexes based on coordinate transforms.

  • wraps a CoordinateTransform instance
  • takes care of creating the index (lazy) coordinates
  • supports label-based selection (i.e., using "physical" or "world" labels)
    • only advanced (point-wise) indexing for now
    • any idea on what else would be nice here and how to implement it?
  • supports alignment by comparing indexes based on their transform (not on their explicit coordinate labels)
    • only exact alignment for now (no join)
  • may be used directly, although should mostly be either subclassed or encapsulated in another Xarray Index class
    • in the rioxarray example (see above), we might want to encapsulate two instances of CoordinateTransformIndex into a custom RasterIndex for the x and y coordinates respectively
    • a custom Xarray index is the right place for encoding / decoding coordinate definition

Usage Example (Affine 2D)

CoordinateTransform subclass

Let's write a subclass of CoordinateTransform that handles 2-d affine transformation (using affine). It is basically boilerplate code that takes care of dimension or coordinate names around the unlabelled input/output arrays of the underlying affine.Affine object.

import affine
import xarray as xr


class Affine2DCoordinateTransform(xr.CoordinateTransform):
    """Affine 2D coordinate transform."""

    affine: affine.Affine
    xy_dims = tuple[str]
    
    def __init__(
        self,
        affine: affine.Affine,
        coord_names: Iterable[Hashable],
        dim_size: Mapping[str, int],
        dtype: Any = np.dtype(np.float64),
    ):
        # two dimensions
        assert len(coord_names) == 2
        assert len(dim_size) == 2

        super().__init__(coord_names, dim_size, dtype=dtype)
        self.affine = affine

        # array dimensions in reverse order (y = rows, x = cols)
        self.xy_dims = tuple(self.dims)
        self.dims = (self.dims[1], self.dims[0])

    def forward(self, dim_positions):
        positions = [dim_positions[dim] for dim in self.xy_dims]
        x_labels, y_labels = self.affine * tuple(positions)

        results = {}
        for name, labels in zip(self.coord_names, [x_labels, y_labels]):
            results[name] = labels

        return results

    def reverse(self, coord_labels):
        labels = [coord_labels[name] for name in self.coord_names]
        x_positions, y_positions = ~self.affine * tuple(labels)

        results = {}
        for dim, positions in zip(self.xy_dims, [x_positions, y_positions]):
            results[dim] = positions

        return results
    
    def equals(self, other):
        return self.affine == other.affine and self.dim_size == other.dim_size

Dataset, coordinates and index creation

In this example the index and the lazy coordinates are created from scratch, no pre-existing (explicit) coordinates are required!

from xarray.indexes import CoordinateTransformIndex

transform = Affine2DCoordinateTransform(
    affine.Affine.scale(1.0, 2.0),
    coord_names=("xc", "yc"),
    dim_size={"x": 10_000, "y": 20_000},
)

index = CoordinateTransformIndex(transform)
ds = xr.Dataset(coords=xr.Coordinates.from_xindex(index))

The resulting Dataset:

>>> ds
<xarray.Dataset> Size: 3GB
Dimensions:  (y: 10000, x: 20000)
Coordinates:
  * xc       (y, x) float64 2GB 0.0 1.0 2.0 3.0 4.0 ... 2e+04 2e+04 2e+04 2e+04
  * yc       (y, x) float64 2GB 0.0 2.0 4.0 6.0 ... 1.999e+04 2e+04 2e+04
Dimensions without coordinates: y, x
Data variables:
    *empty*
Indexes:
  ┌ xc       CoordinateTransformIndexyc

Coordinates "xc" and "yc" are big but they are lazy!

>>> ds.xc
<xarray.DataArray 'xc' (y: 10000, x: 20000)> Size: 2GB
[200000000 values with dtype=float64]
Coordinates:
  * xc       (y, x) float64 2GB 0.0 1.0 2.0 3.0 4.0 ... 2e+04 2e+04 2e+04 2e+04
  * yc       (y, x) float64 2GB 0.0 2.0 4.0 6.0 ... 1.999e+04 2e+04 2e+04
Dimensions without coordinates: y, x
Indexes:
  ┌ xc       CoordinateTransformIndexyc

>>> ds["xc"].variable._data
CoordinateTransformIndexingAdapter(transform=<__main__.Affine2DCoordinateTransform object at 0x15fb63790>)

Indexing

Orthogonal indexing (it is fast, it only computes 2x6 coordinate values below):

>>> ds.yc.isel(y=[0, 1, 3], x=slice(0, 2))
<xarray.DataArray 'yc' (y: 3, x: 2)> Size: 48B
array([[0., 0.],
       [2., 2.],
       [6., 6.]])
Coordinates:
    xc       (y, x) float64 48B 0.0 1.0 0.0 1.0 0.0 1.0
    yc       (y, x) float64 48B 0.0 0.0 2.0 2.0 6.0 6.0
Dimensions without coordinates: y, x

Also works after re-ordering the dimensions:

>>> ds.transpose().yc.isel(y=[0, 1, 3], x=slice(0, 2))
<xarray.DataArray 'yc' (x: 2, y: 3)> Size: 48B
array([[0., 2., 6.],
       [0., 2., 6.]])
Coordinates:
    xc       (x, y) float64 48B 0.0 0.0 0.0 1.0 1.0 1.0
    yc       (x, y) float64 48B 0.0 2.0 6.0 0.0 2.0 6.0
Dimensions without coordinates: x, y

Vectorized indexing:

>>> ds.yc.isel(
...     y=xr.Variable("points", [0, 1, 3]),
...     x=xr.Variable("points", [0, 1, 3]),
... )
<xarray.DataArray 'yc' (points: 3)> Size: 24B
array([0., 2., 6.])
Coordinates:
    xc       (points) float64 24B 0.0 1.0 3.0
    yc       (points) float64 24B 0.0 2.0 6.0
Dimensions without coordinates: points

Label-based selection

Point-wise selection:

>>> ds.sel(
...     xc=xr.Variable("points", [101.34, 545.23, 876.76]),
...     yc=xr.Variable("points", [13.12, 54.98, 76.43]),
...     method="nearest",
... )
<xarray.Dataset> Size: 48B
Dimensions:  (points: 3)
Coordinates:
    xc       (points) float64 24B 101.0 545.0 877.0
    yc       (points) float64 24B 14.0 54.0 76.0
Dimensions without coordinates: points
Data variables:
    *empty*

What's next?

A few potential improvements from here:

  • allow returning or re-calculating the transform instead of computing the coordinate labels while indexing
    • when possible, this should keep the xarray coordinates lazy and should also preserve their index
    • the obvious case if when a full slice is given for each dimension... Are there other less obvious cases?
    • allow CoordinateTransform.forward() to return a new instance of CoordinateTransform?
  • allow special handling of dimension reduction (e.g., return another transform in the reduced space)
  • possible to add generic support for joining / concatenating coordinate transforms? I.e., implement CoordinateTransformIndex.concat and CoordinateTransformIndex.join
  • handle chunking
    • maybe best solved at a higher level? I.e., one transform instance per chunk
  • add some convenient API for setting a new Xarray index from existing dimensions in a Dataset or DataArray?

@mdsumner
Copy link

Nice!! Thanks, I'm having fun with this - appreciate all the detail and functionality here it really helps a (non-native) Python learner.

@mdsumner
Copy link

mdsumner commented Sep 25, 2024

One thing is that the coordinate values are currently "left"/"top" aligned, not the centre, so here we start at left/top 0,0 and end at 4,4.

from xarray.indexes import CoordinateTransformIndex

transform = Affine2DCoordinateTransform(
    affine.Affine.scale(1.0, 1.0),
    coord_names=("xc", "yc"),
    dim_size={"x": 5, "y": 5},
)

index = CoordinateTransformIndex(transform)
ds = xr.Dataset(coords=index.create_coordinates())

ds.yc.values
#array([[0., 0., 0., 0., 0.],
#       [1., 1., 1., 1., 1.],
#       [2., 2., 2., 2., 2.],
#       [3., 3., 3., 3., 3.],
#       [4., 4., 4., 4., 4.]])

I'm assuming that a pure-scale puts us in the cell-area context of 0, 0, shape[0], shape[1] and so I pursued a world-realistic-ish context to convince myself.

I prefer to think in shape+bbox than in transforms when no shear is needed, so I'm using the gdal transform as an intermediate with a helper fun:

https://gist.github.com/mdsumner/dde0b611a4523e3485006c0df0143c2d

(fwiw, I'm sure this is obvious and not exactly priority rn but I'm excited to be able to delve into this and flesh out how I think about it in this context)

edit: I appreciate there's no absolutely right answer here, you might want (even decoupled per dimension) different alignment for your lazy coords in different contexts.

@benbovy
Copy link
Member Author

benbovy commented Sep 25, 2024

Thanks for the feedback @mdsumner.

Yes the idea is to have something very generic in Xarray such that we can build domain-specific applications on top of it. It is very useful to test this functionality in various contexts now to make sure we are providing the right levels of abstractions. So please keep having fun with this :-) !

Regarding your example, I think that rioxarray combines the input affine transformation with affine.Affine.translation(0.5, 0.5) to make coordinate values center aligned. If it is more natural to think in shape+bbox than in transforms in the geo domain, let's build something on top of CoordinateTransformIndex, e.g., something like below adapted from your gist example:

class GeoIndex(CoordinateTransformIndex):

    @classmethod
    def from_shape(cls, shape, bbox=None, center=True):
        if bbox is None: 
            bbox = (0.0, 0.0, shape[0], shape[1])
    
        gdal = (
            bbox[0], (bbox[2] - bbox[0]) / shape[0], 0.0, 
            bbox[3], 0.0, (bbox[1] - bbox[3]) / shape[1]
        )

        aff = affine.Affine.from_gdal(*gdal)

        if center:
            coord_names = ("xc", "yc")
            aff *= affine.Affine.translation(0.5, 0.5)
        else:
            # left/top
            coord_names = ("xlt", "ylt")
            
        transform = Affine2DCoordinateTransform(
            aff,
            coord_names=coord_names,
            dim_size={"x": shape[1], "y": shape[0]},
        )
            
        return cls(transform=transform)
>>> bbox = (-3950000, -3950000, 3950000, 4350000)
>>> shape = (316, 332)
>>> index = GeoIndex.from_shape(shape, bbox=bbox)
>>> ds = xr.Dataset(coords=xr.Coordinates.from_xindex(index))
>>> ds.isel(x=slice(0, 159), y=slice(0, 167))
<xarray.Dataset> Size: 425kB
Dimensions:  (y: 167, x: 159)
Coordinates:
    xc       (y, x) float64 212kB -3.938e+06 -3.912e+06 ... -1.25e+04 1.25e+04
    yc       (y, x) float64 212kB 4.338e+06 4.338e+06 ... 1.875e+05 1.875e+05
Dimensions without coordinates: y, x
Data variables:
    *empty*

(This gives the same coordinate values than the "ice" dataset loaded from the .tif file using the rasterio engine in your 2nd gist)

Comment on lines +1479 to +1482
# TODO: rounding the decimal positions is not always the behavior we expect
# (there are different ways to represent implicit intervals)
# we should probably make this customizable.
pos = np.round(pos).astype("int")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is important I think.

If the coordinates values correspond to the physical values at the top/left pixel corners in the 2D case, we may rather want np.floor(pos).astype("int") when converting decimal positions (obtained by inverse transformation) to integer indexers.

@martindurant
Copy link
Contributor

Great to see this. I haven't looked at the implementation yet, but I think I agree with the description whole heartedly.

It would be the place of the various IO backends to instantiate the affine (or whatever) transform from the the metadata standards of the respective formats.

@benbovy
Copy link
Member Author

benbovy commented Sep 25, 2024

For completeness, here is an implementation of the 1-dimensional "range index" discussed in #8955.

The coordinate transform subclass:

class Range1DCoordinateTransform(xr.CoordinateTransform):
    """Simple bounded interval 1-d coordinate transform."""

    left: float
    right: float
    dim: str
    size: int

    def __init__(
        self,
        left: float,
        right: float,
        coord_name: Hashable,
        dim: str,
        size: int,
        dtype: Any = None,
    ):  
        if dtype is None:
            dtype = np.dtype(np.float64)

        super().__init__([coord_name], {dim: size}, dtype=dtype)

        self.left = left
        self.right = right
        self.dim = dim
        self.size = size

    def forward(self, dim_positions):
        positions = dim_positions[self.dim]
        labels = self.left + positions * (self.right - self.left) / self.size
        return {self.dim: labels}
        
    def reverse(self, coord_labels):
        labels = coord_labels[self.coord_names[0]]
        positions = (labels - self.left) * self.size / (self.right - self.left)
        return {self.dim: positions}

    def equals(self, other):
        return (
            self.left == other.left
            and self.right == other.right
            and self.size == other.size
        )

Dataset creation:

>>> range_tr = Range1DCoordinateTransform(1.0, 2.0, "x", "x", 100)
>>> index = CoordinateTransformIndex(range_tr)
>>> ds = xr.Dataset(data_vars={"foo": ("x", np.arange(100))}, coords=xr.Coordinates.from_xindex(index))
>>> ds
<xarray.Dataset> Size: 2kB
Dimensions:  (x: 100)
Coordinates:
  * x        (x) float64 800B 1.0 1.01 1.02 1.03 1.04 ... 1.96 1.97 1.98 1.99
Data variables:
    foo      (x) int64 800B 0 1 2 3 4 5 6 7 8 9 ... 91 92 93 94 95 96 97 98 99
Indexes:
    x        CoordinateTransformIndex

This example is interesting because in this simple case we would expect a few more operations to work than in the case of more complex transformations such as 2D affine with rotation and/or shear, e.g.,

  • indexing with a slice (step=1) should preserve the coordinate index but it doesn't:
>>> ds.isel(x=slice(5, 10)).xindexes
Indexes:
    *empty*
  • basic label-based selection should also work, but it is not supported:
>>> ds.sel(x=1.65, method="nearest")
TypeError: CoordinateTransformIndex only supports advanced (point-wise) indexing with either xarray.DataArray or xarray.Variable objects.

Perhaps we could try adding support for this in CoordinateTransform and/or CoordinateTransformIndex? My concern is that we may end up cluttering the interface / implementation of those classes with many special cases.

An alternative option is building on top of it, e.g., in this case also provide a Range1DIndex class like so:

---- expand here to see the implementation of Range1DIndex ----
from xarray.core.indexes import IndexSelResult


class Range1DIndex(CoordinateTransformIndex):

    transform: Range1DCoordinateTransform
    dim: str
    coord_name: Hashable
    size: int

    def __init__(
        self,
        left: float,
        right: float,
        coord_name: Hashable,
        dim: str,
        size: int,
        dtype: Any = None,
    ):
        self.transform = Range1DCoordinateTransform(
            left, right, coord_name, dim, size, dtype
        )
        self.dim = dim
        self.coord_name = coord_name
        self.size = size

    def isel(self, indexers):
        idxer = indexers[self.dim]

        # straightforward to generate a new index if a slice is given with step 1
        if isinstance(idxer, slice) and (idxer.step == 1 or idxer.step is None):
            start = max(idxer.start, 0)
            stop = min(idxer.stop, self.size)
            
            new_left = self.transform.forward({self.dim: start})[self.coord_name]
            new_right = self.transform.forward({self.dim: stop})[self.coord_name]
            new_size = stop - start

            return Range1DIndex(new_left, new_right, self.coord_name, self.dim, new_size)

        return None

    def sel(self, labels, method=None, tolerance=None):
        label = labels[self.dim]

        if isinstance(label, slice):
            if label.step is None:
                # slice indexing (preserve the index)
                pos = self.transform.reverse({self.dim: np.array([label.start, label.stop])})
                pos = np.round(pos[self.coord_name]).astype("int")
                new_start = max(pos[0], 0)
                new_stop = min(pos[1], self.size)
                return IndexSelResult({self.dim: slice(new_start, new_stop)})
            else:
                # otherwise convert to basic (array) indexing
                label = np.arange(label.start, label.stop, label.step)

        # support basic indexing (in the 1D case basic vs. vectorized indexing
        # are pretty much similar)
        unwrap_xr = False
        if not isinstance(label, xr.Variable | xr.DataArray):
            # basic indexing -> either scalar or 1-d array
            try:
                var = xr.Variable("_", label)
            except ValueError:
                var = xr.Variable((), label)
            labels = {self.dim: var}
            unwrap_xr = True

        result = super().sel(labels, method=method, tolerance=tolerance)

        if unwrap_xr:
            dim_indexers = {self.dim: result.dim_indexers[self.dim].values}
            result = IndexSelResult(dim_indexers)
        
        return result
>>> index = Range1DIndex(1.0, 2.0, "x", "x", 100)
>>> ds2 = xr.Dataset(data_vars={"foo": ("x", np.arange(100))}, coords=xr.Coordinates.from_xindex(index))

Slicing (notice the preserved Range1DIndex):

>>> ds2.isel(x=slice(5, 10))
<xarray.Dataset> Size: 80B
Dimensions:  (x: 5)
Coordinates:
  * x        (x) float64 40B 1.05 1.06 1.07 1.08 1.09
Data variables:
    foo      (x) int64 40B 5 6 7 8 9
Indexes:
    x        Range1DIndex

Some basic label-based selection:

>>> ds2.sel(x=1.654, method="nearest")
<xarray.Dataset> Size: 16B
Dimensions:  ()
Coordinates:
    x        float64 8B 1.65
Data variables:
    foo      int64 8B 65

>>> ds2.sel(x=slice(1.465, 1.874), method="nearest")   # preserves the index!
<xarray.Dataset> Size: 640B
Dimensions:  (x: 40)
Coordinates:
  * x        (x) float64 320B 1.47 1.48 1.49 1.5 1.51 ... 1.83 1.84 1.85 1.86
Data variables:
    foo      (x) int64 320B 47 48 49 50 51 52 53 54 ... 79 80 81 82 83 84 85 86
Indexes:
    x        Range1DIndex

For such a simple 1-d range example, the coordinate transform abstraction is actually a bit overkill but still has the advantage of providing the lazy coordinate variable "for free".

More consistent with the rest of Xarray API where `coords` is used
everywhere.
@astrofrog
Copy link

@benbovy - @Cadair and I have been playing around with trying to get this to work with the astropy APE 14 WCS specification. Here is a minimal example:

https://gist.github.com/Cadair/4a03750868e044ac4bdd6f3a04ed7abc

We are running into a bug in the __repr__ which is causing an out of bounds error. It seems that accessing the coordinates directly works so it seems to be a problem specific to the __repr__?

Another unrelated comment: it would be nice to have the CoordinateTransform class be a proper abc class, and have the methods that need to be implemented be defined as abstract methods (e.g. forward and reverse)

@benbovy
Copy link
Member Author

benbovy commented Oct 2, 2024

Thanks for the feedback @astrofrog!

I'll look into the __repr__ issue. Could you provide a minimal reproducible example or a link where I can download the FITS file used in your example, please?

@Cadair
Copy link

Cadair commented Oct 2, 2024

@benbovy the fits file is included with astropy , so the code in the notebook should run as-is I believe.

@benbovy
Copy link
Member Author

benbovy commented Oct 2, 2024

Ah thanks. The __repr__ issue should now be fixed in 09667c5.

return None

def sel(
self, labels: dict[Any, Any], method=None, tolerance=None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How hard would it be to support tolerance in some form? This is a common and useful form of error checking.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty tricky to support it here I think, probably better to handle it on a per case basis.

For basic transformations I guess it could be possible to calculate a single, uniform tolerance value in decimal array index units and validate the selected elements using those units (cheap). In other cases we would need to compute the forward transformation of the extracted array indices and then validate the selected elements based on distances in physical units (more expensive).

Also, there may be cases where the coordinates of a same transform object don’t have all the same physical units (e.g., both degrees and radians coordinates in an Astropy WCS object). Unless we forbid that in xarray.CoordinateTransform, it doesn’t make much sense to pass a single tolerance value. Passing a dictionary tolerance={coord_name: value} doesn’t look very nice either IMO. A {unit: value} dict looks better but adding explicit support for units here might be opening a can of worms.

@shoyer
Copy link
Member

shoyer commented Oct 2, 2024

This very exciting! Nice work.

For indexing, it may be worth considering if you can implement .interp(). In practice I think that is often more desirable than nearest neighbor lookup.

@rbavery
Copy link

rbavery commented Oct 21, 2024

rioxarray creates x(x) / y(y) dimension coordinates when the affine transform is rectilinear with no rotation). For those cases we cannot use a single CoordinateTransform instance

I think rioxarray does this for performance reasons because it is faster and possible to correctly apply the affine transformation without calling numpy.meshgrid when the affine is rectilinear with no rotation. But with flexible coordinates, I think both approaches could be replaced with only a single CoordinateTransform with some refactoring.

import numpy
import affine

transform = affine.Affine.translation(.5,.5).scale(1.0)
width = 512
height = 200
%%time
x_coords, _ = transform * (numpy.arange(width), numpy.zeros(width))
_, y_coords = transform * (numpy.zeros(height), numpy.arange(height))

Wall time: 290 μs

%%time
x_coords_mesh, y_coords_mesh = transform * numpy.meshgrid(
    numpy.arange(width),
    numpy.arange(height),
)

Wall time: 3.09 ms

possible to add generic support for joining / concatenating coordinate transforms? I.e., implement CoordinateTransformIndex.concat and CoordinateTransformIndex.join

This sounds very valuable but want to make sure I understand what is meant. If I have a dataset of rasters across different UTM projections, would this allow me to read each with rioxarray and then concatenate the raster arrays such that each raster maintains it's original CRS? Or would this enable concatenating rasters that are already in the same CRS? Or something else?

My use case for this is I'd like to avoid reprojection and have a single xarray.DataArray representing rasters spread over global extents. And I'd like to be able to save this concatenated xarray DataArray to a Zarr v3 store with sharding in a way that preserves each CRS, with GeoZarr.

@benbovy
Copy link
Member Author

benbovy commented Oct 22, 2024

both approaches could be replaced with only a single CoordinateTransform with some refactoring.

Hmm do you have an idea on how this refactoring would look like? I've tried implementing a version of CoordinateTransform that supports coordinates with different dimensions but I eventually gave up because it was too complicated.

Here is one way to support the rectilinear / no rotation affine transform with independent x, y 1-dimensional coordinates without any refactoring:

  • a CoordinateTransform subclass that wraps an affine.Affine instance for either the x or y coordinate
---- expand here to see the implementation of AxisAffineCoordinateTransform ----
class AxisAffineCoordinateTransform(xr.CoordinateTransform):
    """1-axis wrapper of an affine 2D coordinate transform
    with no skew/rotation.
    
    """

    affine: affine.Affine
    is_xaxis: bool
    coord_name: Hashable
    dim: str
    size: int
    
    def __init__(
        self,
        affine: affine.Affine,
        coord_name: Hashable,
        dim: str,
        size: int,
        is_xaxis: bool,
        dtype: Any = np.dtype(np.float64),
    ):
        if (not affine.is_rectilinear or (affine.b == affine.d != 0)):
            raise ValueError("affine must be rectilinear with no rotation")

        super().__init__((coord_name,), {dim: size}, dtype=dtype)
        self.affine = affine
        self.is_xaxis = is_xaxis
        self.coord_name = coord_name
        self.dim = dim
        self.size = size

    def forward(self, dim_positions):
        positions = dim_positions[self.dim]

        if self.is_xaxis:
            labels, _ = self.affine * (positions, np.zeros_like(positions))
        else:
            _, labels = self.affine * (np.zeros_like(positions), positions)

        return {self.coord_name: labels}

    def reverse(self, coord_labels):
        labels = coord_labels[self.coord_name]

        if self.is_xaxis:
            positions, _ = ~self.affine * (labels, np.zeros_like(labels))
        else:
            _, positions = ~self.affine * (np.zeros_like(labels), labels)

        return {self.dim: positions}
    
    def equals(self, other):
        return self.affine == other.affine and self.dim_size == other.dim_size
  • an Xarray Index that encapsulates two CoordinateTransformIndex instances (sharing the same Affine object) for the x and y axis respectively
---- expand here to see the implementation of RasterIndex ----
from xarray import Variable
from xarray.indexes import CoordinateTransformIndex
from xarray.core.indexing import IndexSelResult, merge_sel_results


class RasterIndex(xr.indexes.Index):

    def __init__(
        self,
        x_index: CoordinateTransformIndex,
        y_index: CoordinateTransformIndex,
    ):
        self.x_index = x_index
        self.y_index = y_index

    @classmethod
    def from_transform(
        cls,
        affine: affine.Affine,
        shape: tuple[int, int],
        xy_coord_names: tuple[Hashable, Hashable] = ("x", "y"),
    ):
        # shape is in y, x order
        xtr = AxisAffineCoordinateTransform(
            affine, xy_coord_names[0], xy_coord_names[0], shape[1], is_xaxis=True
        )
        ytr = AxisAffineCoordinateTransform(
            affine, xy_coord_names[1], xy_coord_names[1], shape[0], is_xaxis=False
        )

        return cls(CoordinateTransformIndex(xtr), CoordinateTransformIndex(ytr))

    def create_variables(
        self, variables: Mapping[Any, Variable] | None = None
    ) -> dict[Hashable, Variable]:
        return {**self.x_index.create_variables(), **self.y_index.create_variables()}

    def create_coords(self) -> xr.Coordinates:
        variables = self.create_variables()
        indexes = {name: self for name in variables}
        return xr.Coordinates(coords=variables, indexes=indexes)
    
    def sel(
        self, labels: dict[Any, Any], method=None, tolerance=None
    ) -> IndexSelResult:
        results = []

        xlabels = {k: v for k, v in labels if k in self.x_index.transform.coord_names}
        if xlabels:
            results.append(self.x_index.sel(xlabels))
        
        ylabels = {k: v for k, v in labels if k in self.y_index.transform.coord_names}
        if ylabels:
            results.append(self.y_index.sel(ylabels))
        
        return merge_sel_results(results)
       
     def equals(self, other: Self) -> bool:
        return self.x_index.equals(other.x_index) and self.y_index.equals(other.y_index)

Usage example:

>>> index = RasterIndex.from_transform(affine.Affine.translation(0.5, 0.5), (1000, 2000))
>>> ds = xr.Dataset(coords=xr.Coordinates.from_xindex(index))
>>> ds
<xarray.Dataset> Size: 24kB
Dimensions:  (x: 2000, y: 1000)
Coordinates:
  * x        (x) float64 16kB 0.5 1.5 2.5 3.5 ... 1.998e+03 1.998e+03 2e+03
  * y        (y) float64 8kB 0.5 1.5 2.5 3.5 4.5 ... 996.5 997.5 998.5 999.5
Data variables:
    *empty*
Indexes:
  ┌ x        RasterIndexy

>>> ds.isel(x=slice(100, 200), y=500)
<xarray.Dataset> Size: 808B
Dimensions:  (x: 100)
Coordinates:
    x        (x) float64 800B 100.5 101.5 102.5 103.5 ... 197.5 198.5 199.5
    y        float64 8B 500.5
Data variables:
    *empty*

@benbovy
Copy link
Member Author

benbovy commented Oct 22, 2024

My use case for this is I'd like to avoid reprojection and have a single xarray.DataArray representing rasters spread over global extents. And I'd like to be able to save this concatenated xarray DataArray to a Zarr v3 store with sharding in a way that preserves each CRS, with GeoZarr.

This seems complicated to me. It would be easier in this case to have a unique CRS per DataArray and provide a virtual layer built on top of DataArray to handle lazy reprojection (e.g., similarly to GDAL VRT I guess?).

Also, IIUC Zarr doesn't allow per-chunk or per-shard metadata so it isn't clear to me how GeoZarr would support multiple CRS datasets (zarr-developers/geozarr-spec#4).

@martindurant
Copy link
Contributor

Zarr doesn't allow per-chunk or per-shard metadata

perhaps virtualizarr can do this (@TomNicholas )

@shoyer
Copy link
Member

shoyer commented Jan 29, 2025

Hi @benbovy -- this would be a really amazing feature to land! Are you still planning to work on this, blocked by something, or could someone else try to take it across the line?

@benbovy
Copy link
Member Author

benbovy commented Jan 29, 2025

Hi @shoyer -- Yes! I started looking again at this earlier today actually, I opened #10000 as a side PR.

Apart from unit tests I don't think there is more to add in this PR, which provides all the basic functionality for some concrete use cases (as shown by the examples in the comments here). So hopefully this should be ready sooner than later!

More functionality can be implemented in follow-up PRs, as well as documentation (after playing a bit more with this experimental feature).

@TomNicholas
Copy link
Member

I opened #10000 as a side PR.

Issue number 10,000!

More functionality can be implemented in follow-up PRs

I support merging this as soon as possible, because it enables so much.

@dcherian
Copy link
Contributor

@maxrjones and I would like to play around with this. Can we remove the top-level import and merge it? We can add the top-level import once it's ready.

@benbovy
Copy link
Member Author

benbovy commented Feb 12, 2025

@dcherian Yes let's do this (unless we want to avoid breaking changes in the next related PRs... although I'm already pretty happy with the API here). I'll add unit tests in a follow-up PR.

@keewis
Copy link
Collaborator

keewis commented Feb 12, 2025

we want to avoid breaking changes in the next related PRs

I think the idea was to merge but not expose as public API for now, so that we don't need to worry about breaking changes

@benbovy
Copy link
Member Author

benbovy commented Feb 12, 2025

Hmm maybe we should fix CI before merging it, though? I can have a look tomorrow.

@dcherian
Copy link
Contributor

yes exactly, just fix mypy, merge and don't tell anyone about it till we're ready haha. Let's clearly add an experimental API warning to the docstring too.

Copy link
Contributor

@maxrjones maxrjones left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this PR - super excited to try it out!

These suggested changes make mypy happy again.

@benbovy benbovy marked this pull request as ready for review February 13, 2025 13:27
@benbovy
Copy link
Member Author

benbovy commented Feb 13, 2025

I added the tests here (mostly complete for the features added). This should be ready for review (or for merging :-)).

@benbovy benbovy added the plan to merge Final call for comments label Feb 14, 2025
@dcherian
Copy link
Contributor

I took a quick look through the tests. Thanks @benbovy! This is a large leap forward.

@dcherian dcherian merged commit 4bbab48 into pydata:main Feb 14, 2025
35 checks passed
benbovy added a commit to benbovy/rioxarray that referenced this pull request Feb 19, 2025
Based on coordinate transform examples copied and adapted from
pydata/xarray#9543.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.