Skip to content

Instantly share code, notes, and snippets.

@Jsevillamol
Last active February 11, 2025 20:59
Show Gist options
  • Save Jsevillamol/0daac5a6001843942f91f2a3daea27a7 to your computer and use it in GitHub Desktop.
Save Jsevillamol/0daac5a6001843942f91f2a3daea27a7 to your computer and use it in GitHub Desktop.
ROI Pooling Layer
import tensorflow as tf
from tensorflow.keras.layers import Layer
class ROIPoolingLayer(Layer):
""" Implements Region Of Interest Max Pooling
for channel-first images and relative bounding box coordinates
# Constructor parameters
pooled_height, pooled_width (int) --
specify height and width of layer outputs
Shape of inputs
[(batch_size, pooled_height, pooled_width, n_channels),
(batch_size, num_rois, 4)]
Shape of output
(batch_size, num_rois, pooled_height, pooled_width, n_channels)
"""
def __init__(self, pooled_height, pooled_width, **kwargs):
self.pooled_height = pooled_height
self.pooled_width = pooled_width
super(ROIPoolingLayer, self).__init__(**kwargs)
def compute_output_shape(self, input_shape):
""" Returns the shape of the ROI Layer output
"""
feature_map_shape, rois_shape = input_shape
assert feature_map_shape[0] == rois_shape[0]
batch_size = feature_map_shape[0]
n_rois = rois_shape[1]
n_channels = feature_map_shape[3]
return (batch_size, n_rois, self.pooled_height,
self.pooled_width, n_channels)
def call(self, x):
""" Maps the input tensor of the ROI layer to its output
# Parameters
x[0] -- Convolutional feature map tensor,
shape (batch_size, pooled_height, pooled_width, n_channels)
x[1] -- Tensor of region of interests from candidate bounding boxes,
shape (batch_size, num_rois, 4)
Each region of interest is defined by four relative
coordinates (x_min, y_min, x_max, y_max) between 0 and 1
# Output
pooled_areas -- Tensor with the pooled region of interest, shape
(batch_size, num_rois, pooled_height, pooled_width, n_channels)
"""
def curried_pool_rois(x):
return ROIPoolingLayer._pool_rois(x[0], x[1],
self.pooled_height,
self.pooled_width)
pooled_areas = tf.map_fn(curried_pool_rois, x, dtype=tf.float32)
return pooled_areas
@staticmethod
def _pool_rois(feature_map, rois, pooled_height, pooled_width):
""" Applies ROI pooling for a single image and varios ROIs
"""
def curried_pool_roi(roi):
return ROIPoolingLayer._pool_roi(feature_map, roi,
pooled_height, pooled_width)
pooled_areas = tf.map_fn(curried_pool_roi, rois, dtype=tf.float32)
return pooled_areas
@staticmethod
def _pool_roi(feature_map, roi, pooled_height, pooled_width):
""" Applies ROI pooling to a single image and a single region of interest
"""
# Compute the region of interest
feature_map_height = int(feature_map.shape[0])
feature_map_width = int(feature_map.shape[1])
h_start = tf.cast(feature_map_height * roi[0], 'int32')
w_start = tf.cast(feature_map_width * roi[1], 'int32')
h_end = tf.cast(feature_map_height * roi[2], 'int32')
w_end = tf.cast(feature_map_width * roi[3], 'int32')
region = feature_map[h_start:h_end, w_start:w_end, :]
# Divide the region into non overlapping areas
region_height = h_end - h_start
region_width = w_end - w_start
h_step = tf.cast( region_height / pooled_height, 'int32')
w_step = tf.cast( region_width / pooled_width , 'int32')
areas = [[(
i*h_step,
j*w_step,
(i+1)*h_step if i+1 < pooled_height else region_height,
(j+1)*w_step if j+1 < pooled_width else region_width
)
for j in range(pooled_width)]
for i in range(pooled_height)]
# take the maximum of each area and stack the result
def pool_area(x):
return tf.math.reduce_max(region[x[0]:x[2], x[1]:x[3], :], axis=[0,1])
pooled_features = tf.stack([[pool_area(x) for x in row] for row in areas])
return pooled_features
@Mole1424
Copy link

Mole1424 commented Feb 11, 2025

Hi,
When working with outputs of aribtary size (ie tensors of shape containing at least 1 dimension of size None) using feature_map.shape[0] on line 79 can cause issues. You can use instead use tf.shape(feature_map)[0] to rectify this. The same applies to line 80. Thanks for your work!

(I am new to TensorFlow so this may neither be an optimal nor valid fix but it worked for me)

@Mole1424
Copy link

also calling map_fn with dtype= is deprecated and willbe removed soon to be replaced with fn_output_signature this is is a simple drag and drop replacement, hence lines 58 and 70 should become tf.map_fn(..., fn_output_signature=tf.float32)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment