Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save dsdanielpark/f265df42c1f597368250997c22db6a8d to your computer and use it in GitHub Desktop.
Save dsdanielpark/f265df42c1f597368250997c22db6a8d to your computer and use it in GitHub Desktop.
Rate‑Limit‑Safe Sequential Download of Large HuggingFace Models
"""
Sequentially download all files from a Hugging Face model repository
while preserving default cache behavior and robust retry logic.
"""
import os
import time
import argparse
import logging
from huggingface_hub import HfApi, hf_hub_download

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)

# Environment variables to optimize download behavior
ENV_VARS = {
    'HF_HUB_ENABLE_HF_TRANSFER': '0',      # disable parallel transfer
    'HF_HUB_DOWNLOAD_TIMEOUT': '900',      # 15 minutes timeout
    'HF_HUB_ENABLE_CHUNK_DOWNLOAD': '1',   # enable chunked downloads
    'HF_HUB_DOWNLOAD_CHUNK_SIZE': '100000000',  # 100 MB chunks
    'HF_HUB_DOWNLOAD_RETRY_COUNT': '5'     # maximum retry attempts
}


def setup_environment():
    """Set environment variables for Hugging Face downloads."""
    for key, value in ENV_VARS.items():
        os.environ[key] = value
    logger.info("Environment configured for sequential download.")


def list_repository_files(repo_id: str, token: str = None) -> list:
    """Return list of files in the Hugging Face repository."""
    api = HfApi()
    try:
        files = api.list_repo_files(repo_id, token=token)
        logger.info(f"Discovered {len(files)} files in '{repo_id}'.")
        return files
    except Exception as err:
        logger.error(f"Failed to list files for '{repo_id}': {err}")
        return []


def download_file_with_retries(
    repo_id: str,
    filename: str,
    token: str = None
) -> bool:
    """Download a single file with retry logic."""
    retries = int(os.getenv('HF_HUB_DOWNLOAD_RETRY_COUNT', '5'))
    for attempt in range(1, retries + 1):
        try:
            logger.info(f"Downloading '{filename}' (Attempt {attempt}/{retries})")
            path = hf_hub_download(
                repo_id=repo_id,
                filename=filename,
                token=token,
                local_dir=None,
                local_dir_use_symlinks=False,
                force_download=False,
                resume_download=True
            )
            logger.info(f"Successfully downloaded '{filename}' to {path}")
            return True
        except Exception as err:
            logger.error(f"Error downloading '{filename}': {err}")
            if attempt < retries:
                sleep_time = attempt * 10
                logger.info(f"Retrying in {sleep_time}s...")
                time.sleep(sleep_time)
    logger.error(f"Exhausted retries for '{filename}'.")
    return False


def download_model_sequentially(repo_id: str, token: str = None) -> None:
    """Download all files in a repository one by one."""
    setup_environment()

    files = list_repository_files(repo_id, token)
    if not files:
        logger.warning("No files to download. Exiting.")
        return

    # Separate config files from weight files
    config_files = [f for f in files if not f.lower().endswith(('.bin', '.safetensors'))]
    weight_files = [f for f in files if f.lower().endswith(('.bin', '.safetensors'))]

    # Download configs first, then weights
    for file_list in (config_files, weight_files):
        for filename in file_list:
            success = download_file_with_retries(repo_id, filename, token)
            if not success:
                logger.warning(f"Skipping '{filename}' after failures.")

    logger.info(f"Completed downloading repository '{repo_id}'.")


def parse_arguments():
    """Parse command-line arguments."""
    parser = argparse.ArgumentParser(
        description="Sequentially download Hugging Face model repository files"
    )
    parser.add_argument(
        '--model_id',
        type=str,
        required=True,
        help="Repository ID (e.g., 'Qwen/Qwen2-VL-72B-Instruct')"
    )
    parser.add_argument(
        '--token',
        type=str,
        default=None,
        help="Optional Hugging Face API token"
    )
    return parser.parse_args()


def main():
    args = parse_arguments()
    download_model_sequentially(args.model_id, args.token)


if __name__ == '__main__':
    main()
    
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment