Prim's Algorithm algorithm implemented in Python

Prim’s Minimum Spanning Tree Algorithm

Prim’s algorithm finds the cost of a minimum spanning tree from a weighted undirected graph. This algorithm begins by randomly selecting a vertex and adding the least expensive edge from this vertex to the spanning tree. The algorithm continues to add the least expensive edge from the vertices already added to the spanning tree to make it grow and terminates when all the vertices are added to the spanning tree. It is evident that the algorithm gets greedy by selecting the least expensive edge from the vertices already added to the spanning tree.


Algorithm : Prims minimum spanning tree ( Graph G, Souce_Node S )

1.  Create a dictionary (to be used as a priority queue) PQ to hold pairs of ( node, cost ).
2.  Push [ S, 0\ ] ( node, cost ) in the dictionary PQ i.e Cost of reaching vertex S from source node S is zero.
3.  While PQ contains ( V, C ) pairs :
4.       Get the adjacent node V ( key ) with the smallest edge cost ( value ) from the dictionary PQ.
5.       Cost C = PQ [ V ]
6.       Delete the key-value pair ( V, C ) from the dictionary PQ.
7.        If the adjacent node V is not added to the spanning tree.
8.           Add node V to the spanning tree.
9.           Cost of the spanning tree += Cost C
10.          For all vertices adjacent to vertex V not added to spanning tree.
11.               Push pair of ( adjacent node, cost ) into the dictionary PQ.


Example of finding the minimum spanning tree using Prim’s algorithm Prims_MST


Python : Prim’s minimum spanning tree algorithm implemented in Python 3.6

from dataclasses import dataclass, field

# Setting frozen=True and eq=True makes a class immutable and hashable.
# eq=False is needed so that dictionary can contain multiple items with
# the same key (Node(idnum)) but with different values (cost)
@dataclass(eq=False)
class Node :
    idnum : int

@dataclass
class Graph :
    source  : int
    adjlist : dict

    def PrimsMST(self):

        # Priority queue is implemented as a dictionary with
        # keys as object of 'Node' class and value.
        # Since the priority queue will can have multiple entries for the 
        # same adjnode but with different cost, we have to use objects as
        # keys so that they can be stored in a dictionary. 
        # [As dictionary can't have duplicate keys so objectify the key]

        priority_queue = { Node(self.source) : 0 }
        added = [False] * len(self.adjlist)
        min_span_tree_cost = 0

        while priority_queue :
            # Choose the adjacent node with the least edge cost
            node = min(priority_queue, key=priority_queue.get)
            cost = priority_queue[node]

            # Remove the element from a dictionary in python
            del priority_queue[node]

            if added[node.idnum] == False :
                min_span_tree_cost += cost
                added[node.idnum] = True
                print("Added Node : " + str(node.idnum) + ", cost now : "+str(min_span_tree_cost))

                for item in self.adjlist[node.idnum] :
                    adjnode = item[0]
                    adjcost = item[1]
                    if added[adjnode] == False :
                        priority_queue[Node(adjnode)] = adjcost

        return min_span_tree_cost

def main() :

    g1_edges_from_node = {}

    # Outgoing edges from the node: (adjacent_node, cost) in graph 1.
    g1_edges_from_node[0] = [ (1,1), (2,2), (3,1), (4,1), (5,2), (6,1) ]
    g1_edges_from_node[1] = [ (0,1), (2,2), (6,2) ]
    g1_edges_from_node[2] = [ (0,2), (1,2), (3,1) ]
    g1_edges_from_node[3] = [ (0,1), (2,1), (4,2) ]
    g1_edges_from_node[4] = [ (0,1), (3,2), (5,2) ]
    g1_edges_from_node[5] = [ (0,2), (4,2), (6,1) ]
    g1_edges_from_node[6] = [ (0,1), (2,2), (5,1) ]

    g1 = Graph(0, g1_edges_from_node)
    cost = g1.PrimsMST()
    print("Cost of the minimum spanning tree in graph 1 : " + str(cost) +"\n")

    # Outgoing edges from the node: (adjacent_node, cost) in graph 2.
    g2_edges_from_node = {}
    g2_edges_from_node[0] = [ (1,4), (2,1), (3,5) ];
    g2_edges_from_node[1] = [ (0,4), (3,2), (4,3), (5,3) ];
    g2_edges_from_node[2] = [ (0,1), (3,2), (4,8) ];
    g2_edges_from_node[3] = [ (0,5), (1,2), (2,2), (4,1) ];
    g2_edges_from_node[4] = [ (1,3), (2,8), (3,1), (5,3) ];
    g2_edges_from_node[5] = [ (1,3), (4,3) ];

    g2 = Graph(0, g2_edges_from_node)
    cost = g2.PrimsMST()
    print("Cost of the minimum spanning tree in graph 1 : " + str(cost))

if __name__ == "__main__" :
    main()

Output of Prim’s minimum spanning tree algorithm implemented in Python 3.6

Added Node : 0, cost now : 0
Added Node : 1, cost now : 1
Added Node : 3, cost now : 2
Added Node : 4, cost now : 3
Added Node : 6, cost now : 4
Added Node : 2, cost now : 5
Added Node : 5, cost now : 6
Cost of the minimum spanning tree in graph 1 : 6

Added Node : 0, cost now : 0
Added Node : 2, cost now : 1
Added Node : 3, cost now : 3
Added Node : 4, cost now : 4
Added Node : 1, cost now : 6
Added Node : 5, cost now : 9
Cost of the minimum spanning tree in graph 1 : 9

Copyright (c) 2020, Algotree.org.
All rights reserved.