diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection.Specification.Tests/src/KeyedDependencyInjectionSpecificationTests.cs b/src/libraries/Microsoft.Extensions.DependencyInjection.Specification.Tests/src/KeyedDependencyInjectionSpecificationTests.cs index ee7daad0be7eb1..7cee6ebd40ad18 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection.Specification.Tests/src/KeyedDependencyInjectionSpecificationTests.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection.Specification.Tests/src/KeyedDependencyInjectionSpecificationTests.cs @@ -13,6 +13,84 @@ public abstract partial class KeyedDependencyInjectionSpecificationTests { protected abstract IServiceProvider CreateServiceProvider(IServiceCollection collection); + [Fact] + public void CombinationalRegistration() + { + Service service1 = new(); + Service service2 = new(); + Service keyedService1 = new(); + Service keyedService2 = new(); + Service anykeyService1 = new(); + Service anykeyService2 = new(); + Service nullkeyService1 = new(); + Service nullkeyService2 = new(); + + ServiceCollection serviceCollection = new(); + serviceCollection.AddSingleton(service1); + serviceCollection.AddSingleton(service2); + serviceCollection.AddKeyedSingleton(null, nullkeyService1); + serviceCollection.AddKeyedSingleton(null, nullkeyService2); + serviceCollection.AddKeyedSingleton(KeyedService.AnyKey, anykeyService1); + serviceCollection.AddKeyedSingleton(KeyedService.AnyKey, anykeyService2); + serviceCollection.AddKeyedSingleton("keyedService", keyedService1); + serviceCollection.AddKeyedSingleton("keyedService", keyedService2); + + IServiceProvider provider = CreateServiceProvider(serviceCollection); + + /* + * Table for what results are included: + * + * Query | Keyed? | Unkeyed? | AnyKey? | null key? + * ------------------------------------------------------------------- + * GetServices(Type) | no | yes | no | yes + * GetService(Type) | no | yes | no | yes + * + * GetKeyedServices(null) | no | yes | no | yes + * GetKeyedService(null) | no | yes | no | yes + * + * GetKeyedServices(AnyKey) | yes | no | no | no + * GetKeyedService(AnyKey) | throw | throw | throw | throw + * + * GetKeyedServices(key) | yes | no | no | no + * GetKeyedService(key) | yes | no | yes | no + * + * Summary: + * - A null key is the same as unkeyed. This allows the KeyServices APIs to support both keyed and unkeyed. + * - AnyKey is a special case of Keyed. + * - AnyKey registrations are not returned with GetKeyedServices(AnyKey) and GetKeyedService(AnyKey) always throws. + * - For IEnumerable, the ordering of the results are in registration order. + * - For a singleton resolve, the last match wins. + */ + + // Unkeyed (which is really keyed by Type). + Assert.Equal( + new[] { service1, service2, nullkeyService1, nullkeyService2 }, + provider.GetServices()); + + Assert.Equal(nullkeyService2, provider.GetService()); + + // Null key. + Assert.Equal( + new[] { service1, service2, nullkeyService1, nullkeyService2 }, + provider.GetKeyedServices(null)); + + Assert.Equal(nullkeyService2, provider.GetKeyedService(null)); + + // AnyKey. + Assert.Equal( + new[] { keyedService1, keyedService2 }, + provider.GetKeyedServices(KeyedService.AnyKey)); + + Assert.Throws(() => provider.GetKeyedService(KeyedService.AnyKey)); + + // Keyed. + Assert.Equal( + new[] { keyedService1, keyedService2 }, + provider.GetKeyedServices("keyedService")); + + Assert.Equal(keyedService2, provider.GetKeyedService("keyedService")); + } + [Fact] public void ResolveKeyedService() { @@ -158,10 +236,75 @@ public void ResolveKeyedServicesAnyKeyWithAnyKeyRegistration() _ = provider.GetKeyedService("something-else"); _ = provider.GetKeyedService("something-else-again"); - // Return all services registered with a non null key, but not the one "created" with KeyedService.AnyKey + // Return all services registered with a non null key, but not the one "created" with KeyedService.AnyKey, + // nor the KeyedService.AnyKey registration var allServices = provider.GetKeyedServices(KeyedService.AnyKey).ToList(); - Assert.Equal(5, allServices.Count); - Assert.Equal(new[] { service1, service2, service3, service4 }, allServices.Skip(1)); + Assert.Equal(4, allServices.Count); + Assert.Equal(new[] { service1, service2, service3, service4 }, allServices); + + var someKeyedServices = provider.GetKeyedServices("service").ToList(); + Assert.Equal(new[] { service2, service3, service4 }, someKeyedServices); + + var unkeyedServices = provider.GetServices().ToList(); + Assert.Equal(new[] { service5, service6 }, unkeyedServices); + } + + [Fact] + public void ResolveKeyedServicesAnyKeyConsistency() + { + var serviceCollection = new ServiceCollection(); + var service = new Service("first-service"); + serviceCollection.AddKeyedSingleton("first-service", service); + + var provider1 = CreateServiceProvider(serviceCollection); + Assert.Throws(() => provider1.GetKeyedService(KeyedService.AnyKey)); + // We don't return KeyedService.AnyKey registration when listing services + Assert.Equal(new[] { service }, provider1.GetKeyedServices(KeyedService.AnyKey)); + + var provider2 = CreateServiceProvider(serviceCollection); + Assert.Equal(new[] { service }, provider2.GetKeyedServices(KeyedService.AnyKey)); + Assert.Throws(() => provider2.GetKeyedService(KeyedService.AnyKey)); + } + + [Fact] + public void ResolveKeyedServicesAnyKeyConsistencyWithAnyKeyRegistration() + { + var serviceCollection = new ServiceCollection(); + var service = new Service("first-service"); + var any = new Service("any"); + serviceCollection.AddKeyedSingleton("first-service", service); + serviceCollection.AddKeyedSingleton(KeyedService.AnyKey, (sp, key) => any); + + var provider1 = CreateServiceProvider(serviceCollection); + Assert.Equal(new[] { service }, provider1.GetKeyedServices(KeyedService.AnyKey)); + + // Check twice in different order to check caching + var provider2 = CreateServiceProvider(serviceCollection); + Assert.Equal(new[] { service }, provider2.GetKeyedServices(KeyedService.AnyKey)); + Assert.Same(any, provider2.GetKeyedService(new object())); + + Assert.Throws(() => provider2.GetKeyedService(KeyedService.AnyKey)); + } + + [Fact] + public void ResolveKeyedServicesAnyKeyOrdering() + { + var serviceCollection = new ServiceCollection(); + var service1 = new Service(); + var service2 = new Service(); + var service3 = new Service(); + + serviceCollection.AddKeyedSingleton("A-service", service1); + serviceCollection.AddKeyedSingleton("B-service", service2); + serviceCollection.AddKeyedSingleton("A-service", service3); + + var provider = CreateServiceProvider(serviceCollection); + + // The order should be in registration order, and not grouped by key for example. + // Although this isn't necessarily a requirement, it is the current behavior. + Assert.Equal( + new[] { service1, service2, service3 }, + provider.GetKeyedServices(KeyedService.AnyKey)); } [Fact] @@ -250,7 +393,7 @@ public void ResolveKeyedServicesSingletonInstanceWithAnyKey() var provider = CreateServiceProvider(serviceCollection); var services = provider.GetKeyedServices>("some-key").ToList(); - Assert.Equal(new[] { service1, service2 }, services); + Assert.Equal(new[] { service2 }, services); } [Fact] @@ -504,6 +647,9 @@ public void ResolveKeyedSingletonFromScopeServiceProvider() Assert.Null(scopeA.ServiceProvider.GetService()); Assert.Null(scopeB.ServiceProvider.GetService()); + Assert.Throws(() => scopeA.ServiceProvider.GetKeyedService(KeyedService.AnyKey)); + Assert.Throws(() => scopeB.ServiceProvider.GetKeyedService(KeyedService.AnyKey)); + var serviceA1 = scopeA.ServiceProvider.GetKeyedService("key"); var serviceA2 = scopeA.ServiceProvider.GetKeyedService("key"); @@ -528,6 +674,9 @@ public void ResolveKeyedScopedFromScopeServiceProvider() Assert.Null(scopeA.ServiceProvider.GetService()); Assert.Null(scopeB.ServiceProvider.GetService()); + Assert.Throws(() => scopeA.ServiceProvider.GetKeyedService(KeyedService.AnyKey)); + Assert.Throws(() => scopeB.ServiceProvider.GetKeyedService(KeyedService.AnyKey)); + var serviceA1 = scopeA.ServiceProvider.GetKeyedService("key"); var serviceA2 = scopeA.ServiceProvider.GetKeyedService("key"); diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/src/Resources/Strings.resx b/src/libraries/Microsoft.Extensions.DependencyInjection/src/Resources/Strings.resx index e6f57dcbb5a1ee..2eb2adba7e0060 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/src/Resources/Strings.resx +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/src/Resources/Strings.resx @@ -192,4 +192,7 @@ The type of the key used for lookup doesn't match the type in the constructor parameter with the ServiceKey attribute. + + KeyedService.AnyKey cannot be used to resolve a single service. + \ No newline at end of file diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteFactory.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteFactory.cs index aec3f2c6745420..b3897b1ab7dce1 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteFactory.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteFactory.cs @@ -282,11 +282,13 @@ private static bool AreCompatible(DynamicallyAccessedMemberTypes serviceDynamica CallSiteResultCacheLocation cacheLocation = CallSiteResultCacheLocation.Root; ServiceCallSite[] callSites; + var isAnyKeyLookup = serviceIdentifier.ServiceKey == KeyedService.AnyKey; + // If item type is not generic we can safely use descriptor cache // Special case for KeyedService.AnyKey, we don't want to check the cache because a KeyedService.AnyKey registration // will "hide" all the other service registration if (!itemType.IsConstructedGenericType && - !KeyedService.AnyKey.Equals(cacheKey.ServiceKey) && + !isAnyKeyLookup && _descriptorLookup.TryGetValue(cacheKey, out ServiceDescriptorCacheItem descriptors)) { callSites = new ServiceCallSite[descriptors.Count]; @@ -317,9 +319,12 @@ private static bool AreCompatible(DynamicallyAccessedMemberTypes serviceDynamica int slot = 0; for (int i = _descriptors.Length - 1; i >= 0; i--) { - if (KeysMatch(_descriptors[i].ServiceKey, cacheKey.ServiceKey)) + if (KeysMatch(cacheKey.ServiceKey, _descriptors[i].ServiceKey)) { - if (TryCreateExact(_descriptors[i], cacheKey, callSiteChain, slot) is { } callSite) + // Special case for AnyKey: we don't want to add in cache a mapping AnyKey -> specific type, + // so we need to ask creation with the original identity of the descriptor + var registrationKey = isAnyKeyLookup ? ServiceIdentifier.FromDescriptor(_descriptors[i]) : cacheKey; + if (TryCreateExact(_descriptors[i], registrationKey, callSiteChain, slot) is { } callSite) { AddCallSite(callSite, i); } @@ -327,9 +332,12 @@ private static bool AreCompatible(DynamicallyAccessedMemberTypes serviceDynamica } for (int i = _descriptors.Length - 1; i >= 0; i--) { - if (KeysMatch(_descriptors[i].ServiceKey, cacheKey.ServiceKey)) + if (KeysMatch(cacheKey.ServiceKey, _descriptors[i].ServiceKey)) { - if (TryCreateOpenGeneric(_descriptors[i], cacheKey, callSiteChain, slot, throwOnConstraintViolation: false) is { } callSite) + // Special case for AnyKey: we don't want to add in cache a mapping AnyKey -> specific type, + // so we need to ask creation with the original identity of the descriptor + var registrationKey = isAnyKeyLookup ? ServiceIdentifier.FromDescriptor(_descriptors[i]) : cacheKey; + if (TryCreateOpenGeneric(_descriptors[i], registrationKey, callSiteChain, slot, throwOnConstraintViolation: false) is { } callSite) { AddCallSite(callSite, i); } @@ -360,6 +368,32 @@ void AddCallSite(ServiceCallSite callSite, int index) { callSiteChain.Remove(serviceIdentifier); } + + static bool KeysMatch(object? lookupKey, object? descriptorKey) + { + if (lookupKey == null && descriptorKey == null) + { + // Both are non keyed services + return true; + } + + if (lookupKey != null && descriptorKey != null) + { + // Both are keyed services + + // We don't want to return AnyKey registration, so ignore it + if (descriptorKey.Equals(KeyedService.AnyKey)) + return false; + + // Check if both keys are equal, or if the lookup key + // should matches all keys (except AnyKey) + return lookupKey.Equals(descriptorKey) + || lookupKey.Equals(KeyedService.AnyKey); + } + + // One is a keyed service, one is not + return false; + } } private static CallSiteResultCacheLocation GetCommonCacheLocation(CallSiteResultCacheLocation locationA, CallSiteResultCacheLocation locationB) @@ -693,24 +727,6 @@ internal bool IsService(ServiceIdentifier serviceIdentifier) serviceType == typeof(IServiceProviderIsKeyedService); } - /// - /// Returns true if both keys are null or equals, or if key1 is KeyedService.AnyKey and key2 is not null - /// - private static bool KeysMatch(object? key1, object? key2) - { - if (key1 == null && key2 == null) - return true; - - if (key1 != null && key2 != null) - { - return key1.Equals(key2) - || key1.Equals(KeyedService.AnyKey) - || key2.Equals(KeyedService.AnyKey); - } - - return false; - } - private struct ServiceDescriptorCacheItem { [DisallowNull] diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ThrowHelper.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ThrowHelper.cs index ffacbf9bbc0521..cf9bedf8341d78 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ThrowHelper.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ThrowHelper.cs @@ -15,5 +15,12 @@ internal static void ThrowObjectDisposedException() { throw new ObjectDisposedException(nameof(IServiceProvider)); } + + [DoesNotReturn] + [MethodImpl(MethodImplOptions.NoInlining)] + internal static void ThrowInvalidOperationException_KeyedServiceAnyKeyUsedToResolveService() + { + throw new InvalidOperationException(SR.Format(SR.KeyedServiceAnyKeyUsedToResolveService, nameof(IServiceProvider), nameof(IServiceScopeFactory))); + } } } diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceProvider.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceProvider.cs index d3177c229e31ec..bb05be370d8303 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceProvider.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceProvider.cs @@ -108,11 +108,24 @@ internal ServiceProvider(ICollection serviceDescriptors, Serv /// The type of the service to get. /// The key of the service to get. /// The keyed service. + /// The value is used for + /// when is not an enumerable based on . + /// public object? GetKeyedService(Type serviceType, object? serviceKey) => GetKeyedService(serviceType, serviceKey, Root); internal object? GetKeyedService(Type serviceType, object? serviceKey, ServiceProviderEngineScope serviceProviderEngineScope) - => GetService(new ServiceIdentifier(serviceKey, serviceType), serviceProviderEngineScope); + { + if (serviceKey == KeyedService.AnyKey) + { + if (!serviceType.IsGenericType || serviceType.GetGenericTypeDefinition() != typeof(IEnumerable<>)) + { + ThrowHelper.ThrowInvalidOperationException_KeyedServiceAnyKeyUsedToResolveService(); + } + } + + return GetService(new ServiceIdentifier(serviceKey, serviceType), serviceProviderEngineScope); + } /// /// Gets the service object of the specified type. @@ -120,9 +133,21 @@ internal ServiceProvider(ICollection serviceDescriptors, Serv /// The type of the service to get. /// The key of the service to get. /// The keyed service. - /// The service wasn't found. + /// The service wasn't found or the value is used + /// for when is not an enumerable based on . + /// public object GetRequiredKeyedService(Type serviceType, object? serviceKey) - => GetRequiredKeyedService(serviceType, serviceKey, Root); + { + if (serviceKey == KeyedService.AnyKey) + { + if (!serviceType.IsGenericType || serviceType.GetGenericTypeDefinition() != typeof(IEnumerable<>)) + { + ThrowHelper.ThrowInvalidOperationException_KeyedServiceAnyKeyUsedToResolveService(); + } + } + + return GetRequiredKeyedService(serviceType, serviceKey, Root); + } internal object GetRequiredKeyedService(Type serviceType, object? serviceKey, ServiceProviderEngineScope serviceProviderEngineScope) {