Commit 75cf8003 authored by Klaus Zimmermann's avatar Klaus Zimmermann
Browse files

Add client passing to increase flexibility in processing (closes #181)

By passing the scheduler.client object throught the processing chain,
we give options to the individual steps, like using client.submit etc.
parent 1cacc1f5
...@@ -23,6 +23,9 @@ class DistributedLocalClusterScheduler: ...@@ -23,6 +23,9 @@ class DistributedLocalClusterScheduler:
class LocalThreadsScheduler: class LocalThreadsScheduler:
def __init__(self):
self.client = None
def __enter__(self): def __enter__(self):
dask.config.set(scheduler='threads') dask.config.set(scheduler='threads')
return self return self
...@@ -32,6 +35,9 @@ class LocalThreadsScheduler: ...@@ -32,6 +35,9 @@ class LocalThreadsScheduler:
class SingleThreadedScheduler: class SingleThreadedScheduler:
def __init__(self):
self.client = None
def __enter__(self): def __enter__(self):
dask.config.set(scheduler='single-threaded') dask.config.set(scheduler='single-threaded')
return self return self
......
...@@ -42,7 +42,7 @@ def prepare_input_data(datafiles): ...@@ -42,7 +42,7 @@ def prepare_input_data(datafiles):
return cubes return cubes
def save(result, output_filename, sliced_mode=False): def save(result, output_filename, sliced_mode=False, client=None):
if sliced_mode: if sliced_mode:
logging.info('Performing aggregation in sliced mode') logging.info('Performing aggregation in sliced mode')
data = result.core_data() data = result.core_data()
......
...@@ -20,7 +20,7 @@ class Index: ...@@ -20,7 +20,7 @@ class Index:
for key in [iv.var_name] + iv.aliases: for key in [iv.var_name] + iv.aliases:
self.mapping[key] = argname self.mapping[key] = argname
def __call__(self, cubes): def __call__(self, cubes, client=None):
cube_mapping = {argname: cube.extract(self.period.constraint) cube_mapping = {argname: cube.extract(self.period.constraint)
for cube in cubes for cube in cubes
if (argname := self.mapping.get(cube.var_name)) # noqa if (argname := self.mapping.get(cube.var_name)) # noqa
...@@ -42,6 +42,7 @@ class Index: ...@@ -42,6 +42,7 @@ class Index:
logging.info('Setting up aggregation') logging.info('Setting up aggregation')
aggregated = multicube_aggregated_by(cube_mapping, coord_name, aggregated = multicube_aggregated_by(cube_mapping, coord_name,
self.aggregator, self.aggregator,
period=self.period) period=self.period,
client=client)
aggregated.attributes['frequency'] = self.period.label aggregated.attributes['frequency'] = self.period.label
return aggregated return aggregated
...@@ -111,7 +111,7 @@ def build_output_filename(index, datafiles, output_template): ...@@ -111,7 +111,7 @@ def build_output_filename(index, datafiles, output_template):
def do_main(index_catalog, requested_indices, datafiles, def do_main(index_catalog, requested_indices, datafiles,
output_template, sliced_mode): output_template, sliced_mode, scheduler):
logging.info('Preparing indices') logging.info('Preparing indices')
indices = index_catalog.prepare_indices(requested_indices) indices = index_catalog.prepare_indices(requested_indices)
for index in indices: for index in indices:
...@@ -123,9 +123,10 @@ def do_main(index_catalog, requested_indices, datafiles, ...@@ -123,9 +123,10 @@ def do_main(index_catalog, requested_indices, datafiles,
logging.info('Preparing input data') logging.info('Preparing input data')
input_data = prepare_input_data(datafiles) input_data = prepare_input_data(datafiles)
logging.info('Calculating index') logging.info('Calculating index')
result = index(input_data) result = index(input_data, client=scheduler.client)
logging.info('Saving result') logging.info('Saving result')
save(result, output_filename, sliced_mode=sliced_mode) save(result, output_filename, sliced_mode=sliced_mode,
client=scheduler.client)
def main(): def main():
...@@ -140,12 +141,12 @@ def main(): ...@@ -140,12 +141,12 @@ def main():
print(list(index_catalog.get_list())) print(list(index_catalog.get_list()))
return return
with setup_scheduler(args): with setup_scheduler(args) as scheduler:
logging.info('Scheduler ready; starting main program.') logging.info('Scheduler ready; starting main program.')
start = time.time() start = time.time()
try: try:
do_main(index_catalog, args.indices, args.datafiles, do_main(index_catalog, args.indices, args.datafiles,
args.output_template, args.sliced_mode) args.output_template, args.sliced_mode, scheduler)
finally: finally:
end = time.time() end = time.time()
logging.info(f'Calculation took {end-start:.4f} seconds.') logging.info(f'Calculation took {end-start:.4f} seconds.')
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment