"""
*************************************************************************
// Example program using OpenCV library
//      python >3.7 - OpenCV 4.5
// @file	e6d.py
// @author Luis M. Jimenez
// @date 2022
//
// @brief Course: Computer Vision (1782)
// Dept. of Systems Engineering and Automation
// Automation, Robotics and Computer Vision Lab (ARVC)
// http://arvc.umh.es
// University Miguel Hernandez
//
// @note Description:
//  - Load pretrained Detection/Classification CNN Darknet Models (YOLO)
//  - Test Detection  with camera images
//
*************************************************************************
"""

# Import libraries
import cv2 as cv
import numpy as np
import argparse


# -----------------------------------------
# Global variables
# -----------------------------------------

WINDOW_CAMERA1 = '(W1) Camera 1'   # window id
CAMERA_ID = 0	                   # default camera

# Darknet models files
NETWORK = "yolo"

if NETWORK == "yoloTiny":     # YOLO tiny (COCO dataset 80 classes)
    MODEL_FILE = '../darknet/yolov3-tiny.weights'
    CONFIG_FILE = '../darknet/yolov3-tiny.cfg'
    LABELS_FILE = '../darknet/coco.names'
    BLOB_SIZE = (416, 416)      # input layer size

else:                       # YOLOv3  (COCO dataset 80 classes) 
    MODEL_FILE = '../darknet/yolov3.weights'
    CONFIG_FILE = '../darknet/yolov3.cfg'
    LABELS_FILE = '../darknet/coco.names'
    #BLOB_SIZE = (608, 608)      # native input layer size
    BLOB_SIZE = (320, 320)      # downsized input layer size (increase FPS)

# check command line parameters (imageFile)
parser = argparse.ArgumentParser(description='OpenCV example: classification')
parser.add_argument('-c', dest='cameraID', type=int, default=CAMERA_ID, metavar='id', help='camera id')

CAMERA_ID = parser.parse_args().cameraID

# -----------------------------------------
# Put here the code to Initialize objets
# -----------------------------------------

# TickMeter object to calculate FPS
tm = cv.TickMeter()

# Load CNN data
network = cv.dnn.readNet(model=MODEL_FILE, config=CONFIG_FILE, framework='Darknet')
networkModel = cv.dnn.DetectionModel(network)

# Load labels
with open(LABELS_FILE, 'r') as file:
    labels = file.read().rstrip('\n').split('\n')

# Show network model data
#print(network.dump())
print(f"#Classes: {len(labels)}")
print(f"#Layers: {len(network.getLayerNames())}")
print(f"#OutputLayers: {len(network.getUnconnectedOutLayers())} - ", end='')
print(network.getUnconnectedOutLayersNames(), network.getUnconnectedOutLayers())


# create a list of random colors for each label
colors = np.random.randint(low=0, high=200, size=(len(labels),3), dtype=np.uint8).tolist()

# Open camera object
camera = cv.VideoCapture(CAMERA_ID)
if not camera.isOpened():
    print("you need to connect a camera, sorry.")
    exit()

# Increase camera resolution
camera.set(cv.CAP_PROP_FRAME_WIDTH, 960)
camera.set(cv.CAP_PROP_FRAME_HEIGHT, 720)

# Getting camera resolution
cameraWidth = int(camera.get(cv.CAP_PROP_FRAME_WIDTH))
cameraHeight = int(camera.get(cv.CAP_PROP_FRAME_HEIGHT))

# Creating visualization windows
cv.namedWindow(WINDOW_CAMERA1, cv.WINDOW_AUTOSIZE)

print(f"Capturing images from camera {CAMERA_ID} ({cameraWidth},{cameraHeight})")
print("...Hit q/Q/Esc to exit.")


# -----------------------------------------
# Main Loop
# while there are images ...
# -----------------------------------------
while True:
    # Capture frame-by-frame
    ret, capture = camera.read()

    # if frame is read correctly ret is True
    if not ret:
        print("Can't receive frame (stream end?). Exiting ...")
        break

    # -----------------------------------------
    # Put your image processing code here
    # -----------------------------------------

    tm.start()  # start processing cycle

    # Important: input blob color levels must be scaled to [0-1]
    # We don't set crop as internal box coordinates are normalized to the blob size
    networkModel.setInputParams(scale=1.0/255, size=BLOB_SIZE, mean=capture.mean(axis=(0, 1)), swapRB=True)
    networkModel.setNmsAcrossClasses(True)  # True: NMS across classes, False: NMS only for bboxes of the same class
    classIds, confidences, boxes = networkModel.detect(capture, confThreshold=0.3, nmsThreshold=0.2)

    tm.stop()  # end processing cycle
    # Show FPS on image
    cv.putText(capture, f"FPS: {tm.getFPS():.1f} ", org=(3, 20),
               fontFace=cv.FONT_HERSHEY_DUPLEX, fontScale=0.5, color=(0, 0, 255), thickness=1, lineType=cv.LINE_AA)

    # Show Detected Bounding Boxes
    for idx, classId in enumerate(classIds):
        confidence = confidences[idx]
        (x, y, w, h) = boxes[idx]
        color = colors[classId]
        label = labels[classId]

        cv.rectangle(capture, (x, y), (x + w, y + h), color, thickness=2)
        text = f"{label}: {confidence*100:.1f}%"
        textSize, baseline = cv.getTextSize(text, cv.FONT_HERSHEY_DUPLEX, 0.3, 1)
        cv.rectangle(capture, (x, y), (x + 10 + textSize[0], y - 10 - textSize[1]), color, thickness=cv.FILLED)
        cv.putText(capture, text, (x + 5, y - 5), cv.FONT_HERSHEY_DUPLEX, 0.3, (255, 255, 255), 1, cv.LINE_AA)

    # -----------------------------------------
    # Put your visualization code here
    # -----------------------------------------
    cv.imshow(WINDOW_CAMERA1, capture)     # Display the resulting frame

    # check keystroke to exit (image window must be on focus)
    key = cv.pollKey()
    if key == ord('q') or key == ord('Q') or key == 27:
        break

# End while (main loop)

# -----------------------------------------
# free windows and camera resources
# -----------------------------------------
cv.destroyAllWindows()
if camera.isOpened():  camera.release()


# -----------------------------------------
# free windows and camera resources
# -----------------------------------------
pass
