Commit 4c408e07 authored by Klaus Zimmermann's avatar Klaus Zimmermann
Browse files

Multiple inputs index functions (closes #82)

parent a101b428
......@@ -19,11 +19,11 @@ def ignore_cb(cube, field, filename):
def prepare_input_data(datafiles):
datacube = iris.load_raw(datafiles, callback=ignore_cb)
iris.util.unify_time_units(datacube)
equalise_attributes(datacube)
cube = datacube.concatenate_cube()
return cube
datacubes = iris.load_raw(datafiles, callback=ignore_cb)
iris.util.unify_time_units(datacubes)
equalise_attributes(datacubes)
cubes = datacubes.concatenate()
return cubes
def save(result, output_filename, sliced_mode=False):
......
......@@ -3,26 +3,31 @@
import logging
from .aggregators import PointLocalAggregator
from .iris import cubelist_aggregated_by
from .period import build_period
class Index:
def __init__(self, index_function, output_metadata, period_spec):
def __init__(self, index_function, metadata, period_spec):
self.index_function = index_function
self.output_metadata = output_metadata
self.output_metadata = metadata.output
self.period = build_period(period_spec)
self.aggregator = PointLocalAggregator(index_function,
output_metadata)
metadata.output)
self.mapping = {iv.var_name: argname
for argname, iv in metadata.input.items()}
def __call__(self, cube):
def __call__(self, cubes):
logging.info('Adding coord categorisation.')
coord_name = self.period.add_coord_categorisation(cube)
coord_name = list(map(self.period.add_coord_categorisation, cubes))[0]
logging.info('Extracting period cube')
sub_cube = cube.extract(self.period.constraint)
sub_cubes = cubes.extract(self.period.constraint)
logging.info('Preparing cube')
self.index_function.prepare(sub_cube)
self.index_function.prepare(sub_cubes)
logging.info('Setting up aggregation')
aggregated = sub_cube.aggregated_by(coord_name, self.aggregator,
aggregated = cubelist_aggregated_by(sub_cubes, coord_name,
self.aggregator,
self.mapping,
period=self.period)
aggregated.attributes['frequency'] = self.period.label
return aggregated
......@@ -156,10 +156,10 @@ class Statistics(ReducerMixin, IndexFunction):
def __init__(self, reducer):
super().__init__(reducer)
def prepare(self, input_cube):
super().prepare(input_cube)
self.standard_name = input_cube.standard_name
self.units = input_cube.units
def prepare(self, input_cubes):
super().prepare(input_cubes)
self.standard_name = input_cubes[0].standard_name
self.units = input_cubes[0].units
def call_func(self, data, axis, **kwargs):
axis = normalize_axis(axis, data.ndim)
......@@ -176,10 +176,10 @@ class ThresholdedStatistics(ThresholdMixin, ReducerMixin, IndexFunction):
def __init__(self, threshold, condition, reducer):
super().__init__(threshold, condition, reducer, units=Unit('days'))
def prepare(self, input_cube):
super().prepare(input_cube)
self.standard_name = input_cube.standard_name
self.units = input_cube.units
def prepare(self, input_cubes):
super().prepare(input_cubes)
self.standard_name = input_cubes[0].standard_name
self.units = input_cubes[0].units
def call_func(self, data, axis, **kwargs):
axis = normalize_axis(axis, data.ndim)
......@@ -204,10 +204,10 @@ class TemperatureSum(ThresholdMixin, IndexFunction):
self.fun = lambda d, t: np.maximum(t - d, 0)
self.lazy_fun = lambda d, t: da.maximum(t - d, 0)
def prepare(self, input_cube):
super().prepare(input_cube)
self.standard_name = input_cube.standard_name
if input_cube.units.is_convertible('degC'):
def prepare(self, input_cubes):
super().prepare(input_cubes)
self.standard_name = input_cubes[0].standard_name
if input_cubes[0].units.is_convertible('degC'):
self.units = 'degC days'
else:
raise RuntimeError("Invalid input units")
......
......@@ -65,7 +65,7 @@ class IndexFunction:
self.units = units
self.extra_coords = []
def prepare(self, input_cube):
def prepare(self, input_cubes):
pass
......@@ -77,15 +77,15 @@ class ThresholdMixin:
self.lazy_condition = DASK_OPERATORS[condition]
self.extra_coords.append(threshold.copy())
def prepare(self, input_cube):
def prepare(self, input_cubes):
threshold = self.threshold
threshold.points = threshold.points.astype(input_cube.dtype)
threshold.points = threshold.points.astype(input_cubes[0].dtype)
if threshold.has_bounds():
threshold.bounds = threshold.bounds.astype(input_cube.dtype)
threshold.bounds = threshold.bounds.astype(input_cubes[0].dtype)
change_units(threshold,
input_cube.units,
input_cube.standard_name)
super().prepare(input_cube)
input_cubes[0].units,
input_cubes[0].standard_name)
super().prepare(input_cubes)
class ReducerMixin:
......
import dask.array as da
import iris
import numpy as np
def cubelist_aggregated_by(cubes, coords, aggregator, mapping=None, **kwargs):
if len(cubes) == 1:
return cubes[0].aggregated_by(coords, aggregator, **kwargs)
if mapping is None:
mapping = {}
# We assume all cubes have the same coordinates,
# but a test needs to be added.
groupby_coords = []
dimension_to_groupby = None
# We can't handle weights
if isinstance(
aggregator, iris.analysis.WeightedAggregator
) and aggregator.uses_weighting(**kwargs):
raise ValueError(
"Invalid Aggregation, cubelist_aggregated_by() cannot use"
" weights."
)
reference_cube = cubes[0]
coords = reference_cube._as_list_of_coords(coords)
for coord in sorted(coords, key=lambda coord: coord._as_defn()):
if coord.ndim > 1:
msg = (
"Cannot aggregate_by coord %s as it is "
"multidimensional." % coord.name()
)
raise iris.exceptions.CoordinateMultiDimError(msg)
dimension = reference_cube.coord_dims(coord)
if not dimension:
msg = (
'Cannot group-by the coordinate "%s", as its '
"dimension does not describe any data." % coord.name()
)
raise iris.exceptions.CoordinateCollapseError(msg)
if dimension_to_groupby is None:
dimension_to_groupby = dimension[0]
if dimension_to_groupby != dimension[0]:
msg = "Cannot group-by coordinates over different dimensions."
raise iris.exceptions.CoordinateCollapseError(msg)
groupby_coords.append(coord)
# Determine the other coordinates that share the same group-by
# coordinate dimension.
shared_coords = list(
filter(
lambda coord_: coord_ not in groupby_coords,
reference_cube.coords(contains_dimension=dimension_to_groupby),
)
)
# Determine which of each shared coord's dimensions will be aggregated.
shared_coords_and_dims = [
(coord_, index)
for coord_ in shared_coords
for (index, dim) in enumerate(reference_cube.coord_dims(coord_))
if dim == dimension_to_groupby
]
# Create the aggregation group-by instance.
groupby = iris.analysis._Groupby(
groupby_coords, shared_coords_and_dims
)
# Create the resulting aggregate-by cube and remove the original
# coordinates that are going to be groupedby.
# aggregateby_cube = iris.util._strip_metadata_from_dims(
# reference_cube, [dimension_to_groupby]
# )
key = [slice(None, None)] * reference_cube.ndim
# Generate unique index tuple key to maintain monotonicity.
key[dimension_to_groupby] = tuple(range(len(groupby)))
key = tuple(key)
# aggregateby_cube = aggregateby_cube[key]
aggregateby_cube = reference_cube[key]
for coord in groupby_coords + shared_coords:
aggregateby_cube.remove_coord(coord)
# Determine the group-by cube data shape.
data_shape = list(reference_cube.shape
+ aggregator.aggregate_shape(**kwargs))
data_shape[dimension_to_groupby] = len(groupby)
# Aggregate the group-by data.
if aggregator.lazy_func is not None and reference_cube.has_lazy_data():
def data_getter(cube):
return cube.lazy_data()
aggregate = aggregator.lazy_aggregate
stack = da.stack
else:
def data_getter(cube):
return cube.data
aggregate = aggregator.aggregate
stack = np.stack
front_slice = (slice(None, None),) * dimension_to_groupby
back_slice = (slice(None, None),) * (
len(data_shape) - dimension_to_groupby - 1
)
groupby_subcubes = map(
lambda groupby_slice: {
mapping.get(cube.var_name, cube.var_name):
data_getter(cube[
front_slice + (groupby_slice,) + back_slice
]) for cube in cubes},
groupby.group(),
)
def agg(data):
result = aggregate(data, axis=dimension_to_groupby, **kwargs)
return result
result = list(map(agg, groupby_subcubes))
aggregateby_data = stack(result, axis=dimension_to_groupby)
# Add the aggregation meta data to the aggregate-by cube.
aggregator.update_metadata(
aggregateby_cube, groupby_coords, aggregate=True, **kwargs
)
# Replace the appropriate coordinates within the aggregate-by cube.
(dim_coord,) = reference_cube.coords(
dimensions=dimension_to_groupby, dim_coords=True
) or [None]
for coord in groupby.coords:
if (
dim_coord is not None
and dim_coord._as_defn() == coord._as_defn()
and isinstance(coord, iris.coords.DimCoord)
):
aggregateby_cube.add_dim_coord(
coord.copy(), dimension_to_groupby
)
else:
aggregateby_cube.add_aux_coord(
coord.copy(), reference_cube.coord_dims(coord)
)
# Attach the aggregate-by data into the aggregate-by cube.
aggregateby_cube = aggregator.post_process(
aggregateby_cube, aggregateby_data, coords, **kwargs
)
return aggregateby_cube
......@@ -312,7 +312,7 @@ class IndexCatalog:
logging.error(f'Could not build index function for index '
f'{index} from definition {definition}')
raise
index = Index(index_function, definition.output, period_spec)
index = Index(index_function, definition, period_spec)
indices.append(index)
return indices
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment