Last active
July 14, 2023 13:41
-
-
Save tmorioka/e4d127cacd5a814c5b36737297d98449 to your computer and use it in GitHub Desktop.
Unofficial re-implementation of “Low-Resource” Text Classification: A Parameter-Free Classification Method with Compressors, and a sloppy example of classification on ldcc dataset
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gzip | |
from collections import Counter | |
class PFCC(object): | |
def __init__(self, n_neighbors: int = 5): | |
self._n_neighbors = n_neighbors | |
self._X_train = None | |
self._C_train = None | |
self._y_train = None | |
self._N_train = None | |
def fit(self, X: list[str], y: list[str]): | |
self._X_train = X | |
self._C_train = [ | |
len(gzip.compress(xi.encode())) for xi in X | |
] | |
self._y_train = y | |
self._N_train = len(self._X_train) | |
return self | |
def predict(self, X: list[str]) -> list[str]: | |
if not self._is_trained(): | |
raise RuntimeError("Please call 'fit(X: list[str], y: list[str])' method in advance.") | |
y_pred = list() | |
for xi in X: | |
ci = len(gzip.compress(xi.encode())) | |
distance = list() | |
for xj, cj in zip(self._X_train, self._C_train): | |
cij = len(gzip.compress(" ".join([xi, xj]).encode())) | |
dij = (cij - min(ci, cj)) / max(ci, cj) | |
distance.append(dij) | |
idx = list(sorted(range(self._N_train), key=lambda j: distance[j])) | |
labels = Counter([self._y_train[j] for j in idx[:self._n_neighbors]]) | |
label, _ = max(labels.items(), key=lambda x: x[1]) | |
y_pred.append(label) | |
return y_pred | |
def _is_trained(self): | |
return all( | |
( | |
self._X_train is not None, | |
self._C_train is not None, | |
self._y_train is not None, | |
self._N_train is not None, | |
) | |
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import glob | |
import random | |
import os | |
from sklearn.model_selection import train_test_split | |
from sklearn.metrics import classification_report | |
from sklearn.linear_model import SGDClassifier | |
from sklearn.feature_extraction.text import CountVectorizer | |
from pfcc import PFCC | |
def _read_ldcc_corpus(prefix: str = "./"): | |
labels = ( | |
"dokujo-tsushin", | |
"it-life-hack", | |
"kaden-channel", | |
"livedoor-homme", | |
"movie-enter", | |
"peachy", | |
"smax", | |
"sports-watch", | |
"topic-news", | |
) | |
X = list() | |
y = list() | |
for label in labels: | |
filenames = [filename for filename in glob.glob(os.path.join(prefix, f"text/{label}/*.txt")) if "LICENSE.txt" not in filename] | |
random.shuffle(filenames) | |
for filename in filenames[:100]: | |
with open(filename) as fi: | |
text = "".join(fi.readlines()[2:]) | |
X.append(text) | |
y.append(label) | |
X_train, X_test, y_train, y_test = train_test_split( | |
X, y, test_size=0.5, random_state=42, shuffle=True, stratify=y, | |
) | |
return X_train, y_train, X_test, y_test | |
def _benchmark_ldcc(): | |
X_train, y_train, X_test, y_test = _read_ldcc_corpus() | |
pfcc = PFCC(5).fit(X_train, y_train) | |
y_pred = pfcc.predict(X_test) | |
print("=" * 40, "PFCC", "=" * 40) | |
print(classification_report(y_test, y_pred)) | |
# 申し訳程度の線形識別(文字2~5gram オンラインSVM) | |
vectorizer = CountVectorizer(ngram_range=(2, 5)) | |
_X_train = vectorizer.fit_transform(X_train) | |
_X_test = vectorizer.transform(X_test) | |
sgd_classifier = SGDClassifier(loss="hinge").fit(_X_train, y_train) | |
y_pred = sgd_classifier.predict(_X_test) | |
print("=" * 40, "SGDClassifier(hinge)", "=" * 40) | |
print(classification_report(y_test, y_pred)) | |
if __name__ == "__main__": | |
_benchmark_ldcc() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment