diff --git a/src/FlatSharp.Runtime/Vectors/VectorsCommon.cs b/src/FlatSharp.Runtime/Vectors/VectorsCommon.cs index 4734b615..c9de6867 100644 --- a/src/FlatSharp.Runtime/Vectors/VectorsCommon.cs +++ b/src/FlatSharp.Runtime/Vectors/VectorsCommon.cs @@ -21,6 +21,12 @@ namespace FlatSharp.Internal; /// public static class VectorsCommon { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static bool Contains(IList vector, T? item) + { + return IndexOf(vector, item) >= 0; + } + [MethodImpl(MethodImplOptions.AggressiveInlining)] public static int IndexOf(IList vector, T? item) { diff --git a/src/FlatSharp/FlatBufferVectorHelpers.cs b/src/FlatSharp/FlatBufferVectorHelpers.cs index 56030433..930a3859 100644 --- a/src/FlatSharp/FlatBufferVectorHelpers.cs +++ b/src/FlatSharp/FlatBufferVectorHelpers.cs @@ -64,7 +64,8 @@ public static string CreateCommonReadOnlyVectorMethods( string nullableReference = GetNullableReferenceAnnotation(itemTypeModel); return $$""" - public bool Contains({{baseTypeName}}{{nullableReference}} item) => this.IndexOf(item) >= 0; + public bool Contains({{baseTypeName}}{{nullableReference}} item) + => {{typeof(VectorsCommon).GetGlobalCompilableTypeName()}}.Contains(this, item); public int IndexOf({{baseTypeName}}{{nullableReference}} item) => {{typeof(VectorsCommon).GetGlobalCompilableTypeName()}}.IndexOf(this, item); @@ -84,6 +85,7 @@ private static string GetEfficientMultiply( string indexVariableName) { FlatSharpInternal.Assert(inlineSize != 0, "invalid inline size"); + bool isPowerOf2 = (inlineSize & (inlineSize - 1)) == 0; if (!isPowerOf2) { diff --git a/src/FlatSharp/FlatBufferVectorHelpers_LazyUnion.cs b/src/FlatSharp/FlatBufferVectorHelpers_LazyUnion.cs index fbdeb9af..64516528 100644 --- a/src/FlatSharp/FlatBufferVectorHelpers_LazyUnion.cs +++ b/src/FlatSharp/FlatBufferVectorHelpers_LazyUnion.cs @@ -92,7 +92,11 @@ internal sealed class {{className}} public {{baseTypeName}} this[int index] { get => this.SafeParseItem(index); - set => this.WriteThrough(index, value); + set + { + {{nameof(VectorUtilities)}}.{{nameof(VectorUtilities.CheckIndex)}}(index, this.count); + this.WriteThrough(index, value); + } } public int Count => this.count; diff --git a/src/FlatSharp/FlatBufferVectorHelpers_ProgressiveUnion.cs b/src/FlatSharp/FlatBufferVectorHelpers_ProgressiveUnion.cs index e442920a..639a0c1e 100644 --- a/src/FlatSharp/FlatBufferVectorHelpers_ProgressiveUnion.cs +++ b/src/FlatSharp/FlatBufferVectorHelpers_ProgressiveUnion.cs @@ -192,6 +192,8 @@ private static void GetAddress(uint index, out uint rowIndex, out uint colIndex) int absoluteStartIndex = (int)({{GetEfficientMultiply(chunkSize, "rowIndex")}}); int copyCount = {{chunkSize}}; int remainingItems = this.count - absoluteStartIndex; + + {{StrykerSuppressor.SuppressNextLine("equality")}} if (remainingItems < {{chunkSize}}) { copyCount = remainingItems; @@ -235,6 +237,7 @@ private static void GetAddress(uint index, out uint rowIndex, out uint colIndex) private void ProgressiveSet(int index, {{baseTypeName}} value) { + {{nameof(VectorUtilities)}}.{{nameof(VectorUtilities.CheckIndex)}}(index, this.count); {{nameof(VectorUtilities)}}.{{nameof(VectorUtilities.ThrowInlineNotMutableException)}}(); } diff --git a/src/Tests/FlatSharpEndToEndTests/Helpers.cs b/src/Tests/FlatSharpEndToEndTests/Helpers.cs index 6de92522..ea2cac83 100644 --- a/src/Tests/FlatSharpEndToEndTests/Helpers.cs +++ b/src/Tests/FlatSharpEndToEndTests/Helpers.cs @@ -16,9 +16,11 @@ using FlatSharp.Internal; using System; +using System.Collections.Generic; using System.Linq.Expressions; using System.Reflection; using System.Threading; +using Xunit.Abstractions; namespace FlatSharpEndToEndTests; @@ -97,7 +99,8 @@ public static void AssertMutationWorks( TSource parent, bool isWriteThrough, Expression> propertyLambda, - TProperty newValue) + TProperty newValue, + Action? assertEqual = null) { Assert.True(parent is IFlatBufferDeserializedObject); @@ -124,33 +127,26 @@ public static void AssertMutationWorks( MemberExpression member = propertyLambda.Body as MemberExpression; PropertyInfo propInfo = member.Member as PropertyInfo; - Action action = () => propInfo.SetMethod.Invoke(parent, new object[] { newValue }); + + Func get = () => (TProperty)propInfo.GetMethod.Invoke(parent, null); + Action set = () => propInfo.SetMethod.Invoke(parent, new object[] { newValue }); + + // should be equal to itself to start with. + AssertEquality(option, get(), get(), assertEqual); switch (option) { case FlatBufferDeserializationOption.Lazy when isWriteThrough: case FlatBufferDeserializationOption.Progressive when isWriteThrough: case FlatBufferDeserializationOption.GreedyMutable when isWriteThrough is false: - action(); - - // For value types, validate that they are the same. - if (typeof(TProperty).IsValueType) - { - TProperty readValue = (TProperty)propInfo.GetMethod.Invoke(parent, null); - Assert.Equal(newValue, readValue); - } - else if (option != FlatBufferDeserializationOption.Lazy) - { - TProperty readValue = (TProperty)propInfo.GetMethod.Invoke(parent, null); - Assert.True(object.ReferenceEquals(newValue, readValue)); - } - + set(); + AssertEquality(option, newValue, get(), assertEqual); return; default: var ex = Assert.Throws(new Action(() => { - var ex = Assert.Throws(action).InnerException; + var ex = Assert.Throws(set).InnerException; throw ex; })); @@ -169,12 +165,8 @@ public static void ValidateListVector( IList items, T newValue) { - // This can be lots of things: NotMutable, ArgumentOfRange, IndexOutOfRange, etc. - Assert.ThrowsAny(() => items[-1]); - Assert.ThrowsAny(() => items[items.Count]); - Assert.ThrowsAny(() => items[-1] = default); - Assert.ThrowsAny(() => items[items.Count] = default); - + CheckRangeExceptions(option, isWriteThrough, items); + if (items is IFlatBufferDeserializedVector vec) { Assert.ThrowsAny(() => vec.OffsetOf(-1)); @@ -419,4 +411,52 @@ IEnumerator IEnumerable.GetEnumerator() return ((IEnumerable)list).GetEnumerator(); } } + + private static void AssertEquality(FlatBufferDeserializationOption option, T a, T b, Action? assertEqual) + { + if (assertEqual is null) + { + if (typeof(T).IsValueType && (typeof(T).IsPrimitive || typeof(T).IsEnum)) + { + assertEqual = (a, b) => Assert.Equal(a, b); + } + else if (typeof(T) == typeof(string)) + { + assertEqual = (a, b) => Assert.Equal(a, b); + } + } + + if (!typeof(T).IsValueType && option != FlatBufferDeserializationOption.Lazy) + { + Assert.True(object.ReferenceEquals(a, b)); + } + + assertEqual?.Invoke(a, b); + } + + private static void CheckRangeExceptions(FlatBufferDeserializationOption option, bool isWriteThrough, IList list) + { + if (option == FlatBufferDeserializationOption.Lazy || option == FlatBufferDeserializationOption.Progressive) + { + Assert.Throws(() => list[-1]); + Assert.Throws(() => list[list.Count]); + Assert.Throws(() => list[-1] = default); + Assert.Throws(() => list[list.Count] = default); + } + else if (option == FlatBufferDeserializationOption.Greedy + || (isWriteThrough && option == FlatBufferDeserializationOption.GreedyMutable)) + { + Assert.Throws(() => list[-1]); + Assert.Throws(() => list[list.Count]); + Assert.Throws(() => list[-1] = default); + Assert.Throws(() => list[list.Count] = default); + } + else + { + Assert.Throws(() => list[-1]); + Assert.Throws(() => list[list.Count]); + Assert.Throws(() => list[-1] = default); + Assert.Throws(() => list[list.Count] = default); + } + } } \ No newline at end of file diff --git a/src/Tests/Stryker/Tests/FullTreeTests.cs b/src/Tests/Stryker/Tests/FullTreeTests.cs index 6f707d85..3ebece4a 100644 --- a/src/Tests/Stryker/Tests/FullTreeTests.cs +++ b/src/Tests/Stryker/Tests/FullTreeTests.cs @@ -22,10 +22,21 @@ public void RootMutations(FlatBufferDeserializationOption option) [ClassData(typeof(DeserializationOptionClassData))] public void VectorFieldMutations(FlatBufferDeserializationOption option) { + static void AssertMemoryEqual(Memory? a, Memory? b) + { + Assert.Equal(a is null, b is null); + if (a is null) + { + return; + } + + Helpers.AssertSequenceEqual(a.Value.Span, b.Value.Span); + } + Vectors vectors = this.CreateRoot().SerializeAndParse(option).Vectors; Helpers.AssertMutationWorks(option, vectors, false, r => r.Indexed, null); - Helpers.AssertMutationWorks(option, vectors, false, r => r.Memory, null); + Helpers.AssertMutationWorks(option, vectors, false, r => r.Memory, null, AssertMemoryEqual); Helpers.AssertMutationWorks(option, vectors, false, r => r.RefStruct, null); Helpers.AssertMutationWorks(option, vectors, false, r => r.Str, null); Helpers.AssertMutationWorks(option, vectors, false, r => r.Table, null); @@ -33,6 +44,17 @@ public void VectorFieldMutations(FlatBufferDeserializationOption option) Helpers.AssertMutationWorks(option, vectors, false, r => r.ValueStruct, null); } + [Fact] + public void VectorFieldTests_ProgressiveClear() + { + Vectors vectors = this.CreateRoot().SerializeAndParse(FlatBufferDeserializationOption.Progressive, out byte[] data).Vectors; + Helpers.AssertSequenceEqual(new byte[] { 1, 2, 3, 4, }, vectors.Memory.Value.Span); + + data.AsSpan().Clear(); + + Helpers.AssertSequenceEqual(new byte[] { 0, 0, 0, 0, }, vectors.Memory.Value.Span); + } + [Fact] public void GetMaxSize() { diff --git a/src/Tests/Stryker/Tests/ScalarFieldTests.cs b/src/Tests/Stryker/Tests/ScalarFieldTests.cs new file mode 100644 index 00000000..ccb5bb1e --- /dev/null +++ b/src/Tests/Stryker/Tests/ScalarFieldTests.cs @@ -0,0 +1,51 @@ +using FlatSharp.Internal; +using System.Linq.Expressions; + +namespace FlatSharpStrykerTests; + +public class ScalarFieldTests +{ + [Theory] + [ClassData(typeof(DeserializationOptionClassData))] + public void ValueStructTableField(FlatBufferDeserializationOption option) + { + Root r = CreateTableWithScalarField(out byte[] buffer); + Root parsed = r.SerializeAndParse(option, out byte[] actual); + + Assert.Equal(r.Fields.Memory, parsed.Fields.Memory); + Helpers.AssertSequenceEqual(buffer, actual); + } + + private static Root CreateTableWithScalarField(out byte[] expectedBuffer) + { + Root root = new Root + { + Fields = new() + { + Memory = 3, + } + }; + + expectedBuffer = new byte[] + { + 4, 0, 0, 0, // offset to table start + 248, 255, 255, 255, // soffset to vtable. + 12, 0, 0, 0, // uoffset to field 0 (fields table) + 6, 0, // vtable length + 8, 0, // table length + 4, 0, // offset of field 0 + 0, 0, // padding + + 250, 255, 255, 255, // soffset to vtable + 3, 0, + + 10, 0, + 5, 0, + 0, 0, + 0, 0, + 4, 0, + }; + + return root; + } +} diff --git a/src/Tests/Stryker/Tests/StructFieldTests.cs b/src/Tests/Stryker/Tests/StructFieldTests.cs index 25b2d7a8..0bf5c01d 100644 --- a/src/Tests/Stryker/Tests/StructFieldTests.cs +++ b/src/Tests/Stryker/Tests/StructFieldTests.cs @@ -73,7 +73,16 @@ public void ValueStructTableField(FlatBufferDeserializationOption option) }; Helpers.AssertSequenceEqual(expectedBytes, buffer); - Helpers.AssertMutationWorks(option, fields, false, p => p.ValueStruct, default); + Helpers.AssertMutationWorks(option, fields, false, p => p.ValueStruct, new ValueStruct(), (a, b) => + { + var av = a.Value; + var bv = b.Value; + + Assert.Equal(av.A, bv.A); + Assert.Equal(av.B, bv.B); + Assert.Equal(av.C(0), bv.C(0)); + Assert.Equal(av.C(1), bv.C(1)); + }); } } finally @@ -82,6 +91,27 @@ public void ValueStructTableField(FlatBufferDeserializationOption option) } } + [Fact] + public void ValueStructTableField_ProgressiveClear() + { + Root root = new Root() { Fields = new() { ValueStruct = new ValueStruct { A = 5 } } }.SerializeAndParse(FlatBufferDeserializationOption.Progressive, out byte[] buffer); + + var fields = root.Fields; + Assert.Equal(5, fields.ValueStruct.Value.A); + buffer.AsSpan().Clear(); + Assert.Equal(5, fields.ValueStruct.Value.A); + } + + [Fact] + public void ValueStructStructField_ProgressiveClear() + { + Root root = new Root() { Fields = new() { RefStruct = new() { E = new ValueStruct { A = 5 } } } }.SerializeAndParse(FlatBufferDeserializationOption.Progressive, out byte[] buffer); + + var fields = root.Fields.RefStruct; + Assert.Equal(5, fields.E.A); + buffer.AsSpan().Clear(); + Assert.Equal(5, fields.E.A); + } [Theory] [ClassData(typeof(DeserializationOptionClassData))] @@ -130,7 +160,13 @@ public void ReferenceStructWriteThrough(FlatBufferDeserializationOption option) Helpers.AssertMutationWorks(option, rsp, true, rsp => rsp.__flatsharp__C_1, (sbyte)6); Helpers.AssertMutationWorks(option, rsp, false, rsp => rsp.__flatsharp__D_0, (sbyte)3); Helpers.AssertMutationWorks(option, rsp, false, rsp => rsp.__flatsharp__D_1, (sbyte)6); - Helpers.AssertMutationWorks(option, rsp, false, rsp => rsp.E, new ValueStruct()); + Helpers.AssertMutationWorks(option, rsp, false, rsp => rsp.E, new ValueStruct(), (a, b) => + { + Assert.Equal(a.A, b.A); + Assert.Equal(a.B, b.B); + Assert.Equal(a.C(0), b.C(0)); + Assert.Equal(a.C(1), b.C(1)); + }); var parsed2 = Root.Serializer.Parse(buffer, option); diff --git a/src/Tests/Stryker/Tests/UnionFieldTests.cs b/src/Tests/Stryker/Tests/UnionFieldTests.cs index 91720d1b..e53e3905 100644 --- a/src/Tests/Stryker/Tests/UnionFieldTests.cs +++ b/src/Tests/Stryker/Tests/UnionFieldTests.cs @@ -27,6 +27,19 @@ public void InvalidGetters() Assert.Throws(() => a.RefStruct); } + [Fact] + public void ProgressiveClear() + { + Root parsed = new Root { Fields = new() { Union = new FunUnion("hi") } }.SerializeAndParse(FlatBufferDeserializationOption.Progressive, out byte[] buffer); + + Fields f = parsed.Fields; + Assert.Equal("hi", f.Union.Value.str); + + buffer.AsSpan().Clear(); + + Assert.Equal("hi", f.Union.Value.str); + } + [Theory] [ClassData(typeof(DeserializationOptionClassData))] public void StringMember(FlatBufferDeserializationOption option) @@ -44,7 +57,13 @@ public void StringMember(FlatBufferDeserializationOption option) Assert.True(union.TryGet(out string str)); Assert.Equal("hello", str); - Helpers.AssertMutationWorks(option, parsed.Fields, false, f => f.Union, default); + Helpers.AssertMutationWorks(option, parsed.Fields, false, f => f.Union, new FunUnion(string.Empty), (a, b) => + { + var av = a.Value; + var bv = b.Value; + Assert.Equal(av.Discriminator, bv.Discriminator); + Assert.Equal(av.str, bv.str); + }); } [Fact] diff --git a/src/Tests/Stryker/Tests/UnionVectorTests.cs b/src/Tests/Stryker/Tests/UnionVectorTests.cs index f1caa5f0..dba9cd87 100644 --- a/src/Tests/Stryker/Tests/UnionVectorTests.cs +++ b/src/Tests/Stryker/Tests/UnionVectorTests.cs @@ -93,6 +93,46 @@ public void UnionVector(FlatBufferDeserializationOption option) => Helpers.Repea Helpers.ValidateListVector(option, false, unions, new FunUnion("foo")); }); + [Theory] + [ClassData(typeof(DeserializationOptionClassData))] + public void Big(FlatBufferDeserializationOption option) + { + const int ChunkSize = 8; // this test needs to exercise some edge scenarios that require knowing the chunk size of the union vectors. + + Random r = new Random(); + + for (int test = 1; test < 10; ++test) + { + List expected = new(); + + for (int i = 0; i < (ChunkSize * test) + 1; ++i) + { + FunUnion union = (r.Next() % 4) switch + { + 0 => new(new Key { Name = i.ToString() }), + 1 => new(new RefStruct { A = i }), + 2 => new(new ValueStruct { A = i }), + 3 => new(i.ToString()), + _ => throw new Exception(), + }; + + expected.Add(union); + } + + Root root = new Root { Vectors = new() { Union = expected } }.SerializeAndParse(option); + IList parsed = root.Vectors.Union; + + Assert.Equal(expected.Count, parsed.Count); + for (int i = 0; i < expected.Count; ++i) + { + FunUnion e = expected[i]; + FunUnion p = parsed[i]; + + Assert.Equal(e.Discriminator, p.Discriminator); + } + } + } + [Theory] [ClassData(typeof(DeserializationOptionClassData))] public void UnionVector_Invalid_MissingOffset(FlatBufferDeserializationOption option) diff --git a/src/Tests/Stryker/Tests/ValueStructVectorTests.cs b/src/Tests/Stryker/Tests/ValueStructVectorTests.cs index 1a3ae3cd..24b8674a 100644 --- a/src/Tests/Stryker/Tests/ValueStructVectorTests.cs +++ b/src/Tests/Stryker/Tests/ValueStructVectorTests.cs @@ -1,5 +1,6 @@ using FlatSharp.Internal; using System.Linq.Expressions; +using System.Reflection.Metadata.Ecma335; using System.Threading; namespace FlatSharpStrykerTests; @@ -36,6 +37,31 @@ public void Present(FlatBufferDeserializationOption option) => Helpers.Repeat(() Helpers.ValidateListVector(option, true, vsp, new ValueStruct()); }); + [Theory] + [ClassData(typeof(DeserializationOptionClassData))] + public void WriteThroughDisabled(FlatBufferDeserializationOption option) + { + if (option == FlatBufferDeserializationOption.Greedy) + { + // Greedy is special because writethrough has no bearing on it, so it does not store an + // internal field context. + return; + } + + Root root = CreateRoot(out _); + Root parsed = root.SerializeAndParse(option); + + IList list = parsed.Vectors.ValueStruct; + FieldInfo field = list.GetType().GetField("fieldContext", BindingFlags.NonPublic | BindingFlags.Instance); + TableFieldContext context = (TableFieldContext)field.GetValue(list); + + Assert.True(context.WriteThrough); + context = new TableFieldContext(context.FullName, context.SharedString, writeThrough: false); + field.SetValue(list, context); + + Helpers.ValidateListVector(option, false, list, new ValueStruct()); + } + [Theory] [ClassData(typeof(DeserializationOptionClassData))] public void Missing(FlatBufferDeserializationOption option)