Last active
March 10, 2021 08:57
-
-
Save jkbjh/fd807bd4eb4afab28727efbb82c1f0d6 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# a selection of small numpy helper functions | |
import joblib | |
def apply_along_axes(func, data, axes): | |
""" | |
apply function along axes and use joblib parallel to execute the function on the selections. | |
might result in a large allocation for reconstructing the result. | |
only makes sense if the sub-executions are sufficiently costly. | |
""" | |
sizes = np.array(data.shape)[list(axes)] | |
it = np.nditer(data, flags=["multi_index"], op_axes=[axes]) | |
indices = [] | |
delayed = [] | |
for i in it: | |
indices.append(it.multi_index) | |
delayed.append(joblib.delayed(func)(data[it.multi_index])) | |
results = joblib.Parallel()(delayed) | |
retdat = np.empty(sizes, dtype=np.dtype(results[0])) | |
for i, val in zip(indices, results): | |
retdat[i] = val | |
return retdat | |
def unravel(data): | |
""" | |
merge arrays of lists of arrays (of same lengths and shapes) into a larger array. | |
Might need to be applied repeatedly. | |
Might be slow. | |
""" | |
cell = next(np.nditer(data, flags=["refs_ok"])).item() | |
shape = list(data.shape) | |
if hasattr(cell, "shape") and len(cell.shape) > 0: | |
shape += list(cell.shape) | |
else: | |
return data | |
e = np.empty(shape) | |
it = np.nditer(data, flags=["multi_index", "refs_ok"], op_axes=[list(range(len(data.shape)))]) | |
for i in it: | |
e[it.multi_index] = data[it.multi_index] | |
return e |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment