"""
*************************************************************************
// Example program using OpenCV library
//      python >3.8 - Keras-Tensorflow
// @file	ek1.py
// @author Luis M. Jimenez
// @date 2022
//
// @brief Course: Computer Vision (1782)
// Dept. of Systems Engineering and Automation
// Automation, Robotics and Computer Vision Lab (ARVC)
// http://arvc.umh.es
// University Miguel Hernandez
//
// @note Description:
//	- Keras Sequential API example
//	- Build a MLP (dense layers) as a Sequential Keras Model 
//  - Show layer structure
//	- Load Titere classification training data
//  - Train network
//  - Test network and Show Learning curves
//
*************************************************************************
"""

# Filter out logging messages
#-------|------------------|------------------------------------
#  0    | DEBUG            | [Default] Print all messages
#  1    | INFO             | Filter out INFO messages
#  2    | WARNING          | Filter out INFO & WARNING messages
#  3    | ERROR            | Filter out all messages
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # or any {'0', '1', '2', '3'}

import keras
import tensorflow as tf
import numpy as np
import argparse

# local package for Computer Vision Course (1782): titere data import, plot
import titere

DATA_FILE = 'train.arff'             # default train data file
TEST_FILE = 'test.arff'             # default test data file

# check command line parameters (imageFile)
parser = argparse.ArgumentParser(description='OpenCV example: classification')
parser.add_argument('dataFile', nargs='?', default=DATA_FILE,  help='training file')
parser.add_argument('testFile', nargs='?', default=TEST_FILE,  help='Test file')

DATA_FILE = parser.parse_args().dataFile
TEST_FILE = parser.parse_args().testFile

# ------------------------------------
# LOAD TRAIN/VALIDATION DATA SECTION
# ------------------------------------
# import titere reduced data with only 4 colums: ('Compacidad', 'Excentricidad',  'Rel_Invar_1', 'Rel_Invar_2'))

trainDataDict = titere.readWekaDataTitere(DATA_FILE)
trainData = trainDataDict['dataMat']
trainLabels = trainDataDict['label']        # vector with numeric labels (starts at 1) 
labelRange = trainDataDict['labelRange']    # list with the numeric/string label association

testDataDict = titere.readWekaDataTitere(TEST_FILE, labelRange)
testData = testDataDict['dataMat']
testLabels = testDataDict['label']

numFeatures = trainData.shape[-1]    # numFeatures: trainData last dimension
numClasses = len(labelRange)

# Display data
print(f"{trainDataDict['attributes']} | Clase (Num/Label)")
print(np.column_stack( (trainData, trainLabels, trainDataDict['labelName']) ))
print(f"{labelRange=}")


# adapt trainLabels/testLabels vector to a float32 matrix with one colum per class
trainLabels = keras.utils.to_categorical(trainLabels-1, num_classes=numClasses)
testLabels = keras.utils.to_categorical(testLabels-1, num_classes=numClasses)


# -------------------------
# BUILD/LOAD MODEL SECTION
# -------------------------
# build Sequential Keras model
model = keras.models.Sequential()

model.add( keras.Input(shape=(numFeatures,), name='input', dtype='float32') )
model.add( keras.layers.Dense(units=20, activation='tanh', name='hidden') )
model.add( keras.layers.Dense(units=numClasses, activation='softmax', name='output') )

model.compile( loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'] )

model.summary()
# prints to image file the model graph (needs  pydot and  graphviz https://graphviz.gitlab.io/download/ )
#keras.utils.plot_model(model, to_file='net.png', show_shapes=True)    

# -----------------------
# MODEL TRAINING SECTION
# -----------------------
# train the network
train_res = model.fit(trainData, trainLabels, epochs=100, batch_size=1, verbose=1,
                        validation_data=(testData, testLabels))

titere.plotLearningCurves(train_res.history)

# Evaluate classsifier
loss_and_metrics = model.evaluate(testData, testLabels)
print(f"Loss and accuracy: ", loss_and_metrics)

# -----------------------
# INFERENCE SECTION
# -----------------------
# Prediction
classes = model.predict(testData)

# convert predictionMat to vector: search for max response across classes (colums)
prediction = classes.argmax(axis=1)+ 1
prob = classes.max(axis=1)
y_test = testLabels.argmax(axis=1)+ 1

print(f"Predicted: ", prediction)
print(f"Actual:    ", y_test)
print(f"Prob:      ", prob)

titere.showPerformance(y_test, prediction, labelRange)


# wait for a key to keep plot on screen
input('click a key')


