Created
March 14, 2012 02:35
-
-
Save anonymous/2033568 to your computer and use it in GitHub Desktop.
Weighted Selection
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import scala.util.Random | |
/** | |
* @author Andrew Conway | |
*/ | |
object WeightedRandomSelection { | |
/** | |
* Get the number of times an event with probability p occurs in N samples. | |
* if R is res, then P(R=n) = p^n q^(N-n) N! / n! / (N-n)! | |
* where q = 1-p | |
* This has the property that P(R=0) = q^N, and | |
* P(R=n+1) = p/q (N-n)/(n+1) P(R=n) | |
* Also note that P(R=n+1|R>n) = P(R=n+1)/P(R>n) | |
* Uses these facts to work out the probability that the result is zero. If | |
* not, then the prob that given that, the result is 1, etc. | |
*/ | |
def numEntries(p:Double,N:Int,r:Random) : Int = if (p>0.5) N-numEntries(1.0-p,N,r) else if (p<0.0) 0 else { | |
var n = 0 | |
val q = 1.0-p | |
var prstop = Math.pow(q,N) | |
var cumulative = 0.0 | |
while (n<N && (r.nextDouble()*(1-cumulative))>=prstop) { | |
cumulative+=prstop | |
prstop*=p*(N-n)/(q*(n+1)) | |
n+=1 | |
} | |
n | |
} | |
case class WeightedItem[T](item: T, weight: Double) | |
/** | |
* Compute a weighted selection from the given items. | |
* cumulativeSum must be the same length as items (or longer), with the ith element containing the sum of all | |
* weights from the item i to the end of the list. This is done in a saved way rather than adding up and then | |
* subtracting in order to prevent rounding errors from causing a variety of subtle problems. | |
*/ | |
private def weightedSelectionWithCumSum[T](items: Seq[WeightedItem[T]],cumulativeSum:List[Double], numSelections:Int, r: Random) : Seq[T] = { | |
if (numSelections==0) Nil | |
else { | |
val head = items.head | |
val nhead = numEntries(head.weight/cumulativeSum.head,numSelections,r) | |
List.fill(nhead)(head.item)++weightedSelectionWithCumSum(items.tail,cumulativeSum.tail,numSelections-nhead,r) | |
} | |
} | |
def weightedSelection[T](items: Seq[WeightedItem[T]], numSelections:Int, r: Random): Seq[T] = { | |
val cumsum = items.foldRight(List(0.0)){(wi,l)=>(wi.weight+l.head)::l} | |
weightedSelectionWithCumSum(items,cumsum,numSelections,r) | |
} | |
def testRandomness[T](items: Seq[WeightedItem[T]], numSelections:Int, r: Random) { | |
val runs = 10000 | |
val indexOfItem = Map.empty++items.zipWithIndex.map{case (item,ind)=>item.item->ind} | |
val numItems = items.length | |
val bucketSums = new Array[Double](numItems) | |
val bucketSumSqs = new Array[Double](numItems) | |
for (run<-0 until runs) { | |
// compute chi-squared for a run | |
val runresult = weightedSelection(items,numSelections,r) | |
val buckets = new Array[Double](numItems) | |
for (r<-runresult) buckets(indexOfItem(r))+=1 | |
for (i<-0 until numItems) { | |
val count = buckets(i) | |
bucketSums(i)+=count | |
bucketSumSqs(i)+=count*count | |
} | |
} | |
val sumWeights = items.foldLeft(0.0)(_+_.weight) | |
for ((item,ind)<-items.zipWithIndex) { | |
val p = item.weight/sumWeights | |
val mean = bucketSums(ind)/runs | |
val variance = bucketSumSqs(ind)/runs-mean*mean | |
val expectedMean = numSelections*p | |
val expectedVariance = numSelections*p*(1-p) | |
val expectedErrorInMean = Math.sqrt(expectedVariance/runs) | |
val text = "Item %10s Mean %.3f Expected %.3f±%.3f Variance %.3f expected %.3f".format(item.item,mean,expectedMean,expectedErrorInMean,variance,expectedVariance) | |
println(text) | |
} | |
} | |
def main(args: Array[String]): Unit = { | |
val items = Seq(WeightedItem("Red", 1d/6), WeightedItem("Blue", 2d/6), WeightedItem("Green", 3d/6) ) | |
println(weightedSelection(items, 6, new Random())) | |
testRandomness(items, 6, new Random()) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment