Created
July 25, 2024 04:41
Revisions
-
ucalyptus2 created this gist
Jul 25, 2024 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,74 @@ import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader class Spectrum: def __init__(self, model, train_loader, val_loader, device='cuda'): self.model = model.to(device) self.train_loader = train_loader self.val_loader = val_loader self.device = device self.snr_threshold = None def compute_snr(self, weight_matrix): U, S, V = torch.svd(weight_matrix) signal = S[S > self.snr_threshold] noise = S[S <= self.snr_threshold] snr = signal.sum() / (noise.sum() + 1e-5) return snr def select_layers(self): snr_values = {} for name, param in self.model.named_parameters(): if 'weight' in name and param.requires_grad: snr = self.compute_snr(param.data) snr_values[name] = snr sorted_snr = sorted(snr_values.items(), key=lambda item: item[1], reverse=True) top_layers = [name for name, _ in sorted_snr[:int(0.25 * len(sorted_snr))]] return top_layers def freeze_layers(self, top_layers): for name, param in self.model.named_parameters(): if name not in top_layers: param.requires_grad = False def train(self, num_epochs, learning_rate): criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()), lr=learning_rate) for epoch in range(num_epochs): self.model.train() running_loss = 0.0 for i, (inputs, labels) in enumerate(self.train_loader): inputs, labels = inputs.to(self.device), labels.to(self.device) optimizer.zero_grad() outputs = self.model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() print(f'Epoch {epoch + 1}, Loss: {running_loss / len(self.train_loader)}') def validate(self): self.model.eval() correct = 0 total = 0 with torch.no_grad(): for inputs, labels in self.val_loader: inputs, labels = inputs.to(self.device), labels.to(self.device) outputs = self.model(inputs) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() print(f'Accuracy: {100 * correct / total}%') # Usage example # model = ... # Define your model # train_loader = DataLoader(...) # Define your training data loader # val_loader = DataLoader(...) # Define your validation data loader spectrum = Spectrum(model, train_loader, val_loader) top_layers = spectrum.select_layers() spectrum.freeze_layers(top_layers) spectrum.train(num_epochs=10, learning_rate=1e-5) spectrum.validate()