Representing graphs (data structure) in Python

Even though this is a somewhat old question, I thought I’d give a practical answer for anyone stumbling across this.

Let’s say you get your input data for your connections as a list of tuples like so:

[('A', 'B'), ('B', 'C'), ('B', 'D'), ('C', 'D'), ('E', 'F'), ('F', 'C')]

The data structure I’ve found to be most useful and efficient for graphs in Python is a dict of sets. This will be the underlying structure for our Graph class. You also have to know if these connections are arcs (directed, connect one way) or edges (undirected, connect both ways). We’ll handle that by adding a directed parameter to the Graph.__init__ method. We’ll also add some other helpful methods.

import pprint
from collections import defaultdict


class Graph(object):
    """ Graph data structure, undirected by default. """

    def __init__(self, connections, directed=False):
        self._graph = defaultdict(set)
        self._directed = directed
        self.add_connections(connections)

    def add_connections(self, connections):
        """ Add connections (list of tuple pairs) to graph """

        for node1, node2 in connections:
            self.add(node1, node2)

    def add(self, node1, node2):
        """ Add connection between node1 and node2 """

        self._graph[node1].add(node2)
        if not self._directed:
            self._graph[node2].add(node1)

    def remove(self, node):
        """ Remove all references to node """

        for n, cxns in self._graph.items():  # python3: items(); python2: iteritems()
            try:
                cxns.remove(node)
            except KeyError:
                pass
        try:
            del self._graph[node]
        except KeyError:
            pass

    def is_connected(self, node1, node2):
        """ Is node1 directly connected to node2 """

        return node1 in self._graph and node2 in self._graph[node1]

    def find_path(self, node1, node2, path=[]):
        """ Find any path between node1 and node2 (may not be shortest) """

        path = path + [node1]
        if node1 == node2:
            return path
        if node1 not in self._graph:
            return None
        for node in self._graph[node1]:
            if node not in path:
                new_path = self.find_path(node, node2, path)
                if new_path:
                    return new_path
        return None

    def __str__(self):
        return '{}({})'.format(self.__class__.__name__, dict(self._graph))

I’ll leave it as an “exercise for the reader” to create a find_shortest_path and other methods.

Let’s see this in action though…

>>> connections = [('A', 'B'), ('B', 'C'), ('B', 'D'),
                   ('C', 'D'), ('E', 'F'), ('F', 'C')]
>>> g = Graph(connections, directed=True)
>>> pretty_print = pprint.PrettyPrinter()
>>> pretty_print.pprint(g._graph)
{'A': {'B'},
 'B': {'D', 'C'},
 'C': {'D'},
 'E': {'F'},
 'F': {'C'}}

>>> g = Graph(connections)  # undirected
>>> pretty_print = pprint.PrettyPrinter()
>>> pretty_print.pprint(g._graph)
{'A': {'B'},
 'B': {'D', 'A', 'C'},
 'C': {'D', 'F', 'B'},
 'D': {'C', 'B'},
 'E': {'F'},
 'F': {'E', 'C'}}

>>> g.add('E', 'D')
>>> pretty_print.pprint(g._graph)
{'A': {'B'},
 'B': {'D', 'A', 'C'},
 'C': {'D', 'F', 'B'},
 'D': {'C', 'E', 'B'},
 'E': {'D', 'F'},
 'F': {'E', 'C'}}

>>> g.remove('A')
>>> pretty_print.pprint(g._graph)
{'B': {'D', 'C'},
 'C': {'D', 'F', 'B'},
 'D': {'C', 'E', 'B'},
 'E': {'D', 'F'},
 'F': {'E', 'C'}}

>>> g.add('G', 'B')
>>> pretty_print.pprint(g._graph)
{'B': {'D', 'G', 'C'},
 'C': {'D', 'F', 'B'},
 'D': {'C', 'E', 'B'},
 'E': {'D', 'F'},
 'F': {'E', 'C'},
 'G': {'B'}}

>>> g.find_path('G', 'E')
['G', 'B', 'D', 'C', 'F', 'E']

Leave a Comment