Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BStrStringMarshaller to source generator #69213

Merged
merged 12 commits into from
May 20, 2022
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -867,6 +867,7 @@
<Compile Include="$(MSBuildThisFileDirectory)System\Runtime\InteropServices\MarshalDirectiveException.cs" />
<Compile Include="$(MSBuildThisFileDirectory)System\Runtime\InteropServices\Marshalling\AnsiStringMarshaller.cs" />
<Compile Include="$(MSBuildThisFileDirectory)System\Runtime\InteropServices\Marshalling\ArrayMarshaller.cs" />
<Compile Include="$(MSBuildThisFileDirectory)System\Runtime\InteropServices\Marshalling\BStrStringMarshaller.cs" />
<Compile Include="$(MSBuildThisFileDirectory)System\Runtime\InteropServices\Marshalling\CustomTypeMarshallerAttribute.cs" />
<Compile Include="$(MSBuildThisFileDirectory)System\Runtime\InteropServices\Marshalling\CustomTypeMarshallerDirection.cs" />
<Compile Include="$(MSBuildThisFileDirectory)System\Runtime\InteropServices\Marshalling\CustomTypeMarshallerFeatures.cs" />
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Text;

namespace System.Runtime.InteropServices.Marshalling
{
/// <summary>
/// Marshaller for BSTR strings
/// </summary>
[CLSCompliant(false)]
[CustomTypeMarshaller(typeof(string), BufferSize = 0x200,
AaronRobinsonMSFT marked this conversation as resolved.
Show resolved Hide resolved
Features = CustomTypeMarshallerFeatures.UnmanagedResources | CustomTypeMarshallerFeatures.TwoStageMarshalling | CustomTypeMarshallerFeatures.CallerAllocatedBuffer)]
public unsafe ref struct BStrStringMarshaller
{
private ushort* _ptrToFirstChar;
private bool _allocated;

/// <summary>
/// Initializes a new instance of the <see cref="BStrStringMarshaller"/>.
/// </summary>
/// <param name="str">The string to marshal.</param>
public BStrStringMarshaller(string? str)
: this(str, default)
{ }

/// <summary>
/// Initializes a new instance of the <see cref="BStrStringMarshaller"/>.
/// </summary>
/// <param name="str">The string to marshal.</param>
/// <param name="buffer">Buffer that may be used for marshalling.</param>
/// <remarks>
/// The <paramref name="buffer"/> must not be movable - that is, it should not be
/// on the managed heap or it should be pinned.
/// <seealso cref="CustomTypeMarshallerFeatures.CallerAllocatedBuffer"/>
/// </remarks>
public BStrStringMarshaller(string? str, Span<ushort> buffer)
{
_allocated = false;

if (str is null)
{
_ptrToFirstChar = null;
return;
}

int lengthInBytes = checked(sizeof(char) * str.Length);

// A caller provided buffer must be at least (lengthInBytes + 6) bytes
// in order to be constructed manually. The 6 extra bytes are 4 for byte length and 2 for wide null.
int manualBstrNeeds = checked(lengthInBytes + 6);
if (manualBstrNeeds > buffer.Length)
{
// Use precise byte count when the provided stack-allocated buffer is not sufficient
_ptrToFirstChar = (ushort*)Marshal.AllocBSTRByteLen((uint)lengthInBytes);
_allocated = true;
}
else
{
// Set length and update buffer target
byte* pBuffer = (byte*)Unsafe.AsPointer(ref MemoryMarshal.GetReference(buffer));
*((uint*)pBuffer) = (uint)lengthInBytes;
_ptrToFirstChar = (ushort*)(pBuffer + sizeof(uint));
}

// Confirm the size is properly set for the allocated BSTR.
Debug.Assert(lengthInBytes == Marshal.SysStringByteLen((IntPtr)_ptrToFirstChar));

// Copy characters from the managed string
str.CopyTo(new Span<char>(_ptrToFirstChar, str.Length));
_ptrToFirstChar[str.Length] = '\0'; // null-terminate
AaronRobinsonMSFT marked this conversation as resolved.
Show resolved Hide resolved
}

/// <summary>
/// Returns the native value representing the string.
/// </summary>
/// <remarks>
/// <seealso cref="CustomTypeMarshallerFeatures.TwoStageMarshalling"/>
/// </remarks>
public ushort* ToNativeValue() => _ptrToFirstChar;

/// <summary>
/// Sets the native value representing the string.
/// </summary>
/// <param name="value">The native value.</param>
/// <remarks>
/// <seealso cref="CustomTypeMarshallerFeatures.TwoStageMarshalling"/>
/// </remarks>
public void FromNativeValue(ushort* value)
{
_ptrToFirstChar = value;
_allocated = true;
}

/// <summary>
/// Returns the managed string.
/// </summary>
/// <remarks>
/// <seealso cref="CustomTypeMarshallerDirection.Out"/>
/// </remarks>
public string? ToManaged()
{
if (_ptrToFirstChar is null)
return null;
return Marshal.PtrToStringBSTR((IntPtr)_ptrToFirstChar);
AaronRobinsonMSFT marked this conversation as resolved.
Show resolved Hide resolved
}

/// <summary>
/// Frees native resources.
/// </summary>
/// <remarks>
/// <seealso cref="CustomTypeMarshallerFeatures.UnmanagedResources"/>
/// </remarks>
public void FreeNative()
{
if (_allocated)
Marshal.FreeBSTR((IntPtr)_ptrToFirstChar);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

namespace Microsoft.Interop
{

/// <summary>
/// Type used to pass on default marshalling details.
/// </summary>
Expand Down Expand Up @@ -73,7 +72,6 @@ public enum CharEncoding
Undefined,
Utf8,
Utf16,
Ansi,
Custom
}

Expand Down Expand Up @@ -762,7 +760,13 @@ private bool TryCreateTypeBasedMarshallingInfo(
}
else
{
marshallingInfo = CreateStringMarshallingInfo(type, _defaultInfo.CharEncoding);
marshallingInfo = _defaultInfo.CharEncoding switch
{
CharEncoding.Utf16 => CreateStringMarshallingInfo(type, TypeNames.Utf16StringMarshaller),
CharEncoding.Utf8 => CreateStringMarshallingInfo(type, TypeNames.Utf8StringMarshaller),
_ => throw new InvalidOperationException()
};

return true;
}

Expand Down Expand Up @@ -843,30 +847,25 @@ private MarshallingInfo CreateStringMarshallingInfo(
ITypeSymbol type,
UnmanagedType unmanagedType)
{
CharEncoding charEncoding = unmanagedType switch
string? marshallerName = unmanagedType switch
{
UnmanagedType.LPStr => CharEncoding.Ansi,
UnmanagedType.LPTStr or UnmanagedType.LPWStr => CharEncoding.Utf16,
MarshalAsInfo.UnmanagedType_LPUTF8Str => CharEncoding.Utf8,
_ => CharEncoding.Undefined
UnmanagedType.BStr => TypeNames.BStrStringMarshaller,
UnmanagedType.LPStr => TypeNames.AnsiStringMarshaller,
UnmanagedType.LPTStr or UnmanagedType.LPWStr => TypeNames.Utf16StringMarshaller,
MarshalAsInfo.UnmanagedType_LPUTF8Str => TypeNames.Utf8StringMarshaller,
_ => null
};
if (charEncoding == CharEncoding.Undefined)

if (marshallerName is null)
return new MarshalAsInfo(unmanagedType, _defaultInfo.CharEncoding);

return CreateStringMarshallingInfo(type, charEncoding);
return CreateStringMarshallingInfo(type, marshallerName);
}

private MarshallingInfo CreateStringMarshallingInfo(
ITypeSymbol type,
CharEncoding charEncoding)
string marshallerName)
{
string? marshallerName = charEncoding switch
{
CharEncoding.Ansi => TypeNames.AnsiStringMarshaller,
CharEncoding.Utf16 => TypeNames.Utf16StringMarshaller,
CharEncoding.Utf8 => TypeNames.Utf8StringMarshaller,
_ => throw new InvalidOperationException()
};
INamedTypeSymbol? stringMarshaller = _compilation.GetTypeByMetadataName(marshallerName);
if (stringMarshaller is null)
return new MissingSupportMarshallingInfo();
Expand All @@ -877,9 +876,9 @@ private MarshallingInfo CreateStringMarshallingInfo(
return CreateNativeMarshallingInfoForValue(
type,
stringMarshaller,
default,
null,
customTypeMarshallerData.Value,
allowPinningManagedType: charEncoding == CharEncoding.Utf16,
allowPinningManagedType: marshallerName is TypeNames.Utf16StringMarshaller,
useDefaultMarshalling: false);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ public static class TypeNames
public const string CustomTypeMarshallerAttributeGenericPlaceholder = "System.Runtime.InteropServices.Marshalling.CustomTypeMarshallerAttribute.GenericPlaceholder";

public const string AnsiStringMarshaller = "System.Runtime.InteropServices.Marshalling.AnsiStringMarshaller";
public const string BStrStringMarshaller = "System.Runtime.InteropServices.Marshalling.BStrStringMarshaller";
public const string Utf16StringMarshaller = "System.Runtime.InteropServices.Marshalling.Utf16StringMarshaller";
public const string Utf8StringMarshaller = "System.Runtime.InteropServices.Marshalling.Utf8StringMarshaller";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2078,6 +2078,20 @@ public void FromNativeValue(byte* value) { }
public string? ToManaged() { throw null; }
public void FreeNative() { }
}
[System.CLSCompliant(false)]
[System.Runtime.InteropServices.Marshalling.CustomTypeMarshallerAttribute(typeof(string), BufferSize = 0x200,
Features = System.Runtime.InteropServices.Marshalling.CustomTypeMarshallerFeatures.UnmanagedResources
| System.Runtime.InteropServices.Marshalling.CustomTypeMarshallerFeatures.CallerAllocatedBuffer
| System.Runtime.InteropServices.Marshalling.CustomTypeMarshallerFeatures.TwoStageMarshalling )]
public unsafe ref struct BStrStringMarshaller
AaronRobinsonMSFT marked this conversation as resolved.
Show resolved Hide resolved
{
public BStrStringMarshaller(string? str) { }
public BStrStringMarshaller(string? str, System.Span<ushort> buffer) { }
public ushort* ToNativeValue() { throw null; }
public void FromNativeValue(ushort* value) { }
public string? ToManaged() { throw null; }
public void FreeNative() { }
}
[System.CLSCompliantAttribute(false)]
[System.Runtime.InteropServices.Marshalling.CustomTypeMarshallerAttribute(typeof(System.Runtime.InteropServices.Marshalling.CustomTypeMarshallerAttribute.GenericPlaceholder[]),
System.Runtime.InteropServices.Marshalling.CustomTypeMarshallerKind.LinearCollection, BufferSize = 0x200,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ private class EntryPoints

private const string UShortSuffix = "_ushort";
private const string ByteSuffix = "_byte";
private const string BStrSuffix = "_bstr";

public class Byte
{
Expand All @@ -41,6 +42,15 @@ public class UShort
public const string ReverseInplace = EntryPoints.ReverseInplace + UShortSuffix;
public const string ReverseReplace = EntryPoints.ReverseReplace + UShortSuffix;
}

public class BStr
{
public const string ReturnLength = EntryPoints.ReturnLength + BStrSuffix;
public const string ReverseReturn = EntryPoints.ReverseReturn + BStrSuffix;
public const string ReverseOut = EntryPoints.ReverseOut + BStrSuffix;
public const string ReverseInplace = EntryPoints.ReverseInplace + BStrSuffix;
public const string ReverseReplace = EntryPoints.ReverseReplace + BStrSuffix;
}
}

public partial class Utf16
Expand Down Expand Up @@ -185,6 +195,31 @@ public partial class LPStr
public static partial void Reverse_Replace_Ref([MarshalAs(UnmanagedType.LPStr)] ref string s);
}

public partial class BStr
{
[LibraryImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.BStr.ReturnLength)]
public static partial int ReturnLength([MarshalAs(UnmanagedType.BStr)] string s);

[LibraryImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.BStr.ReturnLength, StringMarshalling = StringMarshalling.Utf16)]
public static partial int ReturnLength_IgnoreStringMarshalling([MarshalAs(UnmanagedType.BStr)] string s);

[LibraryImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.BStr.ReverseReturn)]
[return: MarshalAs(UnmanagedType.BStr)]
public static partial string Reverse_Return([MarshalAs(UnmanagedType.BStr)] string s);

[LibraryImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.BStr.ReverseOut)]
public static partial void Reverse_Out([MarshalAs(UnmanagedType.BStr)] string s, [MarshalAs(UnmanagedType.BStr)] out string ret);

[LibraryImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.BStr.ReverseInplace)]
public static partial void Reverse_Ref([MarshalAs(UnmanagedType.BStr)] ref string s);

[LibraryImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.BStr.ReverseInplace)]
public static partial void Reverse_In([MarshalAs(UnmanagedType.BStr)] in string s);

[LibraryImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.BStr.ReverseReplace)]
public static partial void Reverse_Replace_Ref([MarshalAs(UnmanagedType.BStr)] ref string s);
}

public partial class StringMarshallingCustomType
{
public partial class Utf16
Expand Down Expand Up @@ -418,6 +453,48 @@ public void AnsiStringByRef(string value)
Assert.Equal(expected, refValue);
}

[Theory]
[MemberData(nameof(UnicodeStrings))]
public void BStrStringMarshalledAsExpected(string value)
{
int expectedLen = value != null ? value.Length : -1;

Assert.Equal(expectedLen, NativeExportsNE.BStr.ReturnLength(value));
Assert.Equal(expectedLen, NativeExportsNE.BStr.ReturnLength_IgnoreStringMarshalling(value));
}

[Theory]
[MemberData(nameof(UnicodeStrings))]
public void BStrStringReturn(string value)
{
string expected = ReverseChars(value);

Assert.Equal(expected, NativeExportsNE.BStr.Reverse_Return(value));

string ret;
NativeExportsNE.BStr.Reverse_Out(value, out ret);
Assert.Equal(expected, ret);
}

[Theory]
[MemberData(nameof(UnicodeStrings))]
public void BStrStringByRef(string value)
{
string refValue = value;
string expected = ReverseChars(value);

NativeExportsNE.BStr.Reverse_In(in refValue);
Assert.Equal(value, refValue); // Should not be updated when using 'in'

refValue = value;
NativeExportsNE.BStr.Reverse_Ref(ref refValue);
Assert.Equal(expected, refValue);

refValue = value;
NativeExportsNE.BStr.Reverse_Replace_Ref(ref refValue);
Assert.Equal(expected, refValue);
}

[Theory]
[MemberData(nameof(UnicodeStrings))]
public void StringMarshallingCustomType_MarshalledAsExpected(string value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,11 @@ public static IEnumerable<object[]> CodeSnippetsToCompile()
yield return new[] { CodeSnippets.MarshalAsParametersAndModifiers<string>(UnmanagedType.LPTStr) };
yield return new[] { CodeSnippets.MarshalAsParametersAndModifiers<string>(UnmanagedType.LPUTF8Str) };
yield return new[] { CodeSnippets.MarshalAsParametersAndModifiers<string>(UnmanagedType.LPStr) };
yield return new[] { CodeSnippets.MarshalAsParametersAndModifiers<string>(UnmanagedType.BStr) };
yield return new[] { CodeSnippets.MarshalAsArrayParameterWithNestedMarshalInfo<string>(UnmanagedType.LPWStr) };
yield return new[] { CodeSnippets.MarshalAsArrayParameterWithNestedMarshalInfo<string>(UnmanagedType.LPUTF8Str) };
yield return new[] { CodeSnippets.MarshalAsArrayParameterWithNestedMarshalInfo<string>(UnmanagedType.LPStr) };
yield return new[] { CodeSnippets.MarshalAsArrayParameterWithNestedMarshalInfo<string>(UnmanagedType.BStr) };

// [In, Out] attributes
// By value non-blittable array
Expand Down
Loading