// Referring to the great open source lib of ompi-lda // http://code.google.com/p/ompi-lda/ // I only experimented running this is the local mode and don't // know if this will be working in other deployment mode. import spark.broadcast._ import spark.SparkContext import spark.SparkContext._ import spark.RDD import spark.storage.StorageLevel import scala.util.Random import scala.math.{sqrt, log, pow, abs, exp, min, max} import scala.collection.mutable.HashMap import scala.collection.mutable.ListBuffer import java.io._ object SparkLocalLDA { def getDistributionModel(sc: SparkContext, documents: Seq[(Long, List[(Int, Int)], Array[Int])], vSize: Int, kTopic: Int, topicCountDistribution: Array[Array[Int]], globalDistribution: Array[Int]) = { for (k <- 0 until kTopic) { globalDistribution(k) = 0 for (v <- 0 until vSize) { topicCountDistribution(v)(k) = 0 } } for ((_, wordsTopic, _) <- documents) { for ((word, topic) <- wordsTopic) { topicCountDistribution(word)(topic) += 1 globalDistribution(topic) += 1 } } (topicCountDistribution, globalDistribution) } def restartSpark(sc: SparkContext, scMaster: String) = { // After iterations, Spark will create a lot of RDDs and I only have 4g mem for it. // So I have to restart the Spark. The thread.sleep is for the shutting down of Akka. sc.stop() Thread.sleep(2000) new SparkContext(scMaster, "SparkLocalLDA") } def main(args: Array[String]) { System.setProperty("file.encoding", "UTF-8") System.setProperty("spark.serializer", "spark.KryoSerializer") System.setProperty("spark.local.dir", "/opt/tmp") val kTopic = 40 val alpha = 0.45 val beta = 0.01 val maxIter = 200 // Read the content from stdin, since I don't know how to handle UTF-8 textFiles with Spark // the input format is for each line // documentId word,word,word,word,word var lines = new ListBuffer[String]() var done = false while (!done) { val line = readLine() done = line == null if (!done) lines += line } val scMaster = args(0) // e.g. local[4] var sc = new SparkContext(scMaster, "SparkLocalLDA") var menus = sc.parallelize(lines).map { line => val vs = line.split("\t") var wordList = List[(String, Int)]() if (vs.length > 1 && vs(1).length > 0) { val words = vs(1).split(",").toList var wordCount = new HashMap[String, Int]() withDefaultValue 0 for (word <- words) { wordCount(word) += 1 } wordList = wordCount.toList.sortWith((a, b) => a._2 > b._2) } (vs(0).toLong, wordList) }.filter(_._2.length > 0) val allWords = menus.map { case (docId, words) => for ((word, _) <- words) yield word }.flatMap(p => p).distinct().collect().toList.sortWith(_ < _) val vSize = allWords.length println("Vocabulary Size: " + vSize) var wordIndexMap = HashMap[String, Int]() for (i <- 0 until allWords.length) { wordIndexMap(allWords(i)) = i } var bWordIndexMap = sc.broadcast(wordIndexMap) // init topic distribution var documents = menus.map { case (docId, words) => var wordsTopic = List[(Int, Int)]() for ((word, count) <- words) { val wordIndex = bWordIndexMap.value(word) wordsTopic = wordsTopic ++ List.fill(count) { (wordIndex, Random.nextInt(kTopic)) } } var topicDistribution = Array.fill(kTopic)(0) for ((word, topic) <- wordsTopic) { topicDistribution(topic) += 1 } (docId, wordsTopic, topicDistribution) }.collect() val _topicCountDistribution = Array.fill(vSize, kTopic)(0) val _globalDistribution = Array.fill(kTopic)(0) // start the iteration for (iter <- 1 to maxIter) { getDistributionModel(sc, documents, vSize, kTopic, _topicCountDistribution, _globalDistribution) val topicCountDistribution = sc.broadcast(_topicCountDistribution) val globalDistribution = sc.broadcast(_globalDistribution) var _documents = sc.parallelize(documents).map { case (docId, wordsTopic, topicDistribution) => var newWordsTopic = HashMap[Int, Int]() for ((word, topic) <- wordsTopic) { newWordsTopic(word) = topic } for ((word, topic) <- wordsTopic) { // generate new word topic distribution var newTopicDist = Array.fill(kTopic)(0.0) val wordDist = topicCountDistribution.value(word) for (k <- 0 until kTopic) { val topicDelta = if (k == topic) -1 else 0 val topicWordFactor: Double = wordDist(k) + topicDelta val globalTopicFactor: Double = globalDistribution.value(k) + topicDelta val docTopicFactor: Double = topicDistribution(k) + topicDelta newTopicDist(k) = (topicWordFactor + beta) * (docTopicFactor + alpha) / (globalTopicFactor + vSize * beta) } // choose the new topic var distSum = newTopicDist.foldLeft(0.0)(_+_) val choice = Random.nextDouble() * distSum var localSum = 0.0 var newTopic = newTopicDist.map { dist => localSum += dist localSum }.indexWhere(_ >= choice) // if we got the new topic then update the local model if (newTopic != -1) { // update model distribution topicCountDistribution.value(word)(topic) += -1 globalDistribution.value(topic) += -1 topicCountDistribution.value(word)(newTopic) += 1 globalDistribution.value(newTopic) += 1 // update document distribution topicDistribution(topic) -= 1 topicDistribution(newTopic) += 1 newWordsTopic(word) = newTopic } } (docId, newWordsTopic.toList, topicDistribution) } documents = _documents.collect() _documents = null println("\nFinish running the ith iteration with i = " + iter) println if (iter % 30 == 0) { println("Restart spark.") sc = restartSpark(sc, scMaster) } } getDistributionModel(sc, documents, vSize, kTopic, _topicCountDistribution, _globalDistribution) val topicCountDistribution = sc.broadcast(_topicCountDistribution) val globalDistribution = sc.broadcast(_globalDistribution) sc.parallelize(documents).map { case (docId, wordsTopic, topicDist) => val docLength = topicDist.foldLeft(0)(_+_) // compute p(z|d) var pTopicGivenDoc = Array.fill(kTopic)(0.0) for (k <- 0 until kTopic) { pTopicGivenDoc(k) = (topicDist(k) + alpha) / (docLength + alpha * kTopic) } for (k <- 0 until kTopic if (topicDist(k) > 0)) yield (k, (docId, pTopicGivenDoc(k))) }.flatMap(p => p).groupByKey().map { case (topic, docs) => // get the top 500 docs for each topic val topDocs = docs.toList.filter(_._2 >= 0.000001).sortWith(_._2 > _._2).slice(0, 500).map(p => p._1 + "," + p._2) topic + "\t" + topDocs.mkString("\t") }.saveAsTextFile("doc_dist") } }