Skip to content

Instantly share code, notes, and snippets.

@nilsleh
Last active January 23, 2024 07:44
Show Gist options
  • Save nilsleh/b8a8cee4e67a3ef683bfe743be794278 to your computer and use it in GitHub Desktop.
Save nilsleh/b8a8cee4e67a3ef683bfe743be794278 to your computer and use it in GitHub Desktop.
Laplace Example Constant UQ Reproduce
import os
from functools import lru_cache
from typing import Any, Dict, Optional
import kornia.augmentation as K
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import timm
import torch
from laplace import Laplace
from matplotlib.figure import Figure
from PIL import Image
from sklearn.model_selection import train_test_split
from torch import Tensor
from torch.utils.data import DataLoader, Subset
from torchgeo.datamodules import NonGeoDataModule
from torchgeo.datamodules.utils import group_shuffle_split
from torchgeo.datasets import TropicalCyclone
from torchgeo.transforms import AugmentationSequential
class TropicalCycloneSequence(TropicalCyclone):
"""Tropical Cyclone Dataset adopted for loading sequences."""
valid_tasks = ["regression", "classification"]
# based on https://www.nhc.noaa.gov/climo/?text
class_bins = {
"tropical_depression": (0, 33),
"tropical_storm": (34, 63),
"hurr_1": (64, 82),
"hurr_2": (83, 95),
"hurr_3": (96, 112),
"hurr_4": (113, 136),
"hurr_5": (137, np.inf),
}
def __init__(
self,
root: str = "data",
split: str = "train",
min_wind_speed: float = 0.0,
task: str = "regression",
seq_len: int = 3,
download: bool = False,
api_key: Optional[str] = None,
checksum: bool = False,
) -> None:
"""Initialize a new Tropical Cyclone Wind Estimation Competition Dataset.
Args:
root: root directory where dataset can be found
split: one of "train" or "test"
min_wind_speed: minimum wind speed to include in dataset
task: one of "regression" or "classification"
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
download: if True, download dataset and store it in the root directory
api_key: a RadiantEarth MLHub API key to use for downloading the dataset
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
AssertionError: if ``split`` argument is invalid
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
"""
super().__init__(root, split, None, download, api_key, checksum)
assert (
task in self.valid_tasks
), f"invalid task '{task}', please choose one of {self.valid_tasks}"
self.task = task
self.min_wind_speed = min_wind_speed
self.seq_len = seq_len
self.sequence_df = self.construct_sequences()
print(f"Num samples: {len(self.sequence_df)}")
def construct_sequences(self) -> list[list[str]]:
"""Construct sequence collection for data loading.
Returns:
collection as sequences
"""
df = pd.read_csv(os.path.join(self.root, f"{self.split}_info.csv"))
df = df[df["wind_speed"] >= self.min_wind_speed]
# setup df for possible classification task
filtered_class_bins = {
k: v for k, v in self.class_bins.items() if v[1] > self.min_wind_speed
}
filtered_class_bins = dict(
sorted(filtered_class_bins.items(), key=lambda item: item[1][0])
)
def assign_class(wind_speed):
"""Assign class index to wind speed."""
for i, (class_name, (min_speed, max_speed)) in enumerate(
filtered_class_bins.items()
):
# if wind_speed is within the range of the class, return the class index
if min_speed <= wind_speed <= max_speed:
return i
return len(filtered_class_bins) - 1
df["class_index"] = df["wind_speed"].apply(assign_class)
self.class_to_name = {
i: class_name for i, class_name in enumerate(filtered_class_bins.keys())
}
df["seq_id"] = (
df["path"]
.str.split("/", expand=True)[0]
.str.split("_", expand=True)[7]
.astype(int)
)
self.target_mean = df["wind_speed"].mean()
self.target_std = df["wind_speed"].std()
def get_subsequences(df: pd.DataFrame, k: int) -> list[dict[str, list[int]]]:
"""Generate all possible subsequences of length k for a given group.
Args:
df: grouped dataframe of a single typhoon
k: length of the subsequences to generate
Returns:
list of all possible subsequences of length k for a given typhoon id
"""
min_seq_id = df["seq_id"].min()
max_seq_id = df["seq_id"].max()
# generate possible subsquences of length k for the group
subsequences = [
list(range(i, i + k)) for i in range(min_seq_id, max_seq_id - k + 2)
]
filtered_subsequences: list[list[int]] = [
subseq for subseq in subsequences if set(subseq).issubset(df["seq_id"])
]
wind_speeds = [
df.loc[df["seq_id"] == subseq[-1], "wind_speed"].values[0]
for subseq in filtered_subsequences
if subseq
]
class_labels = [
df.loc[df["seq_id"] == subseq[-1], "class_index"].values[0]
for subseq in filtered_subsequences
if subseq
]
return {
"storm_id": df["storm_id"].iloc[0],
"subsequences": filtered_subsequences,
"wind_speed": wind_speeds,
"class_label": class_labels,
}
# Group by 'object_id' and find consecutive triplets for each group
cons_sequences = (
df.groupby("storm_id").apply(get_subsequences, k=self.seq_len).tolist()
)
# dropna the empty sequences
sequence_df = (
pd.DataFrame(cons_sequences)
.explode(["subsequences", "wind_speed", "class_label"])
.reset_index(drop=True)
.dropna()
)
return sequence_df
def __getitem__(self, index: int) -> dict[str, Any]:
"""Return an index within the dataset.
Args:
index: index to return
Returns:
data, labels
"""
storm_id = self.sequence_df.iloc[index].storm_id
subsequence = self.sequence_df.iloc[index].subsequences
imgs: list[Tensor] = []
for time_idx in subsequence:
directory = os.path.join(
self.root,
"_".join([self.collection_id, self.split, "{0}"]),
"_".join(
[
self.collection_id,
self.split,
"{0}",
storm_id,
str(time_idx).zfill(3),
]
),
)
imgs.append(self._load_image(directory))
sample: dict[str, Any] = {"input": torch.stack(imgs, 0)}
sample.update(self._load_features(directory))
if self.task == "classification":
sample["target"] = (
torch.tensor(int(self.sequence_df.iloc[index].class_label))
.squeeze()
.long()
)
else:
sample["target"] = (
torch.tensor(int(self.sequence_df.iloc[index].wind_speed))
.float()
.unsqueeze(-1)
)
sample["index"] = index
# already stored under "target"
del sample["label"]
del sample["wind_speed"]
if self.transforms is not None:
sample = self.transforms(sample)
return sample
@lru_cache
def _load_image(self, directory: str) -> Tensor:
"""Load a single image.
Args:
directory: directory containing image
Returns:
the image
"""
filename = os.path.join(directory.format("source"), "image.jpg")
with Image.open(filename) as img:
if img.height != self.size or img.width != self.size:
# Moved in PIL 9.1.0
try:
resample = Image.Resampling.BILINEAR
except AttributeError:
resample = Image.BILINEAR
img = img.resize(size=(self.size, self.size), resample=resample)
array: "np.typing.NDArray[np.int_]" = np.array(img)
tensor = torch.from_numpy(array).float()
# investigate why not all images have the same shape
if tensor.dim() != 2:
tensor = tensor[:, :, 0]
return tensor
def __len__(self) -> int:
"""Return the number of data points in the dataset.
Returns:
length of the dataset
"""
return len(self.sequence_df)
def plot(
self,
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> Figure:
"""Plot a sample from the dataset.
Args:
sample: a sample return by :meth:`__getitem__`
show_titles: flag indicating whether to show titles above each panel
suptitle: optional suptitle to use for figure
Returns:
a matplotlib Figure with the rendered sample
.. versionadded:: 0.2
"""
image, label = sample["inputs"] / 255, sample["wind_speed"]
showing_predictions = "prediction" in sample
if showing_predictions:
prediction = sample["prediction"].item()
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.imshow(image.permute(1, 2, 0))
ax.axis("off")
if show_titles:
title = f"Label: {label}"
if showing_predictions:
title += f"\nPrediction: {prediction}"
ax.set_title(title, fontsize=20)
if suptitle is not None:
plt.suptitle(suptitle)
return fig
class TropicalCycloneSequenceDataModule(NonGeoDataModule):
"""LightningDataModule implementation for the NASA Cyclone dataset.
Implements 80/20 train/val splits based on hurricane storm ids.
See :func:`setup` for more details.
"""
input_mean = torch.Tensor([0.28154722, 0.28071895, 0.27990073])
input_std = torch.Tensor([0.23435517, 0.23392765, 0.23351675])
valid_tasks = ["regression", "classification"]
def __init__(
self,
task: str = "regression",
batch_size: int = 64,
num_workers: int = 0,
**kwargs: Any,
) -> None:
"""Initialize a new TropicalCycloneDataModule instance.
Args:
task: One of "regression" or "classification"
batch_size: Size of each mini-batch.
num_workers: Number of workers for parallel data loading.
**kwargs: Additional keyword arguments passed to
:class:`~tropical_cyclone_uq.datasets.TropicalCyclone`.
"""
super().__init__(TropicalCycloneSequence, batch_size, num_workers, **kwargs)
assert (
task in self.valid_tasks
), f"invalid task '{task}', please choose one of {self.valid_tasks}"
self.task = task
self.dataset = TropicalCycloneSequence(split="train", **self.kwargs)
# mean and std can change based on setup because min wind speed is a variable
self.target_mean = torch.Tensor([self.dataset.target_mean])
self.target_std = torch.Tensor([self.dataset.target_std])
self.train_aug = AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std),
K.Normalize(mean=self.input_mean, std=self.input_std),
K.Resize(224),
K.RandomHorizontalFlip(p=0.5),
K.RandomVerticalFlip(p=0.5),
K.RandomRotation(degrees=(90, 91), p=0.5),
K.RandomRotation(degrees=(270, 271), p=0.5),
data_keys=["input"],
)
self.aug = AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std),
K.Normalize(mean=self.input_mean, std=self.input_std),
K.Resize(224),
data_keys=["input"],
)
def setup(self, stage: str) -> None:
"""Set up datasets.
Args:
stage: Either 'fit', 'validate', 'test', or 'predict'.
"""
if stage in ["fit", "validate"]:
self.dataset = TropicalCycloneSequence(
split="train", task=self.task, **self.kwargs
)
train_indices, val_indices = group_shuffle_split(
self.dataset.sequence_df.storm_id, test_size=0.20, random_state=0
)
validation_indices, calibration_indices = train_test_split(
val_indices, test_size=0.20, random_state=0
)
self.train_dataset = Subset(self.dataset, train_indices)
self.val_dataset = Subset(self.dataset, validation_indices)
self.calibration_dataset = Subset(self.dataset, calibration_indices)
if stage in ["test"]:
self.test_dataset = TropicalCycloneSequence(
split="test", task=self.task, **self.kwargs
)
def calibration_dataloader(self) -> torch.utils.data.DataLoader:
"""Return a dataloader for the calibration dataset."""
return DataLoader(
self.calibration_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
collate_fn=self.collate_fn,
shuffle=False,
)
def on_after_batch_transfer(
self, batch: Dict[str, Tensor], dataloader_idx: int
) -> Dict[str, Tensor]:
"""Apply batch augmentations to the batch after it is transferred to the device.
Args:
batch: A batch of data that needs to be altered or augmented.
dataloader_idx: The index of the dataloader to which the batch belongs.
Returns:
A batch of data.
"""
if self.task == "regression":
new_batch = {
"input": self.aug({"input": batch["input"].float()})["input"],
"target": (batch["target"].float() - self.target_mean)
/ self.target_std,
}
else:
new_batch = {
"input": self.aug({"input": batch["input"].float()})["input"],
"target": batch["target"].long(),
}
return new_batch
def run():
"""Run the experiment."""
datamodule = TropicalCycloneSequenceDataModule(
root="/p/project/hai_uqmethodbox/data/tropical_cyclone",
batch_size=64,
num_workers=12,
task="regression",
min_wind_speed=0.0,
seq_len=3,
)
datamodule.setup("fit")
target_mean = datamodule.target_mean.cpu()
target_std = datamodule.target_std.cpu()
def collate_fn_laplace_torch(batch):
"""Collate function to for laplace torch tuple convention.
Args:
batch: input batch
Returns:
return a tuple for Laplace
"""
# Extract images and labels from the batch dictionary
images = [item["input"] for item in batch]
labels = [item["target"] for item in batch]
# Stack images and labels into tensors
inputs = torch.stack(images)
targets = torch.stack(labels)
# apply datamodule augmentation
return (
datamodule.aug({"input": inputs.float()})["input"],
(targets.float() - target_mean) / target_std,
)
train_loader = datamodule.train_dataloader()
train_loader.collate_fn = collate_fn_laplace_torch
det_model = timm.create_model("resnet18", in_chans=3, num_classes=1)
det_model.load_state_dict(torch.load("resnet18.ckpt"))
det_model = det_model.to("cuda")
la = Laplace(det_model, likelihood="regression")
la.fit(train_loader)
log_prior, log_sigma = torch.ones(1, requires_grad=True), torch.ones(1, requires_grad=True)
hyper_optimizer = torch.optim.Adam([log_prior, log_sigma], lr=1e-1)
for i in range(100):
hyper_optimizer.zero_grad()
neg_marglik = - la.log_marginal_likelihood(log_prior.exp(), log_sigma.exp())
neg_marglik.backward()
hyper_optimizer.step()
val_loader = datamodule.val_dataloader()
val_loader.collate_fn = collate_fn_laplace_torch
batch = next(iter(val_loader))
X = batch[0].to("cuda")
f_mu, f_var = la(X)
f_mu = f_mu.squeeze().detach().cpu().numpy()
f_sigma = f_var.squeeze().sqrt().cpu().numpy()
pred_std = np.sqrt(f_sigma**2 + la.sigma_noise.item() ** 2)
if __name__ == "__main__":
run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment