Skip to content

Instantly share code, notes, and snippets.

@skywalkerisnull
Created July 1, 2019 04:47
Show Gist options
  • Save skywalkerisnull/cebc1fc2b00fa76da92173d2baa21714 to your computer and use it in GitHub Desktop.
Save skywalkerisnull/cebc1fc2b00fa76da92173d2baa21714 to your computer and use it in GitHub Desktop.
Be able to use the multi-gpu on Keras 2.2.4
"""
Mask R-CNN
Multi-GPU Support for Keras.
Copyright (c) 2017 Matterport, Inc.
Licensed under the MIT License (see LICENSE for details)
Written by Waleed Abdulla
Ideas and a small code snippets from these sources:
https://github.com/fchollet/keras/issues/2436
https://medium.com/@kuza55/transparent-multi-gpu-training-on-tensorflow-with-keras-8b0016fd9012
https://github.com/avolkov1/keras_experiments/blob/master/keras_exp/multigpu/
https://github.com/fchollet/keras/blob/master/keras/utils/training_utils.py
"""
import tensorflow as tf
import keras.backend as K
import keras.layers as KL
import keras.models as KM
class ParallelModel(KM.Model):
"""Subclasses the standard Keras Model and adds multi-GPU support.
It works by creating a copy of the model on each GPU. Then it slices
the inputs and sends a slice to each copy of the model, and then
merges the outputs together and applies the loss on the combined
outputs.
"""
def __init__(self, keras_model, gpu_count):
"""Class constructor.
keras_model: The Keras model to parallelize
gpu_count: Number of GPUs. Must be > 1
"""
self.inner_model = keras_model
self.gpu_count = gpu_count
merged_outputs = self.make_parallel()
super(ParallelModel, self).__init__(inputs=self.inner_model.inputs,
outputs=merged_outputs)
def __getattribute__(self, attrname):
"""Redirect loading and saving methods to the inner model. That's where
the weights are stored."""
if 'load' in attrname or 'save' in attrname:
return getattr(self.inner_model, attrname)
return super(ParallelModel, self).__getattribute__(attrname)
def summary(self, *args, **kwargs):
"""Override summary() to display summaries of both, the wrapper
and inner models."""
super(ParallelModel, self).summary(*args, **kwargs)
self.inner_model.summary(*args, **kwargs)
def make_parallel(self):
"""Creates a new wrapper model that consists of multiple replicas of
the original model placed on different GPUs.
"""
# Slice inputs. Slice inputs on the CPU to avoid sending a copy
# of the full inputs to all GPUs. Saves on bandwidth and memory.
input_slices = {name: tf.split(x, self.gpu_count)
for name, x in zip(self.inner_model.input_names,
self.inner_model.inputs)}
output_names = self.inner_model.output_names
outputs_all = []
for i in range(len(self.inner_model.outputs)):
outputs_all.append([])
# Run the model call() on each GPU to place the ops there
for i in range(self.gpu_count):
with tf.device('/gpu:%d' % i):
with tf.name_scope('tower_%d' % i):
# Run a slice of inputs through this replica
zipped_inputs = zip(self.inner_model.input_names,
self.inner_model.inputs)
inputs = [
KL.Lambda(lambda s: input_slices[name][i],
output_shape=lambda s: (None,) + s[1:])(tensor)
for name, tensor in zipped_inputs]
# Create the model replica and get the outputs
outputs = self.inner_model(inputs)
if not isinstance(outputs, list):
outputs = [outputs]
# Save the outputs for merging back together later
for l, o in enumerate(outputs):
outputs_all[l].append(o)
# Merge outputs on CPU
with tf.device('/cpu:0'):
merged = []
for outputs, name in zip(outputs_all, output_names):
# Concatenate or average outputs?
# Outputs usually have a batch dimension and we concatenate
# across it. If they don't, then the output is likely a loss
# or a metric value that gets averaged across the batch.
# Keras expects losses and metrics to be scalars.
if K.int_shape(outputs[0]) == ():
# Average
m = KL.Lambda(lambda o: tf.add_n(o) / len(outputs), name=name)(outputs)
else:
# Concatenate
m = KL.Concatenate(axis=0, name=name)(outputs)
merged.append(m)
return merged
if __name__ == "__main__":
# Testing code below. It creates a simple model to train on MNIST and
# tries to run it on 2 GPUs. It saves the graph so it can be viewed
# in TensorBoard. Run it as:
#
# python3 parallel_model.py
import os
import numpy as np
import keras.optimizers
from keras.datasets import mnist
from keras.preprocessing.image import ImageDataGenerator
GPU_COUNT = 2
# Root directory of the project
ROOT_DIR = os.path.abspath("../")
# Directory to save logs and trained model
MODEL_DIR = os.path.join(ROOT_DIR, "logs")
def build_model(x_train, num_classes):
# Reset default graph. Keras leaves old ops in the graph,
# which are ignored for execution but clutter graph
# visualization in TensorBoard.
tf.reset_default_graph()
inputs = KL.Input(shape=x_train.shape[1:], name="input_image")
x = KL.Conv2D(32, (3, 3), activation='relu', padding="same",
name="conv1")(inputs)
x = KL.Conv2D(64, (3, 3), activation='relu', padding="same",
name="conv2")(x)
x = KL.MaxPooling2D(pool_size=(2, 2), name="pool1")(x)
x = KL.Flatten(name="flat1")(x)
x = KL.Dense(128, activation='relu', name="dense1")(x)
x = KL.Dense(num_classes, activation='softmax', name="dense2")(x)
return KM.Model(inputs, x, "digit_classifier_model")
# Load MNIST Data
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = np.expand_dims(x_train, -1).astype('float32') / 255
x_test = np.expand_dims(x_test, -1).astype('float32') / 255
print('x_train shape:', x_train.shape)
print('x_test shape:', x_test.shape)
# Build data generator and model
datagen = ImageDataGenerator()
model = build_model(x_train, 10)
# Add multi-GPU support.
model = ParallelModel(model, GPU_COUNT)
optimizer = keras.optimizers.SGD(lr=0.01, momentum=0.9, clipnorm=5.0)
model.compile(loss='sparse_categorical_crossentropy',
optimizer=optimizer, metrics=['accuracy'])
model.summary()
# Train
model.fit_generator(
datagen.flow(x_train, y_train, batch_size=64),
steps_per_epoch=50, epochs=10, verbose=1,
validation_data=(x_test, y_test),
callbacks=[keras.callbacks.TensorBoard(log_dir=MODEL_DIR,
write_graph=True)]
)
@BenoCharlo
Copy link

Great it's working now. I have upgraded Python to Python 3.6 with Keras 2.2.4 and tensorflow and tensorflow-gpu packages. I have also upgraded the size of the cluster i'm using (it was the key element).
Thank you for your time

@skywalkerisnull
Copy link
Author

No worries,

But for anyone else that may come across this issue in the future, I was using Python 3.6.8 and this is the Conda and Pip dump:

Conda:

_tflow_select             2.1.0                       gpu  
absl-py                   0.7.1                    py36_0  
affine                    2.2.2                    py36_0  
astor                     0.7.1                    py36_0  
attrs                     19.1.0                   py36_1  
blas                      1.0                         mkl  
bzip2                     1.0.8                he774522_0  
ca-certificates           2019.8.28                     0  
certifi                   2019.9.11                py36_0  
chardet                   3.0.4                    pypi_0    pypi
click                     7.0                      pypi_0    pypi
click-plugins             1.1.1                      py_0  
cligj                     0.5.0                    py36_0  
cloudpickle               1.2.1                      py_0    conda-forge  
cudatoolkit               10.0.130                      0  
cudnn                     7.6.0                cuda10.0_0  
curl                      7.65.2               h2a8f88b_0  
cycler                    0.10.0                     py_1    conda-forge  
cytoolz                   0.9.0.1         py36hfa6e2cd_1001    conda-forge
dask-core                 1.2.2                      py_0    conda-forge  
decorator                 4.4.0                      py_0    conda-forge  
expat                     2.2.5                he025d50_0  
flask                     1.0.3                    pypi_0    pypi
flask-cors                3.0.7                    pypi_0    pypi
freetype                  2.10.0               h5db478b_0    conda-forge  
freexl                    1.0.5                hfa6e2cd_0  
gast                      0.2.2                    py36_0  
geojson                   2.5.0                    pypi_0    pypi
geos                      3.7.1             he025d50_1000    conda-forge  
grpcio                    1.16.1           py36h351948d_1  
h5py                      2.9.0            py36h5e291fa_0  
hdf4                      4.2.13               h712560f_2  
hdf5                      1.10.4               h7ebc959_0  
icc_rt                    2019.0.0             h0cc432a_1  
icu                       58.2                 ha66f8fd_1  
idna                      2.8                      pypi_0    pypi
imageio                   2.5.0                    py36_0    conda-forge  
imantics                  0.1.10                   pypi_0    pypi
imgaug                    0.2.9                      py_0    conda-forge  
intel-openmp              2019.4                      245  
intel-tensorflow          0.0.1                    pypi_0    pypi
itsdangerous              1.1.0                    pypi_0    pypi
jinja2                    2.10.1                   pypi_0    pypi
joblib                    0.13.2                   pypi_0    pypi
jpeg                      9c                hfa6e2cd_1001    conda-forge  
kealib                    1.4.7                h07cbb95_6  
keras-applications        1.0.7                      py_0  
keras-base                2.2.4                    py36_0  
keras-gpu                 2.2.4                         0  
keras-preprocessing       1.0.9                      py_0  
keras2onnx                1.5.0                    pypi_0    pypi
kiwisolver                1.1.0            py36he980bc4_0    conda-forge  
krb5                      1.16.1               hc04afaa_7  
libblas                   3.8.0                     8_mkl    conda-forge  
libboost                  1.67.0               hd9e427e_4  
libcblas                  3.8.0                     8_mkl    conda-forge  
libcurl                   7.65.2               h2a8f88b_0
libgdal                   2.3.3                h10f50ba_0
libiconv                  1.15                 h1df5818_7
libkml                    1.3.0                he5f2a48_4
liblapack                 3.8.0                     8_mkl    conda-forge
liblapacke                3.8.0                     8_mkl    conda-forge
libnetcdf                 4.6.1                h411e497_2
libpng                    1.6.37               h7602738_0    conda-forge
libpq                     11.2                 h3235a2c_0
libprotobuf               3.7.1                h7bd577a_0
libspatialite             4.3.0a              hc36aec2_19
libssh2                   1.8.2                h7a1dbc1_0
libtiff                   4.0.10            h6512ee2_1003    conda-forge
libwebp                   1.0.2                hfa6e2cd_2    conda-forge
libxml2                   2.9.9                h464c3ec_0
lxml                      4.4.1                    pypi_0    pypi
lz4-c                     1.8.3             he025d50_1001    conda-forge
markdown                  3.1                      py36_0
markupsafe                1.1.1                    pypi_0    pypi
matplotlib                3.1.0                    py36_1    conda-forge
matplotlib-base           3.1.0            py36h2852a4a_1    conda-forge
mkl                       2019.4                      245
mkl_fft                   1.0.12           py36h14836fe_0
mkl_random                1.0.2            py36h343c172_0
mock                      3.0.5                    py36_0
networkx                  2.3                        py_0    conda-forge
numpy                     1.16.4           py36h19fb1c0_0
numpy-base                1.16.4           py36hc3f5095_0
olefile                   0.46                       py_0    conda-forge
onnx                      1.5.0                    pypi_0    pypi
onnxconverter-common      1.5.0                    pypi_0    pypi
onnxmltools               1.4.1                    pypi_0    pypi
opencv                    4.1.0            py36hb4945ee_5    conda-forge
opencv-python             4.1.0.25                 pypi_0    pypi
openssl                   1.1.1d               he774522_0
pandas                    0.25.0                   pypi_0    pypi
pcre                      8.43                 ha925a31_0
pillow                    6.0.0                    pypi_0    pypi
pip                       19.1.1                   py36_0
proj4                     5.2.0                ha925a31_1
protobuf                  3.7.1            py36h33f27b4_0
pyparsing                 2.4.0                      py_0    conda-forge
pyqt                      5.9.2            py36h6538335_0    conda-forge
pyreadline                2.1                      py36_1
python                    3.6.8                h9f7ef89_7
python-dateutil           2.8.0                      py_0    conda-forge
pytz                      2019.2                   pypi_0    pypi
pywavelets                1.0.3            py36h452e1ab_1    conda-forge
pyyaml                    5.1              py36he774522_0
qt                        5.9.7                hc6833c9_1    conda-forge
rasterio                  1.0.21           py36h6bd7d87_0
requests                  2.22.0                   pypi_0    pypi
rope                      0.14.0                   pypi_0    pypi
scikit-image              0.15.0                   pypi_0    pypi
scikit-learn              0.21.2                   pypi_0    pypi
scipy                     1.2.1            py36h29ff71c_0
setuptools                41.0.1                   py36_0
shapely                   1.6.4           py36h8921fb9_1004    conda-forge
sip                       4.19.8          py36h6538335_1000    conda-forge
six                       1.12.0                   py36_0
skl2onnx                  1.4.9                    pypi_0    pypi
snuggs                    1.4.6                      py_0
sqlite                    3.28.0               he774522_0
tenacity                  5.0.4                    pypi_0    pypi
tensorboard               1.13.1           py36h33f27b4_0
tensorflow                1.13.1          gpu_py36h9006a92_0
tensorflow-base           1.13.1          gpu_py36h871c8ca_0
tensorflow-estimator      1.13.0                     py_0
tensorflow-gpu            1.13.1               h0d30ee6_0
tensorflow-serving-api    1.13.0                   pypi_0    pypi
termcolor                 1.1.0                    py36_1
tk                        8.6.9             hfa6e2cd_1002    conda-forge
toolz                     0.9.0                      py_1    conda-forge
tornado                   6.0.2            py36hfa6e2cd_0    conda-forge
typing                    3.6.6                    pypi_0    pypi
typing-extensions         3.7.2                    pypi_0    pypi
urllib3                   1.25.3                   pypi_0    pypi
vc                        14.1                 h0510ff6_4
vs2015_runtime            14.16.27012          hf0eaf9b_0
werkzeug                  0.15.2                     py_0
wheel                     0.33.4                   py36_0
wincertstore              0.2              py36h7fe50ca_0
xerces-c                  3.2.2                ha925a31_0
xz                        5.2.4             h2fa13f4_1001    conda-forge
yaml                      0.1.7                hc54c509_2
zlib                      1.2.11               h62dcd97_3
zstd                      1.4.0                hd8a0e53_0    conda-forge

pip: 

Package                Version    
---------------------- -----------
absl-py                0.7.1
affine                 2.2.2
astor                  0.7.1      
attrs                  19.1.0
certifi                2019.9.11
chardet                3.0.4
Click                  7.0
click-plugins          1.1.1
cligj                  0.5.0
cloudpickle            1.2.1
cycler                 0.10.0
cytoolz                0.9.0.1
dask                   1.2.2
decorator              4.4.0
Flask                  1.0.3
Flask-Cors             3.0.7
gast                   0.2.2
geojson                2.5.0
grpcio                 1.16.1
h5py                   2.9.0
idna                   2.8
imageio                2.5.0
imantics               0.1.10
imgaug                 0.2.9
intel-tensorflow       0.0.1
itsdangerous           1.1.0
Jinja2                 2.10.1
joblib                 0.13.2
Keras                  2.2.4
Keras-Applications     1.0.7
Keras-Preprocessing    1.0.9
keras2onnx             1.5.0
kiwisolver             1.1.0
lxml                   4.4.1
Markdown               3.1
MarkupSafe             1.1.1
matplotlib             3.1.0
mkl-fft                1.0.12
mkl-random             1.0.2
mock                   3.0.5
networkx               2.3
numpy                  1.16.4
olefile                0.46
onnx                   1.5.0
onnxconverter-common   1.5.0
onnxmltools            1.4.1
opencv-python          4.1.0.25
pandas                 0.25.0
Pillow                 6.0.0
pip                    19.1.1
protobuf               3.7.1
pyparsing              2.4.0
pyreadline             2.1
python-dateutil        2.8.0
pytz                   2019.2
PyWavelets             1.0.3
PyYAML                 5.1
rasterio               1.0.21
requests               2.22.0
rope                   0.14.0
scikit-image           0.15.0
scikit-learn           0.21.2
scipy                  1.2.1
setuptools             41.0.1
Shapely                1.6.4.post2
six                    1.12.0
skl2onnx               1.4.9
snuggs                 1.4.6
tenacity               5.0.4
tensorboard            1.13.1
tensorflow             1.13.1
tensorflow-estimator   1.13.0
tensorflow-serving-api 1.13.0
termcolor              1.1.0
toolz                  0.9.0
tornado                6.0.2
typing                 3.6.6
typing-extensions      3.7.2
urllib3                1.25.3
Werkzeug               0.15.2
wheel                  0.33.4
wincertstore           0.2

@zcunyi
Copy link

zcunyi commented Apr 19, 2021

Hello, I've implemented this function to use multi-gpu with Keras 2.2.4 But there is still a trouble when i tried to train the model. I got an error:

AttributeError: 'Model' object has no attribute 'input_names'

I don't really know how to fix this. Have you an idea? Thks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment