Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 65 additions & 4 deletions src/CodeAnalysis.Tests/ConventionAnalyzerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
using System.Collections.Immutable;
using System.IO;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Devlooped.Extensions.DependencyInjection;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Testing;
using Microsoft.CodeAnalysis.Diagnostics;
using Microsoft.CodeAnalysis.Testing;
using Xunit;
using Xunit.Abstractions;
Expand Down Expand Up @@ -54,7 +57,7 @@ public static void Main()
.AddPackages(ImmutableArray.Create(
new PackageIdentity("Microsoft.Extensions.DependencyInjection", "8.0.0")))
},
}.WithPreprocessorSymbols();
};

var expected = Verifier.Diagnostic(ConventionsAnalyzer.AssignableTypeOfRequired).WithLocation(0);
test.ExpectedDiagnostics.Add(expected);
Expand Down Expand Up @@ -98,7 +101,7 @@ public static void Main()
.AddPackages(ImmutableArray.Create(
new PackageIdentity("Microsoft.Extensions.DependencyInjection", "8.0.0")))
},
}.WithPreprocessorSymbols();
};

//var expected = Verifier.Diagnostic(ConventionsAnalyzer.AssignableTypeOfRequired).WithLocation(0);
//test.ExpectedDiagnostics.Add(expected);
Expand Down Expand Up @@ -145,12 +148,70 @@ public static void Main()
.AddPackages(ImmutableArray.Create(
new PackageIdentity("Microsoft.Extensions.DependencyInjection", "8.0.0")))
},
}.WithPreprocessorSymbols();
};

var expected = Verifier.Diagnostic(ConventionsAnalyzer.OpenGenericType).WithLocation(0);
test.ExpectedDiagnostics.Add(expected);

await test.RunAsync();
}

}
[Fact]
public async Task WarnIfAmbiguousLifetime()
{
var test = new CSharpSourceGeneratorTest<IncrementalGenerator, DefaultVerifier>
{
TestBehaviors = TestBehaviors.SkipGeneratedSourcesCheck,
TestCode =
"""
using System;
using Microsoft.Extensions.DependencyInjection;

public interface IRepository { }
public class MyRepository : IRepository { }

public static class Program
{
public static void Main()
{
var services = new ServiceCollection();
{|#0:services.AddServices(typeof(IRepository), ServiceLifetime.Scoped)|};
{|#1:services.AddServices("Repository", ServiceLifetime.Singleton)|};
}
}
""",
TestState =
{
AnalyzerConfigFiles =
{
("/.editorconfig",
"""
is_global = true
build_property.AddServicesExtension = true
""")
},
Sources =
{
StaticGenerator.AddServicesExtension,
StaticGenerator.ServiceAttribute,
StaticGenerator.ServiceAttributeT,
},
ReferenceAssemblies = new ReferenceAssemblies(
"net8.0",
new PackageIdentity(
"Microsoft.NETCore.App.Ref", "8.0.0"),
Path.Combine("ref", "net8.0"))
.AddPackages(ImmutableArray.Create(
new PackageIdentity("Microsoft.Extensions.DependencyInjection", "8.0.0")))
},
};

var expected = Verifier.Diagnostic(IncrementalGenerator.AmbiguousLifetime)
.WithArguments("MyRepository", "Scoped, Singleton")
.WithLocation(0).WithLocation(1);

test.ExpectedDiagnostics.Add(expected);

await test.RunAsync();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
<PackageReference Include="System.Composition.AttributedModel" Version="8.0.0" />
<PackageReference Include="System.Composition.Hosting" Version="8.0.0" />
<PackageReference Include="System.Composition.TypedParts" Version="8.0.0" />
<PackageReference Include="Microsoft.Bcl.HashCode" Version="6.0.0" GeneratePathProperty="true" />
<PackageReference Include="Microsoft.Bcl.TimeProvider" Version="8.0.1" GeneratePathProperty="true" />
</ItemGroup>

<ItemGroup>
Expand All @@ -32,6 +34,11 @@
<Import Project="..\DependencyInjection\Devlooped.Extensions.DependencyInjection.targets" />
<Import Project="..\SponsorLink\SponsorLink.Analyzer.Tests.targets" />

<ItemGroup>
<Analyzer Include="$(PkgMicrosoft_Bcl_HashCode)\lib\netstandard2.0\Microsoft.Bcl.HashCode.dll" />
<Analyzer Include="$(PkgMicrosoft_Bcl_TimeProvider)\lib\netstandard2.0\Microsoft.Bcl.TimeProvider.dll" />
</ItemGroup>

<!-- Force immediate reporting of status, no install-time grace period -->
<PropertyGroup>
<SponsorLinkNoInstallGrace>true</SponsorLinkNoInstallGrace>
Expand Down
46 changes: 40 additions & 6 deletions src/DependencyInjection/IncrementalGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Diagnostics;
using Microsoft.Extensions.DependencyInjection;
using KeyedService = (Microsoft.CodeAnalysis.INamedTypeSymbol Type, Microsoft.CodeAnalysis.TypedConstant? Key);

Expand All @@ -22,11 +21,21 @@ namespace Devlooped.Extensions.DependencyInjection;
[Generator(LanguageNames.CSharp)]
public class IncrementalGenerator : IIncrementalGenerator
{
class ServiceSymbol(INamedTypeSymbol type, int lifetime, TypedConstant? key)
public static DiagnosticDescriptor AmbiguousLifetime { get; } =
new DiagnosticDescriptor(
"DDI004",
"Ambiguous lifetime registration.",
"More than one registration matches {0} with lifetimes {1}.",
"Build",
DiagnosticSeverity.Warning,
isEnabledByDefault: true);

class ServiceSymbol(INamedTypeSymbol type, int lifetime, TypedConstant? key, Location? location)
{
public INamedTypeSymbol Type => type;
public int Lifetime => lifetime;
public TypedConstant? Key => key;
public Location? Location => location;

public override bool Equals(object? obj)
{
Expand All @@ -42,7 +51,7 @@ public override int GetHashCode()
=> HashCode.Combine(SymbolEqualityComparer.Default.GetHashCode(type), lifetime, key);
}

record ServiceRegistration(int Lifetime, TypeSyntax? AssignableTo, string? FullNameExpression)
record ServiceRegistration(int Lifetime, TypeSyntax? AssignableTo, string? FullNameExpression, Location? Location)
{
Regex? regex;

Expand Down Expand Up @@ -175,7 +184,7 @@ bool IsExport(AttributeData attr)
}
}

services.Add(new(x, lifetime, key));
services.Add(new(x, lifetime, key, attr.ApplicationSyntaxReference?.GetSyntax().GetLocation()));
}

return services.ToImmutableArray();
Expand Down Expand Up @@ -220,7 +229,7 @@ bool IsExport(AttributeData attr)
if (registration!.FullNameExpression != null && !registration.Regex.IsMatch(typeSymbol.ToFullName(compilation)))
continue;

results.Add(new ServiceSymbol(typeSymbol, registration.Lifetime, null));
results.Add(new ServiceSymbol(typeSymbol, registration.Lifetime, null, registration.Location));
}

return results.ToImmutable();
Expand Down Expand Up @@ -259,6 +268,31 @@ void RegisterServicesOutput(IncrementalGeneratorInitializationContext context, I
context.RegisterImplementationSourceOutput(
services.Where(x => x!.Lifetime == 2 && x.Key is not null).Select((x, _) => new KeyedService(x!.Type, x.Key!)).Collect().Combine(compilation),
(ctx, data) => AddPartial("AddKeyedTransient", ctx, data));

context.RegisterImplementationSourceOutput(services.Collect(), ReportInconsistencies);
}

void ReportInconsistencies(SourceProductionContext context, ImmutableArray<ServiceSymbol> array)
{
var grouped = array.GroupBy(x => x.Type, SymbolEqualityComparer.Default).Where(g => g.Count() > 1).ToImmutableArray();
if (grouped.Length == 0)
return;

foreach (var group in grouped)
{
// report if within the group, there are different lifetimes with the same key (or no key)
foreach (var keyed in group.GroupBy(x => x.Key?.Value).Where(g => g.Count() > 1))
{
var lifetimes = string.Join(", ", keyed.Select(x => x.Lifetime).Distinct()
.Select(x => x switch { 0 => "Singleton", 1 => "Scoped", 2 => "Transient", _ => "Unknown" }));

var location = keyed.Where(x => x.Location != null).FirstOrDefault()?.Location;
var otherLocations = keyed.Where(x => x.Location != null).Skip(1).Select(x => x.Location!);

context.ReportDiagnostic(Diagnostic.Create(AmbiguousLifetime,
location, otherLocations, keyed.First().Type.ToDisplayString(), lifetimes));
}
}
}

static string? GetInvokedMethodName(InvocationExpressionSyntax invocation) => invocation.Expression switch
Expand Down Expand Up @@ -330,7 +364,7 @@ void RegisterServicesOutput(IncrementalGeneratorInitializationContext context, I

if (assignableTo != null || fullNameExpression != null)
{
return new ServiceRegistration(lifetime, assignableTo, fullNameExpression);
return new ServiceRegistration(lifetime, assignableTo, fullNameExpression, invocation.GetLocation());
}
}
return null;
Expand Down
Loading