Commit a0236320 authored by Klaus Zimmermann's avatar Klaus Zimmermann
Browse files

Improve spell length calculation to deal with period boundary-crossing spells (closes #183)

parent 6190a0ef
from cf_units import Unit
import dask
import dask.array as da
import numpy as np
......@@ -11,8 +12,11 @@ from .support import (normalize_axis,
class SpellLength(ThresholdMixin, ReducerMixin, IndexFunction):
def __init__(self, threshold, condition, reducer):
super().__init__(threshold, condition, reducer, units=Unit('days'))
kernels = make_spell_length_kernels(self.scalar_reducer)
self.chunk_kernel, self.combine_kernel = kernels
self.spanning_spells = True
self.kernels = make_spell_length_kernels(self.scalar_reducer)
def pre_aggregate_shape(self, *args, **kwargs):
return (4,)
def call_func(self, data, axis, **kwargs):
axis = normalize_axis(axis, data.ndim)
......@@ -26,9 +30,11 @@ class SpellLength(ThresholdMixin, ReducerMixin, IndexFunction):
mask = da.ma.getmaskarray(data).any(axis=axis)
data = da.moveaxis(data, axis, -1)
res = da.reduction(data, self.chunk, self.aggregate,
axis=-1, dtype=int, combine=self.combine,
concatenate=False)
res = da.ma.masked_array(da.ma.getdata(res), mask)
keepdims=True, output_size=4,
axis=-1, dtype=int, concatenate=False)
res = da.ma.masked_array(da.ma.getdata(res),
np.broadcast_to(mask[..., np.newaxis],
res.shape))
return res.astype('float32')
def chunk(self, raw_data, axis, keepdims, computing_meta=False):
......@@ -36,15 +42,54 @@ class SpellLength(ThresholdMixin, ReducerMixin, IndexFunction):
return np.array((0,), ndim=1, dtype=int)
data = self.condition(raw_data, self.threshold.points)
chunk_res = self.chunk_kernel(data)
chunk_res = self.kernels.chunk(data)
return chunk_res
def combine(self, x_chunk, axis, keepdims):
def aggregate(self, x_chunk, axis, keepdims):
if not isinstance(x_chunk, list):
return x_chunk
return self.combine_kernel(np.array(x_chunk))
def aggregate(self, x_chunk, axis, keepdims):
res = self.combine(x_chunk, axis, keepdims)
res = self.reducer(res[..., 1:], axis=-1)
res = self.kernels.aggregate(np.array(x_chunk))
return res
def post_process(self, cube, data, coords, period, **kwargs):
def fuse(this, next_chunk, previous_tail):
own_mask = da.ma.getmaskarray(this[..., 0])
own_head = this[..., 1]
own_tail = da.where(this[..., 0] == this[..., 1],
previous_tail + this[..., 3],
this[..., 3])
next_head = next_chunk[..., 1]
head = da.where(own_head, previous_tail + own_head, 0.)
internal = this[..., 2]
tail = da.where(next_head, 0., own_tail)
stack = da.stack([head, internal, tail], axis=-1)
spell_length = da.ma.masked_array(
self.lazy_reducer(stack, axis=-1),
own_mask)
return spell_length, own_tail
def fuse_last(this, previous_tail):
own_head = this[..., 1]
tail = this[..., 3]
head = da.where(own_head, previous_tail + own_head, 0.)
internal = this[..., 2]
stack = da.stack([head, internal, tail], axis=-1)
spell_length = self.lazy_reducer(stack, axis=-1)
return spell_length
stack = []
this = data[0]
slice_shape = this.shape[:-1]
previous_tail = da.ma.masked_array(
da.zeros(slice_shape, dtype=np.float32),
da.ma.getmaskarray(data[0, ..., 3]))
for next_chunk in data[1:]:
spell_length, previous_tail = fuse(this, next_chunk, previous_tail)
stack.append(spell_length)
this = next_chunk
stack.append(fuse_last(next_chunk, previous_tail))
res_data = da.stack(stack, axis=0)
return cube, res_data
from collections import namedtuple
from numba import jit
import numpy as np
Kernels = namedtuple('Kernels',
['chunk', 'aggregate', 'combine', 'post_process'],
defaults=[None, None])
def make_spell_length_kernels(reducer):
# The gufunc support in numba is lacking right now.
# Once numba supports NEP-20 style signatures for
......@@ -58,7 +65,7 @@ def make_spell_length_kernels(reducer):
return res
@jit(nopython=True)
def combine(x_chunk):
def aggregate(x_chunk):
# start with the first chunk and merge all others subsequently
res = x_chunk[0].copy()
# mark where this chunk is completely covered by a spell
......@@ -109,4 +116,5 @@ def make_spell_length_kernels(reducer):
# and the tail is the new tail
res[ind_tail] = next_chunk[ind_tail]
return res
return (chunk, combine)
return Kernels(chunk, aggregate)
......@@ -6,3 +6,4 @@ channels:
dependencies:
- iris
- numba
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment