Commit 15957296 authored by Klaus Zimmermann's avatar Klaus Zimmermann
Browse files

Index function running statistics (closes #86)

parent ee6be72c
......@@ -147,6 +147,18 @@ index_functions:
percentiles:
kind: quantity
running_statistics:
description: |
First calculate a statistic within a moving window, then calculate a
different statistic across the moving windows.
parameters:
rolling_aggregator:
kind: reducer
window_size:
kind: quantity
reducer:
kind: reducer
temperature_sum:
description: |
Calculates the temperature sum above/below a threshold. First, the threshold
......
......@@ -12,6 +12,7 @@ from .index_functions import ( # noqa: F401
Statistics,
ThresholdedPercentile,
ThresholdedStatistics,
RunningStatistics,
TemperatureSum,
)
......
......@@ -6,7 +6,13 @@ from cf_units import Unit
import dask.array as da
import numpy as np
from .support import normalize_axis, IndexFunction, ThresholdMixin, ReducerMixin
from .support import (
normalize_axis,
IndexFunction,
ThresholdMixin,
ReducerMixin,
RollingWindowMixin,
)
from ..util import change_units
......@@ -323,6 +329,109 @@ class ThresholdedStatistics(ThresholdMixin, ReducerMixin, IndexFunction):
return res.astype("float32")
class RunningStatistics(RollingWindowMixin, IndexFunction):
def __init__(self, rolling_aggregator, window_size, reducer):
super().__init__(rolling_aggregator, window_size, reducer)
self.fuse_periods = True
self.bandwidth = self.window_size.points[0] // 2
self.tail_overlap = self.window_size.points[0] - 1
self.head_overlap = self.tail_overlap + self.window_size.points[0] % 2
def prepare(self, input_cubes):
super().prepare(input_cubes)
ref_cube = next(iter(input_cubes.values()))
self.standard_name = ref_cube.standard_name
self.units = ref_cube.units
def pre_aggregate_shape(self, *args, **kwargs):
return (self.head_overlap + self.tail_overlap + 1,)
def call_func(self, data, axis, **kwargs):
axis = normalize_axis(axis, data.ndim)
mask = np.ma.getmaskarray(data).any(axis=axis)
rolling_view = np.lib.stride_tricks.sliding_window_view(
data, self.window_size.points, axis
)
aggregated = self.rolling_aggregator(rolling_view, -1)
reduced = self.reducer(aggregated, axis=axis)
masked = np.ma.masked_array(np.ma.getdata(reduced), mask)
head_slices = (slice(None, None),) * axis + (slice(None, self.head_overlap),)
head = np.moveaxis(data[head_slices], axis, -1)
tail_slices = (slice(None, None),) * axis + (slice(-self.tail_overlap, None),)
tail = np.moveaxis(data[tail_slices], axis, -1)
res = np.concatenate([head, masked[..., np.newaxis], tail], axis=-1)
return res.astype("float32")
def lazy_func(self, data, axis, **kwargs):
axis = normalize_axis(axis, data.ndim)
mask = da.ma.getmaskarray(data).any(axis=axis)
rolling_view = da.overlap.sliding_window_view(
data, self.window_size.points, axis
)
aggregated = self.lazy_rolling_aggregator(rolling_view, -1)
reduced = self.lazy_reducer(aggregated, axis=axis)
masked = da.ma.masked_array(da.ma.getdata(reduced), mask)
head_slices = (slice(None, None),) * axis + (slice(None, self.head_overlap),)
head = np.moveaxis(data[head_slices], axis, -1)
tail_slices = (slice(None, None),) * axis + (slice(-self.tail_overlap, None),)
tail = np.moveaxis(data[tail_slices], axis, -1)
res = da.concatenate([head, masked[..., np.newaxis], tail], axis=-1)
return res.astype("float32")
def post_process(self, cube, data, coords, period, **kwargs):
def fuse(this, previous_tail, next_head):
head = this[..., : self.head_overlap]
pre_statistic = this[..., self.head_overlap]
tail = this[..., -self.tail_overlap :].copy()
head_overlap = np.concatenate(
[previous_tail[..., -self.bandwidth :], head], axis=-1
)
head_rolling_view = np.lib.stride_tricks.sliding_window_view(
head_overlap, self.window_size.points, -1
)
head_aggregated = self.lazy_rolling_aggregator(head_rolling_view, axis=-1)
tail_overlap = np.concatenate(
[tail, next_head[..., : self.bandwidth]], axis=-1
)
tail_rolling_view = np.lib.stride_tricks.sliding_window_view(
tail_overlap, self.window_size.points, -1
)
tail_aggregated = self.lazy_rolling_aggregator(tail_rolling_view, axis=-1)
concatenated = np.concatenate(
[head_aggregated, pre_statistic[..., np.newaxis], tail_aggregated],
axis=-1,
)
running_statistic = self.lazy_reducer(concatenated, axis=-1)
return running_statistic, tail
if self.fuse_periods and len(data) > 1:
stack = []
this = data[0]
tail_shape = this.shape[:-1] + (self.tail_overlap,)
previous_tail = da.ma.masked_array(
da.zeros(tail_shape, dtype=np.float32),
False,
)
for next_chunk in data[1:]:
next_head = next_chunk[..., : self.head_overlap].copy()
running_statistic, previous_tail = fuse(this, previous_tail, next_head)
stack.append(running_statistic)
this = next_chunk
head_shape = this.shape[:-1] + (self.head_overlap,)
next_head = da.ma.masked_array(
da.zeros(head_shape, dtype=np.float32),
False,
)
stack.append(fuse(next_chunk, previous_tail, next_head)[0])
res_data = da.stack(stack, axis=0)
else:
res_data = self.lazy_reducer(data[..., 1:], axis=-1)
return cube, res_data
class TemperatureSum(ThresholdMixin, IndexFunction):
def __init__(self, threshold, condition):
super().__init__(threshold, condition, units=Unit("days"))
......
......@@ -157,3 +157,13 @@ class ReducerMixin:
self.reducer = NUMPY_REDUCERS[reducer]
self.lazy_reducer = DASK_REDUCERS[reducer]
self.scalar_reducer = SCALAR_REDUCERS[reducer]
class RollingWindowMixin:
def __init__(self, rolling_aggregator, window_size, reducer, *args, **kwargs):
super().__init__(*args, **kwargs)
self.reducer = NUMPY_REDUCERS[reducer]
self.lazy_reducer = DASK_REDUCERS[reducer]
self.window_size = window_size
self.rolling_aggregator = NUMPY_REDUCERS[rolling_aggregator]
self.lazy_rolling_aggregator = NUMPY_REDUCERS[rolling_aggregator]
......@@ -59,6 +59,7 @@ setuptools.setup(
"interday_diurnal_temperature_range=climix.index_functions:InterdayDiurnalTemperatureRange", # noqa: E501
"last_occurrence=climix.index_functions:LastOccurrence",
"percentile=climix.index_functions:Percentile",
"running_statistics=climix.index_functions:RunningStatistics",
"spell_length=climix.index_functions:SpellLength",
"statistics=climix.index_functions:Statistics",
"temperature_sum=climix.index_functions:TemperatureSum",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment