Commit 6190a0ef authored by Klaus Zimmermann's avatar Klaus Zimmermann
Browse files

Improve sliced mode and saving (closes #182)

parent 75cf8003
# -*- coding: utf-8 -*-
import logging
import time
import dask.array as da
from dask.distributed import progress
from iris.analysis import Aggregator
from .util import change_units
......@@ -22,7 +27,38 @@ class PointLocalAggregator(Aggregator):
cube.standard_name = self.index_function.standard_name
cube.units = self.index_function.units
def post_process(self, cube, data, coords, **kwargs):
def compute_pre_result(self, data, client, sliced_mode):
if sliced_mode:
logging.info('Computing pre-result in sliced mode')
stack = []
end = time.time()
cumulative = 0.
no_slices = len(data)
for i, d in enumerate(data):
result_id = f'{i+1}/{no_slices}'
logging.info(f'Computing partial pre-result {result_id}')
d = client.persist(d)
progress(d)
print()
stack.append(d)
start = end
end = time.time()
last = end - start
cumulative += last
eta = cumulative/(i+1)*no_slices
logging.info(f'Finished {result_id} in (last cum eta): '
f'{last:4.0f} {cumulative:4.0f} {eta:4.0f}')
data = da.stack(stack, axis=0)
else:
logging.info('Setting up pre-result in aggregate mode')
start = time.time()
data = client.persist(data)
end = time.time()
logging.info(f'Setup completed in {end - start:4.0f}')
return data
def post_process(self, cube, data, coords, client, sliced_mode, **kwargs):
data = self.compute_pre_result(data, client, sliced_mode)
try:
post_processor = self.index_function.post_process
except AttributeError:
......
# -*- coding: utf-8 -*-
import logging
import threading
import time
from dask.distributed import progress
import iris
from iris.experimental.equalise_cubes import equalise_attributes
import netCDF4
......@@ -42,43 +42,47 @@ def prepare_input_data(datafiles):
return cubes
def save(result, output_filename, sliced_mode=False, client=None):
if sliced_mode:
logging.info('Performing aggregation in sliced mode')
data = result.core_data()
logging.info('creating empty data')
result.data = np.empty(data.shape, data.dtype)
result.data
logging.info('saving empty cube')
def save(result, output_filename, iterative_storage=False, client=None):
data = result.core_data()
if iterative_storage:
logging.info('Storing iteratively')
logging.info('Creating empty data')
result.data = np.zeros(data.shape, data.dtype)
logging.info('Saving empty cube')
iris.save(result, output_filename, fill_value=MISSVAL,
local_keys=['proposed_standard_name'])
logging.info('opening')
logging.info('Reopening output file and beginning storage')
result.data = data
with netCDF4.Dataset(output_filename, 'a') as ds:
var = ds[result.var_name]
time_dim = result.coord_dims('time')[0]
no_slices = result.shape[time_dim]
def store(i, data):
var[i, ...] = data
thread = threading.Thread()
thread.start()
start = time.time()
for i, result_cube in enumerate(result.slices_over(time_dim)):
logging.info(f'Starting with {result_cube.coord("time")}')
result_cube.data
logging.info('Waiting for previous save to finish')
thread.join()
thread = threading.Thread(target=store,
args=(i, result_cube.data))
thread.start()
end = time.time()
cumulative = 0.
for i, result_data in enumerate(result.core_data()):
result_id = f'{i+1}/{no_slices}'
logging.info(f'Computing partial result {result_id}')
r = result_data.persist()
progress(r)
print()
logging.info(f'Storing result {result_id}')
var[i, ...] = r
start = end
end = time.time()
partial = end - start
total_estimate = partial/(i+1)*no_slices
logging.info(f'Finished {i+1}/{no_slices} in {partial:4f}s. '
f'Estimated total time is {total_estimate:4f}s.')
thread.join()
last = end - start
cumulative += last
eta = cumulative/(i+1)*no_slices
logging.info(f'Finished {result_id} in (last cum eta): '
f'{last:4.0f} {cumulative:4.0f} {eta:4.0f}')
else:
logging.info('Performing aggregation in normal mode')
logging.info('Storing non-iteratively')
logging.info('Computing result')
r = client.compute(data)
progress(r)
print()
result.data = r.result()
logging.info(f'Storing result')
iris.save(result, output_filename, fill_value=MISSVAL,
local_keys=['proposed_standard_name'])
logging.info('Calculation complete')
......@@ -20,7 +20,7 @@ class Index:
for key in [iv.var_name] + iv.aliases:
self.mapping[key] = argname
def __call__(self, cubes, client=None):
def __call__(self, cubes, client=None, sliced_mode=False):
cube_mapping = {argname: cube.extract(self.period.constraint)
for cube in cubes
if (argname := self.mapping.get(cube.var_name)) # noqa
......@@ -43,6 +43,7 @@ class Index:
aggregated = multicube_aggregated_by(cube_mapping, coord_name,
self.aggregator,
period=self.period,
client=client)
client=client,
sliced_mode=sliced_mode)
aggregated.attributes['frequency'] = self.period.label
return aggregated
......@@ -43,6 +43,8 @@ def parse_args():
parser.add_argument('-s', '--sliced-mode', action='store_true',
help='activate calculation per period to avoid memory '
'problems')
parser.add_argument('-i', '--iterative-storage', action='store_true',
help='store results iteratively per period')
parser.add_argument('-o', '--output', dest='output_template',
help='output filename')
parser.add_argument('-x', '--index', action='append',
......@@ -111,7 +113,7 @@ def build_output_filename(index, datafiles, output_template):
def do_main(index_catalog, requested_indices, datafiles,
output_template, sliced_mode, scheduler):
output_template, sliced_mode, iterative_storage, scheduler):
logging.info('Preparing indices')
indices = index_catalog.prepare_indices(requested_indices)
for index in indices:
......@@ -123,10 +125,10 @@ def do_main(index_catalog, requested_indices, datafiles,
logging.info('Preparing input data')
input_data = prepare_input_data(datafiles)
logging.info('Calculating index')
result = index(input_data, client=scheduler.client)
result = index(input_data,
client=scheduler.client, sliced_mode=sliced_mode)
logging.info('Saving result')
save(result, output_filename, sliced_mode=sliced_mode,
client=scheduler.client)
save(result, output_filename, iterative_storage, scheduler.client)
def main():
......@@ -146,7 +148,8 @@ def main():
start = time.time()
try:
do_main(index_catalog, args.indices, args.datafiles,
args.output_template, args.sliced_mode, scheduler)
args.output_template,
args.sliced_mode, args.iterative_storage, scheduler)
finally:
end = time.time()
logging.info(f'Calculation took {end-start:.4f} seconds.')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment