"""
*************************************************************************
// Example program using OpenCV library
//      python >3.7 - OpenCV 4.5
// @file	e6.py
// @author Luis M. Jimenez
// @date 2022
//
// @brief Course: Computer Vision (1782)
// Dept. of Systems Engineering and Automation
// Automation, Robotics and Computer Vision Lab (ARVC)
// http://arvc.umh.es
// University Miguel Hernandez
//
// @note Description:
//  - Load training data from Titere/Weka
//  - Train classifiers: SVM, Bayes, MLP, DTree
//  - Validate classifier: Confusion matrix/Precision/Recall/TP-Rate/FP-Rate/F-Score
//
*************************************************************************
"""

# Import libraries
import cv2 as cv
import numpy as np
import argparse
import titere

# -----------------------------------------
# Global variables
# -----------------------------------------

DATA_FILE = 'train.arff'             # default data file
TEST_FILE = 'test.arff'             # default data file
CLASSIFIER = 'mlp'              # svm | bayes | dtree | mlp

# check command line parameters (imageFile)
parser = argparse.ArgumentParser(description='OpenCV example: classification')
parser.add_argument('dataFile', nargs='?', default=DATA_FILE,  help='training file')
parser.add_argument('testFile', nargs='?', default=TEST_FILE,  help='Test file')
parser.add_argument('-c', dest='classifier', type=str, default=CLASSIFIER, metavar='classifier', help='svm | bayes | dtree | mlp')

DATA_FILE = parser.parse_args().dataFile
TEST_FILE = parser.parse_args().testFile
CLASSIFIER = parser.parse_args().classifier.lower()     # classifier (lowercase)

# -----------------------------------------
# Put here the code to Initialize objets
# -----------------------------------------
# import titere
attributes = ('Compacidad', 'Excentricidad',  'Rel_Invar_1', 'Rel_Invar_2')
labelColum = 'Pieza'

trainDataDict = titere.readWekaData(DATA_FILE, labelColum, attributes)
trainData = trainDataDict['dataMat']
trainLabels = trainDataDict['label']        # vector with numeric labels (starts at 1) 
trainLabelsName = trainDataDict['labelName'] # vector with string labels  
labelRange = trainDataDict['labelRange']    # list with the numeric/string label association

testDataDict = titere.readWekaData(TEST_FILE, labelColum, attributes, labelRange)
testData = testDataDict['dataMat']
testLabels = testDataDict['label']          # vector with numeric labels (starts at 1) 
testLabelsName = testDataDict['labelName'] # vector with string labels  

# display data
print(f"{attributes} | {labelColum} (Num/Label)")
print(np.column_stack((trainData, trainLabels, trainLabelsName)))
print(f"{labelRange=}")


# -----------------------------------------
# Put your image processing code here
# -----------------------------------------

if CLASSIFIER == 'svm':
    #----------------
    # SVM classifier
    #----------------

    # create classifier
    svm = cv.ml.SVM.create()

    # Configure SVM Classifier

    # SVM Type: cv.ml.SVM_C_SVC | SVM_NU_SVC | SVM_ONE_CLASS | SVM_EPS_SVR | SVM_NU_SVR
    svm.setType(cv.ml.SVM_C_SVC)

    # Kernel type: cv.ml.SVM_LINEAR | SVM_POLY | SVM_RBF | SVM_SIGMOID | SVM_CHI2 | SVM_INTER | SVM_CUSTOM
    svm.setKernel(cv.ml.SVM_RBF)

    svm.setC(1.0)         # classifier. penalty multiplier for outliers
    svm.setGamma(1.0)     # RBF parameter
    svm.setTermCriteria((cv.TermCriteria_MAX_ITER | cv.TermCriteria_EPS, 10000, 1e-6))

    # Train SVM
    #retval = svm.train(trainData, cv.ml.ROW_SAMPLE, trainLabels)

    # Train SVM with auto tunning of SVM params: C, Gamma
    retval = svm.trainAuto(trainData, cv.ml.ROW_SAMPLE, trainLabels, 5)	    # kFold=5
    print(f"SVM - C: {svm.getC()} Gamma: {svm.getGamma()}")

    # Save trained classifier
    svm.save("SVM.json")

    # get the support vectors
    sv = svm.getSupportVectors()
    print("\n##################")
    print("SVM classifier")
    print(f"Support Vectors: ", len(sv))

    # Test SVM (Prediction)
    retval, prediction = svm.predict(testData)

    prediction = prediction[:, 0].astype(int)   # convert to integer vector
    print(f"Predicted: ", prediction)
    print(f"Real:      ", testLabels)

    titere.showPerformance(testLabels, prediction, labelRange)

elif CLASSIFIER == 'bayes':
    #----------------
    # Bayes classifier
    #----------------
    bayes = cv.ml.NormalBayesClassifier.create()
    retval = bayes.train(trainData, cv.ml.ROW_SAMPLE, trainLabels)
    bayes.save("NormalBayes.json")    # Save trained classifier

    # Test Normal Bayes (Prediction)
    #retval, prediction = bayes.predict(testData)
    retval, prediction, prob = bayes.predictProb(testData)

    prediction = prediction[:, 0].astype(int)   # convert to integer vector
    print("\n##################")
    print("Normal Bayes classifier")
    print(f"Predicted: ", prediction)
    print(f"Real:      ", testLabels)
    print(f"Probability: [", end='')
    for r in range(len(prediction)):
        print(f"{prob[r,prediction[r]-1]:.3}, ", end='')
    print("]")

    titere.showPerformance(testLabels, prediction, labelRange)

elif CLASSIFIER == 'dtree':
    #------------------------
    # Decision Tree classifier
    #-----------------------
    dtree = cv.ml.DTrees.create()

    # Set up DTree's parameters
    dtree.setCVFolds(0)        #  If cv_folds > 1 then prune a tree with K-fold cross-validation where K is equal to cv_folds.
    dtree.setMaxDepth(10)       # Max depth of the decision tree
    dtree.setUse1SERule(False)     # If true then a pruning will be harsher. This will make a tree more compact and more resistant to the training data noise but a bit less accurate.
    dtree.setTruncatePrunedTree(True)  # If true then pruned branches are physically removed from the tree.
    dtree.setUseSurrogates(False)   #  If true then surrogate splits will be built. These splits allow to work with missing data and compute variable importance correctly.

    retval = dtree.train(trainData, cv.ml.ROW_SAMPLE, trainLabels)
    dtree.save("DTree.json")    # Save trained classifier

    # Test Decision Tree (Prediction)
    retval, prediction = dtree.predict(testData)

    prediction = prediction[:, 0].astype(int)   # convert to integer vector
    print("\n##################")
    print("Decision Tree classifier")
    print(f"Predicted: ", prediction)
    print(f"Real:      ", testLabels)

    titere.showPerformance(testLabels, prediction, labelRange)

elif CLASSIFIER == 'mlp':
    # ------------------------
    # MLP Neural Network classifier
    # -----------------------
    mlp = cv.ml.ANN_MLP_create()

    # Set up MLP parameters
    # FIRST layerSizes: integer vector specifying the number of neurons in each layer including the input and output layers.
    mlp.setLayerSizes(np.array([trainData.shape[1], 5, len(labelRange)]))  # 1 hidden layer  with 5 neurons

    # Activation Function: cv.ml.ANN_MLP_IDENTITY | ANN_MLP_SIGMOID_SYM | ANN_MLP_GAUSSIAN  | ANN_MLP_RELU  | ANN_MLP_LEAKYRELU
    # SIGMOID: param1: alpha, param2: beta
    mlp.setActivationFunction(cv.ml.ANN_MLP_SIGMOID_SYM, param1=0.1, param2=1.5)

    # Training methods: cv.ml.ANN_MLP_BACKPROP |  ANN_MLP_RPROP  | ANN_MLP_ANNEAL
    mlp.setTrainMethod(cv.ml.ANN_MLP_BACKPROP)

    mlp.setTermCriteria((cv.TermCriteria_MAX_ITER | cv.TermCriteria_EPS, 10000, 1e-6))

    # adapt trainLabels vector to a float32 matrix with one colum per class
    trainLabelsMat = np.zeros((len(trainLabels), len(labelRange)), dtype=np.float32)
    for row, val in enumerate(trainLabels):
        trainLabelsMat[row, val-1] = 1.0

    retval = mlp.train(trainData, cv.ml.ROW_SAMPLE, trainLabelsMat)

    # Alternative using a  TrainData Object
    # trainDataObj = cv.ml.TrainData.create(trainData, cv.ml.ROW_SAMPLE, trainLabelsMat)
    #retval = mlp.train(trainDataObj)
    
    mlp.save("MLP.json")    # Save trained classifier

    # Test MLP Neural Network (Prediction)
    retval, predictionMat = mlp.predict(testData)

    # convert predictionMat to vector: search for max response across classes (colums)
    prediction = predictionMat.argmax(axis=1) + 1

    print("\n##################")
    print("MLP Neural Network classifier")
    print(f"Predicted: ", prediction)
    print(f"Real:      ", testLabels)

    titere.showPerformance(testLabels, prediction, labelRange)

# -----------------------------------------
# Put your visualization code here
# -----------------------------------------




# -----------------------------------------
# free windows and camera resources
# -----------------------------------------
pass
