spell_functions.py 1.92 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from cf_units import Unit
import dask.array as da
import numpy as np

from .spell_kernels import make_spell_length_kernels
from .support import (normalize_axis,
                      IndexFunction,
                      ThresholdMixin, ReducerMixin)


class SpellLength(ThresholdMixin, ReducerMixin, IndexFunction):
    def __init__(self, threshold, condition, reducer):
        super().__init__(threshold, condition, reducer, units=Unit('days'))
        kernels = make_spell_length_kernels(self.scalar_reducer)
        self.chunk_kernel, self.combine_kernel = kernels

    def call_func(self, data, axis, **kwargs):
        axis = normalize_axis(axis, data.ndim)
        mask = np.ma.getmaskarray(data).any(axis=axis)
        res = np.apply_along_axis(self, axis=axis, arr=data)
        res = np.ma.masked_array(np.ma.getdata(res), mask)
        return res.astype('float32')

    def lazy_func(self, data, axis, **kwargs):
        axis = normalize_axis(axis, data.ndim)
        mask = da.ma.getmaskarray(data).any(axis=axis)
        data = da.moveaxis(data, axis, -1)
        res = da.reduction(data, self.chunk, self.aggregate,
                           axis=-1, dtype=int, combine=self.combine,
                           concatenate=False)
        res = da.ma.masked_array(da.ma.getdata(res), mask)
        return res.astype('float32')

    def chunk(self, raw_data, axis, keepdims, computing_meta=False):
        if computing_meta:
            return np.array((0,), ndim=1, dtype=int)

        data = self.condition(raw_data, self.threshold.points)
        chunk_res = self.chunk_kernel(data)
        return chunk_res

    def combine(self, x_chunk, axis, keepdims):
        if not isinstance(x_chunk, list):
            return x_chunk
        return self.combine_kernel(np.array(x_chunk))

    def aggregate(self, x_chunk, axis, keepdims):
        res = self.combine(x_chunk, axis, keepdims)
        res = self.reducer(res[..., 1:], axis=-1)
        return res