import json
import yaml
from typing import Dict, Any, Union
from pathlib import Path
from datetime import datetime, date

def load_spec(file_path: str) -> Dict[str, Any]:
    """Load an OpenAPI spec from a YAML or JSON file."""
    path = Path(file_path)
    with open(path, 'r') as f:
        if path.suffix in ['.yaml', '.yml']:
            return yaml.safe_load(f)
        else:
            return json.load(f)

def resolve_ref(ref: str, spec: Dict[str, Any]) -> Dict[str, Any]:
    """Resolve a $ref in the OpenAPI spec"""
    if not ref.startswith('#/'):
        raise ValueError(f"Only local references are supported: {ref}")
    
    parts = ref.lstrip('#/').split('/')
    current = spec
    
    for part in parts:
        if part not in current:
            raise ValueError(f"Reference {ref} cannot be resolved. Part '{part}' not found.")
        current = current[part]
    
    return current

def resolve_refs_in_object(obj: Union[Dict[str, Any], list], spec: Dict[str, Any]) -> Union[Dict[str, Any], list]:
    """Recursively resolve all $ref references in an object."""
    if isinstance(obj, dict):
        resolved = {}
        for key, value in obj.items():
            if key == '$ref':
                # Resolve the reference and continue resolving any nested refs
                resolved_ref = resolve_ref(value, spec)
                return resolve_refs_in_object(resolved_ref, spec)
            else:
                resolved[key] = resolve_refs_in_object(value, spec)
        return resolved
    elif isinstance(obj, list):
        return [resolve_refs_in_object(item, spec) for item in obj]
    else:
        return obj

def resolve_all_refs(spec: Dict[str, Any]) -> Dict[str, Any]:
    """Resolve all references in an OpenAPI specification."""
    resolved_spec = {}
    
    # Process each top-level field
    for key, value in spec.items():
        resolved_spec[key] = resolve_refs_in_object(value, spec)
    
    return resolved_spec

class NoAliasDumper(yaml.SafeDumper):
    """A YAML Dumper that never uses aliases."""
    def ignore_aliases(self, data):
        return True

class OpenAPIJSONEncoder(json.JSONEncoder):
    """Custom JSON encoder that can handle datetime objects."""
    def default(self, obj):
        if isinstance(obj, (datetime, date)):
            return obj.isoformat()
        return super().default(obj)

def main():
    import argparse
    parser = argparse.ArgumentParser(description='Resolve all references in an OpenAPI specification')
    parser.add_argument('input_file', help='Path to input OpenAPI spec file (YAML or JSON)')
    parser.add_argument('output_file', help='Path to output the resolved spec')
    parser.add_argument('--format', choices=['yaml', 'json'], default='yaml',
                      help='Output format (default: yaml)')
    
    args = parser.parse_args()
    
    # Load the spec
    spec = load_spec(args.input_file)
    
    # Resolve all references
    resolved_spec = resolve_all_refs(spec)
    
    # Write the output
    with open(args.output_file, 'w') as f:
        if args.format == 'yaml':
            yaml.dump(resolved_spec, f, sort_keys=False, Dumper=NoAliasDumper)
        else:
            json.dump(resolved_spec, f, indent=2, cls=OpenAPIJSONEncoder)
    
    print(f"Successfully resolved all references and saved to {args.output_file}")

if __name__ == '__main__':
    main()