Created
May 23, 2013 07:22
-
-
Save oddskool/5633266 to your computer and use it in GitHub Desktop.
Example for learning from a text stream with sklearn. Inspired by http://stackoverflow.com/questions/12460077/possibility-to-apply-online-algorithms-on-big-data-files-with-sklearn
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import defaultdict | |
import re | |
import numpy as np | |
from sklearn.datasets import fetch_20newsgroups | |
from sklearn.feature_extraction import FeatureHasher | |
from sklearn.linear_model.stochastic_gradient import SGDClassifier | |
from sklearn.externals import joblib | |
def tokens(doc): | |
"""Extract tokens from doc. | |
This uses a simple regex to break strings into tokens. For a more | |
principled approach, see CountVectorizer or TfidfVectorizer. | |
""" | |
return (tok.lower() for tok in re.findall(r"\w+", doc)) | |
def token_freqs(doc): | |
"""Extract a dict mapping tokens from doc to their frequencies.""" | |
freq = defaultdict(int) | |
for tok in tokens(doc): | |
freq[tok] += 1 | |
return freq | |
def chunker(seq, size): | |
"""Iterate by chunks on a sequence. Here we simulate what reading | |
from a stream would do by using a generator.""" | |
for pos in xrange(0, len(seq), size): | |
yield seq[pos:pos + size] | |
categories = [ | |
'alt.atheism', | |
'comp.graphics', | |
'comp.sys.ibm.pc.hardware', | |
'misc.forsale', | |
'rec.autos', | |
'sci.space', | |
'talk.religion.misc', | |
] | |
dataset = fetch_20newsgroups(subset='train', categories=categories) | |
classif_data = zip(dataset.data, dataset.target) | |
classes = np.array(list(set(dataset.target))) | |
hasher = FeatureHasher() | |
classifier = SGDClassifier() | |
for i, chunk in enumerate(chunker(classif_data, 100)): | |
messages, topics = zip(*chunk) | |
X = hasher.transform(token_freqs(msg) for msg in messages) | |
y = np.array(topics) | |
classifier.partial_fit(X, | |
topics, | |
classes=classes) | |
if i % 100 == 0: | |
# dump model to be able to monitor quality and later | |
# analyse convergence externally | |
joblib.dump(classifier, 'model_%04d.pkl' % i) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment