Prim's Algorithm in Python

Key points
- Prim’s algorithm finds the cost of a minimum spanning tree from a weighted undirected graph.
- It begins by randomly selecting a vertex and adding the least expensive edge from this vertex to the spanning tree.
- The algorithm continues to add the least expensive edge from the vertices already added to the spanning tree to make it grow and terminates when all the vertices are added to the spanning tree.
- It is evident that the algorithm gets greedy by selecting the least expensive edge from the vertices already added to the spanning tree.

Algorithm : Prims minimum spanning tree ( Graph G, Souce_Node S )

1.  Create a dictionary (to be used as a priority queue) PQ to hold pairs of ( node, cost ).
2.  Push [ S, 0 ] ( node, cost ) in the dictionary PQ i.e Cost of reaching vertex S from source node S is zero.
3.  While PQ contains ( V, C ) pairs :
4.       Get the adjacent node V ( key ) with the smallest edge cost ( value ) from the dictionary PQ.
5.       Cost C = PQ [ V ]
6.       Delete the key-value pair ( V, C ) from the dictionary PQ.
7.        If the adjacent node V is not added to the spanning tree.
8.           Add node V to the spanning tree.
9.           Cost of the spanning tree += Cost C
10.          For all vertices adjacent to vertex V not added to spanning tree.
11.               Push pair of ( adjacent node, cost ) into the dictionary PQ.


Example of finding the minimum spanning tree using Prim’s algorithm Prims_MST_Python



Python3 : Prim’s minimum spanning tree algorithm.

from typing import List, Dict # For annotations

class Node :

    def __init__(self, arg_id) :
        self._id = arg_id

class Graph :

    def __init__(self, source : int, adj_list : Dict[int, List[int]]) :
        self.source = source
        self.adjlist = adj_list

    def PrimsMST (self) -> int :

        # Priority queue is implemented as a dictionary with
        # key as an object of 'Node' class and value as the cost of 
        # reaching the node from the source.
        # Since the priority queue can have multiple entries for the
        # same adjacent node but a different cost, we have to use objects as
        # keys so that they can be stored in a dictionary. 
        # [As dictionary can't have duplicate keys so objectify the key]

        # The distance of source node from itself is 0. Add source node as the first node
        # in the priority queue
        priority_queue = { Node(self.source) : 0 }
        added = [False] * len(self.adjlist)
        min_span_tree_cost = 0

        while priority_queue :
            # Choose the adjacent node with the least edge cost
            node = min (priority_queue, key=priority_queue.get)
            cost = priority_queue[node]

            # Remove the node from the priority queue
            del priority_queue[node]

            if added[node._id] == False :
                min_span_tree_cost += cost
                added[node._id] = True
                print("Added Node : " + str(node._id) + ", cost now : "+str(min_span_tree_cost))

                for item in self.adjlist[node._id] :
                    adjnode = item[0]
                    adjcost = item[1]
                    if added[adjnode] == False :
                        priority_queue[Node(adjnode)] = adjcost

        return min_span_tree_cost

def main() :

    g1_edges_from_node = {}

    # Outgoing edges from the node: (adjacent_node, cost) in graph 1.
    g1_edges_from_node[0] = [ (1,1), (2,2), (3,1), (4,1), (5,2), (6,1) ]
    g1_edges_from_node[1] = [ (0,1), (2,2), (6,2) ]
    g1_edges_from_node[2] = [ (0,2), (1,2), (3,1) ]
    g1_edges_from_node[3] = [ (0,1), (2,1), (4,2) ]
    g1_edges_from_node[4] = [ (0,1), (3,2), (5,2) ]
    g1_edges_from_node[5] = [ (0,2), (4,2), (6,1) ]
    g1_edges_from_node[6] = [ (0,1), (2,2), (5,1) ]

    g1 = Graph(0, g1_edges_from_node)
    cost = g1.PrimsMST()
    print("Cost of the minimum spanning tree in graph 1 : " + str(cost) +"\n")

    # Outgoing edges from the node: (adjacent_node, cost) in graph 2.
    g2_edges_from_node = {}
    g2_edges_from_node[0] = [ (1,4), (2,1), (3,5) ]
    g2_edges_from_node[1] = [ (0,4), (3,2), (4,3), (5,3) ]
    g2_edges_from_node[2] = [ (0,1), (3,2), (4,8) ]
    g2_edges_from_node[3] = [ (0,5), (1,2), (2,2), (4,1) ]
    g2_edges_from_node[4] = [ (1,3), (2,8), (3,1), (5,3) ]
    g2_edges_from_node[5] = [ (1,3), (4,3) ]

    g2 = Graph(0, g2_edges_from_node)
    cost = g2.PrimsMST()
    print("Cost of the minimum spanning tree in graph 2 : " + str(cost))

if __name__ == "__main__" :
    main()

Output

Added Node : 0, cost now : 0
Added Node : 1, cost now : 1
Added Node : 3, cost now : 2
Added Node : 4, cost now : 3
Added Node : 6, cost now : 4
Added Node : 2, cost now : 5
Added Node : 5, cost now : 6
Cost of the minimum spanning tree in graph 1 : 6

Added Node : 0, cost now : 0
Added Node : 2, cost now : 1
Added Node : 3, cost now : 3
Added Node : 4, cost now : 4
Added Node : 1, cost now : 6
Added Node : 5, cost now : 9
Cost of the minimum spanning tree in graph 2 : 9


Copyright (c) 2019-2024, Algotree.org.
All rights reserved.