Last active
December 18, 2018 06:32
-
-
Save tokestermw/cd34fde292ef5e7737e70133e0e379e1 to your computer and use it in GitHub Desktop.
Using the new Dataset API from TensorFlow 1.2.0, return padded and batched tensors from text data where each line is a sentence.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import tensorflow as tf | |
_major_version, _minor_version, _ = map(int, tf.__version__.split('-')[0].split('.')) | |
assert _major_version >= 1 and _minor_version >= 2, "requires TensorFlow 1.2.0 and above" | |
text_data_path = "./z_sentences.txt" | |
MAX_SEQUENCE_LENGTH = 10 | |
vocab = { | |
"_PAD": 0, | |
"_OOV": 1, | |
"how": 2, | |
"are": 3, | |
"doing": 4, | |
"let": 5, | |
"me": 6, | |
"know": 7, | |
} | |
def tokenize(text): | |
return text.split() | |
def vectorize(tokens, vocab): | |
vector = [vocab.get(word, vocab["_OOV"]) for word in tokens] | |
vector = vector[:MAX_SEQUENCE_LENGTH] | |
return vector | |
def _featurize_py_func(text): | |
tokens = tokenize(text) | |
vector = vectorize(tokens, vocab) | |
return np.array(vector, dtype=np.int32) | |
def make_dataset(batch_size=32, epoch_size=2): | |
filenames = [text_data_path] | |
dataset = tf.contrib.data.TextLineDataset(filenames) | |
dataset = (dataset.map(lambda text: tf.py_func( | |
_featurize_py_func, [text], [tf.int32])) | |
.skip(0) | |
# TODO: doesn't work with padded_shapes=[None] | |
.padded_batch(batch_size, padded_shapes=[MAX_SEQUENCE_LENGTH]) | |
.repeat(epoch_size) | |
.shuffle(buffer_size=10000)) | |
return dataset | |
if __name__ == "__main__": | |
dataset = make_dataset() | |
iterator = dataset.make_one_shot_iterator() | |
next_element = iterator.get_next() | |
with tf.Session() as sess: | |
for i in range(1000): | |
try: | |
element = sess.run(next_element[0]) | |
print(i, element) | |
except tf.errors.OutOfRangeError: | |
print("end of data") | |
break |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
how are you ? | |
let me know . | |
this is a long sentence , relatively speaking . |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment