Last active
November 4, 2020 20:05
Revisions
-
santhalakshminarayana revised this gist
Nov 4, 2020 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -5,7 +5,7 @@ def get_batches_x(tot_seq, batch_size): batch_ids = ind[i:i+batch_size] yield X[batch_ids], Y[batch_ids] class Quote_Generator(nn.Module): def __init__(self, embed_size, hidden_size, vocab_len): super(Quote_Generator, self).__init__() self.hidden_size = hidden_size -
santhalakshminarayana revised this gist
Jan 6, 2020 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -23,7 +23,7 @@ def zero_states(self, batch_size): return (torch.zeros(1, batch_size, self.hidden_size).to(device), torch.zeros(1, batch_size, self.hidden_size).to(device)) def entropy_loss(y, y_hat): y_hat = F.softmax(y_hat, dim = 1) ll = - (y * torch.log(y_hat)) return torch.sum(ll, dim = 1).mean().to(device) -
santhalakshminarayana created this gist
Jan 6, 2020 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,59 @@ def get_batches_x(tot_seq, batch_size): ind = np.random.permutation(tot_seq).tolist() i = 0 for i in range(0, tot_seq, batch_size): batch_ids = ind[i:i+batch_size] yield X[batch_ids], Y[batch_ids] class Quote_Generator(nn.Module): def __init__(self, embed_size, hidden_size, vocab_len): super(Quote_Generator, self).__init__() self.hidden_size = hidden_size self.lstm = nn.LSTM(embed_size, hidden_size, batch_first = True).to(device) self.dropout = nn.Dropout(0.4) self.dense = nn.Linear(hidden_size*5, vocab_len).to(device) def forward(self, x, prev_state): output, state = self.lstm(x) output = self.dropout(output) logits = self.dense(output.reshape(-1, hidden_size*5)) return logits, state def zero_states(self, batch_size): return (torch.zeros(1, batch_size, self.hidden_size).to(device), torch.zeros(1, batch_size, self.hidden_size).to(device)) def entropy_loss(y, y_hat): y_hat = F.softmax(y_hat, dim = 1) ll = - (y * torch.log(y_hat)) return torch.sum(ll, dim = 1).mean().to(device) def qt_train(qt_gen): epochs = 101 batch_size = 4096 losses = [] optimizer = torch.optim.Adam(qt_gen.parameters(), lr=0.001) for epoch in tqdm(range(epochs)): batches = get_batches_x(tot_seq, batch_size) h_h, h_c = qt_gen.zero_states(batch_size) for x,y in batches: qt_gen.train() optimizer.zero_grad() x = torch.tensor(x).float().to(device) y = torch.tensor(y).long().to(device) logits, (h_h, h_c) = qt_gen(x, (h_h, h_c)) loss = entropy_loss(y, logits) h_h.detach() h_c.detach() loss.backward() _ = nn.utils.clip_grad_norm_(qt_gen.parameters(), 5) optimizer.step() losses.append(loss.item()) if (epoch) % 10 == 0: print(f"Epoch : {epoch} ----> Loss : {np.array(losses).mean()}") losses = [] embed_size = 128 hidden_size = 64 qt_gen = Quote_Generator(embed_size, hidden_size, vocab_len).to(device) qt_train(qt_gen)