Last active
April 5, 2025 03:48
-
-
Save antonagestam/8476ada7d74cce93af0339cf32c62ae2 to your computer and use it in GitHub Desktop.
Merge results from multiple async generators into one single stream.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import random | |
from typing import TypeVar, AsyncGenerator | |
T = TypeVar("T") | |
async def read_into_queue( | |
task: AsyncGenerator[T, None], | |
queue: asyncio.Queue[T], | |
done: asyncio.Semaphore, | |
) -> None: | |
async for item in task: | |
await queue.put(item) | |
# All items from this task are in the queue, decrease semaphore by one. | |
await done.acquire() | |
async def join(*generators: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]: | |
queue = asyncio.Queue(maxsize=1) | |
done_semaphore = asyncio.Semaphore(len(generators)) | |
# Read from each given generator into the shared queue. | |
produce_tasks = [ | |
asyncio.create_task(read_into_queue(task, queue, done_semaphore)) | |
for task in generators | |
] | |
# Read items off the queue until it is empty and the semaphore value is down to zero. | |
while not done_semaphore.locked() or not queue.empty(): | |
try: | |
yield await asyncio.wait_for(queue.get(), .001) | |
except TimeoutError: | |
continue | |
# Not strictly needed, but usually a good idea to await tasks, they are already finished here. | |
try: | |
await asyncio.wait_for(asyncio.gather(*produce_tasks), 0) | |
except TimeoutError: | |
raise NotImplementedError("Impossible state: expected all tasks to be exhausted") | |
# --- | |
async def produce(i) -> AsyncGenerator[int, None]: | |
for i in range(i, i + 5): | |
yield i | |
await asyncio.sleep(2 * random.random()) | |
async def consume(source: AsyncGenerator[int, None]) -> None: | |
async for item in source: | |
print(f"{item=}") | |
@asyncio.run | |
@lambda fn: fn() | |
async def main() -> None: | |
await consume( | |
join( | |
produce(10), | |
produce(20), | |
produce(30), | |
) | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
See the finally above. This ensures that generators raising will both raise, and mark the semaphore done.