Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
climix
climix
Commits
0e46b6d4
Commit
0e46b6d4
authored
May 06, 2021
by
Klaus Zimmermann
Browse files
Blackify source code (closes
#241
)
parent
a057e6df
Changes
23
Hide whitespace changes
Inline
Side-by-side
climix/__init__.py
View file @
0e46b6d4
# -*- coding: utf-8 -*-
from
._version
import
get_versions
__version__
=
get_versions
()[
'version'
]
__version__
=
get_versions
()[
"version"
]
del
get_versions
__all__
=
[
'
__version__
'
,
"
__version__
"
,
]
climix/_version.py
View file @
0e46b6d4
# This file helps to compute a version number in source trees obtained from
# git-archive tarball (such as those provided by githubs download-from-tag
# feature). Distribution tarballs (built by setup.py sdist) and build
...
...
@@ -58,17 +57,18 @@ HANDLERS = {}
def
register_vcs_handler
(
vcs
,
method
):
# decorator
"""Decorator to mark a method as the handler for a particular VCS."""
def
decorate
(
f
):
"""Store f in HANDLERS[vcs][method]."""
if
vcs
not
in
HANDLERS
:
HANDLERS
[
vcs
]
=
{}
HANDLERS
[
vcs
][
method
]
=
f
return
f
return
decorate
def
run_command
(
commands
,
args
,
cwd
=
None
,
verbose
=
False
,
hide_stderr
=
False
,
env
=
None
):
def
run_command
(
commands
,
args
,
cwd
=
None
,
verbose
=
False
,
hide_stderr
=
False
,
env
=
None
):
"""Call the given command(s)."""
assert
isinstance
(
commands
,
list
)
p
=
None
...
...
@@ -76,10 +76,13 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False,
try
:
dispcmd
=
str
([
c
]
+
args
)
# remember shell=False, so use git.cmd on windows, not just git
p
=
subprocess
.
Popen
([
c
]
+
args
,
cwd
=
cwd
,
env
=
env
,
stdout
=
subprocess
.
PIPE
,
stderr
=
(
subprocess
.
PIPE
if
hide_stderr
else
None
))
p
=
subprocess
.
Popen
(
[
c
]
+
args
,
cwd
=
cwd
,
env
=
env
,
stdout
=
subprocess
.
PIPE
,
stderr
=
(
subprocess
.
PIPE
if
hide_stderr
else
None
),
)
break
except
EnvironmentError
:
e
=
sys
.
exc_info
()[
1
]
...
...
@@ -116,16 +119,22 @@ def versions_from_parentdir(parentdir_prefix, root, verbose):
for
i
in
range
(
3
):
dirname
=
os
.
path
.
basename
(
root
)
if
dirname
.
startswith
(
parentdir_prefix
):
return
{
"version"
:
dirname
[
len
(
parentdir_prefix
):],
"full-revisionid"
:
None
,
"dirty"
:
False
,
"error"
:
None
,
"date"
:
None
}
return
{
"version"
:
dirname
[
len
(
parentdir_prefix
)
:],
"full-revisionid"
:
None
,
"dirty"
:
False
,
"error"
:
None
,
"date"
:
None
,
}
else
:
rootdirs
.
append
(
root
)
root
=
os
.
path
.
dirname
(
root
)
# up a level
if
verbose
:
print
(
"Tried directories %s but none started with prefix %s"
%
(
str
(
rootdirs
),
parentdir_prefix
))
print
(
"Tried directories %s but none started with prefix %s"
%
(
str
(
rootdirs
),
parentdir_prefix
)
)
raise
NotThisMethod
(
"rootdir doesn't start with parentdir_prefix"
)
...
...
@@ -181,7 +190,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
# starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of
# just "foo-1.0". If we see a "tag: " prefix, prefer those.
TAG
=
"tag: "
tags
=
set
([
r
[
len
(
TAG
):]
for
r
in
refs
if
r
.
startswith
(
TAG
)])
tags
=
set
([
r
[
len
(
TAG
)
:]
for
r
in
refs
if
r
.
startswith
(
TAG
)])
if
not
tags
:
# Either we're using git < 1.8.3, or there really are no tags. We use
# a heuristic: assume all version tags have a digit. The old git %d
...
...
@@ -190,7 +199,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
# between branches and tags. By ignoring refnames without digits, we
# filter out many common branch names like "release" and
# "stabilization", as well as "HEAD" and "master".
tags
=
set
([
r
for
r
in
refs
if
re
.
search
(
r
'
\d
'
,
r
)])
tags
=
set
([
r
for
r
in
refs
if
re
.
search
(
r
"
\d
"
,
r
)])
if
verbose
:
print
(
"discarding '%s', no digits"
%
","
.
join
(
refs
-
tags
))
if
verbose
:
...
...
@@ -198,19 +207,26 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
for
ref
in
sorted
(
tags
):
# sorting will prefer e.g. "2.0" over "2.0rc1"
if
ref
.
startswith
(
tag_prefix
):
r
=
ref
[
len
(
tag_prefix
):]
r
=
ref
[
len
(
tag_prefix
)
:]
if
verbose
:
print
(
"picking %s"
%
r
)
return
{
"version"
:
r
,
"full-revisionid"
:
keywords
[
"full"
].
strip
(),
"dirty"
:
False
,
"error"
:
None
,
"date"
:
date
}
return
{
"version"
:
r
,
"full-revisionid"
:
keywords
[
"full"
].
strip
(),
"dirty"
:
False
,
"error"
:
None
,
"date"
:
date
,
}
# no suitable tags, so version is "0+unknown", but full hex is still there
if
verbose
:
print
(
"no suitable tags, using unknown + full revision id"
)
return
{
"version"
:
"0+unknown"
,
"full-revisionid"
:
keywords
[
"full"
].
strip
(),
"dirty"
:
False
,
"error"
:
"no suitable tags"
,
"date"
:
None
}
return
{
"version"
:
"0+unknown"
,
"full-revisionid"
:
keywords
[
"full"
].
strip
(),
"dirty"
:
False
,
"error"
:
"no suitable tags"
,
"date"
:
None
,
}
@
register_vcs_handler
(
"git"
,
"pieces_from_vcs"
)
...
...
@@ -225,8 +241,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
if
sys
.
platform
==
"win32"
:
GITS
=
[
"git.cmd"
,
"git.exe"
]
out
,
rc
=
run_command
(
GITS
,
[
"rev-parse"
,
"--git-dir"
],
cwd
=
root
,
hide_stderr
=
True
)
out
,
rc
=
run_command
(
GITS
,
[
"rev-parse"
,
"--git-dir"
],
cwd
=
root
,
hide_stderr
=
True
)
if
rc
!=
0
:
if
verbose
:
print
(
"Directory %s not under git control"
%
root
)
...
...
@@ -234,10 +249,19 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
# if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty]
# if there isn't one, this yields HEX[-dirty] (no NUM)
describe_out
,
rc
=
run_command
(
GITS
,
[
"describe"
,
"--tags"
,
"--dirty"
,
"--always"
,
"--long"
,
"--match"
,
"%s*"
%
tag_prefix
],
cwd
=
root
)
describe_out
,
rc
=
run_command
(
GITS
,
[
"describe"
,
"--tags"
,
"--dirty"
,
"--always"
,
"--long"
,
"--match"
,
"%s*"
%
tag_prefix
,
],
cwd
=
root
,
)
# --long was added in git-1.5.5
if
describe_out
is
None
:
raise
NotThisMethod
(
"'git describe' failed"
)
...
...
@@ -260,17 +284,16 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
dirty
=
git_describe
.
endswith
(
"-dirty"
)
pieces
[
"dirty"
]
=
dirty
if
dirty
:
git_describe
=
git_describe
[:
git_describe
.
rindex
(
"-dirty"
)]
git_describe
=
git_describe
[:
git_describe
.
rindex
(
"-dirty"
)]
# now we have TAG-NUM-gHEX or HEX
if
"-"
in
git_describe
:
# TAG-NUM-gHEX
mo
=
re
.
search
(
r
'
^(.+)-(\d+)-g([0-9a-f]+)$
'
,
git_describe
)
mo
=
re
.
search
(
r
"
^(.+)-(\d+)-g([0-9a-f]+)$
"
,
git_describe
)
if
not
mo
:
# unparseable. Maybe git-describe is misbehaving?
pieces
[
"error"
]
=
(
"unable to parse git-describe output: '%s'"
%
describe_out
)
pieces
[
"error"
]
=
"unable to parse git-describe output: '%s'"
%
describe_out
return
pieces
# tag
...
...
@@ -279,10 +302,12 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
if
verbose
:
fmt
=
"tag '%s' doesn't start with prefix '%s'"
print
(
fmt
%
(
full_tag
,
tag_prefix
))
pieces
[
"error"
]
=
(
"tag '%s' doesn't start with prefix '%s'"
%
(
full_tag
,
tag_prefix
))
pieces
[
"error"
]
=
"tag '%s' doesn't start with prefix '%s'"
%
(
full_tag
,
tag_prefix
,
)
return
pieces
pieces
[
"closest-tag"
]
=
full_tag
[
len
(
tag_prefix
):]
pieces
[
"closest-tag"
]
=
full_tag
[
len
(
tag_prefix
)
:]
# distance: number of commits since tag
pieces
[
"distance"
]
=
int
(
mo
.
group
(
2
))
...
...
@@ -293,13 +318,13 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
else
:
# HEX: no tags
pieces
[
"closest-tag"
]
=
None
count_out
,
rc
=
run_command
(
GITS
,
[
"rev-list"
,
"HEAD"
,
"--count"
],
cwd
=
root
)
count_out
,
rc
=
run_command
(
GITS
,
[
"rev-list"
,
"HEAD"
,
"--count"
],
cwd
=
root
)
pieces
[
"distance"
]
=
int
(
count_out
)
# total number of commits
# commit date: see ISO-8601 comment in git_versions_from_keywords()
date
=
run_command
(
GITS
,
[
"show"
,
"-s"
,
"--format=%ci"
,
"HEAD"
],
cwd
=
root
)[
0
].
strip
()
date
=
run_command
(
GITS
,
[
"show"
,
"-s"
,
"--format=%ci"
,
"HEAD"
],
cwd
=
root
)[
0
].
strip
()
pieces
[
"date"
]
=
date
.
strip
().
replace
(
" "
,
"T"
,
1
).
replace
(
" "
,
""
,
1
)
return
pieces
...
...
@@ -330,8 +355,7 @@ def render_pep440(pieces):
rendered
+=
".dirty"
else
:
# exception #1
rendered
=
"0+untagged.%d.g%s"
%
(
pieces
[
"distance"
],
pieces
[
"short"
])
rendered
=
"0+untagged.%d.g%s"
%
(
pieces
[
"distance"
],
pieces
[
"short"
])
if
pieces
[
"dirty"
]:
rendered
+=
".dirty"
return
rendered
...
...
@@ -445,11 +469,13 @@ def render_git_describe_long(pieces):
def
render
(
pieces
,
style
):
"""Render the given version pieces into the requested style."""
if
pieces
[
"error"
]:
return
{
"version"
:
"unknown"
,
"full-revisionid"
:
pieces
.
get
(
"long"
),
"dirty"
:
None
,
"error"
:
pieces
[
"error"
],
"date"
:
None
}
return
{
"version"
:
"unknown"
,
"full-revisionid"
:
pieces
.
get
(
"long"
),
"dirty"
:
None
,
"error"
:
pieces
[
"error"
],
"date"
:
None
,
}
if
not
style
or
style
==
"default"
:
style
=
"pep440"
# the default
...
...
@@ -469,9 +495,13 @@ def render(pieces, style):
else
:
raise
ValueError
(
"unknown style '%s'"
%
style
)
return
{
"version"
:
rendered
,
"full-revisionid"
:
pieces
[
"long"
],
"dirty"
:
pieces
[
"dirty"
],
"error"
:
None
,
"date"
:
pieces
.
get
(
"date"
)}
return
{
"version"
:
rendered
,
"full-revisionid"
:
pieces
[
"long"
],
"dirty"
:
pieces
[
"dirty"
],
"error"
:
None
,
"date"
:
pieces
.
get
(
"date"
),
}
def
get_versions
():
...
...
@@ -485,8 +515,7 @@ def get_versions():
verbose
=
cfg
.
verbose
try
:
return
git_versions_from_keywords
(
get_keywords
(),
cfg
.
tag_prefix
,
verbose
)
return
git_versions_from_keywords
(
get_keywords
(),
cfg
.
tag_prefix
,
verbose
)
except
NotThisMethod
:
pass
...
...
@@ -495,13 +524,16 @@ def get_versions():
# versionfile_source is the relative path from the top of the source
# tree (where the .git directory might live) to this file. Invert
# this to find the root from __file__.
for
i
in
cfg
.
versionfile_source
.
split
(
'/'
):
for
i
in
cfg
.
versionfile_source
.
split
(
"/"
):
root
=
os
.
path
.
dirname
(
root
)
except
NameError
:
return
{
"version"
:
"0+unknown"
,
"full-revisionid"
:
None
,
"dirty"
:
None
,
"error"
:
"unable to find root of source tree"
,
"date"
:
None
}
return
{
"version"
:
"0+unknown"
,
"full-revisionid"
:
None
,
"dirty"
:
None
,
"error"
:
"unable to find root of source tree"
,
"date"
:
None
,
}
try
:
pieces
=
git_pieces_from_vcs
(
cfg
.
tag_prefix
,
root
,
verbose
)
...
...
@@ -515,6 +547,10 @@ def get_versions():
except
NotThisMethod
:
pass
return
{
"version"
:
"0+unknown"
,
"full-revisionid"
:
None
,
"dirty"
:
None
,
"error"
:
"unable to compute version"
,
"date"
:
None
}
return
{
"version"
:
"0+unknown"
,
"full-revisionid"
:
None
,
"dirty"
:
None
,
"error"
:
"unable to compute version"
,
"date"
:
None
,
}
climix/aggregators.py
View file @
0e46b6d4
...
...
@@ -12,13 +12,15 @@ from .util import change_units
class
PointLocalAggregator
(
Aggregator
):
def
__init__
(
self
,
index_function
,
output_metadata
,
**
kwargs
):
cell_method
=
'
max_spell
'
cell_method
=
"
max_spell
"
self
.
index_function
=
index_function
self
.
output_metadata
=
output_metadata
super
().
__init__
(
cell_method
=
cell_method
,
call_func
=
self
.
index_function
.
call_func
,
lazy_func
=
self
.
index_function
.
lazy_func
,
**
kwargs
)
super
().
__init__
(
cell_method
=
cell_method
,
call_func
=
self
.
index_function
.
call_func
,
lazy_func
=
self
.
index_function
.
lazy_func
,
**
kwargs
,
)
def
update_metadata
(
self
,
cube
,
coords
,
**
kwargs
):
super
().
update_metadata
(
cube
,
coords
,
**
kwargs
)
...
...
@@ -29,14 +31,14 @@ class PointLocalAggregator(Aggregator):
def
compute_pre_result
(
self
,
data
,
client
,
sliced_mode
):
if
sliced_mode
:
logging
.
info
(
'
Computing pre-result in sliced mode
'
)
logging
.
info
(
"
Computing pre-result in sliced mode
"
)
stack
=
[]
end
=
time
.
time
()
cumulative
=
0.
cumulative
=
0.
0
no_slices
=
len
(
data
)
for
i
,
d
in
enumerate
(
data
):
result_id
=
f
'
{
i
+
1
}
/
{
no_slices
}
'
logging
.
info
(
f
'
Computing partial pre-result
{
result_id
}
'
)
result_id
=
f
"
{
i
+
1
}
/
{
no_slices
}
"
logging
.
info
(
f
"
Computing partial pre-result
{
result_id
}
"
)
d
=
client
.
persist
(
d
)
progress
(
d
)
print
()
...
...
@@ -45,16 +47,18 @@ class PointLocalAggregator(Aggregator):
end
=
time
.
time
()
last
=
end
-
start
cumulative
+=
last
eta
=
cumulative
/
(
i
+
1
)
*
no_slices
logging
.
info
(
f
'Finished
{
result_id
}
in (last cum eta): '
f
'
{
last
:
4.0
f
}
{
cumulative
:
4.0
f
}
{
eta
:
4.0
f
}
'
)
eta
=
cumulative
/
(
i
+
1
)
*
no_slices
logging
.
info
(
f
"Finished
{
result_id
}
in (last cum eta): "
f
"
{
last
:
4.0
f
}
{
cumulative
:
4.0
f
}
{
eta
:
4.0
f
}
"
)
data
=
da
.
stack
(
stack
,
axis
=
0
)
else
:
logging
.
info
(
'
Setting up pre-result in aggregate mode
'
)
logging
.
info
(
"
Setting up pre-result in aggregate mode
"
)
start
=
time
.
time
()
data
=
client
.
persist
(
data
)
end
=
time
.
time
()
logging
.
info
(
f
'
Setup completed in
{
end
-
start
:
4.0
f
}
'
)
logging
.
info
(
f
"
Setup completed in
{
end
-
start
:
4.0
f
}
"
)
return
data
def
post_process
(
self
,
cube
,
data
,
coords
,
client
,
sliced_mode
,
**
kwargs
):
...
...
@@ -75,7 +79,7 @@ class PointLocalAggregator(Aggregator):
change_units
(
cube
,
self
.
output_metadata
.
units
,
unit_standard_name
)
cube
.
standard_name
=
standard_name
if
proposed_standard_name
is
not
None
:
cube
.
attributes
[
'
proposed_standard_name
'
]
=
proposed_standard_name
cube
.
attributes
[
"
proposed_standard_name
"
]
=
proposed_standard_name
for
extra_coord
in
self
.
index_function
.
extra_coords
:
cube
.
add_aux_coord
(
extra_coord
)
return
cube
...
...
climix/dask_setup.py
View file @
0e46b6d4
...
...
@@ -8,6 +8,7 @@ import sys
import
dask
from
dask.distributed
import
Client
,
LocalCluster
,
wait
,
system
from
dask.distributed
import
progress
as
distributed_progress
# from dask_jobqueue import SLURMCluster
import
psutil
...
...
@@ -23,15 +24,10 @@ def progress(fs):
def
cpu_count_physical
():
# Adapted from psutil
"""Return the number of physical cores in the system."""
IDS
=
[
"physical_package_id"
,
"die_id"
,
"core_id"
,
"book_id"
,
"drawer_id"
]
IDS
=
[
"physical_package_id"
,
"die_id"
,
"core_id"
,
"book_id"
,
"drawer_id"
]
# Method #1
core_ids
=
set
()
for
path
in
glob
.
glob
(
"/sys/devices/system/cpu/cpu[0-9]*/topology"
):
for
path
in
glob
.
glob
(
"/sys/devices/system/cpu/cpu[0-9]*/topology"
):
core_id
=
[]
for
id
in
IDS
:
id_path
=
os
.
path
.
join
(
path
,
id
)
...
...
@@ -58,21 +54,21 @@ def hyperthreading_info():
class
DistributedLocalClusterScheduler
:
def
__init__
(
self
,
threads_per_worker
=
2
,
**
kwargs
):
(
hyperthreading
,
no_logical_cpus
,
no_physical_cpus
)
=
hyperthreading_info
()
(
hyperthreading
,
no_logical_cpus
,
no_physical_cpus
)
=
hyperthreading_info
()
if
hyperthreading
:
factor
=
no_logical_cpus
//
no_physical_cpus
no_available_physical_cpus
=
dask
.
system
.
CPU_COUNT
//
factor
n_workers
=
no_available_physical_cpus
//
threads_per_worker
memory_limit
=
(
system
.
MEMORY_LIMIT
*
.
9
)
/
n_workers
memory_limit
=
(
system
.
MEMORY_LIMIT
*
0
.9
)
/
n_workers
else
:
# let dask figure it out
n_workers
=
None
memory_limit
=
None
self
.
cluster
=
LocalCluster
(
n_workers
=
n_workers
,
threads_per_worker
=
threads_per_worker
,
memory_limit
=
memory_limit
)
self
.
cluster
=
LocalCluster
(
n_workers
=
n_workers
,
threads_per_worker
=
threads_per_worker
,
memory_limit
=
memory_limit
,
)
self
.
client
=
Client
(
self
.
cluster
)
def
__enter__
(
self
):
...
...
@@ -103,7 +99,7 @@ class LocalThreadsScheduler:
self
.
client
=
None
def
__enter__
(
self
):
dask
.
config
.
set
(
scheduler
=
'
threads
'
)
dask
.
config
.
set
(
scheduler
=
"
threads
"
)
return
self
def
__exit__
(
self
,
type
,
value
,
traceback
):
...
...
@@ -113,13 +109,16 @@ class LocalThreadsScheduler:
class
MPIScheduler
:
def
__init__
(
self
,
**
kwargs
):
from
dask_mpi
import
initialize
n_workers
=
4
# tasks-per-node from scheduler
n_threads
=
4
# cpus-per-task from scheduler
memory_limit
=
(
system
.
MEMORY_LIMIT
*
.
9
)
/
n_workers
initialize
(
'ib0'
,
nthreads
=
n_threads
,
local_directory
=
'/scratch/local'
,
memory_limit
=
memory_limit
,)
memory_limit
=
(
system
.
MEMORY_LIMIT
*
0.9
)
/
n_workers
initialize
(
"ib0"
,
nthreads
=
n_threads
,
local_directory
=
"/scratch/local"
,
memory_limit
=
memory_limit
,
)
self
.
client
=
Client
()
def
__enter__
(
self
):
...
...
@@ -135,26 +134,27 @@ class SingleThreadedScheduler:
self
.
client
=
None
def
__enter__
(
self
):
dask
.
config
.
set
(
scheduler
=
'
single-threaded
'
)
dask
.
config
.
set
(
scheduler
=
"
single-threaded
"
)
return
self
def
__exit__
(
self
,
type
,
value
,
traceback
):
pass
SCHEDULERS
=
OrderedDict
([
(
'distributed-local-cluster'
,
DistributedLocalClusterScheduler
),
(
'external'
,
ExternalScheduler
),
(
'threaded'
,
LocalThreadsScheduler
),
(
'mpi'
,
MPIScheduler
),
(
'single-threaded'
,
SingleThreadedScheduler
),
])
SCHEDULERS
=
OrderedDict
(
[
(
"distributed-local-cluster"
,
DistributedLocalClusterScheduler
),
(
"external"
,
ExternalScheduler
),
(
"threaded"
,
LocalThreadsScheduler
),
(
"mpi"
,
MPIScheduler
),
(
"single-threaded"
,
SingleThreadedScheduler
),
]
)
def
setup_scheduler
(
args
):
scheduler_spec
=
args
.
dask_scheduler
.
split
(
'@'
)
scheduler_spec
=
args
.
dask_scheduler
.
split
(
"@"
)
scheduler_name
=
scheduler_spec
[
0
]
scheduler_kwargs
=
{
k
:
v
for
k
,
v
in
(
e
.
split
(
'='
)
for
e
in
scheduler_spec
[
1
:])}
scheduler_kwargs
=
{
k
:
v
for
k
,
v
in
(
e
.
split
(
"="
)
for
e
in
scheduler_spec
[
1
:])}
scheduler
=
SCHEDULERS
[
scheduler_name
]
return
scheduler
(
**
scheduler_kwargs
)
climix/dask_take_along_axis.py
View file @
0e46b6d4
...
...
@@ -32,8 +32,10 @@ def dask_take_along_axis(x, index, axis):
assert
0
<=
axis
<
x
.
ndim
assert
(
x
.
shape
[:
axis
]
+
x
.
shape
[
axis
+
1
:]
==
index
.
shape
[:
axis
]
+
index
.
shape
[
axis
+
1
:])
assert
(
x
.
shape
[:
axis
]
+
x
.
shape
[
axis
+
1
:]
==
index
.
shape
[:
axis
]
+
index
.
shape
[
axis
+
1
:]
)
if
np
.
isnan
(
x
.
chunks
[
axis
]).
any
():
raise
NotImplementedError
(
...
...
@@ -53,9 +55,9 @@ def dask_take_along_axis(x, index, axis):
# Define axis labels for blockwise
x_axes
=
tuple
(
range
(
x
.
ndim
))
idx_label
=
(
x
.
ndim
,)
# arbitrary unused
index_axes
=
x_axes
[:
axis
]
+
idx_label
+
x_axes
[
axis
+
1
:]
index_axes
=
x_axes
[:
axis
]
+
idx_label
+
x_axes
[
axis
+
1
:]
offset_axes
=
(
axis
,)
p_axes
=
x_axes
[:
axis
+
1
]
+
idx_label
+
x_axes
[
axis
+
1
:]
p_axes
=
x_axes
[:
axis
+
1
]
+
idx_label
+
x_axes
[
axis
+
1
:]
# Calculate the cartesian product of every chunk of x vs
# every chunk of index
...
...
@@ -70,9 +72,11 @@ def dask_take_along_axis(x, index, axis):
offset_axes
,
x_size
=
x
.
shape
[
axis
],
axis
=
axis
,
meta
=
sparse
.
COO
(
np
.
empty
((
0
,
0
),
dtype
=
int
),
np
.
empty
((),
dtype
=
x
.
dtype
),
shape
=
(
0
,)
*
len
(
p_axes
)),
meta
=
sparse
.
COO
(
np
.
empty
((
0
,
0
),
dtype
=
int
),
np
.
empty
((),
dtype
=
x
.
dtype
),
shape
=
(
0
,)
*
len
(
p_axes
),
),
dtype
=
x
.
dtype
,
)
...
...
climix/datahandling.py
View file @
0e46b6d4
...
...
@@ -22,10 +22,10 @@ def ignore_cb(cube, field, filename):
"""
Callback to ignore certain common global attributes in data files.
"""
cube
.
attributes
.
pop
(
'
creation_date
'
,
None
)
cube
.
attributes
.
pop
(
'
tracking_id
'
,
None
)
cube
.
attributes
.
pop
(
'
history
'
,
None
)
cube
.
attributes
.
pop
(
'
history_of_appended_files
'
,
None
)
cube
.
attributes
.
pop
(
"
creation_date
"
,
None
)
cube
.
attributes
.
pop
(
"
tracking_id
"
,
None
)
cube
.
attributes
.
pop
(
"
history
"
,
None
)
cube
.
attributes
.
pop
(
"
history_of_appended_files
"
,
None
)
def
prepare_input_data
(
datafiles
):
...
...
@@ -70,14 +70,16 @@ def prepare_input_data(datafiles):
for
var_name
,
cubes
in
cubes_per_var_name
.
items
():
if
len
(
cubes
)
>
1
:
inconsistent_var_names
.
append
(
var_name
)
logger
.
error
(
'Found more than one cube for var_name "{}".
\n
'
'{}'
.
format
(
var_name
,
'
\n
'
.
join
(
map
(
lambda
c
:
str
(
c
),
cubes
))))
raise
ValueError
(
'Found too many cubes for var_names {}. '
'See log for details.'
.
format
(
inconsistent_var_names
))
logger
.
error
(
'Found more than one cube for var_name "{}".
\n
'
"{}"
.
format
(
var_name
,
"
\n
"
.
join
(
map
(
lambda
c
:
str
(
c
),
cubes
)))
)
raise
ValueError
(
"Found too many cubes for var_names {}. "
"See log for details."
.
format
(
inconsistent_var_names
)
)
for
c
in
cubes
:
time
=
c
.
coord
(
'
time
'
)
time
=
c
.
coord
(
"
time
"
)
if
not
time
.
has_bounds
():
time
.
guess_bounds
()
return
cubes
...
...
@@ -118,44 +120,54 @@ def save(result, output_filename, iterative_storage=False, client=None):
"""
data
=
result
.
core_data
().
rechunk
()
if
iterative_storage
:
logger
.
info
(
'
Storing iteratively
'
)
logger
.
info
(
'
Creating empty data
'
)
logger
.
info
(
"
Storing iteratively
"
)
logger
.
info
(
"
Creating empty data
"
)
result
.
data
=
np
.
zeros
(
data
.
shape
,
data
.
dtype
)
logger
.
info
(
'Saving empty cube'
)
iris
.
save
(
result
,
output_filename
,
fill_value
=
MISSVAL
,
local_keys
=
[
'proposed_standard_name'
])
logger
.
info
(
'Reopening output file and beginning storage'
)
with
netCDF4
.
Dataset
(
output_filename
,
'a'
)
as
ds
:
logger
.
info
(
"Saving empty cube"
)
iris
.
save
(
result
,
output_filename
,
fill_value
=
MISSVAL
,
local_keys
=
[
"proposed_standard_name"
],
)
logger
.
info
(
"Reopening output file and beginning storage"
)
with
netCDF4
.
Dataset
(
output_filename
,
"a"
)
as
ds
:
var
=
ds
[
result
.
var_name
]
time_dim
=
result
.
coord_dims
(
'
time
'
)[
0
]
time_dim
=
result
.
coord_dims
(
"
time
"
)[
0
]
no_slices
=
result
.
shape
[
time_dim
]
end
=
time
.
time
()
cumulative
=
0.