Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Add AVX2 version of adler
  • Loading branch information
brianpopow committed Feb 21, 2022
commit 27fc3b01c6678ed5c96e5c137250028b779e449c
155 changes: 121 additions & 34 deletions src/ImageSharp/Compression/Zlib/Adler32.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ internal static class Adler32
#if SUPPORTS_RUNTIME_INTRINSICS
private const int MinBufferSize = 64;

private const int BLOCK_SIZE = 1 << 5;

// The C# compiler emits this as a compile-time constant embedded in the PE file.
private static ReadOnlySpan<byte> Tap1Tap2 => new byte[]
{
Expand Down Expand Up @@ -63,6 +65,11 @@ public static uint Calculate(uint adler, ReadOnlySpan<byte> buffer)
}

#if SUPPORTS_RUNTIME_INTRINSICS
if (Avx2.IsSupported && buffer.Length >= MinBufferSize)
{
return CalculateAvx2(adler, buffer);
}

if (Ssse3.IsSupported && buffer.Length >= MinBufferSize)
{
return CalculateSse(adler, buffer);
Expand All @@ -83,8 +90,6 @@ private static unsafe uint CalculateSse(uint adler, ReadOnlySpan<byte> buffer)
uint s2 = (adler >> 16) & 0xFFFF;

// Process the data in blocks.
const int BLOCK_SIZE = 1 << 5;

uint length = (uint)buffer.Length;
uint blocks = length / BLOCK_SIZE;
length -= blocks * BLOCK_SIZE;
Expand Down Expand Up @@ -164,45 +169,127 @@ private static unsafe uint CalculateSse(uint adler, ReadOnlySpan<byte> buffer)

if (length > 0)
{
if (length >= 16)
{
s2 += s1 += localBufferPtr[0];
s2 += s1 += localBufferPtr[1];
s2 += s1 += localBufferPtr[2];
s2 += s1 += localBufferPtr[3];
s2 += s1 += localBufferPtr[4];
s2 += s1 += localBufferPtr[5];
s2 += s1 += localBufferPtr[6];
s2 += s1 += localBufferPtr[7];
s2 += s1 += localBufferPtr[8];
s2 += s1 += localBufferPtr[9];
s2 += s1 += localBufferPtr[10];
s2 += s1 += localBufferPtr[11];
s2 += s1 += localBufferPtr[12];
s2 += s1 += localBufferPtr[13];
s2 += s1 += localBufferPtr[14];
s2 += s1 += localBufferPtr[15];

localBufferPtr += 16;
length -= 16;
}
HandleLeftOver(localBufferPtr, length, ref s1, ref s2);
}

while (length-- > 0)
{
s2 += s1 += *localBufferPtr++;
}
return s1 | (s2 << 16);
}
}
}

if (s1 >= BASE)
{
s1 -= BASE;
}
// Based on: https://github.com/zlib-ng/zlib-ng/blob/develop/arch/x86/adler32_avx2.c
[MethodImpl(InliningOptions.HotPath | InliningOptions.ShortMethod)]
public static unsafe uint CalculateAvx2(uint adler, ReadOnlySpan<byte> buffer)
{
uint s1 = adler & 0xFFFF;
uint s2 = (adler >> 16) & 0xFFFF;
uint length = (uint)buffer.Length;

s2 %= BASE;
fixed (byte* bufferPtr = buffer)
{
byte* localBufferPtr = bufferPtr;

Vector256<byte> zero = Vector256<byte>.Zero;
var dot3v = Vector256.Create((short)1);
var dot2v = Vector256.Create(32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1);

// Process n blocks of data. At most NMAX data bytes can be
// processed before s2 must be reduced modulo BASE.
var vs1 = Vector256.CreateScalar(s1);
var vs2 = Vector256.CreateScalar(s2);

while (length >= 32)
{
int k = length < NMAX ? (int)length : (int)NMAX;
k -= k % 32;
length -= (uint)k;

Vector256<uint> vs10 = vs1;
Vector256<uint> vs3 = Vector256<uint>.Zero;

while (k >= 32)
{
// Load 32 input bytes.
Vector256<byte> block = Avx.LoadVector256(localBufferPtr);

// Sum of abs diff, resulting in 2 x int32's
Vector256<ushort> vs1sad = Avx2.SumAbsoluteDifferences(block, zero);

vs1 = Avx2.Add(vs1, vs1sad.AsUInt32());
vs3 = Avx2.Add(vs3, vs10);

// sum 32 uint8s to 16 shorts.
Vector256<short> vshortsum2 = Avx2.MultiplyAddAdjacent(block, dot2v);

// sum 16 shorts to 8 uint32s.
Vector256<int> vsum2 = Avx2.MultiplyAddAdjacent(vshortsum2, dot3v);

vs2 = Avx2.Add(vsum2.AsUInt32(), vs2);
vs10 = vs1;

localBufferPtr += BLOCK_SIZE;
k -= 32;
}

return s1 | (s2 << 16);
// Defer the multiplication with 32 to outside of the loop.
vs3 = Avx2.ShiftLeftLogical(vs3, 5);
vs2 = Avx2.Add(vs2, vs3);

s1 = (uint)Numerics.EvenReduceSum(vs1.AsInt32());
s2 = (uint)Numerics.ReduceSum(vs2.AsInt32());

s1 %= BASE;
s2 %= BASE;

vs1 = Vector256.CreateScalar(s1);
vs2 = Vector256.CreateScalar(s2);
}

if (length > 0)
{
HandleLeftOver(localBufferPtr, length, ref s1, ref s2);
}

return s1 | (s2 << 16);
}
}

private static unsafe void HandleLeftOver(byte* localBufferPtr, uint length, ref uint s1, ref uint s2)
{
if (length >= 16)
{
s2 += s1 += localBufferPtr[0];
s2 += s1 += localBufferPtr[1];
s2 += s1 += localBufferPtr[2];
s2 += s1 += localBufferPtr[3];
s2 += s1 += localBufferPtr[4];
s2 += s1 += localBufferPtr[5];
s2 += s1 += localBufferPtr[6];
s2 += s1 += localBufferPtr[7];
s2 += s1 += localBufferPtr[8];
s2 += s1 += localBufferPtr[9];
s2 += s1 += localBufferPtr[10];
s2 += s1 += localBufferPtr[11];
s2 += s1 += localBufferPtr[12];
s2 += s1 += localBufferPtr[13];
s2 += s1 += localBufferPtr[14];
s2 += s1 += localBufferPtr[15];

localBufferPtr += 16;
length -= 16;
}

while (length-- > 0)
{
s2 += s1 += *localBufferPtr++;
}

if (s1 >= BASE)
{
s1 -= BASE;
}

s2 %= BASE;
}
#endif

Expand Down