// Referring to the great open source lib of ompi-lda
// http://code.google.com/p/ompi-lda/
// I only experimented running this is the local mode and don't
// know if this will be working in other deployment mode.

import spark.broadcast._
import spark.SparkContext
import spark.SparkContext._
import spark.RDD
import spark.storage.StorageLevel

import scala.util.Random
import scala.math.{sqrt, log, pow, abs, exp, min, max}
import scala.collection.mutable.HashMap
import scala.collection.mutable.ListBuffer
import java.io._

object SparkLocalLDA {

  def getDistributionModel(sc: SparkContext, documents: Seq[(Long, List[(Int, Int)], Array[Int])],
              vSize: Int, kTopic: Int,
              topicCountDistribution: Array[Array[Int]], globalDistribution: Array[Int]) = {

    for (k <- 0 until kTopic) {
      globalDistribution(k) = 0
      for (v <- 0 until vSize) { topicCountDistribution(v)(k) = 0 }
    }
    for ((_, wordsTopic, _) <- documents) {
      for ((word, topic) <- wordsTopic) {
        topicCountDistribution(word)(topic) += 1
        globalDistribution(topic) += 1
      }
    }

    (topicCountDistribution, globalDistribution)
  }

  def restartSpark(sc: SparkContext, scMaster: String) = {
    // After iterations, Spark will create a lot of RDDs and I only have 4g mem for it.
    // So I have to restart the Spark. The thread.sleep is for the shutting down of Akka.
    sc.stop()
    Thread.sleep(2000)
    new SparkContext(scMaster, "SparkLocalLDA")
  }

  def main(args: Array[String]) {
    System.setProperty("file.encoding", "UTF-8")
    System.setProperty("spark.serializer", "spark.KryoSerializer")
    System.setProperty("spark.local.dir", "/opt/tmp")

    val kTopic = 40
    val alpha = 0.45
    val beta = 0.01
    val maxIter = 200

    // Read the content from stdin, since I don't know how to handle UTF-8 textFiles with Spark
    // the input format is for each line
    //     documentId    word,word,word,word,word
    var lines = new ListBuffer[String]()
    var done = false
    while (!done) {
      val line = readLine()
      done = line == null
      if (!done) lines += line
    }

    val scMaster = args(0)  // e.g. local[4]
    var sc = new SparkContext(scMaster, "SparkLocalLDA")
    var menus = sc.parallelize(lines).map { line =>
      val vs = line.split("\t")
      var wordList = List[(String, Int)]()
      if (vs.length > 1 && vs(1).length > 0) {
        val words = vs(1).split(",").toList
        var wordCount = new HashMap[String, Int]() withDefaultValue 0
        for (word <- words) { wordCount(word) += 1 }
        wordList = wordCount.toList.sortWith((a, b) => a._2 > b._2)
      }
      (vs(0).toLong, wordList)
    }.filter(_._2.length > 0)

    val allWords = menus.map {
      case (docId, words) =>
        for ((word, _) <- words) yield word
    }.flatMap(p => p).distinct().collect().toList.sortWith(_ < _)

    val vSize = allWords.length
    println("Vocabulary Size: " + vSize)

    var wordIndexMap = HashMap[String, Int]()
    for (i <- 0 until allWords.length) { wordIndexMap(allWords(i)) = i }
    var bWordIndexMap = sc.broadcast(wordIndexMap)

    // init topic distribution
    var documents = menus.map {
      case (docId, words) =>
        var wordsTopic = List[(Int, Int)]()
        for ((word, count) <- words) {
          val wordIndex = bWordIndexMap.value(word)
          wordsTopic = wordsTopic ++ List.fill(count) { (wordIndex, Random.nextInt(kTopic)) }
        }
        var topicDistribution = Array.fill(kTopic)(0)
        for ((word, topic) <- wordsTopic) { topicDistribution(topic) += 1 }
        (docId, wordsTopic, topicDistribution)
    }.collect()

    val _topicCountDistribution = Array.fill(vSize, kTopic)(0)
    val _globalDistribution = Array.fill(kTopic)(0)

    // start the iteration
    for (iter <- 1 to maxIter) {
      getDistributionModel(sc, documents, vSize, kTopic, _topicCountDistribution, _globalDistribution)
      val topicCountDistribution = sc.broadcast(_topicCountDistribution)
      val globalDistribution = sc.broadcast(_globalDistribution)

      var _documents = sc.parallelize(documents).map {
        case (docId, wordsTopic, topicDistribution) =>
          var newWordsTopic = HashMap[Int, Int]()
          for ((word, topic) <- wordsTopic) { newWordsTopic(word) = topic }

          for ((word, topic) <- wordsTopic) {
            // generate new word topic distribution
            var newTopicDist = Array.fill(kTopic)(0.0)
            val wordDist = topicCountDistribution.value(word)
            for (k <- 0 until kTopic) {
              val topicDelta = if (k == topic) -1 else 0
              val topicWordFactor: Double = wordDist(k) + topicDelta
              val globalTopicFactor: Double = globalDistribution.value(k) + topicDelta
              val docTopicFactor: Double = topicDistribution(k) + topicDelta
              newTopicDist(k) = (topicWordFactor + beta) * (docTopicFactor + alpha) / (globalTopicFactor + vSize * beta)
            }

            // choose the new topic
            var distSum = newTopicDist.foldLeft(0.0)(_+_)
            val choice = Random.nextDouble() * distSum
            var localSum = 0.0
            var newTopic = newTopicDist.map { dist =>
              localSum += dist
              localSum
            }.indexWhere(_ >= choice)

            // if we got the new topic then update the local model
            if (newTopic != -1) {
              // update model distribution
              topicCountDistribution.value(word)(topic) += -1
              globalDistribution.value(topic) += -1
              topicCountDistribution.value(word)(newTopic) += 1
              globalDistribution.value(newTopic) += 1

              // update document distribution
              topicDistribution(topic) -= 1
              topicDistribution(newTopic) += 1
              newWordsTopic(word) = newTopic
            }
          }
          (docId, newWordsTopic.toList, topicDistribution)
      }

      documents = _documents.collect()
      _documents = null

      println("\nFinish running the ith iteration with i = " + iter)
      println

      if (iter % 30 == 0) {
        println("Restart spark.")
        sc = restartSpark(sc, scMaster)
      }
    }

    getDistributionModel(sc, documents, vSize, kTopic, _topicCountDistribution, _globalDistribution)

    val topicCountDistribution = sc.broadcast(_topicCountDistribution)
    val globalDistribution = sc.broadcast(_globalDistribution)
    sc.parallelize(documents).map {
      case (docId, wordsTopic, topicDist) =>
        val docLength = topicDist.foldLeft(0)(_+_)
        // compute p(z|d)
        var pTopicGivenDoc = Array.fill(kTopic)(0.0)
        for (k <- 0 until kTopic) {
          pTopicGivenDoc(k) = (topicDist(k) + alpha) / (docLength + alpha * kTopic)
        }
        for (k <- 0 until kTopic
             if (topicDist(k) > 0)) yield (k, (docId, pTopicGivenDoc(k)))
    }.flatMap(p => p).groupByKey().map {
      case (topic, docs) =>
        // get the top 500 docs for each topic
        val topDocs = docs.toList.filter(_._2 >= 0.000001).sortWith(_._2 > _._2).slice(0, 500).map(p => p._1 + "," + p._2)
        topic + "\t" + topDocs.mkString("\t")
    }.saveAsTextFile("doc_dist")

  }

}