Year-End Discount: 10% OFF 1-year and 20% OFF 2-year subscriptions!

Trusted answers to developer questions
Trusted Answers to Developer Questions

Related Tags

What is the weighted random selection algorithm?

Ani Tumanyan

Tired of LeetCode? 😩

Learn the 24 patterns to solve any coding interview question without getting lost in a maze of LeetCode-style practice problems. Practice your skills in a hands-on, setup-free coding environment. πŸ’ͺ

Problem statement

We are given an array of positive integers representing the weights of their respective indices. We need to implement an algorithm that selects indices randomly with respect to their weights.

svg viewer
Perform weighted random selection of indices

Let:

  • ww be an array of weights
  • sum(w)sum(w) be the sum of weights
  • length(w)length(w) be the length of the array

Here, we need to implement the pickIndex() function that returns an index i∈[0,length(w)βˆ’1]i \in [0, length(w) -1] with the probability pi=wi/sum(w)p_i = w_i/sum(w).

pickIndex() can be called multiple times on the given distribution of weights.

Example

For w={2,1,4,5}w=\{2, 1, 4, 5\}, pickIndex() must have:

  • 2(2+1+4+5)\frac{2}{(2+1+4+5)} = 212\frac{2}{12} = 16\frac{1}{6} β‰ˆ 16.6% chance of picking index 0

  • 1(1+2+4+5)\frac{1}{(1+2+4+5)} = 112\frac{1}{12} β‰ˆ 8.3% chance of picking index 1

  • 412\frac{4}{12} = 13\frac{1}{3} β‰ˆ 33.3% chance of picking index 2

  • 512\frac{5}{12} β‰ˆ 41.6% of picking index 3

Algorithm

First, we preprocess the array of weights and calculate the prefix sums. Let ss be an array of prefix sums, where si=βˆ‘j=0iwjs_i = \sum_{j=0}^{i}w_j.

For random selection with particular weights, the following technique can then be used:

  1. Generate a random number xx between 00 and sum(w)βˆ’1sum(w)-1.
  2. Find the smallest index that corresponds to the prefix sum greater than the randomly chosen number.
svg viewer
Index 2 would be selected with probability w2/sum(w)

We can use a binary search algorithm for the second step as the array of prefix sums is sorted (weights are positive integers).

Hence, we detect the iith interval of prefix sums in which the randomly chosen value lies. The length of the iith interval is wiw_i. Therefore, the probability that ithi^{th} index is to be selected is wisum(w)\frac{w_i}{sum(w)}.

Code

#include <iostream>
#include <vector>
#include <algorithm>
#include <random>
using namespace std;
class WeightedRandomGenerator {
public:
WeightedRandomGenerator(const vector<int>& w):
randomGen_{random_device{}()}
{
if(w.empty())
return;
//calculate prefix sums for given weights
prefixSums_.resize(w.size());
prefixSums_[0] = w[0];
for(size_t i = 1; i < w.size(); ++i)
{
prefixSums_[i] = prefixSums_[i-1] + w[i];
}
}
int pickIndex()
{
if(prefixSums_.empty())
return -1;
//get sum of all weights
int totalSum = prefixSums_.back();
//get random value between 0 and totalSum - 1
int value = getRandom(0, totalSum - 1);
//get index of the first element in prefix sums greater than value
auto it = upper_bound(prefixSums_.begin(), prefixSums_.end(), value);
return distance(prefixSums_.begin(), it);
}
private:
int getRandom(int left, int right)
{
//C++11's random number generation facilities
std::uniform_int_distribution<> distrib(left, right);
return distrib(randomGen_);
}
private:
vector<int> prefixSums_;
std::mt19937 randomGen_;
};
int main()
{
//create vector with specific weights distribution
vector<int> w = {2, 1, 4, 5};
//create generator object
WeightedRandomGenerator generator(w);
//test generator: calculate frequencies of choosing each index during 100 calls
vector<int> counts(w.size());
for(int i = 0; i < 100; ++i)
{
++counts[generator.pickIndex()];
}
//output testing results
for(size_t j = 0; j < counts.size(); ++j)
{
cout << "Index " << j << " is chosen " << counts[j] << " times." << endl;
}
return 0;
}
Implementation of weighted random selection algorithm in C++

Complexity analysis

Let us take n=length(w)n = length(w) as the number of weights. The preprocessing step (construction of the WeightedRandomGenerator object) has O(n)O(n) time complexity. This is the time complexity as we iterate over the weights and calculate prefix sums in linear time.

pickIndex() searches a random number in a sorted array of prefix sums using a binary search. Therefore, the time complexity of each pickIndex() call is logarithmic – O(logn)O(logn).

The space complexity is O(n)O(n) for storing prefix sums array.

RELATED TAGS

Tired of LeetCode? 😩

Learn the 24 patterns to solve any coding interview question without getting lost in a maze of LeetCode-style practice problems. Practice your skills in a hands-on, setup-free coding environment. πŸ’ͺ

Keep Exploring
Related Courses