Skip to content

Instantly share code, notes, and snippets.

@P403n1x87
Last active November 17, 2019 17:28
Show Gist options
  • Save P403n1x87/6ecfc3d0422d2662b4a3f79126033cd4 to your computer and use it in GitHub Desktop.
Save P403n1x87/6ecfc3d0422d2662b4a3f79126033cd4 to your computer and use it in GitHub Desktop.
from multiprocessing import Pool, Queue, Manager, cpu_count
from multiprocessing.queues import Empty
from threading import Thread
from tqdm import tqdm
def parallelize(func: callable, iterable: list, processes: int = None) -> list:
"""Parallelize the execution of a function over a list.
This method runs a given function over chunks of the given list in
parallel across multiple processes. The return value is an ordered list
with the result of the function on each chunk.
Progress of the overall process is shown via a tqdm progress bar. In order
to send update messages, the give function must have signature
``(chunk, queue)``, where ``chunk`` is the chunk from the given list and
``queue`` is an increment queue. To step the progress bar by ``n``, call
queue.put(n)
It makes sense to wrap a function around this method only if the result
of the whole function can be reconstructed from the result on each chunk.
"""
def update():
pbar = tqdm(total=len(iterable))
while True:
try:
increment = queue.get(timeout=1)
if not increment:
break
pbar.update(increment)
except Empty:
pass
pbar.close()
if processes == 1:
raise RuntimeError("Call the function directly!")
processes = processes or cpu_count()
chunk_size = len(iterable) // processes
queue = Manager().Queue()
pbar_thread = Thread(target=update)
results = []
with Pool(processes=processes) as pool:
for i in range(processes - 1):
results.append(
pool.apply_async(
func, args=(iterable[i * chunk_size : (i + 1) * chunk_size], queue)
)
)
results.append(
pool.apply_async(func, args=(iterable[(i + 1) * chunk_size :], queue))
)
pbar_thread.start()
pool.close()
pool.join()
queue.put(False)
pbar_thread.join()
return [result.get() for result in results]
# ---- EXAMPLE ----
if __name__ == "__main__":
import time
def parallel_sum(arg, queue):
a = 0
for i in arg:
a += i
time.sleep(0.1)
queue.put(1)
return a
n = 200
result = parallelize(func=parallel_sum, iterable=list(range(n)))
assert sum(result) == ((n * (n - 1)) >> 1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment