"""
*************************************************************************
// Example program using OpenCV library
//      python >3.7 - OpenCV 4.5
// @file	e7b.py
// @author Luis M. Jimenez
// @date 2022
//
// @brief Course: Computer Vision (1782)
// Dept. of Systems Engineering and Automation
// Automation, Robotics and Computer Vision Lab (ARVC)
// http://arvc.umh.es
// University Miguel Hernandez
//
// @note Description:
//	- Shows the use of feature points detectors/descriptors/matchers
//
//	- Capture images from a camera (decrease resolution)
//	- Detects featured points (ORB/SIFT/SURF) using common interface FeatureDetector
//	- Calculates Descriptor vector (ORB/SIFT/SURF) using common interface DescriptorExtractor
//  - Matches featured points between frames using Brute Force or FLANN algorithms
//  - Draw Matches and detected points over captured image
//
*************************************************************************
"""

# Import libraries
import cv2 as cv
import numpy as np
import argparse

# -----------------------------------------
# Global variables
# -----------------------------------------

WINDOW_CAMERA1 = '(W1) Camera 1 (Feature Detector)'   # window id
CAMERA_ID = 0	              # default camera
FEATURE_TYPE = 'orb'          # Default Feature Detector sift/surf/orb
MATCHER_TYPE = 'flann'        # Default Feature Matcher bf/flann

# check command line parameters (camera id, feature, descriptor)
parser = argparse.ArgumentParser(description='OpenCV example: Feature Points Detector')
parser.add_argument('-c', dest='cameraID', type=int, default=CAMERA_ID, metavar='id', help='camera id')
parser.add_argument('-f', dest='feature', type=str, default=FEATURE_TYPE, metavar='feature', help='sift | surf | orb')
parser.add_argument('-m', dest='matcher', type=str, default=MATCHER_TYPE, metavar='matcher', help='bf (BruteForce) | flann')

CAMERA_ID = parser.parse_args().cameraID
FEATURE_TYPE = parser.parse_args().feature.lower()          # feature (lowercase)
MATCHER_TYPE = parser.parse_args().matcher.lower()          # matcher (lowercase)

# -----------------------------------------
# Put here the code to Initialize objets
# -----------------------------------------

# Open camera object
camera = cv.VideoCapture(CAMERA_ID)
if not camera.isOpened():
    print("you need to connect a camera, sorry.")
    exit()

# lower resolution to speed feature extraction
camera.set(cv.CAP_PROP_FRAME_WIDTH, 640)
camera.set(cv.CAP_PROP_FRAME_HEIGHT, 480)

# Getting camera resolution
cameraWidth = int(camera.get(cv.CAP_PROP_FRAME_WIDTH))
cameraHeight = int(camera.get(cv.CAP_PROP_FRAME_HEIGHT))

# Creating visualization windows
cv.namedWindow(WINDOW_CAMERA1, cv.WINDOW_AUTOSIZE)

# Init FeaturePoint Detector
if FEATURE_TYPE == 'sift':
    detector = cv.SIFT.create(nfeatures=100)
elif FEATURE_TYPE == 'orb':
    detector = cv.ORB.create(nfeatures=100)
elif FEATURE_TYPE == 'surf':
    detector = cv.xfeatures2d.SURF.create(hessianThreshold=400)
else:
    detector = cv.SIFT.create(nfeatures=100)    # Default detector

# Detector information
print(f"{FEATURE_TYPE=}")
print(f"{detector.descriptorSize()=}")

# Init Feature Matcher
if MATCHER_TYPE == 'flann':  
    if FEATURE_TYPE == 'orb':   # Binary descriptor -> NORM_HAMMING
        index_params = dict(algorithm = 6, table_number = 6, key_size = 12, multi_probe_level = 3)
        matcher = cv.FlannBasedMatcher(index_params)
    else:   
        matcher = cv.FlannBasedMatcher.create()
    
else:   # Brute Force Matcher
    if FEATURE_TYPE == 'orb':   # Binary descriptor -> NORM_HAMMING
        matcher = cv.BFMatcher.create(normType=cv.NORM_HAMMING, crossCheck=True)
    else:                       # Scalar descriptor -> NORM_L2
        matcher = cv.BFMatcher.create(normType=cv.NORM_L2, crossCheck=True)

print(f"{MATCHER_TYPE=}")

print(f"Capturing images from camera {CAMERA_ID} ({cameraWidth},{cameraHeight})")
print("...Hit q/Q/Esc to exit.")

new_reference = True		# Take new reference image

# -----------------------------------------
# Main Loop
# while there are images ...
# -----------------------------------------
while True:
    # Capture frame-by-frame
    ret, capture = camera.read()

    # if frame is read correctly ret is True
    if not ret:
        print("Can't receive frame (stream end?). Exiting ...")
        break

    # -----------------------------------------
    # Put your image processing code here
    # -----------------------------------------
    gray_image = cv.cvtColor(capture, cv.COLOR_BGR2GRAY)

    keypoints = detector.detect(gray_image)                                 # Tuple of keypoints
    keypoints, descriptors = detector.compute(gray_image, keypoints)        # Descriptor matrix (ndarray)

    # Store reference on first frame or Reset
    if new_reference:
        keypoints_ref = keypoints              # Tuples are inmutable so reference copy is valid
        descriptors_ref = descriptors.copy()   # Clone ndarray (avoid reference copy)
        image_ref = capture.copy()             # Clone ndarray (avoid reference copy)
        new_reference = False

    # Match descriptors.
    matches = matcher.match(descriptors, descriptors_ref)   # list/tuple of matches

    # Sort matches in the order of their distance.
    matches = sorted(matches, key=lambda x: x.distance)

   
    # -----------------------------------------
    # Put your visualization code here
    # -----------------------------------------
    
    # Draw first 20 matches.
    dispimage = cv.drawMatches(capture, keypoints, image_ref, keypoints_ref, matches[:min(20,len(matches))],
                               outImg=None, flags=cv.DRAW_MATCHES_FLAGS_DEFAULT)

    # Show #KeyPoints
    cv.putText(dispimage, f"Detected: {len(keypoints)} {FEATURE_TYPE.upper()} points - Matcher: {MATCHER_TYPE.upper()}", org=(5, 15),
               fontFace=cv.FONT_HERSHEY_DUPLEX, fontScale=0.4, color=(0, 0, 255), thickness=1, lineType=cv.LINE_AA)
   
    cv.imshow(WINDOW_CAMERA1, dispimage)     # Display the resulting frame

    # check keystroke to exit (image window must be on focus)
    key = cv.pollKey()
    if key == ord('q') or key == ord('Q') or key == 27:
        break
    elif key == ord('r') or key == ord('R') or key == ord(' '):
        new_reference = True     # Reset Reference image
    elif key == ord('s') or key == ord('S'):
        key = cv.waitKey()     # Stop and wait for a keystroke

# End while (main loop)

# show matches on console
print(f"Left image points: {len(keypoints)}  - Right image points: {len(keypoints_ref)}")
print(f"Matched points: {len(matches)}")
for i in range(min(10,len(matches))):
    print(f"-- Match [{i}] Keypoint Cam: {matches[i].queryIdx}", end='')
    print(f" -- Keypoint Ref: {matches[i].trainIdx} -- dist: {matches[i].distance}")

# -----------------------------------------
# free windows and camera resources
# -----------------------------------------
cv.destroyAllWindows()
if camera.isOpened():  camera.release()
