Commit 75cf8003 authored by Klaus Zimmermann's avatar Klaus Zimmermann
Browse files

Add client passing to increase flexibility in processing (closes #181)

By passing the scheduler.client object throught the processing chain,
we give options to the individual steps, like using client.submit etc.
parent 1cacc1f5
......@@ -23,6 +23,9 @@ class DistributedLocalClusterScheduler:
class LocalThreadsScheduler:
def __init__(self):
self.client = None
def __enter__(self):
dask.config.set(scheduler='threads')
return self
......@@ -32,6 +35,9 @@ class LocalThreadsScheduler:
class SingleThreadedScheduler:
def __init__(self):
self.client = None
def __enter__(self):
dask.config.set(scheduler='single-threaded')
return self
......
......@@ -42,7 +42,7 @@ def prepare_input_data(datafiles):
return cubes
def save(result, output_filename, sliced_mode=False):
def save(result, output_filename, sliced_mode=False, client=None):
if sliced_mode:
logging.info('Performing aggregation in sliced mode')
data = result.core_data()
......
......@@ -20,7 +20,7 @@ class Index:
for key in [iv.var_name] + iv.aliases:
self.mapping[key] = argname
def __call__(self, cubes):
def __call__(self, cubes, client=None):
cube_mapping = {argname: cube.extract(self.period.constraint)
for cube in cubes
if (argname := self.mapping.get(cube.var_name)) # noqa
......@@ -42,6 +42,7 @@ class Index:
logging.info('Setting up aggregation')
aggregated = multicube_aggregated_by(cube_mapping, coord_name,
self.aggregator,
period=self.period)
period=self.period,
client=client)
aggregated.attributes['frequency'] = self.period.label
return aggregated
......@@ -111,7 +111,7 @@ def build_output_filename(index, datafiles, output_template):
def do_main(index_catalog, requested_indices, datafiles,
output_template, sliced_mode):
output_template, sliced_mode, scheduler):
logging.info('Preparing indices')
indices = index_catalog.prepare_indices(requested_indices)
for index in indices:
......@@ -123,9 +123,10 @@ def do_main(index_catalog, requested_indices, datafiles,
logging.info('Preparing input data')
input_data = prepare_input_data(datafiles)
logging.info('Calculating index')
result = index(input_data)
result = index(input_data, client=scheduler.client)
logging.info('Saving result')
save(result, output_filename, sliced_mode=sliced_mode)
save(result, output_filename, sliced_mode=sliced_mode,
client=scheduler.client)
def main():
......@@ -140,12 +141,12 @@ def main():
print(list(index_catalog.get_list()))
return
with setup_scheduler(args):
with setup_scheduler(args) as scheduler:
logging.info('Scheduler ready; starting main program.')
start = time.time()
try:
do_main(index_catalog, args.indices, args.datafiles,
args.output_template, args.sliced_mode)
args.output_template, args.sliced_mode, scheduler)
finally:
end = time.time()
logging.info(f'Calculation took {end-start:.4f} seconds.')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment