Created
March 23, 2024 23:33
-
-
Save catid/a5ac7ece55627fffe73efce205f5c848 to your computer and use it in GitHub Desktop.
GIVT GMM Decoder (GPT-4)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Collaboration between Claude-3 and GPT-4 to implement https://arxiv.org/pdf/2312.02116.pdf | |
# This is just the GMM decoder part of the model they propose (which is the new thing). | |
# This one was mainly generated by GPT-4. | |
# The AIs provided two implementations of the idea and revised eachothers' code. | |
# I tested that the unit tests pass but haven't tried it in a language model yet. | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class GMMParametersPrediction(nn.Module): | |
def __init__(self, hidden_dim, output_dim, num_components): | |
super(GMMParametersPrediction, self).__init__() | |
self.num_components = num_components | |
self.mu = nn.Linear(hidden_dim, output_dim * num_components) | |
self.log_sigma = nn.Linear(hidden_dim, output_dim * num_components) | |
self.logits_pi = nn.Linear(hidden_dim, num_components) | |
def forward(self, x): | |
mu = self.mu(x).view(x.size(0), x.size(1), self.num_components, -1) | |
log_sigma = self.log_sigma(x).view(x.size(0), x.size(1), self.num_components, -1) | |
logits_pi = self.logits_pi(x) | |
pi = F.softmax(logits_pi, dim=-1) | |
sigma = torch.exp(log_sigma) | |
return mu, sigma, pi | |
class GMMOutput(nn.Module): | |
def __init__(self, hidden_dim, output_dim, num_mixtures): | |
super(GMMOutput, self).__init__() | |
self.num_mixtures = num_mixtures | |
self.fc_means = nn.Linear(hidden_dim, output_dim * num_mixtures) | |
self.fc_log_scales = nn.Linear(hidden_dim, output_dim * num_mixtures) | |
self.fc_logits_weights = nn.Linear(hidden_dim, num_mixtures) | |
def forward(self, x): | |
means = self.fc_means(x).view(x.size(0), x.size(1), self.num_mixtures, -1) | |
log_scales = self.fc_log_scales(x).view(x.size(0), x.size(1), self.num_mixtures, -1) | |
scales = torch.exp(log_scales) | |
logits_weights = self.fc_logits_weights(x) | |
weights = F.softmax(logits_weights, dim=-1) | |
return means, scales, weights | |
if __name__ == "__main__": | |
print("GMM Models Module. Define and test GMM parameter prediction models.") | |
import unittest | |
import torch | |
class TestGMMParametersPrediction(unittest.TestCase): | |
def setUp(self): | |
self.hidden_dim = 512 # Example hidden dimension | |
self.output_dim = 256 # Example output dimension (e.g., embedding size) | |
self.num_components = 10 # Number of GMM components | |
self.batch_size = 4 | |
self.seq_length = 7 | |
self.model = GMMParametersPrediction(self.hidden_dim, self.output_dim, self.num_components) | |
def test_output_shapes(self): | |
# Simulate input tensor | |
x = torch.randn(self.batch_size, self.seq_length, self.hidden_dim) | |
mu, sigma, pi = self.model(x) | |
self.assertEqual(mu.shape, (self.batch_size, self.seq_length, self.num_components, self.output_dim)) | |
self.assertEqual(sigma.shape, (self.batch_size, self.seq_length, self.num_components, self.output_dim)) | |
self.assertEqual(pi.shape, (self.batch_size, self.seq_length, self.num_components)) | |
def test_constraints(self): | |
x = torch.randn(self.batch_size, self.seq_length, self.hidden_dim) | |
_, sigma, pi = self.model(x) | |
self.assertTrue(torch.all(sigma > 0), "All sigma values should be positive.") | |
self.assertTrue(torch.allclose(pi.sum(dim=-1), torch.ones(self.batch_size, self.seq_length)), "Mixing coefficients should sum to 1.") | |
class TestGMMOutput(unittest.TestCase): | |
def setUp(self): | |
self.hidden_dim = 512 | |
self.output_dim = 256 | |
self.num_mixtures = 10 | |
self.batch_size = 4 | |
self.seq_length = 7 | |
self.model = GMMOutput(self.hidden_dim, self.output_dim, self.num_mixtures) | |
def test_output_shapes(self): | |
x = torch.randn(self.batch_size, self.seq_length, self.hidden_dim) | |
means, scales, weights = self.model(x) | |
self.assertEqual(means.shape, (self.batch_size, self.seq_length, self.num_mixtures, self.output_dim)) | |
self.assertEqual(scales.shape, (self.batch_size, self.seq_length, self.num_mixtures, self.output_dim)) | |
self.assertEqual(weights.shape, (self.batch_size, self.seq_length, self.num_mixtures)) | |
def test_constraints(self): | |
x = torch.randn(self.batch_size, self.seq_length, self.hidden_dim) | |
_, scales, weights = self.model(x) | |
self.assertTrue(torch.all(scales > 0), "All scale values should be positive.") | |
self.assertTrue(torch.allclose(weights.sum(dim=-1), torch.ones(self.batch_size, self.seq_length)), "Mixing coefficients should sum to 1.") | |
if __name__ == "__main__": | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment