Created
September 2, 2021 14:33
-
-
Save pbloem/2c3af77626d6c80f62487c35a28e3e8c to your computer and use it in GitHub Desktop.
Data loaders for DLVU assignment 3B (recurrent neural nets)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import wget, os, gzip, pickle, random, re, sys | |
IMDB_URL = 'http://dlvu.github.io/data/imdb.{}.pkl.gz' | |
IMDB_FILE = 'imdb.{}.pkl.gz' | |
PAD, START, END, UNK = '.pad', '.start', '.end', '.unk' | |
def load_imdb(final=False, val=5000, seed=0, voc=None, char=False): | |
cst = 'char' if char else 'word' | |
imdb_url = IMDB_URL.format(cst) | |
imdb_file = IMDB_FILE.format(cst) | |
if not os.path.exists(imdb_file): | |
wget.download(imdb_url) | |
with gzip.open(imdb_file) as file: | |
sequences, labels, i2w, w2i = pickle.load(file) | |
if voc is not None and voc < len(i2w): | |
nw_sequences = {} | |
i2w = i2w[:voc] | |
w2i = {w: i for i, w in enumerate(i2w)} | |
mx, unk = voc, w2i['.unk'] | |
for key, seqs in sequences.items(): | |
nw_sequences[key] = [] | |
for seq in seqs: | |
seq = [s if s < mx else unk for s in seq] | |
nw_sequences[key].append(seq) | |
sequences = nw_sequences | |
if final: | |
return (sequences['train'], labels['train']), (sequences['test'], labels['test']), (i2w, w2i), 2 | |
# Make a validation split | |
random.seed(seed) | |
x_train, y_train = [], [] | |
x_val, y_val = [], [] | |
val_ind = set( random.sample(range(len(sequences['train'])), k=val) ) | |
for i, (s, l) in enumerate(zip(sequences['train'], labels['train'])): | |
if i in val_ind: | |
x_val.append(s) | |
y_val.append(l) | |
else: | |
x_train.append(s) | |
y_train.append(l) | |
return (x_train, y_train), \ | |
(x_val, y_val), \ | |
(i2w, w2i), 2 | |
def gen_sentence(sent, g): | |
symb = '_[a-z]*' | |
while True: | |
match = re.search(symb, sent) | |
if match is None: | |
return sent | |
s = match.span() | |
sent = sent[:s[0]] + random.choice(g[sent[s[0]:s[1]]]) + sent[s[1]:] | |
def gen_dyck(p): | |
open = 1 | |
sent = '(' | |
while open > 0: | |
if random.random() < p: | |
sent += '(' | |
open += 1 | |
else: | |
sent += ')' | |
open -= 1 | |
return sent | |
def gen_ndfa(p): | |
word = random.choice(['abc!', 'uvw!', 'klm!']) | |
s = '' | |
while True: | |
if random.random() < p: | |
return 's' + s + 's' | |
else: | |
s+= word | |
def load_brackets(n=50_000, seed=0): | |
return load_toy(n, char=True, seed=seed, name='dyck') | |
def load_ndfa(n=50_000, seed=0): | |
return load_toy(n, char=True, seed=seed, name='ndfa') | |
def load_toy(n=50_000, char=True, seed=0, name='lang'): | |
random.seed(0) | |
if name == 'lang': | |
sent = '_s' | |
toy = { | |
'_s': ['_s _adv', '_np _vp', '_np _vp _prep _np', '_np _vp ( _prep _np )', '_np _vp _con _s' , '_np _vp ( _con _s )'], | |
'_adv': ['briefly', 'quickly', 'impatiently'], | |
'_np': ['a _noun', 'the _noun', 'a _adj _noun', 'the _adj _noun'], | |
'_prep': ['on', 'with', 'to'], | |
'_con' : ['while', 'but'], | |
'_noun': ['mouse', 'bunny', 'cat', 'dog', 'man', 'woman', 'person'], | |
'_vp': ['walked', 'walks', 'ran', 'runs', 'goes', 'went'], | |
'_adj': ['short', 'quick', 'busy', 'nice', 'gorgeous'] | |
} | |
sentences = [ gen_sentence(sent, toy) for _ in range(n)] | |
sentences.sort(key=lambda s : len(s)) | |
elif name == 'dyck': | |
sentences = [gen_dyck(7./16.) for _ in range(n)] | |
sentences.sort(key=lambda s: len(s)) | |
elif name == 'ndfa': | |
sentences = [gen_ndfa(1./4.) for _ in range(n)] | |
sentences.sort(key=lambda s: len(s)) | |
else: | |
raise Exception(name) | |
tokens = set() | |
for s in sentences: | |
if char: | |
for c in s: | |
tokens.add(c) | |
else: | |
for w in s.split(): | |
tokens.add(w) | |
i2t = [PAD, START, END, UNK] + list(tokens) | |
t2i = {t:i for i, t in enumerate(i2t)} | |
sequences = [] | |
for s in sentences: | |
if char: | |
tok = list(s) | |
else: | |
tok = s.split() | |
sequences.append([t2i[t] for t in tok]) | |
return sequences, (i2t, t2i) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment