Skip to content

Instantly share code, notes, and snippets.

@thewisenerd
Created June 11, 2025 23:16
Show Gist options
  • Save thewisenerd/60be82e9dc2f611ba6a540e008579314 to your computer and use it in GitHub Desktop.
Save thewisenerd/60be82e9dc2f611ba6a540e008579314 to your computer and use it in GitHub Desktop.
a dynamically sized worker pool (of sorts)
import asyncio
import typing
from dataclasses import dataclass
# NOTE: this effectively allows for unbounded coroutine creation
# gated by the semaphore concurrency. check memory.
T = typing.TypeVar("T")
R = typing.TypeVar("R")
C = typing.TypeVar("C")
class Result(typing.Generic[T]):
def __init__(self, ok: bool, value: T | None, reason: str | None = None):
self.ok = ok
self.value = value
self.reason = reason
@staticmethod
def success(value: T) -> "Result[T]":
return Result(ok=True, value=value, reason=None)
@staticmethod
def error(reason: str) -> "Result[T]":
return Result(ok=False, value=None, reason=reason)
@dataclass
class TaskResult(typing.Generic[T, R, C]):
task: T
result: Result[R] | None
children: list[C]
TaskType = typing.Union[RootTask, ChildTask] # example..
async def worker(
sem: asyncio.Semaphore,
task: TaskType,
) -> TaskResult[TaskType, Result[Project], TaskType]:
async with sem:
return await task_impl(task)
async def main_impl(
concurrency: int
) -> typing.AsyncGenerator[Result[Project], None]:
if concurrency < 1:
raise ValueError("concurrency must be >= 1")
sem = asyncio.Semaphore(concurrency)
pending = set()
total = 0
completed = 0
task = asyncio.create_task(worker(sem, RootTask()))
pending.add(task)
total += 1
pbar = tqdm.tqdm(total=1, bar_format='[{bar}] {n_fmt}/{total_fmt}')
while pending:
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
for task in done:
task_result = await task
completed += 1
pbar.update(1)
if not isinstance(task_result, TaskResult):
raise TypeError(f"Expected TaskResult, got {type(task_result)}")
if task_result.result is not None:
yield task_result.result
if task_result.children:
for child_task in task_result.children:
new_task = asyncio.create_task(worker(sem, child_task))
pending.add(new_task)
total += 1
pbar.total = total
pbar.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment