"""
*************************************************************************
// Example program using OpenCV library
//      python >3.8 - torch, torchvision (matplotlib, numpy, titere, torch_util)
// @file	et1.py
// @author Luis M. Jimenez
// @date 2024
//
// @brief Course: Computer Vision (1782)
// Dept. of Systems Engineering and Automation
// Automation, Robotics and Computer Vision Lab (ARVC)
// http://arvc.umh.es
// University Miguel Hernandez
//
// @note Description:
//	- Torch Sequential API example
//	- Build a MLP (dense layers) as a Sequential Keras Model 
//  - Show layer structure
//	- Load Titere classification training data
//  - Train network
//  - Test network and Show Learning curves
//
*************************************************************************
"""

# Filter out logging messages
#-------|------------------|------------------------------------
#  0    | DEBUG            | [Default] Print all messages
#  1    | INFO             | Filter out INFO messages
#  2    | WARNING          | Filter out INFO & WARNING messages
#  3    | ERROR            | Filter out all messages
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # or any {'0', '1', '2', '3'}

import torch
import torchvision
import torchinfo

import numpy as np
import time
import argparse

# local package for Computer Vision Course (1782): train, predict, plot metrics
from torch_util import trainModel, predictModel, showPerformance, plotLearningCurves

# local package for Computer Vision Course (1782): titere data import
import titere


DATA_FILE = 'train.arff'             # default train data file
TEST_FILE = 'test.arff'             # default test data file

# check command line parameters (imageFile)
parser = argparse.ArgumentParser(description='OpenCV example: classification')
parser.add_argument('dataFile', nargs='?', default=DATA_FILE,  help='training file')
parser.add_argument('testFile', nargs='?', default=TEST_FILE,  help='Test file')

DATA_FILE = parser.parse_args().dataFile
TEST_FILE = parser.parse_args().testFile


# ------------------------------------
# LOAD TRAIN/VALIDATION DATA SECTION
# ------------------------------------
# import titere and Convert list/numpy array data to pytorch tensors
# reduced train reduced Data with only 4 colums: ('Compacidad', 'Excentricidad',  'Rel_Invar_1', 'Rel_Invar_2'))

trainDataDict = titere.readWekaDataTitere(DATA_FILE)
trainData = torch.tensor(trainDataDict['dataMat'])
trainLabels = torch.tensor(trainDataDict['label'])  # vector with numeric labels (starts at 1) 
labelRange = trainDataDict['labelRange']            # list with the numeric/string label association

testDataDict = titere.readWekaDataTitere(TEST_FILE, labelRange)
testData = torch.tensor(testDataDict['dataMat'])
testLabels = torch.tensor(testDataDict['label'])

numFeatures = trainData.size(-1)    # numFeatures: trainData tensor last dimension
numClasses = len(labelRange)

# Display data
# print(f"{trainDataDict['attributes']} | Clase (Num/Label)")
# print(np.column_stack( (trainData, trainLabels, trainDataDict['labelName']) ))
# print(f"{labelRange=}")

# adapt trainLabels/testLabels vector to a float32 tensor matrix with one column per class
# train/test Labels: vector with numeric labels starting at 1, adjust initial index to 0 (trainLabels-1)
trainLabels = torch.nn.functional.one_hot(trainLabels.long()-1, num_classes=numClasses).float()
testLabels = torch.nn.functional.one_hot(testLabels.long()-1, num_classes=numClasses).float()

# Create a DataLoader (internal object to manage batches/suffle/data augmentation)
train_loader = torch.utils.data.DataLoader( torch.utils.data.TensorDataset(trainData, trainLabels),
                                            batch_size=2, shuffle=True)
val_loader = torch.utils.data.DataLoader( torch.utils.data.TensorDataset(testData, testLabels),
                                            batch_size=1, shuffle=True)

# -------------------------
# BUILD/LOAD MODEL SECTION
# -------------------------
# build Sequential torch model (Four features / Four classes)
model = torch.nn.Sequential(
    torch.nn.Linear(in_features=numFeatures,  out_features=20),
    torch.nn.Tanh(),
    torch.nn.Linear(in_features=20, out_features=numClasses)
    
    # CrossEntropyLoss includes LogSoftmax internaly so we cannot use a output torch.nn.Softmax(dim=1) 
)

# prints model summary
print(f"MLP Sequential Model:\n", model)    

torchinfo.summary(model, col_names=["input_size", "output_size", "num_params", "trainable"],
                    input_size=(1, 4), col_width=15) 

# -----------------------
# MODEL TRAINING SECTION
# -----------------------

# test GPUs for CUDA acceleration
print("GPUs available:", torch.cuda.device_count())
for i in range(torch.cuda.device_count()):
    print(f"GPU: {torch.cuda.get_device_name(i)}, Type: {torch.cuda.get_device_capability(i)}")

# select model device (CPU/CUDA) once it is created and modified
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(DEVICE)

# Configure Loss function and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.5)  # lr: Learning Rate, momentum: gradient inertia

# Train model
start_time = time.time()
history = trainModel( model=model, train_loader=train_loader, val_loader=val_loader, 
                        criterion=criterion, optimizer=optimizer, epochs=100, device=DEVICE )
print(f"Training completed in: {time.time() - start_time:.2f}s")

# Show Learning curves
plotLearningCurves(history)

# -----------------------
# INFERENCE SECTION
# -----------------------

# Prediction for val_dataset
y_pred, prob, y_test =  predictModel(model, val_loader, DEVICE)

print(f"Predicted: ", y_pred)
print(f"Actual:    ", y_test)
print(f"Prob:      ", prob)

showPerformance(y_test, y_pred, labelRange)

# wait for a key to keep plot on screen
input('click a key')