# standard library imports
import enum
import math
from datetime import datetime

# typing
from numpy.random.mtrand import Sequence
from typing import Optional

# image manipulation
import picamzero
import cv2
import numpy as np
from exif import Image

# the logger
from PiLogger import PiLogger

# constants
FLANN_INDEX_KDTREE:int = 1
FEATURE_COUNT:int = 10000
MIN_SQR_SIZE:int = 16
MAX_SQR_SIZE:int = 64

# Error types for comparing square regions of pixels around key point matches (to see which matches are the best)
# RMSE was determined to be the best experimentally
class ErrorType(enum.Enum):
    CosineDifference = 0 # The vector cosine distance between the 2 squares.
    RMSE = 1 # Root mean squared error (between individual pixels)
    MAE = 2 # Mean absolute error (between individual pixels)

"""
# A helper function for debugging.

def display_matches(image_1_cv, key_points_1, image_2_cv, key_points_2, matches):
    match_img = cv2.drawMatches(image_1_cv, key_points_1, image_2_cv, key_points_2, matches, None)
    resize = cv2.resize(match_img, (800,300), interpolation = cv2.INTER_AREA)
    cv2.imshow('matches', resize)
    cv2.waitKey(0)
    cv2.destroyWindow('matches')
"""

class ImageManager:
    """
    A robust image management class for the Astro-Pi Mission Space Lab.
    Handles image capture, conversion, (and did displaying but commented out for submission).
    """

    def __init__(self, mylogger: PiLogger) -> None:
        """
        Initialize the ImageManager with camera setup and logging.

        Args:
            mylogger: PiLogger instance for recording operations and errors
        """
        self.logger = mylogger
        mylogger.operation_start("Initializing ImageManager")
        try:
            # Initialize camera
            self.camera:picamzero.Camera = picamzero.Camera()
            self.logger.info("Camera initialized successfully")

        except Exception as e:
            # Re-raise as a more specific exception with context (to be caught at a higher level)
            raise RuntimeError(f"ImageManager initialization failed: {str(e)}")

        mylogger.operation_end()

    def capture(self) -> tuple[np.ndarray, datetime]:
        """
        Captures an image.
        :return: An RGB image, time taken.
        """
        start = datetime.now() # much better than exif data (faster access and more accurate)
        image = self.camera.capture_array() # captures directly from camera without writing to disk

        return image, start

    @staticmethod
    def get_coords_from_kps(kps:list[cv2.KeyPoint]) -> list[Sequence[float]]:
        return [kp.pt for kp in kps]

    def calc_weighted_mean_feat_dist(self, im1:np.ndarray, im2:np.ndarray, error_type:ErrorType, norm_sqrs:bool, softmax:bool) -> float:
        """
        Calculates an accurate feature distance using key point comparisons after matching to identify the strength of the match using the
        provided error function
        :param im1: the first image to find features in
        :param im2: the second image to find features in
        :param error_type: the type of comparison used to weigh the strengths of feature matches
        :param norm_sqrs: performs vector normalization on the comparison samples (sqrs) of the images to reduce the effect of differences in brightness between images
        :param softmax: uses exponential function to concentrate weighting on one match above others.
        :return: the weighted mean feature distance
        """
        # get a mask for feature recognition for black regions in case any black border is present around the image
        black_mask_im1 = self.mask_black(im1)
        black_mask_im2 = self.mask_black(im2)

        # detect features
        im1_kps, im1_desc = self.get_features(im1, black_mask_im1)
        im2_kps, im2_desc = self.get_features(im2, black_mask_im2)

        # match features
        matches = self.get_matches(im1_desc, im2_desc)

        # get matched feature coordinates
        im1_coords = self.get_coords_from_kps(im1_kps)
        im2_coords = self.get_coords_from_kps(im2_kps)
        coords1, coords2 = self.get_match_coords(im1_coords, im2_coords, matches)

        # get unweighted distances
        distances = self.calc_feature_distances(coords1, coords2)
        distances = sorted(distances)

        # Handle the extreme cases
        if len(distances) == 0:
            return -1 # check for negative distance on a higher level

        """
        # code for getting rid of extreme results, but weighting corrects for extremes much better and limits the possibility of a 'perfect match'
        
        if len(distances) == 1:
            return  distances[0]

        if len(distances) == 2:
            return  distances[0]

        if len(distances) == 3:
            return  distances[1]
            
        n = len(distances)
        lower_quartile = round((n+1)/4)
        upper_quartile = round((3*(n+1))/4)

        distances = distances[lower_quartile:upper_quartile]
        coords1 = coords1[lower_quartile:upper_quartile]
        coords2 = coords2[lower_quartile:upper_quartile]
        """


        # get weights by comparing similarity of squares around matching key points
        sqr_size = self.determine_sqr_size(coords1, coords2, im1, im2)
        weights = self.get_dist_weights(im1, im2, coords1, coords2, sqr_size, error_type, norm_sqrs, softmax)

        # get result
        res = 0
        for w, d in zip(weights, distances):
            res += w * d

        return  res

    @staticmethod
    def determine_sqr_size(coords1:list[tuple[int, int]], coords2:list[tuple[int, int]], im1:np.ndarray, im2:np.ndarray) -> int:
        # get resolutions by ignoring 3rd dimension that is 3 large to accommodate the RGB (or BRG in OpenCV) channals
        resolution_im1 = im1.shape[:2]
        resolution_im2 = im2.shape[:2]

        # get bounds within images
        sqr_size_im1 = ImageManager.get_max_square_sizes(coords1, resolution_im1)
        sqr_size_im2 = ImageManager.get_max_square_sizes(coords2, resolution_im2)

        # combine results and set upper and lower bounds
        sqr_size = min(sqr_size_im1, sqr_size_im2)
        sqr_size = max(sqr_size, MIN_SQR_SIZE) # Makes sure that 1 feature at the edge of the image does not reduce the sqr size to 1
        sqr_size = min(sqr_size, MAX_SQR_SIZE) # Makes sure that we don't spend forever on comparing large squares

        # Gets rid of floating point square size (courtesy of fp coordinates provided by SIFT feature recognition)
        sqr_size = round(sqr_size)
        return  sqr_size

    @staticmethod
    def get_max_square_sizes(coords:list[tuple[int, int]], resolution:tuple[int, int]) -> int:
        # get fist max sqr size by using the las coordinate (faster than removing the first element)
        max_sqr_size = ImageManager.get_max_square_size(coords[-1], resolution)
        coords.pop()

        # find the minimum of all the coordinates' constraints
        for coord in coords:
            new_max = ImageManager.get_max_square_size(coord, resolution)
            max_sqr_size = min(new_max, max_sqr_size)

        return max_sqr_size

    @staticmethod
    def get_max_square_size(coord:tuple[int, int], resolution:tuple[int, int]) -> int:
        # get coords and bounds
        x, y = coord
        width, height = resolution

        # finds the max width assuming a zero indexed image
        to_right = width - x - 1
        max_width = min(x, to_right)

        # finds the max height assuming a zero indexed image
        to_bottom = height - y - 1
        max_height = min(y, to_bottom)

        # chooses the minimum and returns it
        res = min(max_width, max_height)
        return res


    @staticmethod
    def calc_feature_distances(coords1, coords2):
        distances = []

        for coord1, coord2 in zip(coords1, coords2):
            x_diff = coord1[0] - coord2[0]
            y_diff = coord1[1] - coord2[1]
            res = math.sqrt((x_diff * x_diff) + (y_diff * y_diff))
            distances.append(res)

        return  distances

    @staticmethod
    def get_match_coords(coords1, coords2, matches):
        res1 = []
        res2 = []

        for m in matches:
            res1.append(coords1[m.queryIdx])
            res2.append(coords2[m.trainIdx])

        return res1, res2

    _flann:Optional[cv2.FlannBasedMatcher] = None
    def get_matches(self, desc1:np.ndarray, desc2:np.ndarray) -> list[cv2.DMatch]:
        # create matcher if not created already
        if self._flann is None:
            index_params = dict(algorithm = FLANN_INDEX_KDTREE, trees = 5)
            self._flann = cv2.FlannBasedMatcher(index_params)

        # find matches
        matches:Sequence[Sequence[cv2.DMatch]] = self._flann.knnMatch(desc1, desc2, 2)

        # Lowe's second-closest neighbor test using 0.8 as per the original paper.
        res = []
        for m, n in matches:
            if m.distance < 0.8 * n.distance:
                res.append(m)

        return res

    _detector:Optional[cv2.SIFT] = None
    def get_features(self, image: np.array, mask = None):
        # create a detector if not already created
        if self._detector is None: self._detector = cv2.SIFT_create(FEATURE_COUNT)

        # apply mask if provided
        if mask is not None:
            mask *= 255
        res = self._detector.detectAndCompute(image, mask)

        return res

    def mask_black(self, image:np.ndarray) -> np.ndarray:
        grey = self.cv_to_grey(image)
        mask = np.zeros_like(grey)
        mask[grey > 0] = 1
        return mask

    @staticmethod
    def get_dist_weights(im1, im2, coords1, coords2, sqr_size, error_type:ErrorType, norm_sqrs = True, softmax = True):
        # Normalise the sqrs if needed
        if norm_sqrs:
            sqrs1 = ImageManager.get_normed_sqrs(im1, coords1, sqr_size)
            sqrs2 = ImageManager.get_normed_sqrs(im2, coords2, sqr_size)
        else:
            sqrs1 = ImageManager.get_squares_about(im1, coords1, sqr_size)
            sqrs2 = ImageManager.get_squares_about(im2, coords2, sqr_size)

        similarity:list[float] = []
        perfect_match = False

        # calculate the similarity between the sqrs using each method
        for sqr1, sqr2 in zip(sqrs1, sqrs2):
            res = ImageManager.find_similarity(sqr1, sqr2, error_type)

            # stop if a perfect match has been found (more likely than you'd imagine)
            if res is True:
                perfect_match = True
                break

            similarity.append(res)

        # make the perfect match the only number taken into consideration by masking all other numbers
        if perfect_match:

            # create mask
            w_len = len(sqrs1)
            res = [0 for i in range(w_len)]

            # set perfect match to 1
            perfect = len(similarity)
            res[perfect] = 1

            return res

        # normalize the similarity scores to make them into weights
        weights = ImageManager.normalize_weights(similarity, softmax)
        return weights

    @staticmethod
    def find_similarity(sqr1, sqr2, error_type):
        # do not compare sqrs with different shapes (could be caused by one square clipping the border of the image)
        if sqr1.shape != sqr2.shape:
            return 0

        # cosine
        if error_type == ErrorType.CosineDifference:
            weight = ImageManager.get_dot(sqr1, sqr2)
            return  weight

        # rmse
        if error_type == ErrorType.RMSE:
            error = ImageManager.get_rms(sqr1, sqr2)

            # avoids division by zero and also makes confidence in this result go up to infinity (because it is the same square of pixels but shifted along)
            if error == 0:
                return True

            weight = 1 / error # inverse of error is similarity
            return weight

        # mae
        if error_type == ErrorType.MAE:
            error = ImageManager.get_mae(sqr1, sqr2)

            # avoids division by zero and also makes confidence in this result go up to infinity (because it is the same square of pixels but shifted along)
            if error == 0:
                return True

            weight = 1 / error # inverse of error is similarity
            return weight

    @staticmethod
    def normalize_weights(weights, softmax) -> list[float]:
        n_weights = []
        sum_weights = 0

        # get sum of weights
        for w in weights:
            if softmax:
                w = math.exp(w)
            n_weights.append(w)
            sum_weights += w

        # divide by sum of weights
        for i, w in enumerate(n_weights):
            w /= sum_weights
            n_weights[i] = w

        return n_weights

    # the cosine distance between the two squares (imagine them as 2 vectors of numbers being compared)
    @staticmethod
    def get_dot(sqr1, sqr2):
        product = sqr1 * sqr2
        sqr_sum = np.sum(product)
        return sqr_sum

    # the standard Root Mean Squared Error in stats (penalizes more than linearly out-liars)
    @staticmethod
    def get_rms(sqr1, sqr2):
        dif = sqr1 - sqr2
        dif2 = dif**2
        dif_sum = np.sum(dif2)

        # dealing with a bug in numpy to get the number of elements in ndarray (should report it some day)
        len_dif = len(dif)
        if dif.shape[0] == 0:
            len_dif = dif.shape[1] * dif.shape[2]

        dist = dif_sum/len_dif
        dist = math.sqrt(dist)
        return dist

    # the standard Mean Absolute Error in stats (penalizes linearly out-liars)
    @staticmethod
    def get_mae(sqr1, sqr2):
        dif = sqr1 - sqr2
        dif_abs = np.abs(dif)
        diff_tot = np.sum(dif_abs)

        # dealing with a bug in numpy to get the number of elements in ndarray (should report it some day)
        len_dif = len(dif)
        if dif.shape[0] == 0:
            len_dif = dif.shape[1] * dif.shape[2]

        return diff_tot / len_dif

    @staticmethod
    def get_normed_sqrs(image: np.array, coords:list[tuple[int, int]], sqr_size:int) -> tuple[np.ndarray]:
        sqrs = []

        for coord in coords:
            sqr = ImageManager.get_normed_sqr(image, coord, sqr_size)
            sqrs.append(sqr)

        return sqrs

    @staticmethod
    def get_normed_sqr(image: np.array, coord:tuple[int, int], sqr_size:int) -> np.ndarray:
        # get square
        sqr = ImageManager.get_square_about(image, coord, sqr_size)

        # get rid fo 255 factor and also turn it into a float array
        sqr = sqr / 255

        # normalize it as thought it were a vector
        sqr2 = sqr ** 2
        sqr_sum = np.sum(sqr2)
        vec_len = math.sqrt(sqr_sum)
        sqr = sqr / vec_len

        # return the value
        return sqr

    @staticmethod
    def get_squares_about(image:np.ndarray, coords:list[tuple[int, int]], size:int):
        squares = []

        for coord in coords:
            sqr = ImageManager.get_square_about(image, coord, size)
            squares.append(sqr)

        return squares

    @staticmethod
    def get_square_about(image:np.ndarray, coord:tuple[float, float], size:int):
        x, y = coord

        # gets rid of floating point coordinates in a pixelated image (courtesy of feature recognition)
        x = round(x)
        y = round(y)

        # calculate the bounds
        left = x - size
        top = y - size
        right = x + size + 1
        bottom = y + size + 1

        # returns the selected square
        square = image[left : right, top : bottom]
        return square

    @staticmethod
    def cv_to_grey(image: np.ndarray) -> np.ndarray:
        return cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)


    # Un-used utility functions (useful for debugging but not while running)
    """
    
    @staticmethod
    def read(file_path:str) -> np.ndarray:
        return cv2.imread(file_path)

    @staticmethod
    def get_time(image):
        with open(image, 'rb') as image_file:
            img = Image(image_file)
            time_str = img.get("datetime_original")
            time = datetime.strptime(time_str, '%Y:%m:%d %H:%M:%S')
        return time
    
    def CVToHist(self, image: np.ndarray):
        grey = self.CVToGrey(image)
        self.GreyCVToHist(grey)
    
    # requires scikit-image which does not exist in astro-pi-replay
    def GreyCVToHist(self, image: np.ndarray):
        hist_data = sci.exposure.histogram(image)
        matplotlib.pyplot.plot(hist_data[0][1:])
        matplotlib.pyplot.show()


    def RGBToCV(self, image: np.ndarray) -> np.ndarray:
        return cv2.cvtColor(image, cv2.COLOR_RGB2BGR)

    def CVToRGB(self, image: np.ndarray) -> np.ndarray:
        return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    def display_images(
            self,
            images: List[np.ndarray],
            titles: List[str]
    ) -> None:
        \"""
        Display images using OpenCV.

        Args:
            images: List of OpenCV images
            titles: List of titles for the images
        \"""
        try:
            for img, title in zip(images, titles):
                cv2.imshow(title, img)

            cv2.waitKey(0)

            for title in titles:
                cv2.destroyWindow(title)

        except Exception as e:
            self.logger.error(f"OpenCV display error: {str(e)}")
            raise RuntimeError("Could not display given images")

    def display_image(
            self,
            image: np.ndarray,
            title: str
    ) -> None:
        \"""
        Display images using OpenCV.

        Args:
            image: The OpenCV image
            title: The title for the image
        \"""
        try:
            cv2.imshow(title, image)
            cv2.waitKey(0)
            cv2.destroyWindow(title)

        except Exception as e:
            self.logger.error(f"OpenCV display error: {str(e)}")
            raise RuntimeError("Could not display given images")
    """