Prim's Minimum Spanning Tree
by Isai Damier, Android Engineer @ Google

/**************************************************************************
 * Author: Isai Damier
 * Title: Prim's Minimum Spanning Tree
 * Project: geekviewpoint
 * Package: algorithm.graph
 *
 * Statement:
 *   Given a connected undirected graph, find a tree that comprises all the
 *   vertices such that the total weight of all the edges in the tree
 *   is minimal.
 *
 *   The difference between a graph and a tree is this:
 *   for a tree there is only one sequence of edges between
 *   any two vertices; no such guaranty is made for a graph.
 *
 *   The minimum spanning tree (MST) algorithm assumes the worst
 *   about the input graph: that there are more than one path
 *   (i.e. sequence of edges) between any two vertices.
 *   Jarnik-Prim's algorithm then proceeds as follows:
 *
 *   1] Start with an empty tree T
 *   2] Remove an arbitrary vertex from the graph G and add it
 *      to the tree T.
 *   3] a) For each remaining vertex on the graph G
 *      b) find the vertex v on the graph that is the closest
 *         distance from any vertex already on the tree.
 *      c) Remove v from the graph and add it to the tree.
 **************************************************************************/ 
 
package algorithms.graph;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.PriorityQueue;

public class PrimMinimumSpanningTree {

  /******************************************************************
   * Function: minimumSpanningTree
   *
   * @param G: adjacency matrix representing the weighted graph. This
   *           is perforce a square matrix. The keys are vertices
   *           and each value is the weight of the edge linking
   *           those vertices. For example, G[1][5]=23 means
   *           an edge of weight 23 connects vertex 1 and vertex 5.
   *
   * @return Set<String>
   *
   * Implementation Details: Two details are important for
   *   understanding this implementation.
   *   1) Arrays pred[] and link[] are used to prepare each vertex
   *      before it is added to the MST.
   *   2) A vertex is considered in the MST once it is removed from
   *      the priority queue pq.
   *
   *  1] Start with an empty tree T
   *
   *  -- The minimum spanning tree (MST) is considered empty if
   *     all of the following three things are true:
   *     1) for all v, prev[v]=-1. Normally prev[] tracks the
   *        parent vertex x to which a child vertex v is linked
   *         to the tree, as prev[v]=x. Since -1 is not a valid
   *         vertex, prev[v] = -1 means v is the root.
   *     2) for all v, link[v] = infinity. link[] tracks the edge
   *        w that links a child vertex v to the tree by some
   *        parent vertex x. if v is an infinite distance away
   *        from x, then v is effectively not on the tree.
   *     3) all the vertices are in the priority queue pq.
   *
   *  2] Remove an arbitrary vertex from the graph G and add it
   *      to the tree T.
   *
   *  --  The first vertex x on the tree will be the root. So
   *      set link[x]=0. We will actually remove x from the
   *      graph when we dequeue pq.
   *
   *   3a] For each remaining vertex v on the graph G
   *
   *   -- this is simply a while loop checking that pq is not empty
   *
   *   3b-c] find the vertex v on the graph that is the closest distance
   *         from any vertex already on the tree. Remove v from the
   *         graph and add it to the tree.
   *
   *   -- When we remove v from pq, we at once remove it from the graph
   *      so to speak and add it to the tree T.
   *
   *      After removing v from the graph, we relax its neighbors and
   *      prepare them for admission into T. To relax a vertex, we
   *      apply the triangle inequality: if link[u] > G[v][u] then
   *      set link[u]=G[v][u].
   *
   *****************************************************************/
  public List<Integer[]> minimumSpanningTree(Integer[][] G) {
    final int n = G.length;
    int[] prev = new int[n];
    Arrays.fill(prev, -1);
    int[] link = new int[n];
    Arrays.fill(link, Integer.MAX_VALUE);
    link[0] = 0;

    PriorityQueue<Integer[]> pq = createPriorityQueue(link);

    while (!pq.isEmpty()) {
      pq.add(pq.remove());// reorder the queue
      int v = pq.remove()[0];
      for (Integer u[] : pqNeighbors(G, pq, v)) {
        //u[0] is the vertex; u[1] is the weight/link
        if (link[u[0]] > G[v][u[0]]) {
          link[u[0]] = G[v][u[0]];
          u[1] = link[u[0]];//update in pq
          prev[u[0]] = v;
        }
      }
    }//while

    /* How you implement the following is up to you.
     * You can print it or whatever. I am creating
     * a list of array elements {v,x,edge} where edge
     * is the edge connecting vertex x to vertex v.
     */
    List<Integer[]> result = new ArrayList<>();
    for (int i = 0; i < n; i++) {
      if (0 != link[i]) {
        result.add(new Integer[]{i, prev[i], link[i]});
      }
    }
    return result;
  }// minimumSpanningTree
package algorithms.graph;

import java.util.List;
import org.junit.After;
import org.junit.Test;
import static org.junit.Assert.*;
import org.junit.Before;

public class PrimMinimumSpanningTreeTest {

  PrimMinimumSpanningTree mst;

  @Before
  public void setUp() throws Exception {
    mst = new PrimMinimumSpanningTree();
  }

  @After
  public void tearDown() throws Exception {
    mst = null;
  }

  @Test
  public void testMinimumSpanningTree() {
    Integer[][] G = {{null, 6, 5, null, null, null, 8, 14},
      {6, null, 12, null, null, null, null, null},
      {5, 12, null, 9, null, 7, null, null},
      {null, null, 9, null, null, null, null, null},
      {null, null, null, null, null, 15, null, null},
      {null, null, 7, null, 15, null, 10, null},
      {8, null, null, null, null, 10, null, 3},
      {14, null, null, null, null, null, 3, null}};
    mst.minimumSpanningTree(G);

    Integer[][] expected = {{1, 0, 6}, {2, 0, 5}, {3, 2, 9}, {4, 5, 15},
      {5, 2, 7}, {6, 0, 8}, {7, 6, 3}};
    List<Integer[]> result = mst.minimumSpanningTree(G);
    for (int i = 0; i < expected.length; i++) {
      for (int k = 0; k < expected[i].length; k++) {
        if (expected[i][k] != result.get(i)[k]) {
          fail(expected[i][k] + " vs " + result.get(i)[k]);
        }
      }
    }

  }
}