From 21ef37a7a4439125a3bfff9db9a1ed260f897f26 Mon Sep 17 00:00:00 2001 From: Makihiro Date: Mon, 1 Feb 2021 23:12:42 +0900 Subject: [PATCH] Update WeightedSelectMethods - Added the WeightedSelectMethodTests. - All WeightedSelectMethods now support weights less than or equal to 0. - IWeightedSelectMethod now takes a TemporaryArray as an argument instead of a float[]. - Fixed a fatal bug that prevented BinaryWeightedSelectMethod from working. --- .../BinaryWeightedSelectMethod.cs | 62 ++++---- .../IWeightedSelectMethod.cs | 3 +- .../LinearWeightedSelectMethod.cs | 21 +-- .../WeightedSelector/WeightedSelector.cs | 3 +- .../Tests/Editor/EnumerableConversionTests.cs | 40 ++--- .../Tests/Editor/ItemEnumerableGenerator.cs | 49 +++++++ .../Editor/ItemEnumerableGenerator.cs.meta | 11 ++ .../Tests/Editor/WeightedSelectMethodTests.cs | 52 +++++++ .../Editor/WeightedSelectMethodTests.cs.meta | 11 ++ README.md | 138 ++++++++++++++---- 10 files changed, 294 insertions(+), 96 deletions(-) create mode 100644 Assets/MackySoft/MackySoft.Choice/Tests/Editor/ItemEnumerableGenerator.cs create mode 100644 Assets/MackySoft/MackySoft.Choice/Tests/Editor/ItemEnumerableGenerator.cs.meta create mode 100644 Assets/MackySoft/MackySoft.Choice/Tests/Editor/WeightedSelectMethodTests.cs create mode 100644 Assets/MackySoft/MackySoft.Choice/Tests/Editor/WeightedSelectMethodTests.cs.meta diff --git a/Assets/MackySoft/MackySoft.Choice/Runtime/WeightedSelectMethod/BinaryWeightedSelectMethod.cs b/Assets/MackySoft/MackySoft.Choice/Runtime/WeightedSelectMethod/BinaryWeightedSelectMethod.cs index a71bf4e..a66daf0 100644 --- a/Assets/MackySoft/MackySoft.Choice/Runtime/WeightedSelectMethod/BinaryWeightedSelectMethod.cs +++ b/Assets/MackySoft/MackySoft.Choice/Runtime/WeightedSelectMethod/BinaryWeightedSelectMethod.cs @@ -13,40 +13,46 @@ internal sealed class BinaryWeightedSelectMethod : IWeightedSelectMethod { public static readonly BinaryWeightedSelectMethod Instance = new BinaryWeightedSelectMethod(); - public int SelectIndex (float[] weights,float value) { - if (weights == null) { - throw new ArgumentNullException(nameof(weights)); - } - - using TemporaryArray runningTotals = CumulativeSum(weights); - - float targetDistance = value * runningTotals[runningTotals.Length - 1]; - int low = 0; - int high = weights.Length; - - while (low < high) { - int mid = (int)Math.Round((low + high) / 2f); - float distance = runningTotals[mid]; - if (distance < targetDistance) { - low = mid + 1; - } else if (distance > targetDistance) { - high = mid; - } else { - return mid; + public int SelectIndex (TemporaryArray weights,float value) { + CumulativeSum(weights,out var runningTotals,out var indicies); + + using (runningTotals) + using (indicies) { + float targetDistance = value * runningTotals[runningTotals.Length - 1]; + int low = 0; + int high = runningTotals.Length; + + while (low < high) { + int mid = (int)Math.Floor((low + high) / 2f); + float distance = runningTotals[mid]; + if (distance < targetDistance) { + low = mid + 1; + } else if (distance > targetDistance) { + high = mid; + } else { + return indicies[mid]; + } } - } - return low; + return indicies[low]; + } } - static TemporaryArray CumulativeSum (float[] values) { - var results = TemporaryArray.Create(values.Length); + static void CumulativeSum (TemporaryArray weights,out TemporaryArray runningTotals,out TemporaryArray indicies) { + runningTotals = TemporaryArray.CreateAsList(weights.Length); + indicies = TemporaryArray.Create(weights.Length); float sum = 0f; - for (int i = 0;i < results.Length;i++) { - sum += results[i]; - results[i] = sum; + int nonZeroIteration = 0; + for (int i = 0;i < indicies.Length;i++) { + float weight = weights[i]; + if (weight <= 0f) { + continue; + } + sum += weight; + runningTotals.Add(sum); + indicies[nonZeroIteration] = i; + nonZeroIteration++; } - return results; } } diff --git a/Assets/MackySoft/MackySoft.Choice/Runtime/WeightedSelectMethod/IWeightedSelectMethod.cs b/Assets/MackySoft/MackySoft.Choice/Runtime/WeightedSelectMethod/IWeightedSelectMethod.cs index 5e0291b..257c3df 100644 --- a/Assets/MackySoft/MackySoft.Choice/Runtime/WeightedSelectMethod/IWeightedSelectMethod.cs +++ b/Assets/MackySoft/MackySoft.Choice/Runtime/WeightedSelectMethod/IWeightedSelectMethod.cs @@ -1,4 +1,5 @@ ο»Ώusing System; +using MackySoft.Choice.Internal; namespace MackySoft.Choice { @@ -14,7 +15,7 @@ public interface IWeightedSelectMethod { /// /// Selected index from weights. /// - int SelectIndex (float[] weights,float value); + int SelectIndex (TemporaryArray weights,float value); } diff --git a/Assets/MackySoft/MackySoft.Choice/Runtime/WeightedSelectMethod/LinearWeightedSelectMethod.cs b/Assets/MackySoft/MackySoft.Choice/Runtime/WeightedSelectMethod/LinearWeightedSelectMethod.cs index 59484bd..11683fe 100644 --- a/Assets/MackySoft/MackySoft.Choice/Runtime/WeightedSelectMethod/LinearWeightedSelectMethod.cs +++ b/Assets/MackySoft/MackySoft.Choice/Runtime/WeightedSelectMethod/LinearWeightedSelectMethod.cs @@ -1,5 +1,7 @@ ο»Ώusing System; using System.Runtime.CompilerServices; +using MackySoft.Choice.Internal; +using UnityEngine; namespace MackySoft.Choice { @@ -11,27 +13,28 @@ internal sealed class LinearWeightedSelectMethod : IWeightedSelectMethod { public static readonly LinearWeightedSelectMethod Instance = new LinearWeightedSelectMethod(); - public int SelectIndex (float[] weights,float value) { - if (weights == null) { - throw new ArgumentNullException(nameof(weights)); - } - + public int SelectIndex (TemporaryArray weights,float value) { float remainingDistance = value * Sum(weights); for (int i = 0;i < weights.Length;i++) { float weight = weights[i]; + if (weight <= 0f) { + continue; + } + remainingDistance -= weight; - if (remainingDistance < 0f) { + if (remainingDistance <= 0f) { return i; } } + return -1; } [MethodImpl(MethodImplOptions.AggressiveInlining)] - static float Sum (float[] values) { + static float Sum (TemporaryArray weights) { float result = 0f; - for (int i = 0;i < values.Length;i++) { - result += values[i]; + for (int i = 0;i < weights.Length;i++) { + result += weights[i]; } return result; } diff --git a/Assets/MackySoft/MackySoft.Choice/Runtime/WeightedSelector/WeightedSelector.cs b/Assets/MackySoft/MackySoft.Choice/Runtime/WeightedSelector/WeightedSelector.cs index bb15786..3c39bfe 100644 --- a/Assets/MackySoft/MackySoft.Choice/Runtime/WeightedSelector/WeightedSelector.cs +++ b/Assets/MackySoft/MackySoft.Choice/Runtime/WeightedSelector/WeightedSelector.cs @@ -40,7 +40,6 @@ public WeightedSelector (IEnumerable> source,IWeighted m_Method = method; } - internal WeightedSelector (TemporaryArray items,TemporaryArray weights,IWeightedSelectMethod method) { if (items.Length != weights.Length) { throw new ArgumentException(); @@ -94,7 +93,7 @@ public void Clear () { } public TItem SelectItem (float value) { - int index = m_Method.SelectIndex(m_Weights.Array,value); + int index = m_Method.SelectIndex(m_Weights,value); return (index >= 0) ? m_Items[index] : default; } diff --git a/Assets/MackySoft/MackySoft.Choice/Tests/Editor/EnumerableConversionTests.cs b/Assets/MackySoft/MackySoft.Choice/Tests/Editor/EnumerableConversionTests.cs index 41b5bdb..443d3c4 100644 --- a/Assets/MackySoft/MackySoft.Choice/Tests/Editor/EnumerableConversionTests.cs +++ b/Assets/MackySoft/MackySoft.Choice/Tests/Editor/EnumerableConversionTests.cs @@ -1,25 +1,21 @@ using System.Linq; using System.Collections.Generic; -using UnityEngine; using NUnit.Framework; +using MackySoft.Choice.Tests; namespace MackySoft.Choice.Internal.Tests { public class EnumerableConversionTests { - class Item { - public float weight; - } - [Test] public void EnumerableToTemporaryArray () { - Item[] source = GenerateEnumerable().ToArray(); + ItemEntry[] source = ItemEnumerableGenerator.GenerateEnumerable(100).ToArray(); - EnumerableConversion.EnumerableToTemporaryArray(source.Select(x => x),x => x.weight,out var items,out var weights); + EnumerableConversion.EnumerableToTemporaryArray(source.Select(x => x),x => x.item,x => x.weight,out var items,out var weights); for (int i = 0;source.Length > i;i++) { var element = source[i]; - Assert.AreEqual(element,items[i]); + Assert.AreEqual(element.item,items[i]); Assert.AreEqual(element.weight,weights[i]); } @@ -29,13 +25,13 @@ public void EnumerableToTemporaryArray () { [Test] public void ReadOnlyListToTemporaryArray () { - Item[] source = GenerateEnumerable().ToArray(); + ItemEntry[] source = ItemEnumerableGenerator.GenerateEnumerable(100).ToArray(); - EnumerableConversion.EnumerableToTemporaryArray(source,x => x.weight,out var items,out var weights); + EnumerableConversion.EnumerableToTemporaryArray(source,x => x.item,x => x.weight,out var items,out var weights); for (int i = 0;source.Length > i;i++) { var element = source[i]; - Assert.AreEqual(element,items[i]); + Assert.AreEqual(element.item,items[i]); Assert.AreEqual(element.weight,weights[i]); } @@ -45,13 +41,13 @@ public void ReadOnlyListToTemporaryArray () { [Test] public void ListToTemporaryArray () { - List source = GenerateEnumerable().ToList(); + List source = ItemEnumerableGenerator.GenerateEnumerable(100).ToList(); - EnumerableConversion.EnumerableToTemporaryArray(source,x => x.weight,out var items,out var weights); + EnumerableConversion.EnumerableToTemporaryArray(source,x => x.item,x => x.weight,out var items,out var weights); for (int i = 0;source.Count > i;i++) { var element = source[i]; - Assert.AreEqual(element,items[i]); + Assert.AreEqual(element.item,items[i]); Assert.AreEqual(element.weight,weights[i]); } @@ -61,7 +57,7 @@ public void ListToTemporaryArray () { [Test] public void DictionaryToTemporaryArray () { - Dictionary source = GenerateDictionary().ToDictionary(p => p.Key,p => p.Value); + Dictionary source = ItemEnumerableGenerator.GenerateDictionary(100).ToDictionary(p => p.Key,p => p.Value); EnumerableConversion.DictionaryToTemporaryArray(source,out var items,out var weights); @@ -78,7 +74,7 @@ public void DictionaryToTemporaryArray () { [Test] public void DictionaryEnumerableToTemporaryArray () { - Dictionary source = GenerateDictionary().ToDictionary(p => p.Key,p => p.Value); + Dictionary source = ItemEnumerableGenerator.GenerateDictionary(100).ToDictionary(p => p.Key,p => p.Value); EnumerableConversion.DictionaryToTemporaryArray(source.Select(x => x),out var items,out var weights); @@ -93,17 +89,5 @@ public void DictionaryEnumerableToTemporaryArray () { weights.Dispose(); } - static IEnumerable GenerateEnumerable () { - for (int i = 0;100 > i;i++) { - yield return new Item { weight = Random.value }; - } - } - - static IEnumerable> GenerateDictionary () { - for (int i = 0;100 > i;i++) { - yield return new KeyValuePair(new Item(),Random.value); - } - } - } } \ No newline at end of file diff --git a/Assets/MackySoft/MackySoft.Choice/Tests/Editor/ItemEnumerableGenerator.cs b/Assets/MackySoft/MackySoft.Choice/Tests/Editor/ItemEnumerableGenerator.cs new file mode 100644 index 0000000..8995bd1 --- /dev/null +++ b/Assets/MackySoft/MackySoft.Choice/Tests/Editor/ItemEnumerableGenerator.cs @@ -0,0 +1,49 @@ +using System.Collections.Generic; +using UnityEngine; + +namespace MackySoft.Choice.Tests { + + public class Item { + + public int id; + + public override string ToString () { + return "Index: " + id.ToString(); + } + + } + + public struct ItemEntry { + + public Item item; + public float weight; + + public override string ToString () { + return $"{{ {item}: Weight {weight} }}"; + } + + } + + public class ItemEnumerableGenerator { + + public static IEnumerable GenerateEnumerable (int count) { + for (int i = 0;count > i;i++) { + yield return new ItemEntry { + item = new Item { id = i }, + weight = Random.Range(0,10) + }; + } + } + + public static IEnumerable> GenerateDictionary (int count) { + for (int i = 0;count > i;i++) { + yield return new KeyValuePair( + new Item { id = i }, + Random.Range(0,10) + ); + } + } + + } + +} \ No newline at end of file diff --git a/Assets/MackySoft/MackySoft.Choice/Tests/Editor/ItemEnumerableGenerator.cs.meta b/Assets/MackySoft/MackySoft.Choice/Tests/Editor/ItemEnumerableGenerator.cs.meta new file mode 100644 index 0000000..3739019 --- /dev/null +++ b/Assets/MackySoft/MackySoft.Choice/Tests/Editor/ItemEnumerableGenerator.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: d4d9565b73faebf408b7578a4cc2f41e +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/Assets/MackySoft/MackySoft.Choice/Tests/Editor/WeightedSelectMethodTests.cs b/Assets/MackySoft/MackySoft.Choice/Tests/Editor/WeightedSelectMethodTests.cs new file mode 100644 index 0000000..242fe9f --- /dev/null +++ b/Assets/MackySoft/MackySoft.Choice/Tests/Editor/WeightedSelectMethodTests.cs @@ -0,0 +1,52 @@ +using System.Linq; +using System.Collections.Generic; +using UnityEngine; +using NUnit.Framework; + +namespace MackySoft.Choice.Tests { + public class WeightedSelectMethodTests { + + [Test] + public void Linear_ReturnValidValue ([Random(0f,1f,10)] float value) { + var source = ItemEnumerableGenerator.GenerateEnumerable(100).ToArray(); + var weightedSelector = source.ToWeightedSelector(x => x.item,x => x.weight,WeightedSelectMethod.Linear); + Assert.IsNotNull(weightedSelector.SelectItem(value)); + } + + [Test, Repeat(100)] + public void Linear_ReturnValidValue_0 () { + var source = ItemEnumerableGenerator.GenerateEnumerable(100).ToArray(); + var weightedSelector = source.ToWeightedSelector(x => x.item,x => x.weight,WeightedSelectMethod.Linear); + Assert.AreSame(source.FirstOrDefault(x => x.weight > 0f).item,weightedSelector.SelectItem(0f)); + } + + [Test, Repeat(100)] + public void Linear_ReturnValidValue_1 () { + var source = ItemEnumerableGenerator.GenerateEnumerable(100).ToArray(); + var weightedSelector = source.ToWeightedSelector(x => x.item,x => x.weight,WeightedSelectMethod.Linear); + Assert.AreSame(source.LastOrDefault(x => x.weight > 0f).item,weightedSelector.SelectItem(1f)); + } + + [Test] + public void Binary_ReturnValidValue ([Random(0f,1f,10)] float value) { + var source = ItemEnumerableGenerator.GenerateEnumerable(100).ToArray(); + var weightedSelector = source.ToWeightedSelector(x => x.item,x => x.weight,WeightedSelectMethod.Binary); + Assert.IsNotNull(weightedSelector.SelectItem(value)); + } + + [Test, Repeat(100)] + public void Binary_ReturnValidValue_0 () { + var source = ItemEnumerableGenerator.GenerateEnumerable(100).ToArray(); + var weightedSelector = source.ToWeightedSelector(x => x.item,x => x.weight,WeightedSelectMethod.Binary); + Assert.AreSame(weightedSelector.FirstOrDefault(p => p.Value > 0f).Key,weightedSelector.SelectItem(0f)); + } + + [Test, Repeat(100)] + public void Binary_ReturnValidValue_1 () { + var source = ItemEnumerableGenerator.GenerateEnumerable(100).ToArray(); + var weightedSelector = source.ToWeightedSelector(x => x.item,x => x.weight,WeightedSelectMethod.Binary); + Assert.AreSame(source.LastOrDefault(x => x.weight > 0f).item,weightedSelector.SelectItem(1f)); + } + + } +} \ No newline at end of file diff --git a/Assets/MackySoft/MackySoft.Choice/Tests/Editor/WeightedSelectMethodTests.cs.meta b/Assets/MackySoft/MackySoft.Choice/Tests/Editor/WeightedSelectMethodTests.cs.meta new file mode 100644 index 0000000..3a7585d --- /dev/null +++ b/Assets/MackySoft/MackySoft.Choice/Tests/Editor/WeightedSelectMethodTests.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 82c228eeaa8502d4d9607bf259528f11 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/README.md b/README.md index 7b5be91..1396182 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,7 @@ **Created by Hiroya Aramaki ([Makihiro](https://twitter.com/makihiro_dev))** + ## What is Weighted Random Selector ? Weighted Random Selector is an algorithm for randomly selecting elements based on their weights. @@ -16,71 +17,151 @@ Choice is a library that was created to make it easier to implement. Great introduction article on Weighted Random Select: https://blog.bruce-hill.com/a-faster-weighted-random-choice + ## Table of Contents -- [Installation](#installation) -- [Usage](#usage) - - [Select Algorithm](#algorithm) -- [Author Info](#author-info) -- [License](#license) +- [πŸ“₯ Installation](#installation) +- [πŸ”° Usage](#usage) + - [ToWeightedSelector Overloads](#toweightedselector-overloads) + - [LINQ](#linq) + - [Algorithms](#algorithms) +- [πŸ“” Author Info](#author-info) +- [πŸ“œ License](#license) + -# Installation +# πŸ“₯ Installation Download any version from releases. Releases: https://github.com/mackysoft/Choice/releases -# Usage + +# πŸ”° Usage + +```cs +// To use Choice, add this namespace. +using MackySoft.Choice; + +public class WeightedItem { + public string id; + public float weight; +} + +public WeightedItem SelectItem () { + // Prepare weighted items. + var items = new WeightedItem[] { + new WeightedItem { id = "πŸ’", weight = 8f }, + new WeightedItem { id = "🍏", weight = 4f }, + new WeightedItem { id = "🍍", weight = 0f }, + new WeightedItem { id = "πŸ‡", weight = 6f }, + new WeightedItem { id = "🍊", weight = -1f } + }; + + // Create the WeightedSelector. + var weightedSelector = items.ToWeightedSelector(item => item.weight); + + // The probability of each item being selected, + // πŸ’ is 44%, 🍏 is 22%, and πŸ‡ is 33%. + // 🍍 and 🍊 will never be selected because their weights are less or equal to 0. + return weightedSelector.SelectItemWithUnityRandom(); + // Same as weightedSelector.SelectItem(UnityEngine.Random.value); +} +``` + + +## ToWeightedSelector Overloads + +The `ToWeightedSelector` method has many overloads and can be used for a variety of patterns. ```cs +public class WeightedItem { + public string id; + public float weight; +} + public class Item { public string id; - public bool enabled; - public float rarity; } -public Item[] items; +public struct ItemEntry { + public Item item; + public float weight; +} -public IEnumerable SelectItems () { - IWeightedSelector weightedSelector = items - .ToWeightedSelector(item => item.rarity); - - for (int i = 0;i < 1000;i++) { - // Same as weightedSelector.SelectItem(UnityEngine.Random.value) - Item randomSelectedItem = weightedSelector.SelectItemWithUnityRandom(); - yield return randomSelectedItem; - } +public IWeightedSelector WeightedItemPattern () { + var items = new WeightedItem[] { + new WeightedItem { id = "πŸ’", weight = 1f }, + new WeightedItem { id = "🍏", weight = 5f }, + new WeightedItem { id = "🍍", weight = 3f } + }; + + // Create a WeightedSelector using the weight of the WeightedItem. + return fromWeightedItem = items.ToWeightedSelector(weightSelector: item => item.weight); +} + +public IWeightedSelector WeightedEntryPattern () { + var entries = new ItemEntry[] { + new ItemEntry { item = new Item { id = "πŸ’" }, weight = 1f }, + new ItemEntry { item = new Item { id = "🍏" }, weight = 5f }, + new ItemEntry { item = new Item { id = "🍍" }, weight = 3f } + }; + + // Create a WeightedSelector by selecting item and weight from entry respectively. + return entries.ToWeightedSelector( + itemSelector: entry => entry.item, + weightSelector: entry => entry.weight + ); +} + + +public IWeightedSelector DictionaryPattern () { + // This need a Dictionary. (Strictly speaking, IEnumerable>) + var dictionary = new Dictionary( + { new Item { id = "πŸ’" }, 1f }, + { new Item { id = "🍏" }, 5f }, + { new Item { id = "🍍" }, 3f } + ); + + // Create a WeightedSelector with the dictionary key as item and value as weight. + return dictionary.ToWeightedSelector(); } ``` -Since the `ToWeightedSelector` function is defined as an extension of `IEnumerable`, it can be connected from the LINQ syntax. + +## LINQ + +Since the `ToWeightedSelector` method is defined as an extension of `IEnumerable`, it can be connected from the LINQ query operators. ```cs -items - .Where(item => (item != null) && item.enabled) - .ToWeightedSelector(weightSelector: item => item.rarity) +var randomSelectedItem = items + .Where(item => item != null) // null check + .ToWeightedSelector(item => item.weight) .SelectItemWithUnityRandom(); ``` -## Select Algorithm +## Algorithms When creating a WeightedSelector, you can specify the `IWeightedSelectMethod`. ```cs var weightedSelector = items.ToWeightedSelector( - weightSelector: item => item.rarity, - method: WeightedSelectMethod.Binary // Use the binary search algorithm. + item => item.weight, + WeightedSelectMethod.Binary // Use the binary search algorithm. ); ``` +All `ToWeightedSelector` methods can specify `IWeightedSelectMethod`. + If this is not specified, the linear scan algorithm will be used automatically. + ### Linear Scan (`WeightedSelectMethod.Linear`) The simplest algorithm that walks linearly along the weights. This method is an `O(n)` operation, where `n` is number of weights. + ### Binary Search (`WeightedSelectMethod.Binary`) The binary search algorithm that is faster than linear scan by preprocessing to store the current sum of weights. @@ -88,13 +169,14 @@ The binary search algorithm that is faster than linear scan by preprocessing to It has an additional storage cost of `O(n)`, but is accelerated by up to `O(log(n))` for each selection, where `n` is number of weights. -# Author Info +# πŸ“” Author Info Hiroya Aramaki is a indie game developer in Japan. - Blog: [https://mackysoft.net/blog](https://mackysoft.net/blog) - Twitter: [https://twitter.com/makihiro_dev](https://twitter.com/makihiro_dev) -# License + +# πŸ“œ License This library is under the MIT License.