"""
*************************************************************************
// Example program using OpenCV library
//      python >=3.8 - OpenCV 4.5
// @file	e6seg.py
// @author Luis M. Jimenez
// @date 2024
//
// @brief Course: Computer Vision (1782)
// Dept. of Systems Engineering and Automation
// Automation, Robotics and Computer Vision Lab (ARVC)
// http://arvc.umh.es
// University Miguel Hernandez
//
// @note Description:
//  - Load pretrained Detection/Segmentation Mask-RCNN
//  - Test Detection and show segmentation masks with camera images
//
//  - Using base cv:dnn:Net -> returns detection boxes amd mask for every box and object class
//  - Box: has normalized coordinates [0-1] must be scaled with image size
//  - Mask: is a 15x15 matrix with segmentation probability for each pixel block
//          it has normalized coordinates [0-15] must be scaled with box size
//  - Segmentation Mask resolution: 15x15 for each box
//
//  - Using  cv:dnn:SegmentationModel ->  returns a mask of 15x15 resolution for all the image
//           with max probability class index for each pixel block
*************************************************************************
"""

# Import libraries
import cv2 as cv
import numpy as np
import argparse


# -----------------------------------------
# Global variables
# -----------------------------------------

WINDOW_CAMERA1 = '(W1) Camera 1'   # window id
CAMERA_ID = 0	                   # default camera

# Mask-RCNN detection and segmenation model files
MODEL_FILE = '../Mask-RCNN/mask_rcnn_frozen_inference_graph.pb'
CONFIG_FILE = '../Mask-RCNN/mask_rcnn_inception_v2_coco.pbtxt'
LABELS_FILE = '../Mask-RCNN/coco.names'
BLOB_SIZE = (256, 256)      # input layer size

conf_probability = 0.5  #  Minimum probability for box object detection
conf_threshold = 0.3    #  Segmentation Threshold for each pixel block in the mask

# check command line parameters (imageFile)
parser = argparse.ArgumentParser(description='OpenCV example: classification')
parser.add_argument('-c', dest='cameraID', type=int, default=CAMERA_ID, metavar='id', help='camera id')
parser.add_argument('-p', dest='conf_probability', type=float, default=conf_probability, metavar='probability', help='Minimum probability for box object detection')
parser.add_argument('-t', dest='conf_threshold', type=float, default=conf_threshold, metavar='threhold', help='Segmentation Threshold')

CAMERA_ID = parser.parse_args().cameraID
conf_probability = parser.parse_args().conf_probability
conf_threshold = parser.parse_args().conf_threshold

# -----------------------------------------
# Put here the code to Initialize objets
# -----------------------------------------

# TickMeter object to calculate FPS
tm = cv.TickMeter()

# Load CNN data
network = cv.dnn.readNet(model=MODEL_FILE, config=CONFIG_FILE, framework='Tensorflow')

# Load labels
with open(LABELS_FILE, 'r') as file:
    labels = file.read().rstrip('\n').split('\n')

# Show network model data
#print(network.dump())
print(f"#Classes: {len(labels)}")
print(f"#Layers: {len(network.getLayerNames())}")
print(f"#OutputLayers: {len(network.getUnconnectedOutLayers())} - ", end='')
print(network.getUnconnectedOutLayersNames(), network.getUnconnectedOutLayers())


# create a list of random colors for each label
colors = np.random.randint(low=0, high=200, size=(len(labels),3), dtype=np.uint8).tolist()

# Open camera object
camera = cv.VideoCapture(CAMERA_ID)
if not camera.isOpened():
    print("you need to connect a camera, sorry.")
    exit()

# Increase camera resolution
camera.set(cv.CAP_PROP_FRAME_WIDTH, 960)
camera.set(cv.CAP_PROP_FRAME_HEIGHT, 720)

# Getting camera resolution
cameraWidth = int(camera.get(cv.CAP_PROP_FRAME_WIDTH))
cameraHeight = int(camera.get(cv.CAP_PROP_FRAME_HEIGHT))

# set BLOB_SIZE to camera image size
BLOB_SIZE = (cameraHeight, cameraWidth)      # input layer size

# Creating visualization windows
cv.namedWindow(WINDOW_CAMERA1, cv.WINDOW_AUTOSIZE)

print(f"Capturing images from camera {CAMERA_ID} ({cameraWidth},{cameraHeight})")
print("...Hit q/Q/Esc to exit.")


# -----------------------------------------
# Main Loop
# while there are images ...
# -----------------------------------------
while True:
    # Capture frame-by-frame
    ret, capture = camera.read()

    # if frame is read correctly ret is True
    if not ret:
        print("Can't receive frame (stream end?). Exiting ...")
        break

    # -----------------------------------------
    # Put your image processing code here
    # -----------------------------------------

    tm.start()  # start processing cycle

    # Configuring INPUT BLOB/Tensor pre-processing parameters 
    blob = cv.dnn.blobFromImage(capture, scalefactor=1.0, size=(BLOB_SIZE),  swapRB=True, crop=False)
    network.setInput(blob)

    # Infer prediction for both output tensors (box, masks)
    boxes, masks = network.forward(['detection_out_final', 'detection_masks'])

    tm.stop()  # end processing cycle
    # Show FPS on image
    cv.putText(capture, f"FPS: {tm.getFPS():.1f} ", org=(3, 20),
               fontFace=cv.FONT_HERSHEY_DUPLEX, fontScale=0.5, color=(0, 0, 255), thickness=1, lineType=cv.LINE_AA)
   
    
    # Show Detected Boxes/Masks
    for i in range(0, boxes.shape[2]):
        classId = int(boxes[0, 0, i, 1])
        confidence = boxes[0, 0, i, 2]
        
        if classId == 0:    # Background
            continue
        
        # Rescale detection box
        box = boxes[0, 0, i, 3:7] * np.array([cameraWidth, cameraHeight, cameraWidth, cameraHeight])
        x_start, y_start, x_end, y_end = box.astype('int')
        box_width = x_end - x_start
        box_height = y_end - y_start

        # Segmentation mask: a mask or every box and class
        # it is a 15x15 matrix with segmentation probability for each pixel block
        mask = masks[i, classId]
        # Rescale segmentation mask: It has normalized coordinates [0-15] must be scaled with box size
        mask = cv.resize(mask, (box_width, box_height), interpolation=cv.INTER_NEAREST)
        
        # Keep only thouse pixels in the mask that have enough  detection probability
        mask = (mask > conf_threshold)

        color = colors[classId]
        label = labels[classId]

        # Show only thouse objetcs (Boxes) with enough detection probablility
        if confidence > conf_probability:
            # Extract Region of Interest (ROI) for object detection (Box)
            roi = capture[y_start:y_end, x_start:x_end] #  detected box ROI
            # selectc only thouse pixels present in the segmentation mask
            roi = roi[mask]

            # Blend class color with image color in the ROI to visualize the segmenation mask
            blended = (0.34 * np.array(color) + 0.66 * roi).astype('uint8')
            capture[y_start:y_end, x_start:x_end][mask] = blended

            cv.rectangle(capture, (x_start, y_start), (x_end, y_end), color, thickness=2)
            text = f"{label}: {confidence*100:.1f}%"
            textSize, baseline = cv.getTextSize(text, cv.FONT_HERSHEY_DUPLEX, 0.3, 1)
            cv.rectangle(capture, (x_start, y_start), (x_start + 10 + textSize[0], y_start - 10 - textSize[1]), color, thickness=cv.FILLED)
            cv.putText(capture, text, (x_start + 5, y_start - 5), cv.FONT_HERSHEY_DUPLEX, 0.3, (255, 255, 255), 1, cv.LINE_AA)
    
    # -----------------------------------------
    # Put your visualization code here
    # -----------------------------------------
    cv.imshow(WINDOW_CAMERA1, capture)     # Display the resulting frame

    # check keystroke to exit (image window must be on focus)
    key = cv.pollKey()
    if key == ord('q') or key == ord('Q') or key == 27:
        break

# End while (main loop)

# -----------------------------------------
# free windows and camera resources
# -----------------------------------------
cv.destroyAllWindows()
if camera.isOpened():  camera.release()


# -----------------------------------------
# free windows and camera resources
# -----------------------------------------
pass
