"""
//*************************************************************************
// Example program using OpenCV library
//					
// @file	agnes.py					
// @author Luis M. Jimenez
// @date 2022
//
// @brief Course: Computer Vision (1782)
// Dpo. of Systems Engineering and Automation
// Automation, Robotics and Computer Vision Lab (ARVC)
// http://arvc.umh.es
// University Miguel Hernandez
//
// @note Description: 
//	- Hierarchical clustering:  Aglomerative Nesting with centroid selection
//
//  Dependencies:
//     scipy, matplotlib, numpy, argparse
//*************************************************************************
"""
import numpy as np
from scipy.spatial import distance
from scipy import cluster
import matplotlib.pyplot as plt
import argparse

# Default number of clusters
numClusters = 3

# Read parameters
parser = argparse.ArgumentParser(description='AGNES Hierarchical Clustering')
parser.add_argument('-c', dest='numClusters', type=int, default=numClusters,  help='Number of Clusters (Default 3)')
parser.print_help()

numClusters = parser.parse_args().numClusters

# Plot Data / Clusters
def plotDataClusters(data, labels=None, centroids=None, fig=None, title="Data", show_ids=True):

    plt.figure(fig, figsize=(8, 6))
    plt.clf()
    plt.title(title)
    plt.xlabel("x1")
    plt.ylabel("x2")
    plt.gcf().canvas.manager.set_window_title('AGNES: Aglomerative Nesting Clustering')
    if labels is not None: 
        plt.scatter(data[:,0], data[:,1], marker='.', c=labels, cmap='rainbow', label="Data")
    else:  plt.scatter(data[:,0], data[:,1], marker='.', label="Data")

    # Plot items ids
    if show_ids: 
        for i, (x, y) in enumerate(zip(data[:, 0], data[:, 1])):
            plt.annotate(str(i), xy=(x, y), xytext=(-3, 3), textcoords='offset points', ha='right', va='bottom')

    # Plot Centroids
    if centroids is not None:
        labelRange = range(len(centroids))
        plt.scatter(centroids[:,0],centroids[:,1], marker='x', s=40, c=labelRange, cmap='rainbow', label="Centroids")
  
    if centroids is not None:
        plt.legend(loc='lower right')

    plt.draw()
    plt.pause(0.1)  # gives control to GUI event manager to show the plot
# end plotDataClusters

# Plot Dendogram
def plotDendogram(Z, labels=None, threshold=None, fig=2, title="Dendogram"):  

    # Plot Dendogram
    plt.figure(fig, figsize=(10, 7))
    cluster.hierarchy.dendrogram(Z, orientation='top', labels=labels, distance_sort='descending', show_leaf_counts=True)
    plt.title(title)
    plt.xlabel("sample id")
    plt.ylabel("distance")
    plt.gcf().canvas.manager.set_window_title('AGNES: Aglomerative Nesting Clustering')

    if threshold is not None:
        plt.axhline(y=threshold, color="black", linestyle="--")
    plt.draw()
    plt.pause(0.1)  # gives control to GUI event manager to show the plot
# end plotDendogram


# AGNES  Hierarchical clustering 
def agnes(data, numClusters, fig=1):
    
    plotDataClusters(data, fig=fig, title="Data")  # plot data
    input("...Press return to continue...")

    # Hierarchical clustering:  Aglomerative Nesting
    # build  condensed distance matrix 
    D = distance.pdist(data, metric='euclidean')

    # build grouping tree (dendrogram)
    Z = cluster.hierarchy.linkage(D, method='single')

    plotDendogram(Z, threshold=24, fig=fig+1)   # plot Dendogram
    input("...Press return to continue...")

    # cluster data from dendrogram
    # criterion: 'maxclust'  -  t: max number of clusters
    # criterion: 'distance'  -  t: max cophenetic distance
    labels = cluster.hierarchy.fcluster(Z, t=numClusters, criterion='maxclust')
    labelRange = list(set(labels))
    numClusters = len(labelRange)   # effective number of clusters

    print(f"{numClusters} clusters: {labels}")
    plotDataClusters(data, labels, fig=fig,
                    title=f"Hierarchical Clustering - {numClusters} clusters")
    input("...Press return to continue...")
    
    # select cluster's centroid (min mean distace)
    print(f"Centroids:")
    Dm = distance.squareform(D)
    centroids = []
    compactness = 0.0
    for idLabel in labelRange:
        cluster_dist = Dm[labels==idLabel,:]                # select cluster rows from Dm
        cluster_dist = cluster_dist[:,labels == idLabel]    # select cluster cols from Dm
        
        distMed = np.mean(cluster_dist, axis=0)    # mean distance to other cluster members
        minId = distMed.argmin()    # select cluster centroid as the member with the minimum mean distance
        compactness += distMed.min()

        # translate minId to labels index
        idx = np.where(labels == idLabel)[0]
        print(f"{idLabel=} - {idx[minId]=:2} - {distMed.min()=}")

        centroids.append(data[idx[minId]])  
    # end for
         
    centroids = np.array(centroids)
    compactness /= numClusters

    print(f"{compactness=:.2f}")
    plotDataClusters(data, labels, centroids, fig=fig,
                    title=f"Hierarchical Clustering - Compactness={compactness:.2f} - {numClusters} clusters")
    
    return compactness, labels, centroids
# end agnes()


#########################################

# load  data
X = np.array([[5,3], [10,15], [15,12], [24,10], [30,30], [15,60], [25,75], [15,92], 
              [30,55], [35,85], [85,70], [71,80], [60,78], [65,55], [80,91],]).astype(float)

print(f'\nAGNES Hierarchical clustering: try with {numClusters} clusters\n')
agnes(X, numClusters=numClusters, fig=1)

input('\n...Press return to finish program...')




