#!/usr/bin/env python3

import argparse
import os
import re
import subprocess
import json
from packaging.version import parse as parse_version
from functools import lru_cache

@lru_cache(maxsize=32)
def get_tag_version_with_gh(repo, sha, debug):
    cmd = ['gh', 'api', f'repos/{repo}/git/refs/tags', '--paginate']
    try:
        output = subprocess.check_output(cmd).decode('utf-8')
    except subprocess.CalledProcessError as e:
        print(f"An error occurred: {str(e)}. Standard Error Output: {e.stderr}")
        return None
    
    tags_data = json.loads(output)
    possible_tags = []

    for entry in tags_data:
        if 'object' in entry and entry['object']['sha'] == sha:
            possible_tags.append(entry['ref'].split('/')[-1])

    if debug:
        print(f"Debug: Available tags for {repo}@{sha}: {possible_tags}")

    # Separate semantic and non-semantic versions
    semver_tags = []
    non_semver_tags = []

    for tag in possible_tags:
        try:
            version = parse_version(tag)
            if type(version) is not str:
                semver_tags.append((version, tag))
            else:
                non_semver_tags.append(tag)
        except:
            non_semver_tags.append(tag)

    # Sort semantic versions
    semver_tags.sort(key=lambda x: (x[0].epoch, x[0].release, x[0].pre, x[0].dev, x[0].post), reverse=True)

    # Sort non-semantic versions
    non_semver_tags.sort(reverse=True)

    # Prefer semantic versions if available
    if semver_tags:
        return semver_tags[0][1]
    elif non_semver_tags:
        return non_semver_tags[0]
    else:
        return None

def update_yaml_comments(yaml_lines, debug):
    updated_lines = []
    # Using the updated regex pattern with named groups
    pattern = re.compile(r"uses: (?P<repo>[^/@]+/[^/@]+)(?:/[^@]*)?@(?P<sha>[a-f0-9]{40})")

    for line in yaml_lines:
        match = pattern.search(line)
        if match:
            # Extract named groups "repo" and "sha"
            repo = match.group("repo")
            sha = match.group("sha")

            tag_version = get_tag_version_with_gh(repo, sha, debug)

            # Remove any existing comment starting with "# tag="
            line_content = re.sub(r"(  # tag=[^\s]+)?", "", line)

            if tag_version:
                if debug:
                    print(f"Debug: Selected tag for {repo}@{sha}: {tag_version}")
                # Add the new comment, preserving any original newlines
                line = f"{line_content.rstrip()}  # tag={tag_version}{os.linesep}"

        updated_lines.append(line)

    return updated_lines

def main():
    parser = argparse.ArgumentParser(description='Update GitHub Actions YAML file with tag comments.')
    parser.add_argument('files', type=str, nargs='+', help='Paths to the GitHub Actions YAML files.')
    parser.add_argument('-i', '--in-place', action='store_true', help='Modify the files in-place.')
    parser.add_argument('--debug', action='store_true', help='Enable debug output.')
    args = parser.parse_args()

    for file_path in args.files:
        with open(file_path, 'r') as f:
            yaml_lines = f.readlines()
        updated_yaml_lines = update_yaml_comments(yaml_lines, args.debug)
        
        if args.in_place:
            with open(file_path, 'w') as f:
                f.writelines(updated_yaml_lines)
        else:
            print(f"--- {file_path} ---")
            print("".join(updated_yaml_lines))

if __name__ == "__main__":
    main()