Commit 0e46b6d4 authored by Klaus Zimmermann's avatar Klaus Zimmermann
Browse files

Blackify source code (closes #241)

parent a057e6df
# -*- coding: utf-8 -*-
from ._version import get_versions
__version__ = get_versions()['version']
__version__ = get_versions()["version"]
del get_versions
__all__ = [
'__version__',
"__version__",
]
# This file helps to compute a version number in source trees obtained from
# git-archive tarball (such as those provided by githubs download-from-tag
# feature). Distribution tarballs (built by setup.py sdist) and build
......@@ -58,17 +57,18 @@ HANDLERS = {}
def register_vcs_handler(vcs, method): # decorator
"""Decorator to mark a method as the handler for a particular VCS."""
def decorate(f):
"""Store f in HANDLERS[vcs][method]."""
if vcs not in HANDLERS:
HANDLERS[vcs] = {}
HANDLERS[vcs][method] = f
return f
return decorate
def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False,
env=None):
def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None):
"""Call the given command(s)."""
assert isinstance(commands, list)
p = None
......@@ -76,10 +76,13 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False,
try:
dispcmd = str([c] + args)
# remember shell=False, so use git.cmd on windows, not just git
p = subprocess.Popen([c] + args, cwd=cwd, env=env,
stdout=subprocess.PIPE,
stderr=(subprocess.PIPE if hide_stderr
else None))
p = subprocess.Popen(
[c] + args,
cwd=cwd,
env=env,
stdout=subprocess.PIPE,
stderr=(subprocess.PIPE if hide_stderr else None),
)
break
except EnvironmentError:
e = sys.exc_info()[1]
......@@ -116,16 +119,22 @@ def versions_from_parentdir(parentdir_prefix, root, verbose):
for i in range(3):
dirname = os.path.basename(root)
if dirname.startswith(parentdir_prefix):
return {"version": dirname[len(parentdir_prefix):],
"full-revisionid": None,
"dirty": False, "error": None, "date": None}
return {
"version": dirname[len(parentdir_prefix) :],
"full-revisionid": None,
"dirty": False,
"error": None,
"date": None,
}
else:
rootdirs.append(root)
root = os.path.dirname(root) # up a level
if verbose:
print("Tried directories %s but none started with prefix %s" %
(str(rootdirs), parentdir_prefix))
print(
"Tried directories %s but none started with prefix %s"
% (str(rootdirs), parentdir_prefix)
)
raise NotThisMethod("rootdir doesn't start with parentdir_prefix")
......@@ -181,7 +190,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
# starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of
# just "foo-1.0". If we see a "tag: " prefix, prefer those.
TAG = "tag: "
tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)])
tags = set([r[len(TAG) :] for r in refs if r.startswith(TAG)])
if not tags:
# Either we're using git < 1.8.3, or there really are no tags. We use
# a heuristic: assume all version tags have a digit. The old git %d
......@@ -190,7 +199,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
# between branches and tags. By ignoring refnames without digits, we
# filter out many common branch names like "release" and
# "stabilization", as well as "HEAD" and "master".
tags = set([r for r in refs if re.search(r'\d', r)])
tags = set([r for r in refs if re.search(r"\d", r)])
if verbose:
print("discarding '%s', no digits" % ",".join(refs - tags))
if verbose:
......@@ -198,19 +207,26 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
for ref in sorted(tags):
# sorting will prefer e.g. "2.0" over "2.0rc1"
if ref.startswith(tag_prefix):
r = ref[len(tag_prefix):]
r = ref[len(tag_prefix) :]
if verbose:
print("picking %s" % r)
return {"version": r,
"full-revisionid": keywords["full"].strip(),
"dirty": False, "error": None,
"date": date}
return {
"version": r,
"full-revisionid": keywords["full"].strip(),
"dirty": False,
"error": None,
"date": date,
}
# no suitable tags, so version is "0+unknown", but full hex is still there
if verbose:
print("no suitable tags, using unknown + full revision id")
return {"version": "0+unknown",
"full-revisionid": keywords["full"].strip(),
"dirty": False, "error": "no suitable tags", "date": None}
return {
"version": "0+unknown",
"full-revisionid": keywords["full"].strip(),
"dirty": False,
"error": "no suitable tags",
"date": None,
}
@register_vcs_handler("git", "pieces_from_vcs")
......@@ -225,8 +241,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
if sys.platform == "win32":
GITS = ["git.cmd", "git.exe"]
out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root,
hide_stderr=True)
out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=True)
if rc != 0:
if verbose:
print("Directory %s not under git control" % root)
......@@ -234,10 +249,19 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
# if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty]
# if there isn't one, this yields HEX[-dirty] (no NUM)
describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty",
"--always", "--long",
"--match", "%s*" % tag_prefix],
cwd=root)
describe_out, rc = run_command(
GITS,
[
"describe",
"--tags",
"--dirty",
"--always",
"--long",
"--match",
"%s*" % tag_prefix,
],
cwd=root,
)
# --long was added in git-1.5.5
if describe_out is None:
raise NotThisMethod("'git describe' failed")
......@@ -260,17 +284,16 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
dirty = git_describe.endswith("-dirty")
pieces["dirty"] = dirty
if dirty:
git_describe = git_describe[:git_describe.rindex("-dirty")]
git_describe = git_describe[: git_describe.rindex("-dirty")]
# now we have TAG-NUM-gHEX or HEX
if "-" in git_describe:
# TAG-NUM-gHEX
mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe)
mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe)
if not mo:
# unparseable. Maybe git-describe is misbehaving?
pieces["error"] = ("unable to parse git-describe output: '%s'"
% describe_out)
pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out
return pieces
# tag
......@@ -279,10 +302,12 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
if verbose:
fmt = "tag '%s' doesn't start with prefix '%s'"
print(fmt % (full_tag, tag_prefix))
pieces["error"] = ("tag '%s' doesn't start with prefix '%s'"
% (full_tag, tag_prefix))
pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % (
full_tag,
tag_prefix,
)
return pieces
pieces["closest-tag"] = full_tag[len(tag_prefix):]
pieces["closest-tag"] = full_tag[len(tag_prefix) :]
# distance: number of commits since tag
pieces["distance"] = int(mo.group(2))
......@@ -293,13 +318,13 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
else:
# HEX: no tags
pieces["closest-tag"] = None
count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"],
cwd=root)
count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], cwd=root)
pieces["distance"] = int(count_out) # total number of commits
# commit date: see ISO-8601 comment in git_versions_from_keywords()
date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"],
cwd=root)[0].strip()
date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[
0
].strip()
pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1)
return pieces
......@@ -330,8 +355,7 @@ def render_pep440(pieces):
rendered += ".dirty"
else:
# exception #1
rendered = "0+untagged.%d.g%s" % (pieces["distance"],
pieces["short"])
rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"])
if pieces["dirty"]:
rendered += ".dirty"
return rendered
......@@ -445,11 +469,13 @@ def render_git_describe_long(pieces):
def render(pieces, style):
"""Render the given version pieces into the requested style."""
if pieces["error"]:
return {"version": "unknown",
"full-revisionid": pieces.get("long"),
"dirty": None,
"error": pieces["error"],
"date": None}
return {
"version": "unknown",
"full-revisionid": pieces.get("long"),
"dirty": None,
"error": pieces["error"],
"date": None,
}
if not style or style == "default":
style = "pep440" # the default
......@@ -469,9 +495,13 @@ def render(pieces, style):
else:
raise ValueError("unknown style '%s'" % style)
return {"version": rendered, "full-revisionid": pieces["long"],
"dirty": pieces["dirty"], "error": None,
"date": pieces.get("date")}
return {
"version": rendered,
"full-revisionid": pieces["long"],
"dirty": pieces["dirty"],
"error": None,
"date": pieces.get("date"),
}
def get_versions():
......@@ -485,8 +515,7 @@ def get_versions():
verbose = cfg.verbose
try:
return git_versions_from_keywords(get_keywords(), cfg.tag_prefix,
verbose)
return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, verbose)
except NotThisMethod:
pass
......@@ -495,13 +524,16 @@ def get_versions():
# versionfile_source is the relative path from the top of the source
# tree (where the .git directory might live) to this file. Invert
# this to find the root from __file__.
for i in cfg.versionfile_source.split('/'):
for i in cfg.versionfile_source.split("/"):
root = os.path.dirname(root)
except NameError:
return {"version": "0+unknown", "full-revisionid": None,
"dirty": None,
"error": "unable to find root of source tree",
"date": None}
return {
"version": "0+unknown",
"full-revisionid": None,
"dirty": None,
"error": "unable to find root of source tree",
"date": None,
}
try:
pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose)
......@@ -515,6 +547,10 @@ def get_versions():
except NotThisMethod:
pass
return {"version": "0+unknown", "full-revisionid": None,
"dirty": None,
"error": "unable to compute version", "date": None}
return {
"version": "0+unknown",
"full-revisionid": None,
"dirty": None,
"error": "unable to compute version",
"date": None,
}
......@@ -12,13 +12,15 @@ from .util import change_units
class PointLocalAggregator(Aggregator):
def __init__(self, index_function, output_metadata, **kwargs):
cell_method = 'max_spell'
cell_method = "max_spell"
self.index_function = index_function
self.output_metadata = output_metadata
super().__init__(cell_method=cell_method,
call_func=self.index_function.call_func,
lazy_func=self.index_function.lazy_func,
**kwargs)
super().__init__(
cell_method=cell_method,
call_func=self.index_function.call_func,
lazy_func=self.index_function.lazy_func,
**kwargs,
)
def update_metadata(self, cube, coords, **kwargs):
super().update_metadata(cube, coords, **kwargs)
......@@ -29,14 +31,14 @@ class PointLocalAggregator(Aggregator):
def compute_pre_result(self, data, client, sliced_mode):
if sliced_mode:
logging.info('Computing pre-result in sliced mode')
logging.info("Computing pre-result in sliced mode")
stack = []
end = time.time()
cumulative = 0.
cumulative = 0.0
no_slices = len(data)
for i, d in enumerate(data):
result_id = f'{i+1}/{no_slices}'
logging.info(f'Computing partial pre-result {result_id}')
result_id = f"{i+1}/{no_slices}"
logging.info(f"Computing partial pre-result {result_id}")
d = client.persist(d)
progress(d)
print()
......@@ -45,16 +47,18 @@ class PointLocalAggregator(Aggregator):
end = time.time()
last = end - start
cumulative += last
eta = cumulative/(i+1)*no_slices
logging.info(f'Finished {result_id} in (last cum eta): '
f'{last:4.0f} {cumulative:4.0f} {eta:4.0f}')
eta = cumulative / (i + 1) * no_slices
logging.info(
f"Finished {result_id} in (last cum eta): "
f"{last:4.0f} {cumulative:4.0f} {eta:4.0f}"
)
data = da.stack(stack, axis=0)
else:
logging.info('Setting up pre-result in aggregate mode')
logging.info("Setting up pre-result in aggregate mode")
start = time.time()
data = client.persist(data)
end = time.time()
logging.info(f'Setup completed in {end - start:4.0f}')
logging.info(f"Setup completed in {end - start:4.0f}")
return data
def post_process(self, cube, data, coords, client, sliced_mode, **kwargs):
......@@ -75,7 +79,7 @@ class PointLocalAggregator(Aggregator):
change_units(cube, self.output_metadata.units, unit_standard_name)
cube.standard_name = standard_name
if proposed_standard_name is not None:
cube.attributes['proposed_standard_name'] = proposed_standard_name
cube.attributes["proposed_standard_name"] = proposed_standard_name
for extra_coord in self.index_function.extra_coords:
cube.add_aux_coord(extra_coord)
return cube
......
......@@ -8,6 +8,7 @@ import sys
import dask
from dask.distributed import Client, LocalCluster, wait, system
from dask.distributed import progress as distributed_progress
# from dask_jobqueue import SLURMCluster
import psutil
......@@ -23,15 +24,10 @@ def progress(fs):
def cpu_count_physical():
# Adapted from psutil
"""Return the number of physical cores in the system."""
IDS = ["physical_package_id",
"die_id",
"core_id",
"book_id",
"drawer_id"]
IDS = ["physical_package_id", "die_id", "core_id", "book_id", "drawer_id"]
# Method #1
core_ids = set()
for path in glob.glob(
"/sys/devices/system/cpu/cpu[0-9]*/topology"):
for path in glob.glob("/sys/devices/system/cpu/cpu[0-9]*/topology"):
core_id = []
for id in IDS:
id_path = os.path.join(path, id)
......@@ -58,21 +54,21 @@ def hyperthreading_info():
class DistributedLocalClusterScheduler:
def __init__(self, threads_per_worker=2, **kwargs):
(hyperthreading,
no_logical_cpus,
no_physical_cpus) = hyperthreading_info()
(hyperthreading, no_logical_cpus, no_physical_cpus) = hyperthreading_info()
if hyperthreading:
factor = no_logical_cpus // no_physical_cpus
no_available_physical_cpus = dask.system.CPU_COUNT // factor
n_workers = no_available_physical_cpus // threads_per_worker
memory_limit = (system.MEMORY_LIMIT*.9) / n_workers
memory_limit = (system.MEMORY_LIMIT * 0.9) / n_workers
else:
# let dask figure it out
n_workers = None
memory_limit = None
self.cluster = LocalCluster(n_workers=n_workers,
threads_per_worker=threads_per_worker,
memory_limit=memory_limit)
self.cluster = LocalCluster(
n_workers=n_workers,
threads_per_worker=threads_per_worker,
memory_limit=memory_limit,
)
self.client = Client(self.cluster)
def __enter__(self):
......@@ -103,7 +99,7 @@ class LocalThreadsScheduler:
self.client = None
def __enter__(self):
dask.config.set(scheduler='threads')
dask.config.set(scheduler="threads")
return self
def __exit__(self, type, value, traceback):
......@@ -113,13 +109,16 @@ class LocalThreadsScheduler:
class MPIScheduler:
def __init__(self, **kwargs):
from dask_mpi import initialize
n_workers = 4 # tasks-per-node from scheduler
n_threads = 4 # cpus-per-task from scheduler
memory_limit = (system.MEMORY_LIMIT*.9) / n_workers
initialize('ib0',
nthreads=n_threads,
local_directory='/scratch/local',
memory_limit=memory_limit,)
memory_limit = (system.MEMORY_LIMIT * 0.9) / n_workers
initialize(
"ib0",
nthreads=n_threads,
local_directory="/scratch/local",
memory_limit=memory_limit,
)
self.client = Client()
def __enter__(self):
......@@ -135,26 +134,27 @@ class SingleThreadedScheduler:
self.client = None
def __enter__(self):
dask.config.set(scheduler='single-threaded')
dask.config.set(scheduler="single-threaded")
return self
def __exit__(self, type, value, traceback):
pass
SCHEDULERS = OrderedDict([
('distributed-local-cluster', DistributedLocalClusterScheduler),
('external', ExternalScheduler),
('threaded', LocalThreadsScheduler),
('mpi', MPIScheduler),
('single-threaded', SingleThreadedScheduler),
])
SCHEDULERS = OrderedDict(
[
("distributed-local-cluster", DistributedLocalClusterScheduler),
("external", ExternalScheduler),
("threaded", LocalThreadsScheduler),
("mpi", MPIScheduler),
("single-threaded", SingleThreadedScheduler),
]
)
def setup_scheduler(args):
scheduler_spec = args.dask_scheduler.split('@')
scheduler_spec = args.dask_scheduler.split("@")
scheduler_name = scheduler_spec[0]
scheduler_kwargs = {k: v for k, v in (e.split('=')
for e in scheduler_spec[1:])}
scheduler_kwargs = {k: v for k, v in (e.split("=") for e in scheduler_spec[1:])}
scheduler = SCHEDULERS[scheduler_name]
return scheduler(**scheduler_kwargs)
......@@ -32,8 +32,10 @@ def dask_take_along_axis(x, index, axis):
assert 0 <= axis < x.ndim
assert (x.shape[:axis]+x.shape[axis+1:]
== index.shape[:axis]+index.shape[axis+1:])
assert (
x.shape[:axis] + x.shape[axis + 1 :]
== index.shape[:axis] + index.shape[axis + 1 :]
)
if np.isnan(x.chunks[axis]).any():
raise NotImplementedError(
......@@ -53,9 +55,9 @@ def dask_take_along_axis(x, index, axis):
# Define axis labels for blockwise
x_axes = tuple(range(x.ndim))
idx_label = (x.ndim,) # arbitrary unused
index_axes = x_axes[:axis] + idx_label + x_axes[axis+1:]
index_axes = x_axes[:axis] + idx_label + x_axes[axis + 1 :]
offset_axes = (axis,)
p_axes = x_axes[:axis + 1] + idx_label + x_axes[axis + 1:]
p_axes = x_axes[: axis + 1] + idx_label + x_axes[axis + 1 :]
# Calculate the cartesian product of every chunk of x vs
# every chunk of index
......@@ -70,9 +72,11 @@ def dask_take_along_axis(x, index, axis):
offset_axes,
x_size=x.shape[axis],
axis=axis,
meta=sparse.COO(np.empty((0, 0), dtype=int),
np.empty((), dtype=x.dtype),
shape=(0,)*len(p_axes)),
meta=sparse.COO(
np.empty((0, 0), dtype=int),
np.empty((), dtype=x.dtype),
shape=(0,) * len(p_axes),
),
dtype=x.dtype,
)
......
......@@ -22,10 +22,10 @@ def ignore_cb(cube, field, filename):
"""
Callback to ignore certain common global attributes in data files.
"""
cube.attributes.pop('creation_date', None)
cube.attributes.pop('tracking_id', None)
cube.attributes.pop('history', None)
cube.attributes.pop('history_of_appended_files', None)
cube.attributes.pop("creation_date", None)
cube.attributes.pop("tracking_id", None)
cube.attributes.pop("history", None)
cube.attributes.pop("history_of_appended_files", None)
def prepare_input_data(datafiles):
......@@ -70,14 +70,16 @@ def prepare_input_data(datafiles):
for var_name, cubes in cubes_per_var_name.items():
if len(cubes) > 1:
inconsistent_var_names.append(var_name)
logger.error('Found more than one cube for var_name "{}".\n'
'{}'.format(
var_name,
'\n'.join(map(lambda c: str(c), cubes))))
raise ValueError('Found too many cubes for var_names {}. '
'See log for details.'.format(inconsistent_var_names))
logger.error(
'Found more than one cube for var_name "{}".\n'
"{}".format(var_name, "\n".join(map(lambda c: str(c), cubes)))
)
raise ValueError(
"Found too many cubes for var_names {}. "
"See log for details.".format(inconsistent_var_names)
)
for c in cubes:
time = c.coord('time')
time = c.coord("time")
if not time.has_bounds():
time.guess_bounds()
return cubes
......@@ -118,44 +120,54 @@ def save(result, output_filename, iterative_storage=False, client=None):
"""
data = result.core_data().rechunk()
if iterative_storage:
logger.info('Storing iteratively')
logger.info('Creating empty data')
logger.info("Storing iteratively")
logger.info("Creating empty data")
result.data = np.zeros(data.shape, data.dtype)
logger.info('Saving empty cube')
iris.save(result, output_filename, fill_value=MISSVAL,
local_keys=['proposed_standard_name'])
logger.info('Reopening output file and beginning storage')
with netCDF4.Dataset(output_filename, 'a') as ds:
logger.info("Saving empty cube")
iris.save(
result,
output_filename,
fill_value=MISSVAL,
local_keys=["proposed_standard_name"],
)
logger.info("Reopening output file and beginning storage")
with netCDF4.Dataset(output_filename, "a") as ds:
var = ds[result.var_name]
time_dim = result.coord_dims('time')[0]
time_dim = result.coord_dims("time")[0]
no_slices = result.shape[time_dim]