Commit 1cacc1f5 authored by Klaus Zimmermann's avatar Klaus Zimmermann
Browse files

Improve postprocessing (closes #180)

parent 0ca00ff5
......@@ -22,16 +22,15 @@ class PointLocalAggregator(Aggregator):
cube.standard_name = self.index_function.standard_name
cube.units = self.index_function.units
def post_process(self, collapsed_cube, data_result, coords, **kwargs):
cube = super().post_process(collapsed_cube, data_result, coords,
**kwargs)
def post_process(self, cube, data, coords, **kwargs):
try:
post_processor = self.index_function.post_process
except AttributeError:
# this index function does not require post processing
pass
else:
cube = post_processor(cube, data_result, coords, **kwargs)
cube, data = post_processor(cube, data, coords, **kwargs)
cube = super().post_process(cube, data, coords, **kwargs)
standard_name = self.output_metadata.standard_name
unit_standard_name = standard_name
proposed_standard_name = self.output_metadata.proposed_standard_name
......@@ -44,3 +43,6 @@ class PointLocalAggregator(Aggregator):
for extra_coord in self.index_function.extra_coords:
cube.add_aux_coord(extra_coord)
return cube
def pre_aggregate_shape(self, *args, **kwargs):
return self.index_function.pre_aggregate_shape(*args, **kwargs)
......@@ -118,19 +118,17 @@ class FirstOccurrence(ThresholdMixin, IndexFunction):
res = da.ma.masked_array(da.ma.getdata(res), mask)
return res.astype('float32')
def post_process(self, collapsed_cube, data_result, coords,
period, **kwargs):
time = collapsed_cube.coord('time')
def post_process(self, cube, data, coords, period, **kwargs):
time = cube.coord('time')
calendar = time.units.calendar
offsets = np.empty_like(time.points, dtype=data_result.dtype)
offsets = np.empty_like(time.points, dtype=data.dtype)
for i, representative_date in enumerate(time.cells()):
year = representative_date.point.year
start_date = datetime(year, period.first_month_number, 1)
units = Unit(f'days since {year}-01-01', calendar=calendar)
offsets[i] = units.date2num(start_date)
collapsed_cube.data = (collapsed_cube.core_data()
+ offsets[:, None, None])
return collapsed_cube
result_data = data + offsets[:, None, None]
return cube, result_data
class InterdayDiurnalTemperatureRange(IndexFunction):
......@@ -181,19 +179,17 @@ class LastOccurrence(ThresholdMixin, IndexFunction):
res = da.ma.masked_array(da.ma.getdata(res), mask)
return res.astype('float32')
def post_process(self, collapsed_cube, data_result, coords,
period, **kwargs):
time = collapsed_cube.coord('time')
def post_process(self, cube, data, coords, period, **kwargs):
time = cube.coord('time')
calendar = time.units.calendar
offsets = np.empty_like(time.points, dtype=data_result.dtype)
offsets = np.empty_like(time.points, dtype=data.dtype)
for i, representative_date in enumerate(time.cells()):
year = representative_date.point.year
start_date = datetime(year, period.first_month_number, 1)
units = Unit(f'days since {year}-01-01', calendar=calendar)
offsets[i] = units.date2num(start_date)
collapsed_cube.data = (collapsed_cube.core_data()
+ offsets[:, None, None])
return collapsed_cube
result_data = data + offsets[:, None, None]
return cube, result_data
class Percentile(IndexFunction):
......
......@@ -82,6 +82,9 @@ class IndexFunction:
def prepare(self, input_cubes):
pass
def pre_aggregate_shape(self, *args, **kwargs):
return ()
class ThresholdMixin:
def __init__(self, threshold, condition, *args, **kwargs):
......
......@@ -80,7 +80,7 @@ def multicube_aggregated_by(cubes, coords, aggregator, **kwargs):
# Determine the group-by cube data shape.
data_shape = list(ref_cube.shape
+ aggregator.aggregate_shape(**kwargs))
+ aggregator.pre_aggregate_shape(**kwargs))
data_shape[dimension_to_groupby] = len(groupby)
# Aggregate the group-by data.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment