Created
November 4, 2024 14:35
-
-
Save zzstoatzz/1a6bda6e00a9c6dc41e57c919d69493b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# /// script | |
# dependencies = [ | |
# "prefect", | |
# "networkx", | |
# ] | |
# /// | |
from typing import Any, Dict | |
import networkx as nx | |
from pydantic_core import from_json | |
from prefect import Task, flow, task | |
from prefect.futures import PrefectFuture | |
class DAGRunner: | |
def __init__(self, dag_definition: Dict[str, Any], task_registry: Dict[str, Task]): | |
"""Initialize DAG runner with a DAG definition and task registry. | |
Args: | |
dag_definition: Dictionary containing task definitions and their dependencies | |
task_registry: Dictionary mapping task names to task functions | |
Format of dag_definition: | |
{ | |
"tasks": { | |
"task_id": { | |
"function": "function_name", | |
"depends_on": ["upstream_task_id"], | |
"params": {"param1": "value1"} | |
} | |
} | |
} | |
""" | |
self.dag_definition: Dict[str, Any] = dag_definition | |
self.task_registry: Dict[str, Task] = task_registry | |
self.graph: nx.DiGraph = self._build_graph() | |
def _build_graph(self) -> nx.DiGraph: | |
"""Construct a NetworkX DiGraph from the DAG definition.""" | |
graph = nx.DiGraph() | |
# Add all tasks as nodes | |
for task_id, task_def in self.dag_definition["tasks"].items(): | |
graph.add_node(task_id, **task_def) | |
# Add edges for dependencies | |
for task_id, task_def in self.dag_definition["tasks"].items(): | |
for dep in task_def.get("depends_on", []): | |
graph.add_edge(dep, task_id) | |
# Validate that it's a DAG | |
if not nx.is_directed_acyclic_graph(graph): | |
raise ValueError("Task graph contains cycles") | |
return graph | |
def execute(self) -> Dict[str, Any]: | |
"""Execute the DAG and return results.""" | |
futures: Dict[str, PrefectFuture] = {} | |
# Process tasks in topological order | |
for task_id in nx.topological_sort(self.graph): | |
# Get task definition from node attributes | |
task_def = self.graph.nodes[task_id] | |
# Get task function | |
task_func = self.task_registry[task_def["function"]] | |
# Prepare parameters | |
params = task_def.get("params", {}).copy() | |
# Get upstream dependencies | |
wait_for = [futures[dep] for dep in list(self.graph.predecessors(task_id))] | |
# Replace parameter references with actual values | |
for key, value in params.items(): | |
if isinstance(value, str) and value.startswith("$"): | |
ref_task = value[1:] # Remove $ prefix | |
params[key] = futures[ref_task] | |
# Execute task with wait_for dependencies | |
future = task_func.submit(**params, wait_for=wait_for) | |
futures[task_id] = future | |
return {task_id: future.result() for task_id, future in futures.items()} | |
@task | |
def fetch_data(count: int) -> int: | |
"""Simulate fetching some data.""" | |
return count * 2 | |
@task | |
def process_data(data: int) -> int: | |
"""Simulate processing the data.""" | |
return data + 10 | |
@task | |
def summarize(data: int, factor: int = 1) -> int: | |
"""Simulate summarizing the results.""" | |
return data * factor | |
TASK_REGISTRY: Dict[str, Task] = { | |
"fetch_data": fetch_data, | |
"process_data": process_data, | |
"summarize": summarize, | |
} | |
@flow | |
def main(dag_definition: Dict[str, Any]): | |
return DAGRunner(dag_definition, TASK_REGISTRY).execute() | |
if __name__ == "__main__": | |
dag_definition = from_json(""" | |
{ | |
"tasks": { | |
"fetch": { | |
"function": "fetch_data", | |
"params": {"count": 5} | |
}, | |
"process": { | |
"function": "process_data", | |
"params": {"data": "$fetch"}, | |
"depends_on": ["fetch"] | |
}, | |
"summary_a": { | |
"function": "summarize", | |
"params": {"data": "$process", "factor": 2}, | |
"depends_on": ["process"] | |
}, | |
"summary_b": { | |
"function": "summarize", | |
"params": {"data": "$process", "factor": 3}, | |
"depends_on": ["process"] | |
} | |
} | |
} | |
""") | |
print(main(dag_definition)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
uv run prefect_dag.py