"""
*************************************************************************
// Example program using OpenCV library
//      python >3.8 - torch, torchvision (matplotlib,numpy, torch_util)
// @file	et2.py
// @author Luis M. Jimenez
// @date 2024
//
// @brief Course: Computer Vision (1782)
// Dept. of Systems Engineering and Automation
// Automation, Robotics and Computer Vision Lab (ARVC)
// http://arvc.umh.es
// University Miguel Hernandez
//
// @note Description:
//	- Torch API example - Transfer Learning
//	- Load pretrained DNN VGG16 - Show layer structure
//  - Build new network replacing  classification layers
//	- Load image datase
//  - Train new  network classification layers 
//  - Test network and Show Learning curves
//
*************************************************************************
"""

# Filter out logging messages
#-------|------------------|------------------------------------
#  0    | DEBUG            | [Default] Print all messages
#  1    | INFO             | Filter out INFO messages
#  2    | WARNING          | Filter out INFO & WARNING messages
#  3    | ERROR            | Filter out all messages
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # or any {'0', '1', '2', '3'}

import numpy as np
import torch
import torchvision
import torchinfo
import time
import inspect      # to show model code for forward() method

# local package for Computer Vision Course (1782): train, predict, plot metrics
from torch_util import trainModel, predictModel, plotLearningCurves


# class labels
CLASSES_TEXT = { "7": "Shell",  "8": "Green Cube",  "249": "Red Cube", 
                "9": "Blue Shoe",   "15": "Piolin", "35": "Cup",
                "62": "Duck", "69": "Tomato", "291": "Ball",  
                "138": "Blue Car", "160": "White Car", "156": "Red Clock", 
                "233": "Corn", "323": "Packet", "332": "Vase", "950": "Banana" }

NUM_CLASSES = len(CLASSES_TEXT)


# -------------------------
# LOAD/MODIFY MODEL SECTION
# -------------------------

# Option A) Build model and download weights  (base network)
#model = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT)

# Option B) Build model and load weigths from local file (base network)
#model = torchvision.models.vgg16()
#model.load_state_dict(torch.load("../pytorch-models/vgg16-weights.pth"))

#torch.save(model, "../pytorch-models/vgg16_full_model.pth")    # Save model with weigths

# Option C) alternatively  we can use pre-downloaded models (http://umh1782.umh.es/python)
model = torch.load("../pytorch-models/vgg16_full_model.pth")

# prints model summary
print(f"VGG16 model summary:", model)   
# Detailed model info
torchinfo.summary(model, col_names=["input_size", "output_size", "num_params", "trainable"],
                    input_size=(1, 3, 224, 224), col_width=15)   

# Freeze training weigths for layers in current feature block  
for param in model.features.parameters():
    param.requires_grad = False

# Modifiy the last block of VGG network (classifier)
model.classifier = torch.nn.Sequential(
    torch.nn.Dropout(0.5),
    torch.nn.Linear(in_features=model.classifier[0].in_features, out_features=1000),
    torch.nn.LeakyReLU(),
    torch.nn.Linear(in_features=1000, out_features=100),
    torch.nn.LeakyReLU(),
    torch.nn.Linear(in_features=100, out_features=NUM_CLASSES),

    # CrossEntropyLoss includes LogSoftmax internaly so we cannot use an output layer torch.nn.Softmax(dim=1) 
    # during training, we will introduce a softmax for Inference in predictModel to get probabilities
)

# prints model summary
print(f"Modified VGG16 model summary:\n", model)
print(inspect.getsource(model.forward))     # show model code for forward() method

# Detailed model info
torchinfo.summary(model, col_names=["input_size", "output_size", "num_params", "trainable"],
                    input_size=(1, 3, 224, 224), col_width=18) 

# ------------------------------------
# LOAD TRAIN/VALIDATION DATA SECTION
# ------------------------------------

# Image transform  Resize/normalize input image color (net specific) [scale*(x-mean)]
# OPTION A: download default VGG16 preprocess transformation
preprocess_vgg16 = torchvision.models.VGG16_Weights.DEFAULT.transforms()
print("VGG16 Preprocess Transform: ", preprocess_vgg16)

# OPTION B: manually compose preprocess transformation
# preprocess_vgg16 = torchvision.transforms.Compose([
#     torchvision.transforms.Resize(size=256),
#     torchvision.transforms.CenterCrop(size=(224, 224)),
#     torchvision.transforms.ToTensor(),
#     torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# ])


# Load image dataset (list of image files)
full_dataset = torchvision.datasets.ImageFolder(root="../images/", transform=preprocess_vgg16)

# extract labels idx  dictionary: (label -> idx)
class_names = full_dataset.class_to_idx
print(f"ClassNames ({len(class_names)}):", class_names)


# split train/validation subsets are simple references to base full_dataset
data_size = len(full_dataset)
train_size = int(0.8 * data_size)   # 80% train
test_size = data_size - train_size  # 20% validation

train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])

# Data Loaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=10, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=10, shuffle=False)


# -----------------------
# MODEL TRAINING SECTION
# -----------------------

# test GPUs for CUDA acceleration
print("GPUs available:", torch.cuda.device_count())
for i in range(torch.cuda.device_count()):
    print(f"GPU: {torch.cuda.get_device_name(i)}, Type: {torch.cuda.get_device_capability(i)}")

# select model device (CPU/CUDA) once it is created and modified
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(DEVICE)

# Configure Loss function and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
# To speed up training we can focus optimizer only on classifier parameters: model.classifier.parameters()


# Train model (we'll start with  a few epochs as training without GPU acceleration will be slow)
start_time = time.time()
history = trainModel( model=model, train_loader=train_loader, val_loader=val_loader, 
                        criterion=criterion, optimizer=optimizer, epochs=5, device=DEVICE )
print(f"Training completed in: {time.time() - start_time:.2f}s")

# Show Learning curves
plotLearningCurves(history)


# -----------------------
# INFERENCE SECTION
# -----------------------

# Prediction for val_dataset
y_pred, prob, y_test =  predictModel(model, val_loader, DEVICE)

accuracy = np.sum(np.array(y_pred) == np.array(y_test))/len(y_pred)    # Calculate the accuracy (1 - error rate)
print(f"Test Accuracy: {accuracy}")

print(f"(Predicted,Actual): \n", list(zip(y_pred, y_test)))
print(f"Prob:      ", prob)

# -----------------------
# STORE MODEL SECTION
# -----------------------
# Save model
torch.save(model, "MyNet_vgg16.pth")


# save keys/class names file
with open("aloi-16-keys-labels.txt", "w") as f:
    for label, idx in class_names.items():
        f.write(label+'\n')

with open("aloi-16-labels.txt", "w") as f:
    for label, idx in class_names.items():
        f.write(CLASSES_TEXT[label]+'\n')

# wait for a key to keep plot on screen
input('click a key')