diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGeneratorHelpers.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGeneratorHelpers.cs index f9f0bf684ff821..409575057ca547 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGeneratorHelpers.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGeneratorHelpers.cs @@ -11,7 +11,7 @@ namespace Microsoft.Interop { internal static class ComInterfaceGeneratorHelpers { - public static MarshallingGeneratorFactoryKey<(TargetFramework, Version)> CreateGeneratorFactory(StubEnvironment env) + public static MarshallingGeneratorFactoryKey<(TargetFramework, Version)> CreateGeneratorFactory(StubEnvironment env, MarshalDirection direction) { IMarshallingGeneratorFactory generatorFactory; @@ -44,7 +44,17 @@ internal static class ComInterfaceGeneratorHelpers generatorFactory = new AttributedMarshallingModelGeneratorFactory( generatorFactory, elementFactory, - new AttributedMarshallingModelOptions(runtimeMarshallingDisabled, MarshalMode.ManagedToUnmanagedIn, MarshalMode.ManagedToUnmanagedRef, MarshalMode.ManagedToUnmanagedOut)); + new AttributedMarshallingModelOptions( + runtimeMarshallingDisabled, + direction == MarshalDirection.ManagedToUnmanaged + ? MarshalMode.ManagedToUnmanagedIn + : MarshalMode.UnmanagedToManagedOut, + direction == MarshalDirection.ManagedToUnmanaged + ? MarshalMode.ManagedToUnmanagedRef + : MarshalMode.UnmanagedToManagedRef, + direction == MarshalDirection.ManagedToUnmanaged + ? MarshalMode.ManagedToUnmanagedOut + : MarshalMode.UnmanagedToManagedIn)); generatorFactory = new ByValueContentsMarshalKindValidator(generatorFactory); diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/VtableIndexStubGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/VtableIndexStubGenerator.cs index 0805cef07c7eb5..4f06a269dcf4fa 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/VtableIndexStubGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/VtableIndexStubGenerator.cs @@ -26,7 +26,8 @@ internal sealed record IncrementalStubGenerationContext( MethodSignatureDiagnosticLocations DiagnosticLocation, SequenceEqualImmutableArray CallingConvention, VirtualMethodIndexData VtableIndexData, - MarshallingGeneratorFactoryKey<(TargetFramework TargetFramework, Version TargetFrameworkVersion)> GeneratorFactory, + MarshallingGeneratorFactoryKey<(TargetFramework TargetFramework, Version TargetFrameworkVersion)> ManagedToUnmanagedGeneratorFactory, + MarshallingGeneratorFactoryKey<(TargetFramework TargetFramework, Version TargetFrameworkVersion)> UnmanagedToManagedGeneratorFactory, ManagedTypeInfo TypeKeyType, ManagedTypeInfo TypeKeyOwner, SequenceEqualImmutableArray Diagnostics); @@ -301,7 +302,8 @@ private static IncrementalStubGenerationContext CalculateStubInformation(MethodD new MethodSignatureDiagnosticLocations(syntax), new SequenceEqualImmutableArray(callConv, SyntaxEquivalentComparer.Instance), virtualMethodIndexData, - ComInterfaceGeneratorHelpers.CreateGeneratorFactory(environment), + ComInterfaceGeneratorHelpers.CreateGeneratorFactory(environment, MarshalDirection.ManagedToUnmanaged), + ComInterfaceGeneratorHelpers.CreateGeneratorFactory(environment, MarshalDirection.UnmanagedToManaged), typeKeyType, typeKeyOwner, new SequenceEqualImmutableArray(generatorDiagnostics.Diagnostics.ToImmutableArray())); @@ -337,8 +339,8 @@ private static (MemberDeclarationSyntax, ImmutableArray) GenerateMan // Generate stub code var stubGenerator = new ManagedToNativeVTableMethodGenerator( - methodStub.GeneratorFactory.Key.TargetFramework, - methodStub.GeneratorFactory.Key.TargetFrameworkVersion, + methodStub.ManagedToUnmanagedGeneratorFactory.Key.TargetFramework, + methodStub.ManagedToUnmanagedGeneratorFactory.Key.TargetFrameworkVersion, methodStub.SignatureContext.ElementTypeInformation, methodStub.VtableIndexData.SetLastError, methodStub.VtableIndexData.ImplicitThisParameter, @@ -346,7 +348,7 @@ private static (MemberDeclarationSyntax, ImmutableArray) GenerateMan { diagnostics.ReportMarshallingNotSupported(methodStub.DiagnosticLocation, elementInfo, ex.NotSupportedDetails); }, - methodStub.GeneratorFactory.GeneratorFactory); + methodStub.ManagedToUnmanagedGeneratorFactory.GeneratorFactory); BlockSyntax code = stubGenerator.GenerateStubBody( methodStub.VtableIndexData.Index, @@ -370,19 +372,6 @@ private static (MemberDeclarationSyntax, ImmutableArray) GenerateMan methodStub.Diagnostics.Array.AddRange(diagnostics.Diagnostics)); } - private static bool ShouldVisitNode(SyntaxNode syntaxNode) - { - // We only support C# method declarations. - if (syntaxNode.Language != LanguageNames.CSharp - || !syntaxNode.IsKind(SyntaxKind.MethodDeclaration)) - { - return false; - } - - // Filter out methods with no attributes early. - return ((MethodDeclarationSyntax)syntaxNode).AttributeLists.Count > 0; - } - private static Diagnostic? GetDiagnosticIfInvalidMethodForGeneration(MethodDeclarationSyntax methodSyntax, IMethodSymbol method) { // Verify the method has no generic types or defined implementation diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs index 27f13e68272b40..838cabf9576867 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs @@ -12,7 +12,7 @@ namespace Microsoft.Interop { - public readonly record struct AttributedMarshallingModelOptions(bool RuntimeMarshallingDisabled, MarshalMode InMode, MarshalMode RefMode, MarshalMode OutMode); + public readonly record struct AttributedMarshallingModelOptions(bool RuntimeMarshallingDisabled, MarshalMode ManagedToUnmanagedMode, MarshalMode BidirectionalMode, MarshalMode UnmanagedToManagedMode); public class AttributedMarshallingModelGeneratorFactory : IMarshallingGeneratorFactory { @@ -126,7 +126,7 @@ ExpressionSyntax GetExpressionForParam(TypePositionInfo paramInfo, out bool isIn { if (marshallingInfo is NativeLinearCollectionMarshallingInfo collectionInfo) { - CustomTypeMarshallerData marshallerData = GetMarshallerDataForTypePositionInfo(collectionInfo.Marshallers, info); + CustomTypeMarshallerData marshallerData = GetMarshallerDataForTypePositionInfo(collectionInfo.Marshallers, info, context); type = marshallerData.CollectionElementType; marshallingInfo = marshallerData.CollectionElementMarshallingInfo; } @@ -200,16 +200,15 @@ private bool ValidateRuntimeMarshallingOptions(CustomTypeMarshallerData marshall return false; } - private CustomTypeMarshallerData GetMarshallerDataForTypePositionInfo(CustomTypeMarshallers marshallers, TypePositionInfo info) + private CustomTypeMarshallerData GetMarshallerDataForTypePositionInfo(CustomTypeMarshallers marshallers, TypePositionInfo info, StubCodeContext context) { - if (info.IsManagedReturnPosition) - return marshallers.GetModeOrDefault(Options.OutMode); + MarshalDirection elementDirection = MarshallerHelpers.GetMarshalDirection(info, context); - return info.RefKind switch + return elementDirection switch { - RefKind.None or RefKind.In => marshallers.GetModeOrDefault(Options.InMode), - RefKind.Ref => marshallers.GetModeOrDefault(Options.RefMode), - RefKind.Out => marshallers.GetModeOrDefault(Options.OutMode), + MarshalDirection.ManagedToUnmanaged => marshallers.GetModeOrDefault(Options.ManagedToUnmanagedMode), + MarshalDirection.Bidirectional => marshallers.GetModeOrDefault(Options.BidirectionalMode), + MarshalDirection.UnmanagedToManaged => marshallers.GetModeOrDefault(Options.UnmanagedToManagedMode), _ => throw new UnreachableException() }; } @@ -218,7 +217,7 @@ private IMarshallingGenerator CreateCustomNativeTypeMarshaller(TypePositionInfo { ValidateCustomNativeTypeMarshallingSupported(info, context, marshalInfo); - CustomTypeMarshallerData marshallerData = GetMarshallerDataForTypePositionInfo(marshalInfo.Marshallers, info); + CustomTypeMarshallerData marshallerData = GetMarshallerDataForTypePositionInfo(marshalInfo.Marshallers, info, context); if (!ValidateRuntimeMarshallingOptions(marshallerData)) { throw new MarshallingNotSupportedException(info, context) @@ -373,9 +372,10 @@ private static TypeSyntax ReplacePlaceholderSyntaxWithUnmanagedTypeSyntax( private void ValidateCustomNativeTypeMarshallingSupported(TypePositionInfo info, StubCodeContext context, NativeMarshallingAttributeInfo marshalInfo) { + MarshalDirection elementDirection = MarshallerHelpers.GetMarshalDirection(info, context); // Marshalling out or return parameter, but no out marshaller is specified - if ((info.RefKind == RefKind.Out || info.IsManagedReturnPosition) - && !marshalInfo.Marshallers.IsDefinedOrDefault(Options.OutMode)) + if (elementDirection == MarshalDirection.UnmanagedToManaged + && !marshalInfo.Marshallers.IsDefinedOrDefault(Options.UnmanagedToManagedMode)) { throw new MarshallingNotSupportedException(info, context) { @@ -384,7 +384,7 @@ private void ValidateCustomNativeTypeMarshallingSupported(TypePositionInfo info, } // Marshalling ref parameter, but no ref marshaller is specified - if (info.RefKind == RefKind.Ref && !marshalInfo.Marshallers.IsDefinedOrDefault(Options.RefMode)) + if (elementDirection == MarshalDirection.Bidirectional && !marshalInfo.Marshallers.IsDefinedOrDefault(Options.BidirectionalMode)) { throw new MarshallingNotSupportedException(info, context) { @@ -393,20 +393,8 @@ private void ValidateCustomNativeTypeMarshallingSupported(TypePositionInfo info, } // Marshalling in parameter, but no in marshaller is specified - if (info.RefKind == RefKind.In - && !marshalInfo.Marshallers.IsDefinedOrDefault(Options.InMode)) - { - throw new MarshallingNotSupportedException(info, context) - { - NotSupportedDetails = SR.Format(SR.ManagedToUnmanagedMissingRequiredMarshaller, marshalInfo.EntryPointType.FullTypeName) - }; - } - - // Marshalling by value, but no in marshaller is specified - if (!info.IsByRef - && !info.IsManagedReturnPosition - && context.SingleFrameSpansNativeContext - && !marshalInfo.Marshallers.IsDefinedOrDefault(Options.InMode)) + if (elementDirection == MarshalDirection.ManagedToUnmanaged + && !marshalInfo.Marshallers.IsDefinedOrDefault(Options.ManagedToUnmanagedMode)) { throw new MarshallingNotSupportedException(info, context) { diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/BlittableMarshaller.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/BlittableMarshaller.cs index bc2de37e5b5486..8d76a667634a58 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/BlittableMarshaller.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/BlittableMarshaller.cs @@ -66,12 +66,14 @@ public IEnumerable Generate(TypePositionInfo info, StubCodeCont yield break; } + MarshalDirection elementMarshalling = MarshallerHelpers.GetMarshalDirection(info, context); + switch (context.CurrentStage) { case StubCodeContext.Stage.Setup: break; case StubCodeContext.Stage.Marshal: - if (info.RefKind == RefKind.Ref) + if (elementMarshalling is MarshalDirection.ManagedToUnmanaged or MarshalDirection.Bidirectional && info.IsByRef) { yield return ExpressionStatement( AssignmentExpression( @@ -82,11 +84,14 @@ public IEnumerable Generate(TypePositionInfo info, StubCodeCont break; case StubCodeContext.Stage.Unmarshal: - yield return ExpressionStatement( - AssignmentExpression( - SyntaxKind.SimpleAssignmentExpression, - IdentifierName(managedIdentifier), - IdentifierName(nativeIdentifier))); + if (elementMarshalling is MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional && info.IsByRef) + { + yield return ExpressionStatement( + AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + IdentifierName(managedIdentifier), + IdentifierName(nativeIdentifier))); + } break; default: break; diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/BoolMarshaller.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/BoolMarshaller.cs index 7ed1dadeba8822..12e96294a8dcc4 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/BoolMarshaller.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/BoolMarshaller.cs @@ -52,6 +52,7 @@ public ValueBoundaryBehavior GetValueBoundaryBehavior(TypePositionInfo info, Stu public IEnumerable Generate(TypePositionInfo info, StubCodeContext context) { + MarshalDirection elementMarshalDirection = MarshallerHelpers.GetMarshalDirection(info, context); (string managedIdentifier, string nativeIdentifier) = context.GetIdentifiers(info); switch (context.CurrentStage) { @@ -59,7 +60,7 @@ public IEnumerable Generate(TypePositionInfo info, StubCodeCont break; case StubCodeContext.Stage.Marshal: // = ()( ? _trueValue : _falseValue); - if (info.RefKind != RefKind.Out) + if (elementMarshalDirection is MarshalDirection.ManagedToUnmanaged or MarshalDirection.Bidirectional) { yield return ExpressionStatement( AssignmentExpression( @@ -75,7 +76,7 @@ public IEnumerable Generate(TypePositionInfo info, StubCodeCont break; case StubCodeContext.Stage.Unmarshal: - if (info.IsManagedReturnPosition || (info.IsByRef && info.RefKind != RefKind.In)) + if (elementMarshalDirection is MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional) { // = == _trueValue; // or diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/CharMarshaller.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/CharMarshaller.cs index def4ae4686a0cc..5fcd4056f4275b 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/CharMarshaller.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/CharMarshaller.cs @@ -82,23 +82,30 @@ public IEnumerable Generate(TypePositionInfo info, StubCodeCont yield break; } + MarshalDirection elementMarshalDirection = MarshallerHelpers.GetMarshalDirection(info, context); + switch (context.CurrentStage) { case StubCodeContext.Stage.Setup: break; case StubCodeContext.Stage.Marshal: - if ((info.IsByRef && info.RefKind != RefKind.Out) || !context.SingleFrameSpansNativeContext) + if (elementMarshalDirection is MarshalDirection.ManagedToUnmanaged or MarshalDirection.Bidirectional) { - yield return ExpressionStatement( - AssignmentExpression( - SyntaxKind.SimpleAssignmentExpression, - IdentifierName(nativeIdentifier), - IdentifierName(managedIdentifier))); + // There's an implicit conversion from char to ushort, + // so we simplify the generated code to just pass the char value directly + if (info.IsByRef) + { + yield return ExpressionStatement( + AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + IdentifierName(nativeIdentifier), + IdentifierName(managedIdentifier))); + } } break; case StubCodeContext.Stage.Unmarshal: - if (info.IsManagedReturnPosition || (info.IsByRef && info.RefKind != RefKind.In)) + if (elementMarshalDirection is MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional) { yield return ExpressionStatement( AssignmentExpression( diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/CustomTypeMarshallingGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/CustomTypeMarshallingGenerator.cs index 012dc6fc9cbd75..c84974de06e3c2 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/CustomTypeMarshallingGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/CustomTypeMarshallingGenerator.cs @@ -45,6 +45,7 @@ public SignatureBehavior GetNativeSignatureBehavior(TypePositionInfo info) public IEnumerable Generate(TypePositionInfo info, StubCodeContext context) { + MarshalDirection elementMarshalDirection = MarshallerHelpers.GetMarshalDirection(info, context); // Although custom native type marshalling doesn't support [In] or [Out] by value marshalling, // other marshallers that wrap this one might, so we handle the correct cases here. switch (context.CurrentStage) @@ -52,45 +53,44 @@ public IEnumerable Generate(TypePositionInfo info, StubCodeCont case StubCodeContext.Stage.Setup: return _nativeTypeMarshaller.GenerateSetupStatements(info, context); case StubCodeContext.Stage.Marshal: - if (!info.IsManagedReturnPosition && info.RefKind != RefKind.Out) + if (elementMarshalDirection is MarshalDirection.ManagedToUnmanaged or MarshalDirection.Bidirectional) { return _nativeTypeMarshaller.GenerateMarshalStatements(info, context); } break; case StubCodeContext.Stage.Pin: - if (!info.IsByRef || info.RefKind == RefKind.In) + if (context.SingleFrameSpansNativeContext && elementMarshalDirection is MarshalDirection.ManagedToUnmanaged) { return _nativeTypeMarshaller.GeneratePinStatements(info, context); } break; case StubCodeContext.Stage.PinnedMarshal: - if (!info.IsManagedReturnPosition && info.RefKind != RefKind.Out) + if (elementMarshalDirection is MarshalDirection.ManagedToUnmanaged or MarshalDirection.Bidirectional) { return _nativeTypeMarshaller.GeneratePinnedMarshalStatements(info, context); } break; case StubCodeContext.Stage.NotifyForSuccessfulInvoke: - if (!info.IsManagedReturnPosition && info.RefKind != RefKind.Out) + if (elementMarshalDirection is MarshalDirection.ManagedToUnmanaged or MarshalDirection.Bidirectional) { return _nativeTypeMarshaller.GenerateNotifyForSuccessfulInvokeStatements(info, context); } break; case StubCodeContext.Stage.UnmarshalCapture: - if (info.IsManagedReturnPosition || (info.IsByRef && info.RefKind != RefKind.In)) + if (elementMarshalDirection is MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional) { return _nativeTypeMarshaller.GenerateUnmarshalCaptureStatements(info, context); } break; case StubCodeContext.Stage.Unmarshal: - if (info.IsManagedReturnPosition || (info.IsByRef && info.RefKind != RefKind.In) + if (elementMarshalDirection is MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional || (_enableByValueContentsMarshalling && !info.IsByRef && info.ByValueContentsMarshalKind.HasFlag(ByValueContentsMarshalKind.Out))) { return _nativeTypeMarshaller.GenerateUnmarshalStatements(info, context); } break; case StubCodeContext.Stage.GuaranteedUnmarshal: - if (info.IsManagedReturnPosition - || (info.IsByRef && info.RefKind != RefKind.In) + if (elementMarshalDirection is MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional || (_enableByValueContentsMarshalling && !info.IsByRef && info.ByValueContentsMarshalKind.HasFlag(ByValueContentsMarshalKind.Out))) { return _nativeTypeMarshaller.GenerateGuaranteedUnmarshalStatements(info, context); diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/DelegateMarshaller.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/DelegateMarshaller.cs index 00c7f6e2b941b0..3276fe37876616 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/DelegateMarshaller.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/DelegateMarshaller.cs @@ -31,15 +31,14 @@ public ValueBoundaryBehavior GetValueBoundaryBehavior(TypePositionInfo info, Stu public IEnumerable Generate(TypePositionInfo info, StubCodeContext context) { - // [TODO] Handle byrefs in a more common place? - // This pattern will become very common (arrays and strings will also use it) + MarshalDirection elementMarshalDirection = MarshallerHelpers.GetMarshalDirection(info, context); (string managedIdentifier, string nativeIdentifier) = context.GetIdentifiers(info); switch (context.CurrentStage) { case StubCodeContext.Stage.Setup: break; case StubCodeContext.Stage.Marshal: - if (info.RefKind != RefKind.Out) + if (elementMarshalDirection is MarshalDirection.ManagedToUnmanaged or MarshalDirection.Bidirectional) { // = != null ? Marshal.GetFunctionPointerForDelegate() : default; yield return ExpressionStatement( @@ -62,7 +61,7 @@ public IEnumerable Generate(TypePositionInfo info, StubCodeCont } break; case StubCodeContext.Stage.Unmarshal: - if (info.IsManagedReturnPosition || (info.IsByRef && info.RefKind != RefKind.In)) + if (elementMarshalDirection is MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional) { // = != default : Marshal.GetDelegateForFunctionPointer<>() : null; yield return ExpressionStatement( @@ -88,7 +87,7 @@ public IEnumerable Generate(TypePositionInfo info, StubCodeCont } break; case StubCodeContext.Stage.NotifyForSuccessfulInvoke: - if (info.RefKind != RefKind.Out) + if (elementMarshalDirection is MarshalDirection.ManagedToUnmanaged or MarshalDirection.Bidirectional) { yield return ExpressionStatement( InvocationExpression( diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/MarshallerHelpers.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/MarshallerHelpers.cs index e633cfcf45f3da..f6129c9a5abb5a 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/MarshallerHelpers.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/MarshallerHelpers.cs @@ -332,5 +332,63 @@ public static StatementSyntax SkipInitOrDefaultInit(TypePositionInfo info, StubC Token(SyntaxKind.DefaultKeyword)))); } } + + /// + /// Get the marshalling direction for a given in a given . + /// For example, an out parameter is marshalled in the direction in a stub, + /// but from in a stub. + /// + /// The info for an element. + /// The context for the stub. + /// The direction the element is marshalled. + public static MarshalDirection GetMarshalDirection(TypePositionInfo info, StubCodeContext context) + { + if (context.Direction is not (MarshalDirection.ManagedToUnmanaged or MarshalDirection.UnmanagedToManaged)) + { + throw new ArgumentException("Stub context direction must not be bidirectional."); + } + + if (context.Direction == MarshalDirection.ManagedToUnmanaged) + { + if (info.IsManagedReturnPosition) + { + return MarshalDirection.UnmanagedToManaged; + } + if (!info.IsByRef) + { + return MarshalDirection.ManagedToUnmanaged; + } + switch (info.RefKind) + { + case RefKind.In: + return MarshalDirection.ManagedToUnmanaged; + case RefKind.Ref: + return MarshalDirection.Bidirectional; + case RefKind.Out: + return MarshalDirection.UnmanagedToManaged; + } + throw new UnreachableException("An element is either a return value or passed by value or by ref."); + } + + + if (info.IsNativeReturnPosition) + { + return MarshalDirection.ManagedToUnmanaged; + } + if (!info.IsByRef) + { + return MarshalDirection.UnmanagedToManaged; + } + switch (info.RefKind) + { + case RefKind.In: + return MarshalDirection.UnmanagedToManaged; + case RefKind.Ref: + return MarshalDirection.Bidirectional; + case RefKind.Out: + return MarshalDirection.ManagedToUnmanaged; + } + throw new UnreachableException("An element is either a return value or passed by value or by ref."); + } } } diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/UnreachableException.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/UnreachableException.cs index 203657801ccc8a..48f0d77b1b4fcb 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/UnreachableException.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/UnreachableException.cs @@ -12,5 +12,8 @@ namespace Microsoft.Interop /// internal sealed class UnreachableException : Exception { + public UnreachableException() { } + + public UnreachableException(string message) : base(message) { } } }