Created
February 8, 2024 17:56
Revisions
-
jpcbertoldo created this gist
Feb 8, 2024 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,60 @@ # %% # Scratch of spro on a batch of images import torch # ---- Dummy data ---- # a batch of predictions for two images # shape: (2, 1, 256, 256) predictions = torch.rand(2, 1, 256, 256) # masks as list[list[Tensor]] # external list is for the batch # internal lists are for the multiple masks of the image masks = [ # each internal list corresponds to one image (index in `preds`) [ # each tensor corresponds to one file (ie. one spro curve / mask) torch.rand(256, 256) > .99, torch.rand(256, 256) > .90, ], [ torch.rand(256, 256) > .80, ], ] # ---- Concat masks & repeat predictions ---- # Concatenate all masks into a single tensor # shape: (3, 256, 256) masks_concatenated = torch.cat([torch.cat(image_masks) for image_masks in masks]) # make a predictions tensor where the prediction of an image with N masks is repeated N times # (ie. the same prediction is repeated for each mask of the image) # shape: (3, 1, 256, 256) predictions_repeated = torch.cat([ pred.repeat(len(image_masks), 1, 1, 1) for pred, image_masks in zip(predictions, masks) ]) # so the indexes in `predictions_repeated` match the indexes in `masks_concatenated` # but does NOT the indexes in `predictions` # mapping would be # `predictions_repeated[0]` -> `predictions[0]` # `predictions_repeated[1]` -> `predictions[0]` # `predictions_repeated[2]` -> `predictions[1]` # ---- SPRO ---- # ... # a fake result # 10_000 is the number of points (thresholds) in the spro curve spro_curves = torch.rand(3, 10_000) #... # this should happen with all curves from all batches (NOT per batch), Eq. 1 in the paper spro = spro_curves.mean(dim=0) # something weird happens here: batch size of the input is 2, but there are 3 curves