Skip to content

Instantly share code, notes, and snippets.

@xiabingquan
Created August 28, 2025 04:04
Show Gist options
  • Save xiabingquan/fec67c744ad59c5ee7ebc6c70ec3bac9 to your computer and use it in GitHub Desktop.
Save xiabingquan/fec67c744ad59c5ee7ebc6c70ec3bac9 to your computer and use it in GitHub Desktop.
A minimal script to visualize RoPE's long-decay behaviour
import os
import tqdm
import numpy as np
import matplotlib.pyplot as plt
from loguru import logger
from joblib import Parallel, delayed
def vallina_theta_b1e4(i, d):
"""
The vanilla RoPE theta function with base 10000.
Args:
i: the dimension index
d: the total feature dimension
"""
return 10000 ** (-2 * i / d)
def vallina_theta_b1e6(i, d):
return 1000000 ** (-2 * i / d)
THETA_FUNCS = {
'vallina_b1e4': vallina_theta_b1e4,
'vallina_b1e6': vallina_theta_b1e6,
}
def run_one_example(d: int, context: int, theta_func_name: str, fig_path: str) -> None:
"""
Args:
d: feature dimension
context: the maximum context length
theta_func_name: the name of the theta function to use
fig_path: the path to save the figure
"""
theta_func = THETA_FUNCS[theta_func_name]
# Define the statistics function f(m)
def f(m):
result = 0
for j in range(d // 2): # j from 0 to d/2-1
# Calculate sum of complex exponentials
complex_sum = 0
for i in range(j + 1): # i from 0 to j
complex_sum += np.exp(1j * m * theta_func(i, d)) # Added missing 'm *'
# Add the norm (absolute value) to result
result += np.abs(complex_sum)
# Average by d/2
return result / (d / 2)
# Generate m values from 0 to context
m_values = np.arange(0, context, 1)
f_values = [f(m) for m in m_values]
# Create the plot
plt.figure(figsize=(10, 6))
plt.plot(m_values, f_values, linewidth=2)
plt.xlabel('relative distance (m)')
plt.ylabel('Statistics f(m)')
plt.title(f'd={d}, context={context}, theta_func={theta_func_name}')
plt.grid(True, alpha=0.3)
plt.tight_layout()
os.makedirs(os.path.dirname(fig_path) or './', exist_ok=True)
plt.savefig(fig_path, dpi=600)
if __name__ == "__main__":
save_dir = './rope_figs'
fn_args = []
for d in (32, 64, 128, 512, 1024, 2048):
for context in (256, 512, 2048, 4096, 8192, 16384, 32768):
for theta_func in ('vallina_b1e4', 'vallina_b1e6'):
fig_path = os.path.join(save_dir, f'dimension-{d}', f'd{d}_c{context}_{theta_func}.png')
if os.path.exists(fig_path):
logger.info(f"Figure already exists: {fig_path}, skipping...")
continue
fn_args.append((d, context, theta_func, fig_path))
num_proc = 8
Parallel(n_jobs=num_proc)(
delayed(run_one_example)(*a) for a in tqdm.tqdm(fn_args, desc="Generating figures")
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment