"""
//*************************************************************************
// Example program using OpenCV library
//					
// @file	mean-shift.py					
// @author Luis M. Jimenez
// @date 2022
//
// @brief Course: Computer Vision (1782)
// Dpo. of Systems Engineering and Automation
// Automation, Robotics and Computer Vision Lab (ARVC)
// http://arvc.umh.es
// University Miguel Hernandez
//
// @note Description: 
//	- Mean-shift clustering:   maximun density gradient 
//
//  Dependencies:
//     scipy, sklearn, matplotlib, numpy, argparse
//*************************************************************************
"""
import numpy as np
from scipy.spatial import distance
from sklearn import cluster
from sklearn import datasets
import matplotlib.pyplot as plt
import argparse

# Default bandwidth for  distance kernel (0.0 means is estimated during training)
# determines the number of clusters
bandwidth = 0.0

# Read parameters
parser = argparse.ArgumentParser(description='Meam-Shift Clustering')
parser.add_argument('-b', dest='bandwidth', type=float, default=bandwidth,  help='Bandwidth for  distance kernel')
parser.print_help()

bandwidth = parser.parse_args().bandwidth

# Plot Data / Clusters
def plotDataClusters(data, labels=None, centroids=None, new_centroids=None, fig=None, title="Data", show_ids=False):

    plt.figure(fig, figsize=(8, 6))
    plt.clf()
    plt.title(title)
    plt.xlabel("x1")
    plt.ylabel("x2")
    plt.gcf().canvas.manager.set_window_title('Mean-Shift Clustering')
    if labels is not None: 
        plt.scatter(data[:,0], data[:,1], marker='.', c=labels, cmap='rainbow', label="Data")
    else:  plt.scatter(data[:,0], data[:,1], marker='.', label="Data")

    # Plot items ids
    if show_ids: 
        for i, (x, y) in enumerate(zip(data[:, 0], data[:, 1])):
            plt.annotate(str(i), xy=(x, y), xytext=(-3, 3), textcoords='offset points', ha='right', va='bottom')

    # Plot Centroids
    if centroids is not None:
        if labels is not None:
            labelRange = range(len(centroids))
            plt.scatter(centroids[:,0],centroids[:,1], marker='x', s=40, c=labelRange, cmap='rainbow', label="Centroids")
        else:
            plt.scatter(centroids[:,0],centroids[:,1], marker='x', s=40, c='red', alpha=0.6, label="Centroids")

    # Plot updated menan-shift centroids
    if new_centroids is not None:
       plt.scatter(new_centroids[:, 0], new_centroids[:, 1], marker='x', s=40, c='blue', alpha=0.8, label="New Centroids")
 
    if(centroids is not None):
        plt.legend(loc='lower right')

    plt.draw()
    plt.pause(0.1)  # gives control to GUI event manager to show the plot
# end plotDataClusters


# Step by setp Mean-Shift clustering
def meanShiftStep(data, bandwidth=0.0, fig=1, max_iter=50, tol=1e-3):
 
    plotDataClusters(data, fig=fig, title="Data")  # plot data
    input("...Press return to continue...")
  
     # bandwidth for distance kernel. if 0 it is estimated with sklearn.cluster.estimate_bandwidth
    if bandwidth == 0.0:
        bandwidth = cluster.estimate_bandwidth(data, quantile=0.2)
    print(f"{bandwidth=:0.2f}")

    centroids = np.copy(data)  # Inicializar con cada punto como un centroide
    for iteration in range(max_iter):
        new_centroids = np.zeros_like(centroids)
        for i, x in enumerate(centroids):
            distances = np.linalg.norm(X - x, axis=1)
            # gaussian filter
            weights = np.exp(-0.5 * (distances ** 2) / (bandwidth ** 2))
            new_centroids[i] = np.sum(X * weights[:, None], axis=0) / np.sum(weights)
        
        shift_distances = np.linalg.norm(new_centroids - centroids, axis=1)

        plotDataClusters(data, centroids=centroids, new_centroids=new_centroids,  fig=fig,
                    title=f"Mean-Shift Clustering - Step {iteration+1:2} - {bandwidth=:.3f}")
    
        input('\n...Press return to continue...')

        if np.max(shift_distances) < tol:
            break
        centroids = new_centroids
    
    # Group very close centroids (<bandwidth/2)
    final_centroids = []
    for c in centroids:
        if not any(np.linalg.norm(c - fc) < bandwidth / 2 for fc in final_centroids):
            final_centroids.append(c)
    centroids = np.array(final_centroids)
    numClusters = len(centroids)

    # label clusters data
    # Compute matrix distance between centroids and sample data (scipy.spatial.distance)
    D = distance.cdist(data, centroids, metric='euclidean') 

    # find nearest centroid for every sample  
    labels = D.argmin(axis=1) 
    distances = D.min(axis=1)
    compactness= np.mean(distances)

    print(f"{compactness=:.2f}")
    print(f"Centroids:\n{centroids}")
    # plot clusters labels
    plotDataClusters(data, centroids=centroids, labels=labels, fig=fig,
                        title=f"Mean-Shift Clustering:  Compactness={compactness:.2f} - {numClusters} clusters")
    input('...Press return to continue...')

    return compactness, labels, centroids
#end meanShiftStep()


# Mean-Shift clustering sklearn.cluster
def meanShift(data, bandwidth=0.0, fig=1):
    
    plotDataClusters(data, fig=fig, title="Data")  # plot data
    input("...Press return to continue...")

    # Mean-Shift clustering:  sklearn.cluster
    # bandwidth for distance kernel. if 0 it is estimated with sklearn.cluster.estimate_bandwidth
    if bandwidth == 0.0:
        bandwidth = cluster.estimate_bandwidth(data, quantile=0.35)
    print(f"{bandwidth=:.2f}")

    print(f"Mean-Shift clustering (scklearn): {bandwidth=:.3f} ") 
    mean_shift = cluster.MeanShift(bandwidth=bandwidth)
    mean_shift.fit(data)

    # clusters and centroids
    labels = mean_shift.labels_  # clusters labels
    centroids = mean_shift.cluster_centers_  # clusters centroids

    labelRange = list(set(labels))
    numClusters = len(centroids)   # effective number of clusters

    print(f"{numClusters} clusters: {labels}")
    
    # Compactness: mean distance to centroids for each cluster
    # Compute matrix distance between centroids and sample data (scipy.spatial.distance)
    D = distance.cdist(data, centroids, metric='euclidean') 

    # find distance nearest centroid for every sample  
    distances = D.min(axis=1)
    compactness= np.mean(distances)

    print(f"{compactness=:.2f}")
    print(f"Centroids:\n{centroids}")
    plotDataClusters(data, labels=labels, centroids=centroids, fig=fig,
                    title=f"Mean-Shift Clustering (scklearn) - Compactness={compactness:.2f} - {numClusters} clusters")
    
    return compactness, labels, centroids
# end meanShift()


#########################################

# load  data
#X = np.array([[5,3], [10,15], [15,12], [24,10], [30,30], [15,60], [25,75], [15,92], 
#              [30,55], [35,85], [85,70], [71,80], [60,78], [65,55], [80,91],]).astype(float)

# sklearn.datasets
X, _ = datasets.make_blobs(n_samples=100, centers=3, cluster_std=1.5, random_state=45)

print(f'\nMean-Shift Clustering Setp by Step:\n')
meanShiftStep(X, bandwidth=bandwidth)

input('\n...Press return to finish program...')

print(f'\nMean-Shift Clustering (sklearn.cluster):\n')
meanShift(X, bandwidth=bandwidth)

input('\n...Press return to finish program...')


