"""
// Python module with functions to train and predict a torch model and ploting metrics
//      python >3.8
// @file	torch_util.py
// @author Luis M. Jimenez
// @date 2025
// Dept. of Systems Engineering and Automation
// Automation, Robotics and Computer Vision Lab (ARVC)
// http://arvc.umh.es
// University Miguel Hernandez
//
// Functions:
    trainModel(model, train_loader, val_loader, criterion, optimizer, epochs, device)
    predictModel(model, data_loader, device)
    showPerformance(label, prediction, labelRange)
    plotLearningCurves(metrics)
"""
import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib

"""
Train model Function
inputs: 
    model:          torch.nn.Model 
    train_loader:   traning data torch.utils.data.DataLoader
    val_loader:     validation data torch.utils.data.DataLoader
    criterion:      Loss function torch.nn
    optimizer:      optimizer torch.optim
    epochs:         number of epochs
    device:         acceleration device torch.device('cpu', 'cuda', ...) Default 'cpu'
outputs: dictionary with the training metrics:
    { 'loss':         per epoch training data loss
      'val_loss':     per epoch validation data loss
      'accuracy':     per epoch training data accuracy
      'val_accuracy'  per epoch validation data accuracy
    }
"""
def trainModel(model, train_loader, val_loader, criterion, optimizer, epochs=10, device=torch.device("cpu")):
    # init metrics history
    history = {
        'loss': [],
        'val_loss': [],
        'accuracy': [],
        'val_accuracy': []
    }

    print(f"\nTraining Model:")
    for epoch in range(epochs):
        model.train()   # configure model for training
        running_loss, correct, total = 0.0, 0, 0

        print(f"Epoch [{epoch + 1}/{epochs}] ", end='')
        for batch_idx, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)
       
            # Forward pass  
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # Backward pass
            optimizer.zero_grad()   # clear out batch accumulative gradients before backpropagation
            loss.backward()         # loss backpropagation    
            optimizer.step()        # update model parameters

             # Accumulate batch loss/accuray metrics
            running_loss += loss.item() * inputs.size(0)    # add batch loss metric
            _, predicted = torch.max(outputs, 1)                    # categorical to vector
            if labels.dim() == 2 : _, labels = torch.max(labels, 1) # categorical to vector

            correct += (predicted == labels).sum().item()   # add batch accuracy metric
            total += labels.size(0)

            # Verbose: print current epoch metrics
            print(f"\rEpoch [{epoch + 1}/{epochs}] - Batch [{batch_idx:02}] | "
                  f"Train Loss: {running_loss/total:.4f}, Train Acc: {correct/total:.4f} ", end='', flush=True )
        #end train batch loop

        train_loss = running_loss / total       # training epoch loss
        train_acc = correct / total             # training epoch accuracy

        # Validation metrics
        model.eval()    # configure model for inference
        val_loss, val_correct, val_total = 0.0, 0, 0
        with torch.no_grad():   # disable gradient calculation for inference
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)

                val_loss += loss.item() * inputs.size(0)            # add batch loss metric
                _, predicted = torch.max(outputs, 1)                    # categorical to vector
                if labels.dim() == 2 : _, labels = torch.max(labels, 1) # categorical to vector

                val_correct += (predicted == labels).sum().item()   # add batch accuracy metric
                val_total += labels.size(0)
                print(f".", end='', flush=True)     # print progressing status
            #end validation batch loop
        #end with torch.no_grad()
        
        val_loss /= val_total                   # validation epoch loss
        val_acc = val_correct / val_total       # validation epoch accuracy

        # add epoch metrics to history dictionary
        history['loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['accuracy'].append(train_acc)
        history['val_accuracy'].append(val_acc)

        # Verbose: print current epoch metrics
        print(f"\rEpoch [{epoch + 1}/{epochs}] - "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
    # end for epoch  loop 

    return history
#end def trainModel()


"""
Predict model Function
inputs: 
    model:          torch.nn.Model 
    data_loader:    input data torch.utils.data.DataLoader
    device:         acceleration device torch.device('cpu', 'cuda', ...) Default 'cpu'
outputs:
        y_pred:    list with class predictions
        prob:      list with probabilities
        y_test:    list with actual class
"""
# Predict model Function
def predictModel(model, data_loader, device=torch.device("cpu")):
    # init metrics history
    y_pred = []
    prob = []
    y_test = []

    print(f"\nPredict Model: ", end='', flush=True)
    model.eval()    # configure model for inference
    with torch.no_grad():   # disable gradient calculation for inference
        for inputs, labels in data_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)

            # softmax normalization layer to get probabilities instead of logits
            outputs = torch.softmax(outputs, dim=1)

            # convert categorical prediction outputs to vector: search for max response across classes (colums)
            probability, predicted = torch.max(outputs, 1)              # categorical to vector
            if labels.dim() == 2 : _, labels = torch.max(labels, 1)     # categorical to vector

            y_pred.extend(predicted.tolist())
            prob.extend(probability.tolist())
            y_test.extend(labels.tolist())
            print(f".", end='', flush=True)     # print progressing status
    #end with torch.no_grad()
    
    print(f"\r\n")
    return y_pred, prob, y_test
#end def predictModel()

"""
Calculates Confusion Matrix, Precision/Recall/F-Score, TP-Rate, FP-Rate
Shows performace 
inputs: 
    label: list/ndarray test label (row vector) 
    prediction:  list/ndarray  predicted label (row vector)
    labelRange: list with the range of labels (numeric/string label association)
outputs: accuracy, confusionMat
"""
def showPerformance(label, prediction, labelRange):
    nc = len(labelRange)
    confusionMat = np.zeros((nc, nc), dtype=np.int32)

    # Confusion matrix: (rows: actual class, cols: predicted class)
    for idx,predictedClass in enumerate(prediction):
        actualClass = label[idx]
        confusionMat[actualClass-1, predictedClass-1] += 1

    test_count = confusionMat.sum()
    accuracy = confusionMat.trace(dtype=np.float32) / test_count  # sum diagonal elements

    # Show confusion matrix data
    print("\n##################")
    print("Class Labels:")
    print("-------------")
    for r in range(nc):
        print(f"Class ({r+1}): {labelRange[r]}")

    print("\nConfusion Matrix:")
    print("-----------------")
    print(f"                Prediction (colums)")
    print(f"         ", end='')
    for c in range(nc):
        print(f"  ({c+1})", end='')
    print(f"   |Total")

    print("         ", end='')
    for c in range(nc):
        print(" ----", end='')
    print("    -----")

    for r in range(nc):
        print(f"Class ({r+1}): ", end='')
        for c in range(nc):
            print(f"{confusionMat[r,c]:3}", end='')
            if c < nc - 1: print(", ", end='')
        print(f"   | ({confusionMat[r,:].sum():2})")

    print("         ", end='')
    for c in range(nc):
        print(" ----", end='')
    print("    -----")

    print("  Total  ", end='')
    for c in range(nc):
        print(f" ({confusionMat[:,c].sum():2})", end='')
    print(f"   | ({confusionMat.sum():2})")


    print("\n           F-Score | Precision | Recall/TPRate |  FPRate  ")
    print("----------------------------------------------------------")
    for r in range(nc):
        recall = 0.0; precision = 0.0; fScore = 0.0
        if confusionMat[r,:].sum() != 0:  recall = float(confusionMat[r, r]) / confusionMat[r,:].sum()
        if confusionMat[:,r].sum() != 0:  precision = float(confusionMat[r, r]) / confusionMat[:,r].sum()
        if (recall + precision) != 0: fScore = 2 * recall*precision / (recall + precision)

        fpRateNum = 0.0;	fpRateDen = 0.0;	fpRate = 1.0
        for j in range(nc):
            if j != r:
                fpRateNum += confusionMat[j, r]
                fpRateDen += confusionMat[j,:].sum()

        if fpRateDen != 0: fpRate = fpRateNum / fpRateDen
        print(f"Class ({r + 1}): {fScore:7.3} | {precision:9.3} | {recall:9.3}     | {fpRate:7.3}")
    # end for r in range(nc):

    print(f"\nAccuracy: {accuracy*100}%")
    return accuracy, confusionMat

# End of Function showPerformance



# Plot Trainig Metrics
# input: dictionary with metrics
def plotLearningCurves(metrics):

    plt.ion()   # interactive mode, non script blocking plots

    for metric_name, values in metrics.items():
        plt.plot(values, '.-', label=metric_name)

    plt.title('Training Metrics')
    plt.xlabel('Epoch')
    plt.ylabel('Value')
    plt.legend(loc='best')
    plt.pause(0.1)    # non blocking call to GUI event manager to uptade the window
    
# end of plotLearningCurves function

