#!/usr/bin/env python
# Copyright (c) 2022-2024, openradar developers.
# Distributed under the MIT License. See LICENSE for more info.
"""
XRadar Util
===========
.. autosummary::
:nosignatures:
:toctree: generated/
{}
"""
__all__ = [
"has_import",
"get_first_angle",
"get_second_angle",
"remove_duplicate_rays",
"reindex_angle",
"extract_angle_parameters",
"ipol_time",
"rolling_dim",
"get_sweep_keys",
"apply_to_sweeps",
"apply_to_volume",
"map_over_sweeps",
]
__doc__ = __doc__.format("\n ".join(__all__))
import contextlib
import functools
import gzip
import importlib.util
import io
import warnings
import numpy as np
import xarray as xr
from scipy import interpolate
[docs]
def has_import(pkg_name):
return importlib.util.find_spec(pkg_name)
[docs]
def get_first_angle(ds):
"""Return name of first angle dimension from given dataset.
Parameters
----------
ds : xarray.Dataset
Dateset to get first angle name from.
Returns
-------
first_angle : str
Name of first angle dimension.
"""
first_angle = list(set(ds.dims) & set(ds.coords) ^ {"range"})[0]
if first_angle == "time":
raise ValueError(
"first dimension is ``time``, but needed ``azimuth`` or ``elevation``"
)
return first_angle
[docs]
def get_second_angle(ds):
"""Return name of second angle coordinate from given dataset.
Parameters
----------
ds : xarray.Dataset
Dateset to get second angle name from.
Returns
-------
out : str
Name of second angle coordinate.
"""
return list(
((set(ds.coords) | set(ds.variables)) ^ set(ds.dims)) & {"azimuth", "elevation"}
)[0]
[docs]
def remove_duplicate_rays(ds):
"""Remove duplicate rays.
Parameters
----------
ds : xarray.Dataset
Dateset to remove duplicate rays.
Returns
-------
ds : xarray.Dataset
Dataset with duplicate rays removed
"""
first_angle = get_first_angle(ds)
_, idx = np.unique(ds[first_angle], return_index=True)
if len(idx) < len(ds[first_angle]):
# todo:
# if times have been calculated with wrong number of rays
# (ODIM_H5, we would need to recalculate the times
# how should we find out?
ds = ds.isel({first_angle: idx})
return ds
def _reindex_angle(ds, array, tolerance, method="nearest"):
"""Reindex first angle.
Missing values will be filled by variable's _FillValue.
Parameters
----------
ds : xarray.Dataset
Dateset to reindex first angle.
array : array-like
Array with angle values which the Dataset should reindex to.
tolerance : float
Angle tolerance up to which angles should be considered for used method.
Keyword Arguments
-----------------
method : str
Reindexing method, defaults to "nearest". See :py:meth:`xarray.Dataset.reindex`.
Returns
-------
ds : xarray.Dataset
Reindexed dataset
"""
# handle fill value
fill_value = {
k: np.asarray(v._FillValue).astype(v.dtype)
for k, v in ds.items()
if hasattr(v, "_FillValue")
}
angle = get_first_angle(ds)
# reindex
ds = ds.reindex(
{angle: array},
method=method,
tolerance=tolerance,
fill_value=fill_value,
)
return ds
[docs]
def reindex_angle(
ds,
start_angle=None,
stop_angle=None,
angle_res=None,
direction=None,
method="nearest",
tolerance=None,
):
"""Reindex along first angle.
Missing values will be filled by variable's ``_FillValue``.
Parameters
----------
ds : xarray.Dataset
Dateset to reindex first angle.
Keyword Arguments
-----------------
start_angle : float
Start angle of dataset.
stop_angle : float
Stop angle of dataset.
angle_res : float
Angle resolution of the dataset.
direction : int
Sweep direction, -1 -> CCW, 1 -> CW.
method : str
Reindexing method, defaults to "nearest". See :py:meth:`xarray.Dataset.reindex`.
tolerance : float
Angle tolerance up to which angles should be considered for used method.
Defaults to angle_res / 2.
Returns
-------
ds : xarray.Dataset
Reindexed dataset
"""
if tolerance is None:
tolerance = angle_res / 2.0
# handle angle order, angle dimension
first_angle = get_first_angle(ds)
second_angle = get_second_angle(ds)
expected_angle_span = abs(stop_angle - start_angle)
expected_number_rays = int(np.round(expected_angle_span / angle_res, decimals=0))
# create reindexing angle
ang = start_angle + direction * np.arange(
angle_res / 2.0,
expected_number_rays * angle_res,
angle_res,
dtype=ds[first_angle].dtype,
)
ds = ds.sortby(first_angle)
ds = _reindex_angle(ds, ang, tolerance, method=method)
# check secondary angle coordinate (no nan)
# set nan values to reasonable median
sang = ds[second_angle]
if np.count_nonzero(np.isnan(sang)):
ds[second_angle] = sang.fillna(sang.median(skipna=True))
return ds
def _trailing_zeros(da, digits=16, dim=0):
"""Calculate number of trailing zeros for input cast to int."""
for i in range(digits):
x = da.dropna(da.dims[dim]).astype(int).values % np.power(10, i)
if not x.any():
continue
else:
break
return i - 1
def _ipol_time(da, dim0, a1gate=0, direction=1):
"""Interpolate/extrapolate missing time steps (NaT).
Parameters
----------
da : xarray.DataArray
DataArray to interpolate/extrapolate missing timesteps.
Returns
-------
da : xarray.DataArray
DataArray with interpolated/extrapolated timesteps.
"""
dtype = da.dtype
idx = slice(None, None)
# extract wanted section
da_sel = da.isel({dim0: idx})
# get sorting indices along first dimension
sidx = da_sel[dim0].argsort()
# special handling for wrap-around angles
angles = da_sel[dim0].values
# a1gate should normally only be set for PPI,
if a1gate > 0:
angles[-a1gate:] += 360
da_sel = da_sel.assign_coords({dim0: angles})
# prepare azimuth array for interpolation
angles = da_sel[dim0].diff(dim0).cumsum(dim0).pad({dim0: (1, 0)}, constant_values=0)
da_sel = da_sel.assign_coords({dim0: angles * direction})
# apply original order for interpolation, get angles
angles = da_sel.sortby([sidx])[dim0]
# drop NaT from selection for creation of interpolator
da_sel = da_sel.dropna(dim0)
# setup interpolator
ipol = interpolate.interp1d(
da_sel[dim0].values,
da_sel.astype(int),
fill_value="extrapolate",
assume_sorted=False,
)
# floating point interpolation might introduce spurious digits
# get least significant digit
sig = np.power(10, _trailing_zeros(da_sel.time))
# interpolate and round to the least significant digit
data = np.rint(ipol(angles) / sig).astype(int) * sig
# apply interpolated times into original DataArray
da.loc[{dim0: idx}] = data.astype(dtype)
return da
[docs]
def ipol_time(ds, *, a1gate_idx=None, direction=None, **kwargs):
"""Interpolate/extrapolate missing time steps (NaT).
Parameters
----------
ds : xarray.Dataset
Dataset to interpolate/extrapolate missing timesteps.
Keyword Arguments
-----------------
a1gate_idx : int | None
First measured gate. 0 assumed, if None.
direction : int | None
1: CW, -1: CCW, Clockwise assumed, if None.
Returns
-------
ds : xarray.Dataset
Dataset with interpolated/extrapolated timesteps.
"""
# get first dim and sort to get common state
dim0 = get_first_angle(ds)
ds = ds.sortby(dim0)
# return early, if nothing to do
if not np.isnan(ds.time).any():
return ds
if direction is None:
# set clockwise, rotating in positive direction
direction = 1
time = ds.time.astype(int)
# skip NaT (-9223372036854775808) for amin/amax calculation
amin = time.where(time > -9223372036854775808).argmin(dim0, skipna=True).values
amax = time.where(time > -9223372036854775808).argmax(dim0, skipna=True).values
time = ds.time
if a1gate_idx is None:
# if times are not sorted ascending
if amin > amax:
# check if we have missing times between amax and amin
# todo: align with start or end per keyword argument
if (amin - amax) > 1:
warnings.warn(
"Rays might miss on beginning and/or end of sweep. "
"`a1gate` information is needed to fully recover. "
"We'll assume sweep start at first valid ray."
)
# set a1gate to amin
a1gate_idx = amin
else:
a1gate_idx = 0
if a1gate_idx > 0:
# roll first ray to 0-index, interpolate, roll-back
time = time.roll({dim0: -a1gate_idx}, roll_coords=True)
time = time.pipe(_ipol_time, dim0, a1gate=a1gate_idx, direction=direction)
time = time.roll({dim0: a1gate_idx}, roll_coords=True)
else:
time = time.pipe(_ipol_time, dim0, direction=direction)
ds_out = ds.assign({"time": ([dim0], time.values)})
return ds_out.sortby(dim0)
@contextlib.contextmanager
def _get_data_file(file, file_or_filelike):
if file_or_filelike == "filelike":
_open = open
if file[-3:] == ".gz":
_open = gzip.open
with _open(file, mode="r+b") as f:
yield io.BytesIO(f.read())
else:
yield file
[docs]
def rolling_dim(data, window):
"""Return array with rolling dimension of window-length added at the end."""
shape = data.shape[:-1] + (data.shape[-1] - window + 1, window)
strides = data.strides + (data.strides[-1],)
return np.lib.stride_tricks.as_strided(data, shape=shape, strides=strides)
[docs]
def get_sweep_keys(dtree):
"""Return which nodes in the datatree contain sweep variables
Parameters
----------
dtree : xarray.DataTree
Datatree to check for sweep_n keys
Returns
-------
keys : list
List of associated keys with sweep_n
"""
sweep_group_keys = []
for key in list(dtree.children):
parts = key.split("_")
try:
# Try to set the second part of the tree key to an int
int(parts[1])
# Check for "sweep" in the first part of the key
assert "sweep" in parts[0]
sweep_group_keys.append(key)
# This would fail with strings - ex. sweep_group_attrs
except ValueError:
pass
# This would fail if "sweep" not in key - ex. radar_parameters
except AssertionError:
pass
return sweep_group_keys
[docs]
def apply_to_sweeps(dtree, func, *args, **kwargs):
"""
Applies a given function to all sweep nodes in the radar volume.
Parameters
----------
dtree : xarray.DataTree
The DataTree object representing the radar volume.
func : function
The function to apply to each sweep.
*args : tuple
Additional positional arguments to pass to the function.
**kwargs : dict
Additional keyword arguments to pass to the function.
Returns
-------
xarray.DataTree
A new DataTree object with the function applied to all sweeps.
"""
# Create a new tree dictionary
tree = {"/": dtree.ds} # Start with the root Dataset
# Add all nodes except the root
tree.update({node.path: node.ds for node in dtree.subtree if node.path != "/"})
# Apply the function to all sweep nodes and update the tree dictionary
tree.update(
{
node.path: func(dtree[node.path].to_dataset(), *args, **kwargs)
for node in dtree.match("sweep*").subtree
if node.path.startswith("/sweep")
}
)
# Return a new DataTree constructed from the modified tree dictionary
return xr.DataTree.from_dict(tree)
[docs]
def apply_to_volume(dtree, func, *args, **kwargs):
"""
Alias for apply_to_sweeps.
Applies a given function to all sweep nodes in the radar volume.
Parameters
----------
dtree : xarray.DataTree
The DataTree object representing the radar volume.
func : function
The function to apply to each sweep.
*args : tuple
Additional positional arguments to pass to the function.
**kwargs : dict
Additional keyword arguments to pass to the function.
Returns
-------
xarray.DataTree
A new DataTree object with the function applied to all sweeps.
"""
return apply_to_sweeps(dtree, func, *args, **kwargs)
[docs]
def map_over_sweeps(func):
"""
Decorator to apply a function only to sweep nodes in a DataTree.
This decorator first checks whether the dataset provided to the function has the 'range' dimension,
indicating it's a sweep node. If true, the function is applied. Non-sweep nodes are left unchanged.
Parameters
----------
func : callable
A function that operates on an xarray Dataset. The function must accept a Dataset as its
first argument and return a modified Dataset.
Returns
-------
callable
A function that can be applied to all sweep nodes in a DataTree.
Examples
--------
>>> @map_over_sweeps
>>> def calculate_rain_rate(ds, ref_field='DBZH'):
>>> # Function logic to calculate rain rate
>>> return ds
"""
@functools.wraps(func)
def _func(*args, **kwargs):
"""
Internal function to apply `func` only to sweep nodes.
Checks for the presence of the 'range' dimension to identify sweep nodes. Non-sweep nodes
are left unchanged.
Parameters
----------
*args : tuple
Positional arguments passed to the function.
**kwargs : dict
Keyword arguments passed to the function.
Returns
-------
Dataset or unchanged object
The modified Dataset if applied to a sweep node, otherwise the unchanged object.
"""
if "range" in args[0].dims:
return func(*args, **kwargs)
else:
return args[0]
# map _func over datasets in a DataTree
def _map_over_sweeps(*args, **kwargs):
return xr.map_over_datasets(functools.partial(_func, **kwargs), *args)
return _map_over_sweeps