Last active
December 22, 2015 10:59
-
-
Save dylon/6462512 to your computer and use it in GitHub Desktop.
Demonstrates how to randomly sort a collection of elements, where each class may have a different weight (all weights must be non-negative and sum to 1.0)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.util.Map; | |
import java.util.HashMap; | |
public class WeightedRandomTest { | |
private static final int NUM_CLASSIFICATIONS = 1_000_000; | |
private static Classification classify() { | |
double factor = Math.random(); | |
for (final Classification classification : Classification.values()) { | |
factor -= classification.getWeight(); | |
if (factor <= 0.0) { | |
return classification; | |
} | |
} | |
return null; | |
} | |
public static void main(final String[] args) { | |
final Map<Classification, Integer> histogram = new HashMap<>(); | |
for (final Classification classification : Classification.values()) { | |
histogram.put(classification, 0); | |
} | |
for (int i = 0; i < NUM_CLASSIFICATIONS; ++i) { | |
final Classification classification = classify(); | |
histogram.put(classification, 1 + histogram.get(classification)); | |
} | |
for (final Classification classification : Classification.values()) { | |
System.out.printf("%s{expected=%.5f, actual=%.5f}\n", | |
classification, | |
classification.getWeight(), | |
(double) histogram.get(classification) / NUM_CLASSIFICATIONS); | |
} | |
} | |
private enum Classification { | |
TRAIN(0.60), | |
TEST(0.20), | |
CROSS_VALIDATION(0.20); | |
private final double weight; | |
private Classification(final double weight) { | |
this.weight = weight; | |
} | |
public double getWeight() { | |
return weight; | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Sample output: