// counts errors, and keep one error record for debugging
// WARNING: b/c this uses accumulators, the semantics around counting are *extremely* confusing
// if the RDD ever gets recomputed, do to shared lineage, cache eviction, or stage retries.
// use with caution.

import org.apache.spark.{Accumulator, AccumulatorParam, SparkContext}

class ErrorTracker[T] private(val name: String) extends Serializable {

  var errorCounts = 0L
  var successCounts = 0L
  var errorSample: Option[T] = None

  def ok(): Unit = successCounts += 1

  def error(t: T): Unit = {
    errorCounts += 1
    errorSample = Some(t)
  }

  override def toString(): String = {
    val total = errorCounts + successCounts
    val frac = errorCounts.toDouble / total
    f"$errorCounts%d errors / $total%d total ($frac%2.2f). " +
      s"${errorSample.map{e => s"One random error: $e"}.getOrElse("")}"
  }
}

object ErrorTracker {
  def apply[T](name: String, sc: SparkContext): Accumulator[ErrorTracker[T]] = {
    sc.accumulator(new ErrorTracker[T](name), name)(new ErrorTrackerAccumulator[T])
  }
}

private class ErrorTrackerAccumulator[T] extends AccumulatorParam[ErrorTracker[T]] {
  override def addInPlace(r1: ErrorTracker[T], r2: ErrorTracker[T]): ErrorTracker[T] = {
    r1.errorCounts += r2.errorCounts
    r1.successCounts += r2.successCounts
    if (r1.errorSample.isEmpty) {
      r1.errorSample = r2.errorSample
    }
    r1
  }

  override def zero(initialValue: ErrorTracker[T]): ErrorTracker[T] = initialValue
}