Skip to content

Instantly share code, notes, and snippets.

@kgutwin
Last active June 4, 2024 16:08
Show Gist options
  • Save kgutwin/efe5f03df5ff930d899249018a0a551b to your computer and use it in GitHub Desktop.
Save kgutwin/efe5f03df5ff930d899249018a0a551b to your computer and use it in GitHub Desktop.
PRQL lineage visualization
#!/usr/bin/env python3
r"""Create a GraphViz `dot` rendering of `prqlc debug lineage` output.
Typical usage of this script would be:
prqlc debug lineage --format json my-query.prql \
| python3 lineage_dot.py | dot -Tpng -o out.png
You must have GraphViz installed. Download and installation instructions for
GraphViz are available at https://graphviz.org/download/.
"""
import json
import sys
import argparse
import typing as ty
class Node(dict[str, str]):
def __str__(self) -> str:
kv = ",".join(f'{k}="{v}"' for k, v in self.items() if k != "id" and v)
return f"{self['id']} [{kv}];"
class Target:
def __init__(self, left: int | str, right: int | str):
self.left = left
self.right = right
def __str__(self) -> str:
return f"{self.left} -> {self.right};"
class Parent(Target):
def __str__(self) -> str:
return f"{self.left} -> {self.right} [style=dotted];"
GraphPart = str | Node | Target | Parent
class Subgraph(list[GraphPart]):
def __init__(
self, i: int | str, label: str | None = None, parts: list[GraphPart] = []
):
self._i = i
self._label = label
super().__init__(parts)
def __str__(self) -> str:
rv = f"subgraph cluster_{self._i} {{\n"
if self._label:
rv += f' label = "{self._label}"\n'
rv += " " + "\n ".join(str(i) for i in self)
rv += "\n }"
return rv
class Digraph(list[GraphPart | Subgraph]):
def __str__(self) -> str:
rv = "digraph G {\n"
rv += " " + "\n ".join(str(i) for i in self)
rv += "\n}"
return rv
NODE_COLOR_MAP = {
"Ident": "darkgreen",
"Literal": "blue3",
"Tuple": "chocolate4",
"Array": "coral3",
"TransformCall": "deeppink2",
"SString": "darkcyan",
"FString": "cyan3",
"Case": "darkorchid",
"RqOperator": "crimson",
"Table": "darkgreen",
"Column": "cornflowerblue",
}
def main() -> None:
argparser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
)
argparser.add_argument(
"--no-legend",
help="Do not include legend in output",
action="store_true",
default=False,
)
argparser.add_argument(
"--all-frames",
help="Show all frames in output",
action="store_true",
default=False,
)
args = argparser.parse_args()
graph = json.load(sys.stdin)
G = Digraph()
G.append("rankdir = BT;")
G.append("node [penwidth=2];")
if not args.no_legend:
G.append(
Subgraph(
"legend",
"Legend",
[
Node(
id="l1",
label="TransformCall",
shape="house",
color=NODE_COLOR_MAP["TransformCall"],
),
Node(
id="l2",
label="Table",
shape="box",
color=NODE_COLOR_MAP["Table"],
),
Node(id="l3", label="Expression", shape="ellipse"),
Node(
id="l4",
label="Frame Column",
shape="trapezium",
color=NODE_COLOR_MAP["Column"],
style="filled",
),
"l4 -> l3 -> l2 [label=Target];",
"l1 -> l3 [label=Parent,style=dotted];",
],
)
)
# Create subgraphs for each frame
frames: ty.Iterable[tuple[int, tuple[str, dict[str, ty.Any]]]]
if args.all_frames:
frames = enumerate(graph["frames"], start=1)
else:
frames = [(len(graph["frames"]), graph["frames"][-1])]
frame_subgraphs = {}
tables = {}
for frame_n, frame in frames:
span, lineage = frame
frame_subgraphs[span] = frame_subgraph = Subgraph(frame_n)
G.append(frame_subgraph)
for column_n, column in enumerate(lineage["columns"]):
column_id = f"col_{frame_n}_{column_n}"
if "Single" in column and column["Single"]["name"]:
name = ".".join(column["Single"]["name"])
frame_subgraph.append(
Node(
id=column_id,
label=name,
shape="trapezium",
color=NODE_COLOR_MAP["Column"],
style="filled",
)
)
G.append(Target(column_id, column["Single"]["target_id"]))
elif "All" in column:
label = "All"
if column["All"]["except"]:
label += r"\nExcept: " + repr(column["All"]["except"])
frame_subgraph.append(
Node(
id=column_id,
label=label,
shape="invtrapezium",
color=NODE_COLOR_MAP["Column"],
style="filled",
)
)
G.append(Target(column_id, column["All"]["input_id"]))
for table in lineage["inputs"]:
tables[table["id"]] = table
# add each node in the graph
for node in graph["nodes"]:
n = Node(
id=node["id"],
shape="ellipse",
label=r"\N\n" + node["kind"] + r"\n",
color=NODE_COLOR_MAP.get(node["kind"], ""),
)
subgraph = None
if node["kind"].startswith("TransformCall"):
n["shape"] = "house"
n["color"] = NODE_COLOR_MAP["TransformCall"]
if node.get("span") is not None:
n["label"] += node["span"] + r"\n"
if node["span"] in frame_subgraphs:
subgraph = frame_subgraphs[node["span"]]
n["label"] += f"Frame {subgraph._i}\\n"
n["shape"] = "house"
if "ident" in node:
n["label"] += ".".join(node["ident"]["Ident"]) + r"\n"
if "alias" in node:
n["label"] += f"alias: {node['alias']}\\n"
if node["id"] in tables:
table = tables[node["id"]]
table_name = ".".join(table["table"])
n["label"] += f"Table {table_name}\\n"
if table["name"] != table_name:
n["label"] += f"alias: {table['name']}\\n"
n["shape"] = "box"
n["label"] = n["label"][:-2] # trim off trailing \n
(subgraph or G).append(n)
for target in node.get("targets", []):
G.append(Target(node["id"], target))
if "parent" in node:
G.append(Parent(node["parent"], node["id"]))
# output the graph
print(G)
if __name__ == "__main__":
main()
sys.exit(0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment