Last active
June 15, 2020 15:38
-
-
Save ashelly/cc0cd6916235a04df49a to your computer and use it in GitHub Desktop.
Fast Weighted Random Sampling from discrete distributions. i.e Selecting items from a weighted list.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
WRS.c | |
(c) AShelly (github.com/ashelly) | |
Weighted random sampling with replacement of N items in O(1) time. | |
(After preparing a O(N) sized buffer in O(N) time.) | |
The concept is: | |
Randomly select a buffer index. Each index is selected with probablilty 1/N. | |
Each index stores the fraction of hits for which this item should be selected, | |
and the index of another item, which will be selected if this one is not. | |
Imagine a histogram contining [2,4,5,6,3]. It has 5 buckets, sum = 20. | |
We transform it by normalizing, then dividing each bucket's weight by 1/5. | |
Then we "fill up" the underpopulated buckets from the overpopulated ones. | |
AA 0.50 -> AACC 0.50 AACC 0.50 AACC 0.50 | |
BBBB 1.00 BBBB 1.00 BBBB 1.00 BBBB 1.00 | |
CCCCC 1.25 -> CCC 0.75(+0.50A) -> CCCD 0.75(+0.50A) CCCD 0.75(+0.50A) | |
DDDDDD 1.50 DDDDDD 1.50 -> DDDDD 1.25(+0.25C) -> DDDD 1.00(+0.25C+0.25E) | |
EEE 0.75 EEE 0.75 EEE 0.75 -> EEED 0.75 | |
Independently invented by me in 2014, only 40 years after the concept was published by | |
A. J. Walker in 1974 as "New fast method for generating discrete random numbers with | |
arbitrary frequency distributions". See http://en.wikipedia.org/wiki/Alias_method | |
*/ | |
#include <stdio.h> | |
#include <stdlib.h> | |
//helper for demo only. !!Not ideal uniform distribution.!! | |
double rand_percent() { | |
return ((double)rand())/RAND_MAX; | |
} | |
//WRS data structure | |
typedef struct wrs_data { | |
size_t N; //number items | |
double *share; //fractional share | |
size_t *pair; //remainder here | |
} wrs_t; | |
//Create the pre-processed data structure, allowing fast weighted selection. | |
//weights do not have to be normalized, we will do that first. | |
wrs_t* wrs_create(double* weights, size_t N) { | |
//make space | |
wrs_t* data = malloc(sizeof(wrs_t)); | |
data->share = malloc(N *sizeof(double)); | |
data->pair = malloc(N * sizeof(size_t)); | |
data->N = N; | |
double sum = 0; | |
size_t i, j, k; | |
//Normalize and find what fraction of the ideal distribution is in each bucket. | |
//Set bucket's initial partner to self. acts as 'unprocessed' marker, and | |
// handles small rounding errors: if there are no big buckets left, excess goes back to self. | |
for (i=0; i<N; i++) { sum += weights[i]; } | |
for (i=0; i<N; i++) { | |
data->share[i] = weights[i] / (sum/N); | |
data->pair[i] = i; | |
} | |
//Find first overpopulated bucket | |
for (j=0; j<N && !(data->share[j] > 1.0); j++) {/*seek*/} | |
for (i=0; i<N; i++) { | |
k = i; // k is bucket under consideration | |
if (data->pair[k]!=i) continue; // reject already considered buckets | |
//If this bucket has less samples than a flat distribution, | |
//it will be selected more frequently than it should be. | |
double excess = 1.0 - data->share[k]; | |
while (excess > 0 ) { | |
if (j == N) { break; } // no more partners, close enough. | |
printf("moving %.5f from %ld to %ld\n",excess,k,j); | |
data->pair[k]=j; // send excess hits to another bucket | |
data->share[j] -= excess; // account for increased selection rate | |
excess = 1.0 - data->share[j]; | |
// If new bucket is now underpopulated, repeat with next over-full bucket | |
if (excess >= 0) { | |
for (k=j++; j<N && !(data->share[j] > 1.0); j++) {/*seek*/} | |
} | |
} | |
} | |
return data; | |
} | |
//O(1) weighted random sampling. | |
//Choose a real number. Treat as distance into the array. | |
//if the fractional part is greater than that bucket's allocation, use it's paired bucket. | |
size_t wrs_pick(wrs_t* data) | |
{ | |
double pct = rand_percent()*data->N; | |
size_t idx = (int)pct; | |
if (pct-idx > data->share[idx]) { idx = data->pair[idx]; } | |
return idx; | |
} | |
//Clean up nicely | |
void wrs_destroy(wrs_t* data){ | |
free(data->pair); | |
free(data->share); | |
free(data); | |
} | |
//util to back out the normalized weights from the pre-processed data. | |
double* wrs_norm(wrs_t* data) { | |
double* pct = calloc(data->N, sizeof(double)); | |
size_t i; | |
for (i=0;i<data->N;i++) { | |
pct[i]+=data->share[i]; | |
pct[data->pair[i]]+=1.0-data->share[i]; | |
} | |
return pct; | |
} | |
/** sample usage **/ | |
//double weights[]= {20,1,4,10,15,10,16,10,8,6}; | |
double weights[]= {2,4,5,6,3}; | |
#define NW (sizeof(weights)/sizeof(weights[0])) | |
#define TRIALS 1000 | |
int main(int argc,char*argv[]){ | |
//pre-process input data | |
wrs_t* dist = wrs_create(weights, NW); | |
//show normalized weights | |
int i; | |
double* d=wrs_norm(dist); | |
printf("\n"); | |
for (i=0;i<NW;i++) { | |
printf("%.5f ",d[i]/NW); | |
} | |
free(d); | |
//build new histogram | |
int samples[NW]={0}; | |
for (i=0; i<TRIALS; i++){ | |
samples[wrs_pick(dist)]++; | |
} | |
//show that generated data matches | |
printf("\n--------------\n"); | |
for (i=0;i<NW;i++) { printf("%.5f ",(double)samples[i]/TRIALS);} | |
printf("\n"); | |
//cleanup | |
wrs_destroy(dist); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thank you for this :) Its a very clever way to do it. I had a go at this and while it works, my implementation doesnt perform faster than a linear search? Maybe I'm doing something wrong? I've posted about it here if you are interested: https://stackoverflow.com/questions/62391780/walkers-alias-method-for-weighted-random-selection-isnt-faster-than-a-linear-se