Prim's Algorithm in Java

Finding cost of a minimum spanning tree

Key points of Prim’s algorithm

  • Prim’s algorithm finds the cost of a minimum spanning tree from a weighted undirected graph.
  • Prim’s algorithm begins by randomly selecting a vertex and adding the least expensive edge from this vertex to the spanning tree. The algorithm continues to add the least expensive edge from the vertices already added to the spanning tree to make it grow and terminates when all the vertices are added to the spanning tree.
  • The algorithm is greedy in nature as it selects the least expensive edge from the vertices already added to the spanning tree.

Algorithm : Prims minimum spanning tree ( Graph G, Souce_Node S )

1.  Create a priority queue Q of NodeCost objects ( node, cost ).
2.  Push [ S, 0 ] ( node, cost ) in the priority queue Q i.e Cost of reaching the node S from source node S is zero.
3.  While ( ! Q.empty() )
4.       Object = Q.top(); Q.pop()
5.       Node N = Object.Node and Cost C = Object.Cost
6.       If the node N is not present in the spanning tree
7.           Add node N to the spanning tree.
8.           Cost of the spanning tree += Cost C
9.           For all the nodes adjacent to node N that are not in the spanning tree.
10.               Push object ( adjacent node, cost ) into the Q


Example of finding the minimum spanning tree using Prim’s algorithm Prims_MST_Java

Time complexity of Prim’s algorithm : O((E+V)log(V))

Why is the time complexity of Prim's algorithm O((E+V)log(V))?


Java Prim’s minimum spanning tree algorithm

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Collections;
import java.util.PriorityQueue;
import java.util.Comparator;

class NodeCost {

    int node; // Adjacent node
    int cost; // Costance/cost to adjacent node

    NodeCost (int node, int cost) {
        this.node = node;
        this.cost = cost;
    }
}

class Prims {

    int Find_MST(int source_node, List<List<NodeCost>> graph) {

        // Comparator lambda function that enables the priority queue to store the nodes
        // based on the cost in the ascending order.
        Comparator<NodeCost> NodeCostComparator = (obj1, obj2) -> {
            return obj1.cost - obj2.cost;
        };

        // Priority queue stores the object node-cost into the queue with 
        // the smallest cost node at the top.
        PriorityQueue<NodeCost> pq = new PriorityQueue<>(NodeCostComparator);

        // The cost of the source node to itself is 0
        pq.add(new NodeCost(source_node, 0));

        boolean added[] = new boolean[graph.size()];
        Arrays.fill(added, false);

        int mst_cost = 0;

        while (!pq.isEmpty()) {

            // Select the item <node, cost> with minimum cost
            NodeCost item = pq.peek();
            pq.remove();

            int node = item.node;
            int cost = item.cost;

            // If the node is node not yet added to the minimum spanning tree, add it and increment the cost.
            if ( !added[node] ) {
                mst_cost += cost;
                added[node] = true;

                // Iterate through all the nodes adjacent to the node taken out of priority queue.
                // Push only those nodes (node, cost) that are not yet present in the minumum spanning tree.
                for (NodeCost pair_node_cost : graph.get(node)) {
                    int adj_node = pair_node_cost.node;
                    if (added[adj_node] == false) {
                        pq.add(pair_node_cost);
                    }
                }
            }
        }
        return mst_cost;
    }

    public static void main(String args[]) {

        Prims p = new Prims();

        int num_nodes = 6; // Nodes (0, 1, 2, 3, 4, 5)

        List<List<NodeCost>> graph_1 = new ArrayList<>(num_nodes);
        for (int i=0; i < num_nodes; i++) {
            graph_1.add(new ArrayList<>());
        }

        // Node 0
        Collections.addAll(graph_1.get(0), new NodeCost(1, 4), new NodeCost(2, 1), new NodeCost(3, 5));
        // Node 1
        Collections.addAll(graph_1.get(1), new NodeCost(0, 4), new NodeCost(3, 2), new NodeCost(4, 3),
                                           new NodeCost(5, 3));
        // Node 2
        Collections.addAll(graph_1.get(2), new NodeCost(0, 1), new NodeCost(3, 2), new NodeCost(4, 8));
        // Node 3
        Collections.addAll(graph_1.get(3), new NodeCost(0, 5), new NodeCost(1, 2), new NodeCost(2, 2), 
                                           new NodeCost(4, 1));
        // Node 4
        Collections.addAll(graph_1.get(4), new NodeCost(1, 3), new NodeCost(2, 8), new NodeCost(3, 1), 
                                           new NodeCost(5, 4));
        // Nod
        
        e 5
        Collections.addAll(graph_1.get(5), new NodeCost(1, 3), new NodeCost(4, 4));

        // Start adding nodes to minimum spanning tree with 0 as the souce node
        System.out.println("Cost of the minimum spanning tree in graph 1 : " + p.Find_MST(0, graph_1));

        // Outgoing edges from the node:<cost, adjacent_node> in graph 2.
        num_nodes = 7; // Nodes (0, 1, 2, 3, 4, 5, 6)

        List<List<NodeCost>> graph_2 = new ArrayList<>(num_nodes);
        for (int i=0; i < num_nodes; i++) {
            graph_2.add(new ArrayList<>());
        }

        // Node 0
        Collections.addAll(graph_2.get(0), new NodeCost(1, 1), new NodeCost(2, 2), new NodeCost(3, 1), 
                                           new NodeCost(4, 1), new NodeCost(5, 2), new NodeCost(6, 1));
        // Node 1
        Collections.addAll(graph_2.get(1), new NodeCost(0, 1), new NodeCost(2, 2), new NodeCost(6, 2));
        // Node 2
        Collections.addAll(graph_2.get(2), new NodeCost(0, 2), new NodeCost(1, 2), new NodeCost(3, 1));
        // Node 3
        Collections.addAll(graph_2.get(3), new NodeCost(0, 1), new NodeCost(2, 1), new NodeCost(4, 2));
        // Node 4
        Collections.addAll(graph_2.get(4), new NodeCost(0, 1), new NodeCost(3, 2), new NodeCost(5, 2));
        // Node 5
        Collections.addAll(graph_2.get(5), new NodeCost(0, 2), new NodeCost(4, 2), new NodeCost(6, 1));
        // Node 6
        Collections.addAll(graph_2.get(6), new NodeCost(0, 1), new NodeCost(1, 2), new NodeCost(5, 1));


        // Start adding nodes to minimum spanning tree with 0 as the souce node
        System.out.println("Cost of the minimum spanning tree in graph 2 : " + p.Find_MST(0, graph_2));
    }
}

Output

Cost of the minimum spanning tree in graph 1 : 9
Cost of the minimum spanning tree in graph 2 : 6


Copyright (c) 2019-2024, Algotree.org.
All rights reserved.