Last active
May 16, 2024 16:10
-
-
Save vxgmichel/23469ee482aeba4a1c4d3cd66f1ac6c5 to your computer and use it in GitHub Desktop.
Async map with task limit using anyio
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import asyncio | |
from contextlib import asynccontextmanager | |
from typing import AsyncIterable, AsyncIterator, Awaitable, Callable | |
from anyio import create_memory_object_stream, create_task_group, abc | |
from anyio.streams.memory import MemoryObjectReceiveStream | |
@asynccontextmanager | |
async def amap[ | |
A, B | |
]( | |
source: AsyncIterable[A], | |
corofn: Callable[[A], Awaitable[B]], | |
task_limit: float = math.inf, | |
) -> AsyncIterator[MemoryObjectReceiveStream[B]]: | |
async def source_task(task_status: abc.TaskStatus) -> None: | |
async with send_item_stream: | |
task_status.started() | |
async for item in source: | |
await send_token_stream.send(None) | |
await task_group.start(item_task, item) | |
async def item_task(item: A, task_status: abc.TaskStatus) -> None: | |
try: | |
async with send_item_stream.clone() as cloned_stream: | |
task_status.started() | |
result = await corofn(item) | |
await cloned_stream.send(result) | |
finally: | |
await receive_token_stream.receive() | |
send_token_stream, receive_token_stream = create_memory_object_stream[None]( | |
max_buffer_size=task_limit | |
) | |
send_item_stream, receive_item_stream = create_memory_object_stream[B]() | |
async with receive_item_stream: | |
async with create_task_group() as task_group: | |
await task_group.start(source_task) | |
yield receive_item_stream | |
task_group.cancel_scope.cancel() | |
async def main(): | |
async def input_gen() -> AsyncIterator[str]: | |
for char in "abc123xyz789": | |
await asyncio.sleep(0.1) | |
yield char | |
async def slow_task(item: str) -> str: | |
await asyncio.sleep(0.5) | |
return f"{item}_loaded" | |
print("Running without task limit") | |
async with amap(input_gen(), slow_task) as items: | |
async for item in items: | |
print(f"Received: {item}") | |
print("Running with task limit = 1") | |
async with amap(input_gen(), slow_task, task_limit=1) as items: | |
async for item in items: | |
print(f"Received: {item}") | |
print("Stopping after 3_loaded is received") | |
async with amap(input_gen(), slow_task) as items: | |
async for item in items: | |
print(f"Received: {item}") | |
if item == "3_loaded": | |
break | |
if __name__ == "__main__": | |
asyncio.run(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment