Created
December 6, 2012 13:30
-
-
Save mijia/4224451 to your computer and use it in GitHub Desktop.
Running lda on Spark platform but in local mode
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Referring to the great open source lib of ompi-lda | |
// http://code.google.com/p/ompi-lda/ | |
// I only experimented running this is the local mode and don't | |
// know if this will be working in other deployment mode. | |
import spark.broadcast._ | |
import spark.SparkContext | |
import spark.SparkContext._ | |
import spark.RDD | |
import spark.storage.StorageLevel | |
import scala.util.Random | |
import scala.math.{sqrt, log, pow, abs, exp, min, max} | |
import scala.collection.mutable.HashMap | |
import scala.collection.mutable.ListBuffer | |
import java.io._ | |
object SparkLocalLDA { | |
def getDistributionModel(sc: SparkContext, documents: Seq[(Long, List[(Int, Int)], Array[Int])], | |
vSize: Int, kTopic: Int, | |
topicCountDistribution: Array[Array[Int]], globalDistribution: Array[Int]) = { | |
for (k <- 0 until kTopic) { | |
globalDistribution(k) = 0 | |
for (v <- 0 until vSize) { topicCountDistribution(v)(k) = 0 } | |
} | |
for ((_, wordsTopic, _) <- documents) { | |
for ((word, topic) <- wordsTopic) { | |
topicCountDistribution(word)(topic) += 1 | |
globalDistribution(topic) += 1 | |
} | |
} | |
(topicCountDistribution, globalDistribution) | |
} | |
def restartSpark(sc: SparkContext, scMaster: String) = { | |
// After iterations, Spark will create a lot of RDDs and I only have 4g mem for it. | |
// So I have to restart the Spark. The thread.sleep is for the shutting down of Akka. | |
sc.stop() | |
Thread.sleep(2000) | |
new SparkContext(scMaster, "SparkLocalLDA") | |
} | |
def main(args: Array[String]) { | |
System.setProperty("file.encoding", "UTF-8") | |
System.setProperty("spark.serializer", "spark.KryoSerializer") | |
System.setProperty("spark.local.dir", "/opt/tmp") | |
val kTopic = 40 | |
val alpha = 0.45 | |
val beta = 0.01 | |
val maxIter = 200 | |
// Read the content from stdin, since I don't know how to handle UTF-8 textFiles with Spark | |
// the input format is for each line | |
// documentId word,word,word,word,word | |
var lines = new ListBuffer[String]() | |
var done = false | |
while (!done) { | |
val line = readLine() | |
done = line == null | |
if (!done) lines += line | |
} | |
val scMaster = args(0) // e.g. local[4] | |
var sc = new SparkContext(scMaster, "SparkLocalLDA") | |
var menus = sc.parallelize(lines).map { line => | |
val vs = line.split("\t") | |
var wordList = List[(String, Int)]() | |
if (vs.length > 1 && vs(1).length > 0) { | |
val words = vs(1).split(",").toList | |
var wordCount = new HashMap[String, Int]() withDefaultValue 0 | |
for (word <- words) { wordCount(word) += 1 } | |
wordList = wordCount.toList.sortWith((a, b) => a._2 > b._2) | |
} | |
(vs(0).toLong, wordList) | |
}.filter(_._2.length > 0) | |
val allWords = menus.map { | |
case (docId, words) => | |
for ((word, _) <- words) yield word | |
}.flatMap(p => p).distinct().collect().toList.sortWith(_ < _) | |
val vSize = allWords.length | |
println("Vocabulary Size: " + vSize) | |
var wordIndexMap = HashMap[String, Int]() | |
for (i <- 0 until allWords.length) { wordIndexMap(allWords(i)) = i } | |
var bWordIndexMap = sc.broadcast(wordIndexMap) | |
// init topic distribution | |
var documents = menus.map { | |
case (docId, words) => | |
var wordsTopic = List[(Int, Int)]() | |
for ((word, count) <- words) { | |
val wordIndex = bWordIndexMap.value(word) | |
wordsTopic = wordsTopic ++ List.fill(count) { (wordIndex, Random.nextInt(kTopic)) } | |
} | |
var topicDistribution = Array.fill(kTopic)(0) | |
for ((word, topic) <- wordsTopic) { topicDistribution(topic) += 1 } | |
(docId, wordsTopic, topicDistribution) | |
}.collect() | |
val _topicCountDistribution = Array.fill(vSize, kTopic)(0) | |
val _globalDistribution = Array.fill(kTopic)(0) | |
// start the iteration | |
for (iter <- 1 to maxIter) { | |
getDistributionModel(sc, documents, vSize, kTopic, _topicCountDistribution, _globalDistribution) | |
val topicCountDistribution = sc.broadcast(_topicCountDistribution) | |
val globalDistribution = sc.broadcast(_globalDistribution) | |
var _documents = sc.parallelize(documents).map { | |
case (docId, wordsTopic, topicDistribution) => | |
var newWordsTopic = HashMap[Int, Int]() | |
for ((word, topic) <- wordsTopic) { newWordsTopic(word) = topic } | |
for ((word, topic) <- wordsTopic) { | |
// generate new word topic distribution | |
var newTopicDist = Array.fill(kTopic)(0.0) | |
val wordDist = topicCountDistribution.value(word) | |
for (k <- 0 until kTopic) { | |
val topicDelta = if (k == topic) -1 else 0 | |
val topicWordFactor: Double = wordDist(k) + topicDelta | |
val globalTopicFactor: Double = globalDistribution.value(k) + topicDelta | |
val docTopicFactor: Double = topicDistribution(k) + topicDelta | |
newTopicDist(k) = (topicWordFactor + beta) * (docTopicFactor + alpha) / (globalTopicFactor + vSize * beta) | |
} | |
// choose the new topic | |
var distSum = newTopicDist.foldLeft(0.0)(_+_) | |
val choice = Random.nextDouble() * distSum | |
var localSum = 0.0 | |
var newTopic = newTopicDist.map { dist => | |
localSum += dist | |
localSum | |
}.indexWhere(_ >= choice) | |
// if we got the new topic then update the local model | |
if (newTopic != -1) { | |
// update model distribution | |
topicCountDistribution.value(word)(topic) += -1 | |
globalDistribution.value(topic) += -1 | |
topicCountDistribution.value(word)(newTopic) += 1 | |
globalDistribution.value(newTopic) += 1 | |
// update document distribution | |
topicDistribution(topic) -= 1 | |
topicDistribution(newTopic) += 1 | |
newWordsTopic(word) = newTopic | |
} | |
} | |
(docId, newWordsTopic.toList, topicDistribution) | |
} | |
documents = _documents.collect() | |
_documents = null | |
println("\nFinish running the ith iteration with i = " + iter) | |
println | |
if (iter % 30 == 0) { | |
println("Restart spark.") | |
sc = restartSpark(sc, scMaster) | |
} | |
} | |
getDistributionModel(sc, documents, vSize, kTopic, _topicCountDistribution, _globalDistribution) | |
val topicCountDistribution = sc.broadcast(_topicCountDistribution) | |
val globalDistribution = sc.broadcast(_globalDistribution) | |
sc.parallelize(documents).map { | |
case (docId, wordsTopic, topicDist) => | |
val docLength = topicDist.foldLeft(0)(_+_) | |
// compute p(z|d) | |
var pTopicGivenDoc = Array.fill(kTopic)(0.0) | |
for (k <- 0 until kTopic) { | |
pTopicGivenDoc(k) = (topicDist(k) + alpha) / (docLength + alpha * kTopic) | |
} | |
for (k <- 0 until kTopic | |
if (topicDist(k) > 0)) yield (k, (docId, pTopicGivenDoc(k))) | |
}.flatMap(p => p).groupByKey().map { | |
case (topic, docs) => | |
// get the top 500 docs for each topic | |
val topDocs = docs.toList.filter(_._2 >= 0.000001).sortWith(_._2 > _._2).slice(0, 500).map(p => p._1 + "," + p._2) | |
topic + "\t" + topDocs.mkString("\t") | |
}.saveAsTextFile("doc_dist") | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment