Last active
June 26, 2025 08:06
-
-
Save vTuanpham/a5e3eaf44c6e6d180eb41ecbf1100770 to your computer and use it in GitHub Desktop.
A clean patch script to force higher-resolution segmentation masks by modifying the Proto stride and fixing validation logic. Supports full-res (mask_ratio=1), half-res (2), or quarter-res (4).
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# ----------------------------------------------------------------------------- | |
# Hardcoded patch for Ultralytics YOLOv8/YOLOv11 to fix segmentation mask | |
# resolution directly inside the model definition. | |
# | |
# This allowed for native high resolution without the need for mask_ratio postprocess | |
# and onnx export with little postprocessing logic. | |
# | |
# ✱ Only modifies: Proto module (stride), and validation postprocess logic. | |
# ✱ Ignores CLI config, cfg parser, and argument propagation. | |
# ----------------------------------------------------------------------------- | |
import pathlib, re, shutil, subprocess, sys | |
# ───────────── USER CONFIGURATION ──────────────────────────────────────────── | |
MASK_RATIO = 2 # 1 = full-res, 2 = half-res, 4 = quarter-res | |
REPO_URL = "https://github.com/ultralytics/ultralytics.git" | |
REPO_REF = "v8.3.159" | |
VERBOSE = True | |
# ----------------------------------------------------------------------------- | |
if 8 % MASK_RATIO: | |
raise ValueError("MASK_RATIO must divide 8 ⇒ choose 1, 2, or 4") | |
STRIDE = 8 // MASK_RATIO | |
REPO_DIR = pathlib.Path("ultralytics").resolve() | |
# ────────────── UTILS ─────────────────────────────────────────────────────── | |
def say(msg, icon="✓"): | |
if VERBOSE: | |
print(f"{icon} {msg}") | |
def sh(cmd): subprocess.run(cmd, check=True) | |
def patch(path: pathlib.Path, pattern: str, replace: str, *, | |
flags=re.M, required: bool = True): | |
"""Simple patch with idempotency.""" | |
if not path.exists(): | |
say(f"✗ {path} missing", "❌"); return | |
txt = path.read_text("utf-8") | |
if replace.strip() in txt: | |
say(f"✓ Already patched {path.name}"); return | |
new, n = re.subn(pattern, replace, txt, flags=flags, count=1) | |
if n: | |
path.write_text(new, "utf-8") | |
say(f"✓ Patched {path.name}") | |
elif required: | |
say(f"⚠ Pattern not found in {path.name}") | |
# ────────────── MAIN PATCHING LOGIC ───────────────────────────────────────── | |
def main() -> None: | |
say("🚀 Starting hardcoded mask_ratio patch") | |
if REPO_DIR.exists(): | |
say("🧹 Removing previous clone") | |
shutil.rmtree(REPO_DIR) | |
say(f"📥 Cloning @ {REPO_REF[:7]}") | |
sh(["git", "clone", "--branch", REPO_REF, "--depth", "1", REPO_URL, str(REPO_DIR)]) | |
# ── Files ────────────────────────────────────────────────────────────── | |
block_py = REPO_DIR / "ultralytics/nn/modules/block.py" | |
val_py = REPO_DIR / "ultralytics/models/yolo/segment/val.py" | |
# ── [1] Modify Proto upsample to emit masks at H/MASK_RATIO ─────────── | |
# We rewrite: | |
# self.upsample = nn.ConvTranspose2d(...) | |
# to: | |
# self.upsample = nn.ConvTranspose2d(c_, c_, STRIDE, STRIDE, bias=True) | |
patch(block_py, | |
r"self\.upsample\s*=\s*nn\.ConvTranspose2d\([^)]*\)", | |
rf"self.upsample = nn.ConvTranspose2d(c_, c_, {STRIDE}, {STRIDE}, bias=True)") | |
# ── [2] Fix val.py’s postprocess scale logic ─────────────────────────── | |
# Replace: | |
# imgsz = [4 * x for x in proto.shape[2:]] | |
# with: | |
# imgsz = [<MASK_RATIO> * x for x in proto.shape[2:]] | |
patch(val_py, | |
r"imgsz\s*=\s*\[4\s*\*\s*x\s*for\s*x\s*in\s*proto\.shape\[2:\]\]", | |
rf"imgsz = [{MASK_RATIO} * x for x in proto.shape[2:]]") | |
# ── [3] Reinstall patched repo ───────────────────────────────────────── | |
say("🔄 Reinstalling Ultralytics in editable mode") | |
sh([sys.executable, "-m", "pip", "install", "-e", str(REPO_DIR), "--no-deps", "-q", "--force-reinstall"]) | |
say(f"🎉 Done! Masks will now be generated at 1/{MASK_RATIO} resolution.") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment