Created
December 12, 2016 05:23
-
-
Save dilipbobby/bb38acae7d6ca33ba49130a23e9b7e0b to your computer and use it in GitHub Desktop.
LatentDirichletAllocation test code.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import com.datumbox.framework.common.Configuration; | |
import com.datumbox.framework.common.dataobjects.Dataframe; | |
import com.datumbox.framework.common.dataobjects.Record; | |
import com.datumbox.framework.core.machinelearning.classification.SoftMaxRegression; | |
import com.datumbox.framework.core.machinelearning.topicmodeling.LatentDirichletAllocation; | |
import com.datumbox.framework.core.utilities.text.extractors.UniqueWordSequenceExtractor; | |
import java.io.UncheckedIOException; | |
import java.net.URI; | |
import java.net.URISyntaxException; | |
import java.util.HashMap; | |
import java.util.Map; | |
public class LatentDirichletAllocationTest { | |
/** | |
* class LatentDirichletAllocation. | |
*/ | |
public static void main(String args[]) { | |
Configuration conf = Configuration.getConfiguration(); | |
String dbName = "LatentDirichletAllocation"; | |
Map<Object, URI> dataset = new HashMap<>(); | |
try { | |
dataset.put("positive", LatentDirichletAllocationTest.class.getClassLoader().getResource("datasets/sentipos").toURI()); | |
dataset.put("negative", LatentDirichletAllocationTest.class.getClassLoader().getResource("datasets/sentineg").toURI()); | |
} | |
catch(UncheckedIOException | URISyntaxException ex) { | |
//logger.warn("Unable to download datasets, skipping test."); | |
throw new RuntimeException(ex); | |
} | |
UniqueWordSequenceExtractor wsExtractor = new UniqueWordSequenceExtractor(new UniqueWordSequenceExtractor.Parameters()); | |
Dataframe trainingData = Dataframe.Builder.parseTextFiles(dataset, wsExtractor, conf); | |
LatentDirichletAllocation lda = new LatentDirichletAllocation(dbName, conf); | |
LatentDirichletAllocation.TrainingParameters trainingParameters = new LatentDirichletAllocation.TrainingParameters(); | |
trainingParameters.setMaxIterations(15); | |
trainingParameters.setAlpha(0.01); | |
trainingParameters.setBeta(0.01); | |
trainingParameters.setK(25); | |
lda.fit(trainingData, trainingParameters); | |
lda.validate(trainingData); | |
Dataframe reducedTrainingData = new Dataframe(conf); | |
for(Record r : trainingData) { | |
//take the topic assignments and convert them into a new Record | |
reducedTrainingData.add(new Record(r.getYPredictedProbabilities(), r.getY())); | |
System.out.println(r); | |
//System.out.println(r.getY()); | |
// System.out.println(r.getYPredictedProbabilities()); | |
} | |
reducedTrainingData.get(0); | |
SoftMaxRegression smr = new SoftMaxRegression(dbName, conf); | |
SoftMaxRegression.TrainingParameters tp = new SoftMaxRegression.TrainingParameters(); | |
tp.setLearningRate(1.0); | |
tp.setTotalIterations(50); | |
SoftMaxRegression.ValidationMetrics vm = smr.kFoldCrossValidation(reducedTrainingData, tp, 1); | |
Dataframe redtestData = new Dataframe(conf); | |
// double expResult = 0.6843125117743629; | |
double result = vm.getMacroF1(); | |
System.out.println(result); | |
// assertEquals(expResult, result, Constants.DOUBLE_ACCURACY_HIGH); | |
smr.delete(); | |
lda.delete(); | |
reducedTrainingData.delete(); | |
trainingData.delete(); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment