Created
April 3, 2018 11:21
-
-
Save darrengarvey/ff05fbe28ab2061c101fe64353b467ff to your computer and use it in GitHub Desktop.
Perf test for `tensorflow.data.Dataset.list_files()`
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
from __future__ import print_function | |
''' | |
Perf test for `tensorflow.data.Dataset.list_files()`. | |
At least on TF 1.7 all file names are loaded into memory before | |
training starts. This can be a bottleneck for slower disks, especially | |
with large datasets. | |
''' | |
import os | |
import shutil | |
import tempfile | |
import tensorflow as tf | |
import time | |
from functools import partial | |
FLAGS = tf.app.flags.FLAGS | |
tf.app.flags.DEFINE_string('dir', '/tmp/', 'Base directory to perf test') | |
tf.app.flags.DEFINE_integer('width', 10, 'Number of dirs at each level') | |
tf.app.flags.DEFINE_integer('depth', 2, 'Number of nested dirs') | |
def timeit(fn, msg, N=0): | |
start = time.time() | |
res = fn() | |
end = time.time() | |
runtime = (end - start) * 1000 | |
msg = '{}: time: {:.2f} ms'.format(msg, runtime) | |
if N: | |
msg += ' ({:.2f} ms per iteration)'.format(runtime / N) | |
print(msg) | |
return res | |
def load_data(): | |
if not os.path.exists(FLAGS.dir): | |
os.makedirs(FLAGS.dir) | |
base = tempfile.mkdtemp(prefix=FLAGS.dir) | |
print('saving files to dir: {}'.format(base)) | |
start = time.time() | |
for i in range(FLAGS.width): | |
new_base = os.path.join(base, str(i), *[str(j) for j in range(FLAGS.depth - 1)]) | |
if not os.path.exists(new_base): | |
os.makedirs(new_base) | |
f = os.path.join(new_base, 'stuff.txt') | |
open(f, 'w').close() | |
return base | |
def prep_data(base): | |
pattern = '{}/{}/*.txt'.format(base, os.path.join(*['**' for _ in range(FLAGS.depth)])) | |
dataset = tf.data.Dataset.list_files(pattern) | |
return dataset.make_one_shot_iterator().get_next() | |
def read_data(data, sess, N=1): | |
for _ in range(N): | |
sess.run(data) | |
def main(_): | |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' # hide some uninteresting logs | |
base = timeit(load_data, 'load data') | |
data = timeit(partial(prep_data, base), 'prep data') | |
with tf.Session() as sess: | |
timeit(partial(read_data, data, sess), 'read first filename') | |
timeit(partial(read_data, data, sess), 'read second filename') | |
N = (FLAGS.width) - 2 | |
timeit(partial(read_data, data, sess, N), 'read {} more filenames'.format(N), N) | |
shutil.rmtree(base) | |
if __name__ == '__main__': | |
tf.app.run(main) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment