Skip to content

Commit

Permalink
Update WeightedSelectMethods
Browse files Browse the repository at this point in the history
- Added the WeightedSelectMethodTests.
- All WeightedSelectMethods now support weights less than or equal to 0.
- IWeightedSelectMethod now takes a TemporaryArray<float> as an argument instead of a float[].
- Fixed a fatal bug that prevented BinaryWeightedSelectMethod from working.
  • Loading branch information
mackysoft committed Feb 1, 2021
1 parent 98091d3 commit 21ef37a
Show file tree
Hide file tree
Showing 10 changed files with 294 additions and 96 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,40 +13,46 @@ internal sealed class BinaryWeightedSelectMethod : IWeightedSelectMethod {

public static readonly BinaryWeightedSelectMethod Instance = new BinaryWeightedSelectMethod();

public int SelectIndex (float[] weights,float value) {
if (weights == null) {
throw new ArgumentNullException(nameof(weights));
}

using TemporaryArray<float> runningTotals = CumulativeSum(weights);

float targetDistance = value * runningTotals[runningTotals.Length - 1];
int low = 0;
int high = weights.Length;

while (low < high) {
int mid = (int)Math.Round((low + high) / 2f);
float distance = runningTotals[mid];
if (distance < targetDistance) {
low = mid + 1;
} else if (distance > targetDistance) {
high = mid;
} else {
return mid;
public int SelectIndex (TemporaryArray<float> weights,float value) {
CumulativeSum(weights,out var runningTotals,out var indicies);

using (runningTotals)
using (indicies) {
float targetDistance = value * runningTotals[runningTotals.Length - 1];
int low = 0;
int high = runningTotals.Length;

while (low < high) {
int mid = (int)Math.Floor((low + high) / 2f);
float distance = runningTotals[mid];
if (distance < targetDistance) {
low = mid + 1;
} else if (distance > targetDistance) {
high = mid;
} else {
return indicies[mid];
}
}
}

return low;
return indicies[low];
}
}

static TemporaryArray<float> CumulativeSum (float[] values) {
var results = TemporaryArray<float>.Create(values.Length);
static void CumulativeSum (TemporaryArray<float> weights,out TemporaryArray<float> runningTotals,out TemporaryArray<int> indicies) {
runningTotals = TemporaryArray<float>.CreateAsList(weights.Length);
indicies = TemporaryArray<int>.Create(weights.Length);
float sum = 0f;
for (int i = 0;i < results.Length;i++) {
sum += results[i];
results[i] = sum;
int nonZeroIteration = 0;
for (int i = 0;i < indicies.Length;i++) {
float weight = weights[i];
if (weight <= 0f) {
continue;
}
sum += weight;
runningTotals.Add(sum);
indicies[nonZeroIteration] = i;
nonZeroIteration++;
}
return results;
}

}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using MackySoft.Choice.Internal;

namespace MackySoft.Choice {

Expand All @@ -14,7 +15,7 @@ public interface IWeightedSelectMethod {
/// </param>
/// <returns> Selected index from weights. </returns>
/// <exception cref="ArgumentNullException"></exception>
int SelectIndex (float[] weights,float value);
int SelectIndex (TemporaryArray<float> weights,float value);

}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
using System;
using System.Runtime.CompilerServices;
using MackySoft.Choice.Internal;
using UnityEngine;

namespace MackySoft.Choice {

Expand All @@ -11,27 +13,28 @@ internal sealed class LinearWeightedSelectMethod : IWeightedSelectMethod {

public static readonly LinearWeightedSelectMethod Instance = new LinearWeightedSelectMethod();

public int SelectIndex (float[] weights,float value) {
if (weights == null) {
throw new ArgumentNullException(nameof(weights));
}

public int SelectIndex (TemporaryArray<float> weights,float value) {
float remainingDistance = value * Sum(weights);
for (int i = 0;i < weights.Length;i++) {
float weight = weights[i];
if (weight <= 0f) {
continue;
}

remainingDistance -= weight;
if (remainingDistance < 0f) {
if (remainingDistance <= 0f) {
return i;
}
}

return -1;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
static float Sum (float[] values) {
static float Sum (TemporaryArray<float> weights) {
float result = 0f;
for (int i = 0;i < values.Length;i++) {
result += values[i];
for (int i = 0;i < weights.Length;i++) {
result += weights[i];
}
return result;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ public WeightedSelector (IEnumerable<KeyValuePair<TItem,float>> source,IWeighted
m_Method = method;
}


internal WeightedSelector (TemporaryArray<TItem> items,TemporaryArray<float> weights,IWeightedSelectMethod method) {
if (items.Length != weights.Length) {
throw new ArgumentException();
Expand Down Expand Up @@ -94,7 +93,7 @@ public void Clear () {
}

public TItem SelectItem (float value) {
int index = m_Method.SelectIndex(m_Weights.Array,value);
int index = m_Method.SelectIndex(m_Weights,value);
return (index >= 0) ? m_Items[index] : default;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,25 +1,21 @@
using System.Linq;
using System.Collections.Generic;
using UnityEngine;
using NUnit.Framework;
using MackySoft.Choice.Tests;

namespace MackySoft.Choice.Internal.Tests {

public class EnumerableConversionTests {

class Item {
public float weight;
}

[Test]
public void EnumerableToTemporaryArray () {
Item[] source = GenerateEnumerable().ToArray();
ItemEntry[] source = ItemEnumerableGenerator.GenerateEnumerable(100).ToArray();

EnumerableConversion.EnumerableToTemporaryArray(source.Select(x => x),x => x.weight,out var items,out var weights);
EnumerableConversion.EnumerableToTemporaryArray(source.Select(x => x),x => x.item,x => x.weight,out var items,out var weights);

for (int i = 0;source.Length > i;i++) {
var element = source[i];
Assert.AreEqual(element,items[i]);
Assert.AreEqual(element.item,items[i]);
Assert.AreEqual(element.weight,weights[i]);
}

Expand All @@ -29,13 +25,13 @@ public void EnumerableToTemporaryArray () {

[Test]
public void ReadOnlyListToTemporaryArray () {
Item[] source = GenerateEnumerable().ToArray();
ItemEntry[] source = ItemEnumerableGenerator.GenerateEnumerable(100).ToArray();

EnumerableConversion.EnumerableToTemporaryArray(source,x => x.weight,out var items,out var weights);
EnumerableConversion.EnumerableToTemporaryArray(source,x => x.item,x => x.weight,out var items,out var weights);

for (int i = 0;source.Length > i;i++) {
var element = source[i];
Assert.AreEqual(element,items[i]);
Assert.AreEqual(element.item,items[i]);
Assert.AreEqual(element.weight,weights[i]);
}

Expand All @@ -45,13 +41,13 @@ public void ReadOnlyListToTemporaryArray () {

[Test]
public void ListToTemporaryArray () {
List<Item> source = GenerateEnumerable().ToList();
List<ItemEntry> source = ItemEnumerableGenerator.GenerateEnumerable(100).ToList();

EnumerableConversion.EnumerableToTemporaryArray(source,x => x.weight,out var items,out var weights);
EnumerableConversion.EnumerableToTemporaryArray(source,x => x.item,x => x.weight,out var items,out var weights);

for (int i = 0;source.Count > i;i++) {
var element = source[i];
Assert.AreEqual(element,items[i]);
Assert.AreEqual(element.item,items[i]);
Assert.AreEqual(element.weight,weights[i]);
}

Expand All @@ -61,7 +57,7 @@ public void ListToTemporaryArray () {

[Test]
public void DictionaryToTemporaryArray () {
Dictionary<Item,float> source = GenerateDictionary().ToDictionary(p => p.Key,p => p.Value);
Dictionary<Item,float> source = ItemEnumerableGenerator.GenerateDictionary(100).ToDictionary(p => p.Key,p => p.Value);

EnumerableConversion.DictionaryToTemporaryArray(source,out var items,out var weights);

Expand All @@ -78,7 +74,7 @@ public void DictionaryToTemporaryArray () {

[Test]
public void DictionaryEnumerableToTemporaryArray () {
Dictionary<Item,float> source = GenerateDictionary().ToDictionary(p => p.Key,p => p.Value);
Dictionary<Item,float> source = ItemEnumerableGenerator.GenerateDictionary(100).ToDictionary(p => p.Key,p => p.Value);

EnumerableConversion.DictionaryToTemporaryArray(source.Select(x => x),out var items,out var weights);

Expand All @@ -93,17 +89,5 @@ public void DictionaryEnumerableToTemporaryArray () {
weights.Dispose();
}

static IEnumerable<Item> GenerateEnumerable () {
for (int i = 0;100 > i;i++) {
yield return new Item { weight = Random.value };
}
}

static IEnumerable<KeyValuePair<Item,float>> GenerateDictionary () {
for (int i = 0;100 > i;i++) {
yield return new KeyValuePair<Item,float>(new Item(),Random.value);
}
}

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
using System.Collections.Generic;
using UnityEngine;

namespace MackySoft.Choice.Tests {

public class Item {

public int id;

public override string ToString () {
return "Index: " + id.ToString();
}

}

public struct ItemEntry {

public Item item;
public float weight;

public override string ToString () {
return $"{{ {item}: Weight {weight} }}";
}

}

public class ItemEnumerableGenerator {

public static IEnumerable<ItemEntry> GenerateEnumerable (int count) {
for (int i = 0;count > i;i++) {
yield return new ItemEntry {
item = new Item { id = i },
weight = Random.Range(0,10)
};
}
}

public static IEnumerable<KeyValuePair<Item,float>> GenerateDictionary (int count) {
for (int i = 0;count > i;i++) {
yield return new KeyValuePair<Item,float>(
new Item { id = i },
Random.Range(0,10)
);
}
}

}

}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
using System.Linq;
using System.Collections.Generic;
using UnityEngine;
using NUnit.Framework;

namespace MackySoft.Choice.Tests {
public class WeightedSelectMethodTests {

[Test]
public void Linear_ReturnValidValue ([Random(0f,1f,10)] float value) {
var source = ItemEnumerableGenerator.GenerateEnumerable(100).ToArray();
var weightedSelector = source.ToWeightedSelector(x => x.item,x => x.weight,WeightedSelectMethod.Linear);
Assert.IsNotNull(weightedSelector.SelectItem(value));
}

[Test, Repeat(100)]
public void Linear_ReturnValidValue_0 () {
var source = ItemEnumerableGenerator.GenerateEnumerable(100).ToArray();
var weightedSelector = source.ToWeightedSelector(x => x.item,x => x.weight,WeightedSelectMethod.Linear);
Assert.AreSame(source.FirstOrDefault(x => x.weight > 0f).item,weightedSelector.SelectItem(0f));
}

[Test, Repeat(100)]
public void Linear_ReturnValidValue_1 () {
var source = ItemEnumerableGenerator.GenerateEnumerable(100).ToArray();
var weightedSelector = source.ToWeightedSelector(x => x.item,x => x.weight,WeightedSelectMethod.Linear);
Assert.AreSame(source.LastOrDefault(x => x.weight > 0f).item,weightedSelector.SelectItem(1f));
}

[Test]
public void Binary_ReturnValidValue ([Random(0f,1f,10)] float value) {
var source = ItemEnumerableGenerator.GenerateEnumerable(100).ToArray();
var weightedSelector = source.ToWeightedSelector(x => x.item,x => x.weight,WeightedSelectMethod.Binary);
Assert.IsNotNull(weightedSelector.SelectItem(value));
}

[Test, Repeat(100)]
public void Binary_ReturnValidValue_0 () {
var source = ItemEnumerableGenerator.GenerateEnumerable(100).ToArray();
var weightedSelector = source.ToWeightedSelector(x => x.item,x => x.weight,WeightedSelectMethod.Binary);
Assert.AreSame(weightedSelector.FirstOrDefault(p => p.Value > 0f).Key,weightedSelector.SelectItem(0f));
}

[Test, Repeat(100)]
public void Binary_ReturnValidValue_1 () {
var source = ItemEnumerableGenerator.GenerateEnumerable(100).ToArray();
var weightedSelector = source.ToWeightedSelector(x => x.item,x => x.weight,WeightedSelectMethod.Binary);
Assert.AreSame(source.LastOrDefault(x => x.weight > 0f).item,weightedSelector.SelectItem(1f));
}

}
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 21ef37a

Please sign in to comment.