main.py 10.6 KB
Newer Older
Klaus Zimmermann's avatar
Klaus Zimmermann committed
1
2
3
4
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import argparse
5
import copy
Klaus Zimmermann's avatar
Klaus Zimmermann committed
6
import os
7
import re
8
import threading
Klaus Zimmermann's avatar
Klaus Zimmermann committed
9
10
11
import time

import iris
12
from iris.experimental.equalise_cubes import equalise_attributes
13
import netCDF4
14
import numpy as np
15
import pkg_resources
16
import sentry_sdk
17
import six
Klaus Zimmermann's avatar
Klaus Zimmermann committed
18
19
20
import yaml

import climix
21
from .dask_setup import SCHEDULERS, setup_scheduler
22
from .index import Index
23
from .index_functions import SUPPORTED_OPERATORS, SUPPORTED_REDUCERS
24
from .period import PeriodSpecification
Klaus Zimmermann's avatar
Klaus Zimmermann committed
25

26
27
28
import logging
logging.basicConfig(level=logging.INFO)

Klaus Zimmermann's avatar
Klaus Zimmermann committed
29

30
31
32
sentry_sdk.init("https://d3ac73a62877407b848dfc3f318bed85@sentry.io/1458386")


Klaus Zimmermann's avatar
Klaus Zimmermann committed
33
34
35
36
37
MISSVAL = 1.0e20


def parse_args():
    parser = argparse.ArgumentParser(
38
        description=(f'A climate index thing, version {climix.__version__}.'))
39
40
41
42
43
    parser.add_argument('-d', '--dask-scheduler', choices=SCHEDULERS.keys(),
                        default='distributed-local-cluster')
    parser.add_argument('-k', '--keep-open', action='store_true',
                        help='keep climix running until key press '
                        '(useful for debugging)')
44
45
46
    parser.add_argument('-s', '--sliced-mode', action='store_true',
                        help='activate calculation per period to avoid memory '
                        'problems')
47
48
49
50
    parser.add_argument('-o', '--output', dest='output_template',
                        help='output filename')
    parser.add_argument('-x', '--index', action='append',
                        required=True, metavar='INDEX', dest='indices',
Lars Bärring's avatar
Lars Bärring committed
51
                        help='the index to calculcate')
52
53
    parser.add_argument('datafiles', nargs='+', metavar="DATAFILE",
                        help='the input data files')
Klaus Zimmermann's avatar
Klaus Zimmermann committed
54
55
56
57
    return parser.parse_args()


def ignore_cb(cube, field, filename):
58
59
    cube.attributes.pop('creation_date', None)
    cube.attributes.pop('tracking_id', None)
Klaus Zimmermann's avatar
Klaus Zimmermann committed
60
61
62
63
64
65
66
67
68
69


def load_metadata():
    metadata_filename = os.path.join(os.path.dirname(__file__),
                                     'etc', 'metadata.yml')
    with open(metadata_filename) as md_file:
        metadata = yaml.safe_load(md_file)
    return metadata


70
71
72
73
74
75
76
def build_parameters(parameters_metadata):
    parameters = {}
    for name, md in parameters_metadata.items():
        kind = md.pop('kind')
        if kind == 'quantity':
            parameter = iris.cube.Cube(**md)
        elif kind == 'operator':
77
78
            op = md['operator']
            if op not in SUPPORTED_OPERATORS:
79
                raise ValueError(f'Unknown operator <{op}>')
80
            parameter = op
81
        elif kind == 'reducer':
82
83
            red = md['reducer']
            if red not in SUPPORTED_REDUCERS:
84
                raise ValueError(f'Unknown reducer <{red}>')
85
            parameter = red
86
        else:
87
            raise ValueError(f'Unknown parameter kind <{kind}>')
88
89
90
91
        parameters[name] = parameter
    return parameters


92
93
94
95
96
def build_index_function(spec):
    name = spec['name']
    candidates = list(pkg_resources.iter_entry_points('climix.index_functions',
                                                      name=name))
    if len(candidates) == 0:
97
        raise ValueError(f'No implementation found for index_function <{name}>')
98
    elif len(candidates) > 1:
99
100
101
102
103
104
105
106
        distributions = [candidate.dist for candidate in candidates]
        raise ValueError(
            f'Found several implementations for index_function <{name}>. '
            f'Please make sure only one is installed at any time. '
            f'The implementations come from the distributions {distributions}')
    candidate = candidates[0]
    logging.info(f'Found implementation for index_function <{name}> '
                 f'from distribution <{candidate.dist}>')
107
108
109
110
111
112
    index_function_factory = candidates[0].load()
    parameters = build_parameters(spec['parameters'])
    index_function = index_function_factory(**parameters)
    return index_function


113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
def build_template_index(index_definitions):
    expr = re.compile(r'{([^}]+)}')
    template_index = {}
    for index in index_definitions.keys():
        split = expr.split(index)
        if len(split) == 1:
            continue
        signature = tuple(split[::2])
        parameter_names = split[1::2]
        template_index[signature] = (index, parameter_names)
    return template_index


def replace_parameters_in_dict(dictionary, parameter_dict):
    for key, value in dictionary.items():
        if isinstance(value, dict):
            k, v = list(value.items())[0]
            if len(value) == 1 and v is None and k in parameter_dict.keys():
                dictionary[key] = parameter_dict[k]
            else:
                replace_parameters_in_dict(value, parameter_dict)
        elif isinstance(value, six.string_types):
            dictionary[key] = dictionary[key].format(**parameter_dict)


def get_index_definition(index_definitions, index):
    try:
        return index_definitions[index]
    except KeyError:
        index_expr = re.compile(r'(\d+)')
        split = index_expr.split(index)
        if len(split) == 1:
            raise
        templates = build_template_index(index_definitions)
        signature = tuple(split[::2])
        template, parameter_names = templates[signature]
        parameter_values = split[1::2]
        parameter_dict = {name: int(value)
                          for (name, value) in
                          zip(parameter_names, parameter_values)}
        index_definition = copy.deepcopy(index_definitions[template])
        replace_parameters_in_dict(index_definition, parameter_dict)
        return index_definition


Klaus Zimmermann's avatar
Klaus Zimmermann committed
158
def prepare_indices(index_definitions, requested_indices):
159
    def select_period(metadata):
160
161
        selected_period = metadata['default']
        period_metadata = metadata['allowed'][selected_period]
162
163
164
        return PeriodSpecification(selected_period, period_metadata)
    indices = []
    for index in requested_indices:
165
        definition = get_index_definition(index_definitions, index)
166
167
168
        period_spec = select_period(definition['period'])
        index_function = build_index_function(definition['index_function'])
        index = Index(index_function, definition['output'], period_spec)
169
170
        indices.append(index)
    return indices
Klaus Zimmermann's avatar
Klaus Zimmermann committed
171
172
173


def prepare_input_data(datafiles):
174
    datacube = iris.load_raw(datafiles, callback=ignore_cb)
Klaus Zimmermann's avatar
Klaus Zimmermann committed
175
    iris.util.unify_time_units(datacube)
176
    equalise_attributes(datacube)
Klaus Zimmermann's avatar
Klaus Zimmermann committed
177
178
179
180
181
    cube = datacube.concatenate_cube()
    return cube


def guess_output_template(datafiles):
182
    output_template = '{var_name}_{frequency}.nc'
Klaus Zimmermann's avatar
Klaus Zimmermann committed
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
    def filename_stripper(path):
        # remove directory part...
        basename = os.path.basename(path)
        # ...and extension
        root, ext = os.path.splitext(basename)
        # split at _...
        parts = root.split('_')
        # and remove
        # first part (usually variable)
        # and last part (usually time)
        base = '_'.join(parts[1:-1])
        try:
            time = [int(t) for t in parts[-1].split('-')]
            if len(time) == 1:
                time *= 2
        except ValueError:
            time = [None, None]
        return (base, time[0], time[1])
    files = [filename_stripper(p) for p in datafiles]
    bases, starts, ends = zip(*files)
    unique_bases = set(bases)
    if len(unique_bases) == 1:
205
        base = unique_bases.pop()
Klaus Zimmermann's avatar
Klaus Zimmermann committed
206
207
        start = min(starts)
        end = max(ends)
208
        output_template = f'{{var_name}}_{base}_{{frequency}}_{start}-{end}.nc'
Klaus Zimmermann's avatar
Klaus Zimmermann committed
209
210
211
212
213
214
    return output_template


def build_output_filename(index, datafiles, output_template):
    if output_template is None:
        output_template = guess_output_template(datafiles)
215
216
    frequency = index.period.label
    return output_template.format(frequency=frequency, **index.output_metadata)
Klaus Zimmermann's avatar
Klaus Zimmermann committed
217
218


219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
def save(result, output_filename, sliced_mode=False):
    if sliced_mode:
        logging.info('Performing aggregation in sliced mode')
        data = result.core_data()
        logging.info('creating empty data')
        result.data = np.empty(data.shape, data.dtype)
        result.data
        logging.info('saving empty cube')
        iris.save(result, output_filename, fill_value=MISSVAL,
                  local_keys=['proposed_standard_name'])
        logging.info('opening')
        result.data = data
        with netCDF4.Dataset(output_filename, 'a') as ds:
            var = ds[result.var_name]
            time_dim = result.coord_dims('time')[0]
            no_slices = result.shape[time_dim]
            def store(i, data):
                var[i, ...] = data
            thread = threading.Thread()
            thread.start()
            start = time.time()
            for i, result_cube in enumerate(result.slices_over(time_dim)):
                logging.info('Starting with {}'.format(result_cube.coord('time')))
                result_cube.data
                logging.info('Waiting for previous save to finish')
                thread.join()
                thread = threading.Thread(target=store,
                                          args=(i, result_cube.data))
                thread.start()
                end = time.time()
                partial = end - start
                total_estimate = partial/(i+1)*no_slices
                logging.info(f'Finished {i+1}/{no_slices} in {partial:4f}s. '
                             f'Estimated total time is {total_estimate:4f}s.')
            thread.join()
    else:
        logging.info('Performing aggregation in normal mode')
        iris.save(result, output_filename, fill_value=MISSVAL,
                  local_keys=['proposed_standard_name'])


260
def do_main(requested_indices, datafiles, output_template, sliced_mode):
261
    logging.info('Loading metadata')
Klaus Zimmermann's avatar
Klaus Zimmermann committed
262
    metadata = load_metadata()
263
    logging.info('Preparing indices')
264
    indices = prepare_indices(metadata['indices'], requested_indices)
Klaus Zimmermann's avatar
Klaus Zimmermann committed
265
    for index in indices:
266
267
        logging.info(f'Starting calculations for index {index}')
        logging.info('Building output filename')
Klaus Zimmermann's avatar
Klaus Zimmermann committed
268
        output_filename = build_output_filename(index, datafiles, output_template)
269
        logging.info('Preparing input data')
Klaus Zimmermann's avatar
Klaus Zimmermann committed
270
        input_data = prepare_input_data(datafiles)
271
        logging.info('Calculating index')
272
273
274
        result = index(input_data)
        logging.info('Saving result')
        save(result, output_filename, sliced_mode=sliced_mode)
Klaus Zimmermann's avatar
Klaus Zimmermann committed
275
276
277
278


def main():
    args = parse_args()
279
280
    with setup_scheduler(args) as scheduler:
        logging.info('Scheduler ready; starting main program.')
Klaus Zimmermann's avatar
Klaus Zimmermann committed
281
        start = time.time()
282
283
284
285
286
287
288
289
        try:
            do_main(args.indices, args.datafiles,
                    args.output_template, args.sliced_mode)
        finally:
            end = time.time()
            logging.info(f'Calculation took {end-start:.4f} seconds.')
        if args.keep_open:
            input('Press enter to close the cluster ')
Klaus Zimmermann's avatar
Klaus Zimmermann committed
290
291
292
293


if __name__ == "__main__":
    main()