Last active
July 7, 2021 06:37
-
-
Save wingbeats/fb7de05d8b59603945a70a012bb5c5cf to your computer and use it in GitHub Desktop.
Wingbeats Guided Grad-CAM
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## based on https://github.com/jacobgil/keras-grad-cam | |
## and on https://github.com/vense/keras-grad-cam/blob/master/grad-cam.py | |
import keras | |
from keras.layers import Input | |
from keras.layers import BatchNormalization | |
from keras.layers import Conv2D | |
from keras.layers import Activation | |
from keras.layers import MaxPooling2D | |
from keras.layers import GlobalAveragePooling2D | |
from keras.layers import Dropout | |
from keras.layers import Dense | |
from keras.preprocessing import image | |
from keras.preprocessing.image import img_to_array | |
from keras.layers.core import Lambda | |
from tensorflow.python.framework import ops | |
from keras import Model | |
import keras.backend as K | |
import tensorflow as tf | |
import numpy as np | |
import sys | |
import cv2 | |
from matplotlib import pyplot as plt | |
import librosa | |
from scipy.misc import imsave | |
model_name = 'basic_cnn_2d' | |
best_weights_path = model_name + '.h5' | |
SR = 8000 | |
N_FFT = 256 | |
HOP_LEN = N_FFT / 6 | |
input_shape = (129, 120, 3) | |
img_input = Input(shape = input_shape) | |
target_names = ['Ae. aegypti', 'Ae. albopictus', 'An. gambiae', 'An. arabiensis', 'C. pipiens', 'C. quinquefasciatus'] | |
def basic_cnn_2d(rows, cols, channels, num_classes): | |
inputs = Input(shape = (rows, cols, channels)) | |
x = BatchNormalization() (inputs) | |
x = Conv2D(16, kernel_size = (3, 3), activation = 'relu', padding = 'same', name = 'conv1') (x) | |
x = BatchNormalization() (x) | |
x = MaxPooling2D((2,2)) (x) | |
x = Conv2D(32, kernel_size = (3, 3), activation = 'relu', padding = 'same', name = 'conv2') (x) | |
x = BatchNormalization() (x) | |
x = MaxPooling2D((2,2)) (x) | |
x = Conv2D(64, kernel_size = (3, 3), activation = 'relu', padding = 'same', name = 'conv3') (x) | |
x = BatchNormalization() (x) | |
x = MaxPooling2D((2,2)) (x) | |
x = Conv2D(128, kernel_size = (3, 3), activation = 'relu', padding = 'same', name = 'conv4') (x) | |
x = BatchNormalization() (x) | |
x = MaxPooling2D((2,2)) (x) | |
x = Conv2D(256, kernel_size = (3, 3), activation = 'relu', padding = 'same', name = 'conv5') (x) | |
x = BatchNormalization() (x) | |
x = MaxPooling2D((2,2)) (x) | |
x = GlobalAveragePooling2D() (x) | |
x = Dropout(0.5) (x) | |
x = Dense(num_classes) (x) | |
outputs = Activation('softmax') (x) | |
model = Model(inputs, outputs) | |
model.compile(optimizer = 'adam', loss = 'categorical_crossentropy', metrics = ['accuracy']) | |
return model | |
def target_category_loss(x, category_index, nb_classes): | |
return tf.multiply(x, K.one_hot([category_index], nb_classes)) | |
def target_category_loss_output_shape(input_shape): | |
return input_shape | |
def normalize(x): | |
return x / (K.sqrt(K.mean(K.square(x))) + 1e-5) | |
def register_gradient(): | |
if "GuidedBackProp" not in ops._gradient_registry._registry: | |
@ops.RegisterGradient("GuidedBackProp") | |
def _GuidedBackProp(op, grad): | |
dtype = op.inputs[0].dtype | |
return grad * tf.cast(grad > 0., dtype) * \ | |
tf.cast(op.inputs[0] > 0., dtype) | |
def compile_saliency_function(model, chosen_layer): | |
input_img = model.input | |
layer_dict = dict([(layer.name, layer) for layer in model.layers[1:]]) | |
layer_output = layer_dict[chosen_layer].output | |
max_output = K.max(layer_output, axis=3) | |
saliency = K.gradients(K.sum(max_output), input_img)[0] | |
return K.function([input_img, K.learning_phase()], [saliency]) | |
def modify_backprop(model, name): | |
g = tf.get_default_graph() | |
with g.gradient_override_map({'Relu': name}): | |
layer_dict = [layer for layer in model.layers[1:] | |
if hasattr(layer, 'activation')] | |
for layer in layer_dict: | |
if layer.activation == keras.activations.relu: | |
layer.activation = tf.nn.relu | |
new_model = basic_cnn_2d(input_shape[0], input_shape[1], input_shape[2], len(target_names)) | |
new_model.load_weights(best_weights_path) | |
return new_model | |
def deprocess_image(x): | |
if np.ndim(x) > 3: | |
x = np.squeeze(x) | |
x -= x.mean() | |
x /= (x.std() + 1e-5) | |
x *= 0.1 | |
x += 0.5 | |
x = np.clip(x, 0, 1) | |
x *= 255 | |
x = np.clip(x, 0, 255).astype('uint8') | |
return x | |
def grad_cam(input_model, img, category_index, chosen_layer): | |
nb_classes = len(target_names) | |
target_layer = lambda x: target_category_loss(x, category_index, nb_classes) | |
x = input_model.layers[-1].output | |
x = Lambda(target_layer, output_shape=target_category_loss_output_shape)(x) | |
model = keras.models.Model(input_model.layers[0].input, x) | |
loss = K.sum(model.layers[-1].output) | |
conv_output = [l for l in model.layers if l.name is chosen_layer][0].output | |
grads = normalize(K.gradients(loss, conv_output)[0]) | |
gradient_function = K.function([model.layers[0].input], [conv_output, grads]) | |
output, grads_val = gradient_function([img]) | |
output, grads_val = output[0, :], grads_val[0, :, :, :] | |
weights = np.mean(grads_val, axis = (0, 1)) | |
cam = np.ones(output.shape[0 : 2], dtype = np.float32) | |
for i, w in enumerate(weights): | |
cam += w * output[:, :, i] | |
cam = cv2.resize(cam, (120, 129)) | |
cam = np.maximum(cam, 0) | |
heatmap = cam / np.max(cam) | |
img = img*255. | |
img = np.squeeze(img, axis=0) | |
cam = cv2.applyColorMap(np.uint8(255*heatmap), cv2.COLORMAP_JET) | |
cam = np.float32(cam) + np.float32(img) | |
cam = 255 * cam / np.max(cam) | |
return np.uint8(cam), heatmap | |
#################################################################################### | |
wav1 = 'Wingbeats/Ae. aegypti/D_16_12_13_00_57_15/F161213_010334_189_G_050.wav' | |
wav2 = 'Wingbeats/An. arabiensis/D_17_01_31_18_23_02/F170131_195705_105_G_050.wav' | |
data1, rate = librosa.load(wav1, sr = SR) | |
X = librosa.stft(data1, n_fft = N_FFT, hop_length = HOP_LEN) | |
D = librosa.amplitude_to_db(X) | |
D = np.flipud(D) | |
imsave(wav1.split('Wingbeats/')[1].split('/')[0] + '.png', D) | |
data2, rate = librosa.load(wav2, sr = SR) | |
X = librosa.stft(data2, n_fft = N_FFT, hop_length = HOP_LEN) | |
D = librosa.amplitude_to_db(X) | |
D = np.flipud(D) | |
imsave(wav2.split('Wingbeats/')[1].split('/')[0] + '.png', D) | |
mix_data = data1 + data2 | |
mix_data = librosa.stft(mix_data, n_fft = N_FFT, hop_length = HOP_LEN) | |
mix_data = librosa.amplitude_to_db(mix_data) | |
mix_data = np.flipud(mix_data) | |
imsave('mix.png', mix_data) | |
#################################################################################### | |
img_name1 = 'Ae. aegypti.png' | |
img_name2 = 'An. arabiensis.png' | |
img_name3 = 'mix.png' | |
data1 = image.load_img(img_name1, target_size=(129, 120)) | |
data1 = image.img_to_array(data1) | |
data1 /= 255. | |
plt.imshow(data1[:,:,0]) | |
plt.title(img_name1.split('.png')[0]) | |
plt.axis('off') | |
plt.show() | |
data2 = image.load_img(img_name2, target_size=(129, 120)) | |
data2 = image.img_to_array(data2) | |
data2 /= 255. | |
plt.imshow(data2[:,:,0]) | |
plt.title(img_name2.split('.png')[0]) | |
plt.axis('off') | |
plt.show() | |
mix_data = image.load_img(img_name3, target_size=(129, 120)) | |
mix_data = image.img_to_array(mix_data) | |
mix_data /= 255. | |
plt.imshow(mix_data[:,:,0]) | |
plt.title(img_name1.split('.png')[0] + ' + ' + img_name2.split('.png')[0]) | |
plt.axis('off') | |
plt.show() | |
data1 = np.expand_dims(data1, axis = 0) | |
data2 = np.expand_dims(data2, axis = 0) | |
mix_data = np.expand_dims(mix_data, axis = 0) | |
model = basic_cnn_2d(input_shape[0], input_shape[1], input_shape[2], len(target_names)) | |
model.load_weights(best_weights_path) | |
chosen_layer = 'conv5' | |
predictions_1 = model.predict(data1) | |
predicted_class_1 = np.argmax(predictions_1) | |
predictions_2 = model.predict(data2) | |
predicted_class_2 = np.argmax(predictions_2) | |
#################################################################################### | |
cam1, heatmap1 = grad_cam(model, mix_data, predicted_class_1, chosen_layer) | |
register_gradient() | |
guided_model = modify_backprop(model, 'GuidedBackProp') | |
saliency_fn = compile_saliency_function(guided_model, chosen_layer) | |
saliency = saliency_fn([mix_data, 0]) | |
gradcam = saliency[0] * heatmap1[..., np.newaxis] | |
deprocessed_image = deprocess_image(gradcam) | |
deprocessed_image = cv2.cvtColor(deprocessed_image, cv2.COLOR_BGR2RGB) | |
mixture_spec = np.squeeze(mix_data, axis = 0) | |
plt.imshow(mixture_spec[:,:,0]) | |
plt.imshow(deprocessed_image[:,:,0], alpha = 0.70) | |
plt.title(img_name1.split('.png')[0] + ' Guided Grad-CAM') | |
plt.axis('off') | |
plt.show() | |
#################################################################################### | |
cam2, heatmap2 = grad_cam(model, mix_data, predicted_class_2, chosen_layer) | |
register_gradient() | |
guided_model = modify_backprop(model, 'GuidedBackProp') | |
saliency_fn = compile_saliency_function(guided_model, chosen_layer) | |
saliency = saliency_fn([mix_data, 0]) | |
gradcam = saliency[0] * heatmap2[..., np.newaxis] | |
deprocessed_image = deprocess_image(gradcam) | |
deprocessed_image = cv2.cvtColor(deprocessed_image, cv2.COLOR_BGR2RGB) | |
mixture_spec = np.squeeze(mix_data, axis = 0) | |
plt.imshow(mixture_spec[:,:,0]) | |
plt.imshow(deprocessed_image[:,:,0], alpha = 0.70) | |
plt.title(img_name2.split('.png')[0] + ' Guided Grad-CAM') | |
plt.axis('off') | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment