Commit bd9c5640 authored by Klaus Zimmermann's avatar Klaus Zimmermann
Browse files

Add handling of aliased input variables (closes #155)

parent 5e3bad17
......@@ -23,6 +23,22 @@ def prepare_input_data(datafiles):
iris.util.unify_time_units(datacubes)
equalise_attributes(datacubes)
cubes = datacubes.concatenate()
var_names = [c.var_name for c in cubes]
if len(var_names) > len(set(var_names)): # noqa
cubes_per_var_name = {}
for c in cubes:
cs = cubes_per_var_name.setdefault(c.var_name, [])
cs.append(c)
inconsistent_var_names = []
for var_name, cubes in cubes_per_var_name.items():
if len(cubes) > 1:
inconsistent_var_names.append(var_name)
logging.error('Found more than one cube for var_name "{}".\n'
'{}'.format(
var_name,
'\n'.join(map(lambda c: str(c), cubes))))
raise ValueError('Found too many cubes for var_names {}. '
'See log for details.'.format(inconsistent_var_names))
return cubes
......
......@@ -3,15 +3,40 @@ variables:
standard_name: precipitation_flux
cell_methods:
- time: mean
aliases:
- pradjust
- prec
- rr
- precip
- RR
tas:
standard_name: air_temperature
cell_methods:
- time: mean
aliases:
- tasadjust
- tmean
- tm
- tg
- TG
- meant
tasmax:
standard_name: air_temperature
cell_methods:
- time: maximum
aliases:
- tasmaxadjust
- tmax
- tx
- maxt
- TX
tasmin:
standard_name: air_temperature
cell_methods:
- time: minimum
aliases:
- tasminadjust
- tmin
- tn
- mint
- TN
......@@ -3,31 +3,41 @@
import logging
from .aggregators import PointLocalAggregator
from .iris import cubelist_aggregated_by
from .iris import multicube_aggregated_by
from .period import build_period
class Index:
def __init__(self, index_function, metadata, period_spec):
self.index_function = index_function
self.output_metadata = metadata.output
self.metadata = metadata
self.period = build_period(period_spec)
self.aggregator = PointLocalAggregator(index_function,
metadata.output)
self.mapping = {iv.var_name: argname
for argname, iv in metadata.input.items()}
self.input_argnames = set(metadata.input.keys())
self.mapping = {}
for argname, iv in metadata.input.items():
for key in [iv.var_name] + iv.aliases:
self.mapping[key] = argname
def __call__(self, cubes):
cube_mapping = {argname: cube.extract(self.period.constraint)
for cube in cubes
if (argname := self.mapping.get(cube.var_name)) # noqa
is not None}
for argname in self.input_argnames:
if argname in cube_mapping:
logging.info('Data found for input {}'.format(argname))
else:
raise ValueError('No data found for input {}'.format(argname))
logging.info('Adding coord categorisation.')
coord_name = list(map(self.period.add_coord_categorisation, cubes))[0]
logging.info('Extracting period cube')
sub_cubes = cubes.extract(self.period.constraint)
logging.info('Preparing cube')
self.index_function.prepare(sub_cubes)
coord_name = list(map(
self.period.add_coord_categorisation, cube_mapping.values()))[0]
logging.info('Preparing cubes')
self.index_function.prepare(cube_mapping)
logging.info('Setting up aggregation')
aggregated = cubelist_aggregated_by(sub_cubes, coord_name,
self.aggregator,
self.mapping,
period=self.period)
aggregated = multicube_aggregated_by(cube_mapping, coord_name,
self.aggregator,
period=self.period)
aggregated.attributes['frequency'] = self.period.label
return aggregated
......@@ -20,7 +20,7 @@ class CountLevelCrossings(IndexFunction):
def prepare(self, input_cubes):
props = {(cube.dtype, cube.units, cube.standard_name)
for cube in input_cubes}
for cube in input_cubes.values()}
assert len(props) == 1
dtype, units, standard_name = props.pop()
threshold = self.threshold
......@@ -186,8 +186,9 @@ class Statistics(ReducerMixin, IndexFunction):
def prepare(self, input_cubes):
super().prepare(input_cubes)
self.standard_name = input_cubes[0].standard_name
self.units = input_cubes[0].units
ref_cube = next(iter(input_cubes.values()))
self.standard_name = ref_cube.standard_name
self.units = ref_cube.units
def call_func(self, data, axis, **kwargs):
axis = normalize_axis(axis, data.ndim)
......@@ -206,8 +207,9 @@ class ThresholdedStatistics(ThresholdMixin, ReducerMixin, IndexFunction):
def prepare(self, input_cubes):
super().prepare(input_cubes)
self.standard_name = input_cubes[0].standard_name
self.units = input_cubes[0].units
ref_cube = next(iter(input_cubes.values()))
self.standard_name = ref_cube.standard_name
self.units = ref_cube.units
def call_func(self, data, axis, **kwargs):
axis = normalize_axis(axis, data.ndim)
......@@ -234,8 +236,9 @@ class TemperatureSum(ThresholdMixin, IndexFunction):
def prepare(self, input_cubes):
super().prepare(input_cubes)
self.standard_name = input_cubes[0].standard_name
if input_cubes[0].units.is_convertible('degC'):
ref_cube = next(iter(input_cubes.values()))
self.standard_name = ref_cube.standard_name
if ref_cube.units.is_convertible('degC'):
self.units = 'degC days'
else:
raise RuntimeError("Invalid input units")
......
......@@ -78,13 +78,14 @@ class ThresholdMixin:
self.extra_coords.append(threshold.copy())
def prepare(self, input_cubes):
ref_cube = next(iter(input_cubes.values()))
threshold = self.threshold
threshold.points = threshold.points.astype(input_cubes[0].dtype)
threshold.points = threshold.points.astype(ref_cube.dtype)
if threshold.has_bounds():
threshold.bounds = threshold.bounds.astype(input_cubes[0].dtype)
threshold.bounds = threshold.bounds.astype(ref_cube.dtype)
change_units(threshold,
input_cubes[0].units,
input_cubes[0].standard_name)
ref_cube.units,
ref_cube.standard_name)
super().prepare(input_cubes)
......
......@@ -3,11 +3,7 @@ import iris
import numpy as np
def cubelist_aggregated_by(cubes, coords, aggregator, mapping=None, **kwargs):
if len(cubes) == 1:
return cubes[0].aggregated_by(coords, aggregator, **kwargs)
if mapping is None:
mapping = {}
def multicube_aggregated_by(cubes, coords, aggregator, **kwargs):
# We assume all cubes have the same coordinates,
# but a test needs to be added.
groupby_coords = []
......@@ -18,13 +14,13 @@ def cubelist_aggregated_by(cubes, coords, aggregator, mapping=None, **kwargs):
aggregator, iris.analysis.WeightedAggregator
) and aggregator.uses_weighting(**kwargs):
raise ValueError(
"Invalid Aggregation, cubelist_aggregated_by() cannot use"
"Invalid Aggregation, multicube_aggregated_by() cannot use"
" weights."
)
reference_cube = cubes[0]
ref_cube = next(iter(cubes.values()))
coords = reference_cube._as_list_of_coords(coords)
coords = ref_cube._as_list_of_coords(coords)
for coord in sorted(coords, key=lambda coord: coord._as_defn()):
if coord.ndim > 1:
msg = (
......@@ -32,7 +28,7 @@ def cubelist_aggregated_by(cubes, coords, aggregator, mapping=None, **kwargs):
"multidimensional." % coord.name()
)
raise iris.exceptions.CoordinateMultiDimError(msg)
dimension = reference_cube.coord_dims(coord)
dimension = ref_cube.coord_dims(coord)
if not dimension:
msg = (
'Cannot group-by the coordinate "%s", as its '
......@@ -51,7 +47,7 @@ def cubelist_aggregated_by(cubes, coords, aggregator, mapping=None, **kwargs):
shared_coords = list(
filter(
lambda coord_: coord_ not in groupby_coords,
reference_cube.coords(contains_dimension=dimension_to_groupby),
ref_cube.coords(contains_dimension=dimension_to_groupby),
)
)
......@@ -59,7 +55,7 @@ def cubelist_aggregated_by(cubes, coords, aggregator, mapping=None, **kwargs):
shared_coords_and_dims = [
(coord_, index)
for coord_ in shared_coords
for (index, dim) in enumerate(reference_cube.coord_dims(coord_))
for (index, dim) in enumerate(ref_cube.coord_dims(coord_))
if dim == dimension_to_groupby
]
......@@ -71,24 +67,24 @@ def cubelist_aggregated_by(cubes, coords, aggregator, mapping=None, **kwargs):
# Create the resulting aggregate-by cube and remove the original
# coordinates that are going to be groupedby.
# aggregateby_cube = iris.util._strip_metadata_from_dims(
# reference_cube, [dimension_to_groupby]
# ref_cube, [dimension_to_groupby]
# )
key = [slice(None, None)] * reference_cube.ndim
key = [slice(None, None)] * ref_cube.ndim
# Generate unique index tuple key to maintain monotonicity.
key[dimension_to_groupby] = tuple(range(len(groupby)))
key = tuple(key)
# aggregateby_cube = aggregateby_cube[key]
aggregateby_cube = reference_cube[key]
aggregateby_cube = ref_cube[key]
for coord in groupby_coords + shared_coords:
aggregateby_cube.remove_coord(coord)
# Determine the group-by cube data shape.
data_shape = list(reference_cube.shape
data_shape = list(ref_cube.shape
+ aggregator.aggregate_shape(**kwargs))
data_shape[dimension_to_groupby] = len(groupby)
# Aggregate the group-by data.
if aggregator.lazy_func is not None and reference_cube.has_lazy_data():
if aggregator.lazy_func is not None and ref_cube.has_lazy_data():
def data_getter(cube):
return cube.lazy_data()
......@@ -105,14 +101,21 @@ def cubelist_aggregated_by(cubes, coords, aggregator, mapping=None, **kwargs):
back_slice = (slice(None, None),) * (
len(data_shape) - dimension_to_groupby - 1
)
groupby_subcubes = map(
lambda groupby_slice: {
mapping.get(cube.var_name, cube.var_name):
data_getter(cube[
front_slice + (groupby_slice,) + back_slice
]) for cube in cubes},
groupby.group(),
)
if len(cubes) == 1:
groupby_subcubes = map(
lambda groupby_slice: data_getter(
ref_cube[front_slice + (groupby_slice,) + back_slice]),
groupby.group(),)
else:
groupby_subcubes = map(
lambda groupby_slice: {
argname:
data_getter(cube[
front_slice + (groupby_slice,) + back_slice
]) for argname, cube in cubes.items()},
groupby.group(),
)
def agg(data):
result = aggregate(data, axis=dimension_to_groupby, **kwargs)
......@@ -125,7 +128,7 @@ def cubelist_aggregated_by(cubes, coords, aggregator, mapping=None, **kwargs):
aggregateby_cube, groupby_coords, aggregate=True, **kwargs
)
# Replace the appropriate coordinates within the aggregate-by cube.
(dim_coord,) = reference_cube.coords(
(dim_coord,) = ref_cube.coords(
dimensions=dimension_to_groupby, dim_coords=True
) or [None]
for coord in groupby.coords:
......@@ -139,7 +142,7 @@ def cubelist_aggregated_by(cubes, coords, aggregator, mapping=None, **kwargs):
)
else:
aggregateby_cube.add_aux_coord(
coord.copy(), reference_cube.coord_dims(coord)
coord.copy(), ref_cube.coord_dims(coord)
)
# Attach the aggregate-by data into the aggregate-by cube.
......
......@@ -107,7 +107,7 @@ def build_output_filename(index, datafiles, output_template):
output_template = guess_output_template(datafiles)
frequency = index.period.label
return output_template.format(frequency=frequency,
**index.output_metadata.drs)
**index.metadata.output.drs)
def do_main(index_catalog, requested_indices, datafiles,
......
......@@ -66,18 +66,22 @@ class InputVariable:
var_name: str
standard_name: str
cell_methods: List[CellMethod]
aliases: List[str]
def instantiate(self, parameters):
return InputVariable(
format_var_name(self.var_name, parameters),
self.standard_name.format(**parameters),
self.cell_methods)
self.cell_methods, self.aliases)
def build_variable(name, variable, path):
cell_methods = [CellMethod(*cm.popitem())
for cm in variable.pop('cell_methods')]
return InputVariable(name, variable['standard_name'], cell_methods)
return InputVariable(name,
variable['standard_name'],
cell_methods,
variable['aliases'])
class ParameterKind(Enum):
......
......@@ -30,7 +30,7 @@ setuptools.setup(
'Topic :: Scientific/Engineering :: GIS',
],
packages=setuptools.find_packages(exclude=['data', 'legacy']),
python_requires='>=3',
python_requires='>=3.8',
install_requires=[
'cf-units',
'dask>=2.4.0',
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment