diff --git a/src/libraries/System.Private.CoreLib/src/System.Private.CoreLib.Shared.projitems b/src/libraries/System.Private.CoreLib/src/System.Private.CoreLib.Shared.projitems
index 478bd1cb3d008..a374c87fb383f 100644
--- a/src/libraries/System.Private.CoreLib/src/System.Private.CoreLib.Shared.projitems
+++ b/src/libraries/System.Private.CoreLib/src/System.Private.CoreLib.Shared.projitems
@@ -866,6 +866,7 @@
+
diff --git a/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshalling/BStrStringMarshaller.cs b/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshalling/BStrStringMarshaller.cs
new file mode 100644
index 0000000000000..ce5a83a69e1ab
--- /dev/null
+++ b/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshalling/BStrStringMarshaller.cs
@@ -0,0 +1,125 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Diagnostics;
+using System.Runtime.CompilerServices;
+using System.Text;
+
+namespace System.Runtime.InteropServices.Marshalling
+{
+ ///
+ /// Marshaller for BSTR strings
+ ///
+ [CLSCompliant(false)]
+ [CustomTypeMarshaller(typeof(string), BufferSize = 0x100,
+ Features = CustomTypeMarshallerFeatures.UnmanagedResources | CustomTypeMarshallerFeatures.TwoStageMarshalling | CustomTypeMarshallerFeatures.CallerAllocatedBuffer)]
+ public unsafe ref struct BStrStringMarshaller
+ {
+ private void* _ptrToFirstChar;
+ private bool _allocated;
+
+ ///
+ /// Initializes a new instance of the .
+ ///
+ /// The string to marshal.
+ public BStrStringMarshaller(string? str)
+ : this(str, default)
+ { }
+
+ ///
+ /// Initializes a new instance of the .
+ ///
+ /// The string to marshal.
+ /// Buffer that may be used for marshalling.
+ ///
+ /// The must not be movable - that is, it should not be
+ /// on the managed heap or it should be pinned.
+ ///
+ ///
+ public BStrStringMarshaller(string? str, Span buffer)
+ {
+ _allocated = false;
+
+ if (str is null)
+ {
+ _ptrToFirstChar = null;
+ return;
+ }
+
+ ushort* ptrToFirstChar;
+ int lengthInBytes = checked(sizeof(char) * str.Length);
+
+ // A caller provided buffer must be at least (lengthInBytes + 6) bytes
+ // in order to be constructed manually. The 6 extra bytes are 4 for byte length and 2 for wide null.
+ int manualBstrNeeds = checked(lengthInBytes + 6);
+ if (manualBstrNeeds > buffer.Length)
+ {
+ // Use precise byte count when the provided stack-allocated buffer is not sufficient
+ ptrToFirstChar = (ushort*)Marshal.AllocBSTRByteLen((uint)lengthInBytes);
+ _allocated = true;
+ }
+ else
+ {
+ // Set length and update buffer target
+ byte* pBuffer = (byte*)Unsafe.AsPointer(ref MemoryMarshal.GetReference(buffer));
+ *((uint*)pBuffer) = (uint)lengthInBytes;
+ ptrToFirstChar = (ushort*)(pBuffer + sizeof(uint));
+ }
+
+ // Confirm the size is properly set for the allocated BSTR.
+ Debug.Assert(lengthInBytes == Marshal.SysStringByteLen((IntPtr)ptrToFirstChar));
+
+ // Copy characters from the managed string
+ str.CopyTo(new Span(ptrToFirstChar, str.Length));
+ ptrToFirstChar[str.Length] = '\0'; // null-terminate
+ _ptrToFirstChar = ptrToFirstChar;
+ }
+
+ ///
+ /// Returns the native value representing the string.
+ ///
+ ///
+ ///
+ ///
+ public void* ToNativeValue() => _ptrToFirstChar;
+
+ ///
+ /// Sets the native value representing the string.
+ ///
+ /// The native value.
+ ///
+ ///
+ ///
+ public void FromNativeValue(void* value)
+ {
+ _ptrToFirstChar = value;
+ _allocated = true;
+ }
+
+ ///
+ /// Returns the managed string.
+ ///
+ ///
+ ///
+ ///
+ public string? ToManaged()
+ {
+ if (_ptrToFirstChar is null)
+ return null;
+
+ return Marshal.PtrToStringBSTR((IntPtr)_ptrToFirstChar);
+ }
+
+ ///
+ /// Frees native resources.
+ ///
+ ///
+ ///
+ ///
+ public void FreeNative()
+ {
+ if (_allocated)
+ Marshal.FreeBSTR((IntPtr)_ptrToFirstChar);
+ }
+ }
+}
diff --git a/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshalling/Utf16StringMarshaller.cs b/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshalling/Utf16StringMarshaller.cs
index e207bec5f9fcf..04ada6d9d5132 100644
--- a/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshalling/Utf16StringMarshaller.cs
+++ b/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshalling/Utf16StringMarshaller.cs
@@ -13,7 +13,7 @@ namespace System.Runtime.InteropServices.Marshalling
Features = CustomTypeMarshallerFeatures.UnmanagedResources | CustomTypeMarshallerFeatures.TwoStageMarshalling)]
public unsafe ref struct Utf16StringMarshaller
{
- private ushort* _nativeValue;
+ private void* _nativeValue;
///
/// Initializes a new instance of the .
@@ -25,7 +25,7 @@ public unsafe ref struct Utf16StringMarshaller
/// The string to marshal.
public Utf16StringMarshaller(string? str)
{
- _nativeValue = (ushort*)Marshal.StringToCoTaskMemUni(str);
+ _nativeValue = (void*)Marshal.StringToCoTaskMemUni(str);
}
///
@@ -34,7 +34,7 @@ public Utf16StringMarshaller(string? str)
///
///
///
- public ushort* ToNativeValue() => _nativeValue;
+ public void* ToNativeValue() => _nativeValue;
///
/// Sets the native value representing the string.
@@ -43,7 +43,7 @@ public Utf16StringMarshaller(string? str)
///
///
///
- public void FromNativeValue(ushort* value) => _nativeValue = value;
+ public void FromNativeValue(void* value) => _nativeValue = value;
///
/// Returns the managed string.
diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingAttributeInfo.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingAttributeInfo.cs
index 33270eb954550..f2a98d6b246fd 100644
--- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingAttributeInfo.cs
+++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingAttributeInfo.cs
@@ -11,7 +11,6 @@
namespace Microsoft.Interop
{
-
///
/// Type used to pass on default marshalling details.
///
@@ -72,7 +71,6 @@ public enum CharEncoding
Undefined,
Utf8,
Utf16,
- Ansi,
Custom
}
@@ -761,7 +759,13 @@ private bool TryCreateTypeBasedMarshallingInfo(
}
else
{
- marshallingInfo = CreateStringMarshallingInfo(type, _defaultInfo.CharEncoding);
+ marshallingInfo = _defaultInfo.CharEncoding switch
+ {
+ CharEncoding.Utf16 => CreateStringMarshallingInfo(type, TypeNames.Utf16StringMarshaller),
+ CharEncoding.Utf8 => CreateStringMarshallingInfo(type, TypeNames.Utf8StringMarshaller),
+ _ => throw new InvalidOperationException()
+ };
+
return true;
}
@@ -842,30 +846,25 @@ private MarshallingInfo CreateStringMarshallingInfo(
ITypeSymbol type,
UnmanagedType unmanagedType)
{
- CharEncoding charEncoding = unmanagedType switch
+ string? marshallerName = unmanagedType switch
{
- UnmanagedType.LPStr => CharEncoding.Ansi,
- UnmanagedType.LPTStr or UnmanagedType.LPWStr => CharEncoding.Utf16,
- MarshalAsInfo.UnmanagedType_LPUTF8Str => CharEncoding.Utf8,
- _ => CharEncoding.Undefined
+ UnmanagedType.BStr => TypeNames.BStrStringMarshaller,
+ UnmanagedType.LPStr => TypeNames.AnsiStringMarshaller,
+ UnmanagedType.LPTStr or UnmanagedType.LPWStr => TypeNames.Utf16StringMarshaller,
+ MarshalAsInfo.UnmanagedType_LPUTF8Str => TypeNames.Utf8StringMarshaller,
+ _ => null
};
- if (charEncoding == CharEncoding.Undefined)
+
+ if (marshallerName is null)
return new MarshalAsInfo(unmanagedType, _defaultInfo.CharEncoding);
- return CreateStringMarshallingInfo(type, charEncoding);
+ return CreateStringMarshallingInfo(type, marshallerName);
}
private MarshallingInfo CreateStringMarshallingInfo(
ITypeSymbol type,
- CharEncoding charEncoding)
+ string marshallerName)
{
- string? marshallerName = charEncoding switch
- {
- CharEncoding.Ansi => TypeNames.AnsiStringMarshaller,
- CharEncoding.Utf16 => TypeNames.Utf16StringMarshaller,
- CharEncoding.Utf8 => TypeNames.Utf8StringMarshaller,
- _ => throw new InvalidOperationException()
- };
INamedTypeSymbol? stringMarshaller = _compilation.GetTypeByMetadataName(marshallerName);
if (stringMarshaller is null)
return new MissingSupportMarshallingInfo();
@@ -876,9 +875,9 @@ private MarshallingInfo CreateStringMarshallingInfo(
return CreateNativeMarshallingInfoForValue(
type,
stringMarshaller,
- default,
+ null,
customTypeMarshallerData.Value,
- allowPinningManagedType: charEncoding == CharEncoding.Utf16,
+ allowPinningManagedType: marshallerName is TypeNames.Utf16StringMarshaller,
useDefaultMarshalling: false);
}
diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/TypeNames.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/TypeNames.cs
index 876dc649001f9..778c1cd2beb39 100644
--- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/TypeNames.cs
+++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/TypeNames.cs
@@ -18,6 +18,7 @@ public static class TypeNames
public const string CustomTypeMarshallerAttributeGenericPlaceholder = "System.Runtime.InteropServices.Marshalling.CustomTypeMarshallerAttribute.GenericPlaceholder";
public const string AnsiStringMarshaller = "System.Runtime.InteropServices.Marshalling.AnsiStringMarshaller";
+ public const string BStrStringMarshaller = "System.Runtime.InteropServices.Marshalling.BStrStringMarshaller";
public const string Utf16StringMarshaller = "System.Runtime.InteropServices.Marshalling.Utf16StringMarshaller";
public const string Utf8StringMarshaller = "System.Runtime.InteropServices.Marshalling.Utf8StringMarshaller";
diff --git a/src/libraries/System.Runtime.InteropServices/ref/System.Runtime.InteropServices.cs b/src/libraries/System.Runtime.InteropServices/ref/System.Runtime.InteropServices.cs
index ca618ba5c7b93..895df2ce26a16 100644
--- a/src/libraries/System.Runtime.InteropServices/ref/System.Runtime.InteropServices.cs
+++ b/src/libraries/System.Runtime.InteropServices/ref/System.Runtime.InteropServices.cs
@@ -2103,6 +2103,20 @@ public void FromNativeValue(byte* value) { }
public T[]? ToManaged() { throw null; }
public void FreeNative() { }
}
+ [System.CLSCompliant(false)]
+ [System.Runtime.InteropServices.Marshalling.CustomTypeMarshallerAttribute(typeof(string), BufferSize = 0x100,
+ Features = System.Runtime.InteropServices.Marshalling.CustomTypeMarshallerFeatures.UnmanagedResources
+ | System.Runtime.InteropServices.Marshalling.CustomTypeMarshallerFeatures.CallerAllocatedBuffer
+ | System.Runtime.InteropServices.Marshalling.CustomTypeMarshallerFeatures.TwoStageMarshalling )]
+ public unsafe ref struct BStrStringMarshaller
+ {
+ public BStrStringMarshaller(string? str) { }
+ public BStrStringMarshaller(string? str, System.Span buffer) { }
+ public void* ToNativeValue() { throw null; }
+ public void FromNativeValue(void* value) { }
+ public string? ToManaged() { throw null; }
+ public void FreeNative() { }
+ }
[System.AttributeUsageAttribute(System.AttributeTargets.Struct)]
public sealed partial class CustomTypeMarshallerAttribute : System.Attribute
{
@@ -2197,8 +2211,8 @@ public void FreeNative() { }
public unsafe ref struct Utf16StringMarshaller
{
public Utf16StringMarshaller(string? str) { }
- public ushort* ToNativeValue() { throw null; }
- public void FromNativeValue(ushort* value) { }
+ public void* ToNativeValue() { throw null; }
+ public void FromNativeValue(void* value) { }
public string? ToManaged() { throw null; }
public void FreeNative() { }
}
diff --git a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/StringTests.cs b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/StringTests.cs
index 9987c687f8455..b518f6ed0d9fa 100644
--- a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/StringTests.cs
+++ b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/StringTests.cs
@@ -23,6 +23,7 @@ private class EntryPoints
private const string UShortSuffix = "_ushort";
private const string ByteSuffix = "_byte";
+ private const string BStrSuffix = "_bstr";
public class Byte
{
@@ -41,6 +42,15 @@ public class UShort
public const string ReverseInplace = EntryPoints.ReverseInplace + UShortSuffix;
public const string ReverseReplace = EntryPoints.ReverseReplace + UShortSuffix;
}
+
+ public class BStr
+ {
+ public const string ReturnLength = EntryPoints.ReturnLength + BStrSuffix;
+ public const string ReverseReturn = EntryPoints.ReverseReturn + BStrSuffix;
+ public const string ReverseOut = EntryPoints.ReverseOut + BStrSuffix;
+ public const string ReverseInplace = EntryPoints.ReverseInplace + BStrSuffix;
+ public const string ReverseReplace = EntryPoints.ReverseReplace + BStrSuffix;
+ }
}
public partial class Utf16
@@ -185,6 +195,31 @@ public partial class LPStr
public static partial void Reverse_Replace_Ref([MarshalAs(UnmanagedType.LPStr)] ref string s);
}
+ public partial class BStr
+ {
+ [LibraryImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.BStr.ReturnLength)]
+ public static partial int ReturnLength([MarshalAs(UnmanagedType.BStr)] string s);
+
+ [LibraryImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.BStr.ReturnLength, StringMarshalling = StringMarshalling.Utf16)]
+ public static partial int ReturnLength_IgnoreStringMarshalling([MarshalAs(UnmanagedType.BStr)] string s);
+
+ [LibraryImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.BStr.ReverseReturn)]
+ [return: MarshalAs(UnmanagedType.BStr)]
+ public static partial string Reverse_Return([MarshalAs(UnmanagedType.BStr)] string s);
+
+ [LibraryImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.BStr.ReverseOut)]
+ public static partial void Reverse_Out([MarshalAs(UnmanagedType.BStr)] string s, [MarshalAs(UnmanagedType.BStr)] out string ret);
+
+ [LibraryImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.BStr.ReverseInplace)]
+ public static partial void Reverse_Ref([MarshalAs(UnmanagedType.BStr)] ref string s);
+
+ [LibraryImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.BStr.ReverseInplace)]
+ public static partial void Reverse_In([MarshalAs(UnmanagedType.BStr)] in string s);
+
+ [LibraryImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.BStr.ReverseReplace)]
+ public static partial void Reverse_Replace_Ref([MarshalAs(UnmanagedType.BStr)] ref string s);
+ }
+
public partial class StringMarshallingCustomType
{
public partial class Utf16
@@ -418,6 +453,48 @@ public void AnsiStringByRef(string value)
Assert.Equal(expected, refValue);
}
+ [Theory]
+ [MemberData(nameof(UnicodeStrings))]
+ public void BStrStringMarshalledAsExpected(string value)
+ {
+ int expectedLen = value != null ? value.Length : -1;
+
+ Assert.Equal(expectedLen, NativeExportsNE.BStr.ReturnLength(value));
+ Assert.Equal(expectedLen, NativeExportsNE.BStr.ReturnLength_IgnoreStringMarshalling(value));
+ }
+
+ [Theory]
+ [MemberData(nameof(UnicodeStrings))]
+ public void BStrStringReturn(string value)
+ {
+ string expected = ReverseChars(value);
+
+ Assert.Equal(expected, NativeExportsNE.BStr.Reverse_Return(value));
+
+ string ret;
+ NativeExportsNE.BStr.Reverse_Out(value, out ret);
+ Assert.Equal(expected, ret);
+ }
+
+ [Theory]
+ [MemberData(nameof(UnicodeStrings))]
+ public void BStrStringByRef(string value)
+ {
+ string refValue = value;
+ string expected = ReverseChars(value);
+
+ NativeExportsNE.BStr.Reverse_In(in refValue);
+ Assert.Equal(value, refValue); // Should not be updated when using 'in'
+
+ refValue = value;
+ NativeExportsNE.BStr.Reverse_Ref(ref refValue);
+ Assert.Equal(expected, refValue);
+
+ refValue = value;
+ NativeExportsNE.BStr.Reverse_Replace_Ref(ref refValue);
+ Assert.Equal(expected, refValue);
+ }
+
[Theory]
[MemberData(nameof(UnicodeStrings))]
public void StringMarshallingCustomType_MarshalledAsExpected(string value)
diff --git a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/Compiles.cs b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/Compiles.cs
index e8ca396d89560..3448c83275caf 100644
--- a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/Compiles.cs
+++ b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/Compiles.cs
@@ -108,9 +108,11 @@ public static IEnumerable