Skip to content

Instantly share code, notes, and snippets.

@HarryR
Created June 18, 2025 03:43
Show Gist options
  • Save HarryR/e5ac1105d630e118e330966c1a7f5054 to your computer and use it in GitHub Desktop.
Save HarryR/e5ac1105d630e118e330966c1a7f5054 to your computer and use it in GitHub Desktop.
Make a 3d graph from nodes and edges, using igraph and plotly
###############################################################################
import sys
from collections import namedtuple
from dataclasses import dataclass
Node = namedtuple('Node',['name','group'])
GraphEdgesT = set[tuple[str,str]]
GraphNodesT = dict[str,Node]
@dataclass
class Node:
name:str
group:str
@dataclass
class Graph:
nodes: GraphNodesT
edges: GraphEdgesT
@classmethod
def default(cls):
return cls(dict(),set())
def add(self,node_name:str,color:str,targets:list[tuple[str,str]]=[]):
if node_name not in self.nodes:
self.nodes[node_name] = Node(node_name, color)
for t in targets:
self.edges.add((node_name,t))
def __add__(self, other:'Graph') -> 'Graph':
return graph_combine(self, other)
def graph_combine(*args:Graph):
nodes, edges = dict(),set()
for a in args:
nodes.update(a.nodes)
edges = edges.union(a.edges)
return Graph(nodes,edges)
def graph_linearize(args:Graph):
nodes = args.nodes
name_to_index = {name: i for i, name in enumerate(nodes.keys())}
# Convert edges to use numeric indices
edges = [(name_to_index[source], name_to_index[target])
for source,target in list(sorted(args.edges))]
labels = [node.name for node in nodes.values()]
groups = [node.group for node in nodes.values()]
return (nodes, labels, groups), edges
def flatten(xss):
return [x for xs in xss for x in xs]
def squares(x, p):
while True:
x = (x * x) % p
yield x
###############################################################################
import plotly.graph_objects as go
import igraph as ig
def draw_graph(output_filename:str,title:str,graph:Graph,modifiers_by_color:dict[tuple[int,int]]):
(nodes, labels, groups), edges = graph_linearize(graph)
N = len(nodes)
G = ig.Graph(edges, directed=False)
# Set vertex names as attributes
G.vs['name'] = labels
G.vs['group'] = groups
layt = G.layout('kk', dim=3, kkconst=N*5)
# We can apply different multipliers per-color... if needed
for i in range(len(layt)):
color = groups[i]
if color in modifiers_by_color:
(zoffset, multiplier) = modifiers_by_color[color]
layt[i][2] = layt[i][2] + zoffset
layt[i][0] *= multiplier
layt[i][1] *= multiplier
layt[i][2] *= multiplier
Xn = [layt[k][0] for k in range(N)]
Yn = [layt[k][1] for k in range(N)]
Zn = [layt[k][2] for k in range(N)]
Xe = flatten([[layt[e[0]][0], layt[e[1]][0], None] for e in edges])
Ye = flatten([[layt[e[0]][1], layt[e[1]][1], None] for e in edges])
Ze = flatten([[layt[e[0]][2], layt[e[1]][2], None] for e in edges])
trace1 = go.Scatter3d(x=Xe,
y=Ye,
z=Ze,
mode='lines',
line=dict(color='rgb(125,125,125)', width=1),
hoverinfo='none')
trace2 = go.Scatter3d(x=Xn,
y=Yn,
z=Zn,
mode='markers',
name='actors',
marker=dict(symbol='circle',
size=6,
color=groups,
colorscale='Viridis',
line=dict(color='rgb(50,50,50)',width=0.5)),
text=labels,
hoverinfo='text')
axis = dict(showbackground=False,
showline=False,
zeroline=False,
showgrid=False,
showticklabels=False,
title='')
layout = go.Layout(
title=title,
width=1000,
height=1000,
showlegend=False,
scene=dict(xaxis=dict(axis), yaxis=dict(axis), zaxis=dict(axis)),
margin=dict(t=100),
hovermode='closest')
fig = go.Figure(data=[trace1, trace2], layout=layout)
fig.write_html(output_filename)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment