Commit caf15b62 authored by Klaus Zimmermann's avatar Klaus Zimmermann
Browse files

Percentile based indices (closes #196)

parent 6d0aeddb
# -*- coding: utf-8 -*-
import numpy as np
import sparse
def dask_take_along_axis_chunk(x, idx, offset, x_size, axis):
# Needed when idx is unsigned
idx = idx.astype(np.int64)
# Normalize negative indices
idx = np.where(idx < 0, idx + x_size, idx)
# A chunk of the offset dask Array is a numpy array with shape (1, ).
# It indicates the index of the first element along axis of the current
# chunk of x.
idx = idx - offset
# Drop elements of idx that do not fall inside the current chunk of x
idx_filter = (idx >= 0) & (idx < x.shape[axis])
idx[~idx_filter] = 0
res = np.take_along_axis(x, idx, axis=axis)
res[~idx_filter] = 0
return sparse.COO(np.expand_dims(res, axis))
def dask_take_along_axis(x, index, axis):
from dask.array.core import Array, blockwise, from_array, map_blocks
if axis < 0:
axis += x.ndim
assert 0 <= axis < x.ndim
assert (x.shape[:axis]+x.shape[axis+1:]
== index.shape[:axis]+index.shape[axis+1:])
if np.isnan(x.chunks[axis]).any():
raise NotImplementedError(
"take_along_axis for an array with unknown chunks with "
"a dask.array of ints is not supported"
)
# Calculate the offset at which each chunk starts along axis
# e.g. chunks=(..., (5, 3, 4), ...) -> offset=[0, 5, 8]
offset = np.roll(np.cumsum(x.chunks[axis]), 1)
offset[0] = 0
offset = from_array(offset, chunks=1)
# Tamper with the declared chunks of offset to make blockwise align it with
# x[axis]
offset = Array(offset.dask, offset.name, (x.chunks[axis],), offset.dtype)
# Define axis labels for blockwise
x_axes = tuple(range(x.ndim))
idx_label = (x.ndim,) # arbitrary unused
index_axes = x_axes[:axis] + idx_label + x_axes[axis+1:]
offset_axes = (axis,)
p_axes = x_axes[:axis + 1] + idx_label + x_axes[axis + 1:]
# Calculate the cartesian product of every chunk of x vs
# every chunk of index
p = blockwise(
dask_take_along_axis_chunk,
p_axes,
x,
x_axes,
index,
index_axes,
offset,
offset_axes,
x_size=x.shape[axis],
axis=axis,
meta=np.empty((0,) * index.ndim, dtype=x.dtype),
dtype=x.dtype,
)
res = p.sum(axis=axis)
res = map_blocks(lambda sparse_x: sparse_x.todense(), res, dtype=res.dtype)
return res
......@@ -21,6 +21,9 @@ class Index:
self.mapping[key] = argname
def __call__(self, cubes, client=None, sliced_mode=False):
logging.info('Starting preprocess')
self.index_function.preprocess(cubes, client)
logging.info('Finished preprocess')
cube_mapping = {argname: cube.extract(self.period.constraint)
for cube in cubes
if (argname := self.mapping.get(cube.var_name)) # noqa
......
......@@ -15,6 +15,10 @@ from .index_functions import ( # noqa: F401
TemperatureSum,
)
from .percentile_functions import ( # noqa: F401
PercentileOccurrence,
)
from .spell_functions import ( # noqa: F401
SpellLength,
)
import logging
import cftime
import dask.array as da
from dask.distributed import progress
import numpy as np
from .support import IndexFunction, parse_timerange
from ..dask_take_along_axis import dask_take_along_axis
def calc_doy_indices(day_of_year):
max_doy = day_of_year.max()
for doy in range(1, max_doy + 1):
exact_inds = np.nonzero(day_of_year == doy)[0]
inds = np.stack(
[exact_inds + i for i in range(-2, 3)],
axis=-1,
).ravel()
yield doy, inds
def calc_index(prob, n):
delta = prob * (n + 1./3.) + 1./3.
j = int(delta)
gamma = delta - j
if j > n/2:
k = n - j
else:
k = -j - 1
return (k, gamma)
class BootstrapQuantiles:
def __init__(self, data, prob, first_year, window_size, client):
self.first_year = first_year
n = data.shape[-1]
self.k, self.gamma = calc_index(prob, n)
self.order_indices = client.persist(da.argtopk(data,
self.k - window_size,
axis=-1))
progress(self.order_indices)
self.order_statistics = client.persist(
dask_take_along_axis(data, self.order_indices, axis=-1))
progress(self.order_statistics)
self.years = client.persist(
first_year + da.floor(self.order_indices / window_size))
progress(self.years)
def quantiles(self, ignore_year=None, duplicate_year=None):
k = abs(self.k)
if ignore_year is None and duplicate_year is None:
qi = self.order_statistics[..., k-2:k]
else:
offset = (da.sum(self.years[..., :k] == ignore_year, axis=-1)
- da.sum(self.years[..., :k] == duplicate_year, axis=-1))
qi = dask_take_along_axis(
self.order_statistics,
da.stack([k-2+offset, k-1+offset], axis=-1),
axis=-1)
quantiles = (1.-self.gamma)*qi[..., 0] + self.gamma*qi[..., 1]
return quantiles
class TimesHelper:
def __init__(self, time):
self.times = time.points
self.units = str(time.units)
def __getattr__(self, name):
return getattr(self.times, name)
def __len__(self):
return len(self.times)
def __getitem__(self, key):
return self.times[key]
def build_indices(time, max_doy, no_years, window_size):
"""
Build indices
Given a linear time coordinate, build an index array `idx` of shape
`(max_doy, no_years, window_size)` such that `idx[doy, yr]` contains the
indices of the `window_size` days in the time coordinate that should
contribute to the day of year `doy` for the year `yr`. If `doy` is smaller
than `window_size`, this will include days from the year `yr -
1`. Conversely, if `doy` is larger than `max_doy - window_size`, it will
include days from `yr + 1`.
"""
window_width = window_size // 2
first_year = time.cell(0).point.timetuple()[0]
np_indices = np.zeros((max_doy, no_years, window_size), int)
for c in time.cells():
tt = c.point.timetuple()
year = tt[0]
day_of_year = tt[7] - 1
if day_of_year >= max_doy:
continue
idx_y = year - first_year
days = np.arange(day_of_year - window_width,
day_of_year + window_width + 1)
np_indices[day_of_year, idx_y] = \
window_width + idx_y * 365 + days
np_indices[0, 0, :2] = window_width
np_indices[1, 0, :1] = window_width
np_indices[-1, -1, -2:] = np_indices[-1, -1, -3]
np_indices[-2, -1, -1:] = np_indices[-2, -1, -2]
return np_indices
class PercentileOccurrence(IndexFunction):
def __init__(self, percentile, condition,
reference_period="1961-1991",
bootstrapping=True):
super().__init__(units="%")
timerange = parse_timerange(reference_period)
if timerange.climatological:
raise ValueError('The reference period cannot be climatological')
self.base = timerange
percentile.convert_units('1')
self.percentile = float(percentile.points)
self.bootstrapping = bootstrapping
def preprocess(self, cubes, client):
window_size = 5
window_width = window_size // 2
cube = cubes[0]
time = cube.coord('time')
time_units = time.units
times = TimesHelper(time)
idx_0 = cftime.date2index(self.base.start, times,
calendar=time_units.calendar,
select='after')
idx_n = cftime.date2index(self.base.end, times,
calendar=time_units.calendar,
select='before')
if time[idx_n].points[0] == cftime.date2num(
self.base.end,
time_units.name,
calendar=time_units.calendar):
# if the end date exists exactly in the data, don't include it
idx_n -= 1
self.first_year = self.base.start.year
self.last_year = self.base.end.year - 1
max_doy = 365
self.k = calc_index(self.percentile, max_doy)
self.years = {y: np.arange(i*window_size, (i+1)*window_size)
for i, y in enumerate(range(self.first_year,
self.last_year + 1))}
np_indices = build_indices(time[idx_0:idx_n+1],
max_doy, len(self.years), window_size)
all_data = da.moveaxis(
cube.core_data()[idx_0-window_width:idx_n+window_width+1],
0, -1)
data = []
for idx_d in range(max_doy):
data.append(all_data[..., np_indices[idx_d].ravel()])
data = da.stack(data, axis=0)
data = data.rechunk({0: -1, 1: 'auto', 2: 'auto', 3: -1})
self.quantiler = BootstrapQuantiles(data, self.percentile,
self.first_year, window_size,
client)
client.cancel(data)
logging.info("Starting quantile calculation")
res = client.persist(self.quantiler.quantiles().rechunk())
progress(res)
self.out_of_base_quantiles = res
def call_func(self, data, axis, **kwargs):
pass
def lazy_func(self, data, axis, cube, client, **kwargs):
year = cube.coord('time').cell(0).point.year
logging.info(f'Starting year {year}')
if data.shape[0] > 365:
data = data[:-1]
if self.bootstrapping and year in self.years:
logging.info('Using bootstrapping')
quantile_years = [y for y in self.years.keys() if y != year]
counts = []
for duplicate_year in quantile_years:
quantiles = self.quantiler.quantiles(
ignore_year=year,
duplicate_year=duplicate_year)
cond = data[...] < quantiles
count = da.count_nonzero(cond, axis=0)
counts.append(count)
counts = da.stack(counts, axis=-1)
avg_counts = counts.mean(axis=-1)
percents = avg_counts/(data.shape[0]/100.)
else:
logging.info('Not using bootstrapping')
cond = data < self.out_of_base_quantiles
counts = da.count_nonzero(cond, axis=0).astype(np.float32)
percents = counts/(data.shape[0]/100.)
return percents
# -*- coding: utf-8 -*-
from collections import namedtuple
from datetime import datetime
import operator
from cf_units import Unit
......@@ -69,6 +71,41 @@ DASK_REDUCERS = {
}
TimeRange = namedtuple('TimeRange', ['start', 'end', 'climatological'])
TIME_FORMATS = {
4: '%Y',
6: '%Y%m',
8: '%Y%m%d',
10: '%Y%m%d%H',
12: '%Y%m%d%H%M',
14: '%Y%m%d%H%M%S',
}
def parse_timerange(timerange):
parts = timerange.split('-')
n = len(parts)
if n < 2 or n > 3:
raise ValueError(f'Invalid timerange {timerange}')
if n == 3:
if parts[2] != 'clim':
raise ValueError(f'Invalid timerange {timerange}')
climatological = True
else:
climatological = False
n_start = len(parts[0])
n_end = len(parts[1])
if n_start != n_end:
raise ValueError(f'Start and end time must have the same '
f'resolution in {timerange}')
format_string = TIME_FORMATS[n_start]
start_time = datetime.strptime(parts[0], format_string)
end_time = datetime.strptime(parts[1], format_string)
return TimeRange(start_time, end_time, climatological)
def normalize_axis(axis, ndim):
if isinstance(axis, list) and len(axis) == 1:
axis = axis[0]
......@@ -85,6 +122,9 @@ class IndexFunction:
self.units = units
self.extra_coords = []
def preprocess(self, cubes, client):
pass
def prepare(self, input_cubes):
pass
......
......@@ -104,21 +104,22 @@ def multicube_aggregated_by(cubes, coords, aggregator, **kwargs):
if len(cubes) == 1:
groupby_subcubes = map(
lambda groupby_slice: data_getter(
ref_cube[front_slice + (groupby_slice,) + back_slice]),
lambda groupby_slice:
ref_cube[front_slice + (groupby_slice,) + back_slice],
groupby.group(),)
else:
groupby_subcubes = map(
lambda groupby_slice: {
argname:
data_getter(cube[
argname: cube[
front_slice + (groupby_slice,) + back_slice
]) for argname, cube in cubes.items()},
] for argname, cube in cubes.items()},
groupby.group(),
)
def agg(data):
result = aggregate(data, axis=dimension_to_groupby, **kwargs)
def agg(cube):
data = data_getter(cube)
result = aggregate(data,
axis=dimension_to_groupby, cube=cube, **kwargs)
return result
result = list(map(agg, groupby_subcubes))
aggregateby_data = stack(result, axis=dimension_to_groupby)
......
......@@ -42,6 +42,7 @@ setuptools.setup(
'regex',
'sentry-sdk',
'scitools-iris>=2.2.0',
'sparse',
],
extras_require={
'editor': ['pyexcel', 'pyexcel-xls', 'jinja2']
......@@ -68,6 +69,8 @@ setuptools.setup(
'climix.index_functions:LastOccurrence',
'percentile='
'climix.index_functions:Percentile',
'percentile_occurrence='
'climix.index_functions:PercentileOccurrence',
'spell_length='
'climix.index_functions:SpellLength',
'statistics='
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment