Skip to content

Instantly share code, notes, and snippets.

@ucalyptus2
Created July 25, 2024 04:41

Revisions

  1. ucalyptus2 created this gist Jul 25, 2024.
    74 changes: 74 additions & 0 deletions spectrum_gpt4o.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,74 @@
    import torch
    import torch.nn as nn
    import torch.optim as optim
    from torch.utils.data import DataLoader

    class Spectrum:
    def __init__(self, model, train_loader, val_loader, device='cuda'):
    self.model = model.to(device)
    self.train_loader = train_loader
    self.val_loader = val_loader
    self.device = device
    self.snr_threshold = None

    def compute_snr(self, weight_matrix):
    U, S, V = torch.svd(weight_matrix)
    signal = S[S > self.snr_threshold]
    noise = S[S <= self.snr_threshold]
    snr = signal.sum() / (noise.sum() + 1e-5)
    return snr

    def select_layers(self):
    snr_values = {}
    for name, param in self.model.named_parameters():
    if 'weight' in name and param.requires_grad:
    snr = self.compute_snr(param.data)
    snr_values[name] = snr
    sorted_snr = sorted(snr_values.items(), key=lambda item: item[1], reverse=True)
    top_layers = [name for name, _ in sorted_snr[:int(0.25 * len(sorted_snr))]]
    return top_layers

    def freeze_layers(self, top_layers):
    for name, param in self.model.named_parameters():
    if name not in top_layers:
    param.requires_grad = False

    def train(self, num_epochs, learning_rate):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()), lr=learning_rate)
    for epoch in range(num_epochs):
    self.model.train()
    running_loss = 0.0
    for i, (inputs, labels) in enumerate(self.train_loader):
    inputs, labels = inputs.to(self.device), labels.to(self.device)
    optimizer.zero_grad()
    outputs = self.model(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    running_loss += loss.item()
    print(f'Epoch {epoch + 1}, Loss: {running_loss / len(self.train_loader)}')

    def validate(self):
    self.model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
    for inputs, labels in self.val_loader:
    inputs, labels = inputs.to(self.device), labels.to(self.device)
    outputs = self.model(inputs)
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()
    print(f'Accuracy: {100 * correct / total}%')

    # Usage example
    # model = ... # Define your model
    # train_loader = DataLoader(...) # Define your training data loader
    # val_loader = DataLoader(...) # Define your validation data loader

    spectrum = Spectrum(model, train_loader, val_loader)
    top_layers = spectrum.select_layers()
    spectrum.freeze_layers(top_layers)
    spectrum.train(num_epochs=10, learning_rate=1e-5)
    spectrum.validate()