Created
May 17, 2011 08:17
-
-
Save spranesh/976142 to your computer and use it in GitHub Desktop.
Directed Graph class
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" Implementation of a simple directed graph with no weights. | |
Test by either running python on this file, | |
or by calling nosetests on this file. | |
""" | |
import unittest | |
import collections | |
class DirectedGraph: | |
def __init__(self): | |
self.graph = collections.defaultdict(lambda: list()) | |
return | |
def __repr__(self): | |
return repr(self.graph) | |
def __str__(self): | |
return str(self.graph) | |
def AddEdge(self, a, b): | |
""" Add directed edge a -> b """ | |
assert(b not in self.graph[a]) # No multi graphs allowed | |
self.graph[a].append(b) | |
return | |
def HasEdge(self, a, b): | |
return b in self.graph[a] | |
def RemoveEdge(self, a, b): | |
""" O(neighbours(a)) operation.""" | |
assert(b in self.graph[a]) | |
self.graph[a].remove(b) | |
assert(b not in self.graph[a]) | |
def DFS(self, start): | |
assert(self.graph.has_key(start)) | |
visited = collections.defaultdict(lambda: False) | |
q = collections.deque([start]) | |
visited[start] = True | |
while len(q) > 0: | |
current_node = q.pop() | |
neighbours = self.graph[current_node] | |
q.extend([n for n in neighbours if not visited[n]]) | |
for n in neighbours: | |
visited[n] = True | |
yield current_node | |
def GetReachabilityGraph(self): | |
""" An O(n^2) implementation """ | |
reachability = {} | |
for node in self.graph: | |
dfs = list(self.DFS(node)) | |
reachability[node] = dfs[1:] # we don't want the node itself | |
return reachability | |
class TestDirectedGraph(unittest.TestCase): | |
def setUp(self): | |
""" Set up method. This graph, g is the graph used in all cases. | |
Run before calling each test. i.e, each test is independent of other | |
tests.""" | |
self.g = DirectedGraph() | |
self.g.AddEdge(1, 2) | |
self.g.AddEdge(1, 3) | |
self.g.AddEdge(2, 1) | |
self.g.AddEdge(4, 5) | |
return | |
def testHasEdge(self): | |
assert(self.g.HasEdge(1, 2)) | |
assert(self.g.HasEdge(1, 3)) | |
assert(self.g.HasEdge(2, 1)) | |
assert(self.g.HasEdge(4, 5)) | |
assert(not self.g.HasEdge(5, 6)) | |
def testAddEdge(self): | |
assert(not self.g.HasEdge(5, 6)) | |
self.g.AddEdge(5, 6) | |
assert(self.g.HasEdge(5, 6)) | |
def testRemoveEdge(self): | |
assert(self.g.HasEdge(1, 2)) | |
self.g.RemoveEdge(1, 2) | |
assert(not self.g.HasEdge(1, 2)) | |
def testDFS(self): | |
# Should return [1, 2, 3] in some order | |
d1 = list(self.g.DFS(1)) | |
# We don't know the order of d1 | |
assert(len(d1) == 3) | |
assert(1 in d1) | |
assert(2 in d1) | |
assert(3 in d1) | |
# Should also return [1, 2, 3] in some order. | |
d2 = list(self.g.DFS(2)) | |
assert(len(d2) == 3) | |
assert(1 in d2) | |
assert(2 in d2) | |
assert(3 in d2) | |
# Should return [3] | |
d3 = list(self.g.DFS(3)) | |
assert(len(d3) == 1) | |
assert(d3[0] == 3) | |
# Should return [4, 5] in some order | |
d4 = list(self.g.DFS(4)) | |
assert(len(d4) == 2) | |
assert(4 in d4) | |
assert(5 in d4) | |
# Should return [5] | |
d5 = list(self.g.DFS(5)) | |
assert(len(d5) == 1) | |
assert(5 in d5) | |
return | |
def testGetReachabilityGraph(self): | |
# Try on a simpler graph | |
g = DirectedGraph() | |
g.AddEdge('f', 'g') | |
g.AddEdge('g', 'f') | |
g.AddEdge('h', 'f') | |
reachability_graph = g.GetReachabilityGraph() | |
reachability_graph['f'] = ['g'] | |
reachability_graph['g'] = ['f'] | |
assert(len(reachability_graph['h']) == 2) | |
assert('f' in reachability_graph['h']) | |
assert('g' in reachability_graph['h']) | |
if __name__ == '__main__': | |
unittest.main() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment