Skip to content

Instantly share code, notes, and snippets.

@zzstoatzz
Created November 4, 2024 14:35
Show Gist options
  • Save zzstoatzz/1a6bda6e00a9c6dc41e57c919d69493b to your computer and use it in GitHub Desktop.
Save zzstoatzz/1a6bda6e00a9c6dc41e57c919d69493b to your computer and use it in GitHub Desktop.
# /// script
# dependencies = [
# "prefect",
# "networkx",
# ]
# ///
from typing import Any, Dict
import networkx as nx
from pydantic_core import from_json
from prefect import Task, flow, task
from prefect.futures import PrefectFuture
class DAGRunner:
def __init__(self, dag_definition: Dict[str, Any], task_registry: Dict[str, Task]):
"""Initialize DAG runner with a DAG definition and task registry.
Args:
dag_definition: Dictionary containing task definitions and their dependencies
task_registry: Dictionary mapping task names to task functions
Format of dag_definition:
{
"tasks": {
"task_id": {
"function": "function_name",
"depends_on": ["upstream_task_id"],
"params": {"param1": "value1"}
}
}
}
"""
self.dag_definition: Dict[str, Any] = dag_definition
self.task_registry: Dict[str, Task] = task_registry
self.graph: nx.DiGraph = self._build_graph()
def _build_graph(self) -> nx.DiGraph:
"""Construct a NetworkX DiGraph from the DAG definition."""
graph = nx.DiGraph()
# Add all tasks as nodes
for task_id, task_def in self.dag_definition["tasks"].items():
graph.add_node(task_id, **task_def)
# Add edges for dependencies
for task_id, task_def in self.dag_definition["tasks"].items():
for dep in task_def.get("depends_on", []):
graph.add_edge(dep, task_id)
# Validate that it's a DAG
if not nx.is_directed_acyclic_graph(graph):
raise ValueError("Task graph contains cycles")
return graph
def execute(self) -> Dict[str, Any]:
"""Execute the DAG and return results."""
futures: Dict[str, PrefectFuture] = {}
# Process tasks in topological order
for task_id in nx.topological_sort(self.graph):
# Get task definition from node attributes
task_def = self.graph.nodes[task_id]
# Get task function
task_func = self.task_registry[task_def["function"]]
# Prepare parameters
params = task_def.get("params", {}).copy()
# Get upstream dependencies
wait_for = [futures[dep] for dep in list(self.graph.predecessors(task_id))]
# Replace parameter references with actual values
for key, value in params.items():
if isinstance(value, str) and value.startswith("$"):
ref_task = value[1:] # Remove $ prefix
params[key] = futures[ref_task]
# Execute task with wait_for dependencies
future = task_func.submit(**params, wait_for=wait_for)
futures[task_id] = future
return {task_id: future.result() for task_id, future in futures.items()}
@task
def fetch_data(count: int) -> int:
"""Simulate fetching some data."""
return count * 2
@task
def process_data(data: int) -> int:
"""Simulate processing the data."""
return data + 10
@task
def summarize(data: int, factor: int = 1) -> int:
"""Simulate summarizing the results."""
return data * factor
TASK_REGISTRY: Dict[str, Task] = {
"fetch_data": fetch_data,
"process_data": process_data,
"summarize": summarize,
}
@flow
def main(dag_definition: Dict[str, Any]):
return DAGRunner(dag_definition, TASK_REGISTRY).execute()
if __name__ == "__main__":
dag_definition = from_json("""
{
"tasks": {
"fetch": {
"function": "fetch_data",
"params": {"count": 5}
},
"process": {
"function": "process_data",
"params": {"data": "$fetch"},
"depends_on": ["fetch"]
},
"summary_a": {
"function": "summarize",
"params": {"data": "$process", "factor": 2},
"depends_on": ["process"]
},
"summary_b": {
"function": "summarize",
"params": {"data": "$process", "factor": 3},
"depends_on": ["process"]
}
}
}
""")
print(main(dag_definition))
@zzstoatzz
Copy link
Author

zzstoatzz commented Nov 4, 2024

uv run prefect_dag.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment