Created
February 8, 2024 17:56
-
-
Save jpcbertoldo/b766da23f45a4117b428940764c50de3 to your computer and use it in GitHub Desktop.
tmp-spro-scratch.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# %% | |
# Scratch of spro on a batch of images | |
import torch | |
# ---- Dummy data ---- | |
# a batch of predictions for two images | |
# shape: (2, 1, 256, 256) | |
predictions = torch.rand(2, 1, 256, 256) | |
# masks as list[list[Tensor]] | |
# external list is for the batch | |
# internal lists are for the multiple masks of the image | |
masks = [ | |
# each internal list corresponds to one image (index in `preds`) | |
[ | |
# each tensor corresponds to one file (ie. one spro curve / mask) | |
torch.rand(256, 256) > .99, | |
torch.rand(256, 256) > .90, | |
], | |
[ | |
torch.rand(256, 256) > .80, | |
], | |
] | |
# ---- Concat masks & repeat predictions ---- | |
# Concatenate all masks into a single tensor | |
# shape: (3, 256, 256) | |
masks_concatenated = torch.cat([torch.cat(image_masks) for image_masks in masks]) | |
# make a predictions tensor where the prediction of an image with N masks is repeated N times | |
# (ie. the same prediction is repeated for each mask of the image) | |
# shape: (3, 1, 256, 256) | |
predictions_repeated = torch.cat([ | |
pred.repeat(len(image_masks), 1, 1, 1) | |
for pred, image_masks in zip(predictions, masks) | |
]) | |
# so the indexes in `predictions_repeated` match the indexes in `masks_concatenated` | |
# but does NOT the indexes in `predictions` | |
# mapping would be | |
# `predictions_repeated[0]` -> `predictions[0]` | |
# `predictions_repeated[1]` -> `predictions[0]` | |
# `predictions_repeated[2]` -> `predictions[1]` | |
# ---- SPRO ---- | |
# ... | |
# a fake result | |
# 10_000 is the number of points (thresholds) in the spro curve | |
spro_curves = torch.rand(3, 10_000) | |
#... | |
# this should happen with all curves from all batches (NOT per batch), Eq. 1 in the paper | |
spro = spro_curves.mean(dim=0) | |
# something weird happens here: batch size of the input is 2, but there are 3 curves |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment