Created
July 5, 2023 01:24
-
-
Save GRAYgoose124/4083e0203bffd16a7e014cd957e05b50 to your computer and use it in GitHub Desktop.
ThreadPool fun
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from abc import ABCMeta | |
from dataclasses import dataclass, asdict | |
import enum | |
import threading | |
import json | |
import logging | |
from queue import Queue | |
import time | |
from typing import Literal | |
import typing | |
import uuid | |
import pydantic | |
logger = logging.getLogger(__name__) | |
class BaseActionEnum(str, enum.Enum): | |
pass | |
@dataclass | |
class WorkRequest: | |
id: str | |
call: BaseActionEnum | |
args: tuple | |
class EOQTerminateThread(object): | |
pass | |
class QueueingWorkThread(threading.Thread, metaclass=ABCMeta): | |
def __init__(self, queue: Queue, actions: dict): | |
super().__init__() | |
self._queue = queue | |
self._actions = actions | |
self.delay = 0.0 | |
def run(self): | |
while True: | |
item = self._queue.get() | |
if item is EOQTerminateThread: | |
break | |
if item.call in self._actions: | |
action = self._actions[item.call] | |
result = action(*item.args) | |
logger.debug(f"{action.__name__}({item.args}) return {result}") | |
self._queue.task_done() | |
self.delay = max(0.0, self.delay - 0.025) | |
else: | |
# TODO: Rather than just putting it back on the queue, we should probably request | |
# invalidating the request. item needs to become Request and such. | |
self._queue.put(item) | |
logger.error( | |
f"UKN {item.call=} | {self.__class__.__name__}-{self.ident}" | |
) | |
time.sleep(self.delay) | |
self.delay = min(20.0, self.delay + 0.33) | |
def get_work(self) -> WorkRequest: | |
return self._queue.get() | |
def submit_work(self, item: WorkRequest): | |
self._queue.put(item) | |
def create_str_enum_type(name: str, values: list[str]) -> enum.Enum: | |
values = [v.upper() for v in values] | |
members = {v: v for v in values} | |
enum_type = enum.Enum(name, members, type=str) | |
return enum_type | |
class ActionSettingMeta(type): | |
def __new__(cls, name, bases, attrs): | |
if "Actions" in attrs: | |
# add each action staticmethod to a dict of action, functions | |
attrs["actions"] = { | |
k: v.__func__ | |
for k, v in attrs["Actions"].__dict__.items() | |
if isinstance(v, staticmethod) | |
} | |
del attrs["Actions"] | |
if "Settings" in attrs: | |
# add each setting to the class | |
for k, v in attrs["Settings"].__dict__.items(): | |
if not k.startswith("__"): | |
attrs[k] = v | |
del attrs["Settings"] | |
if "WorkRequest" in attrs: | |
# generate a WorkRequest class using actions for the call field | |
actions = [k for k in attrs["actions"].keys() if not k.startswith("__")] | |
ActionEnum = create_str_enum_type("ActionEnum", actions) | |
attrs["ActionEnum"] = ActionEnum | |
attrs["WorkRequest"] = type( | |
"WorkRequest", | |
(WorkRequest,), | |
{ | |
"id": str, | |
"call": ActionEnum, | |
"args": tuple, | |
}, | |
) | |
return super().__new__(cls, name, bases, attrs) | |
class QueueingThreadPool(metaclass=ActionSettingMeta): | |
class Settings: | |
max_jobs = 100 | |
num_threads = 4 | |
class Actions: | |
@staticmethod | |
def action(arg: str, arg2: int): | |
pass | |
@dataclass | |
class WorkRequest(pydantic.BaseModel): | |
call: "QueueingThreadPool.ActionEnum" = pydantic.Field( | |
default="action", annotation=Literal["action"] | |
) | |
args: tuple = ("arg", 1) | |
def __init__(self): | |
self._todo_jobs = Queue() | |
self._running_jobs = {} | |
self._threads = [ | |
QueueingWorkThread(self._todo_jobs, self.actions) | |
for i in range(self.num_threads) | |
] | |
def start(self): | |
for thread in self._threads: | |
thread.start() | |
def submit_work(self, item: WorkRequest): | |
if self._todo_jobs.qsize() >= self.max_jobs: | |
raise RuntimeError("Too many jobs in queue") | |
if item.id is None: | |
item.id = uuid.uuid4().hex | |
self._todo_jobs.put(item) | |
def join(self): | |
for _ in self._threads: | |
self._todo_jobs.put(EOQTerminateThread) | |
for thread in self._threads: | |
thread.join() | |
class MyQTP(QueueingThreadPool): | |
class Settings: | |
max_jobs = 100 | |
num_threads = 4 | |
class Actions: | |
@staticmethod | |
def pprint(*args): | |
print(*args) | |
@staticmethod | |
def plen(*args): | |
return len(*args) | |
@staticmethod | |
def pstr(*args): | |
return str(*args) | |
@staticmethod | |
def psum(*args): | |
return sum(*args) | |
def main(): | |
logging.basicConfig(level=logging.DEBUG) | |
qtp = MyQTP() | |
qtp.start() | |
running = True | |
print(f"{qtp.WorkRequest.__dataclass_fields__=}") | |
while running: | |
qtp.submit_work(qtp.WorkRequest("pprint", ("Hello",))) | |
qtp.submit_work(qtp.WorkRequest("plen", ("Hello",))) | |
qtp.submit_work(qtp.WorkRequest("pstr", (123,))) | |
qtp.submit_work(qtp.WorkRequest("psum", ([1, 2, 3],))) | |
try: | |
qtp.submit_work(qtp.WorkRequest("notavalidrequest", ([1, 2, 3, 4],))) | |
except Exception as e: | |
print(e) | |
time.sleep(1) | |
if __name__ == "__main__": | |
try: | |
main() | |
except KeyboardInterrupt: | |
pass |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment