diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection.Abstractions/src/Extensions/ServiceCollectionDescriptorExtensions.Keyed.cs b/src/libraries/Microsoft.Extensions.DependencyInjection.Abstractions/src/Extensions/ServiceCollectionDescriptorExtensions.Keyed.cs index aad0b4ff2fe414..436f108e5eb5f2 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection.Abstractions/src/Extensions/ServiceCollectionDescriptorExtensions.Keyed.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection.Abstractions/src/Extensions/ServiceCollectionDescriptorExtensions.Keyed.cs @@ -401,7 +401,7 @@ public static IServiceCollection RemoveAllKeyed(this IServiceCollection collecti for (int i = collection.Count - 1; i >= 0; i--) { ServiceDescriptor? descriptor = collection[i]; - if (descriptor.ServiceType == serviceType && descriptor.ServiceKey == serviceKey) + if (descriptor.ServiceType == serviceType && object.Equals(descriptor.ServiceKey, serviceKey)) { collection.RemoveAt(i); } diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection.Abstractions/src/Extensions/ServiceCollectionDescriptorExtensions.cs b/src/libraries/Microsoft.Extensions.DependencyInjection.Abstractions/src/Extensions/ServiceCollectionDescriptorExtensions.cs index f2435dca6cfcee..2d000f17105832 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection.Abstractions/src/Extensions/ServiceCollectionDescriptorExtensions.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection.Abstractions/src/Extensions/ServiceCollectionDescriptorExtensions.cs @@ -67,7 +67,7 @@ public static void TryAdd( for (int i = 0; i < count; i++) { if (collection[i].ServiceType == descriptor.ServiceType - && collection[i].ServiceKey == descriptor.ServiceKey) + && object.Equals(collection[i].ServiceKey, descriptor.ServiceKey)) { // Already added return; @@ -474,7 +474,7 @@ public static void TryAddEnumerable( ServiceDescriptor service = services[i]; if (service.ServiceType == descriptor.ServiceType && service.GetImplementationType() == implementationType && - service.ServiceKey == descriptor.ServiceKey) + object.Equals(service.ServiceKey, descriptor.ServiceKey)) { // Already added return; @@ -532,7 +532,7 @@ public static IServiceCollection Replace( int count = collection.Count; for (int i = 0; i < count; i++) { - if (collection[i].ServiceType == descriptor.ServiceType && collection[i].ServiceKey == descriptor.ServiceKey) + if (collection[i].ServiceType == descriptor.ServiceType && object.Equals(collection[i].ServiceKey, descriptor.ServiceKey)) { collection.RemoveAt(i); break; diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection.Abstractions/src/Microsoft.Extensions.DependencyInjection.Abstractions.csproj b/src/libraries/Microsoft.Extensions.DependencyInjection.Abstractions/src/Microsoft.Extensions.DependencyInjection.Abstractions.csproj index 422c5bef17065a..134bae2b58535f 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection.Abstractions/src/Microsoft.Extensions.DependencyInjection.Abstractions.csproj +++ b/src/libraries/Microsoft.Extensions.DependencyInjection.Abstractions/src/Microsoft.Extensions.DependencyInjection.Abstractions.csproj @@ -4,6 +4,8 @@ $(NetCoreAppCurrent);$(NetCoreAppPrevious);$(NetCoreAppMinimum);netstandard2.1;netstandard2.0;$(NetFrameworkMinimum) true true + true + 1 Abstractions for dependency injection. Commonly Used Types: diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceCollectionKeyedServiceExtensionsTest.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceCollectionKeyedServiceExtensionsTest.cs index 25b61065ef0ba3..6be7e22cce2c35 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceCollectionKeyedServiceExtensionsTest.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceCollectionKeyedServiceExtensionsTest.cs @@ -170,6 +170,7 @@ public static TheoryData TryAddImplementationTypeData { collection => collection.TryAddKeyedTransient("key-2"), serviceType, "key-2", implementationType, ServiceLifetime.Transient }, { collection => collection.TryAddKeyedTransient("key-3"), serviceType, "key-3", serviceType, ServiceLifetime.Transient }, { collection => collection.TryAddKeyedTransient(implementationType, "key-4"), implementationType, "key-4", implementationType, ServiceLifetime.Transient }, + { collection => collection.TryAddKeyedTransient(implementationType, 9), implementationType, 9, implementationType, ServiceLifetime.Transient }, { collection => collection.TryAddKeyedScoped(serviceType, "key-1", implementationType), serviceType, "key-1", implementationType, ServiceLifetime.Scoped }, { collection => collection.TryAddKeyedScoped("key-2"), serviceType, "key-2", implementationType, ServiceLifetime.Scoped }, @@ -325,6 +326,40 @@ public void TryAddEnumerable_DoesNotAddDuplicate( Assert.Equal(expectedLifetime, d.Lifetime); } + [Fact] + public void TryAddEnumerable_DoesNotAddDuplicateWhenKeyIsInt() + { + // Arrange + var collection = new ServiceCollection(); + var descriptor1 = ServiceDescriptor.KeyedTransient(1); + collection.TryAddEnumerable(descriptor1); + var descriptor2 = ServiceDescriptor.KeyedTransient(1); + + // Act + collection.TryAddEnumerable(descriptor2); + + // Assert + var d = Assert.Single(collection); + Assert.Same(descriptor1, d); + } + + [Fact] + public void TryAddEnumerable_DoesNotAddDuplicateWhenKeyIsString() + { + // Arrange + var collection = new ServiceCollection(); + var descriptor1 = ServiceDescriptor.KeyedTransient("service1"); + collection.TryAddEnumerable(descriptor1); + var descriptor2 = ServiceDescriptor.KeyedTransient("service1"); + + // Act + collection.TryAddEnumerable(descriptor2); + + // Assert + var d = Assert.Single(collection); + Assert.Same(descriptor1, d); + } + public static TheoryData TryAddEnumerableInvalidImplementationTypeData { get @@ -412,6 +447,24 @@ public void Replace_ReplacesFirstServiceWithMatchingServiceType() Assert.Equal(new[] { descriptor2, descriptor3 }, collection); } + [Fact] + public void Replace_ReplacesFirstServiceWithMatchingServiceTypeWhenKeyIsInt() + { + // Arrange + var collection = new ServiceCollection(); + var descriptor1 = new ServiceDescriptor(typeof(IFakeService), 1, typeof(FakeService), ServiceLifetime.Transient); + var descriptor2 = new ServiceDescriptor(typeof(IFakeService), 1, typeof(FakeService), ServiceLifetime.Transient); + collection.Add(descriptor1); + collection.Add(descriptor2); + var descriptor3 = new ServiceDescriptor(typeof(IFakeService), 1, typeof(FakeService), ServiceLifetime.Singleton); + + // Act + collection.Replace(descriptor3); + + // Assert + Assert.Equal(new[] { descriptor2, descriptor3 }, collection); + } + [Fact] public void RemoveAll_RemovesAllServicesWithMatchingServiceType() { @@ -431,6 +484,44 @@ public void RemoveAll_RemovesAllServicesWithMatchingServiceType() Assert.Equal(new[] { descriptor }, collection); } + private enum ServiceKeyEnum { First, Second } + + [Fact] + public void RemoveAll_RemovesAllMatchingServicesWhenKeyIsEnum() + { + var descriptor = new ServiceDescriptor(typeof(IFakeService), ServiceKeyEnum.First, typeof(FakeService), ServiceLifetime.Transient); + var collection = new ServiceCollection + { + descriptor, + new ServiceDescriptor(typeof(IFakeService), ServiceKeyEnum.Second, typeof(FakeService), ServiceLifetime.Transient), + new ServiceDescriptor(typeof(IFakeService), ServiceKeyEnum.Second, typeof(FakeService), ServiceLifetime.Transient), + }; + + // Act + collection.RemoveAllKeyed(ServiceKeyEnum.Second); + + // Assert + Assert.Equal(new[] { descriptor }, collection); + } + + [Fact] + public void RemoveAll_RemovesAllMatchingServicesWhenKeyIsInt() + { + var descriptor = new ServiceDescriptor(typeof(IFakeService), 1, typeof(FakeService), ServiceLifetime.Transient); + var collection = new ServiceCollection + { + descriptor, + new ServiceDescriptor(typeof(IFakeService), 2, typeof(FakeService), ServiceLifetime.Transient), + new ServiceDescriptor(typeof(IFakeService), 2, typeof(FakeService), ServiceLifetime.Transient), + }; + + // Act + collection.RemoveAllKeyed(2); + + // Assert + Assert.Equal(new[] { descriptor }, collection); + } + public static TheoryData NullServiceKeyData { get