Skip to content

Instantly share code, notes, and snippets.

@jpcbertoldo
Created February 8, 2024 17:56

Revisions

  1. jpcbertoldo created this gist Feb 8, 2024.
    60 changes: 60 additions & 0 deletions gistfile1.txt
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,60 @@
    # %%
    # Scratch of spro on a batch of images

    import torch

    # ---- Dummy data ----

    # a batch of predictions for two images
    # shape: (2, 1, 256, 256)
    predictions = torch.rand(2, 1, 256, 256)

    # masks as list[list[Tensor]]
    # external list is for the batch
    # internal lists are for the multiple masks of the image
    masks = [
    # each internal list corresponds to one image (index in `preds`)
    [
    # each tensor corresponds to one file (ie. one spro curve / mask)
    torch.rand(256, 256) > .99,
    torch.rand(256, 256) > .90,
    ],
    [
    torch.rand(256, 256) > .80,
    ],
    ]

    # ---- Concat masks & repeat predictions ----

    # Concatenate all masks into a single tensor
    # shape: (3, 256, 256)
    masks_concatenated = torch.cat([torch.cat(image_masks) for image_masks in masks])

    # make a predictions tensor where the prediction of an image with N masks is repeated N times
    # (ie. the same prediction is repeated for each mask of the image)
    # shape: (3, 1, 256, 256)
    predictions_repeated = torch.cat([
    pred.repeat(len(image_masks), 1, 1, 1)
    for pred, image_masks in zip(predictions, masks)
    ])
    # so the indexes in `predictions_repeated` match the indexes in `masks_concatenated`
    # but does NOT the indexes in `predictions`
    # mapping would be
    # `predictions_repeated[0]` -> `predictions[0]`
    # `predictions_repeated[1]` -> `predictions[0]`
    # `predictions_repeated[2]` -> `predictions[1]`

    # ---- SPRO ----
    # ...

    # a fake result
    # 10_000 is the number of points (thresholds) in the spro curve
    spro_curves = torch.rand(3, 10_000)

    #...
    # this should happen with all curves from all batches (NOT per batch), Eq. 1 in the paper
    spro = spro_curves.mean(dim=0)



    # something weird happens here: batch size of the input is 2, but there are 3 curves