Visualizing Prim's algorithm with networkx and matplotlib
Among the programs we write, some (but never enough) perform a precise mathematical function such as sorting or finding the maximum of a sequence of numbers, determining primality, or finding the square root. We call such programs algorithms.
- Alan Perlis, Foreword to the Structure and Interpretation of Computer Programs
Prim's algorithm finds the minimum spanning tree (MST) for a weighted graph. That is, the set of edges that connects every node in the graph while minimizing total edge weight.
Computing a graph's MST is, on its surface, a pretty difficult problem to solve. If you were handed a graph on paper with hundreds of nodes and edges, finding the MST without knowing an algorithm seems like it could easily take months (even knowing an algorithm, doing it by hand would be a daunting task).
But Prim's algorithm is a great example of a problem that becomes much easier to understand and solve with the right approach and data structures. It combines a number of interesting challenges and algorithmic approaches - namely sorting, searching, greediness, and efficiently processing items by priority. As a bonus, it's a delight to watch in action, to see the algorithm start in the middle of a jumbled mess of edges and nodes and slowly conquer the graph.
We'll gloss over the theory of why Prim's algorithm works but I'll link some references at the end. We will, however, write it from scratch1 and watch it in action with matplotlib.
1. Depending on your definition of "from scratch." We'll use libraries for the graph and priority queue which are integral parts of the algorithm. Every time I use this phrase, I think of Carl Sagan saying "if you wish to make an apple pie from scratch, you must first invent the universe."
from queue import PriorityQueue from random import randint, uniform import networkx as nx from matplotlib import animation, rc import matplotlib.pyplot as plt
First, some magic to embed the matplotlib animation in a notebook (thanks to this post for explaining).
To make the visualization reasonable, we'll create a graph with $25$ nodes and $150$ edges. Each node is represented with a number $[0,25)$ and each edge is given a random weight $[0,1]$.
NUM_NODES = 25 def random_node(): return randint(0, NUM_NODES-1) def random_weight(): return uniform(0, 1)
We start by creating a graph and adding edges between consecutive nodes so that all nodes in the graph are connected. That is, we connect nodes (0,1), (1,2), (2,3), etc. so that we aren't left with any unconnected nodes. Then, we create another 125 edges between random nodes. Each edge is given a random weight between $0$ and $1$.
graph = nx.Graph() for i in range(1, NUM_NODES): graph.add_edge(i-1, i, weight=random_weight()) for _ in range(NUM_NODES * 5): graph.add_edge( random_node(), random_node(), weight=random_weight() ) pos = nx.random_layout(graph)
For the last bit of set-up, we need to create three sets to store:
- All edges in the graph
- The edges of the graph in the MST
- The nodes of the graph in the MST
We initialize (2) and (3) to be empty and Prim's algorithm will add new edges and nodes until (3) contains all nodes in the graph.
Edges are represented as tuples that hold the two nodes
connected by the edge. Because the edges are
undirected, an edge between nodes $1$ and $5$ could be
(1, 5) or
(5, 1). To simplify comparing
edges between data structures, we'll always store them in
sorted order (in this case,
%%capture all_edges = set( tuple(sorted((n1, n2))) for n1, n2 in graph.edges() ) edges_in_mst = set() nodes_on_mst = set() fig, ax = plt.subplots(figsize=(6,4))
Finally, we're ready to implement Prim's algorithm. The algorithm works on the following principle - if you have a set of nodes and edges that you know are in the MST, then the edge with minimum weight that connects a node in the MST to a node not already in the MST is guaranteed to be in the MST.
That's a lot of words so let's look at quick example. Take a graph with four nodes where each node is connected with the following weights.
|Node 1||Node 2||Weight|
|Node 1||Node 2||Weight|
Let's say we start at Node 1 (it doesn't matter which node we start with). The edge with minimum weight connected to Node 1 is $(1, 2)$ so that must be in the MST.
Now, we want to know the edge with minimum weight that takes us from a node in the MST ($1$ or $2$) to a node that is not in the MST ($3$ or $4$). In our example, it's easy to see that $(1, 3)$ has the next smallest weight and, after that, $(1, 4)$ which connects every node. The final MST is $(1, 2)$, $(1, 3)$, and $(1, 4)$.
Our example is simple, but in large graphs with many nodes and edges, the challenge is to efficiently find the edge with the lowest weight. For this, Prim's algorithm uses a minimum priority queue which maintains the queue such that the next element returned always contains the smallest weight.
is a minimum priority queue that takes a tuple in the form
(priority_value, element). In our case,
priority_value is the
edge's weight and
element is the tuple representing the edge.
For example, the edge $(1, 2)$ with a weight of $0.5$ would be
added to the priority queue with:
pqueue.put((0.5, (1, 2)))
def prims(): pqueue = PriorityQueue() # Start at any random node and add all edges connected to this # node to the priority queue. start_node = random_node() for neighbor in graph.neighbors(start_node): edge_data = graph.get_edge_data(start_node, neighbor) edge_weight = edge_data["weight"] pqueue.put((edge_weight, (start_node, neighbor))) # Loop until all nodes are in the MST while len(nodes_on_mst) < NUM_NODES: # Get the edge with smallest weight from the priority queue _, edge = pqueue.get(pqueue) if edge not in nodes_on_mst: new_node = edge elif edge not in nodes_on_mst: new_node = edge else: # If this edge connects two nodes that are already in the # MST, then skip this and continue to the next edge in # the priority queue. continue # Every time a new node is added to the priority queue, add # all edges that it sits on to the priority queue. for neighbor in graph.neighbors(new_node): edge_data = graph.get_edge_data(new_node, neighbor) edge_weight = edge_data["weight"] pqueue.put((edge_weight, (new_node, neighbor))) # Add this edge to the MST. edges_in_mst.add(tuple(sorted(edge))) nodes_on_mst.add(new_node) # Yield edges in the MST to plot. yield edges_in_mst
The last step is to provide the functions to draw the graph and MST in matplotlib. We'll use the networkx draw_networkx_nodes and draw_networkx_edges to draw three elements:
- All nodes in the graph
- The edges in the graph not in the MST, drawn in light green. This is computed by taking the difference between the set of all edges in the graph and the edges in the MST.
- The edges in the graph in the MST, drawn in deep blue.
def update(mst_edges): ax.clear() nx.draw_networkx_nodes(graph, pos, node_size=25, ax=ax) nx.draw_networkx_edges( graph, pos, edgelist=all_edges-mst_edges, alpha=0.1, edge_color='g', width=1, ax=ax ) nx.draw_networkx_edges( graph, pos, edgelist=mst_edges, alpha=1.0, edge_color='b', width=1, ax=ax ) def do_nothing(): # FuncAnimation requires an initialization function. We don't # do any initialization, so we provide a no-op function. pass
That's it, let's see our animation!
ani = animation.FuncAnimation( fig, update, init_func=do_nothing, frames=prims, interval=500, ) ani
I learned Prim's algorithm from the awesome Algorithms, Part II course on Coursera. I enjoyed everything about this course, the content is presented clearly, the exercises are challenging and rewarding, and the suite of libraries developed for the course are extremely impressive. Proofs about the correctness and complexity of Prim's algorithm are in the course's textbook, Algorithms, 4th Edition. The course website also contains two different implementations of Prim's algorithm in Java.