Algorithms · Animation

How does Kruskal’s Algorithm progress?

As a continuation to the Prim’s Algorithm animation, I have also implemented the Kruskal’s Algorithm as it is applied on randomly distributed numbers. Basically, instead of starting from a origin point, the Kruskal’s Algorithm starts by finding the distances between each point and then sort those distances in increasing order. Then, starting from the smallest distances, point pairs are added into disjoint sets to create trees until independent trees merge into a single tree (hence single cluster). One important point in the algorithm is that points that would create cycles within the trees are avoided so that minimum distance is guaranteed. Below are some end results of trees formed for given set of randomly distributed numbers:

Also, with a very small twist, the very same algorithm can be utilized to get clusters in such a way that any two points within a cluster is closer to each other while the any two objects from different clusters are further apart from each other. For that end, instead of merging all the points into a single tree, the iteration should continue until k trees are obtained where k is the number clusters desired. As an example, below is the final result to obtain cluster of 4 for the given points.

Clustering via Kruskal’s algorithm

The code that implements the Kruskal’s algorithm in C++ is provided below. Note that it also has the capability to stop at desired number of clusters for the given points. Please, share your comments/questions below and stay tuned for the next post.

#include <algorithm>
#include <iostream>
#include <iomanip>
#include <cassert>
#include <vector>
#include <set>
#include <cmath>
#include <fstream>
using std::vector;
using std::pair;
using namespace std;

bool bVerbose = false;
int offset = 1; // offset value used for priting to screen, e.g. points start from 1 not 0

// Auxiliary functions
void dumpConnectedPointPairs(int point1, int point2,const string & filename=""){
  fstream outFile;,ios::out | ios::app);
  if (outFile.is_open()){
    outFile<<point1<<" "<<point2<<endl;
  }else {
    throw std::runtime_error("Could not open file");

void displayVectorofPair(vector< pair<double, pair<int,int>> > & dvp,const string & str=""){

  size_t size= dvp.size();
  if (bVerbose){  cout<<"size= "<<size<<endl; }

  cout<<"Point 1 : ";
  for (auto pr : dvp){
    pair<int,int> point = pr.second;
    int firstpoint= point.first;
    cout<<firstpoint+offset<<' ';
  cout<<"\nPoint 2 : ";
  for (auto pr : dvp){
    pair<int,int> point = pr.second;
    int secondpoint= point.second;
    cout<<secondpoint+offset<<' ';
  cout<<"\nDistance: ";
  for (auto pr : dvp){
    double dist = pr.first;
    cout<<dist<<' ';


class DisjointSetsElement {
    int size, parent, rank;

    DisjointSetsElement(int size = 0, int parent = -1, int rank = 0):
      size(size), parent(parent), rank(rank) {}

class DisjointSets {
  int size;
  int max_table_size;
  vector <DisjointSetsElement> sets;

  explicit DisjointSets(int size):  size(size), max_table_size(0), sets(size) {
    for (int i = 0; i < size; i++){ // makeSet(i) operation, i.e. create a singleton set       sets[i].parent = i; //at first parent is assigned as itself (self-parenting [self-loop])       sets[i].rank   = 0; // MY: Although already taken care of by the default value assignment. Done for the sake of clarity     }   }   int getParent(int i) { // similar to Find(). Find parent and compress path. [This will later allow log*()     if ( i != sets[i].parent) {  //time as the tree depth does not increase much by compression]       sets[i].parent = getParent(sets[i].parent );     }     return sets[i].parent ;   }   void merge(int i, int j){ // merging or union of two sets (union_rank) - this makes the algorithm log*() complexity     int i_id = getParent(i);     int j_id = getParent(j);     if (i_id == j_id)       return;     if ( sets[i_id].rank > sets[j_id].rank ){
      sets[j_id].parent = i_id;
    } else {
      sets[i_id].parent = j_id;
      if (sets[i_id].rank == sets[j_id].rank){
        sets[j_id].rank +=1;

  // Print the disjoint set info for visualization/debugging
  void printSets(){ //print info on the sets

    cout<<"Vertex: ";
    for (int i=0; i<size; ++i){   cout<< i+offset <<' '; }

    cout<<"\nParent: ";
    for (auto &s : sets){ cout<< s.parent+offset <<' '; }

    cout<<"\nRank  : ";
    for (auto &s : sets){ cout<<s.rank<<' '; }

  // This is used later on decide on whether the number of clusters required by the user is achieved or not
  int findNumberofUniqueParents(){ //Not a classical disjoint set or Kruskal algorithm component

    std::set<int> local_set; //since sets can store unique elements, it is a natural choice here
    for (auto &s : sets){

    if (bVerbose) {
      cout<<"Unique Number of Parents : "<<local_set.size()<<"  - Unique parent vertices: ";
      for (auto & s : local_set){
        cout<< s +offset<<' ';

    return static_cast<int>(local_set.size());


bool compare ( pair<double, pair<int,int> > & lhs, pair<double, pair<int,int> > & rhs){
  return lhs.first < rhs.first;

double clustering(vector<int> x, vector<int> y, int desiredNoClusters ) {
  // At the beginning, there are edges among all points, we make sure that we utilize edges only once (i.e. edge
  // between point 1 and point 2 is taken but not between Point 2 & Point 1.
  // Also self edges, i.e. Point i to Point i not included as it is zero
  size_t nVertex = x.size();
  vector<  pair<double,pair<int,int>>  > distVector; // A vector of pair where each pair stores the distance (cost)
                               // of each distinct edge and info of that edges on between which points it is as a pair if <int,int>
                               // e.g. let distance between Point j and k be 1.3, then a pair of < 1.3, pair<j,k> > will be pushed to vector
  for (size_t i=0; i<nVertex; ++i){
    for (size_t j=i+1; j<nVertex; ++j){
      double dist = sqrt( (x[i]-x[j])*(x[i]-x[j]) + (y[i]-y[j])*(y[i]-y[j]) );

      pair<int,int> point = make_pair(i,j);
      pair<double, pair<int,int> > distPoint = make_pair(dist,point);


  if (bVerbose){ displayVectorofPair(distVector,"Dist Point Pair Vector:");   }

  // now sort the vector with respect to distance (in increasing distance order)
  std::sort(distVector.begin(), distVector.end(),compare); 

  if (bVerbose){ displayVectorofPair(distVector,"Dist Point Pair Vector (Sorted):"); }

  // For all vertices (points), make singleton sets and display for visualization
  DisjointSets allsets(nVertex); // makeSet() operation is done here
  if (bVerbose) {
      cout<<"\nAt the beginning, the sets of points"<<endl;

  // delete the parent.txt file if it already exists

  size_t inext=0;
  for (size_t i=0; i<distVector.size(); ++i){

    int point1 = (distVector[i].second).first;
    int point2 = (distVector[i].second).second;
    int parent1 = allsets.getParent(point1);
    int parent2 = allsets.getParent(point2);

    // Check whether they belong to the same parent. If so, that means we cannot merge
    // them since it would create a cycle which we don't want as we want minimum spanning tree
    bool sameParentsAlready = (parent1 == parent2);
    if (!sameParentsAlready ) { //do merge only if they are not connected already

      allsets.merge( point1, point2 ); // if they are not connected already (i.e. not have the same parents), merge them
      if (bVerbose) { allsets.printSets(); }

      // Dump the connection info between points after each merge (for visualization)

      // This is used to stop merging when the desired number of clusters are reached
      int currentNoClusters = allsets.findNumberofUniqueParents();
      if (currentNoClusters == desiredNoClusters ){
        inext = i+1;

  // Now that we have reached to the desired number of clusters, but within the clusters there can be point pairs
  // where the distances are still smaller than those between the clusters. But because they would create cycles
  // in the tree, we sweep them until we reach to the distance that is really between the clusters and not creating
  // a cycle. At that point, we stop and that distance is the distance we are looking for.
  double finalDist =0.0 ;
  for (size_t i=inext; i<distVector.size(); ++i){

    int point1 = (distVector[i].second).first;
    int point2 = (distVector[i].second).second;
    int parent1 = allsets.getParent(point1);
    int parent2 = allsets.getParent(point2);

    bool sameParentsAlready = (parent1 == parent2);
    if (sameParentsAlready ) {
      allsets.merge( point1, point2 );
     } else {
      finalDist = distVector[i].first;

  cout<<"Final minimum Distance between clusters = "<<finalDist<<endl;

  return finalDist;

int main(int argc, char** argv) {

  for (int i=0; i<argc;++i){
    string str1=argv[i];
    if ("-verbose") == 0){
      cout<<"Verbose option is requested"<<endl;       bVerbose = true;     }   }   size_t n;   int k;   std::cin >> n;
  vector<int> x(n), y(n);
  for (size_t i = 0; i < n; i++) {     std::cin >> x[i] >> y[i];
  std::cin >> k;
  std::cout << std::setprecision(10) << clustering(x, y, k) << std::endl;

Once compiled and run, the above code will dump a text file (connectedpairs.txt) containing the step-by-step progression of the edge generation between the points. The input file format is as follows:

Input file format Sample input file
x1 y1
xN yN
1 2
4 5
2 3

The produced output file and the original input file can then be fed to the following Python code to visualize and animate that progression as shown in the above video using the following command:

./ -input testinput1 -connectedpoints connectedpointpairs.txt

#!&lt;path to your python&gt;/bin/python2.7
Plots the points of a given set and the paths between them
that gives the minimum spanning tree among them
import os
from pylab import matplotlib, plt, sqrt

def FindFirstandLastCost(xpoints, ypoints, parentlines, nParents):
  ''' This can be used in the visualization of the cost function to
      find the max and min extents
  str_parents = parentlines[1].split() # 0th is dummy
  parents = [int(i) for i in str_parents]

  cost_first = 0
  for i in range(1, nParents):
    cost_first += sqrt((xpoints[i]-xpoints[parents[i]])**2  + (ypoints[i]-ypoints[parents[i]])**2)

  str_parents = parentlines[-1].split()
  parents = [int(i) for i in str_parents]

  cost_last = 0
  for i in range(1, nParents):
    cost_last += sqrt((xpoints[i]-xpoints[parents[i]])**2  + (ypoints[i]-ypoints[parents[i]])**2)

  return (cost_first, cost_last)

def PlotPointsOnly(ax, xpoints, ypoints, nPoints, k=0):
    Here just the plotting of the points as small circles done.
    The connections between them (i.e. paths) are done elsewhere

  for i in range(nPoints):
    if i == 0:
      plt.plot(xpoints[i], ypoints[i], 'mo', markersize=10)
      plt.plot(xpoints[i], ypoints[i], 'mo', markersize=10)

  plt.title('Iteration # '+str(k)+' of '+str(nPoints), fontsize=16)

  if ARGS.verbose:
    print "min(x) = ", min(xpoints), " max(x)=", max(xpoints)
    print "min(y) = ", min(ypoints), " max(y)=", max(ypoints)
  axes = plt.gca()
  #axes.set_xlim( -200, 215 )
  axes.set_xlim(min(xpoints)-5, max(xpoints)+5)
  axes.set_ylim(min(ypoints)-5, max(ypoints)+5)


def GetPointsAndPaths():
  Read the input files and get necessary data for plotting

  # Read the input files for processing
  with open(ARGS.input) as f:
    lines =

  nPoints = int(lines[0])

  if ARGS.verbose:
    print "lines=", lines
    print "nPoints = ", nPoints 

  x, y = [[], []]
  for i in range(nPoints):
    temp = (lines[i+1]).split()
    #print "temp= ", temp

  print "x= ", x, " min(x) = ", min(x), " max(x) = ", max(x)
  print "y= ", y, " min(y) = ", min(y), " max(y) = ", max(y)

  with open(ARGS.connectedpoints) as f:
    connectedpointslines =

  print "connectedpointslines  = ", connectedpointslines  

  nConnectedPoints = len(connectedpointslines)
  print "nConnectedPoints = ", nConnectedPoints

  return (x, y, connectedpointslines, nPoints, nConnectedPoints)

# --------------------------------------------------
def main():

  x, y, connectedpointslines, nPoints, nConnectedPoints = GetPointsAndPaths()

  # For interactive plotting in python, check:

  # - Plot just the points (no connecting paths)
  fig2 = plt.figure(5, figsize=(16, 9)) ## This mean 16x9 inches (number of pixels is (16x9)*dpi value set in savefig

  ax = plt.subplot2grid((1, 3), (0, 0), colspan=2)

  PlotPointsOnly(ax, x, y, nPoints)

  plt.savefig('test_dpi240_16x9__initial.png', facecolor='w', dpi=240)

  #Find the first and last costs to set the y-limits of cost-vs-iteration plots
  #cost_first, cost_last= FindFirstandLastCost(x, y, connectedpointslines,nConnectedPoints)
  #if ARGS.verbose: print "cost_first, cost_last= ",cost_first," ", cost_last

  print "Start iterations:"

  cost_arr = []
  for k in range(0, nConnectedPoints):
  #for k in range(1):
  #for k in range(nConnectedPoints-1,nConnectedPoints):
    print "k= ", k, " of ", nConnectedPoints
    str_connectedpoints = connectedpointslines[k].split()
    #print "str_connectedpoints  = ", str_connectedpoints
    connectedpoints = [int(i) for i in str_connectedpoints]
    #print "connectedpoints= ", connectedpoints


    #ax = plt.subplot(1,2,1)
    ax = plt.subplot2grid((1, 3), (0, 0), colspan=2)

    PlotPointsOnly(ax, x, y, nPoints, k)

    # plot the paths
    cost_k = 0
    for i in range(0, k+1):

      str_connectedpoints = connectedpointslines[i].split()
      connectedpoints = [int(j) for j in str_connectedpoints]
      m, n = connectedpoints

      plt.plot([x[m], x[n]], [y[m], y[n]], '-k', linewidth=2)
      cost_k += sqrt((x[m]-x[n])**2  + (y[m]-y[n])**2)

    #print "cost_arr =", cost_arr
    print "cost_k =", cost_k



    # Plot the cost vs iterations
    ax = plt.subplot2grid((1, 3), (0, 2), colspan=1)
    plt.plot(cost_arr, '-mo', markersize=6)
    axes = plt.gca()
    axes.set_xlim(-1, k+2)
    axes.set_ylim(0, 700)

    plt.title(r'Cost: $\sum_{\forall\, i,j} (P_{i}-P_{j})_{connected}$ = ' +str(int(cost_k*10)/10.0), fontsize=15, y=1.02)
    plt.xlabel('Number of iterations', fontsize=14)
    plt.ylabel('Cost: Sum of all connected distances', fontsize=14)

    # Dump png file for later video processing
    plt.savefig('test_dpi240_16x9_'+str(k)+'.png', facecolor='w', dpi=240)

  # This is for interactive plotting plt.ion()
  while True:

# -- Parse the input ---------------------------------------------------------
def ParseInput():
  Read input arguments to be plotted by this script

  import argparse
  parser = argparse.ArgumentParser()
  parser.add_argument("-v", "--verbose", help="Increase output verbosity", action="store_true")
  parser.add_argument("-input", type=str, default=None, help="enter original input")
  parser.add_argument("-connectedpoints", type=str, default=None, help="Enter cost and distance file")

  args = parser.parse_args()

  if not args.input or not args.connectedpoints:
    print "Enter cost and input files"

  if args.verbose:
    print "input=", args.input
    print "connectedpoints=", args.connectedpoints

  return args
# -----------------------------------------------------------------------------

# This is the standard boilerplate that calls the main() function.
if __name__ == '__main__':

  ARGS = ParseInput()


For a similar animation of how the Prim’s algorithm work, please check this post.

Алгоритм Краскала, Algoritmo de Kruskal, 크러스컬 알고리즘, 克鲁斯克尔演算法

One thought on “How does Kruskal’s Algorithm progress?

Leave a Reply

Fill in your details below or click an icon to log in: Logo

You are commenting using your account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s