Skip to content

Instantly share code, notes, and snippets.

@tdiggelm
Last active October 21, 2019 22:42
Show Gist options
  • Save tdiggelm/8dba42e40352a3f7b2f135f30e9f9c62 to your computer and use it in GitHub Desktop.
Save tdiggelm/8dba42e40352a3f7b2f135f30e9f9c62 to your computer and use it in GitHub Desktop.
Helper functions to initialise Embedding layer with pre-trained GloVe embeddings.
import zipfile
from io import TextIOWrapper
import numpy as np
import tensorflow as tf
import os
GLOVE_EMBEDDINGS = {
'glove.6B.50d' : ('http://nlp.stanford.edu/data/glove.6B.zip' , 50),
'glove.6B.100d' : ('http://nlp.stanford.edu/data/glove.6B.zip' , 100),
'glove.6B.200d' : ('http://nlp.stanford.edu/data/glove.6B.zip' , 200),
'glove.6B.300d' : ('http://nlp.stanford.edu/data/glove.6B.zip' , 300),
'glove.42B.300d' : ('http://nlp.stanford.edu/data/glove.42B.300d.zip' , 300),
'glove.840B.300d': ('http://nlp.stanford.edu/data/glove.840B.300d.zip' , 300),
}
def download_glove_embeddings_index(pretrained_name):
try:
url, dimension = GLOVE_EMBEDDINGS[pretrained_name]
except KeyError:
raise ValueError(f"Unknown pretrained embedding name '{pretrained_name}'")
archive_fname = os.path.basename(url)
archive_fname = tf.keras.utils.get_file(archive_fname, url)
embeddings_index = {}
with zipfile.ZipFile(archive_fname) as z:
with z.open(f'{pretrained_name}.txt') as f:
f = TextIOWrapper(f, encoding='utf-8')
for line in f:
values = line.split()
word = values[0]
coefs = np.asarray(values[1:], dtype='float32')
embeddings_index[word] = coefs
return embeddings_index, dimension
# get glove coeff matrix
def get_glove_embeddings(tok, pretrained_name='glove.6B.300d'):
embeddings_index, embeddings_dim = download_glove_embeddings_index(pretrained_name)
if tok.num_words:
max_num_words = tok.num_words
else:
max_num_words = len(tok.word_index) + 1
word_index = tok.word_index
num_words = min(max_num_words, len(word_index)+1)
embeddings_matrix = np.random.uniform(low=-0.25, high=0.25, size=(num_words, embeddings_dim))
for word, i in word_index.items():
if i >= max_num_words:
continue
embedding_vector = embeddings_index.get(word)
if not embedding_vector is None:
embeddings_matrix[i] = embedding_vector
# return: (input_dim, output_dim, embeddings_matrix)
return (num_words, embeddings_dim, embeddings_matrix)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment