"""
*************************************************************************
// Example program using OpenCV library
//      python >3.7 - OpenCV 4.5
// @file	e6c.py
// @author Luis M. Jimenez
// @date 2022
//
// @brief Course: Computer Vision (1782)
// Dept. of Systems Engineering and Automation
// Automation, Robotics and Computer Vision Lab (ARVC)
// http://arvc.umh.es
// University Miguel Hernandez
//
// @note Description:
//  - Load Classification CNN (Caffe Models)
// -  AlexNet/GoogleNet(CafeeNet/R-CNN/Finetune-Flicker
//  - Test Detection  with camera images
//
*************************************************************************
"""

# Import libraries
import cv2 as cv
import numpy as np
import argparse


# -----------------------------------------
# Global variables
# -----------------------------------------

WINDOW_CAMERA1 = '(W1) Camera 1'   # window id
CAMERA_ID = 0	                   # default camera
ALGORITHM = 'model'       #  net or model:  prediction using cv.dnn_Net class  or cv_dnn_ClassificationModel

# Caffe models files
# Pretrained CNN model and configuration files 
NETWORK = "anet"

if NETWORK== "anet":     # AlexNet (ImageNet dataset 1000 classes)
    MODEL_FILE = '../caffe/bvlc_alexnet/bvlc_alexnet.caffemodel'
    CONFIG_FILE = '../caffe/bvlc_alexnet/deploy.prototxt'
    LABELS_FILE = '../caffe/bvlc_alexnet/imagenet-labels.txt'
    BLOB_SIZE = (227, 227)      # input layer size

elif NETWORK== "gnet":      # GoogLeNet  (ImageNet dataset 1000 classes)
    MODEL_FILE = '../caffe/bvlc_googlenet/bvlc_googlenet.caffemodel'
    CONFIG_FILE = '../caffe/bvlc_googlenet/deploy.prototxt'
    LABELS_FILE = '../caffe/bvlc_googlenet/imagenet-labels.txt'
    BLOB_SIZE = (224, 224)      # input layer size

elif NETWORK== "cnet":      #  CaffeNet (Alexnet variation)  (ImageNet dataset 1000 classes)
    MODEL_FILE = '../caffe/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel'
    CONFIG_FILE = '../caffe/bvlc_reference_caffenet/deploy.prototxt'
    LABELS_FILE = '../caffe/bvlc_reference_caffenet/imagenet-labels.txt'
    BLOB_SIZE = (227, 227)      # input layer size

elif NETWORK== "fnet":      #  FineTunning  (PASCAL_VOC-12 dataset 20 classes)
    MODEL_FILE = '../caffe/finetune_flickr_style/finetune_flickr_style.caffemodel'
    CONFIG_FILE = '../caffe/finetune_flickr_style/deploy.prototxt'
    LABELS_FILE = '../caffe/finetune_flickr_style/pascalvoc12-20-labels.txt'
    BLOB_SIZE = (227, 227)      # input layer size

elif NETWORK== "rnet":      #  R-CNN  (ILSRC13 dataset 200 classes)
    MODEL_FILE = '../caffe/bvlc_reference_rcnn_ilsvrc13/bvlc_reference_rcnn_ilsvrc13.caffemodel'
    CONFIG_FILE = '../caffe/bvlc_reference_rcnn_ilsvrc13/deploy.prototxt'
    LABELS_FILE = '../caffe/bvlc_reference_rcnn_ilsvrc13/ilsvrc13-200-labels.txt'
    BLOB_SIZE = (227, 227)      # input layer size


# check command line parameters (imageFile)
parser = argparse.ArgumentParser(description='OpenCV example: classification')
parser.add_argument('-c', dest='cameraID', type=int, default=CAMERA_ID, metavar='id', help='camera id')
parser.add_argument('-a', dest='classType', type=str, default=ALGORITHM, metavar='classType', help='net: cv.dnn_Net class  or model: cv_dnn_ClassificationModel')
parser.add_argument('-n', dest='netType', type=str, default=NETWORK, metavar='netType', help='anet: AlexNet, gnet: GoogLeNet, cnet: CaffeNet, fnet: FineTunning, rnet: R-CNN')

CAMERA_ID = parser.parse_args().cameraID
ALGORITHM = parser.parse_args().classType
NETWORK = parser.parse_args().netType

# -----------------------------------------
# Put here the code to Initialize objets
# -----------------------------------------

# TickMeter object to calculate FPS
tm = cv.TickMeter()

# Load CNN data
network = cv.dnn.readNet(model=MODEL_FILE, config=CONFIG_FILE, framework='Caffe')
networkModel = cv.dnn.ClassificationModel(network)

# Load labels
with open(LABELS_FILE, 'r') as file:
    labels = file.read().rstrip('\n').split('\n')

# final class names (just the first word of the many ImageNet names for one image)
labels = [name.strip(' "').split(',')[0] for name in labels]

# Show network model data
# print(network.dump())     
print(f"#Classes: {len(labels)}")
print(f"#Layers: {len(network.getLayerNames())}")
print(f"#OutputLayers: {len(network.getUnconnectedOutLayers())} - ", end='')
print(network.getUnconnectedOutLayersNames(), network.getUnconnectedOutLayers())


# Open camera object
camera = cv.VideoCapture(CAMERA_ID)
if not camera.isOpened():
    print("you need to connect a camera, sorry.")
    exit()

# Set camera resolution 
camera.set(cv.CAP_PROP_FRAME_WIDTH, 640)
camera.set(cv.CAP_PROP_FRAME_HEIGHT, 480)

# Getting camera resolution
cameraWidth = int(camera.get(cv.CAP_PROP_FRAME_WIDTH))
cameraHeight = int(camera.get(cv.CAP_PROP_FRAME_HEIGHT))

# Creating visualization windows
cv.namedWindow(WINDOW_CAMERA1, cv.WINDOW_AUTOSIZE)

print(f"Capturing images from camera {CAMERA_ID} ({cameraWidth},{cameraHeight})")
print("...Hit q/Q/Esc to exit.")


# -----------------------------------------
# Main Loop
# while there are images ...
# -----------------------------------------
while True:
    # Capture frame-by-frame
    ret, capture = camera.read()

    # if frame is read correctly ret is True
    if not ret:
        print("Can't receive frame (stream end?). Exiting ...")
        break
    # -----------------------------------------
    # Put your image processing code here
    # -----------------------------------------

    tm.start()  # start processing cycle

    if ALGORITHM == 'model':    # using dnn_ClassificationModel
        networkModel.setInputParams(scale=1.0, size=BLOB_SIZE, mean=capture.mean(axis=(0, 1)), swapRB=True, crop=True)
        classId, conf = networkModel.classify(capture)

    else:   # using dnn.Net
        blob = cv.dnn.blobFromImage(capture, scalefactor=1.0, size=BLOB_SIZE, mean=capture.mean(axis=(0, 1)), swapRB=True, crop=True)
        network.setInput(blob)
        outputBlobs = network.forward()
        classId = np.argmax(outputBlobs[0])
        conf = np.max(outputBlobs[0])

    tm.stop()  # end processing cycle
    cv.putText(capture, f"FPS: {tm.getFPS():.1f} ", org=(3, 20),
               fontFace=cv.FONT_HERSHEY_DUPLEX, fontScale=0.5, color=(0, 0, 255), thickness=1, lineType=cv.LINE_AA)

    if conf > 0.3:
        cv.putText(capture, f"({classId}) {conf*100:5.2f}% - {labels[classId]} ", org=(90, 20),
                   fontFace=cv.FONT_HERSHEY_DUPLEX, fontScale=0.5, color=(0, 255, 0), thickness=1, lineType=cv.LINE_AA)


    # -----------------------------------------
    # Put your visualization code here
    # -----------------------------------------
    cv.imshow(WINDOW_CAMERA1, capture)     # Display the resulting frame

    # check keystroke to exit (image window must be on focus)
    key = cv.pollKey()
    if key == ord('q') or key == ord('Q') or key == 27:
        break

# End while (main loop)

# -----------------------------------------
# free windows and camera resources
# -----------------------------------------
cv.destroyAllWindows()
if camera.isOpened():  camera.release()


# -----------------------------------------
# free windows and camera resources
# -----------------------------------------
pass
