Skip to content

Instantly share code, notes, and snippets.

@simonster
Last active April 24, 2025 11:48
Show Gist options
  • Save simonster/155894d48aef2bd36bd2dd8267e62391 to your computer and use it in GitHub Desktop.
Save simonster/155894d48aef2bd36bd2dd8267e62391 to your computer and use it in GitHub Desktop.
Mean attention distance
# Copyright 2022 Google LLC.
# SPDX-License-Identifier: Apache-2.0
# Author: Maithra Raghu <[email protected]>
def compute_distance_matrix(patch_size, num_patches, length):
"""Helper function to compute distance matrix."""
distance_matrix = np.zeros((num_patches, num_patches))
for i in range(num_patches):
for j in range(num_patches):
if i == j: # zero distance
continue
xi, yi = (int(i/length)), (i % length)
xj, yj = (int(j/length)), (j % length)
distance_matrix[i, j] = patch_size*np.linalg.norm([xi - xj, yi - yj])
return distance_matrix
def compute_mean_attention_dist(patch_size, attention_weights):
num_patches = attention_weights.shape[-1]
length = int(np.sqrt(num_patches))
assert (length**2 == num_patches), ("Num patches is not perfect square")
distance_matrix = compute_distance_matrix(patch_size, num_patches, length)
h, w = distance_matrix.shape
distance_matrix = distance_matrix.reshape((1, 1, h, w))
mean_distances = attention_weights*distance_matrix
mean_distances = np.sum(mean_distances, axis=-1) # sum along last axis to get average distance per token
mean_distances = np.mean(mean_distances, axis=-1) # now average across all the tokes
return mean_distances
@anika81199
Copy link

Hey! I had a quick question regarding the attention distance calculation. Since the attention scores are softmaxed over all tokens (including the CLS token), removing the CLS token would mean we're slicing out a portion of the normalized distribution. That would make the attention weights over the 196 patch tokens no longer sum to 1.
Shouldn't we re-normalize the remaining 196 attention weights after excluding the CLS token? Otherwise, the computed mean attention distance wouldn’t truly reflect a proper expectation under a valid probability distribution. Curious if this step was intentionally skipped or if there's a rationale for not re-normalizing.

Thanks in advance!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment