Created
December 23, 2022 10:20
-
-
Save AlessandroMondin/9c6ca94b57ab5f5b0af41d5e59477d9d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Training_Dataset(Dataset): | |
"""COCO 2017 dataset constructed using the PyTorch built-in functionalities""" | |
def __init__(self, | |
num_classes, | |
root_directory=config.ROOT_DIR, | |
transform=None, | |
train=True, | |
rect_training=False, | |
default_size=640, | |
bs=64, | |
bboxes_format="coco", | |
ultralytics_loss=False, | |
): | |
assert bboxes_format in ["coco", "yolo"], 'bboxes_format must be either "coco" or "yolo"' | |
self.bs = bs | |
self.batch_range = 64 if bs < 64 else 128 | |
self.bboxes_format = bboxes_format | |
self.ultralytics_loss = ultralytics_loss | |
self.root_directory = root_directory | |
self.nc = num_classes | |
self.transform = transform | |
self.rect_training = rect_training | |
self.default_size = default_size | |
self.train = train | |
if train: | |
fname = 'images/train' | |
annot_file = "annot_train.csv" | |
self.annot_folder = "train" | |
else: | |
fname = 'images/val' | |
annot_file = "annot_val.csv" | |
self.annot_folder = "val" | |
self.fname = fname | |
try: | |
self.annotations = pd.read_csv(os.path.join(root_directory, "labels", annot_file), | |
header=None, index_col=0).sort_values(by=[0]) | |
self.annotations = self.annotations.head((len(self.annotations)-1)) # just removes last line | |
except FileNotFoundError: | |
annotations = [] | |
for img_txt in os.listdir(os.path.join(self.root_directory, "labels", self.annot_folder)): | |
img = img_txt.split(".txt")[0] | |
try: | |
w, h = imagesize.get(os.path.join(self.root_directory, "images", self.annot_folder, f"{img}.jpg")) | |
except FileNotFoundError: | |
continue | |
annotations.append([str(img) + ".jpg", h, w]) | |
self.annotations = pd.DataFrame(annotations) | |
self.annotations.to_csv(os.path.join(self.root_directory, "labels", annot_file)) | |
self.len_ann = len(self.annotations) | |
if rect_training: | |
self.annotations = self.adaptive_shape(self.annotations) | |
def __len__(self): | |
return len(self.annotations) | |
def __getitem__(self, idx): | |
img_name = self.annotations.iloc[idx, 0] | |
tg_height = self.annotations.iloc[idx, 1] if self.rect_training else 640 | |
tg_width = self.annotations.iloc[idx, 2] if self.rect_training else 640 | |
# img_name[:-4] to remove the .jpg or .png which are coco img formats | |
label_path = os.path.join(self.root_directory, "labels", self.annot_folder, img_name[:-4] + ".txt") | |
# to avoid an annoying "UserWarning: loadtxt: Empty input file" | |
with warnings.catch_warnings(): | |
warnings.simplefilter("ignore") | |
labels = np.loadtxt(fname=label_path, delimiter=" ", ndmin=2) | |
# removing annotations with negative values | |
labels = labels[np.all(labels >= 0, axis=1), :] | |
# to avoid negative values | |
labels[:, 3:5] = np.floor(labels[:, 3:5] * 1000) / 1000 | |
img = np.array(Image.open(os.path.join(self.root_directory, self.fname, img_name)).convert("RGB")) | |
if self.bboxes_format == "coco": | |
labels[:, -1] -= 1 # 0-indexing the classes of coco labels (1-80 --> 0-79) | |
labels = np.roll(labels, axis=1, shift=1) | |
# normalized coordinates are scale invariant, hence after resizing the img we don't resize labels | |
labels[:, 1:] = coco_to_yolo_tensors(labels[:, 1:], w0=img.shape[1], h0=img.shape[0]) | |
img = resize_image(img, (int(tg_width), int(tg_height))) | |
if self.transform: | |
batch_n = idx // self.bs | |
if batch_n % 2 == 0: | |
self.transform[1].p = 1 | |
else: | |
self.transform[1].p = 0 | |
# albumentations requires bboxes to be (x,y,w,h,class_idx) | |
augmentations = self.transform(image=img, | |
bboxes=np.roll(labels, axis=1, shift=4) | |
) | |
img = augmentations["image"] | |
labels = np.array(augmentations["bboxes"]) | |
if len(labels): | |
labels = np.roll(labels, axis=1, shift=1) | |
if self.ultralytics_loss: | |
labels = torch.from_numpy(labels) | |
out_bboxes = torch.zeros((labels.shape[0], 6)) | |
if len(labels): | |
out_bboxes[..., 1:] = labels | |
img = img.transpose((2, 0, 1)) | |
img = np.ascontiguousarray(img) | |
return torch.from_numpy(img), out_bboxes if self.ultralytics_loss else labels |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment