from pathlib import Path
import torch
import glob


from lightglue import LightGlue, SuperPoint, DISK, SIFT, ALIKED, DoGHardNet
from lightglue.utils import load_image, rbd
from lightglue import viz2d

torch.set_grad_enabled(False)

class ImageMatcher:
    def __init__(self, device):
        self.device = device
        self.extractor = SuperPoint(max_num_keypoints=2048).eval().to(self.device)
        self.matcher = LightGlue(features="superpoint").eval().to(self.device)

        # or DISK+LightGlue, ALIKED+LightGlue or SIFT+LightGlue
        # extractor = DISK(max_num_keypoints=2048).eval()  # load the extractor
        # matcher = LightGlue(features='disk').eval()  # load the matcher
    
    def compare_images(self, image_path1, image_path2):
        image0 = load_image(image_path1).to(self.device)
        image1 = load_image(image_path2).to(self.device)
        
        
        feats0 = self.extractor.extract(image0.to(self.device))
        feats1 = self.extractor.extract(image1.to(self.device))
        matches01 = self.matcher({"image0": feats0, "image1": feats1})
        feats0, feats1, matches01 = [
            rbd(x) for x in [feats0, feats1, matches01]
        ]  # remove batch dimension
        
        kpts0, kpts1, matches = feats0["keypoints"], feats1["keypoints"], matches01["matches"]
        m_kpts0, m_kpts1 = kpts0[matches[..., 0]], kpts1[matches[..., 1]]

        print(f"Number of matches; {len(matches)}; {image_path1} ; {image_path2} ")

        axes = viz2d.plot_images([image0, image1])
        viz2d.plot_matches(m_kpts0, m_kpts1, color="lime", lw=0.2)
        viz2d.add_text(0, f'Stop after {matches01["stop"]} layers', fs=20)

        kpc0, kpc1 = viz2d.cm_prune(matches01["prune0"]), viz2d.cm_prune(matches01["prune1"])
        viz2d.plot_images([image0, image1])
        viz2d.plot_keypoints([kpts0, kpts1], colors=[kpc0, kpc1], ps=10)

        return axes

matcher = ImageMatcher("cpu")

image_files = ['./' + f for f in glob.glob('*.jpg')]

for one in image_files:
    for two in image_files:
        if one != two:
            print(f"Comparing {one} and {two}")
            matcher.compare_images(Path(one), Path(two))