From ca652c983ab593297377ffcf96d81e7ccf7429cd Mon Sep 17 00:00:00 2001 From: lanougue Date: Wed, 16 Jun 2021 00:12:09 +0200 Subject: [PATCH] try --- xrft/xrft.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/xrft/xrft.py b/xrft/xrft.py index bc987fdf..0d88ebb9 100644 --- a/xrft/xrft.py +++ b/xrft/xrft.py @@ -450,6 +450,7 @@ def dft( dims=up_dim, coords={up_dim: newcoords[up_dim]}, ) # taking advantage of xarray broadcasting and ordered coordinates + daft[up_dim].attrs.update({"direct_lag": lag.obj}) if true_amplitude: daft = daft * np.prod(delta_x) @@ -539,8 +540,9 @@ def idft( dim = [d for d in dim if d != real_dim] + [ real_dim ] # real dim has to be moved or added at the end ! - - if lag is not None: + if lag is None: + lag = [daft[d].attrs.get("direct_lag", 0.0) for d in dim] + else: if isinstance(lag, float) or isinstance(lag, int): lag = [lag] if len(dim) != len(lag): @@ -549,8 +551,8 @@ def idft( msg = "Setting lag with true_phase=False does not guarantee accurate idft." warnings.warn(msg, Warning) - for d, l in zip(dim, lag): - daft = daft * np.exp(1j * 2.0 * np.pi * daft[d] * l) + for d, l in zip(dim, lag): + daft = daft * np.exp(1j * 2.0 * np.pi * daft[d] * l) if chunks_to_segments: daft = _stack_chunks(daft, dim)