Last active
July 10, 2018 12:34
-
-
Save daniel-j-h/18b255ad5c82927413ea71bd830c9c51 to your computer and use it in GitHub Desktop.
Test-time augmentation utility helper for https://github.com/mapbox/robosat
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
''' | |
pip install tqdm pillow mercantile | |
''' | |
''' | |
Simple image rotation script for test-time augmentation. | |
Usage: | |
- predict on original slippy map dir with image tiles, save probabilities in probs0 directory | |
- copy original slippy map dir three times, use this script to rotate by 90, 180, 270, respectively | |
- predict on three new slippy map dirs, save in probs1, probs2, probs3 directory | |
- use this script to rotate probabilities back by 270, 180, 90, respectively | |
- use `rs masks` on probs0 probs1 probs2 probs3; it handles weighted average soft-voting already | |
''' | |
import os | |
import sys | |
import argparse | |
import concurrent.futures | |
import mercantile | |
from PIL import Image | |
from tqdm import tqdm | |
def tiles_from_slippy_map(root): | |
for z in os.listdir(root): | |
for x in os.listdir(os.path.join(root, z)): | |
for name in os.listdir(os.path.join(root, z, x)): | |
y = os.path.splitext(name)[0] | |
tile = mercantile.Tile(x=int(x), y=int(y), z=int(z)) | |
path = os.path.join(root, z, x, name) | |
yield tile, path | |
def main(): | |
rotations = {90: Image.ROTATE_90, 180: Image.ROTATE_180, 270: Image.ROTATE_270} | |
parser = argparse.ArgumentParser() | |
parser.add_argument('degree', type=int, choices=rotations.keys()) | |
parser.add_argument('root', type=str) | |
parser.add_argument('--threads', type=int, default=1) | |
args = parser.parse_args() | |
tiles = list(tiles_from_slippy_map(args.root)) | |
progress = tqdm(total=len(tiles), desc='Rotating', unit='tile', ascii=True) | |
def worker(tile): | |
_, path = tile | |
ok = False | |
try: | |
Image.open(path).transpose(rotations[args.degree]).save(path, optimize=True) | |
ok = True | |
except: | |
pass | |
progress.update() | |
return tile, ok | |
with concurrent.futures.ThreadPoolExecutor(args.threads) as executor: | |
for tile, ok in executor.map(worker, tiles): | |
if not ok: | |
print("Warning: {} failed, skipping".format(tile), file=sys.stderr) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment