Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,54 @@ public void ResolveKeyedServicesAnyKeyWithAnyKeyRegistration()
_ = provider.GetKeyedService<IService>("something-else");
_ = provider.GetKeyedService<IService>("something-else-again");

// Return all services registered with a non null key, but not the one "created" with KeyedService.AnyKey
// Return all services registered with a non null key, but not the one "created" with KeyedService.AnyKey,
// nor the KeyedService.AnyKey registration
var allServices = provider.GetKeyedServices<IService>(KeyedService.AnyKey).ToList();
Assert.Equal(5, allServices.Count);
Assert.Equal(new[] { service1, service2, service3, service4 }, allServices.Skip(1));
Assert.Equal(4, allServices.Count);
Assert.Equal(new[] { service1, service2, service3, service4 }, allServices);

var someKeyedServices = provider.GetKeyedServices<IService>("service").ToList();
Assert.Equal(new[] { service2, service3, service4 }, someKeyedServices);

var unkeyedServices = provider.GetServices<IService>().ToList();
Assert.Equal(new[] { service5, service6 }, unkeyedServices);
}

[Fact]
public void ResolveKeyedServicesAnyKeyConsistency()
{
var serviceCollection = new ServiceCollection();
var service = new Service("first-service");
serviceCollection.AddKeyedSingleton<IService>("first-service", service);

var provider1 = CreateServiceProvider(serviceCollection);
Assert.Throws<InvalidOperationException>(() => provider1.GetKeyedService<IService>(KeyedService.AnyKey));
// We don't return KeyedService.AnyKey registration when listing services
Assert.Equal(new[] { service }, provider1.GetKeyedServices<IService>(KeyedService.AnyKey));

var provider2 = CreateServiceProvider(serviceCollection);
Assert.Equal(new[] { service }, provider2.GetKeyedServices<IService>(KeyedService.AnyKey));
Assert.Throws<InvalidOperationException>(() => provider2.GetKeyedService<IService>(KeyedService.AnyKey));
}

[Fact]
public void ResolveKeyedServicesAnyKeyConsistencyWithAnyKeyRegistration()
{
var serviceCollection = new ServiceCollection();
var service = new Service("first-service");
var any = new Service("any");
serviceCollection.AddKeyedSingleton<IService>("first-service", service);
serviceCollection.AddKeyedSingleton<IService>(KeyedService.AnyKey, (sp, key) => any);

var provider1 = CreateServiceProvider(serviceCollection);
Assert.Equal(new[] { service }, provider1.GetKeyedServices<IService>(KeyedService.AnyKey));

// Check twice in different order to check caching
var provider2 = CreateServiceProvider(serviceCollection);
Assert.Equal(new[] { service }, provider2.GetKeyedServices<IService>(KeyedService.AnyKey));
Assert.Same(any, provider2.GetKeyedService<IService>(new object()));

Assert.Throws<InvalidOperationException>(() => provider2.GetKeyedService<IService>(KeyedService.AnyKey));
}

[Fact]
Expand Down Expand Up @@ -250,7 +294,7 @@ public void ResolveKeyedServicesSingletonInstanceWithAnyKey()
var provider = CreateServiceProvider(serviceCollection);

var services = provider.GetKeyedServices<IFakeOpenGenericService<PocoClass>>("some-key").ToList();
Assert.Equal(new[] { service1, service2 }, services);
Assert.Equal(new[] { service2 }, services);
}

[Fact]
Expand Down Expand Up @@ -504,6 +548,9 @@ public void ResolveKeyedSingletonFromScopeServiceProvider()
Assert.Null(scopeA.ServiceProvider.GetService<IService>());
Assert.Null(scopeB.ServiceProvider.GetService<IService>());

Assert.Throws<InvalidOperationException>(() => scopeA.ServiceProvider.GetKeyedService<IService>(KeyedService.AnyKey));
Assert.Throws<InvalidOperationException>(() => scopeB.ServiceProvider.GetKeyedService<IService>(KeyedService.AnyKey));

var serviceA1 = scopeA.ServiceProvider.GetKeyedService<IService>("key");
var serviceA2 = scopeA.ServiceProvider.GetKeyedService<IService>("key");

Expand All @@ -528,6 +575,9 @@ public void ResolveKeyedScopedFromScopeServiceProvider()
Assert.Null(scopeA.ServiceProvider.GetService<IService>());
Assert.Null(scopeB.ServiceProvider.GetService<IService>());

Assert.Throws<InvalidOperationException>(() => scopeA.ServiceProvider.GetKeyedService<IService>(KeyedService.AnyKey));
Assert.Throws<InvalidOperationException>(() => scopeB.ServiceProvider.GetKeyedService<IService>(KeyedService.AnyKey));

var serviceA1 = scopeA.ServiceProvider.GetKeyedService<IService>("key");
var serviceA2 = scopeA.ServiceProvider.GetKeyedService<IService>("key");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,4 +192,7 @@
<data name="InvalidServiceKeyType" xml:space="preserve">
<value>The type of the key used for lookup doesn't match the type in the constructor parameter with the ServiceKey attribute.</value>
</data>
<data name="KeyedServiceAnyKeyUsedToResolveService" xml:space="preserve">
<value>KeyedService.AnyKey cannot be used to resolve a single service.</value>
</data>
</root>
Original file line number Diff line number Diff line change
Expand Up @@ -282,11 +282,13 @@ private static bool AreCompatible(DynamicallyAccessedMemberTypes serviceDynamica
CallSiteResultCacheLocation cacheLocation = CallSiteResultCacheLocation.Root;
ServiceCallSite[] callSites;

var isAnyKeyLookup = serviceIdentifier.ServiceKey == KeyedService.AnyKey;

// If item type is not generic we can safely use descriptor cache
// Special case for KeyedService.AnyKey, we don't want to check the cache because a KeyedService.AnyKey registration
// will "hide" all the other service registration
if (!itemType.IsConstructedGenericType &&
!KeyedService.AnyKey.Equals(cacheKey.ServiceKey) &&
!isAnyKeyLookup &&
_descriptorLookup.TryGetValue(cacheKey, out ServiceDescriptorCacheItem descriptors))
{
callSites = new ServiceCallSite[descriptors.Count];
Expand Down Expand Up @@ -317,19 +319,25 @@ private static bool AreCompatible(DynamicallyAccessedMemberTypes serviceDynamica
int slot = 0;
for (int i = _descriptors.Length - 1; i >= 0; i--)
{
if (KeysMatch(_descriptors[i].ServiceKey, cacheKey.ServiceKey))
if (KeysMatch(cacheKey.ServiceKey, _descriptors[i].ServiceKey))
{
if (TryCreateExact(_descriptors[i], cacheKey, callSiteChain, slot) is { } callSite)
// Special case for AnyKey: we don't want to add in cache a mapping AnyKey -> specific type,
// so we need to ask creation with the original identity of the descriptor
var registrationKey = isAnyKeyLookup ? ServiceIdentifier.FromDescriptor(_descriptors[i]) : cacheKey;
if (TryCreateExact(_descriptors[i], registrationKey, callSiteChain, slot) is { } callSite)
{
AddCallSite(callSite, i);
}
}
}
for (int i = _descriptors.Length - 1; i >= 0; i--)
{
if (KeysMatch(_descriptors[i].ServiceKey, cacheKey.ServiceKey))
if (KeysMatch(cacheKey.ServiceKey, _descriptors[i].ServiceKey))
{
if (TryCreateOpenGeneric(_descriptors[i], cacheKey, callSiteChain, slot, throwOnConstraintViolation: false) is { } callSite)
// Special case for AnyKey: we don't want to add in cache a mapping AnyKey -> specific type,
// so we need to ask creation with the original identity of the descriptor
var registrationKey = isAnyKeyLookup ? ServiceIdentifier.FromDescriptor(_descriptors[i]) : cacheKey;
if (TryCreateOpenGeneric(_descriptors[i], registrationKey, callSiteChain, slot, throwOnConstraintViolation: false) is { } callSite)
{
AddCallSite(callSite, i);
}
Expand Down Expand Up @@ -360,6 +368,32 @@ void AddCallSite(ServiceCallSite callSite, int index)
{
callSiteChain.Remove(serviceIdentifier);
}

static bool KeysMatch(object? lookupKey, object? descriptorKey)
{
if (lookupKey == null && descriptorKey == null)
{
// Both are non keyed services
return true;
}

if (lookupKey != null && descriptorKey != null)
{
// Both are keyed services

// We don't want to return AnyKey registration, so ignore it
if (descriptorKey.Equals(KeyedService.AnyKey))
return false;

// Check if both keys are equal, or if the lookup key
// should matches all keys (except AnyKey)
return lookupKey.Equals(descriptorKey)
|| lookupKey.Equals(KeyedService.AnyKey);
}

// One is a keyed service, one is not
return false;
}
}

private static CallSiteResultCacheLocation GetCommonCacheLocation(CallSiteResultCacheLocation locationA, CallSiteResultCacheLocation locationB)
Expand Down Expand Up @@ -693,24 +727,6 @@ internal bool IsService(ServiceIdentifier serviceIdentifier)
serviceType == typeof(IServiceProviderIsKeyedService);
}

/// <summary>
/// Returns true if both keys are null or equals, or if key1 is KeyedService.AnyKey and key2 is not null
/// </summary>
private static bool KeysMatch(object? key1, object? key2)
{
if (key1 == null && key2 == null)
return true;

if (key1 != null && key2 != null)
{
return key1.Equals(key2)
|| key1.Equals(KeyedService.AnyKey)
|| key2.Equals(KeyedService.AnyKey);
}

return false;
}

private struct ServiceDescriptorCacheItem
{
[DisallowNull]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,12 @@ internal static void ThrowObjectDisposedException()
{
throw new ObjectDisposedException(nameof(IServiceProvider));
}

[DoesNotReturn]
[MethodImpl(MethodImplOptions.NoInlining)]
internal static void ThrowInvalidOperationException_KeyedServiceAnyKeyUsedToResolveService()
{
throw new InvalidOperationException(SR.Format(SR.KeyedServiceAnyKeyUsedToResolveService, nameof(IServiceProvider), nameof(IServiceScopeFactory)));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,17 @@ internal ServiceProvider(ICollection<ServiceDescriptor> serviceDescriptors, Serv
=> GetKeyedService(serviceType, serviceKey, Root);

internal object? GetKeyedService(Type serviceType, object? serviceKey, ServiceProviderEngineScope serviceProviderEngineScope)
=> GetService(new ServiceIdentifier(serviceKey, serviceType), serviceProviderEngineScope);
{
if (serviceKey == KeyedService.AnyKey)
{
if (!serviceType.IsGenericType || serviceType.GetGenericTypeDefinition() != typeof(IEnumerable<>))
{
ThrowHelper.ThrowInvalidOperationException_KeyedServiceAnyKeyUsedToResolveService();
}
}

return GetService(new ServiceIdentifier(serviceKey, serviceType), serviceProviderEngineScope);
}

/// <summary>
/// Gets the service object of the specified type.
Expand All @@ -122,7 +132,17 @@ internal ServiceProvider(ICollection<ServiceDescriptor> serviceDescriptors, Serv
/// <returns>The keyed service.</returns>
/// <exception cref="InvalidOperationException">The service wasn't found.</exception>
public object GetRequiredKeyedService(Type serviceType, object? serviceKey)
=> GetRequiredKeyedService(serviceType, serviceKey, Root);
{
if (serviceKey == KeyedService.AnyKey)
{
if (!serviceType.IsGenericType || serviceType.GetGenericTypeDefinition() != typeof(IEnumerable<>))
{
ThrowHelper.ThrowInvalidOperationException_KeyedServiceAnyKeyUsedToResolveService();
}
}

return GetRequiredKeyedService(serviceType, serviceKey, Root);
}

internal object GetRequiredKeyedService(Type serviceType, object? serviceKey, ServiceProviderEngineScope serviceProviderEngineScope)
{
Expand Down
Loading