Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Runtime.InteropServices;

using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.Diagnostics;
Expand All @@ -18,6 +16,13 @@ public class ConvertToGeneratedDllImportAnalyzer : DiagnosticAnalyzer
{
private const string Category = "Interoperability";

private static readonly string[] s_unsupportedTypeNames = new string[]
{
"System.Runtime.InteropServices.CriticalHandle",
"System.Runtime.InteropServices.HandleRef",
"System.Text.StringBuilder"
};

public static readonly DiagnosticDescriptor ConvertToGeneratedDllImport =
new DiagnosticDescriptor(
Ids.ConvertToGeneratedDllImport,
Expand All @@ -43,15 +48,21 @@ public override void Initialize(AnalysisContext context)
if (generatedDllImportAttrType == null)
return;

INamedTypeSymbol? dllImportAttrType = compilationContext.Compilation.GetTypeByMetadataName(typeof(DllImportAttribute).FullName);
if (dllImportAttrType == null)
return;
List<ITypeSymbol> knownUnsupportedTypes = new List<ITypeSymbol>();
foreach (string typeName in s_unsupportedTypeNames)
{
INamedTypeSymbol? unsupportedType = compilationContext.Compilation.GetTypeByMetadataName(typeName);
if (unsupportedType != null)
{
knownUnsupportedTypes.Add(unsupportedType);
}
}

compilationContext.RegisterSymbolAction(symbolContext => AnalyzeSymbol(symbolContext, dllImportAttrType), SymbolKind.Method);
compilationContext.RegisterSymbolAction(symbolContext => AnalyzeSymbol(symbolContext, knownUnsupportedTypes), SymbolKind.Method);
});
}

private static void AnalyzeSymbol(SymbolAnalysisContext context, INamedTypeSymbol dllImportAttrType)
private static void AnalyzeSymbol(SymbolAnalysisContext context, List<ITypeSymbol> knownUnsupportedTypes)
{
var method = (IMethodSymbol)context.Symbol;

Expand All @@ -64,6 +75,19 @@ private static void AnalyzeSymbol(SymbolAnalysisContext context, INamedTypeSymbo
if (dllImportData.ModuleName == "QCall")
return;

// Ignore methods with unsupported parameters
foreach (IParameterSymbol parameter in method.Parameters)
{
if (knownUnsupportedTypes.Contains(parameter.Type))
{
return;
}
}

// Ignore methods with unsupported returns
if (method.ReturnsByRef || method.ReturnsByRefReadonly || knownUnsupportedTypes.Contains(method.ReturnType))
return;

context.ReportDiagnostic(method.CreateDiagnostic(ConvertToGeneratedDllImport, method.Name));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;
Expand All @@ -28,6 +29,18 @@ public sealed class ConvertToGeneratedDllImportFixer : CodeFixProvider
public const string NoPreprocessorDefinesKey = "ConvertToGeneratedDllImport";
public const string WithPreprocessorDefinesKey = "ConvertToGeneratedDllImportPreprocessor";

private static readonly string[] s_preferredAttributeArgumentOrder =
{
nameof(DllImportAttribute.EntryPoint),
nameof(DllImportAttribute.BestFitMapping),
nameof(DllImportAttribute.CallingConvention),
nameof(DllImportAttribute.CharSet),
nameof(DllImportAttribute.ExactSpelling),
nameof(DllImportAttribute.PreserveSig),
nameof(DllImportAttribute.SetLastError),
nameof(DllImportAttribute.ThrowOnUnmappableChar)
};

public override async Task RegisterCodeFixesAsync(CodeFixContext context)
{
// Get the syntax root and semantic model
Expand Down Expand Up @@ -152,8 +165,11 @@ private async Task<Document> ConvertToGeneratedDllImport(
SyntaxFactory.ElasticMarker
}));

// Sort attribute arguments so that GeneratedDllImport and DllImport match
MethodDeclarationSyntax updatedDeclaration = (MethodDeclarationSyntax)generator.ReplaceNode(methodSyntax, dllImportSyntax, SortDllImportAttributeArguments(dllImportSyntax, generator));

// Remove existing leading trivia - it will be on the GeneratedDllImport method
MethodDeclarationSyntax updatedDeclaration = methodSyntax.WithLeadingTrivia();
updatedDeclaration = updatedDeclaration.WithLeadingTrivia();

// #endif
updatedDeclaration = updatedDeclaration.WithTrailingTrivia(
Expand Down Expand Up @@ -225,7 +241,26 @@ private SyntaxNode GetGeneratedDllImportAttribute(
}
}

return generator.RemoveNodes(generatedDllImportSyntax, argumentsToRemove);
generatedDllImportSyntax = generator.RemoveNodes(generatedDllImportSyntax, argumentsToRemove);
return SortDllImportAttributeArguments((AttributeSyntax)generatedDllImportSyntax, generator);
}

private static SyntaxNode SortDllImportAttributeArguments(AttributeSyntax attribute, SyntaxGenerator generator)
{
AttributeArgumentListSyntax updatedArgList = attribute.ArgumentList.WithArguments(
SyntaxFactory.SeparatedList(
attribute.ArgumentList.Arguments.OrderBy(arg =>
{
// Unnamed arguments first
if (arg.NameEquals == null)
return -1;

// Named arguments in specified order, followed by any named arguments with no preferred order
string name = arg.NameEquals.Name.Identifier.Text;
int index = System.Array.IndexOf(s_preferredAttributeArgumentOrder, name);
return index == -1 ? int.MaxValue : index;
})));
return generator.ReplaceNode(attribute, attribute.ArgumentList, updatedArgList);
}

private bool TryCreateUnmanagedCallConvAttributeToEmit(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,13 @@ public static IEnumerable<object[]> NoMarshallingRequiredTypes() => new[]
new object[] { typeof(ConsoleKey) }, // enum
};

public static IEnumerable<object[]> UnsupportedTypes() => new[]
{
new object[] { typeof(System.Runtime.InteropServices.CriticalHandle) },
new object[] { typeof(System.Runtime.InteropServices.HandleRef) },
new object[] { typeof(System.Text.StringBuilder) },
};

[Theory]
[MemberData(nameof(MarshallingRequiredTypes))]
[MemberData(nameof(NoMarshallingRequiredTypes))]
Expand Down Expand Up @@ -134,6 +141,14 @@ await VerifyCS.VerifyAnalyzerAsync(
.WithArguments("Method2"));
}

[Theory]
[MemberData(nameof(UnsupportedTypes))]
public async Task UnsupportedType_NoDiagnostic(Type type)
{
string source = DllImportWithType(type.FullName!);
await VerifyCS.VerifyAnalyzerAsync(source);
}

[Fact]
public async Task NotDllImport_NoDiagnostic()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ partial class Test
[GeneratedDllImport(""DoesNotExist"", EntryPoint = ""Entry"")]
public static partial int {{|CS8795:Method1|}}(out int ret);
#else
[DllImport(""DoesNotExist"", BestFitMapping = false, EntryPoint = ""Entry"")]
[DllImport(""DoesNotExist"", EntryPoint = ""Entry"", BestFitMapping = false)]
public static extern int Method1(out int ret);
#endif

Expand Down Expand Up @@ -306,7 +306,7 @@ partial class Test
[GeneratedDllImport(""DoesNotExist"", EntryPoint = ""Entry"")]
public static partial int {{|CS8795:Method1|}}(out int ret);
#else
[DllImport(""DoesNotExist"", CallingConvention = CallingConvention.Winapi, EntryPoint = ""Entry"")]
[DllImport(""DoesNotExist"", EntryPoint = ""Entry"", CallingConvention = CallingConvention.Winapi)]
public static extern int Method1(out int ret);
#endif
}}" : @$"
Expand Down Expand Up @@ -351,7 +351,7 @@ partial class Test
[UnmanagedCallConv(CallConvs = new System.Type[] {{ typeof({callConvType.FullName}) }})]
public static partial int {{|CS8795:Method1|}}(out int ret);
#else
[DllImport(""DoesNotExist"", CallingConvention = CallingConvention.{callConv}, EntryPoint = ""Entry"")]
[DllImport(""DoesNotExist"", EntryPoint = ""Entry"", CallingConvention = CallingConvention.{callConv})]
public static extern int Method1(out int ret);
#endif
}}" : @$"
Expand All @@ -361,6 +361,44 @@ partial class Test
[GeneratedDllImport(""DoesNotExist"", EntryPoint = ""Entry"")]
[UnmanagedCallConv(CallConvs = new System.Type[] {{ typeof({callConvType.FullName}) }})]
public static partial int {{|CS8795:Method1|}}(out int ret);
}}";
await VerifyCS.VerifyCodeFixAsync(
source,
fixedSource,
usePreprocessorDefines ? WithPreprocessorDefinesKey : NoPreprocessorDefinesKey);
}

[Theory]
[InlineData(true)]
[InlineData(false)]
public async Task PreferredAttributeOrder(bool usePreprocessorDefines)
{
string source = @$"
using System.Runtime.InteropServices;
partial class Test
{{
[DllImport(""DoesNotExist"", SetLastError = true, EntryPoint = ""Entry"", ExactSpelling = true, CharSet = CharSet.Unicode)]
public static extern int [|Method|](out int ret);
}}";
// Fixed source will have CS8795 (Partial method must have an implementation) without generator run
string fixedSource = usePreprocessorDefines
? @$"
using System.Runtime.InteropServices;
partial class Test
{{
#if DLLIMPORTGENERATOR_ENABLED
[GeneratedDllImport(""DoesNotExist"", EntryPoint = ""Entry"", CharSet = CharSet.Unicode, ExactSpelling = true, SetLastError = true)]
public static partial int {{|CS8795:Method|}}(out int ret);
#else
[DllImport(""DoesNotExist"", EntryPoint = ""Entry"", CharSet = CharSet.Unicode, ExactSpelling = true, SetLastError = true)]
public static extern int Method(out int ret);
#endif
}}" : @$"
using System.Runtime.InteropServices;
partial class Test
{{
[GeneratedDllImport(""DoesNotExist"", EntryPoint = ""Entry"", CharSet = CharSet.Unicode, ExactSpelling = true, SetLastError = true)]
public static partial int {{|CS8795:Method|}}(out int ret);
}}";
await VerifyCS.VerifyCodeFixAsync(
source,
Expand Down