Skip to content

Commit

Permalink
Add BStrStringMarshaller to source generator (#69213)
Browse files Browse the repository at this point in the history
* Add BStrStringMarshaller to source generator

* Convert to use void* for BStr and Utf16 marshallers' native types.

Co-authored-by: Jan Kotas <jkotas@microsoft.com>
  • Loading branch information
AaronRobinsonMSFT and jkotas authored May 20, 2022
1 parent de27aa1 commit 810a7f9
Show file tree
Hide file tree
Showing 9 changed files with 307 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -866,6 +866,7 @@
<Compile Include="$(MSBuildThisFileDirectory)System\Runtime\InteropServices\MarshalDirectiveException.cs" />
<Compile Include="$(MSBuildThisFileDirectory)System\Runtime\InteropServices\Marshalling\AnsiStringMarshaller.cs" />
<Compile Include="$(MSBuildThisFileDirectory)System\Runtime\InteropServices\Marshalling\ArrayMarshaller.cs" />
<Compile Include="$(MSBuildThisFileDirectory)System\Runtime\InteropServices\Marshalling\BStrStringMarshaller.cs" />
<Compile Include="$(MSBuildThisFileDirectory)System\Runtime\InteropServices\Marshalling\CustomTypeMarshallerAttribute.cs" />
<Compile Include="$(MSBuildThisFileDirectory)System\Runtime\InteropServices\Marshalling\CustomTypeMarshallerDirection.cs" />
<Compile Include="$(MSBuildThisFileDirectory)System\Runtime\InteropServices\Marshalling\CustomTypeMarshallerFeatures.cs" />
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Text;

namespace System.Runtime.InteropServices.Marshalling
{
/// <summary>
/// Marshaller for BSTR strings
/// </summary>
[CLSCompliant(false)]
[CustomTypeMarshaller(typeof(string), BufferSize = 0x100,
Features = CustomTypeMarshallerFeatures.UnmanagedResources | CustomTypeMarshallerFeatures.TwoStageMarshalling | CustomTypeMarshallerFeatures.CallerAllocatedBuffer)]
public unsafe ref struct BStrStringMarshaller
{
private void* _ptrToFirstChar;
private bool _allocated;

/// <summary>
/// Initializes a new instance of the <see cref="BStrStringMarshaller"/>.
/// </summary>
/// <param name="str">The string to marshal.</param>
public BStrStringMarshaller(string? str)
: this(str, default)
{ }

/// <summary>
/// Initializes a new instance of the <see cref="BStrStringMarshaller"/>.
/// </summary>
/// <param name="str">The string to marshal.</param>
/// <param name="buffer">Buffer that may be used for marshalling.</param>
/// <remarks>
/// The <paramref name="buffer"/> must not be movable - that is, it should not be
/// on the managed heap or it should be pinned.
/// <seealso cref="CustomTypeMarshallerFeatures.CallerAllocatedBuffer"/>
/// </remarks>
public BStrStringMarshaller(string? str, Span<ushort> buffer)
{
_allocated = false;

if (str is null)
{
_ptrToFirstChar = null;
return;
}

ushort* ptrToFirstChar;
int lengthInBytes = checked(sizeof(char) * str.Length);

// A caller provided buffer must be at least (lengthInBytes + 6) bytes
// in order to be constructed manually. The 6 extra bytes are 4 for byte length and 2 for wide null.
int manualBstrNeeds = checked(lengthInBytes + 6);
if (manualBstrNeeds > buffer.Length)
{
// Use precise byte count when the provided stack-allocated buffer is not sufficient
ptrToFirstChar = (ushort*)Marshal.AllocBSTRByteLen((uint)lengthInBytes);
_allocated = true;
}
else
{
// Set length and update buffer target
byte* pBuffer = (byte*)Unsafe.AsPointer(ref MemoryMarshal.GetReference(buffer));
*((uint*)pBuffer) = (uint)lengthInBytes;
ptrToFirstChar = (ushort*)(pBuffer + sizeof(uint));
}

// Confirm the size is properly set for the allocated BSTR.
Debug.Assert(lengthInBytes == Marshal.SysStringByteLen((IntPtr)ptrToFirstChar));

// Copy characters from the managed string
str.CopyTo(new Span<char>(ptrToFirstChar, str.Length));
ptrToFirstChar[str.Length] = '\0'; // null-terminate
_ptrToFirstChar = ptrToFirstChar;
}

/// <summary>
/// Returns the native value representing the string.
/// </summary>
/// <remarks>
/// <seealso cref="CustomTypeMarshallerFeatures.TwoStageMarshalling"/>
/// </remarks>
public void* ToNativeValue() => _ptrToFirstChar;

/// <summary>
/// Sets the native value representing the string.
/// </summary>
/// <param name="value">The native value.</param>
/// <remarks>
/// <seealso cref="CustomTypeMarshallerFeatures.TwoStageMarshalling"/>
/// </remarks>
public void FromNativeValue(void* value)
{
_ptrToFirstChar = value;
_allocated = true;
}

/// <summary>
/// Returns the managed string.
/// </summary>
/// <remarks>
/// <seealso cref="CustomTypeMarshallerDirection.Out"/>
/// </remarks>
public string? ToManaged()
{
if (_ptrToFirstChar is null)
return null;

return Marshal.PtrToStringBSTR((IntPtr)_ptrToFirstChar);
}

/// <summary>
/// Frees native resources.
/// </summary>
/// <remarks>
/// <seealso cref="CustomTypeMarshallerFeatures.UnmanagedResources"/>
/// </remarks>
public void FreeNative()
{
if (_allocated)
Marshal.FreeBSTR((IntPtr)_ptrToFirstChar);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace System.Runtime.InteropServices.Marshalling
Features = CustomTypeMarshallerFeatures.UnmanagedResources | CustomTypeMarshallerFeatures.TwoStageMarshalling)]
public unsafe ref struct Utf16StringMarshaller
{
private ushort* _nativeValue;
private void* _nativeValue;

/// <summary>
/// Initializes a new instance of the <see cref="Utf16StringMarshaller"/>.
Expand All @@ -25,7 +25,7 @@ public unsafe ref struct Utf16StringMarshaller
/// <param name="str">The string to marshal.</param>
public Utf16StringMarshaller(string? str)
{
_nativeValue = (ushort*)Marshal.StringToCoTaskMemUni(str);
_nativeValue = (void*)Marshal.StringToCoTaskMemUni(str);
}

/// <summary>
Expand All @@ -34,7 +34,7 @@ public Utf16StringMarshaller(string? str)
/// <remarks>
/// <seealso cref="CustomTypeMarshallerFeatures.TwoStageMarshalling"/>
/// </remarks>
public ushort* ToNativeValue() => _nativeValue;
public void* ToNativeValue() => _nativeValue;

/// <summary>
/// Sets the native value representing the string.
Expand All @@ -43,7 +43,7 @@ public Utf16StringMarshaller(string? str)
/// <remarks>
/// <seealso cref="CustomTypeMarshallerFeatures.TwoStageMarshalling"/>
/// </remarks>
public void FromNativeValue(ushort* value) => _nativeValue = value;
public void FromNativeValue(void* value) => _nativeValue = value;

/// <summary>
/// Returns the managed string.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

namespace Microsoft.Interop
{

/// <summary>
/// Type used to pass on default marshalling details.
/// </summary>
Expand Down Expand Up @@ -72,7 +71,6 @@ public enum CharEncoding
Undefined,
Utf8,
Utf16,
Ansi,
Custom
}

Expand Down Expand Up @@ -761,7 +759,13 @@ private bool TryCreateTypeBasedMarshallingInfo(
}
else
{
marshallingInfo = CreateStringMarshallingInfo(type, _defaultInfo.CharEncoding);
marshallingInfo = _defaultInfo.CharEncoding switch
{
CharEncoding.Utf16 => CreateStringMarshallingInfo(type, TypeNames.Utf16StringMarshaller),
CharEncoding.Utf8 => CreateStringMarshallingInfo(type, TypeNames.Utf8StringMarshaller),
_ => throw new InvalidOperationException()
};

return true;
}

Expand Down Expand Up @@ -842,30 +846,25 @@ private MarshallingInfo CreateStringMarshallingInfo(
ITypeSymbol type,
UnmanagedType unmanagedType)
{
CharEncoding charEncoding = unmanagedType switch
string? marshallerName = unmanagedType switch
{
UnmanagedType.LPStr => CharEncoding.Ansi,
UnmanagedType.LPTStr or UnmanagedType.LPWStr => CharEncoding.Utf16,
MarshalAsInfo.UnmanagedType_LPUTF8Str => CharEncoding.Utf8,
_ => CharEncoding.Undefined
UnmanagedType.BStr => TypeNames.BStrStringMarshaller,
UnmanagedType.LPStr => TypeNames.AnsiStringMarshaller,
UnmanagedType.LPTStr or UnmanagedType.LPWStr => TypeNames.Utf16StringMarshaller,
MarshalAsInfo.UnmanagedType_LPUTF8Str => TypeNames.Utf8StringMarshaller,
_ => null
};
if (charEncoding == CharEncoding.Undefined)

if (marshallerName is null)
return new MarshalAsInfo(unmanagedType, _defaultInfo.CharEncoding);

return CreateStringMarshallingInfo(type, charEncoding);
return CreateStringMarshallingInfo(type, marshallerName);
}

private MarshallingInfo CreateStringMarshallingInfo(
ITypeSymbol type,
CharEncoding charEncoding)
string marshallerName)
{
string? marshallerName = charEncoding switch
{
CharEncoding.Ansi => TypeNames.AnsiStringMarshaller,
CharEncoding.Utf16 => TypeNames.Utf16StringMarshaller,
CharEncoding.Utf8 => TypeNames.Utf8StringMarshaller,
_ => throw new InvalidOperationException()
};
INamedTypeSymbol? stringMarshaller = _compilation.GetTypeByMetadataName(marshallerName);
if (stringMarshaller is null)
return new MissingSupportMarshallingInfo();
Expand All @@ -876,9 +875,9 @@ private MarshallingInfo CreateStringMarshallingInfo(
return CreateNativeMarshallingInfoForValue(
type,
stringMarshaller,
default,
null,
customTypeMarshallerData.Value,
allowPinningManagedType: charEncoding == CharEncoding.Utf16,
allowPinningManagedType: marshallerName is TypeNames.Utf16StringMarshaller,
useDefaultMarshalling: false);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ public static class TypeNames
public const string CustomTypeMarshallerAttributeGenericPlaceholder = "System.Runtime.InteropServices.Marshalling.CustomTypeMarshallerAttribute.GenericPlaceholder";

public const string AnsiStringMarshaller = "System.Runtime.InteropServices.Marshalling.AnsiStringMarshaller";
public const string BStrStringMarshaller = "System.Runtime.InteropServices.Marshalling.BStrStringMarshaller";
public const string Utf16StringMarshaller = "System.Runtime.InteropServices.Marshalling.Utf16StringMarshaller";
public const string Utf8StringMarshaller = "System.Runtime.InteropServices.Marshalling.Utf8StringMarshaller";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2103,6 +2103,20 @@ public void FromNativeValue(byte* value) { }
public T[]? ToManaged() { throw null; }
public void FreeNative() { }
}
[System.CLSCompliant(false)]
[System.Runtime.InteropServices.Marshalling.CustomTypeMarshallerAttribute(typeof(string), BufferSize = 0x100,
Features = System.Runtime.InteropServices.Marshalling.CustomTypeMarshallerFeatures.UnmanagedResources
| System.Runtime.InteropServices.Marshalling.CustomTypeMarshallerFeatures.CallerAllocatedBuffer
| System.Runtime.InteropServices.Marshalling.CustomTypeMarshallerFeatures.TwoStageMarshalling )]
public unsafe ref struct BStrStringMarshaller
{
public BStrStringMarshaller(string? str) { }
public BStrStringMarshaller(string? str, System.Span<ushort> buffer) { }
public void* ToNativeValue() { throw null; }
public void FromNativeValue(void* value) { }
public string? ToManaged() { throw null; }
public void FreeNative() { }
}
[System.AttributeUsageAttribute(System.AttributeTargets.Struct)]
public sealed partial class CustomTypeMarshallerAttribute : System.Attribute
{
Expand Down Expand Up @@ -2197,8 +2211,8 @@ public void FreeNative() { }
public unsafe ref struct Utf16StringMarshaller
{
public Utf16StringMarshaller(string? str) { }
public ushort* ToNativeValue() { throw null; }
public void FromNativeValue(ushort* value) { }
public void* ToNativeValue() { throw null; }
public void FromNativeValue(void* value) { }
public string? ToManaged() { throw null; }
public void FreeNative() { }
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ private class EntryPoints

private const string UShortSuffix = "_ushort";
private const string ByteSuffix = "_byte";
private const string BStrSuffix = "_bstr";

public class Byte
{
Expand All @@ -41,6 +42,15 @@ public class UShort
public const string ReverseInplace = EntryPoints.ReverseInplace + UShortSuffix;
public const string ReverseReplace = EntryPoints.ReverseReplace + UShortSuffix;
}

public class BStr
{
public const string ReturnLength = EntryPoints.ReturnLength + BStrSuffix;
public const string ReverseReturn = EntryPoints.ReverseReturn + BStrSuffix;
public const string ReverseOut = EntryPoints.ReverseOut + BStrSuffix;
public const string ReverseInplace = EntryPoints.ReverseInplace + BStrSuffix;
public const string ReverseReplace = EntryPoints.ReverseReplace + BStrSuffix;
}
}

public partial class Utf16
Expand Down Expand Up @@ -185,6 +195,31 @@ public partial class LPStr
public static partial void Reverse_Replace_Ref([MarshalAs(UnmanagedType.LPStr)] ref string s);
}

public partial class BStr
{
[LibraryImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.BStr.ReturnLength)]
public static partial int ReturnLength([MarshalAs(UnmanagedType.BStr)] string s);

[LibraryImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.BStr.ReturnLength, StringMarshalling = StringMarshalling.Utf16)]
public static partial int ReturnLength_IgnoreStringMarshalling([MarshalAs(UnmanagedType.BStr)] string s);

[LibraryImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.BStr.ReverseReturn)]
[return: MarshalAs(UnmanagedType.BStr)]
public static partial string Reverse_Return([MarshalAs(UnmanagedType.BStr)] string s);

[LibraryImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.BStr.ReverseOut)]
public static partial void Reverse_Out([MarshalAs(UnmanagedType.BStr)] string s, [MarshalAs(UnmanagedType.BStr)] out string ret);

[LibraryImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.BStr.ReverseInplace)]
public static partial void Reverse_Ref([MarshalAs(UnmanagedType.BStr)] ref string s);

[LibraryImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.BStr.ReverseInplace)]
public static partial void Reverse_In([MarshalAs(UnmanagedType.BStr)] in string s);

[LibraryImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.BStr.ReverseReplace)]
public static partial void Reverse_Replace_Ref([MarshalAs(UnmanagedType.BStr)] ref string s);
}

public partial class StringMarshallingCustomType
{
public partial class Utf16
Expand Down Expand Up @@ -418,6 +453,48 @@ public void AnsiStringByRef(string value)
Assert.Equal(expected, refValue);
}

[Theory]
[MemberData(nameof(UnicodeStrings))]
public void BStrStringMarshalledAsExpected(string value)
{
int expectedLen = value != null ? value.Length : -1;

Assert.Equal(expectedLen, NativeExportsNE.BStr.ReturnLength(value));
Assert.Equal(expectedLen, NativeExportsNE.BStr.ReturnLength_IgnoreStringMarshalling(value));
}

[Theory]
[MemberData(nameof(UnicodeStrings))]
public void BStrStringReturn(string value)
{
string expected = ReverseChars(value);

Assert.Equal(expected, NativeExportsNE.BStr.Reverse_Return(value));

string ret;
NativeExportsNE.BStr.Reverse_Out(value, out ret);
Assert.Equal(expected, ret);
}

[Theory]
[MemberData(nameof(UnicodeStrings))]
public void BStrStringByRef(string value)
{
string refValue = value;
string expected = ReverseChars(value);

NativeExportsNE.BStr.Reverse_In(in refValue);
Assert.Equal(value, refValue); // Should not be updated when using 'in'

refValue = value;
NativeExportsNE.BStr.Reverse_Ref(ref refValue);
Assert.Equal(expected, refValue);

refValue = value;
NativeExportsNE.BStr.Reverse_Replace_Ref(ref refValue);
Assert.Equal(expected, refValue);
}

[Theory]
[MemberData(nameof(UnicodeStrings))]
public void StringMarshallingCustomType_MarshalledAsExpected(string value)
Expand Down
Loading

0 comments on commit 810a7f9

Please sign in to comment.