diff --git a/climix/index_functions/__init__.py b/climix/index_functions/__init__.py index c128b888725ed8a3e47d7320ff7c602e090f80c5..fcb653748934472b63406127cc91c8d8e46b7cee 100644 --- a/climix/index_functions/__init__.py +++ b/climix/index_functions/__init__.py @@ -10,6 +10,7 @@ from .index_functions import ( # noqa: F401 LastOccurrence, Percentile, Statistics, + ThresholdedDiurnalTemperatureRange, ThresholdedPercentile, ThresholdedStatistics, RunningStatistics, diff --git a/climix/index_functions/index_functions.py b/climix/index_functions/index_functions.py index 81a52c3994258efb748c02b20da908423b9ced13..99f4a1f7f547f1f2fb18cc7dc1b0f91d94a66c9b 100644 --- a/climix/index_functions/index_functions.py +++ b/climix/index_functions/index_functions.py @@ -8,6 +8,7 @@ import numpy as np from .support import ( normalize_axis, + DifferenceThresholdMixin, IndexFunction, ThresholdMixin, ReducerMixin, @@ -263,6 +264,41 @@ class Statistics(ReducerMixin, IndexFunction): return res.astype("float32") +class ThresholdedDiurnalTemperatureRange(DifferenceThresholdMixin, IndexFunction): + def __init__(self, threshold, condition): + super().__init__( + threshold, condition, units=Unit("days") + ) + + def prepare(self, input_cubes): + props = { + (cube.dtype, cube.units, cube.standard_name) + for cube in input_cubes.values() + } + assert len(props) == 1 + dtype, units, standard_name = props.pop() + assert units.is_convertible(Unit("degree_Celsius")) + super().prepare(input_cubes) + + def call_func(self, data, axis, **kwargs): + dtr = data["high_data"] - data["low_data"] + axis = normalize_axis(axis, dtr.ndim) + mask = np.ma.getmaskarray(dtr).any(axis=axis) + cond = self.condition(dtr, self.threshold.points) + res = np.count_nonzero(cond, axis=axis) + res = np.ma.masked_array(np.ma.getdata(res), mask) + return res.astype("float32") + + def lazy_func(self, data, axis, **kwargs): + dtr = data["high_data"] - data["low_data"] + axis = normalize_axis(axis, dtr.ndim) + mask = da.ma.getmaskarray(dtr).any(axis=axis) + cond = self.lazy_condition(dtr, self.threshold.points) + res = da.count_nonzero(cond, axis=axis) + res = da.ma.masked_array(da.ma.getdata(res), mask) + return res.astype("float32") + + class ThresholdedPercentile(ThresholdMixin, IndexFunction): def __init__(self, threshold, condition, percentiles, interpolation="linear"): super().__init__(threshold, condition) diff --git a/climix/index_functions/support.py b/climix/index_functions/support.py index 371f799287a37b91d1c05eb46921bda23acf8a17..bd44d6c0917e6eea4c0384acb75cdce17ac85cc4 100644 --- a/climix/index_functions/support.py +++ b/climix/index_functions/support.py @@ -116,6 +116,32 @@ def normalize_axis(axis, ndim): return axis +class DifferenceThresholdMixin: + def __init__(self, threshold, condition, *args, **kwargs): + super().__init__(*args, **kwargs) + self.threshold = threshold + self.condition = NUMPY_OPERATORS[condition] + self.lazy_condition = DASK_OPERATORS[condition] + self.extra_coords.append(threshold.copy()) + + def prepare(self, input_cubes): + ref_cube = next(iter(input_cubes.values())) + threshold = self.threshold + threshold.points = threshold.points.astype(ref_cube.dtype) + if threshold.has_bounds(): + threshold.bounds = threshold.bounds.astype(ref_cube.dtype) + threshold_ref = threshold.copy() + threshold_ref.points.fill(0) + if threshold_ref.has_bounds(): + threshold_ref.bounds.fill(0) + change_units(threshold, ref_cube.units, ref_cube.standard_name) + change_units(threshold_ref, ref_cube.units, ref_cube.standard_name) + threshold.points = threshold.points - threshold_ref.points + if threshold.has_bounds(): + threshold.bounds = threshold.bounds - threshold_ref.bounds + super().prepare(input_cubes) + + class IndexFunction: def __init__(self, standard_name=None, units=Unit("no_unit")): super().__init__() diff --git a/setup.py b/setup.py index 8b5602dffb7421247cc09773dd33af321dff2863..d898399d3df59b05a724b59b19325f4babedb121 100755 --- a/setup.py +++ b/setup.py @@ -70,6 +70,7 @@ setuptools.setup( "spell_length=climix.index_functions:SpellLength", "statistics=climix.index_functions:Statistics", "temperature_sum=climix.index_functions:TemperatureSum", + "thresholded_diurnal_temperature_range=climix.index_functions:ThresholdedDiurnalTemperatureRange", # noqa: E501 "thresholded_percentile=climix.index_functions:ThresholdedPercentile", "thresholded_running_statistics=climix.index_functions:ThresholdedRunningStatistics", # noqa: E501 "thresholded_statistics=climix.index_functions:ThresholdedStatistics",