Skip to content

Instantly share code, notes, and snippets.

@daniellerch
Created February 16, 2018 20:36
Show Gist options
  • Save daniellerch/41eb92aaf9529d4ca8c8b0f4cb2b4fbd to your computer and use it in GitHub Desktop.
Save daniellerch/41eb92aaf9529d4ca8c8b0f4cb2b4fbd to your computer and use it in GitHub Desktop.
Keras/Finetuning
from keras.preprocessing.image import ImageDataGenerator
from keras.optimizers import SGD
from keras.models import Model
from keras.utils import np_utils
from keras.preprocessing import image
from keras.applications.inception_resnet_v2 import InceptionResNetV2
from keras.applications.inception_resnet_v2 import preprocess_input, decode_predictions
from keras.layers import Dense, GlobalAveragePooling2D
from sklearn.model_selection import train_test_split
import numpy as np
import os
import glob
def load_data(path, pattern):
class_names={}
class_id=0
x = []
y = []
for d in glob.glob(os.path.join(path, '*')):
clname = os.path.basename(d)
for f in glob.glob(os.path.join(d, pattern)):
if not clname in class_names:
class_names[clname]=class_id
class_id += 1
img = image.load_img(f, target_size=(224, 224))
npi = image.img_to_array(img)
npi = preprocess_input(npi)
for i in range(4):
npi=np.rot90(npi, i)
x.append(npi)
y.append(class_names[clname])
return np.array(x), np.array(y), class_names
x, y, class_names = load_data('flower_photos', '*.jpg')
num_classes = len(class_names)
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.1)
y_train = np_utils.to_categorical(y_train, num_classes)
y_test = np_utils.to_categorical(y_test, num_classes)
model = InceptionResNetV2(weights='imagenet', include_top=False)
x = model.output
x = GlobalAveragePooling2D()(x)
predictions = Dense(num_classes, activation='softmax')(x)
model = Model(input=model.input, output=predictions)
LAYERS_TO_FREEZE=700
for layer in model.layers[:LAYERS_TO_FREEZE]:
layer.trainable = False
model.compile(optimizer="adam", loss='categorical_crossentropy', metrics=['accuracy'])
model.fit(x_train, y_train, batch_size=128, epochs=1, verbose=1, validation_split=0.1)
score = model.evaluate(x_test, y_test, verbose=0)
print 'Testing set accuracy:', score[1]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment