Last active
May 20, 2021 10:39
-
-
Save magesh-technovator/68f117517346a5004c7ee9332bc1b505 to your computer and use it in GitHub Desktop.
segmentation_DataLoader
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.utils.data import Dataset, DataLoader | |
import glob | |
import os | |
import numpy as np | |
import cv2 | |
import torch | |
from torchvision import transforms, utils | |
from PIL import Image | |
class SegDataset(Dataset): | |
"""Segmentation Dataset""" | |
def __init__(self, root_dir, imageFolder, maskFolder, transform=None, seed=None, fraction=None, subset=None, imagecolormode='rgb', maskcolormode='rgb'): | |
""" | |
Args: | |
root_dir (string): Directory with all the images and should have the following structure. | |
root | |
--Images | |
-----Img 1 | |
-----Img N | |
--Mask | |
-----Mask 1 | |
-----Mask N | |
imageFolder (string) = 'Images' : Name of the folder which contains the Images. | |
maskFolder (string) = 'Masks : Name of the folder which contains the Masks. | |
transform (callable, optional): Optional transform to be applied on a sample. | |
seed: Specify a seed for the train and test split | |
fraction: A float value from 0 to 1 which specifies the validation split fraction | |
subset: 'Train' or 'Test' to select the appropriate set. | |
imagecolormode: 'rgb' or 'grayscale' | |
maskcolormode: 'rgb' or 'grayscale' | |
""" | |
self.color_dict = {'rgb': 1, 'grayscale': 0} | |
assert(imagecolormode in ['rgb', 'grayscale']) | |
assert(maskcolormode in ['rgb', 'grayscale']) | |
self.imagecolorflag = self.color_dict[imagecolormode] | |
self.maskcolorflag = self.color_dict[maskcolormode] | |
self.root_dir = root_dir | |
self.transform = transform | |
if not fraction: | |
self.image_names = sorted( | |
glob.glob(os.path.join(self.root_dir, imageFolder, '*'))) | |
self.mask_names = sorted( | |
glob.glob(os.path.join(self.root_dir, maskFolder, '*'))) | |
else: | |
assert(subset in ['Train', 'Test']) | |
self.fraction = fraction | |
self.image_list = np.array( | |
sorted(glob.glob(os.path.join(self.root_dir, imageFolder, '*')))) | |
self.mask_list = np.array( | |
sorted(glob.glob(os.path.join(self.root_dir, maskFolder, '*')))) | |
if seed: | |
np.random.seed(seed) | |
indices = np.arange(len(self.image_list)) | |
np.random.shuffle(indices) | |
self.image_list = self.image_list[indices] | |
self.mask_list = self.mask_list[indices] | |
if subset == 'Train': | |
self.image_names = self.image_list[:int( | |
np.ceil(len(self.image_list)*(1-self.fraction)))] | |
self.mask_names = self.mask_list[:int( | |
np.ceil(len(self.mask_list)*(1-self.fraction)))] | |
else: | |
self.image_names = self.image_list[int( | |
np.ceil(len(self.image_list)*(1-self.fraction))):] | |
self.mask_names = self.mask_list[int( | |
np.ceil(len(self.mask_list)*(1-self.fraction))):] | |
def __len__(self): | |
return len(self.image_names) | |
def __getitem__(self, idx): | |
img_name = self.image_names[idx] | |
if self.imagecolorflag: | |
image = cv2.imread( | |
img_name, self.imagecolorflag).transpose(2, 0, 1) | |
else: | |
image = cv2.imread(img_name, self.imagecolorflag) | |
msk_name = self.mask_names[idx] | |
if self.maskcolorflag: | |
mask = cv2.imread(msk_name, self.maskcolorflag).transpose(2, 0, 1) | |
else: | |
mask = cv2.imread(msk_name, self.maskcolorflag) | |
sample = {'image': image, 'mask': mask} | |
if self.transform: | |
sample = self.transform(sample) | |
return sample | |
class Resize(object): | |
"""Resize image and/or masks.""" | |
def __init__(self, imageresize, maskresize): | |
self.imageresize = imageresize | |
self.maskresize = maskresize | |
def __call__(self, sample): | |
image, mask = sample['image'], sample['mask'] | |
if len(image.shape) == 3: | |
image = image.transpose(1, 2, 0) | |
if len(mask.shape) == 3: | |
mask = mask.transpose(1, 2, 0) | |
mask = cv2.resize(mask, self.maskresize, cv2.INTER_AREA) | |
image = cv2.resize(image, self.imageresize, cv2.INTER_AREA) | |
if len(image.shape) == 3: | |
image = image.transpose(2, 0, 1) | |
if len(mask.shape) == 3: | |
mask = mask.transpose(2, 0, 1) | |
return {'image': image, | |
'mask': mask} | |
class ToTensor(object): | |
"""Convert ndarrays in sample to Tensors.""" | |
def __call__(self, sample, maskresize=None, imageresize=None): | |
image, mask = sample['image'], sample['mask'] | |
if len(mask.shape) == 2: | |
mask = mask.reshape((1,)+mask.shape) | |
if len(image.shape) == 2: | |
image = image.reshape((1,)+image.shape) | |
return {'image': torch.from_numpy(image), | |
'mask': torch.from_numpy(mask)} | |
class Normalize(object): | |
'''Normalize image''' | |
def __call__(self, sample): | |
image, mask = sample['image'], sample['mask'] | |
return {'image': image.type(torch.FloatTensor)/255, | |
'mask': mask.type(torch.FloatTensor)/255} | |
def get_dataloader_single_folder(data_dir, imageFolder='Images', maskFolder='Masks', fraction=0.2, batch_size=4): | |
""" | |
Create training and testing dataloaders from a single folder. | |
""" | |
data_transforms = { | |
'Train': transforms.Compose([Resize((256, 256), (256, 256)), ToTensor(), Normalize()]), | |
'Test': transforms.Compose([Resize((256, 256), (256, 256)), ToTensor(), Normalize()]), | |
} | |
image_datasets = {x: SegDataset(data_dir, imageFolder=imageFolder, maskFolder=maskFolder, seed=100, fraction=fraction, subset=x, transform=data_transforms[x]) | |
for x in ['Train', 'Test']} | |
dataloaders = {x: DataLoader(image_datasets[x], batch_size=batch_size, | |
shuffle=True, num_workers=8) | |
for x in ['Train', 'Test']} | |
return dataloaders | |
# Create the dataloader | |
dataloaders = get_dataloader_single_folder( | |
data_dir, imageFolder='images', maskFolder='masks', fraction=0.2, batch_size=batchsize) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment