Skip to content

Instantly share code, notes, and snippets.

@mijia
Created December 6, 2012 13:30
Show Gist options
  • Save mijia/4224451 to your computer and use it in GitHub Desktop.
Save mijia/4224451 to your computer and use it in GitHub Desktop.
Running lda on Spark platform but in local mode
// Referring to the great open source lib of ompi-lda
// http://code.google.com/p/ompi-lda/
// I only experimented running this is the local mode and don't
// know if this will be working in other deployment mode.
import spark.broadcast._
import spark.SparkContext
import spark.SparkContext._
import spark.RDD
import spark.storage.StorageLevel
import scala.util.Random
import scala.math.{sqrt, log, pow, abs, exp, min, max}
import scala.collection.mutable.HashMap
import scala.collection.mutable.ListBuffer
import java.io._
object SparkLocalLDA {
def getDistributionModel(sc: SparkContext, documents: Seq[(Long, List[(Int, Int)], Array[Int])],
vSize: Int, kTopic: Int,
topicCountDistribution: Array[Array[Int]], globalDistribution: Array[Int]) = {
for (k <- 0 until kTopic) {
globalDistribution(k) = 0
for (v <- 0 until vSize) { topicCountDistribution(v)(k) = 0 }
}
for ((_, wordsTopic, _) <- documents) {
for ((word, topic) <- wordsTopic) {
topicCountDistribution(word)(topic) += 1
globalDistribution(topic) += 1
}
}
(topicCountDistribution, globalDistribution)
}
def restartSpark(sc: SparkContext, scMaster: String) = {
// After iterations, Spark will create a lot of RDDs and I only have 4g mem for it.
// So I have to restart the Spark. The thread.sleep is for the shutting down of Akka.
sc.stop()
Thread.sleep(2000)
new SparkContext(scMaster, "SparkLocalLDA")
}
def main(args: Array[String]) {
System.setProperty("file.encoding", "UTF-8")
System.setProperty("spark.serializer", "spark.KryoSerializer")
System.setProperty("spark.local.dir", "/opt/tmp")
val kTopic = 40
val alpha = 0.45
val beta = 0.01
val maxIter = 200
// Read the content from stdin, since I don't know how to handle UTF-8 textFiles with Spark
// the input format is for each line
// documentId word,word,word,word,word
var lines = new ListBuffer[String]()
var done = false
while (!done) {
val line = readLine()
done = line == null
if (!done) lines += line
}
val scMaster = args(0) // e.g. local[4]
var sc = new SparkContext(scMaster, "SparkLocalLDA")
var menus = sc.parallelize(lines).map { line =>
val vs = line.split("\t")
var wordList = List[(String, Int)]()
if (vs.length > 1 && vs(1).length > 0) {
val words = vs(1).split(",").toList
var wordCount = new HashMap[String, Int]() withDefaultValue 0
for (word <- words) { wordCount(word) += 1 }
wordList = wordCount.toList.sortWith((a, b) => a._2 > b._2)
}
(vs(0).toLong, wordList)
}.filter(_._2.length > 0)
val allWords = menus.map {
case (docId, words) =>
for ((word, _) <- words) yield word
}.flatMap(p => p).distinct().collect().toList.sortWith(_ < _)
val vSize = allWords.length
println("Vocabulary Size: " + vSize)
var wordIndexMap = HashMap[String, Int]()
for (i <- 0 until allWords.length) { wordIndexMap(allWords(i)) = i }
var bWordIndexMap = sc.broadcast(wordIndexMap)
// init topic distribution
var documents = menus.map {
case (docId, words) =>
var wordsTopic = List[(Int, Int)]()
for ((word, count) <- words) {
val wordIndex = bWordIndexMap.value(word)
wordsTopic = wordsTopic ++ List.fill(count) { (wordIndex, Random.nextInt(kTopic)) }
}
var topicDistribution = Array.fill(kTopic)(0)
for ((word, topic) <- wordsTopic) { topicDistribution(topic) += 1 }
(docId, wordsTopic, topicDistribution)
}.collect()
val _topicCountDistribution = Array.fill(vSize, kTopic)(0)
val _globalDistribution = Array.fill(kTopic)(0)
// start the iteration
for (iter <- 1 to maxIter) {
getDistributionModel(sc, documents, vSize, kTopic, _topicCountDistribution, _globalDistribution)
val topicCountDistribution = sc.broadcast(_topicCountDistribution)
val globalDistribution = sc.broadcast(_globalDistribution)
var _documents = sc.parallelize(documents).map {
case (docId, wordsTopic, topicDistribution) =>
var newWordsTopic = HashMap[Int, Int]()
for ((word, topic) <- wordsTopic) { newWordsTopic(word) = topic }
for ((word, topic) <- wordsTopic) {
// generate new word topic distribution
var newTopicDist = Array.fill(kTopic)(0.0)
val wordDist = topicCountDistribution.value(word)
for (k <- 0 until kTopic) {
val topicDelta = if (k == topic) -1 else 0
val topicWordFactor: Double = wordDist(k) + topicDelta
val globalTopicFactor: Double = globalDistribution.value(k) + topicDelta
val docTopicFactor: Double = topicDistribution(k) + topicDelta
newTopicDist(k) = (topicWordFactor + beta) * (docTopicFactor + alpha) / (globalTopicFactor + vSize * beta)
}
// choose the new topic
var distSum = newTopicDist.foldLeft(0.0)(_+_)
val choice = Random.nextDouble() * distSum
var localSum = 0.0
var newTopic = newTopicDist.map { dist =>
localSum += dist
localSum
}.indexWhere(_ >= choice)
// if we got the new topic then update the local model
if (newTopic != -1) {
// update model distribution
topicCountDistribution.value(word)(topic) += -1
globalDistribution.value(topic) += -1
topicCountDistribution.value(word)(newTopic) += 1
globalDistribution.value(newTopic) += 1
// update document distribution
topicDistribution(topic) -= 1
topicDistribution(newTopic) += 1
newWordsTopic(word) = newTopic
}
}
(docId, newWordsTopic.toList, topicDistribution)
}
documents = _documents.collect()
_documents = null
println("\nFinish running the ith iteration with i = " + iter)
println
if (iter % 30 == 0) {
println("Restart spark.")
sc = restartSpark(sc, scMaster)
}
}
getDistributionModel(sc, documents, vSize, kTopic, _topicCountDistribution, _globalDistribution)
val topicCountDistribution = sc.broadcast(_topicCountDistribution)
val globalDistribution = sc.broadcast(_globalDistribution)
sc.parallelize(documents).map {
case (docId, wordsTopic, topicDist) =>
val docLength = topicDist.foldLeft(0)(_+_)
// compute p(z|d)
var pTopicGivenDoc = Array.fill(kTopic)(0.0)
for (k <- 0 until kTopic) {
pTopicGivenDoc(k) = (topicDist(k) + alpha) / (docLength + alpha * kTopic)
}
for (k <- 0 until kTopic
if (topicDist(k) > 0)) yield (k, (docId, pTopicGivenDoc(k)))
}.flatMap(p => p).groupByKey().map {
case (topic, docs) =>
// get the top 500 docs for each topic
val topDocs = docs.toList.filter(_._2 >= 0.000001).sortWith(_._2 > _._2).slice(0, 500).map(p => p._1 + "," + p._2)
topic + "\t" + topDocs.mkString("\t")
}.saveAsTextFile("doc_dist")
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment