Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,84 @@ public abstract partial class KeyedDependencyInjectionSpecificationTests
{
protected abstract IServiceProvider CreateServiceProvider(IServiceCollection collection);

[Fact]
public void CombinationalRegistration()
{
Service service1 = new();
Service service2 = new();
Service keyedService1 = new();
Service keyedService2 = new();
Service anykeyService1 = new();
Service anykeyService2 = new();
Service nullkeyService1 = new();
Service nullkeyService2 = new();

ServiceCollection serviceCollection = new();
serviceCollection.AddSingleton<IService>(service1);
serviceCollection.AddSingleton<IService>(service2);
serviceCollection.AddKeyedSingleton<IService>(null, nullkeyService1);
serviceCollection.AddKeyedSingleton<IService>(null, nullkeyService2);
serviceCollection.AddKeyedSingleton<IService>(KeyedService.AnyKey, anykeyService1);
serviceCollection.AddKeyedSingleton<IService>(KeyedService.AnyKey, anykeyService2);
serviceCollection.AddKeyedSingleton<IService>("keyedService", keyedService1);
serviceCollection.AddKeyedSingleton<IService>("keyedService", keyedService2);

IServiceProvider provider = CreateServiceProvider(serviceCollection);

/*
* Table for what results are included:
*
* Query | Keyed? | Unkeyed? | AnyKey? | null key?
* -------------------------------------------------------------------
* GetServices(Type) | no | yes | no | yes
* GetService(Type) | no | yes | no | yes
*
* GetKeyedServices(null) | no | yes | no | yes
* GetKeyedService(null) | no | yes | no | yes
*
* GetKeyedServices(AnyKey) | yes | no | no | no
* GetKeyedService(AnyKey) | throw | throw | throw | throw
*
* GetKeyedServices(key) | yes | no | no | no
* GetKeyedService(key) | yes | no | yes | no
*
* Summary:
* - A null key is the same as unkeyed. This allows the KeyServices APIs to support both keyed and unkeyed.
* - AnyKey is a special case of Keyed.
* - AnyKey registrations are not returned with GetKeyedServices(AnyKey) and GetKeyedService(AnyKey) always throws.
* - For IEnumerable, the ordering of the results are in registration order.
* - For a singleton resolve, the last match wins.
*/

// Unkeyed (which is really keyed by Type).
Assert.Equal(
new[] { service1, service2, nullkeyService1, nullkeyService2 },
provider.GetServices<IService>());

Assert.Equal(nullkeyService2, provider.GetService<IService>());

// Null key.
Assert.Equal(
new[] { service1, service2, nullkeyService1, nullkeyService2 },
provider.GetKeyedServices<IService>(null));

Assert.Equal(nullkeyService2, provider.GetKeyedService<IService>(null));

// AnyKey.
Assert.Equal(
new[] { keyedService1, keyedService2 },
provider.GetKeyedServices<IService>(KeyedService.AnyKey));

Assert.Throws<InvalidOperationException>(() => provider.GetKeyedService<IService>(KeyedService.AnyKey));

// Keyed.
Assert.Equal(
new[] { keyedService1, keyedService2 },
provider.GetKeyedServices<IService>("keyedService"));

Assert.Equal(keyedService2, provider.GetKeyedService<IService>("keyedService"));
}

[Fact]
public void ResolveKeyedService()
{
Expand Down Expand Up @@ -158,10 +236,75 @@ public void ResolveKeyedServicesAnyKeyWithAnyKeyRegistration()
_ = provider.GetKeyedService<IService>("something-else");
_ = provider.GetKeyedService<IService>("something-else-again");

// Return all services registered with a non null key, but not the one "created" with KeyedService.AnyKey
// Return all services registered with a non null key, but not the one "created" with KeyedService.AnyKey,
// nor the KeyedService.AnyKey registration
var allServices = provider.GetKeyedServices<IService>(KeyedService.AnyKey).ToList();
Assert.Equal(5, allServices.Count);
Assert.Equal(new[] { service1, service2, service3, service4 }, allServices.Skip(1));
Assert.Equal(4, allServices.Count);
Assert.Equal(new[] { service1, service2, service3, service4 }, allServices);

var someKeyedServices = provider.GetKeyedServices<IService>("service").ToList();
Assert.Equal(new[] { service2, service3, service4 }, someKeyedServices);

var unkeyedServices = provider.GetServices<IService>().ToList();
Assert.Equal(new[] { service5, service6 }, unkeyedServices);
}

[Fact]
public void ResolveKeyedServicesAnyKeyConsistency()
{
var serviceCollection = new ServiceCollection();
var service = new Service("first-service");
serviceCollection.AddKeyedSingleton<IService>("first-service", service);

var provider1 = CreateServiceProvider(serviceCollection);
Assert.Throws<InvalidOperationException>(() => provider1.GetKeyedService<IService>(KeyedService.AnyKey));
// We don't return KeyedService.AnyKey registration when listing services
Assert.Equal(new[] { service }, provider1.GetKeyedServices<IService>(KeyedService.AnyKey));

var provider2 = CreateServiceProvider(serviceCollection);
Assert.Equal(new[] { service }, provider2.GetKeyedServices<IService>(KeyedService.AnyKey));
Assert.Throws<InvalidOperationException>(() => provider2.GetKeyedService<IService>(KeyedService.AnyKey));
}

[Fact]
public void ResolveKeyedServicesAnyKeyConsistencyWithAnyKeyRegistration()
{
var serviceCollection = new ServiceCollection();
var service = new Service("first-service");
var any = new Service("any");
serviceCollection.AddKeyedSingleton<IService>("first-service", service);
serviceCollection.AddKeyedSingleton<IService>(KeyedService.AnyKey, (sp, key) => any);

var provider1 = CreateServiceProvider(serviceCollection);
Assert.Equal(new[] { service }, provider1.GetKeyedServices<IService>(KeyedService.AnyKey));

// Check twice in different order to check caching
var provider2 = CreateServiceProvider(serviceCollection);
Assert.Equal(new[] { service }, provider2.GetKeyedServices<IService>(KeyedService.AnyKey));
Assert.Same(any, provider2.GetKeyedService<IService>(new object()));

Assert.Throws<InvalidOperationException>(() => provider2.GetKeyedService<IService>(KeyedService.AnyKey));
}

[Fact]
public void ResolveKeyedServicesAnyKeyOrdering()
{
var serviceCollection = new ServiceCollection();
var service1 = new Service();
var service2 = new Service();
var service3 = new Service();

serviceCollection.AddKeyedSingleton<IService>("A-service", service1);
serviceCollection.AddKeyedSingleton<IService>("B-service", service2);
serviceCollection.AddKeyedSingleton<IService>("A-service", service3);

var provider = CreateServiceProvider(serviceCollection);

// The order should be in registration order, and not grouped by key for example.
// Although this isn't necessarily a requirement, it is the current behavior.
Assert.Equal(
new[] { service1, service2, service3 },
provider.GetKeyedServices<IService>(KeyedService.AnyKey));
}

[Fact]
Expand Down Expand Up @@ -250,7 +393,7 @@ public void ResolveKeyedServicesSingletonInstanceWithAnyKey()
var provider = CreateServiceProvider(serviceCollection);

var services = provider.GetKeyedServices<IFakeOpenGenericService<PocoClass>>("some-key").ToList();
Assert.Equal(new[] { service1, service2 }, services);
Assert.Equal(new[] { service2 }, services);
}

[Fact]
Expand Down Expand Up @@ -504,6 +647,9 @@ public void ResolveKeyedSingletonFromScopeServiceProvider()
Assert.Null(scopeA.ServiceProvider.GetService<IService>());
Assert.Null(scopeB.ServiceProvider.GetService<IService>());

Assert.Throws<InvalidOperationException>(() => scopeA.ServiceProvider.GetKeyedService<IService>(KeyedService.AnyKey));
Assert.Throws<InvalidOperationException>(() => scopeB.ServiceProvider.GetKeyedService<IService>(KeyedService.AnyKey));

var serviceA1 = scopeA.ServiceProvider.GetKeyedService<IService>("key");
var serviceA2 = scopeA.ServiceProvider.GetKeyedService<IService>("key");

Expand All @@ -528,6 +674,9 @@ public void ResolveKeyedScopedFromScopeServiceProvider()
Assert.Null(scopeA.ServiceProvider.GetService<IService>());
Assert.Null(scopeB.ServiceProvider.GetService<IService>());

Assert.Throws<InvalidOperationException>(() => scopeA.ServiceProvider.GetKeyedService<IService>(KeyedService.AnyKey));
Assert.Throws<InvalidOperationException>(() => scopeB.ServiceProvider.GetKeyedService<IService>(KeyedService.AnyKey));

var serviceA1 = scopeA.ServiceProvider.GetKeyedService<IService>("key");
var serviceA2 = scopeA.ServiceProvider.GetKeyedService<IService>("key");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,4 +192,7 @@
<data name="InvalidServiceKeyType" xml:space="preserve">
<value>The type of the key used for lookup doesn't match the type in the constructor parameter with the ServiceKey attribute.</value>
</data>
<data name="KeyedServiceAnyKeyUsedToResolveService" xml:space="preserve">
<value>KeyedService.AnyKey cannot be used to resolve a single service.</value>
</data>
</root>
Original file line number Diff line number Diff line change
Expand Up @@ -282,11 +282,13 @@ private static bool AreCompatible(DynamicallyAccessedMemberTypes serviceDynamica
CallSiteResultCacheLocation cacheLocation = CallSiteResultCacheLocation.Root;
ServiceCallSite[] callSites;

var isAnyKeyLookup = serviceIdentifier.ServiceKey == KeyedService.AnyKey;

// If item type is not generic we can safely use descriptor cache
// Special case for KeyedService.AnyKey, we don't want to check the cache because a KeyedService.AnyKey registration
// will "hide" all the other service registration
if (!itemType.IsConstructedGenericType &&
!KeyedService.AnyKey.Equals(cacheKey.ServiceKey) &&
!isAnyKeyLookup &&
_descriptorLookup.TryGetValue(cacheKey, out ServiceDescriptorCacheItem descriptors))
{
callSites = new ServiceCallSite[descriptors.Count];
Expand Down Expand Up @@ -317,19 +319,25 @@ private static bool AreCompatible(DynamicallyAccessedMemberTypes serviceDynamica
int slot = 0;
for (int i = _descriptors.Length - 1; i >= 0; i--)
{
if (KeysMatch(_descriptors[i].ServiceKey, cacheKey.ServiceKey))
if (KeysMatch(cacheKey.ServiceKey, _descriptors[i].ServiceKey))
{
if (TryCreateExact(_descriptors[i], cacheKey, callSiteChain, slot) is { } callSite)
// Special case for AnyKey: we don't want to add in cache a mapping AnyKey -> specific type,
// so we need to ask creation with the original identity of the descriptor
var registrationKey = isAnyKeyLookup ? ServiceIdentifier.FromDescriptor(_descriptors[i]) : cacheKey;
if (TryCreateExact(_descriptors[i], registrationKey, callSiteChain, slot) is { } callSite)
{
AddCallSite(callSite, i);
}
}
}
for (int i = _descriptors.Length - 1; i >= 0; i--)
{
if (KeysMatch(_descriptors[i].ServiceKey, cacheKey.ServiceKey))
if (KeysMatch(cacheKey.ServiceKey, _descriptors[i].ServiceKey))
{
if (TryCreateOpenGeneric(_descriptors[i], cacheKey, callSiteChain, slot, throwOnConstraintViolation: false) is { } callSite)
// Special case for AnyKey: we don't want to add in cache a mapping AnyKey -> specific type,
// so we need to ask creation with the original identity of the descriptor
var registrationKey = isAnyKeyLookup ? ServiceIdentifier.FromDescriptor(_descriptors[i]) : cacheKey;
if (TryCreateOpenGeneric(_descriptors[i], registrationKey, callSiteChain, slot, throwOnConstraintViolation: false) is { } callSite)
{
AddCallSite(callSite, i);
}
Expand Down Expand Up @@ -360,6 +368,32 @@ void AddCallSite(ServiceCallSite callSite, int index)
{
callSiteChain.Remove(serviceIdentifier);
}

static bool KeysMatch(object? lookupKey, object? descriptorKey)
{
if (lookupKey == null && descriptorKey == null)
{
// Both are non keyed services
return true;
}

if (lookupKey != null && descriptorKey != null)
{
// Both are keyed services

// We don't want to return AnyKey registration, so ignore it
if (descriptorKey.Equals(KeyedService.AnyKey))
return false;

// Check if both keys are equal, or if the lookup key
// should matches all keys (except AnyKey)
return lookupKey.Equals(descriptorKey)
|| lookupKey.Equals(KeyedService.AnyKey);
}

// One is a keyed service, one is not
return false;
}
}

private static CallSiteResultCacheLocation GetCommonCacheLocation(CallSiteResultCacheLocation locationA, CallSiteResultCacheLocation locationB)
Expand Down Expand Up @@ -693,24 +727,6 @@ internal bool IsService(ServiceIdentifier serviceIdentifier)
serviceType == typeof(IServiceProviderIsKeyedService);
}

/// <summary>
/// Returns true if both keys are null or equals, or if key1 is KeyedService.AnyKey and key2 is not null
/// </summary>
private static bool KeysMatch(object? key1, object? key2)
{
if (key1 == null && key2 == null)
return true;

if (key1 != null && key2 != null)
{
return key1.Equals(key2)
|| key1.Equals(KeyedService.AnyKey)
|| key2.Equals(KeyedService.AnyKey);
}

return false;
}

private struct ServiceDescriptorCacheItem
{
[DisallowNull]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,12 @@ internal static void ThrowObjectDisposedException()
{
throw new ObjectDisposedException(nameof(IServiceProvider));
}

[DoesNotReturn]
[MethodImpl(MethodImplOptions.NoInlining)]
internal static void ThrowInvalidOperationException_KeyedServiceAnyKeyUsedToResolveService()
{
throw new InvalidOperationException(SR.Format(SR.KeyedServiceAnyKeyUsedToResolveService, nameof(IServiceProvider), nameof(IServiceScopeFactory)));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -108,21 +108,46 @@ internal ServiceProvider(ICollection<ServiceDescriptor> serviceDescriptors, Serv
/// <param name="serviceType">The type of the service to get.</param>
/// <param name="serviceKey">The key of the service to get.</param>
/// <returns>The keyed service.</returns>
/// <exception cref="InvalidOperationException">The <see cref="KeyedService.AnyKey"/> value is used for <paramref name="serviceKey"/>
/// when <paramref name="serviceType"/> is not an enumerable based on <see cref="IEnumerable{T}"/>.
/// </exception>
public object? GetKeyedService(Type serviceType, object? serviceKey)
=> GetKeyedService(serviceType, serviceKey, Root);

internal object? GetKeyedService(Type serviceType, object? serviceKey, ServiceProviderEngineScope serviceProviderEngineScope)
=> GetService(new ServiceIdentifier(serviceKey, serviceType), serviceProviderEngineScope);
{
if (serviceKey == KeyedService.AnyKey)
{
if (!serviceType.IsGenericType || serviceType.GetGenericTypeDefinition() != typeof(IEnumerable<>))
{
ThrowHelper.ThrowInvalidOperationException_KeyedServiceAnyKeyUsedToResolveService();
}
}

return GetService(new ServiceIdentifier(serviceKey, serviceType), serviceProviderEngineScope);
}

/// <summary>
/// Gets the service object of the specified type.
/// </summary>
/// <param name="serviceType">The type of the service to get.</param>
/// <param name="serviceKey">The key of the service to get.</param>
/// <returns>The keyed service.</returns>
/// <exception cref="InvalidOperationException">The service wasn't found.</exception>
/// <exception cref="InvalidOperationException">The service wasn't found or the <see cref="KeyedService.AnyKey"/> value is used
/// for <paramref name="serviceKey"/> when <paramref name="serviceType"/> is not an enumerable based on <see cref="IEnumerable{T}"/>.
/// </exception>
public object GetRequiredKeyedService(Type serviceType, object? serviceKey)
=> GetRequiredKeyedService(serviceType, serviceKey, Root);
{
if (serviceKey == KeyedService.AnyKey)
{
if (!serviceType.IsGenericType || serviceType.GetGenericTypeDefinition() != typeof(IEnumerable<>))
{
ThrowHelper.ThrowInvalidOperationException_KeyedServiceAnyKeyUsedToResolveService();
}
}

return GetRequiredKeyedService(serviceType, serviceKey, Root);
}

internal object GetRequiredKeyedService(Type serviceType, object? serviceKey, ServiceProviderEngineScope serviceProviderEngineScope)
{
Expand Down
Loading