diff --git a/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/CollectionsMarshal.cs b/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/CollectionsMarshal.cs index a64f44683323e2..ccf5a9a2b44b33 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/CollectionsMarshal.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/CollectionsMarshal.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Collections.Generic; +using Internal.Runtime.CompilerServices; namespace System.Runtime.InteropServices { @@ -18,6 +19,25 @@ public static class CollectionsMarshal public static Span AsSpan(List? list) => list is null ? default : new Span(list._items, 0, list._size); + /// + /// Gets a ref to a in the . + /// + /// The dictionary to get the ref to from. + /// The key used for lookup. + /// Items should not be added or removed from the while the ref is in use. + /// Thrown when does not exist in the . + public static ref TValue GetValueRef(Dictionary dictionary, TKey key) where TKey : notnull + { + ref TValue valueRef = ref dictionary.FindValue(key); + + if (Unsafe.IsNullRef(ref valueRef)) + { + ThrowHelper.ThrowKeyNotFoundException(key); + } + + return ref valueRef; + } + /// /// Gets either a ref to a in the or a ref null if it does not exist in the . /// diff --git a/src/libraries/System.Runtime.InteropServices/ref/System.Runtime.InteropServices.cs b/src/libraries/System.Runtime.InteropServices/ref/System.Runtime.InteropServices.cs index 02e739ad731733..12d191daf9c18e 100644 --- a/src/libraries/System.Runtime.InteropServices/ref/System.Runtime.InteropServices.cs +++ b/src/libraries/System.Runtime.InteropServices/ref/System.Runtime.InteropServices.cs @@ -176,6 +176,7 @@ public CoClassAttribute(System.Type coClass) { } public static partial class CollectionsMarshal { public static System.Span AsSpan(System.Collections.Generic.List? list) { throw null; } + public static ref TValue GetValueRef(System.Collections.Generic.Dictionary dictionary, TKey key) where TKey : notnull { throw null; } public static ref TValue GetValueRefOrNullRef(System.Collections.Generic.Dictionary dictionary, TKey key) where TKey : notnull { throw null; } } [System.AttributeUsageAttribute(System.AttributeTargets.Field | System.AttributeTargets.Parameter | System.AttributeTargets.Property | System.AttributeTargets.ReturnValue, Inherited=false)] diff --git a/src/libraries/System.Runtime.InteropServices/tests/System/Runtime/InteropServices/CollectionsMarshalTests.cs b/src/libraries/System.Runtime.InteropServices/tests/System/Runtime/InteropServices/CollectionsMarshalTests.cs index 5c5116b2a72004..eda15955f0adea 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/System/Runtime/InteropServices/CollectionsMarshalTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/System/Runtime/InteropServices/CollectionsMarshalTests.cs @@ -144,6 +144,159 @@ public void ListAsSpanLinkBreaksOnResize() } } + [Fact] + public void GetValueRefValueType() + { + var dict = new Dictionary + { + { 1, default }, + { 2, default } + }; + + Assert.Equal(2, dict.Count); + + Assert.Equal(0, dict[1].Value); + Assert.Equal(0, dict[1].Property); + + var itemVal = dict[1]; + itemVal.Value = 1; + itemVal.Property = 2; + + // Does not change values in dictionary + Assert.Equal(0, dict[1].Value); + Assert.Equal(0, dict[1].Property); + + CollectionsMarshal.GetValueRef(dict, 1).Value = 3; + CollectionsMarshal.GetValueRef(dict, 1).Property = 4; + + Assert.Equal(3, dict[1].Value); + Assert.Equal(4, dict[1].Property); + + ref var itemRef = ref CollectionsMarshal.GetValueRef(dict, 2); + + Assert.Equal(0, itemRef.Value); + Assert.Equal(0, itemRef.Property); + + itemRef.Value = 5; + itemRef.Property = 6; + + Assert.Equal(5, itemRef.Value); + Assert.Equal(6, itemRef.Property); + Assert.Equal(dict[2].Value, itemRef.Value); + Assert.Equal(dict[2].Property, itemRef.Property); + + itemRef = new() { Value = 7, Property = 8 }; + + Assert.Equal(7, itemRef.Value); + Assert.Equal(8, itemRef.Property); + Assert.Equal(dict[2].Value, itemRef.Value); + Assert.Equal(dict[2].Property, itemRef.Property); + + // Check for exception + + Assert.Throws(() => CollectionsMarshal.GetValueRef(dict, 3)); + + Assert.Equal(2, dict.Count); + } + + [Fact] + public void GetValueRefClass() + { + var dict = new Dictionary + { + { 1, new() }, + { 2, new() } + }; + + Assert.Equal(2, dict.Count); + + Assert.Equal(0, dict[1].Value); + Assert.Equal(0, dict[1].Property); + + var itemVal = dict[1]; + itemVal.Value = 1; + itemVal.Property = 2; + + // Does change values in dictionary + Assert.Equal(1, dict[1].Value); + Assert.Equal(2, dict[1].Property); + + CollectionsMarshal.GetValueRef(dict, 1).Value = 3; + CollectionsMarshal.GetValueRef(dict, 1).Property = 4; + + Assert.Equal(3, dict[1].Value); + Assert.Equal(4, dict[1].Property); + + ref var itemRef = ref CollectionsMarshal.GetValueRef(dict, 2); + + Assert.Equal(0, itemRef.Value); + Assert.Equal(0, itemRef.Property); + + itemRef.Value = 5; + itemRef.Property = 6; + + Assert.Equal(5, itemRef.Value); + Assert.Equal(6, itemRef.Property); + Assert.Equal(dict[2].Value, itemRef.Value); + Assert.Equal(dict[2].Property, itemRef.Property); + + itemRef = new() { Value = 7, Property = 8 }; + + Assert.Equal(7, itemRef.Value); + Assert.Equal(8, itemRef.Property); + Assert.Equal(dict[2].Value, itemRef.Value); + Assert.Equal(dict[2].Property, itemRef.Property); + + // Check for exception + + Assert.Throws(() => CollectionsMarshal.GetValueRef(dict, 3)); + + Assert.Equal(2, dict.Count); + } + + [Fact] + public void GetValueRefLinkBreaksOnResize() + { + var dict = new Dictionary + { + { 1, new() } + }; + + Assert.Equal(1, dict.Count); + + ref var itemRef = ref CollectionsMarshal.GetValueRef(dict, 1); + + Assert.Equal(0, itemRef.Value); + Assert.Equal(0, itemRef.Property); + + itemRef.Value = 1; + itemRef.Property = 2; + + Assert.Equal(1, itemRef.Value); + Assert.Equal(2, itemRef.Property); + Assert.Equal(dict[1].Value, itemRef.Value); + Assert.Equal(dict[1].Property, itemRef.Property); + + // Resize + dict.EnsureCapacity(100); + for (int i = 2; i <= 50; i++) + { + dict.Add(i, new()); + } + + itemRef.Value = 3; + itemRef.Property = 4; + + Assert.Equal(3, itemRef.Value); + Assert.Equal(4, itemRef.Property); + + // Check connection broken + Assert.NotEqual(dict[1].Value, itemRef.Value); + Assert.NotEqual(dict[1].Property, itemRef.Property); + + Assert.Equal(50, dict.Count); + } + [Fact] public void GetValueRefOrNullRefValueType() {