Trusted answers to developer questions

What is the weighted random selection algorithm?

Get Started With Machine Learning

Learn the fundamentals of Machine Learning with this free course. Future-proof your career by adding ML skills to your toolkit — or prepare to land a job in AI or Data Science.

Problem statement

We are given an array of positive integers representing the weights of their respective indices. We need to implement an algorithm that selects indices randomly with respect to their weights.

Perform weighted random selection of indices
Perform weighted random selection of indices

Let:

  • ww be an array of weights
  • sum(w)sum(w) be the sum of weights
  • length(w)length(w) be the length of the array

Here, we need to implement the pickIndex() function that returns an index i[0,length(w)1]i \in [0, length(w) -1] with the probability pi=wi/sum(w)p_i = w_i/sum(w).

pickIndex() can be called multiple times on the given distribution of weights.

Example

For w={2,1,4,5}w=\{2, 1, 4, 5\}, pickIndex() must have:

  • 2(2+1+4+5)\frac{2}{(2+1+4+5)} = 212\frac{2}{12} = 16\frac{1}{6} ≈ 16.6% chance of picking index 0

  • 1(1+2+4+5)\frac{1}{(1+2+4+5)} = 112\frac{1}{12} ≈ 8.3% chance of picking index 1

  • 412\frac{4}{12} = 13\frac{1}{3} ≈ 33.3% chance of picking index 2

  • 512\frac{5}{12} ≈ 41.6% of picking index 3

Algorithm

First, we preprocess the array of weights and calculate the prefix sums. Let ss be an array of prefix sums, where si=j=0iwjs_i = \sum_{j=0}^{i}w_j.

For random selection with particular weights, the following technique can then be used:

  1. Generate a random number xx between 00 and sum(w)1sum(w)-1.
  2. Find the smallest index that corresponds to the prefix sum greater than the randomly chosen number.
Index 2 would be selected with probability w2/sum(w)
Index 2 would be selected with probability w2/sum(w)

We can use a binary search algorithm for the second step as the array of prefix sums is sorted (weights are positive integers).

Hence, we detect the iith interval of prefix sums in which the randomly chosen value lies. The length of the iith interval is wiw_i. Therefore, the probability that ithi^{th} index is to be selected is wisum(w)\frac{w_i}{sum(w)}.

Code

#include <iostream>
#include <vector>
#include <algorithm>
#include <random>
using namespace std;
class WeightedRandomGenerator {
public:
WeightedRandomGenerator(const vector<int>& w):
randomGen_{random_device{}()}
{
if(w.empty())
return;
//calculate prefix sums for given weights
prefixSums_.resize(w.size());
prefixSums_[0] = w[0];
for(size_t i = 1; i < w.size(); ++i)
{
prefixSums_[i] = prefixSums_[i-1] + w[i];
}
}
int pickIndex()
{
if(prefixSums_.empty())
return -1;
//get sum of all weights
int totalSum = prefixSums_.back();
//get random value between 0 and totalSum - 1
int value = getRandom(0, totalSum - 1);
//get index of the first element in prefix sums greater than value
auto it = upper_bound(prefixSums_.begin(), prefixSums_.end(), value);
return distance(prefixSums_.begin(), it);
}
private:
int getRandom(int left, int right)
{
//C++11's random number generation facilities
std::uniform_int_distribution<> distrib(left, right);
return distrib(randomGen_);
}
private:
vector<int> prefixSums_;
std::mt19937 randomGen_;
};
int main()
{
//create vector with specific weights distribution
vector<int> w = {2, 1, 4, 5};
//create generator object
WeightedRandomGenerator generator(w);
//test generator: calculate frequencies of choosing each index during 100 calls
vector<int> counts(w.size());
for(int i = 0; i < 100; ++i)
{
++counts[generator.pickIndex()];
}
//output testing results
for(size_t j = 0; j < counts.size(); ++j)
{
cout << "Index " << j << " is chosen " << counts[j] << " times." << endl;
}
return 0;
}

Complexity analysis

Let us take n=length(w)n = length(w) as the number of weights. The preprocessing step (construction of the WeightedRandomGenerator object) has O(n)O(n) time complexity. This is the time complexity as we iterate over the weights and calculate prefix sums in linear time.

pickIndex() searches a random number in a sorted array of prefix sums using a binary search. Therefore, the time complexity of each pickIndex() call is logarithmic – O(logn)O(logn).

The space complexity is O(n)O(n) for storing prefix sums array.

RELATED TAGS

Did you find this helpful?