Skip to content

Instantly share code, notes, and snippets.

@magesh-technovator
Last active May 20, 2021 10:39
Show Gist options
  • Save magesh-technovator/68f117517346a5004c7ee9332bc1b505 to your computer and use it in GitHub Desktop.
Save magesh-technovator/68f117517346a5004c7ee9332bc1b505 to your computer and use it in GitHub Desktop.
segmentation_DataLoader
from torch.utils.data import Dataset, DataLoader
import glob
import os
import numpy as np
import cv2
import torch
from torchvision import transforms, utils
from PIL import Image
class SegDataset(Dataset):
"""Segmentation Dataset"""
def __init__(self, root_dir, imageFolder, maskFolder, transform=None, seed=None, fraction=None, subset=None, imagecolormode='rgb', maskcolormode='rgb'):
"""
Args:
root_dir (string): Directory with all the images and should have the following structure.
root
--Images
-----Img 1
-----Img N
--Mask
-----Mask 1
-----Mask N
imageFolder (string) = 'Images' : Name of the folder which contains the Images.
maskFolder (string) = 'Masks : Name of the folder which contains the Masks.
transform (callable, optional): Optional transform to be applied on a sample.
seed: Specify a seed for the train and test split
fraction: A float value from 0 to 1 which specifies the validation split fraction
subset: 'Train' or 'Test' to select the appropriate set.
imagecolormode: 'rgb' or 'grayscale'
maskcolormode: 'rgb' or 'grayscale'
"""
self.color_dict = {'rgb': 1, 'grayscale': 0}
assert(imagecolormode in ['rgb', 'grayscale'])
assert(maskcolormode in ['rgb', 'grayscale'])
self.imagecolorflag = self.color_dict[imagecolormode]
self.maskcolorflag = self.color_dict[maskcolormode]
self.root_dir = root_dir
self.transform = transform
if not fraction:
self.image_names = sorted(
glob.glob(os.path.join(self.root_dir, imageFolder, '*')))
self.mask_names = sorted(
glob.glob(os.path.join(self.root_dir, maskFolder, '*')))
else:
assert(subset in ['Train', 'Test'])
self.fraction = fraction
self.image_list = np.array(
sorted(glob.glob(os.path.join(self.root_dir, imageFolder, '*'))))
self.mask_list = np.array(
sorted(glob.glob(os.path.join(self.root_dir, maskFolder, '*'))))
if seed:
np.random.seed(seed)
indices = np.arange(len(self.image_list))
np.random.shuffle(indices)
self.image_list = self.image_list[indices]
self.mask_list = self.mask_list[indices]
if subset == 'Train':
self.image_names = self.image_list[:int(
np.ceil(len(self.image_list)*(1-self.fraction)))]
self.mask_names = self.mask_list[:int(
np.ceil(len(self.mask_list)*(1-self.fraction)))]
else:
self.image_names = self.image_list[int(
np.ceil(len(self.image_list)*(1-self.fraction))):]
self.mask_names = self.mask_list[int(
np.ceil(len(self.mask_list)*(1-self.fraction))):]
def __len__(self):
return len(self.image_names)
def __getitem__(self, idx):
img_name = self.image_names[idx]
if self.imagecolorflag:
image = cv2.imread(
img_name, self.imagecolorflag).transpose(2, 0, 1)
else:
image = cv2.imread(img_name, self.imagecolorflag)
msk_name = self.mask_names[idx]
if self.maskcolorflag:
mask = cv2.imread(msk_name, self.maskcolorflag).transpose(2, 0, 1)
else:
mask = cv2.imread(msk_name, self.maskcolorflag)
sample = {'image': image, 'mask': mask}
if self.transform:
sample = self.transform(sample)
return sample
class Resize(object):
"""Resize image and/or masks."""
def __init__(self, imageresize, maskresize):
self.imageresize = imageresize
self.maskresize = maskresize
def __call__(self, sample):
image, mask = sample['image'], sample['mask']
if len(image.shape) == 3:
image = image.transpose(1, 2, 0)
if len(mask.shape) == 3:
mask = mask.transpose(1, 2, 0)
mask = cv2.resize(mask, self.maskresize, cv2.INTER_AREA)
image = cv2.resize(image, self.imageresize, cv2.INTER_AREA)
if len(image.shape) == 3:
image = image.transpose(2, 0, 1)
if len(mask.shape) == 3:
mask = mask.transpose(2, 0, 1)
return {'image': image,
'mask': mask}
class ToTensor(object):
"""Convert ndarrays in sample to Tensors."""
def __call__(self, sample, maskresize=None, imageresize=None):
image, mask = sample['image'], sample['mask']
if len(mask.shape) == 2:
mask = mask.reshape((1,)+mask.shape)
if len(image.shape) == 2:
image = image.reshape((1,)+image.shape)
return {'image': torch.from_numpy(image),
'mask': torch.from_numpy(mask)}
class Normalize(object):
'''Normalize image'''
def __call__(self, sample):
image, mask = sample['image'], sample['mask']
return {'image': image.type(torch.FloatTensor)/255,
'mask': mask.type(torch.FloatTensor)/255}
def get_dataloader_single_folder(data_dir, imageFolder='Images', maskFolder='Masks', fraction=0.2, batch_size=4):
"""
Create training and testing dataloaders from a single folder.
"""
data_transforms = {
'Train': transforms.Compose([Resize((256, 256), (256, 256)), ToTensor(), Normalize()]),
'Test': transforms.Compose([Resize((256, 256), (256, 256)), ToTensor(), Normalize()]),
}
image_datasets = {x: SegDataset(data_dir, imageFolder=imageFolder, maskFolder=maskFolder, seed=100, fraction=fraction, subset=x, transform=data_transforms[x])
for x in ['Train', 'Test']}
dataloaders = {x: DataLoader(image_datasets[x], batch_size=batch_size,
shuffle=True, num_workers=8)
for x in ['Train', 'Test']}
return dataloaders
# Create the dataloader
dataloaders = get_dataloader_single_folder(
data_dir, imageFolder='images', maskFolder='masks', fraction=0.2, batch_size=batchsize)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment