"""
//*************************************************************************
// Example program using OpenCV library
//					
// @file	kmeans.py					
// @author Luis M. Jimenez
// @date 2022
//
// @brief Course: Computer Vision (1782)
// Dpo. of Systems Engineering and Automation
// Automation, Robotics and Computer Vision Lab (ARVC)
// http://arvc.umh.es
// University Miguel Hernandez
//
// @note Description: 
//	- K-means clustering step by step
//
//  Dependencies:
//     scipy, matplotlib, numpy, argparse
//*************************************************************************
"""
import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial import distance
from scipy import cluster
import argparse

# Default number of clusters
numClusters = 3

# Read parameters
parser = argparse.ArgumentParser(description='K-means Clustering')
parser.add_argument('-c', dest='numClusters', type=int, default=numClusters,  help='Number of Clusters (Default 3)')
parser.print_help()

numClusters = parser.parse_args().numClusters


# Plot Data / Clusters
def plotDataClusters(data, labels=None, centroids=None, fig=None, title="Data", show_ids=True):

    plt.figure(fig, figsize=(8, 6))
    plt.clf()
    plt.title(title)
    plt.xlabel("x1")
    plt.ylabel("x2")
    plt.gcf().canvas.manager.set_window_title('K-means Clustering')
    if labels is not None: 
        plt.scatter(data[:,0], data[:,1], marker='.', c=labels, cmap='rainbow', label="Data")
    else:  plt.scatter(data[:,0], data[:,1], marker='.', label="Data")

    # Plot items ids
    if show_ids: 
        for i, (x, y) in enumerate(zip(data[:, 0], data[:, 1])):
            plt.annotate(str(i), xy=(x, y), xytext=(-3, 3), textcoords='offset points', ha='right', va='bottom')

    # Plot Centroids
    if centroids is not None:
        labelRange = range(len(centroids))
        plt.scatter(centroids[:,0],centroids[:,1], marker='x', s=40, c=labelRange, cmap='rainbow', label="Centroids")

    if centroids is not None:
        plt.legend(loc='lower right')
    plt.draw()
    plt.pause(0.1)  # gives control to GUI event manager to show the plot
# end plotDataClusters


# K-means clustering step by step
# if randomClusters is True, simple random data is used for seeds
# if randomClusters is False, optimal random data is used for seeds (default)
def kmeansStep(data, numClusters, randomClusters=False, fig=1, max_iter=100, tol=1e-3):

    plotDataClusters(data, title="Data", fig=fig)
    input('...Press return to continue...')

    # select random centroid seeds from data
    if randomClusters:  # we just choose random data as seeds
        idx = np.random.choice(len(data), numClusters, replace=False)
        centroids = data[idx,:]
    else:   # Choose optimal distribution over the feature space for random seeds
        # Compute condensed matrix distance for all data points (scipy.spatial.distance)
        D = distance.pdist(data, metric='euclidean') 
        Dm = distance.squareform(D) # distance matrix

        idx = []
        clusterId = np.random.choice(len(data), 1)[0] # choose first random item
        idx.append(clusterId) 
        for i in range(1,numClusters):
            dc = Dm[idx].mean(axis=0)   # mean distance from previous clusters to all the data

            # set to 0 the distance to previous clusters to avoid centroids repetition
            for id in idx: dc[id] = 0.0
            dc /= dc.sum()  # normalize probability distribution

            # choose the new cluster center from the data points with the probability of x being proportional to dc
            clusterId = np.random.choice(len(data), 1, p=dc)[0]
            idx.append(clusterId)
        centroids = data[idx,:]

    # Plot data and initial centroids
    title_type = "(Random Centroids)"  if randomClusters else "(Optimal Random Distribution)"
    plotDataClusters(data, centroids=centroids, title=f"Seed Clusters - {title_type}", fig=fig)
    input('...Press return to continue...')

    # update clusters and centroids
    compactness = 0.0
    step = 0
    end = False
    while not end:         
        # Compute matrix distance between centroids and sample data (scipy.spatial.distance)
        D = distance.cdist(data, centroids, metric='euclidean') 

        # find nearest centroid for every sample  
        labels = D.argmin(axis=1) 
        distances = D.min(axis=1)

        compactness= np.mean(distances)

        # new centroids
        end = True
        for i in range(numClusters):
            newCentroid =  np.mean(data[labels==i], axis=0)
            # if a centroid moves more that (tol) try another step
            if np.linalg.norm(newCentroid-centroids[i]) > tol: end = False
            centroids[i] = newCentroid

        # plot clusters labels
        plotDataClusters(data, centroids=centroids, labels=labels, fig=fig,
                          title=f"Step: {step+1:2} - Compactness={compactness:.2f} - {numClusters} clusters")
        input('...Press return to continue...')

        step +=1
        if step > max_iter: end = True
    # end while loop

    # Plot final clusters and centroids
    print(f"{compactness=:.2f}")
    print(f"Centroids:\n{centroids}")
    plotDataClusters(data, centroids=centroids, labels=labels, fig=fig,
                          title=f"{title_type} Final Compactness={compactness:.2f} - {numClusters} clusters")

    return compactness, labels, centroids
# end function kmeansStep


# K-means clustering scipy implementation
def kmeans(data, numClusters, fig=2):
    
    # Plot data
    plotDataClusters(data, fig=fig, title="K-means Clustering - scipy")
    input('...Press return to continue...')

    # k-means clustering
    centroids, compactness = cluster.vq.kmeans(data, numClusters)
    labels, dist = cluster.vq.vq(data, centroids)    # find nearest centroid for every sample

    # plot clusters labels
    print(f"{compactness=:.2f}")
    print(f"Centroids:\n{centroids}")
    plotDataClusters(data, centroids=centroids, labels=labels, fig=fig,
                          title=f"(Scipy K-means) Compactness={compactness:.2f} - {numClusters} clusters")

    return compactness, labels, centroids
# end of function kmeans

#########################################

# load data
X = np.array([[5,3], [10,15], [15,12], [24,10], [30,30], [15,60], [25,75], [15,92], 
              [30,55], [35,85], [85,70], [71,80], [60,78], [65,55], [80,91],]).astype(float)

print(f'\nK-means clustering, random seeds: {numClusters} clusters\n')
kmeansStep(X, numClusters=numClusters, randomClusters=True, fig=1)

input('\n...Press return to continue...')

print(f'\nK-means clustering, optimal spatial distribution random seeds: {numClusters} clusters\n')
kmeansStep(X, numClusters=numClusters,  randomClusters=False, fig=2)

input('\n...Press return to continue...')

print(f'\nK-means clustering, scipy algorithm: {numClusters} clusters\n')
kmeans(X, numClusters=numClusters, fig=3)

input('\n...Press return to finish program...')