Skip to content

Instantly share code, notes, and snippets.

@bollwyvl
Last active January 23, 2026 22:03
Show Gist options
  • Select an option

  • Save bollwyvl/34761e760011a68990a0ab0079f1b661 to your computer and use it in GitHub Desktop.

Select an option

Save bollwyvl/34761e760011a68990a0ab0079f1b661 to your computer and use it in GitHub Desktop.
fat feedstock saver
"""Update a ``recipe.yaml``'s sources and licenses from a checkout with submodules.
Given a ``recipe.yaml`` to have the following structure:
context:
version: <the version>
# if found, sources with the same `target_directory` will preserve patches
# source:
# - patches: []
about:
repository: <the github URL>
extra:
feedstock-name: <a name>
... the following will be overwritten:
source:
- url: <the tarball with ${{ version }}>
sha256: <the sha>
# patches: [<patches>]
- target_directory: <the target> # <ref>
url: <the submodule>
sha256: <the sha>
about:
repository: <the github URL>
license_file:
- <the file>
"""
from __future__ import annotations
import sys
from ruamel.yaml import YAML, CommentedMap
from pathlib import Path
from typing import Any
from io import StringIO
from copy import deepcopy
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
import tomllib
import subprocess
from urllib.request import urlretrieve
from dataclasses import dataclass, field
from logging import getLogger, Logger
from hashlib import sha256
import functools
import re
from collections.abc import Iterator
HERE = Path(__file__).parent
GH_TARBALL = "{repo}/archive/{ref}.tar.gz"
RE_GH_URL = r"https://github.com/(?P<org>.*?)/(?P<repo>.*?)/archive/(?P<ref>.*).tar.gz"
RE_SUBMOD_LINE = r"^(path|url)\s*=\s*(.*)$"
# factories
def _get_yaml_parser():
"""Build a reasonable round-tripping YAML parser matching prettier defaults."""
yaml = YAML(typ="rt")
yaml.width = 4096
yaml.preserve_quotes = True
yaml.default_flow_style = False
yaml.explicit_start = False
yaml.indent(mapping=2, sequence=4, offset=2)
return yaml
@dataclass
class FatFeedstockSaver:
recipe_path: Path
work_dir: Path
tag_pattern: str
recursive: bool
license_globs: list[str]
ignore_paths: list[str]
# lazy/derived
yaml: YAML = field(default_factory=_get_yaml_parser)
log: Logger = field(default_factory=lambda: getLogger(__name__))
recipe_text: str = str
recipe: dict[str, Any] = field(default_factory=dict)
url_rewrites: list[str] = field(default_factory="")
sources_by_target: dict[str, CommentedMap] = field(default_factory=dict)
def __post_init__(self):
self.recipe_text = self.recipe_path.read_text(encoding="utf-8").strip()
self.recipe = self.yaml.load(self.recipe_text)
self.sources_by_target = {
src.get("target_directory"): src for src in self.recipe.get("source", [])
}
def run(self) -> int:
"""Update the recipe with all the tarballs."""
self.ensure_repo()
sources = [self.main_tarball_source(), *self.submodules_sources()]
self.update_recipe(sources)
return 0
def ensure_repo(self) -> None:
"""Ensure the repo exists and is checked out at the expected tag."""
if not self.repo_dir.exists():
self.git("clone", self.repo_url, self.repo_dir)
git = functools.partial(self.git, cwd=self.repo_dir)
git("fetch", "origin", self.tag)
git("checkout", "--force", self.tag)
recursive = ["--recursive"] if self.recursive else []
git("submodule", "update", "--init", *recursive)
def submodules_sources(self) -> Iterator[CommentedMap]:
"""Find submodules."""
mod_toml = self.submodule_toml()
subs = {
m["path"]: {"url": m["url"].removesuffix(".git"), "name": name}
for name, m in mod_toml.items()
}
raw_subs = self.git("submodule", "status", capture=True, cwd=self.repo_dir)
for sub in raw_subs.strip().splitlines():
commit, path, ref = sub.split()
if self.ignore_paths and path.startswith(tuple(self.ignore_paths)):
self.log.warning("... ignoring source for %s", path)
subs[path].update(commit=commit, ref=ref)
yield from [
self.one_submodule_source(path, sub) for path, sub in sorted(subs.items())
]
def submodule_toml(self) -> list[dict[str, Any]]:
"""Munge .gitmodules into toml and parse."""
raw_mods = (self.repo_dir / ".gitmodules").read_text(encoding="utf-8")
mods = raw_mods.replace("submodule ", "submodule.")
mods = re.sub(RE_SUBMOD_LINE, r"\1 = '\2'", mods, flags=re.MULTILINE)
try:
return tomllib.loads(mods)["submodule"]
except tomllib.TOMLDecodeError as err:
self.log.error("%s", mods)
raise err
def one_submodule_source(self, path: str, sub: dict[str, Any]) -> CommentedMap:
"""Build and comment one submodule source."""
url = GH_TARBALL.format(repo=sub["url"], ref=sub["commit"])
for rewrite in self.url_rewrites:
old, new = rewrite.split(":")
url = url.replace(old, new)
src = self.tarball_source(url, target=path)
src.yaml_add_eol_comment(f"""{sub["ref"]}""", "target_directory")
return src
def update_recipe(self, sources: list[CommentedMap]) -> None:
"""Write out the recipe, if needed."""
new_recipe = deepcopy(self.recipe)
new_recipe["source"] = sources
new_recipe["about"]["license_file"] = sorted(set(self.all_license_files()))
io = StringIO()
self.yaml.dump(new_recipe, io)
new_text = io.getvalue().strip()
if new_text == self.recipe_text:
self.log.warning("recipe did not change")
else:
self.log.warning("updating recipe")
self.recipe_path.write_text(new_text + "\n", encoding="utf-8")
def all_license_files(self) -> Iterator[str]:
"""Get all the license files in all sources."""
for glob in self.license_globs:
for path in self.repo_dir.rglob(glob):
if path.is_file():
rel = path.relative_to(self.repo_dir).as_posix()
if self.ignore_paths and rel.startswith(tuple(self.ignore_paths)):
self.log.warning("... ignoring license %s", rel)
continue
yield rel
def main_tarball_source(self) -> str:
"""Get the main tarball source entry."""
url = GH_TARBALL.format(repo=self.repo_url, ref=f"refs/tags/{self.tag}")
src = self.tarball_source(url)
src["url"] = src["url"].replace(
self.tag, self.tag_pattern.format(version="${{ version }}")
)
return src
def tarball_source(self, url: str, target: str | None = None) -> CommentedMap:
"""Fetch a tarball at a URL, returning a source entry."""
m = re.search(RE_GH_URL, url)
if m is None:
msg = f"cannot build predictable filename for {url}"
raise ValueError(msg)
fname = "{org}--{repo}--{ref}.tar.gz".format(**m.groupdict())
fname = re.sub(r"[^a-z\d_\-\.]+", "--", fname).lower()
dest = self.tarball_dir / fname
if not dest.exists():
dest.parent.mkdir(parents=True, exist_ok=True)
self.log.warning(
"... fetching %s to\n\t%s\n\tas %s", url, fname, target or "."
)
urlretrieve(url, dest)
src = {**self.sources_by_target.get(target, {})}
if target:
src.update(target_directory=target)
src.update(url=url, sha256=sha256(dest.read_bytes()).hexdigest())
return CommentedMap(src)
def git(
self, *args: Any, cwd: Path | None = None, capture: bool = False
) -> str | None:
"""Shell out to ``git``."""
args = [*map(str, ["git", *args])]
self.log.warning("in %s\n\t>>> %s", cwd, " ".join(args))
if capture:
return subprocess.check_output(args, cwd=cwd, encoding="utf-8")
else:
return subprocess.check_call(args, cwd=cwd)
@property
def repo_url(self) -> str:
return self.recipe["about"]["repository"]
@property
def version(self) -> str:
return self.recipe["context"]["version"]
@property
def tag(self) -> str:
return self.tag_pattern.format(version=self.version)
@property
def name(self) -> str:
return self.recipe["extra"]["feedstock-name"]
@property
def repo_dir(self) -> Path:
return self.work_dir / self.name
@property
def tarball_dir(self) -> str:
return self.work_dir / "tarballs"
if __name__ == "__main__":
parser = ArgumentParser(
usage=__doc__, formatter_class=ArgumentDefaultsHelpFormatter
)
parser.add_argument("recipe_path", type=Path, help="path to a recipe.yaml")
parser.add_argument("--work-dir", type=Path, help="path to working files")
parser.add_argument(
"--tag-pattern",
default="v{version}",
help="python format string for upstream tag",
)
parser.add_argument(
"--recursive",
action="store_false",
help="whether to update submodules recursively",
)
parser.add_argument(
"--ignore-path",
dest="ignore_paths",
action="append",
help="paths in a checkout that should not be included as source/license",
)
parser.add_argument(
"--license-glob",
dest="license_globs",
action="append",
default=["LICENSE*", "LICENCE*"],
help="globs to add to license_file",
)
parser.add_argument(
"--url-rewrite",
action="append",
dest="url_rewrites",
help="replace strings in repos. e.g. old-org/old-repo:new-org/new-repo",
)
ffs = FatFeedstockSaver(**dict(vars(parser.parse_args())))
sys.exit(ffs.run())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment