"""
//*************************************************************************
// Example program using OpenCV library
//					
// @file	adaboost.py					
// @author Luis M. Jimenez
// @date 2025
//
// @brief Course: Computer Vision (1782)
// Dpo. of Systems Engineering and Automation
// Automation, Robotics and Computer Vision Lab (ARVC)
// http://arvc.umh.es
// University Miguel Hernandez
//
// @note Description: 
//	- AdaBoost Classifier:  multiple weak classifiers (single feature binary decision trees) 
//
//  Dependencies:
//     scipy, sklearn, matplotlib, numpy, argparse
//*************************************************************************
"""
import numpy as np
from sklearn import tree
import matplotlib.pyplot as plt
import argparse

# Default number of weak classifiers
numWeakClassifiers = 10

# Read parameters
parser = argparse.ArgumentParser(description='AdaBoost Classifier')
parser.add_argument('-c', dest='numWeakClassifiers', type=int, default=numWeakClassifiers, 
                    help='Number weak classifiers (single feature binary trees) (Default 10)')
parser.print_help()

numWeakClassifiers = parser.parse_args().numWeakClassifiers


# Plot Data / Classification rules
def plotData(data, labels=None, weights=None, val_data=None, val_labels=None, 
              weakClassifiers=None, fig=None, title="Data", show_ids=False):

    # Plot Data 
    plt.figure(fig, figsize=(9, 7))
    plt.clf()
    plt.title(title)
    plt.xlabel("x1")
    plt.ylabel("x2")
    plt.gcf().canvas.manager.set_window_title('AdaBoost Classifier')

    # plot decision rules
    if (weakClassifiers is not None):
        if not isinstance(weakClassifiers, list): weakClassifiers = [weakClassifiers]
        all_data = data if val_data is None else np.concatenate((data, val_data), axis=0)
        x_min, x_max = all_data[:, 0].min() - 1, all_data[:, 0].max() + 1
        y_min, y_max = all_data[:, 1].min() - 1, all_data[:, 1].max() + 1
        for i, wc in enumerate(weakClassifiers):
            transp = 0.04; 
            if wc['order']==True: colors = ('blue', 'red')    # class 1 (red) to the right side val>threshold
            else: colors = ('red', 'blue')              # class 1 (red) to the left side val<threshold
            if wc['feature'] == 0:    # Decision in x1
                plt.axvline(x=wc['threshold'], color="black", linestyle="--", label=f"x1 ≤ {wc['threshold']:.2f} ({wc['alpha']:.2f})")
                plt.fill_betweenx([y_min, y_max], x_min, wc['threshold'], color=colors[0], alpha=transp)  # color lef side
                plt.fill_betweenx([y_min, y_max], wc['threshold'], x_max, color=colors[1], alpha=transp)  # color right side
            elif wc['feature'] == 1:  # Decision in x2
                plt.axhline(y=wc['threshold'], color="black", linestyle="--", label=f"x2 ≤ {wc['threshold']:.2f} ({wc['alpha']:.2f})")
                plt.fill_between([x_min, x_max], y_min, wc['threshold'], color=colors[0], alpha=transp)  # color lower side
                plt.fill_between([x_min, x_max], wc['threshold'], y_max, color=colors[1], alpha=transp)  # color upper side
        
        if len(weakClassifiers)>1: plt.legend(loc='lower right')

    # Normalize scale for data points to avoid too big circles [min_size,max_size]
    min_size=20; max_size=500
    if weights is not None:
        min_weight, max_weight = np.min(weights), np.max(weights)
        if min_weight == max_weight:  
            scaled_weights = np.ones(len(weights)) * min_size*2
        else:
            scaled_weights = min_size + (weights - min_weight) / (max_weight - min_weight) * (max_size - min_size)
    else:
        scaled_weights = np.ones(len(data)) * min_size*2

    # Plot data
    if labels is not None:
        plt.scatter(data[:,0], data[:,1], s=scaled_weights, marker='o', c=labels, cmap='coolwarm', alpha=0.6, edgecolors='k')
    else:
        plt.scatter(data[:,0], data[:,1], s=scaled_weights, marker='o', alpha=0.6, edgecolors='k')

    # Plot Validation Data (predict)
    if val_data is not None and val_labels is not None:
        plt.scatter(val_data[:,0],val_data[:,1], marker='x', s=40, c=val_labels, cmap='coolwarm', label="Validation Data")
        
        # Plot items ids
        if show_ids:
            for i, (x, y) in enumerate(zip(val_data[:, 0], val_data[:, 1])):
                plt.annotate(str(i), xy=(x, y), xytext=(-3, 3), textcoords='offset points', ha='right', va='bottom')
        plt.legend(loc='lower right')
        

    plt.draw()
    plt.pause(0.1)  # gives control to GUI event manager to show the plot
# end plotData


class AdaBoost:
    def __init__(self, n_estimators=50, fig=None):
        self.n_estimators = n_estimators    # max number of weak estimators
        self.fig = fig                      # plot window
        self.inputLabelsSigned = False      # stores the type of lables [-1,1] or [0,1] used in fit to return the same type
        self.n_features = 0                 # number of features
        self.weakModels = []                # weak models (binary decision trees)
        self.weakModelsData = []            # dictionary list: 'feature', 'threshold'
                                            # 'alpha' (classification weight/error for each estimator)
                                            # 'order' determines if lable class (1) is at right side (feat>threshod) or at the left side (feat<threshold) 
 
    def fit(self, X, y, step_delay=0):
        n_samples, self.n_features = X.shape
        if set(np.unique(y)) == {0, 1}: 
            y = 2*y-1 # adapt labels fron [0,1] to [-1,1]
            self.inputLabelsSigned = False
        else: self.inputLabelsSigned = True
        w = np.ones(n_samples)/n_samples        # inital weigths
        
        for i in range(self.n_estimators):
            model = tree.DecisionTreeClassifier(max_depth=1)    # one feature binary decision tree
            model.fit(X, y, sample_weight=w)    # train weigthed data
            
            modelData = dict()
            modelData['feature'] = model.tree_.feature[0]       # feature selected for the Decision Tree
            modelData['threshold']  = model.tree_.threshold[0]     # threshold selected for the Decision Tree
        
            # determines if lable class (1) is at right side (feat>threshod) or at the left side (feat<threshold) 
            right_class = np.argmax(model.tree_.value[model.tree_.children_right[0]])
            modelData['order'] = (right_class==1)    # True if class 1 is at right side  (feat>threshod)
            
            # prediction error for weak classifier
            predictions = model.predict(X)
            err = np.sum(w * (predictions != y)) / np.sum(w)
            modelData['alpha'] = 0.5 * np.log((1 - err) / (err + 1e-10))

            # print and plot data before update weigths
            title= f"[Tree {i+1}]   {err=:.2f} - alpha: {modelData['alpha']:.2f} - feature: x{modelData['feature']+1} - threshold: {modelData['threshold']}"
            plotData(X, y, fig=self.fig, weights=w, weakClassifiers=modelData, title=title)
            print(title)

            if step_delay<0: input("Press Enter to continue...")
            else: plt.pause(step_delay)  # wait each step

            # update weigths for each sample
            w = w * np.exp(-modelData['alpha'] * y * predictions)
            w = w / np.sum(w)

            # store weak classifier data
            self.weakModels.append(model)
            self.weakModelsData.append(modelData)
        #end for
        self.n_estimators = len(self.weakModels)    # update actual number of estimators

        title= f"AdaBoost Combined Decision Rule ({self.n_estimators} weak classifiers)"
        plotData(X, y, fig=self.fig, weights=w, weakClassifiers=self.weakModelsData, title=title)
        print(title)
        input("...Press return to continue...")

        return self.n_estimators
    # end fit

    def predict(self, X):
        strong_preds = np.zeros(X.shape[0])
        for model, modelData in zip(self.weakModels, self.weakModelsData):
            strong_preds += modelData['alpha'] * model.predict(X)
        strong_preds = np.sign(strong_preds)    #  sign of the prediction
       
        # adapt label output to the same type supplied for training
        if not self.inputLabelsSigned:  strong_preds = (strong_preds+1)/2

        return strong_preds.astype(int)
    #end predict
# end AdaBoost class


#########################################

# load  data
X = np.array([[5,3], [10,15], [15,12], [24,10], [30,30], [15,60], [25,75], [50,30], 
              [15,92], [30,55], [35,85], [85,70], [71,80], [60,78], [65,55], [80,91] ]).astype(float)
y = np.array([ 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1])

X_val = np.array([[40,35],[60,35],[35,65],[15,80]]).astype(float)
y_val = np.array([ 0, 0, 1, 1])

##########################################
print(f'\nAdaBoost Classifier:\n') 
plotData(X, y, fig=1, title="Data")  # plot data
input("...Press return to continue...")

model = AdaBoost(n_estimators=numWeakClassifiers)
numWeakClassifiers = model.fit(X, y, step_delay=-1)

predictions = model.predict(X_val)
accuracy = np.sum(predictions == y_val)/len(predictions)

title = f"AdaBoost: Validation Data {accuracy=:.2f}  ({numWeakClassifiers} weak classifiers)"
plotData(X, y, val_data=X_val, val_labels=y_val, 
         weakClassifiers=model.weakModelsData, title=title)

print(title)
print(f"Test Accuracy: {accuracy}")
print(f"Validation data predictions: {predictions}")
print(f"Validation data labels     : {y_val}")
input('\n...Press return to finish program...')




