Created
January 5, 2020 07:25
-
-
Save PhamBaTrungThanh/6d1f91fb1edb054bd740edc560c43f56 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import absolute_import, division, print_function, unicode_literals | |
import pathlib | |
import os | |
import tensorflow as tf | |
from tensorflow import keras | |
# Helper libraries | |
import numpy as np | |
import matplotlib.pyplot as plt | |
BATCH_SIZE = 32 | |
IMG_HEIGHT = 120 | |
IMG_WIDTH = 120 | |
AUTOTUNE = tf.data.experimental.AUTOTUNE | |
DATA_DIR = 'data/6x6/trains' | |
data_dir = pathlib.Path(DATA_DIR) | |
CLASS_NAMES = np.array([item.name for item in data_dir.glob('*') if item.name != "LICENSE.txt"]) | |
def decode_img(img): | |
# convert the compressed string to a 3D uint8 tensor | |
img = tf.image.decode_jpeg(img, channels=3) | |
# Use `convert_image_dtype` to convert to floats in the [0,1] range. | |
img = tf.image.convert_image_dtype(img, tf.float32) | |
# resize the image to the desired size. | |
return tf.image.resize(img, [IMG_WIDTH, IMG_HEIGHT]) | |
def get_label(file_path): | |
# convert the path to a list of path components | |
parts = tf.strings.split(file_path, os.path.sep) | |
# The second to last is the class-directory | |
return parts[-2] == CLASS_NAMES | |
def process_path(file_path): | |
label = get_label(file_path) | |
# load the raw data from the file as a string | |
img = tf.io.read_file(file_path) | |
img = decode_img(img) | |
return img, label | |
def prepare_for_training(ds, cache=True, shuffle_buffer_size=1000): | |
# This is a small dataset, only load it once, and keep it in memory. | |
# use `.cache(filename)` to cache preprocessing work for datasets that don't | |
# fit in memory. | |
if cache: | |
if isinstance(cache, str): | |
ds = ds.cache(cache) | |
else: | |
ds = ds.cache() | |
ds = ds.shuffle(buffer_size=shuffle_buffer_size) | |
# Repeat forever | |
ds = ds.repeat() | |
ds = ds.batch(BATCH_SIZE) | |
# `prefetch` lets the dataset fetch batches in the background while the model | |
# is training. | |
ds = ds.prefetch(buffer_size=AUTOTUNE) | |
return ds | |
list_ds = tf.data.Dataset.list_files(str(data_dir/'*/*')) | |
labeled_ds = list_ds.map(process_path, num_parallel_calls=AUTOTUNE) | |
train_ds = prepare_for_training(labeled_ds) | |
image_batch, label_batch = next(iter(train_ds)) | |
model = keras.Sequential([ | |
keras.layers.Flatten(), | |
keras.layers.Dense(128, activation='relu'), | |
keras.layers.Dense(10, activation='softmax') | |
]) | |
model.compile(optimizer='adam', | |
loss='sparse_categorical_crossentropy', | |
metrics=['accuracy']) | |
model.fit(image_batch, label_batch, epochs=10) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment