Commit 7084d89e authored by Klaus Zimmermann's avatar Klaus Zimmermann
Browse files

Add season_length index function

parent 785805fb
......@@ -23,5 +23,6 @@ from .percentile_functions import ( # noqa: F401
from .spell_functions import ( # noqa: F401
FirstSpell,
SeasonLength,
SpellLength,
)
from functools import partial
from cf_units import Unit
import dask.array as da
import numpy as np
from .spell_kernels import make_first_spell_kernels, make_spell_length_kernels
from .support import normalize_axis, IndexFunction, ThresholdMixin, ReducerMixin
from .support import (
normalize_axis,
IndexFunction,
ThresholdMixin,
ReducerMixin,
NUMPY_OPERATORS,
DASK_OPERATORS,
change_units,
)
class FirstSpell(ThresholdMixin, IndexFunction):
......@@ -37,7 +47,7 @@ class FirstSpell(ThresholdMixin, IndexFunction):
meta=np.array((), dtype=int),
)
res = first_spell_data[..., 2].copy()
res = da.where(res >= 0, res - offset, res)
res = da.where(res >= 0, res + offset, res)
res = da.ma.masked_array(da.ma.getdata(res), mask)
return res.astype("float32")
......@@ -56,6 +66,111 @@ class FirstSpell(ThresholdMixin, IndexFunction):
return res
class SeasonLength(IndexFunction):
def __init__(
self,
start_threshold,
start_condition,
duration,
dead_period,
end_threshold,
end_condition,
):
super().__init__(units=Unit("days"))
self.duration = duration
self.dead_period = dead_period
self.kernels = make_first_spell_kernels(duration.points[0])
self.start_threshold = start_threshold
self.start_condition = NUMPY_OPERATORS[start_condition]
self.start_lazy_condition = DASK_OPERATORS[start_condition]
self.extra_coords.append(start_threshold.copy())
self.end_threshold = end_threshold
self.end_condition = NUMPY_OPERATORS[end_condition]
self.end_lazy_condition = DASK_OPERATORS[end_condition]
self.extra_coords.append(end_threshold.copy())
def prepare(self, input_cubes):
ref_cube = next(iter(input_cubes.values()))
threshold = self.start_threshold
threshold.points = threshold.points.astype(ref_cube.dtype)
if threshold.has_bounds():
threshold.bounds = threshold.bounds.astype(ref_cube.dtype)
change_units(threshold, ref_cube.units, ref_cube.standard_name)
threshold = self.end_threshold
threshold.points = threshold.points.astype(ref_cube.dtype)
if threshold.has_bounds():
threshold.bounds = threshold.bounds.astype(ref_cube.dtype)
change_units(threshold, ref_cube.units, ref_cube.standard_name)
super().prepare(input_cubes)
def pre_aggregate_shape(self, *args, **kwargs):
return (4,)
def call_func(self, data, axis, **kwargs):
raise NotImplementedError
def lazy_func(self, data, axis, **kwargs):
axis = normalize_axis(axis, data.ndim)
mask = da.ma.getmaskarray(data).any(axis=axis)
data = da.moveaxis(data, axis, -1)
max_length = data.shape[-1]
first_start_spell_data = da.reduction(
data,
partial(
self.chunk,
condition=self.start_condition,
threshold=self.start_threshold.points[0],
),
self.aggregate,
keepdims=True,
output_size=4,
axis=-1,
dtype=int,
concatenate=False,
meta=np.array((), dtype=int),
)
start = first_start_spell_data[..., 2].copy()
offset = self.dead_period.points[0]
data = data[..., offset:]
first_end_spell_data = da.reduction(
data,
partial(
self.chunk,
condition=self.end_condition,
threshold=self.end_threshold.points[0],
),
self.aggregate,
keepdims=True,
output_size=4,
axis=-1,
dtype=int,
concatenate=False,
meta=np.array((), dtype=int),
)
end = first_end_spell_data[..., 2].copy()
end = da.where(end >= 0, end + offset, end)
res = da.where(end < 0, max_length, end)
res = da.where(start < 0, 0, res - start)
res = da.ma.masked_array(da.ma.getdata(res), mask)
return res.astype("float32")
def chunk(
self, raw_data, axis, keepdims, condition, threshold, computing_meta=False
):
if computing_meta:
return np.array((), dtype=int)
data = condition(raw_data, threshold)
chunk_res = self.kernels.chunk(data)
return chunk_res
def aggregate(self, x_chunk, axis, keepdims):
if not isinstance(x_chunk, list):
return x_chunk
res = self.kernels.aggregate(np.array(x_chunk))
return res
class SpellLength(ThresholdMixin, ReducerMixin, IndexFunction):
def __init__(self, threshold, condition, reducer, fuse_periods=False):
super().__init__(threshold, condition, reducer, units=Unit("days"))
......
......@@ -61,6 +61,7 @@ setuptools.setup(
"last_occurrence=climix.index_functions:LastOccurrence",
"percentile=climix.index_functions:Percentile",
"running_statistics=climix.index_functions:RunningStatistics",
"season_length=climix.index_functions:SeasonLength",
"spell_length=climix.index_functions:SpellLength",
"statistics=climix.index_functions:Statistics",
"temperature_sum=climix.index_functions:TemperatureSum",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment