Last active
December 16, 2022 18:46
-
-
Save lucasdavid/c48d39399b0bbe167cc2f5c056beadf0 to your computer and use it in GitHub Desktop.
Test which option translates segmentation masks pixels faster
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import os | |
import time | |
import numpy as np | |
import tensorflow as tf | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--opt', default='a', choices=['a', 'b']) | |
parser.add_argument('--batch_size', type=int, default=32) | |
parser.add_argument('--iterations', type=int, default=1000) | |
parser.add_argument('--data-dir', default='~/VOCdevkit/VOC2012') | |
# --- Benchmark --- | |
def mask_decode_py(mask: tf.Tensor) -> tf.Tensor: | |
if args.opt == 'a': | |
mask = mask[..., 0] * 256 * 256 + mask[..., 1] * 256 + mask[..., 2] | |
mask = tf.gather(COLOR_MAP, mask) | |
else: | |
mask = mask[..., 0] * 256**2 + mask[..., 1] * 256 + mask[..., 2] | |
mask = tf.argmax(mask[..., tf.newaxis] == COLOR_MAP, axis=-1) # HW, C | |
return mask[..., tf.newaxis] | |
py_fn_input_signature = [tf.TensorSpec(shape=[None, None, None, 3], dtype=tf.int32)] | |
mask_decode_jit = tf.function( | |
func=mask_decode_py, | |
reduce_retracing=True, | |
jit_compile=True, | |
input_signature=py_fn_input_signature) | |
mask_decode_jit.__name__ = 'mask_decode_jit' | |
mask_decode_tf = tf.function( | |
func=mask_decode_py, | |
reduce_retracing=True, | |
input_signature=py_fn_input_signature) | |
mask_decode_tf.__name__ = 'mask_decode_tf' | |
def sample_from(masks): | |
masks = np.random.choice(masks, size=args.batch_size, replace=False) | |
masks = [os.path.join(MASKS_DIR, m) for m in masks] | |
masks = tf.stack([ | |
tf.image.resize( | |
tf.io.decode_png(tf.io.read_file(mask_file_path)), | |
(512, 512) | |
) | |
for mask_file_path in masks | |
]) | |
return tf.cast(masks, tf.int32) | |
def benchmark(fn, masks): | |
warmup_time = 0 | |
for i in range(3): | |
d = sample_from(masks) | |
start = time.time() | |
r = fn(d) | |
warmup_time += time.time() - start | |
total_time = 0 | |
for i in range(args.iterations): | |
d = sample_from(masks) | |
start = time.time() | |
r = fn(d) | |
total_time += time.time() - start | |
print(f"{fn.__name__} | " | |
f"warmup: {warmup_time / 3:7.4f} s | " | |
f"total: {total_time / args.iterations:7.4f} s") | |
return r | |
if __name__ == "__main__": | |
args = parser.parse_args() | |
MASKS_DIR = os.path.expanduser(os.path.join(args.data_dir, 'SegmentationClass')) | |
ALL_MASKS = os.listdir(MASKS_DIR) | |
VOC_COLORMAP = np.array([ | |
[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128], | |
[128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0], [192, 0, 0], | |
[64, 128, 0], [192, 128, 0], [64, 0, 128], [192, 0, 128], [64, 128, 128], | |
[192, 128, 128], [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], | |
[0, 64, 128], [224, 224, 192] | |
], dtype='int32') | |
if args.opt == 'a': | |
COLOR_MAP = [0] * (256**3) | |
for i, cm in enumerate(VOC_COLORMAP): | |
COLOR_MAP[(cm[0] * 256 + cm[1]) * 256 + cm[2]] = i | |
# There is a special mapping with [224, 224, 192] -> 255 | |
COLOR_MAP[224 * 256 * 256 + 224 * 256 + 192] = 255 | |
else: | |
COLOR_MAP = VOC_COLORMAP[:, 0] * 256**2 + VOC_COLORMAP[:, 1] * 256 + VOC_COLORMAP[:, 2] | |
COLOR_MAP[-1] = 255 | |
COLOR_MAP = tf.constant(COLOR_MAP, dtype=tf.int32) | |
print(f'Benchmarking opt {args.opt} iterations:{args.iterations}') | |
benchmark(mask_decode_py, ALL_MASKS) | |
benchmark(mask_decode_tf, ALL_MASKS) | |
benchmark(mask_decode_jit, ALL_MASKS) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment