Last active
April 24, 2025 11:48
-
-
Save simonster/155894d48aef2bd36bd2dd8267e62391 to your computer and use it in GitHub Desktop.
Mean attention distance
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2022 Google LLC. | |
# SPDX-License-Identifier: Apache-2.0 | |
# Author: Maithra Raghu <[email protected]> | |
def compute_distance_matrix(patch_size, num_patches, length): | |
"""Helper function to compute distance matrix.""" | |
distance_matrix = np.zeros((num_patches, num_patches)) | |
for i in range(num_patches): | |
for j in range(num_patches): | |
if i == j: # zero distance | |
continue | |
xi, yi = (int(i/length)), (i % length) | |
xj, yj = (int(j/length)), (j % length) | |
distance_matrix[i, j] = patch_size*np.linalg.norm([xi - xj, yi - yj]) | |
return distance_matrix | |
def compute_mean_attention_dist(patch_size, attention_weights): | |
num_patches = attention_weights.shape[-1] | |
length = int(np.sqrt(num_patches)) | |
assert (length**2 == num_patches), ("Num patches is not perfect square") | |
distance_matrix = compute_distance_matrix(patch_size, num_patches, length) | |
h, w = distance_matrix.shape | |
distance_matrix = distance_matrix.reshape((1, 1, h, w)) | |
mean_distances = attention_weights*distance_matrix | |
mean_distances = np.sum(mean_distances, axis=-1) # sum along last axis to get average distance per token | |
mean_distances = np.mean(mean_distances, axis=-1) # now average across all the tokes | |
return mean_distances |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hey! I had a quick question regarding the attention distance calculation. Since the attention scores are softmaxed over all tokens (including the CLS token), removing the CLS token would mean we're slicing out a portion of the normalized distribution. That would make the attention weights over the 196 patch tokens no longer sum to 1.
Shouldn't we re-normalize the remaining 196 attention weights after excluding the CLS token? Otherwise, the computed mean attention distance wouldn’t truly reflect a proper expectation under a valid probability distribution. Curious if this step was intentionally skipped or if there's a rationale for not re-normalizing.
Thanks in advance!