"""
*************************************************************************
// Example program using OpenCV library
//      python >3.8 - Keras-Tensorflow (numpy,tf2onnx, titere)
// @file	ek2.py
// @author Luis M. Jimenez
// @date 2022
//
// @brief Course: Computer Vision (1782)
// Dept. of Systems Engineering and Automation
// Automation, Robotics and Computer Vision Lab (ARVC)
// http://arvc.umh.es
// University Miguel Hernandez
//
// @note Description:
//  - Keras Functional API example - Transfer Learning
//  - Load pretrained DNN VGG16 - Show layer structure
//  - Build new network replacing last three classification layers
//  - Load image dataset
//  - Train new  network classification layers 
//  - Test network and Show Learning curves
//
*************************************************************************
"""

# Filter out logging messages
#-------|------------------|------------------------------------
#  0    | DEBUG            | [Default] Print all messages
#  1    | INFO             | Filter out INFO messages
#  2    | WARNING          | Filter out INFO & WARNING messages
#  3    | ERROR            | Filter out all messages
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # or any {'0', '1', '2', '3'}
os.environ['TF_XLA_FLAGS'] = '--tf_xla_auto_jit=0'

import keras
import tensorflow as tf
import numpy as np
import tf2onnx
import time

# local package for Computer Vision Course (1782): plot
from titere import  plotLearningCurves

# class labels
CLASSES_TEXT = { "7": "Shell",  "8": "Green Cube",  "249": "Red Cube", 
                "9": "Blue Shoe",   "15": "Piolin", "35": "Cup",
                "62": "Duck", "69": "Tomato", "291": "Ball",  
                "138": "Blue Car", "160": "White Car", "156": "Red Clock", 
                "233": "Corn", "323": "Packet", "332": "Vase", "950": "Banana" }

NUM_CLASSES = len(CLASSES_TEXT)

# test GPUs CUDA acceleration
gpus = tf.config.list_physical_devices('GPU')
print("GPUs available:", len(gpus))
for gpu in gpus:
    print(f"GPU: {gpu.name}, Type: {gpu.device_type}")

# -------------------------
# LOAD/MODIFY MODEL SECTION
# -------------------------

# change input  tensor axis order: keras: (n,h,w,c) /  opencv blob format is (n,c,h,w) 'channels_first'
# VGG16 is trained with (n,h,w,c) 'channels_last' order
keras.backend.set_image_data_format('channels_last')

# Option A) download model (base network)
#model = tf.keras.applications.VGG16()

# Option B) alternatively  we can use pre-downloaded models (http://umh1782.umh.es/python)
model = keras.models.load_model('../tensorflow-keras/vgg16.h5')

# show model summary
model.summary()
#keras.utils.plot_model(model, to_file='net.png', show_shapes=True)    # prints to image file model graph  (Requires pydot and graphviz)

# sets all layers for imported model not trainable  (weights  won't be trained)
model.trainable = False

# Build new model

# get base network input tensor  (first layer) 
# Note: Current keras versions dont not recognize input tensor for first layer (use output instead)
base_in_tsr =  model.get_layer(index=0).output

# get network Flatten-Conv output tensor  
base_out_tsr = model.get_layer(name='flatten').output

print(f"# layers: ", len(model.layers))
print(f"Net Input  shape: ",  base_in_tsr.shape)
print(f"Net Flatten-Conv output  shape: ", base_out_tsr.shape)

# get input image size (n,h,w,c) /(n,c,h,w)
if keras.backend.image_data_format() == 'channels_first':
    image_size = tuple(base_in_tsr.shape)[2:4]
else:
    image_size = tuple(base_in_tsr.shape)[1:3]

# add new classification layers (trainable == True) default
layer_tsr = keras.layers.Dropout(rate=0.2, name='reg')(base_out_tsr)    # Regularization layer
layer_tsr = keras.layers.Dense(units=1000, activation='tanh', name='fc1') (layer_tsr)
layer_tsr = keras.layers.Dense(units=100, activation='tanh', name='fc2') (layer_tsr)
output_tsr = keras.layers.Dense(units=NUM_CLASSES, activation='softmax', name='prediction') (layer_tsr)

# build the new model
model = keras.models.Model(inputs=base_in_tsr, outputs=output_tsr, name='TunnedVGG16')
model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])

model.summary()     # prints model summary

# ------------------------------------
# LOAD TRAIN/VALIDATION DATA SECTION
# ------------------------------------
# Load training data
trainDataset, testDataset = tf.keras.utils.image_dataset_from_directory('../images/',  
                batch_size=10,  label_mode='categorical',
                image_size=image_size, color_mode="rgb", crop_to_aspect_ratio=True,
                shuffle=True, validation_split=0.2, seed=1, subset='both')

print(f"Batch size: ",  len(next(iter(trainDataset))[0]))
print(f"TrainDataset batches: ",  trainDataset.cardinality().numpy())
print(f"TrainDataset images: ",  sum(len(batch[0]) for batch in trainDataset))
print(f"TestDataset batches: ",  testDataset.cardinality().numpy())
print(f"TestDataset images: ",  sum(len(batch[0]) for batch in testDataset))

# extract labels idx  list: ([idx] -> labels)
class_names = trainDataset.class_names
print(f"ClassNames ({len(class_names)}):", class_names)

# normalize input image color (net specific) [scale*(x-mean)]
trainDataset = trainDataset.map( lambda x,y:  (tf.keras.applications.vgg16.preprocess_input(x), y)  )
testDataset = testDataset.map( lambda x,y:  (tf.keras.applications.vgg16.preprocess_input(x), y)  )


# -----------------------
# MODEL TRAINING SECTION
# -----------------------
start_time = time.time()

# train the network (only new layers)
train_res = model.fit(trainDataset, epochs=5, verbose=1, validation_data=testDataset)

train_time =  time.time() - start_time
print(f"Training time: {train_time:.2f}s")

plotLearningCurves(train_res.history)

# Evaluate classsifier
loss_and_metrics = model.evaluate(testDataset)
print(f"Loss and accuracy: ", loss_and_metrics)


# -----------------------
# INFERENCE SECTION
# -----------------------
# Prediction
classes = model.predict(testDataset)

# convert categorical predictionMat to vector: search for max response across classes (colums)
y_pred = classes.argmax(axis=1)
prob = classes.max(axis=1)

# extract  labels as numpy array,  y tensor in testDataset has shape (m,b,n) -> (m*b)
# batch in testDataset (batch[0]->x, batch[1]->y)
y_test = np.array([item for batch in testDataset.as_numpy_iterator() for item in batch[1]]).argmax(axis=1)

accuracy = np.sum(y_pred == y_test)/len(y_pred)    # Calculate the accuracy (1 - error rate)

print(f"Test Accuracy: {accuracy}")
print(f"(Predicted,Actual): \n", list(zip(y_pred, y_test)))
print(f"Prob:      ", prob)


# -----------------------
# STORE MODEL SECTION
# -----------------------
# save trained model
model.save('MyNet_vgg16.h5')        # HDF5 Keras legacy format
model.save('MyNet_vgg16.keras')     # New Keras  format

# convert model to standard format ONNX
tf2onnx.convert.from_keras(model, output_path="MyNet_vgg16.onnx")

# save keys/class names file
with open("aloi-16-keys-labels.txt", "w") as f:
    for label in class_names:
        f.write(label+'\n')

with open("aloi-16-labels.txt", "w") as f:
    for label in class_names:
        f.write(CLASSES_TEXT[label]+'\n')


# wait for a key to keep plot on screen
input('click a key')