From a9c4798ddb2f5ed52c7bcdac6b83510f62996795 Mon Sep 17 00:00:00 2001 From: Bardur Arantsson Date: Mon, 15 Jul 2013 22:17:36 +0200 Subject: Make Weighted Random Selection algorithm higher level --- src/skills.cc | 79 +++++++++++++++++++++++------------------------------------ 1 file changed, 30 insertions(+), 49 deletions(-) diff --git a/src/skills.cc b/src/skills.cc index 221a5e27..b5b90d77 100644 --- a/src/skills.cc +++ b/src/skills.cc @@ -12,12 +12,12 @@ #include "angband.h" -#include #include "hooks.h" #include "util.hpp" #include #include +#include #include #include #include @@ -1240,19 +1240,18 @@ void init_skill(s32b value, s32b mod, int i) } /* - * Perform weighted random selection without replacement according to - * the algorithm given in "Weighted Random Sampling" (2005, Eframidis, - * Spirakis) + * Perform weighted random shuffle according to the algorithm given in + * "Weighted Random Sampling" (2005, Eframidis, Spirakis). * - * @param k is the total number of items to choose. This MUST be smaller than or equal to the number of weights. - * @param weights is the array of weights. - * @return an output vector of size k containing the chosen indices. + * @param weights is the vector of weights. + * @return an output vector of the same size as the input weights vector containing, + * in order of selection, the indices to select. For example, if you + * need to choose two items, you would use v[0], v[1] as the indices + * to pick. */ -static std::vector wrs_without_replacement(size_t k, const std::vector &unscaled_weights) +static std::vector wrs(const std::vector &unscaled_weights) { - size_t n = unscaled_weights.size(); - - assert(k <= n); + const size_t n = unscaled_weights.size(); /* Rescale weights into unit interval for numerical stability */ std::vector weights(unscaled_weights.size()); @@ -1270,47 +1269,30 @@ static std::vector wrs_without_replacement(size_t k, const std::vector keys(unscaled_weights.size()); + /* Generate the keys and indexes to use for selection. This + is the only randomized portion of the algorithm. */ + std::vector> keys_and_indexes(unscaled_weights.size()); for (size_t i = 0; i < n; i++) { - double u = ((double) rand_int(100000)) / ((double) 100000); - keys[i] = pow(u, 1/weights[i]); + /* Randomized keys according to the algorithm. */ + double u = static_cast(rand_int(100000)) / 100000; + double k = std::pow(u, 1/weights[i]); + /* Set up the key and index. We negate the k value + here so that keys will be sorted in descending + order rather than ascending order. */ + keys_and_indexes[i] = std::make_tuple(-k, i); } - /* Generate the initial permutation */ - std::vector permutation(unscaled_weights.size()); - for (size_t i = 0; i < n; i++) { - permutation[i] = i; - } + /* Sort indexes according to keys. Since the keys have been + negated and we're using a lexicographical sort, we're + effectively sorting in descending order of key. */ + std::sort(std::begin(keys_and_indexes), + std::end(keys_and_indexes)); - /* Select the k indexes with the largest keys */ + /* Produce the output vector consisting of indexes only. */ std::vector indexes; - for (size_t i = 0; i < k; i++) { - /* Find maximal value and its index */ - size_t max_idx = i; - double max_value = keys[max_idx]; - for (size_t j = i + 1; j < n; j++) { - if (keys[j] > max_value) { - max_idx = j; - max_value = keys[j]; - } - } - - /* Swap into k'th position */ - if (max_idx != i) { - /* Swap keys */ - std::swap(keys[i], keys[max_idx]); - /* Swap indexes in permutation */ - std::swap(permutation[i], permutation[max_idx]); - } - - /* Output the k'th choice. We can do this already - since we'll never revisit the i'th position in - permutation vector. */ - indexes.push_back(permutation[i]); + for (auto const &key_and_index: keys_and_indexes) { + indexes.push_back(std::get<1>(key_and_index)); } - return indexes; } @@ -1341,9 +1323,8 @@ void do_get_new_skill() weights.push_back(s_info[available_skills[i]].random_gain_chance); } - std::vector indexes = - wrs_without_replacement(LOST_SWORD_NSKILLS, weights); - assert(indexes.size() == LOST_SWORD_NSKILLS); + std::vector indexes = wrs(weights); + assert(indexes.size() >= LOST_SWORD_NSKILLS); /* Extract the information needed from the skills */ for (i = 0; i < LOST_SWORD_NSKILLS; i++) -- cgit v1.2.3