Created
November 23, 2017 23:42
-
-
Save GreenRiverRUS/2eb167d8102e6db2a7d467437e77002b to your computer and use it in GitHub Desktop.
spaCy model builder
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding: utf8 | |
from __future__ import unicode_literals | |
import gzip | |
import plac | |
import math | |
from tqdm import tqdm | |
import spacy | |
import numpy | |
from ast import literal_eval | |
from pathlib import Path | |
from preshed.counter import PreshCounter | |
from spacy.compat import fix_text | |
from spacy.vectors import Vectors | |
from spacy.util import prints, ensure_path | |
@plac.annotations( | |
lang=("model language", "positional", None, str), | |
output_dir=("model output directory", "positional", None, Path), | |
freqs_loc=("location of words frequencies data", "positional", | |
None, Path), | |
clusters_loc=("optional: location of clusters data", | |
"option", None, str), | |
vectors_loc=("optional: location of vectors data", | |
"option", None, str), | |
prune_vectors=("optional: number of vectors to prune to.", | |
"option", "V", int) | |
) | |
def main(lang, output_dir, freqs_loc, clusters_loc=None, vectors_loc=None, prune_vectors=-1): | |
if not freqs_loc.exists(): | |
prints(freqs_loc, title="Can't find words frequencies data", exits=1) | |
clusters_loc = ensure_path(clusters_loc) | |
vectors_loc = ensure_path(vectors_loc) | |
print('Reading freqs...') | |
probs, oov_prob = read_freqs(freqs_loc) | |
print('Reading vectors...') | |
vectors_data, vector_keys = read_vectors(vectors_loc) if vectors_loc else None | |
print('Reading clusters...') | |
clusters = read_clusters(clusters_loc) if clusters_loc else {} | |
nlp = create_model(lang, probs, oov_prob, clusters, vectors_data, vector_keys, prune_vectors) | |
if not output_dir.exists(): | |
output_dir.mkdir() | |
nlp.to_disk(output_dir) | |
return nlp | |
def create_model(lang, probs, oov_prob, clusters, vectors_data, vector_keys, prune_vectors): | |
print('Creating model...') | |
nlp = spacy.blank(lang) | |
for lexeme in nlp.vocab: | |
lexeme.rank = 0 | |
lex_added = 0 | |
for i, (word, prob) in enumerate(tqdm(sorted(probs.items(), key=lambda item: item[1], reverse=True))): | |
lexeme = nlp.vocab[word] | |
lexeme.rank = i | |
lexeme.prob = prob | |
lexeme.is_oov = False | |
# Decode as a little-endian string, so that we can do & 15 to get | |
# the first 4 bits. See _parse_features.pyx | |
if word in clusters: | |
lexeme.cluster = int(clusters[word][::-1], 2) | |
else: | |
lexeme.cluster = 0 | |
lex_added += 1 | |
nlp.vocab.cfg.update({'oov_prob': oov_prob}) | |
if len(vectors_data): | |
nlp.vocab.vectors = Vectors(data=vectors_data, keys=vector_keys) | |
if prune_vectors >= 1: | |
nlp.vocab.prune_vectors(prune_vectors) | |
vec_added = len(nlp.vocab.vectors) | |
prints("{} entries, {} vectors".format(lex_added, vec_added), | |
title="Sucessfully compiled vocab") | |
return nlp | |
def read_vectors(vectors_loc): | |
with vectors_loc.open() as f: | |
shape = tuple(int(size) for size in f.readline().split()) | |
vectors_data = numpy.zeros(shape=shape, dtype='f') | |
vectors_keys = [] | |
for i, line in enumerate(tqdm(f)): | |
pieces = line.split() | |
word = pieces.pop(0) | |
vectors_data[i] = numpy.array([float(val_str) for val_str in pieces], dtype='f') | |
vectors_keys.append(word) | |
return vectors_data, vectors_keys | |
def read_freqs(freqs_path, max_length=100, min_doc_freq=5, min_freq=50): | |
counts = PreshCounter() | |
total = 0 | |
freqs_file = check_unzip(freqs_path) | |
for i, line in enumerate(freqs_file): | |
freq, doc_freq, key = line.rstrip().split('\t', 2) | |
freq = int(freq) | |
counts.inc(i+1, freq) | |
total += freq | |
counts.smooth() | |
log_total = math.log(total) | |
freqs_file = check_unzip(freqs_path) | |
probs = {} | |
for line in tqdm(freqs_file): | |
freq, doc_freq, key = line.rstrip().split('\t', 2) | |
doc_freq = int(doc_freq) | |
freq = int(freq) | |
if doc_freq >= min_doc_freq and freq >= min_freq and len(key) < max_length: | |
word = literal_eval(key) | |
smooth_count = counts.smoother(int(freq)) | |
probs[word] = math.log(smooth_count) - log_total | |
oov_prob = math.log(counts.smoother(0)) - log_total | |
return probs, oov_prob | |
def read_clusters(clusters_path): | |
clusters = {} | |
with clusters_path.open() as f: | |
for line in tqdm(f): | |
try: | |
cluster, word, freq = line.split() | |
word = fix_text(word) | |
except ValueError: | |
continue | |
# If the clusterer has only seen the word a few times, its | |
# cluster is unreliable. | |
if int(freq) >= 3: | |
clusters[word] = cluster | |
else: | |
clusters[word] = '0' | |
# Expand clusters with re-casing | |
for word, cluster in list(clusters.items()): | |
if word.lower() not in clusters: | |
clusters[word.lower()] = cluster | |
if word.title() not in clusters: | |
clusters[word.title()] = cluster | |
if word.upper() not in clusters: | |
clusters[word.upper()] = cluster | |
return clusters | |
def check_unzip(file_path): | |
file_path_str = file_path.as_posix() | |
if file_path_str.endswith('gz'): | |
return gzip.open(file_path_str) | |
else: | |
return file_path.open() | |
if __name__ == '__main__': | |
plac.call(main) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment