"""
// Python module with functions to read data files from Titere/Weka
//      python >3.8
// @file	titere.py
// @author Luis M. Jimenez
// @date 2022
// Dept. of Systems Engineering and Automation
// Automation, Robotics and Computer Vision Lab (ARVC)
// http://arvc.umh.es
// University Miguel Hernandez
// Functions:
    readWekaData(file, labelColum="", attributes=[], labelRange=[])
    readWekaDataTitere(file, labelRange=[])
    showPerformance(label, prediction, labelRange)
    plotLearningCurves(metrics)
"""

# Example:
# import titere
#
# attributes = ('Compacidad', 'Excentricidad',  'Rel_Invar_1', 'Rel_Invar_2')
# labelColum = 'Pieza'
# data = titere.readWekaData('datos1.arff', labelColum, attributes)

# display data
# print(f"{data['attributes']} | {data['labelColum']} (Num/Label)")
# print(np.column_stack( (data['dataMat'], data['label'], data['labelName']) ))
# print(f"{data['labelRange']=}")


import numpy as np
from scipy.io import arff
import matplotlib.pyplot as plt
import matplotlib

"""
Read Weka arff data file
inputs: 
    file: arff file 
    labelColum:  str column name for response (label) (if empty uses the last column)
    attributes: (list/tuple of str) with selected column names (all if it is None or empty list)
    labelRange: list with the range of labels (if empty, it is obtained from data) numeric/string label association
outputs: dictionary with the following elements:
        'dataMat': dataMat,     # Matrix with attributes data
        'label': label,         # vector with numeric labels starting at 1
        'labelName': labelName, # vector with string labels  
        'labelRange': labelRange,    # list with the numeric/string label association
        'attributes': attributes,   # selected attributes
        'labelColum': labelColum   # selected label colum
"""
def readWekaData(file, labelColum="", attributes=[], labelRange=[]):
     # Load weka data file (.arff)
    data, metadata = arff.loadarff(file)
    attrNames = metadata.names()
    attrTypes = metadata.types()
    # print(attrTypes)
    # print(attrNames)

    # Convert to lowercase keys dictionary
    # a_lower = {k.lower():v for k,v in data.items()}

    # Set Label colum if empty
    if labelColum == "": 
        labelColum = attrNames[-1]     # if empty uses the last colum
  
    # extract features and convert the structured array (list of tuples) to a standard np.array
    # in the structured array, data columns can be indexed by feature name
    if len(attributes) == 0:   # empty list, add all attrNames colums but labelColum
        attributes.extend(attrNames)
        attributes.remove(labelColum)

    dataMat = np.empty((len(data), len(attributes)), dtype=np.float32)
    for idx, item in enumerate(attributes):
        dataMat[:, idx] = data[item].astype(np.float32)

    # Label vector string
    labelName = data[labelColum].astype(str)

    # Label vector class number (starting at 1)
    if len(labelRange) == 0:
        labelRange = set(labelName)      # range of labels
        labelRange = sorted(labelRange)  # sort label set -> list

    # Label vector (int) class number (starting at 1) referred to labelRange
    label = np.zeros(len(labelName), dtype=int)    # init  label vector to zero
    for idx, item in enumerate(labelRange):
        label[labelName==item] = idx+1

    # build result dictionary
    res_data = {
        'dataMat': dataMat,     # Matrix with attributes data
        'label': label,         # vector with numeric labels starting at 1
        'labelName': labelName, # vector with string labels  
        'labelRange': labelRange,    # list with the numeric/string label association
        'attributes': attributes,   # selected attributes
        'labelColum': labelColum   # selected label colum
    }
    return res_data
    
# End of Function readWekaData


"""
Read Weka arff data file: retrieves only this columns generated by TITERE:
        ('Compacidad', 'Excentricidad',  'Rel_Invar_1', 'Rel_Invar_2')
    Label colum: 'Pieza'
"""
def readWekaDataTitere(file, labelRange=[]):
     
    attributes = ('Compacidad', 'Excentricidad',  'Rel_Invar_1', 'Rel_Invar_2')
    labelColum = 'Pieza'

    return readWekaData(file, labelColum, attributes, labelRange)

# End of Function readWekaDataTitere


"""
Calculates Confusion Matrix, Precision/Recall/F-Score, TP-Rate, FP-Rate
Shows performace 
inputs: 
    label: list/ndarray test label (row vector) 
    prediction:  list/ndarray  predicted label (row vector)
    labelRange: list with the range of labels (numeric/string label association)
outputs: accuracy, confusionMat
"""
def showPerformance(label, prediction, labelRange):
    nc = len(labelRange)
    confusionMat = np.zeros((nc, nc), dtype=np.int32)

    # Confusion matrix: (rows: actual class, cols: predicted class)
    for idx,predictedClass in enumerate(prediction):
        actualClass = label[idx]
        confusionMat[actualClass-1, predictedClass-1] += 1

    test_count = confusionMat.sum()
    accuracy = confusionMat.trace(dtype=np.float32) / test_count  # sum diagonal elements

    # Show confusion matrix data

    print("\n##################")
    print("Class Labels:")
    print("-------------")
    for r in range(nc):
        print(f"Class ({r+1}): {labelRange[r]}")

    print("\nConfusion Matrix:")
    print("-----------------")
    print(f"                Prediction (colums)")
    print(f"         ", end='')
    for c in range(nc):
        print(f"  ({c+1})", end='')
    print(f"   |Total")

    print("         ", end='')
    for c in range(nc):
        print(" ----", end='')
    print("    -----")

    for r in range(nc):
        print(f"Class ({r+1}): ", end='')
        for c in range(nc):
            print(f"{confusionMat[r,c]:3}", end='')
            if c < nc - 1: print(", ", end='')
        print(f"   | ({confusionMat[r,:].sum():2})")

    print("         ", end='')
    for c in range(nc):
        print(" ----", end='')
    print("    -----")

    print("  Total  ", end='')
    for c in range(nc):
        print(f" ({confusionMat[:,c].sum():2})", end='')
    print(f"   | ({confusionMat.sum():2})")


    print("\n           F-Score | Precision | Recall/TPRate |  FPRate  ")
    print("----------------------------------------------------------")
    for r in range(nc):
        recall = 0.0; precision = 0.0; fScore = 0.0
        if confusionMat[r,:].sum() != 0:  recall = float(confusionMat[r, r]) / confusionMat[r,:].sum()
        if confusionMat[:,r].sum() != 0:  precision = float(confusionMat[r, r]) / confusionMat[:,r].sum()
        if (recall + precision) != 0: fScore = 2 * recall*precision / (recall + precision)

        fpRateNum = 0.0;	fpRateDen = 0.0;	fpRate = 1.0
        for j in range(nc):
            if j != r:
                fpRateNum += confusionMat[j, r]
                fpRateDen += confusionMat[j,:].sum()

        if fpRateDen != 0: fpRate = fpRateNum / fpRateDen
        print(f"Class ({r + 1}): {fScore:7.3} | {precision:9.3} | {recall:9.3}     | {fpRate:7.3}")


    print(f"\nAccuracy: {accuracy*100}%")

    return accuracy, confusionMat

# End of Function showPerformance



# Plot Training Metrics (dictionary)
def plotLearningCurves(metrics):

    plt.ion()   # interactive mode, non blocking plots

    for metric_name, values in metrics.items():
        plt.plot(values, '.-', label=metric_name)

    plt.title('Training Metrics')
    plt.xlabel('Epoch')
    plt.ylabel('Value')
    plt.legend(loc='best')
    plt.pause(0.1)    # non blocking call to GUI event manager to uptade the window
    
# end of PlotLearningCurves function