Skip to content

Instantly share code, notes, and snippets.

@jakobamb
Created September 11, 2024 09:39
Show Gist options
  • Save jakobamb/f111113174db8b39c7d3543bffec6ab8 to your computer and use it in GitHub Desktop.
Save jakobamb/f111113174db8b39c7d3543bffec6ab8 to your computer and use it in GitHub Desktop.
RankMe implementation
import torch
def rankMe(z: torch.Tensor, epsilon=1e-7):
"""
Implementation of RankMe by Garrido et al. 2023 (https://proceedings.mlr.press/v202/garrido23a.html)
param z: a tensor of shape (B, N, K), where B is batch size.
"""
singular_values = torch.linalg.svd(z, full_matrices=False).S
sum_singular_values = torch.sum(singular_values, dim=-1, keepdim=True)
pk = singular_values / (sum_singular_values + epsilon)
log_pk = torch.log(pk + epsilon)
score = torch.exp(-torch.sum(pk * log_pk, dim=-1))
return score
if __name__ == "__main__":
t = torch.randn((64, 10, 5))
print(rankMe(t))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment