Last active
October 21, 2019 22:42
-
-
Save tdiggelm/8dba42e40352a3f7b2f135f30e9f9c62 to your computer and use it in GitHub Desktop.
Helper functions to initialise Embedding layer with pre-trained GloVe embeddings.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import zipfile | |
from io import TextIOWrapper | |
import numpy as np | |
import tensorflow as tf | |
import os | |
GLOVE_EMBEDDINGS = { | |
'glove.6B.50d' : ('http://nlp.stanford.edu/data/glove.6B.zip' , 50), | |
'glove.6B.100d' : ('http://nlp.stanford.edu/data/glove.6B.zip' , 100), | |
'glove.6B.200d' : ('http://nlp.stanford.edu/data/glove.6B.zip' , 200), | |
'glove.6B.300d' : ('http://nlp.stanford.edu/data/glove.6B.zip' , 300), | |
'glove.42B.300d' : ('http://nlp.stanford.edu/data/glove.42B.300d.zip' , 300), | |
'glove.840B.300d': ('http://nlp.stanford.edu/data/glove.840B.300d.zip' , 300), | |
} | |
def download_glove_embeddings_index(pretrained_name): | |
try: | |
url, dimension = GLOVE_EMBEDDINGS[pretrained_name] | |
except KeyError: | |
raise ValueError(f"Unknown pretrained embedding name '{pretrained_name}'") | |
archive_fname = os.path.basename(url) | |
archive_fname = tf.keras.utils.get_file(archive_fname, url) | |
embeddings_index = {} | |
with zipfile.ZipFile(archive_fname) as z: | |
with z.open(f'{pretrained_name}.txt') as f: | |
f = TextIOWrapper(f, encoding='utf-8') | |
for line in f: | |
values = line.split() | |
word = values[0] | |
coefs = np.asarray(values[1:], dtype='float32') | |
embeddings_index[word] = coefs | |
return embeddings_index, dimension | |
# get glove coeff matrix | |
def get_glove_embeddings(tok, pretrained_name='glove.6B.300d'): | |
embeddings_index, embeddings_dim = download_glove_embeddings_index(pretrained_name) | |
if tok.num_words: | |
max_num_words = tok.num_words | |
else: | |
max_num_words = len(tok.word_index) + 1 | |
word_index = tok.word_index | |
num_words = min(max_num_words, len(word_index)+1) | |
embeddings_matrix = np.random.uniform(low=-0.25, high=0.25, size=(num_words, embeddings_dim)) | |
for word, i in word_index.items(): | |
if i >= max_num_words: | |
continue | |
embedding_vector = embeddings_index.get(word) | |
if not embedding_vector is None: | |
embeddings_matrix[i] = embedding_vector | |
# return: (input_dim, output_dim, embeddings_matrix) | |
return (num_words, embeddings_dim, embeddings_matrix) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment