"""
*************************************************************************
// Example program using OpenCV library
//      python >3.7 - OpenCV 4.5
// @file	e7c.py
// @author Luis M. Jimenez
// @date 2022
//
// @brief Course: Computer Vision (1782)
// Dept. of Systems Engineering and Automation
// Automation, Robotics and Computer Vision Lab (ARVC)
// http://arvc.umh.es
// University Miguel Hernandez
//
// @note Description:
//	- Shows the use of feature points detectors/descriptors/matchers
//
//	- Capture images from two cameras (decrease resolution)
//	- Detects featured points (ORB/SIFT/SURF) using common interface FeatureDetector
//	- Calculates Descriptor vector (ORB/SIFT/SURF) using common interface DescriptorExtractor
//  - Matches featured points between frames using Brute Force or FLANN algorithms
//  - shows the use of knnMatch, finds the 2 best matches for each keypoint (k=2)
//  - Select those matches with enough distance difference  between these 2 matches (radio test)
//  - Draw Matches and detected points over captured image
//
*************************************************************************
"""

# Import libraries
import cv2 as cv
import numpy as np
import argparse

# -----------------------------------------
# Global variables
# -----------------------------------------

WINDOW_CAMERA1 = '(W1) Cameras (Feature Detector)'    # window id
CAMERA_ID = [0, 1]	              # camera ids
FEATURE_TYPE = 'orb'              # Default Feature Detector sift/surf/orb
MATCHER_TYPE = 'flann'            # Default Feature Matcher bf/flann

# check command line parameters (camera id, feature, descriptor)
parser = argparse.ArgumentParser(description='OpenCV example: Feature Points Detector')
parser.add_argument('-c', dest='cameraID', type=int, default=CAMERA_ID[0], metavar='id', help='camera id')
parser.add_argument('-f', dest='feature', type=str, default=FEATURE_TYPE, metavar='feature', help='sift | surf | orb')
parser.add_argument('-m', dest='matcher', type=str, default=MATCHER_TYPE, metavar='matcher', help='bf (BruteForce) | flann')

CAMERA_ID[0] = parser.parse_args().cameraID
CAMERA_ID[1] = parser.parse_args().cameraID + 1
FEATURE_TYPE = parser.parse_args().feature.lower()          # feature (lowercase)
MATCHER_TYPE = parser.parse_args().matcher.lower()          # matcher (lowercase)

# -----------------------------------------
# Put here the code to Initialize objets
# -----------------------------------------

# Open camera objects
cameras = []
for id in CAMERA_ID:
    camera = cv.VideoCapture(id)
    if not camera.isOpened():
        print("you need to connect a camera, sorry.")
        exit()
    cameras.append(camera)
    # lower resolution to speed feature extraction
    camera.set(cv.CAP_PROP_FRAME_WIDTH, 640)
    camera.set(cv.CAP_PROP_FRAME_HEIGHT, 480)

# Getting camera resolution
cameraWidth = int(camera.get(cv.CAP_PROP_FRAME_WIDTH))
cameraHeight = int(camera.get(cv.CAP_PROP_FRAME_HEIGHT))

# Creating visualization windows
cv.namedWindow(WINDOW_CAMERA1, cv.WINDOW_AUTOSIZE)

# Init FeaturePoint Detector
if FEATURE_TYPE == 'sift':
    detector = cv.SIFT.create(nfeatures=100)
elif FEATURE_TYPE == 'orb':
    detector = cv.ORB.create(nfeatures=100)
elif FEATURE_TYPE == 'surf':
    detector = cv.xfeatures2d.SURF.create(hessianThreshold=400)
else:
    detector = cv.SIFT.create(nfeatures=100)    # Default detector

# Detector information
print(f"{FEATURE_TYPE=}")
print(f"{detector.descriptorSize()=}")

# Init Feature Matcher
if MATCHER_TYPE == 'flann':
    if FEATURE_TYPE == 'orb':   # Binary descriptor -> NORM_HAMMING
        index_params = dict(algorithm = 6, table_number = 6, key_size = 12, multi_probe_level = 3)
        matcher = cv.FlannBasedMatcher(index_params)
    else:   
        matcher = cv.FlannBasedMatcher.create()

else:   # Brute Force Matcher
    if FEATURE_TYPE == 'orb':   # Binary descriptor -> NORM_HAMMING
        matcher = cv.BFMatcher.create(normType=cv.NORM_HAMMING, crossCheck=True)
    else:                       # Scalar descriptor -> NORM_L2
        matcher = cv.BFMatcher.create(normType=cv.NORM_L2, crossCheck=True)

print(f"{MATCHER_TYPE=}")

print(f"Capturing images from camera {CAMERA_ID} ({cameraWidth},{cameraHeight})")
print("...Hit q/Q/Esc to exit.")


# -----------------------------------------
# Main Loop
# while there are images ...
# -----------------------------------------
while True:
    # Capture frame-by-frame
    captures = []
    for camera in cameras:
        ret, capture = camera.read()
        if ret: captures.append(capture)

    # if frame is read correctly ret is True
    if not ret:
        print("Can't receive frame (stream end?). Exiting ...")
        break

    # -----------------------------------------
    # Put your image processing code here
    # -----------------------------------------
    gray_images = []
    for capture in captures:
        gray_image = cv.cvtColor(capture, cv.COLOR_BGR2GRAY)
        gray_images.append(gray_image)

    keypointsList = []
    descriptorsList = []
    for gray_image in gray_images:
        keypoints = detector.detect(gray_image)                                 # Tuple of keypoints
        keypoints, descriptors = detector.compute(gray_image, keypoints)        # Descriptor matrix (ndarray)
    
        keypointsList.append(keypoints)
        descriptorsList.append(descriptors)
        
    # knnMatch descriptors.
    matches = matcher.knnMatch(descriptorsList[0], descriptorsList[1], k=2)   # list/tuple of k-tuple of matches

    # Apply ratio test to the two best matches
    goodMatches = []
    for m1, m2 in matches:      # tuple with two elements k==2
        if m1.distance/m2.distance < 0.75:
            goodMatches.append(m1)
    
    # Sort matches in the order of their distance.
    matches = sorted(goodMatches, key=lambda x: x.distance)

    # -----------------------------------------
    # Put your visualization code here
    # -----------------------------------------

    # Draw first 20 matches.
    dispimage = cv.drawMatches(captures[0], keypointsList[0], captures[1], keypointsList[1], matches[:min(20,len(matches))],
                               outImg=None, flags=cv.DRAW_MATCHES_FLAGS_DEFAULT)

    # Show #KeyPoints
    cv.putText(dispimage, f"Detected: {len(keypoints)} {FEATURE_TYPE.upper()} points - Matcher: {MATCHER_TYPE.upper()}", org=(5, 15),
                fontFace=cv.FONT_HERSHEY_DUPLEX, fontScale=0.4, color=(0, 0, 255), thickness=1, lineType=cv.LINE_AA)
    
    cv.imshow(WINDOW_CAMERA1, dispimage)     # Display the resulting frame

    # check keystroke to exit (image window must be on focus)
    key = cv.pollKey()
    if key == ord('q') or key == ord('Q') or key == 27:
        break
    elif key == ord('s') or key == ord('S'):
        key = cv.waitKey()     # Stop and wait for a keystroke

# End while (main loop)


# -----------------------------------------
# free windows and camera resources
# -----------------------------------------
cv.destroyAllWindows()
for camera in cameras:
    if camera.isOpened():  camera.release()
