diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection.Abstractions/src/Extensions/ServiceCollectionDescriptorExtensions.Keyed.cs b/src/libraries/Microsoft.Extensions.DependencyInjection.Abstractions/src/Extensions/ServiceCollectionDescriptorExtensions.Keyed.cs
index aad0b4ff2fe414..436f108e5eb5f2 100644
--- a/src/libraries/Microsoft.Extensions.DependencyInjection.Abstractions/src/Extensions/ServiceCollectionDescriptorExtensions.Keyed.cs
+++ b/src/libraries/Microsoft.Extensions.DependencyInjection.Abstractions/src/Extensions/ServiceCollectionDescriptorExtensions.Keyed.cs
@@ -401,7 +401,7 @@ public static IServiceCollection RemoveAllKeyed(this IServiceCollection collecti
for (int i = collection.Count - 1; i >= 0; i--)
{
ServiceDescriptor? descriptor = collection[i];
- if (descriptor.ServiceType == serviceType && descriptor.ServiceKey == serviceKey)
+ if (descriptor.ServiceType == serviceType && object.Equals(descriptor.ServiceKey, serviceKey))
{
collection.RemoveAt(i);
}
diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection.Abstractions/src/Extensions/ServiceCollectionDescriptorExtensions.cs b/src/libraries/Microsoft.Extensions.DependencyInjection.Abstractions/src/Extensions/ServiceCollectionDescriptorExtensions.cs
index f2435dca6cfcee..2d000f17105832 100644
--- a/src/libraries/Microsoft.Extensions.DependencyInjection.Abstractions/src/Extensions/ServiceCollectionDescriptorExtensions.cs
+++ b/src/libraries/Microsoft.Extensions.DependencyInjection.Abstractions/src/Extensions/ServiceCollectionDescriptorExtensions.cs
@@ -67,7 +67,7 @@ public static void TryAdd(
for (int i = 0; i < count; i++)
{
if (collection[i].ServiceType == descriptor.ServiceType
- && collection[i].ServiceKey == descriptor.ServiceKey)
+ && object.Equals(collection[i].ServiceKey, descriptor.ServiceKey))
{
// Already added
return;
@@ -474,7 +474,7 @@ public static void TryAddEnumerable(
ServiceDescriptor service = services[i];
if (service.ServiceType == descriptor.ServiceType &&
service.GetImplementationType() == implementationType &&
- service.ServiceKey == descriptor.ServiceKey)
+ object.Equals(service.ServiceKey, descriptor.ServiceKey))
{
// Already added
return;
@@ -532,7 +532,7 @@ public static IServiceCollection Replace(
int count = collection.Count;
for (int i = 0; i < count; i++)
{
- if (collection[i].ServiceType == descriptor.ServiceType && collection[i].ServiceKey == descriptor.ServiceKey)
+ if (collection[i].ServiceType == descriptor.ServiceType && object.Equals(collection[i].ServiceKey, descriptor.ServiceKey))
{
collection.RemoveAt(i);
break;
diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection.Abstractions/src/Microsoft.Extensions.DependencyInjection.Abstractions.csproj b/src/libraries/Microsoft.Extensions.DependencyInjection.Abstractions/src/Microsoft.Extensions.DependencyInjection.Abstractions.csproj
index 422c5bef17065a..134bae2b58535f 100644
--- a/src/libraries/Microsoft.Extensions.DependencyInjection.Abstractions/src/Microsoft.Extensions.DependencyInjection.Abstractions.csproj
+++ b/src/libraries/Microsoft.Extensions.DependencyInjection.Abstractions/src/Microsoft.Extensions.DependencyInjection.Abstractions.csproj
@@ -4,6 +4,8 @@
$(NetCoreAppCurrent);$(NetCoreAppPrevious);$(NetCoreAppMinimum);netstandard2.1;netstandard2.0;$(NetFrameworkMinimum)
true
true
+ true
+ 1
Abstractions for dependency injection.
Commonly Used Types:
diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceCollectionKeyedServiceExtensionsTest.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceCollectionKeyedServiceExtensionsTest.cs
index 25b61065ef0ba3..6be7e22cce2c35 100644
--- a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceCollectionKeyedServiceExtensionsTest.cs
+++ b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceCollectionKeyedServiceExtensionsTest.cs
@@ -170,6 +170,7 @@ public static TheoryData TryAddImplementationTypeData
{ collection => collection.TryAddKeyedTransient("key-2"), serviceType, "key-2", implementationType, ServiceLifetime.Transient },
{ collection => collection.TryAddKeyedTransient("key-3"), serviceType, "key-3", serviceType, ServiceLifetime.Transient },
{ collection => collection.TryAddKeyedTransient(implementationType, "key-4"), implementationType, "key-4", implementationType, ServiceLifetime.Transient },
+ { collection => collection.TryAddKeyedTransient(implementationType, 9), implementationType, 9, implementationType, ServiceLifetime.Transient },
{ collection => collection.TryAddKeyedScoped(serviceType, "key-1", implementationType), serviceType, "key-1", implementationType, ServiceLifetime.Scoped },
{ collection => collection.TryAddKeyedScoped("key-2"), serviceType, "key-2", implementationType, ServiceLifetime.Scoped },
@@ -325,6 +326,40 @@ public void TryAddEnumerable_DoesNotAddDuplicate(
Assert.Equal(expectedLifetime, d.Lifetime);
}
+ [Fact]
+ public void TryAddEnumerable_DoesNotAddDuplicateWhenKeyIsInt()
+ {
+ // Arrange
+ var collection = new ServiceCollection();
+ var descriptor1 = ServiceDescriptor.KeyedTransient(1);
+ collection.TryAddEnumerable(descriptor1);
+ var descriptor2 = ServiceDescriptor.KeyedTransient(1);
+
+ // Act
+ collection.TryAddEnumerable(descriptor2);
+
+ // Assert
+ var d = Assert.Single(collection);
+ Assert.Same(descriptor1, d);
+ }
+
+ [Fact]
+ public void TryAddEnumerable_DoesNotAddDuplicateWhenKeyIsString()
+ {
+ // Arrange
+ var collection = new ServiceCollection();
+ var descriptor1 = ServiceDescriptor.KeyedTransient("service1");
+ collection.TryAddEnumerable(descriptor1);
+ var descriptor2 = ServiceDescriptor.KeyedTransient("service1");
+
+ // Act
+ collection.TryAddEnumerable(descriptor2);
+
+ // Assert
+ var d = Assert.Single(collection);
+ Assert.Same(descriptor1, d);
+ }
+
public static TheoryData TryAddEnumerableInvalidImplementationTypeData
{
get
@@ -412,6 +447,24 @@ public void Replace_ReplacesFirstServiceWithMatchingServiceType()
Assert.Equal(new[] { descriptor2, descriptor3 }, collection);
}
+ [Fact]
+ public void Replace_ReplacesFirstServiceWithMatchingServiceTypeWhenKeyIsInt()
+ {
+ // Arrange
+ var collection = new ServiceCollection();
+ var descriptor1 = new ServiceDescriptor(typeof(IFakeService), 1, typeof(FakeService), ServiceLifetime.Transient);
+ var descriptor2 = new ServiceDescriptor(typeof(IFakeService), 1, typeof(FakeService), ServiceLifetime.Transient);
+ collection.Add(descriptor1);
+ collection.Add(descriptor2);
+ var descriptor3 = new ServiceDescriptor(typeof(IFakeService), 1, typeof(FakeService), ServiceLifetime.Singleton);
+
+ // Act
+ collection.Replace(descriptor3);
+
+ // Assert
+ Assert.Equal(new[] { descriptor2, descriptor3 }, collection);
+ }
+
[Fact]
public void RemoveAll_RemovesAllServicesWithMatchingServiceType()
{
@@ -431,6 +484,44 @@ public void RemoveAll_RemovesAllServicesWithMatchingServiceType()
Assert.Equal(new[] { descriptor }, collection);
}
+ private enum ServiceKeyEnum { First, Second }
+
+ [Fact]
+ public void RemoveAll_RemovesAllMatchingServicesWhenKeyIsEnum()
+ {
+ var descriptor = new ServiceDescriptor(typeof(IFakeService), ServiceKeyEnum.First, typeof(FakeService), ServiceLifetime.Transient);
+ var collection = new ServiceCollection
+ {
+ descriptor,
+ new ServiceDescriptor(typeof(IFakeService), ServiceKeyEnum.Second, typeof(FakeService), ServiceLifetime.Transient),
+ new ServiceDescriptor(typeof(IFakeService), ServiceKeyEnum.Second, typeof(FakeService), ServiceLifetime.Transient),
+ };
+
+ // Act
+ collection.RemoveAllKeyed(ServiceKeyEnum.Second);
+
+ // Assert
+ Assert.Equal(new[] { descriptor }, collection);
+ }
+
+ [Fact]
+ public void RemoveAll_RemovesAllMatchingServicesWhenKeyIsInt()
+ {
+ var descriptor = new ServiceDescriptor(typeof(IFakeService), 1, typeof(FakeService), ServiceLifetime.Transient);
+ var collection = new ServiceCollection
+ {
+ descriptor,
+ new ServiceDescriptor(typeof(IFakeService), 2, typeof(FakeService), ServiceLifetime.Transient),
+ new ServiceDescriptor(typeof(IFakeService), 2, typeof(FakeService), ServiceLifetime.Transient),
+ };
+
+ // Act
+ collection.RemoveAllKeyed(2);
+
+ // Assert
+ Assert.Equal(new[] { descriptor }, collection);
+ }
+
public static TheoryData NullServiceKeyData
{
get