Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
ec13dc8
Deduplicate base interfaces from another assembly in COM generator
jtschuster Jan 28, 2025
47651ef
PR feedback: Sort usings, generate new guid
jtschuster Jan 28, 2025
92efab7
Fix test
jtschuster Jan 29, 2025
a3204be
Use max(vtableIndex) for vtable allocation rather than methods length
jtschuster Jan 30, 2025
9877bf7
Merge branch 'main' of https://github.com/dotnet/runtime into TwoInhe…
jtschuster Mar 26, 2025
52c9a22
Account for 0 methods
jtschuster Mar 26, 2025
c99bf15
Merge branch 'main' of https://github.com/dotnet/runtime into TwoInhe…
jtschuster Jun 11, 2025
0b349a7
Use VTableSize property instead of ad-hoc calculations
jtschuster Jun 11, 2025
7f83082
Merge branch 'main' of https://github.com/dotnet/runtime into TwoInhe…
jtschuster Jul 11, 2025
5f29e70
Move ComInterfaces folder to Common/
jtschuster Jul 14, 2025
5a619e3
Working implementation
jtschuster Jul 14, 2025
1c09a78
Add tests for cross assembly inheritance
jtschuster Jul 15, 2025
63315f4
Extract identical code for CalculateStubInformation
jtschuster Jul 15, 2025
602d2dc
Remove comments, use SkippedStubContext instead of throwing
jtschuster Jul 15, 2025
dfcc019
Remove unnecessary changes
jtschuster Jul 16, 2025
1e6b8a8
Merge branch 'main' into TwoInheritingIfacesFromDll
jtschuster Jul 18, 2025
f8d7e97
Cast to nullable instead of using 'as'
jtschuster Jul 18, 2025
eac65bb
Update src/libraries/System.Runtime.InteropServices/gen/ComInterfaceG…
jtschuster Jul 18, 2025
e2290b3
Merge branch 'main' of https://github.com/dotnet/runtime into TwoInhe…
jtschuster Jul 28, 2025
9711cdc
Fix build issues after merge
jtschuster Jul 28, 2025
b1c0f74
Remove duplicate assertion, uncomment test code
jtschuster Jul 28, 2025
a24cf60
Merge branch 'main' into TwoInheritingIfacesFromDll
jtschuster Jul 29, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,23 @@ internal sealed record ComInterfaceAndMethodsContext(ComInterfaceContext Interfa
/// <summary>
/// COM methods that require shadowing declarations on the derived interface.
/// </summary>
public IEnumerable<ComMethodContext> ShadowingMethods => Methods.Where(m => m.IsInheritedMethod && !m.IsHiddenOnDerivedInterface);
public IEnumerable<ComMethodContext> ShadowingMethods => Methods.Where(m => m.IsInheritedMethod && !m.IsHiddenOnDerivedInterface && !m.IsExternallyDefined);

/// <summary>
/// COM methods that are declared on an interface the interface inherits from.
/// </summary>
public IEnumerable<ComMethodContext> InheritedMethods => Methods.Where(m => m.IsInheritedMethod);

/// <summary>
/// The size of the vtable for this interface, including the base interface methods and IUnknown methods.
/// </summary>
public int VTableSize => Methods.Length == 0
? IUnknownConstants.VTableSize
: 1 + Methods.Max(m => m.GenerationContext.VtableIndexData.Index);

/// <summary>
/// The size of the vtable for the base interface, including it's base interface methods and IUnknown methods.
/// </summary>
public int BaseVTableSize => VTableSize - DeclaredMethods.Count();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
var externalInterfaceSymbols = attributedInterfaces.SelectMany(static (data, ct) =>
{
return ComInterfaceInfo.CreateInterfaceInfoForBaseInterfacesInOtherCompilations(data.Symbol);
});
}).Collect().SelectMany(static (data, ct) => data.Distinct(ComInterfaceInfo.EqualityComparerForExternalIfaces.Instance));

var interfaceSymbolsWithoutDiagnostics = interfaceSymbolsToGenerateWithoutDiagnostics.Concat(externalInterfaceSymbols);

Expand Down Expand Up @@ -84,11 +84,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
.SelectMany(static (data, ct) =>
{
return ComMethodContext.CalculateAllMethods(data, ct);
})
// Now that we've determined method offsets, we can remove all externally defined methods.
// We'll also filter out methods originally declared on externally defined base interfaces
// as we may not be able to emit them into our assembly.
.Where(context => !context.Method.OriginalDeclaringInterface.IsExternallyDefined);
});

// Now that we've determined method offsets, we can remove all externally defined interfaces.
var interfaceContextsToGenerate = interfaceContexts.Where(context => !context.IsExternallyDefined);
Expand All @@ -107,13 +103,20 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
return new ComMethodContext(
data.Method,
data.OwningInterface,
CalculateStubInformation(data.Method.MethodInfo.Syntax, symbolMap[data.Method.MethodInfo], data.Method.Index, env, data.OwningInterface.Info, ct));
CalculateStubInformation(
data.Method.MethodInfo.Syntax,
symbolMap[data.Method.MethodInfo],
data.Method.Index,
env,
data.OwningInterface.Info,
ct));
}).WithTrackingName(StepNames.CalculateStubInformation);

var interfaceAndMethodsContexts = comMethodContexts
.Collect()
.Combine(interfaceContextsToGenerate.Collect())
.SelectMany((data, ct) => GroupComContextsForInterfaceGeneration(data.Left, data.Right, ct));
.SelectMany((data, ct) =>
GroupComContextsForInterfaceGeneration(data.Left, data.Right, ct));

// Generate the code for the managed-to-unmanaged stubs.
var managedToNativeInterfaceImplementations = interfaceAndMethodsContexts
Expand Down Expand Up @@ -256,12 +259,22 @@ private static bool IsHResultLikeType(ManagedTypeInfo type)
|| typeName.Equals("hresult", StringComparison.OrdinalIgnoreCase);
}

private static IncrementalMethodStubGenerationContext CalculateStubInformation(MethodDeclarationSyntax syntax, IMethodSymbol symbol, int index, StubEnvironment environment, ComInterfaceInfo owningInterfaceInfo, CancellationToken ct)
/// <summary>
/// Calculates the shared information needed for both source-available and sourceless stub generation.
/// </summary>
private static IncrementalMethodStubGenerationContext CalculateSharedStubInformation(
IMethodSymbol symbol,
int index,
StubEnvironment environment,
ISignatureDiagnosticLocations diagnosticLocations,
ComInterfaceInfo owningInterfaceInfo,
CancellationToken ct)
{
ct.ThrowIfCancellationRequested();
INamedTypeSymbol? lcidConversionAttrType = environment.LcidConversionAttrType;
INamedTypeSymbol? suppressGCTransitionAttrType = environment.SuppressGCTransitionAttrType;
INamedTypeSymbol? unmanagedCallConvAttrType = environment.UnmanagedCallConvAttrType;

// Get any attributes of interest on the method
AttributeData? lcidConversionAttr = null;
AttributeData? suppressGCTransitionAttribute = null;
Expand All @@ -282,8 +295,7 @@ private static IncrementalMethodStubGenerationContext CalculateStubInformation(M
}
}

var locations = new MethodSignatureDiagnosticLocations(syntax);
var generatorDiagnostics = new GeneratorDiagnosticsBag(new DiagnosticDescriptorProvider(), locations, SR.ResourceManager, typeof(FxResources.Microsoft.Interop.ComInterfaceGenerator.SR));
var generatorDiagnostics = new GeneratorDiagnosticsBag(new DiagnosticDescriptorProvider(), diagnosticLocations, SR.ResourceManager, typeof(FxResources.Microsoft.Interop.ComInterfaceGenerator.SR));

if (lcidConversionAttr is not null)
{
Expand All @@ -293,8 +305,8 @@ private static IncrementalMethodStubGenerationContext CalculateStubInformation(M

GeneratedComInterfaceCompilationData.TryGetGeneratedComInterfaceAttributeFromInterface(symbol.ContainingType, out var generatedComAttribute);
var generatedComInterfaceAttributeData = GeneratedComInterfaceCompilationData.GetDataFromAttribute(generatedComAttribute);
// Create the stub.

// Create the stub.
var signatureContext = SignatureContext.Create(
symbol,
DefaultMarshallingInfoParser.Create(
Expand Down Expand Up @@ -387,21 +399,14 @@ private static IncrementalMethodStubGenerationContext CalculateStubInformation(M
GeneratorDiagnostics.SizeOfInCollectionMustBeDefinedAtCallReturnValue);
}

var containingSyntaxContext = new ContainingSyntaxContext(syntax);

var methodSyntaxTemplate = new ContainingSyntax(new SyntaxTokenList(syntax.Modifiers.Where(static m => !m.IsKind(SyntaxKind.NewKeyword))).StripAccessibilityModifiers(), SyntaxKind.MethodDeclaration, syntax.Identifier, syntax.TypeParameterList);

ImmutableArray<FunctionPointerUnmanagedCallingConventionSyntax> callConv = VirtualMethodPointerStubGenerator.GenerateCallConvSyntaxFromAttributes(
suppressGCTransitionAttribute,
unmanagedCallConvAttribute,
ImmutableArray.Create(FunctionPointerUnmanagedCallingConvention(Identifier("MemberFunction"))));

var declaringType = ManagedTypeInfo.CreateTypeInfoForTypeSymbol(symbol.ContainingType);

var virtualMethodIndexData = new VirtualMethodIndexData(index, ImplicitThisParameter: true, direction, true, ExceptionMarshalling.Com);

MarshallingInfo exceptionMarshallingInfo;

if (generatedComInterfaceAttributeData.ExceptionToUnmanagedMarshaller is null)
{
exceptionMarshallingInfo = new ComExceptionMarshalling();
Expand All @@ -418,11 +423,9 @@ private static IncrementalMethodStubGenerationContext CalculateStubInformation(M

return new IncrementalMethodStubGenerationContext(
signatureContext,
containingSyntaxContext,
methodSyntaxTemplate,
locations,
diagnosticLocations,
callConv.ToSequenceEqualImmutableArray(SyntaxEquivalentComparer.Instance),
virtualMethodIndexData,
new VirtualMethodIndexData(index, ImplicitThisParameter: true, direction, true, ExceptionMarshalling.Com),
exceptionMarshallingInfo,
environment.EnvironmentFlags,
owningInterfaceInfo.Type,
Expand All @@ -431,6 +434,45 @@ private static IncrementalMethodStubGenerationContext CalculateStubInformation(M
ComInterfaceDispatchMarshallingInfo.Instance);
}

private static IncrementalMethodStubGenerationContext CalculateStubInformation(MethodDeclarationSyntax? syntax, IMethodSymbol symbol, int index, StubEnvironment environment, ComInterfaceInfo owningInterface, CancellationToken ct)
{
ISignatureDiagnosticLocations locations = syntax is null
? NoneSignatureDiagnosticLocations.Instance
: new MethodSignatureDiagnosticLocations(syntax);

var sourcelessStubInformation = CalculateSharedStubInformation(
symbol,
index,
environment,
locations,
owningInterface,
ct);

if (syntax is null)
return sourcelessStubInformation;

var containingSyntaxContext = new ContainingSyntaxContext(syntax);
var methodSyntaxTemplate = new ContainingSyntax(
new SyntaxTokenList(syntax.Modifiers.Where(static m => !m.IsKind(SyntaxKind.NewKeyword))).StripAccessibilityModifiers(),
SyntaxKind.MethodDeclaration,
syntax.Identifier,
syntax.TypeParameterList);

return new SourceAvailableIncrementalMethodStubGenerationContext(
sourcelessStubInformation.SignatureContext,
containingSyntaxContext,
methodSyntaxTemplate,
locations,
sourcelessStubInformation.CallingConvention,
sourcelessStubInformation.VtableIndexData,
sourcelessStubInformation.ExceptionMarshallingInfo,
sourcelessStubInformation.EnvironmentFlags,
sourcelessStubInformation.TypeKeyOwner,
sourcelessStubInformation.DeclaringType,
sourcelessStubInformation.Diagnostics,
ComInterfaceDispatchMarshallingInfo.Instance);
}

private static MarshalDirection GetDirectionFromOptions(ComInterfaceOptions options)
{
if (options.HasFlag(ComInterfaceOptions.ManagedObjectWrapper | ComInterfaceOptions.ComObjectWrapper))
Expand Down Expand Up @@ -520,12 +562,12 @@ static bool MethodEquals(ComMethodContext a, ComMethodContext b)
private static InterfaceDeclarationSyntax GenerateImplementationInterface(ComInterfaceAndMethodsContext interfaceGroup, CancellationToken _)
{
var definingType = interfaceGroup.Interface.Info.Type;
var shadowImplementations = interfaceGroup.InheritedMethods.Select(m => (Method: m, ManagedToUnmanagedStub: m.ManagedToUnmanagedStub))
var shadowImplementations = interfaceGroup.InheritedMethods.Where(m => !m.IsExternallyDefined).Select(m => (Method: m, ManagedToUnmanagedStub: m.ManagedToUnmanagedStub))
.Where(p => p.ManagedToUnmanagedStub is GeneratedStubCodeContext)
.Select(ctx => ((GeneratedStubCodeContext)ctx.ManagedToUnmanagedStub).Stub.Node
.WithExplicitInterfaceSpecifier(
ExplicitInterfaceSpecifier(ParseName(definingType.FullTypeName))));
var inheritedStubs = interfaceGroup.InheritedMethods.Select(m => m.UnreachableExceptionStub);
var inheritedStubs = interfaceGroup.InheritedMethods.Where(m => !m.IsExternallyDefined).Select(m => m.UnreachableExceptionStub);
return ImplementationInterfaceTemplate
.AddBaseListTypes(SimpleBaseType(definingType.Syntax))
.WithMembers(
Expand Down Expand Up @@ -661,7 +703,6 @@ private static InterfaceDeclarationSyntax GenerateImplementationVTable(ComInterf

BlockSyntax fillBaseInterfaceSlots;


if (interfaceMethods.Interface.Base is null)
{
// If we don't have a base interface, we need to manually fill in the base iUnknown slots.
Expand Down Expand Up @@ -740,7 +781,7 @@ private static InterfaceDeclarationSyntax GenerateImplementationVTable(ComInterf
}
else
{
// NativeMemory.Copy(StrategyBasedComWrappers.DefaultIUnknownInteraceDetailsStrategy.GetIUnknownDerivedDetails(typeof(<baseInterfaceType>).TypeHandle).ManagedVirtualMethodTable, vtable, (nuint)(sizeof(void*) * <startingOffset>));
// NativeMemory.Copy(StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy.GetIUnknownDerivedDetails(typeof(<baseInterfaceType>).TypeHandle).ManagedVirtualMethodTable, vtable, (nuint)(sizeof(void*) * <baseVTableSize>));
fillBaseInterfaceSlots = Block(
MethodInvocationStatement(
TypeSyntaxes.System_Runtime_InteropServices_NativeMemory,
Expand All @@ -750,7 +791,7 @@ private static InterfaceDeclarationSyntax GenerateImplementationVTable(ComInterf
TypeSyntaxes.StrategyBasedComWrappers
.Dot(IdentifierName("DefaultIUnknownInterfaceDetailsStrategy")),
IdentifierName("GetIUnknownDerivedDetails"),
Argument( //baseInterfaceTypeInfo.BaseInterface.FullTypeName)),
Argument(
TypeOfExpression(ParseTypeName(interfaceMethods.Interface.Base.Info.Type.FullTypeName))
.Dot(IdentifierName("TypeHandle"))))
.Dot(IdentifierName("ManagedVirtualMethodTable"))),
Expand All @@ -767,7 +808,7 @@ private static InterfaceDeclarationSyntax GenerateImplementationVTable(ComInterf
ParenthesizedExpression(
BinaryExpression(SyntaxKind.MultiplyExpression,
SizeOfExpression(PointerType(PredefinedType(Token(SyntaxKind.VoidKeyword)))),
LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(interfaceMethods.InheritedMethods.Count() + 3))))))));
LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(interfaceMethods.BaseVTableSize))))))));
}

var validDeclaredMethods = interfaceMethods.DeclaredMethods
Expand All @@ -787,7 +828,7 @@ private static InterfaceDeclarationSyntax GenerateImplementationVTable(ComInterf
IdentifierName($"{declaredMethodContext.MethodInfo.MethodName}_{declaredMethodContext.GenerationContext.VtableIndexData.Index}")),
PrefixUnaryExpression(
SyntaxKind.AddressOfExpression,
IdentifierName($"ABI_{declaredMethodContext.GenerationContext.StubMethodSyntaxTemplate.Identifier}")))));
IdentifierName($"ABI_{((SourceAvailableIncrementalMethodStubGenerationContext)declaredMethodContext.GenerationContext).StubMethodSyntaxTemplate.Identifier}")))));
}

return ImplementationInterfaceTemplate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Threading;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using InterfaceInfo = (Microsoft.Interop.ComInterfaceInfo InterfaceInfo, Microsoft.CodeAnalysis.INamedTypeSymbol Symbol);
using DiagnosticOrInterfaceInfo = Microsoft.Interop.DiagnosticOr<(Microsoft.Interop.ComInterfaceInfo InterfaceInfo, Microsoft.CodeAnalysis.INamedTypeSymbol Symbol)>;
using System.Diagnostics;

namespace Microsoft.Interop
{
Expand Down Expand Up @@ -176,6 +177,13 @@ public static ImmutableArray<InterfaceInfo> CreateInterfaceInfoForBaseInterfaces
return builder.ToImmutable();
}

internal sealed class EqualityComparerForExternalIfaces : IEqualityComparer<(ComInterfaceInfo InterfaceInfo, INamedTypeSymbol Symbol)>
{
public bool Equals((ComInterfaceInfo, INamedTypeSymbol) x, (ComInterfaceInfo, INamedTypeSymbol) y) => SymbolEqualityComparer.Default.Equals(x.Item2, y.Item2);
public int GetHashCode((ComInterfaceInfo, INamedTypeSymbol) obj) => SymbolEqualityComparer.Default.GetHashCode(obj.Item2);
public static readonly EqualityComparerForExternalIfaces Instance = new();
}

private static bool IsInPartialContext(INamedTypeSymbol symbol, InterfaceDeclarationSyntax syntax, [NotNullWhen(false)] out DiagnosticInfo? diagnostic)
{
// Verify that the types the interface is declared in are marked partial.
Expand Down
Loading
Loading