diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection.Specification.Tests/src/KeyedDependencyInjectionSpecificationTests.cs b/src/libraries/Microsoft.Extensions.DependencyInjection.Specification.Tests/src/KeyedDependencyInjectionSpecificationTests.cs index 3664a3b81eafd4..336c3e14b479cf 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection.Specification.Tests/src/KeyedDependencyInjectionSpecificationTests.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection.Specification.Tests/src/KeyedDependencyInjectionSpecificationTests.cs @@ -6,6 +6,7 @@ using System.Security.Cryptography; using Microsoft.Extensions.DependencyInjection.Specification.Fakes; using Xunit; +using static Microsoft.Extensions.DependencyInjection.Specification.KeyedDependencyInjectionSpecificationTests; namespace Microsoft.Extensions.DependencyInjection.Specification { @@ -152,10 +153,48 @@ public void ResolveKeyedServicesAnyKeyWithAnyKeyRegistration() _ = provider.GetKeyedService("something-else"); _ = provider.GetKeyedService("something-else-again"); - // Return all services registered with a non null key, but not the one "created" with KeyedService.AnyKey + // Return all services registered with a non null key, but not the one "created" with KeyedService.AnyKey, + // nor the KeyedService.AnyKey registration var allServices = provider.GetKeyedServices(KeyedService.AnyKey).ToList(); - Assert.Equal(5, allServices.Count); - Assert.Equal(new[] { service1, service2, service3, service4 }, allServices.Skip(1)); + Assert.Equal(4, allServices.Count); + Assert.Equal(new[] { service1, service2, service3, service4 }, allServices); + } + + [Fact] + public void ResolveKeyedServicesAnyKeyConsistency() + { + var serviceCollection = new ServiceCollection(); + var service = new Service("first-service"); + serviceCollection.AddKeyedSingleton("first-service", service); + + var provider1 = CreateServiceProvider(serviceCollection); + Assert.Null(provider1.GetKeyedService(KeyedService.AnyKey)); + // We don't return KeyedService.AnyKey registration when listing services + Assert.Equal(new[] { service }, provider1.GetKeyedServices(KeyedService.AnyKey)); + + var provider2 = CreateServiceProvider(serviceCollection); + Assert.Equal(new[] { service }, provider2.GetKeyedServices(KeyedService.AnyKey)); + // But we should be able to directly do a lookup on it + Assert.Null(provider2.GetKeyedService(KeyedService.AnyKey)); + } + + [Fact] + public void ResolveKeyedServicesAnyKeyConsistencyWithAnyKeyRegistration() + { + var serviceCollection = new ServiceCollection(); + var service = new Service("first-service"); + var any = new Service("any"); + serviceCollection.AddKeyedSingleton("first-service", service); + serviceCollection.AddKeyedSingleton(KeyedService.AnyKey, (sp, key) => any); + + var provider1 = CreateServiceProvider(serviceCollection); + Assert.Same(any, provider1.GetKeyedService(KeyedService.AnyKey)); + Assert.Equal(new[] { service }, provider1.GetKeyedServices(KeyedService.AnyKey)); + + // Check twice in different order to check caching + var provider2 = CreateServiceProvider(serviceCollection); + Assert.Equal(new[] { service }, provider2.GetKeyedServices(KeyedService.AnyKey)); + Assert.Same(any, provider2.GetKeyedService(KeyedService.AnyKey)); } [Fact] @@ -243,7 +282,7 @@ public void ResolveKeyedServicesSingletonInstanceWithAnyKey() var provider = CreateServiceProvider(serviceCollection); var services = provider.GetKeyedServices>("some-key").ToList(); - Assert.Equal(new[] { service1, service2 }, services); + Assert.Equal(new[] { service2 }, services); } [Fact] diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/Microsoft.Extensions.DependencyInjection.sln b/src/libraries/Microsoft.Extensions.DependencyInjection/Microsoft.Extensions.DependencyInjection.sln index 9345a6ded34b0d..008b98db7ff0da 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/Microsoft.Extensions.DependencyInjection.sln +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/Microsoft.Extensions.DependencyInjection.sln @@ -159,4 +159,4 @@ Global GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {68A7BDA7-8093-433C-BF7A-8A6A7560BD02} EndGlobalSection -EndGlobal +EndGlobal \ No newline at end of file diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteFactory.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteFactory.cs index f9b902dde19f3e..32245626bf59b0 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteFactory.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteFactory.cs @@ -282,11 +282,13 @@ private static bool AreCompatible(DynamicallyAccessedMemberTypes serviceDynamica CallSiteResultCacheLocation cacheLocation = CallSiteResultCacheLocation.Root; ServiceCallSite[] callSites; + var isAnyKeyLookup = serviceIdentifier.ServiceKey == KeyedService.AnyKey; + // If item type is not generic we can safely use descriptor cache // Special case for KeyedService.AnyKey, we don't want to check the cache because a KeyedService.AnyKey registration // will "hide" all the other service registration if (!itemType.IsConstructedGenericType && - !KeyedService.AnyKey.Equals(cacheKey.ServiceKey) && + !isAnyKeyLookup && _descriptorLookup.TryGetValue(cacheKey, out ServiceDescriptorCacheItem descriptors)) { callSites = new ServiceCallSite[descriptors.Count]; @@ -317,9 +319,12 @@ private static bool AreCompatible(DynamicallyAccessedMemberTypes serviceDynamica int slot = 0; for (int i = _descriptors.Length - 1; i >= 0; i--) { - if (KeysMatch(_descriptors[i].ServiceKey, cacheKey.ServiceKey)) + if (KeysMatch(cacheKey.ServiceKey, _descriptors[i].ServiceKey)) { - if (TryCreateExact(_descriptors[i], cacheKey, callSiteChain, slot) is { } callSite) + // Special case for AnyKey: we don't want to add in cache a mapping AnyKey -> specific type, + // so we need to ask creation with the original identity of the descriptor + var registrationKey = isAnyKeyLookup ? ServiceIdentifier.FromDescriptor(_descriptors[i]) : cacheKey; + if (TryCreateExact(_descriptors[i], registrationKey, callSiteChain, slot) is { } callSite) { AddCallSite(callSite, i); } @@ -327,9 +332,12 @@ private static bool AreCompatible(DynamicallyAccessedMemberTypes serviceDynamica } for (int i = _descriptors.Length - 1; i >= 0; i--) { - if (KeysMatch(_descriptors[i].ServiceKey, cacheKey.ServiceKey)) + if (KeysMatch(cacheKey.ServiceKey, _descriptors[i].ServiceKey)) { - if (TryCreateOpenGeneric(_descriptors[i], cacheKey, callSiteChain, slot, throwOnConstraintViolation: false) is { } callSite) + // Special case for AnyKey: we don't want to add in cache a mapping AnyKey -> specific type, + // so we need to ask creation with the original identity of the descriptor + var registrationKey = isAnyKeyLookup ? ServiceIdentifier.FromDescriptor(_descriptors[i]) : cacheKey; + if (TryCreateOpenGeneric(_descriptors[i], registrationKey, callSiteChain, slot, throwOnConstraintViolation: false) is { } callSite) { AddCallSite(callSite, i); } @@ -360,6 +368,32 @@ void AddCallSite(ServiceCallSite callSite, int index) { callSiteChain.Remove(serviceIdentifier); } + + static bool KeysMatch(object? lookupKey, object? descriptorKey) + { + if (lookupKey == null && descriptorKey == null) + { + // Both are non keyed services + return true; + } + + if (lookupKey != null && descriptorKey != null) + { + // Both are keyed services + + // We don't want to return AnyKey registration, so ignore it + if (descriptorKey.Equals(KeyedService.AnyKey)) + return false; + + // Check if both keys are equal, or if the lookup key + // should matches all keys (except AnyKey) + return lookupKey.Equals(descriptorKey) + || lookupKey.Equals(KeyedService.AnyKey); + } + + // One is a keyed service, one is not + return false; + } } private static CallSiteResultCacheLocation GetCommonCacheLocation(CallSiteResultCacheLocation locationA, CallSiteResultCacheLocation locationB) @@ -688,24 +722,6 @@ internal bool IsService(ServiceIdentifier serviceIdentifier) serviceType == typeof(IServiceProviderIsKeyedService); } - /// - /// Returns true if both keys are null or equals, or if key1 is KeyedService.AnyKey and key2 is not null - /// - private static bool KeysMatch(object? key1, object? key2) - { - if (key1 == null && key2 == null) - return true; - - if (key1 != null && key2 != null) - { - return key1.Equals(key2) - || key1.Equals(KeyedService.AnyKey) - || key2.Equals(KeyedService.AnyKey); - } - - return false; - } - private struct ServiceDescriptorCacheItem { [DisallowNull]