Created
May 4, 2015 00:21
-
-
Save aron-bordin/0a3c13a508b5246702be to your computer and use it in GitHub Desktop.
Sentiment NLP with Mallet
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.aronbordin; | |
import ca.uwo.csd.ai.nlp.kernel.LinearKernel; | |
import ca.uwo.csd.ai.nlp.mallet.libsvm.SVMClassifierTrainer; | |
import cc.mallet.classify.*; | |
import cc.mallet.pipe.*; | |
import cc.mallet.pipe.iterator.CsvIterator; | |
import cc.mallet.types.Instance; | |
import cc.mallet.types.InstanceList; | |
import cc.mallet.util.Randoms; | |
import java.io.*; | |
import java.util.ArrayList; | |
import java.util.Iterator; | |
import java.util.regex.Pattern; | |
public class Sentiment { | |
protected Pipe pipe; | |
protected InstanceList trainInstances; | |
protected static final String FILE_TRAIN_X = "data/train_labeled.tsv"; | |
protected static final String FILE_TEST = "data/test_data.tsv"; | |
protected static final String FILE_TRAIN_X_BIN = "data/train_labeled.bin"; | |
protected static final String FILE_CLASSIFIER = "data/classifier.bin"; | |
protected static final String FILE_TEST_OUT = "data/test.out"; | |
Sentiment() { | |
pipe = buildPipe(); | |
} | |
protected void importFile() { | |
CsvIterator iter = null; | |
try { | |
iter = new CsvIterator(FILE_TRAIN_X, "\"(\\w+)\"\\s+(\\d)\\s+(.*)", 3, 2, 1); | |
} catch (FileNotFoundException e) { | |
e.printStackTrace(); | |
} | |
trainInstances = new InstanceList(pipe); | |
trainInstances.addThruPipe(iter); | |
} | |
private void saveFile() { | |
trainInstances.save(new File(FILE_TRAIN_X_BIN)); | |
} | |
protected Pipe buildPipe() { | |
ArrayList<Pipe> pipeList = new ArrayList<Pipe>(); | |
pipeList.add(new Input2CharSequence("UTF-8")); | |
Pattern patternToken = Pattern.compile("[\\p{L}\\p{N}_]+"); | |
pipeList.add(new CharSequence2TokenSequence(patternToken)); | |
pipeList.add(new TokenSequenceLowercase()); | |
pipeList.add(new TokenSequenceRemoveStopwords(false, false)); | |
// pipeList.add(new TokenSequenceRemoveNonAlpha()); | |
pipeList.add(new TokenSequence2FeatureSequence()); | |
pipeList.add(new Target2Label()); | |
pipeList.add(new FeatureSequence2FeatureVector()); | |
return new SerialPipes(pipeList); | |
} | |
protected void train() { | |
InstanceList[] splited_data = trainInstances.split( | |
new Randoms(), | |
new double[]{0.7, 0.3} | |
); | |
ClassifierTrainer trainer; | |
Classifier classifier; | |
Trial trial; | |
/* System.out.println("Training with MaxEntTrainer..."); | |
trainer = new MaxEntTrainer(); | |
classifier = trainer.train(splited_data[0]); | |
System.out.print("Done! "); | |
trial = new Trial(classifier, splited_data[1]); | |
System.out.println("Accuracy: " + trial.getAccuracy()); | |
System.out.println("Training with NaiveBayesTrainer..."); | |
trainer = new NaiveBayesTrainer(); | |
classifier = trainer.train(splited_data[0]); | |
System.out.print("Done! "); | |
trial = new Trial(classifier, splited_data[1]); | |
System.out.println("Accuracy: " + trial.getAccuracy());*/ | |
System.out.println("Training with SVMClassifierTrainer..."); | |
trainer = new SVMClassifierTrainer(new LinearKernel()); | |
classifier = trainer.train(splited_data[0]); | |
System.out.print("Done! "); | |
trial = new Trial(classifier, splited_data[1]); | |
System.out.println("Accuracy: " + trial.getAccuracy()); | |
} | |
protected void trainAndPredict() { | |
ClassifierTrainer trainer; | |
Classifier classifier = null; | |
try { | |
if (new File(FILE_CLASSIFIER).exists()) { | |
ObjectInputStream obj = new ObjectInputStream(new FileInputStream(new File(FILE_CLASSIFIER))); | |
classifier = (Classifier) obj.readObject(); | |
obj.close(); | |
} else { | |
// System.out.println("Training with MaxEnt..."); | |
// trainer = new MaxEntTrainer(); | |
// classifier = trainer.train(trainInstances); | |
System.out.println("Training with NaiveBayesTrainer..."); | |
trainer = new NaiveBayesTrainer(); | |
classifier = trainer.train(trainInstances); | |
ObjectOutputStream obj = new ObjectOutputStream(new FileOutputStream(new File(FILE_CLASSIFIER))); | |
obj.writeObject(classifier); | |
obj.close(); | |
} | |
} catch (IOException e) { | |
e.printStackTrace(); | |
} catch (ClassNotFoundException e) { | |
e.printStackTrace(); | |
} | |
CsvIterator iterTest = null; | |
try { | |
iterTest = new CsvIterator(FILE_TEST, "\"(\\w+)\"\\s+(.*)", 2, 0, 1); | |
} catch (FileNotFoundException e) { | |
e.printStackTrace(); | |
} | |
Iterator<Instance> testInstances = classifier.getInstancePipe().newIteratorFrom(iterTest); | |
FileWriter csv = null; | |
try { | |
csv = new FileWriter(new File(FILE_TEST_OUT)); | |
csv.append("\"id\",\"sentiment\"\n"); | |
while (testInstances.hasNext()) { | |
Instance i = testInstances.next(); | |
System.out.println(classifier.classify(i)); | |
csv.append("\"" + i.getName() + "\","); | |
csv.append(classifier.classify(i).getLabelVector().getBestLabel().toString()); | |
csv.append("\n"); | |
} | |
csv.flush(); | |
csv.close(); | |
} catch (IOException e) { | |
e.printStackTrace(); | |
} | |
} | |
protected void readFile() { | |
trainInstances = trainInstances.load(new File(FILE_TRAIN_X_BIN)); | |
} | |
public static void main(String a[]) { | |
System.out.println("Sentiment analysis\n\n\n"); | |
System.out.println("Reading data..."); | |
Sentiment snt = new Sentiment(); | |
if (new File(FILE_TRAIN_X_BIN).exists()) { | |
snt.readFile(); | |
} else { | |
System.out.println("\tVectoring labeled sentences..."); | |
snt.importFile(); | |
snt.saveFile(); | |
System.out.println("\tDone!"); | |
} | |
snt.train(); | |
System.out.println("Done!"); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment