diff --git a/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.T.cs b/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.T.cs index 877c21311876d5..2d0d68b8b55679 100644 --- a/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.T.cs +++ b/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.T.cs @@ -6,6 +6,7 @@ using System.Numerics; using System.Runtime.CompilerServices; using System.Runtime.Intrinsics; +using System.Runtime.Intrinsics.Arm; namespace System { @@ -1543,7 +1544,7 @@ private static int IndexOfValueType(ref TValue searchSpace, TV do { equals = TNegator.NegateIfNeeded(Vector128.Equals(values, Vector128.LoadUnsafe(ref currentSearchSpace))); - if (equals == Vector128.Zero) + if (!VectorContainsMatch(equals)) { currentSearchSpace = ref Unsafe.Add(ref currentSearchSpace, Vector128.Count); continue; @@ -1557,7 +1558,7 @@ private static int IndexOfValueType(ref TValue searchSpace, TV if ((uint)length % Vector128.Count != 0) { equals = TNegator.NegateIfNeeded(Vector128.Equals(values, Vector128.LoadUnsafe(ref oneVectorAwayFromEnd))); - if (equals != Vector128.Zero) + if (VectorContainsMatch(equals)) { return ComputeFirstIndex(ref searchSpace, ref oneVectorAwayFromEnd, equals); } @@ -1691,7 +1692,7 @@ private static int IndexOfAnyValueType(ref TValue searchSpace, { current = Vector128.LoadUnsafe(ref currentSearchSpace); equals = TNegator.NegateIfNeeded(Vector128.Equals(values0, current) | Vector128.Equals(values1, current)); - if (equals == Vector128.Zero) + if (!VectorContainsMatch(equals)) { currentSearchSpace = ref Unsafe.Add(ref currentSearchSpace, Vector128.Count); continue; @@ -1706,7 +1707,7 @@ private static int IndexOfAnyValueType(ref TValue searchSpace, { current = Vector128.LoadUnsafe(ref oneVectorAwayFromEnd); equals = TNegator.NegateIfNeeded(Vector128.Equals(values0, current) | Vector128.Equals(values1, current)); - if (equals != Vector128.Zero) + if (VectorContainsMatch(equals)) { return ComputeFirstIndex(ref searchSpace, ref oneVectorAwayFromEnd, equals); } @@ -1835,7 +1836,7 @@ private static int IndexOfAnyValueType(ref TValue searchSpace, { current = Vector128.LoadUnsafe(ref currentSearchSpace); equals = TNegator.NegateIfNeeded(Vector128.Equals(values0, current) | Vector128.Equals(values1, current) | Vector128.Equals(values2, current)); - if (equals == Vector128.Zero) + if (!VectorContainsMatch(equals)) { currentSearchSpace = ref Unsafe.Add(ref currentSearchSpace, Vector128.Count); continue; @@ -1850,7 +1851,7 @@ private static int IndexOfAnyValueType(ref TValue searchSpace, { current = Vector128.LoadUnsafe(ref oneVectorAwayFromEnd); equals = TNegator.NegateIfNeeded(Vector128.Equals(values0, current) | Vector128.Equals(values1, current) | Vector128.Equals(values2, current)); - if (equals != Vector128.Zero) + if (VectorContainsMatch(equals)) { return ComputeFirstIndex(ref searchSpace, ref oneVectorAwayFromEnd, equals); } @@ -1954,7 +1955,7 @@ private static int IndexOfAnyValueType(ref TValue searchSpace, current = Vector128.LoadUnsafe(ref currentSearchSpace); equals = TNegator.NegateIfNeeded(Vector128.Equals(values0, current) | Vector128.Equals(values1, current) | Vector128.Equals(values2, current) | Vector128.Equals(values3, current)); - if (equals == Vector128.Zero) + if (!VectorContainsMatch(equals)) { currentSearchSpace = ref Unsafe.Add(ref currentSearchSpace, Vector128.Count); continue; @@ -1970,7 +1971,7 @@ private static int IndexOfAnyValueType(ref TValue searchSpace, current = Vector128.LoadUnsafe(ref oneVectorAwayFromEnd); equals = TNegator.NegateIfNeeded(Vector128.Equals(values0, current) | Vector128.Equals(values1, current) | Vector128.Equals(values2, current) | Vector128.Equals(values3, current)); - if (equals != Vector128.Zero) + if (VectorContainsMatch(equals)) { return ComputeFirstIndex(ref searchSpace, ref oneVectorAwayFromEnd, equals); } @@ -2067,7 +2068,7 @@ internal static int IndexOfAnyValueType(ref T searchSpace, T value0, T value1 current = Vector128.LoadUnsafe(ref currentSearchSpace); equals = Vector128.Equals(values0, current) | Vector128.Equals(values1, current) | Vector128.Equals(values2, current) | Vector128.Equals(values3, current) | Vector128.Equals(values4, current); - if (equals == Vector128.Zero) + if (!VectorContainsMatch(equals)) { currentSearchSpace = ref Unsafe.Add(ref currentSearchSpace, Vector128.Count); continue; @@ -2083,7 +2084,7 @@ internal static int IndexOfAnyValueType(ref T searchSpace, T value0, T value1 current = Vector128.LoadUnsafe(ref oneVectorAwayFromEnd); equals = Vector128.Equals(values0, current) | Vector128.Equals(values1, current) | Vector128.Equals(values2, current) | Vector128.Equals(values3, current) | Vector128.Equals(values4, current); - if (equals != Vector128.Zero) + if (VectorContainsMatch(equals)) { return ComputeFirstIndex(ref searchSpace, ref oneVectorAwayFromEnd, equals); } @@ -2614,11 +2615,41 @@ private static int LastIndexOfAnyValueType(ref TValue searchSp return -1; } + [MethodImpl(MethodImplOptions.AggressiveOptimization)] + private static bool VectorContainsMatch(Vector128 inputVector) where TValue : struct, INumber + { + // The inputVector has all the bits of a matching element set. + if (AdvSimd.Arm64.IsSupported) + { + // We select a larger byte from the adjacent pair of bytes. Thus, if any of the 16 bytes of inputVector are set, + // representing a matched element, they get reflected in the bottom half (8 bytes) of the register. + // We then detect a match when any of the bottom 8 bytes are non-zero. + return AdvSimd.Arm64.MaxPairwise(inputVector.AsByte(), inputVector.AsByte()).AsUInt64().ToScalar() != 0; + } + else + { + return (inputVector != Vector128.Zero); + } + } + [MethodImpl(MethodImplOptions.AggressiveInlining)] private static int ComputeFirstIndex(ref T searchSpace, ref T current, Vector128 equals) where T : struct { - uint notEqualsElements = equals.ExtractMostSignificantBits(); - int index = BitOperations.TrailingZeroCount(notEqualsElements); + int index; + if (AdvSimd.Arm64.IsSupported && Unsafe.SizeOf() > 1) + { + // PairwiseMax selects a larger of the adjacent elements. Now the bottom half of + // the target register (64 bits) represents elements of half the original size. + // Dividing the number of trailing zeros by 4 gives the byte offset position of + // the first matching element, and dividing it by the element size gives its index. + index = (BitOperations.TrailingZeroCount( + AdvSimd.Arm64.MaxPairwise(equals.AsByte(), equals.AsByte()).AsUInt64().ToScalar()) >> 2) / Unsafe.SizeOf(); + } + else + { + uint notEqualsElements = equals.ExtractMostSignificantBits(); + index = BitOperations.TrailingZeroCount(notEqualsElements); + } return index + (int)(Unsafe.ByteOffset(ref searchSpace, ref current) / Unsafe.SizeOf()); }