diff --git a/src/libraries/Common/src/Interop/Windows/SspiCli/ISSPIInterface.cs b/src/libraries/Common/src/Interop/Windows/SspiCli/ISSPIInterface.cs index 2b50412e8fc6c9..20cd29588f9a10 100644 --- a/src/libraries/Common/src/Interop/Windows/SspiCli/ISSPIInterface.cs +++ b/src/libraries/Common/src/Interop/Windows/SspiCli/ISSPIInterface.cs @@ -15,8 +15,8 @@ internal interface ISSPIInterface int AcquireCredentialsHandle(string moduleName, Interop.SspiCli.CredentialUse usage, ref SafeSspiAuthDataHandle authdata, out SafeFreeCredentials outCredential); int AcquireCredentialsHandle(string moduleName, Interop.SspiCli.CredentialUse usage, ref Interop.SspiCli.SCHANNEL_CRED authdata, out SafeFreeCredentials outCredential); int AcquireDefaultCredential(string moduleName, Interop.SspiCli.CredentialUse usage, out SafeFreeCredentials outCredential); - int AcceptSecurityContext(SafeFreeCredentials credential, ref SafeDeleteSslContext context, ReadOnlySpan inputBuffers, Interop.SspiCli.ContextFlags inFlags, Interop.SspiCli.Endianness endianness, ref SecurityBuffer outputBuffer, ref Interop.SspiCli.ContextFlags outFlags); - int InitializeSecurityContext(ref SafeFreeCredentials credential, ref SafeDeleteSslContext context, string targetName, Interop.SspiCli.ContextFlags inFlags, Interop.SspiCli.Endianness endianness, ReadOnlySpan inputBuffers, ref SecurityBuffer outputBuffer, ref Interop.SspiCli.ContextFlags outFlags); + int AcceptSecurityContext(SafeFreeCredentials credential, ref SafeDeleteSslContext context, InputSecurityBuffers inputBuffers, Interop.SspiCli.ContextFlags inFlags, Interop.SspiCli.Endianness endianness, ref SecurityBuffer outputBuffer, ref Interop.SspiCli.ContextFlags outFlags); + int InitializeSecurityContext(ref SafeFreeCredentials credential, ref SafeDeleteSslContext context, string targetName, Interop.SspiCli.ContextFlags inFlags, Interop.SspiCli.Endianness endianness, InputSecurityBuffers inputBuffers, ref SecurityBuffer outputBuffer, ref Interop.SspiCli.ContextFlags outFlags); int EncryptMessage(SafeDeleteContext context, ref Interop.SspiCli.SecBufferDesc inputOutput, uint sequenceNumber); int DecryptMessage(SafeDeleteContext context, ref Interop.SspiCli.SecBufferDesc inputOutput, uint sequenceNumber); int MakeSignature(SafeDeleteContext context, ref Interop.SspiCli.SecBufferDesc inputOutput, uint sequenceNumber); diff --git a/src/libraries/Common/src/Interop/Windows/SspiCli/SSPIAuthType.cs b/src/libraries/Common/src/Interop/Windows/SspiCli/SSPIAuthType.cs index de1b873ef43c86..10e9e6d0321dde 100644 --- a/src/libraries/Common/src/Interop/Windows/SspiCli/SSPIAuthType.cs +++ b/src/libraries/Common/src/Interop/Windows/SspiCli/SSPIAuthType.cs @@ -45,12 +45,12 @@ public int AcquireCredentialsHandle(string moduleName, Interop.SspiCli.Credentia return SafeFreeCredentials.AcquireCredentialsHandle(moduleName, usage, ref authdata, out outCredential); } - public int AcceptSecurityContext(SafeFreeCredentials credential, ref SafeDeleteSslContext context, ReadOnlySpan inputBuffers, Interop.SspiCli.ContextFlags inFlags, Interop.SspiCli.Endianness endianness, ref SecurityBuffer outputBuffer, ref Interop.SspiCli.ContextFlags outFlags) + public int AcceptSecurityContext(SafeFreeCredentials credential, ref SafeDeleteSslContext context, InputSecurityBuffers inputBuffers, Interop.SspiCli.ContextFlags inFlags, Interop.SspiCli.Endianness endianness, ref SecurityBuffer outputBuffer, ref Interop.SspiCli.ContextFlags outFlags) { return SafeDeleteContext.AcceptSecurityContext(ref credential, ref context, inFlags, endianness, inputBuffers, ref outputBuffer, ref outFlags); } - public int InitializeSecurityContext(ref SafeFreeCredentials credential, ref SafeDeleteSslContext context, string targetName, Interop.SspiCli.ContextFlags inFlags, Interop.SspiCli.Endianness endianness, ReadOnlySpan inputBuffers, ref SecurityBuffer outputBuffer, ref Interop.SspiCli.ContextFlags outFlags) + public int InitializeSecurityContext(ref SafeFreeCredentials credential, ref SafeDeleteSslContext context, string targetName, Interop.SspiCli.ContextFlags inFlags, Interop.SspiCli.Endianness endianness, InputSecurityBuffers inputBuffers, ref SecurityBuffer outputBuffer, ref Interop.SspiCli.ContextFlags outFlags) { return SafeDeleteContext.InitializeSecurityContext(ref credential, ref context, targetName, inFlags, endianness, inputBuffers, ref outputBuffer, ref outFlags); } diff --git a/src/libraries/Common/src/Interop/Windows/SspiCli/SSPISecureChannelType.cs b/src/libraries/Common/src/Interop/Windows/SspiCli/SSPISecureChannelType.cs index 13e32d4d5e4fcc..e3832d8c3cdba5 100644 --- a/src/libraries/Common/src/Interop/Windows/SspiCli/SSPISecureChannelType.cs +++ b/src/libraries/Common/src/Interop/Windows/SspiCli/SSPISecureChannelType.cs @@ -45,12 +45,12 @@ public int AcquireCredentialsHandle(string moduleName, Interop.SspiCli.Credentia return SafeFreeCredentials.AcquireCredentialsHandle(moduleName, usage, ref authdata, out outCredential); } - public int AcceptSecurityContext(SafeFreeCredentials credential, ref SafeDeleteSslContext context, ReadOnlySpan inputBuffers, Interop.SspiCli.ContextFlags inFlags, Interop.SspiCli.Endianness endianness, ref SecurityBuffer outputBuffer, ref Interop.SspiCli.ContextFlags outFlags) + public int AcceptSecurityContext(SafeFreeCredentials credential, ref SafeDeleteSslContext context, InputSecurityBuffers inputBuffers, Interop.SspiCli.ContextFlags inFlags, Interop.SspiCli.Endianness endianness, ref SecurityBuffer outputBuffer, ref Interop.SspiCli.ContextFlags outFlags) { return SafeDeleteContext.AcceptSecurityContext(ref credential, ref context, inFlags, endianness, inputBuffers, ref outputBuffer, ref outFlags); } - public int InitializeSecurityContext(ref SafeFreeCredentials credential, ref SafeDeleteSslContext context, string targetName, Interop.SspiCli.ContextFlags inFlags, Interop.SspiCli.Endianness endianness, ReadOnlySpan inputBuffers, ref SecurityBuffer outputBuffer, ref Interop.SspiCli.ContextFlags outFlags) + public int InitializeSecurityContext(ref SafeFreeCredentials credential, ref SafeDeleteSslContext context, string targetName, Interop.SspiCli.ContextFlags inFlags, Interop.SspiCli.Endianness endianness, InputSecurityBuffers inputBuffers, ref SecurityBuffer outputBuffer, ref Interop.SspiCli.ContextFlags outFlags) { return SafeDeleteContext.InitializeSecurityContext(ref credential, ref context, targetName, inFlags, endianness, inputBuffers, ref outputBuffer, ref outFlags); } diff --git a/src/libraries/Common/src/Interop/Windows/SspiCli/SSPIWrapper.cs b/src/libraries/Common/src/Interop/Windows/SspiCli/SSPIWrapper.cs index 30d705123c3ec0..79e004c6476fa2 100644 --- a/src/libraries/Common/src/Interop/Windows/SspiCli/SSPIWrapper.cs +++ b/src/libraries/Common/src/Interop/Windows/SspiCli/SSPIWrapper.cs @@ -140,24 +140,24 @@ public static SafeFreeCredentials AcquireCredentialsHandle(ISSPIInterface secMod return outCredential; } - internal static int InitializeSecurityContext(ISSPIInterface secModule, ref SafeFreeCredentials credential, ref SafeDeleteSslContext context, string targetName, Interop.SspiCli.ContextFlags inFlags, Interop.SspiCli.Endianness datarep, ReadOnlySpan inputBuffers, ref SecurityBuffer outputBuffer, ref Interop.SspiCli.ContextFlags outFlags) + internal static int InitializeSecurityContext(ISSPIInterface secModule, ref SafeFreeCredentials credential, ref SafeDeleteSslContext context, string targetName, Interop.SspiCli.ContextFlags inFlags, Interop.SspiCli.Endianness datarep, InputSecurityBuffers inputBuffers, ref SecurityBuffer outputBuffer, ref Interop.SspiCli.ContextFlags outFlags) { if (NetEventSource.IsEnabled) NetEventSource.Log.InitializeSecurityContext(credential, context, targetName, inFlags); int errorCode = secModule.InitializeSecurityContext(ref credential, ref context, targetName, inFlags, datarep, inputBuffers, ref outputBuffer, ref outFlags); - if (NetEventSource.IsEnabled) NetEventSource.Log.SecurityContextInputBuffers(nameof(InitializeSecurityContext), inputBuffers.Length, outputBuffer.size, (Interop.SECURITY_STATUS)errorCode); + if (NetEventSource.IsEnabled) NetEventSource.Log.SecurityContextInputBuffers(nameof(InitializeSecurityContext), inputBuffers.Count, outputBuffer.size, (Interop.SECURITY_STATUS)errorCode); return errorCode; } - internal static int AcceptSecurityContext(ISSPIInterface secModule, SafeFreeCredentials credential, ref SafeDeleteSslContext context, Interop.SspiCli.ContextFlags inFlags, Interop.SspiCli.Endianness datarep, ReadOnlySpan inputBuffers, ref SecurityBuffer outputBuffer, ref Interop.SspiCli.ContextFlags outFlags) + internal static int AcceptSecurityContext(ISSPIInterface secModule, SafeFreeCredentials credential, ref SafeDeleteSslContext context, Interop.SspiCli.ContextFlags inFlags, Interop.SspiCli.Endianness datarep, InputSecurityBuffers inputBuffers, ref SecurityBuffer outputBuffer, ref Interop.SspiCli.ContextFlags outFlags) { if (NetEventSource.IsEnabled) NetEventSource.Log.AcceptSecurityContext(credential, context, inFlags); int errorCode = secModule.AcceptSecurityContext(credential, ref context, inputBuffers, inFlags, datarep, ref outputBuffer, ref outFlags); - if (NetEventSource.IsEnabled) NetEventSource.Log.SecurityContextInputBuffers(nameof(AcceptSecurityContext), inputBuffers.Length, outputBuffer.size, (Interop.SECURITY_STATUS)errorCode); + if (NetEventSource.IsEnabled) NetEventSource.Log.SecurityContextInputBuffers(nameof(AcceptSecurityContext), inputBuffers.Count, outputBuffer.size, (Interop.SECURITY_STATUS)errorCode); return errorCode; } diff --git a/src/libraries/Common/src/Interop/Windows/SspiCli/SecuritySafeHandles.cs b/src/libraries/Common/src/Interop/Windows/SspiCli/SecuritySafeHandles.cs index e5b024948ee541..02d939392b030e 100644 --- a/src/libraries/Common/src/Interop/Windows/SspiCli/SecuritySafeHandles.cs +++ b/src/libraries/Common/src/Interop/Windows/SspiCli/SecuritySafeHandles.cs @@ -396,7 +396,7 @@ internal static unsafe int InitializeSecurityContext( string targetName, Interop.SspiCli.ContextFlags inFlags, Interop.SspiCli.Endianness endianness, - ReadOnlySpan inSecBuffers, + InputSecurityBuffers inSecBuffers, ref SecurityBuffer outSecBuffer, ref Interop.SspiCli.ContextFlags outFlags) { @@ -413,7 +413,8 @@ internal static unsafe int InitializeSecurityContext( throw new ArgumentNullException(nameof(inCredentials)); } - Interop.SspiCli.SecBufferDesc inSecurityBufferDescriptor = new Interop.SspiCli.SecBufferDesc(inSecBuffers.Length); + Debug.Assert(inSecBuffers.Count <= 3); + Interop.SspiCli.SecBufferDesc inSecurityBufferDescriptor = new Interop.SspiCli.SecBufferDesc(inSecBuffers.Count); Interop.SspiCli.SecBufferDesc outSecurityBufferDescriptor = new Interop.SspiCli.SecBufferDesc(1); // Actually, this is returned in outFlags. @@ -431,34 +432,41 @@ internal static unsafe int InitializeSecurityContext( SafeFreeContextBuffer outFreeContextBuffer = null; try { - Span inUnmanagedBuffer = stackalloc Interop.SspiCli.SecBuffer[inSecurityBufferDescriptor.cBuffers]; - inUnmanagedBuffer.Clear(); + Span inUnmanagedBuffer = stackalloc Interop.SspiCli.SecBuffer[3]; fixed (void* inUnmanagedBufferPtr = inUnmanagedBuffer) - fixed (void* pinnedToken0 = inSecBuffers.Length > 0 ? inSecBuffers[0].token : null) - fixed (void* pinnedToken1 = inSecBuffers.Length > 1 ? inSecBuffers[1].token : null) - fixed (void* pinnedToken2 = inSecBuffers.Length > 2 ? inSecBuffers[2].token : null) // pin all buffers, even if null or not used, to avoid needing to allocate GCHandles + fixed (void* pinnedToken0 = inSecBuffers._item0.Token) + fixed (void* pinnedToken1 = inSecBuffers._item1.Token) + fixed (void* pinnedToken2 = inSecBuffers._item2.Token) { - Debug.Assert(inSecBuffers.Length <= 3); - // Fix Descriptor pointer that points to unmanaged SecurityBuffers. inSecurityBufferDescriptor.pBuffers = inUnmanagedBufferPtr; - for (int index = 0; index < inSecurityBufferDescriptor.cBuffers; ++index) + // Updated pvBuffer with pinned address. UnmanagedToken takes precedence. + if (inSecBuffers.Count > 2) { - ref readonly SecurityBuffer securityBuffer = ref inSecBuffers[index]; + inUnmanagedBuffer[2].BufferType = inSecBuffers._item2.Type; + inUnmanagedBuffer[2].cbBuffer = inSecBuffers._item2.Token.Length; + inUnmanagedBuffer[2].pvBuffer = inSecBuffers._item2.UnmanagedToken != null ? + (IntPtr)inSecBuffers._item2.UnmanagedToken.DangerousGetHandle() : + (IntPtr)pinnedToken2; + } - // Copy the SecurityBuffer content into unmanaged place holder. - inUnmanagedBuffer[index].cbBuffer = securityBuffer.size; - inUnmanagedBuffer[index].BufferType = securityBuffer.type; - - // Use the unmanaged token if it's not null; otherwise use the managed buffer. - inUnmanagedBuffer[index].pvBuffer = - securityBuffer.unmanagedToken != null ? securityBuffer.unmanagedToken.DangerousGetHandle() : - securityBuffer.token == null || securityBuffer.token.Length == 0 ? IntPtr.Zero : - Marshal.UnsafeAddrOfPinnedArrayElement(securityBuffer.token, securityBuffer.offset); -#if TRACE_VERBOSE - if (NetEventSource.IsEnabled) NetEventSource.Info(null, $"SecBuffer: cbBuffer:{securityBuffer.size} BufferType:{securityBuffer.type}"); -#endif + if (inSecBuffers.Count > 1) + { + inUnmanagedBuffer[1].BufferType = inSecBuffers._item1.Type; + inUnmanagedBuffer[1].cbBuffer = inSecBuffers._item1.Token.Length; + inUnmanagedBuffer[1].pvBuffer = inSecBuffers._item1.UnmanagedToken != null ? + (IntPtr)inSecBuffers._item1.UnmanagedToken.DangerousGetHandle() : + (IntPtr)pinnedToken1; + } + + if (inSecBuffers.Count > 0) + { + inUnmanagedBuffer[0].BufferType = inSecBuffers._item0.Type; + inUnmanagedBuffer[0].cbBuffer = inSecBuffers._item0.Token.Length; + inUnmanagedBuffer[0].pvBuffer = inSecBuffers._item0.UnmanagedToken != null ? + (IntPtr)inSecBuffers._item0.UnmanagedToken.DangerousGetHandle() : + (IntPtr)pinnedToken0; } fixed (byte* pinnedOutBytes = outSecBuffer.token) @@ -626,7 +634,7 @@ internal static unsafe int AcceptSecurityContext( ref SafeDeleteSslContext refContext, Interop.SspiCli.ContextFlags inFlags, Interop.SspiCli.Endianness endianness, - ReadOnlySpan inSecBuffers, + InputSecurityBuffers inSecBuffers, ref SecurityBuffer outSecBuffer, ref Interop.SspiCli.ContextFlags outFlags) { @@ -643,7 +651,8 @@ internal static unsafe int AcceptSecurityContext( throw new ArgumentNullException(nameof(inCredentials)); } - Interop.SspiCli.SecBufferDesc inSecurityBufferDescriptor = new Interop.SspiCli.SecBufferDesc(inSecBuffers.Length); + Debug.Assert(inSecBuffers.Count <= 3); + Interop.SspiCli.SecBufferDesc inSecurityBufferDescriptor = new Interop.SspiCli.SecBufferDesc(inSecBuffers.Count); Interop.SspiCli.SecBufferDesc outSecurityBufferDescriptor = new Interop.SspiCli.SecBufferDesc(count: 2); // Actually, this is returned in outFlags. @@ -663,35 +672,42 @@ internal static unsafe int AcceptSecurityContext( outUnmanagedBuffer[1].pvBuffer = IntPtr.Zero; try { - Span inUnmanagedBuffer = stackalloc Interop.SspiCli.SecBuffer[inSecurityBufferDescriptor.cBuffers]; - inUnmanagedBuffer.Clear(); + // Allocate always maximum to allow better code optimization. + Span inUnmanagedBuffer = stackalloc Interop.SspiCli.SecBuffer[3]; fixed (void* inUnmanagedBufferPtr = inUnmanagedBuffer) fixed (void* outUnmanagedBufferPtr = outUnmanagedBuffer) - fixed (void* pinnedToken0 = inSecBuffers.Length > 0 ? inSecBuffers[0].token : null) - fixed (void* pinnedToken1 = inSecBuffers.Length > 1 ? inSecBuffers[1].token : null) - fixed (void* pinnedToken2 = inSecBuffers.Length > 2 ? inSecBuffers[2].token : null) // pin all buffers, even if null or not used, to avoid needing to allocate GCHandles + fixed (void* pinnedToken0 = inSecBuffers._item0.Token) + fixed (void* pinnedToken1 = inSecBuffers._item1.Token) + fixed (void* pinnedToken2 = inSecBuffers._item2.Token) { - Debug.Assert(inSecBuffers.Length <= 3); - - // Fix Descriptor pointer that points to unmanaged SecurityBuffers. inSecurityBufferDescriptor.pBuffers = inUnmanagedBufferPtr; - for (int index = 0; index < inSecurityBufferDescriptor.cBuffers; ++index) + // Updated pvBuffer with pinned address. UnmanagedToken takes precedence. + if (inSecBuffers.Count > 2) { - ref readonly SecurityBuffer securityBuffer = ref inSecBuffers[index]; + inUnmanagedBuffer[2].BufferType = inSecBuffers._item2.Type; + inUnmanagedBuffer[2].cbBuffer = inSecBuffers._item2.Token.Length; + inUnmanagedBuffer[2].pvBuffer = inSecBuffers._item2.UnmanagedToken != null ? + (IntPtr)inSecBuffers._item2.UnmanagedToken.DangerousGetHandle() : + (IntPtr)pinnedToken2; + } - // Copy the SecurityBuffer content into unmanaged place holder. - inUnmanagedBuffer[index].cbBuffer = securityBuffer.size; - inUnmanagedBuffer[index].BufferType = securityBuffer.type; - - // Use the unmanaged token if it's not null; otherwise use the managed buffer. - inUnmanagedBuffer[index].pvBuffer = - securityBuffer.unmanagedToken != null ? securityBuffer.unmanagedToken.DangerousGetHandle() : - securityBuffer.token == null || securityBuffer.token.Length == 0 ? IntPtr.Zero : - Marshal.UnsafeAddrOfPinnedArrayElement(securityBuffer.token, securityBuffer.offset); -#if TRACE_VERBOSE - if (NetEventSource.IsEnabled) NetEventSource.Info(null, $"SecBuffer: cbBuffer:{securityBuffer.size} BufferType:{securityBuffer.type}"); -#endif + if (inSecBuffers.Count > 1) + { + inUnmanagedBuffer[1].BufferType = inSecBuffers._item1.Type; + inUnmanagedBuffer[1].cbBuffer = inSecBuffers._item1.Token.Length; + inUnmanagedBuffer[1].pvBuffer = inSecBuffers._item1.UnmanagedToken != null ? + (IntPtr)inSecBuffers._item1.UnmanagedToken.DangerousGetHandle() : + (IntPtr)pinnedToken1; + } + + if (inSecBuffers.Count > 0) + { + inUnmanagedBuffer[0].BufferType = inSecBuffers._item0.Type; + inUnmanagedBuffer[0].cbBuffer = inSecBuffers._item0.Token.Length; + inUnmanagedBuffer[0].pvBuffer = inSecBuffers._item0.UnmanagedToken != null ? + (IntPtr)inSecBuffers._item0.UnmanagedToken.DangerousGetHandle() : + (IntPtr)pinnedToken0; } fixed (byte* pinnedOutBytes = outSecBuffer.token) diff --git a/src/libraries/Common/src/System/Net/Security/NegotiateStreamPal.Windows.cs b/src/libraries/Common/src/System/Net/Security/NegotiateStreamPal.Windows.cs index f2f868b2da35cf..96690d1a3bc601 100644 --- a/src/libraries/Common/src/System/Net/Security/NegotiateStreamPal.Windows.cs +++ b/src/libraries/Common/src/System/Net/Security/NegotiateStreamPal.Windows.cs @@ -80,31 +80,17 @@ internal static SecurityStatusPal InitializeSecurityContext( ref byte[] resultBlob, ref ContextFlagsPal contextFlags) { -#if NETSTANDARD2_0 - Span inSecurityBufferSpan = new SecurityBuffer[2]; -#else - TwoSecurityBuffers twoSecurityBuffers = default; - Span inSecurityBufferSpan = MemoryMarshal.CreateSpan(ref twoSecurityBuffers._item0, 2); -#endif - int inSecurityBufferSpanLength = 0; - if (incomingBlob != null && channelBinding != null) + InputSecurityBuffers inputBuffers = default; + if (incomingBlob != null) { - inSecurityBufferSpan[0] = new SecurityBuffer(incomingBlob, SecurityBufferType.SECBUFFER_TOKEN); - inSecurityBufferSpan[1] = new SecurityBuffer(channelBinding); - inSecurityBufferSpanLength = 2; + inputBuffers.SetNextBuffer(new InputSecurityBuffer(incomingBlob, SecurityBufferType.SECBUFFER_TOKEN)); } - else if (incomingBlob != null) - { - inSecurityBufferSpan[0] = new SecurityBuffer(incomingBlob, SecurityBufferType.SECBUFFER_TOKEN); - inSecurityBufferSpanLength = 1; - } - else if (channelBinding != null) + + if (channelBinding != null) { - inSecurityBufferSpan[0] = new SecurityBuffer(channelBinding); - inSecurityBufferSpanLength = 1; + inputBuffers.SetNextBuffer(new InputSecurityBuffer(channelBinding)); } - inSecurityBufferSpan = inSecurityBufferSpan.Slice(0, inSecurityBufferSpanLength); var outSecurityBuffer = new SecurityBuffer(resultBlob, SecurityBufferType.SECBUFFER_TOKEN); @@ -118,7 +104,7 @@ internal static SecurityStatusPal InitializeSecurityContext( spn, ContextFlagsAdapterPal.GetInteropFromContextFlagsPal(requestedContextFlags), Interop.SspiCli.Endianness.SECURITY_NETWORK_DREP, - inSecurityBufferSpan, + inputBuffers, ref outSecurityBuffer, ref outContextFlags); securityContext = sslContext; @@ -151,31 +137,16 @@ internal static SecurityStatusPal AcceptSecurityContext( ref byte[] resultBlob, ref ContextFlagsPal contextFlags) { -#if NETSTANDARD2_0 - Span inSecurityBufferSpan = new SecurityBuffer[2]; -#else - TwoSecurityBuffers twoSecurityBuffers = default; - Span inSecurityBufferSpan = MemoryMarshal.CreateSpan(ref twoSecurityBuffers._item0, 2); -#endif - - int inSecurityBufferSpanLength = 0; - if (incomingBlob != null && channelBinding != null) + InputSecurityBuffers inputBuffers = default; + if (incomingBlob != null) { - inSecurityBufferSpan[0] = new SecurityBuffer(incomingBlob, SecurityBufferType.SECBUFFER_TOKEN); - inSecurityBufferSpan[1] = new SecurityBuffer(channelBinding); - inSecurityBufferSpanLength = 2; + inputBuffers.SetNextBuffer(new InputSecurityBuffer(incomingBlob, SecurityBufferType.SECBUFFER_TOKEN)); } - else if (incomingBlob != null) - { - inSecurityBufferSpan[0] = new SecurityBuffer(incomingBlob, SecurityBufferType.SECBUFFER_TOKEN); - inSecurityBufferSpanLength = 1; - } - else if (channelBinding != null) + + if (channelBinding != null) { - inSecurityBufferSpan[0] = new SecurityBuffer(channelBinding); - inSecurityBufferSpanLength = 1; + inputBuffers.SetNextBuffer(new InputSecurityBuffer(channelBinding)); } - inSecurityBufferSpan = inSecurityBufferSpan.Slice(0, inSecurityBufferSpanLength); var outSecurityBuffer = new SecurityBuffer(resultBlob, SecurityBufferType.SECBUFFER_TOKEN); @@ -188,7 +159,7 @@ internal static SecurityStatusPal AcceptSecurityContext( ref sslContext, ContextFlagsAdapterPal.GetInteropFromContextFlagsPal(requestedContextFlags), Interop.SspiCli.Endianness.SECURITY_NETWORK_DREP, - inSecurityBufferSpan, + inputBuffers, ref outSecurityBuffer, ref outContextFlags); diff --git a/src/libraries/Common/src/System/Net/Security/SecurityBuffer.Windows.cs b/src/libraries/Common/src/System/Net/Security/SecurityBuffer.Windows.cs index d076942ebbd35f..9ba6ed506cf538 100644 --- a/src/libraries/Common/src/System/Net/Security/SecurityBuffer.Windows.cs +++ b/src/libraries/Common/src/System/Net/Security/SecurityBuffer.Windows.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System.Diagnostics; using System.Runtime.InteropServices; using System.Security.Authentication.ExtendedProtection; @@ -28,6 +29,56 @@ internal ref struct ThreeSecurityBuffers private SecurityBuffer _item2; } + [StructLayout(LayoutKind.Sequential)] + internal ref struct InputSecurityBuffers + { + internal int Count; + internal InputSecurityBuffer _item0; + internal InputSecurityBuffer _item1; + internal InputSecurityBuffer _item2; + + internal void SetNextBuffer(InputSecurityBuffer buffer) + { + Debug.Assert(Count >= 0 && Count < 3); + if (Count == 0) + { + _item0 = buffer; + } + else if (Count == 1) + { + _item1 = buffer; + } + else + { + _item2 = buffer; + } + + Count++; + } + } + + [StructLayout(LayoutKind.Auto)] + internal readonly ref struct InputSecurityBuffer + { + public readonly SecurityBufferType Type; + public readonly ReadOnlySpan Token; + public readonly SafeHandle UnmanagedToken; + + public InputSecurityBuffer(ReadOnlySpan data, SecurityBufferType tokentype) + { + Token = data; + Type = tokentype; + UnmanagedToken = null; + } + + public InputSecurityBuffer(ChannelBinding binding) + { + Type = SecurityBufferType.SECBUFFER_CHANNEL_BINDINGS; + Token = default; + UnmanagedToken = binding; + } + } + [StructLayout(LayoutKind.Auto)] internal struct SecurityBuffer { diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SecureChannel.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SecureChannel.cs index d039fbcdcbd2b1..8f3140e516329e 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/SecureChannel.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/SecureChannel.cs @@ -783,6 +783,7 @@ private SecurityStatusPal GenerateToken(byte[] input, int offset, int count, ref SecurityStatusPal status = default; bool cachedCreds = false; byte[] thumbPrint = null; + ReadOnlySpan inputBuffer = new ReadOnlySpan(input, offset, count); // // Looping through ASC or ISC with potentially cached credential that could have been @@ -796,7 +797,7 @@ private SecurityStatusPal GenerateToken(byte[] input, int offset, int count, ref if (_refreshCredentialNeeded) { cachedCreds = _sslAuthenticationOptions.IsServer - ? AcquireServerCredentials(ref thumbPrint, new ReadOnlySpan(input, offset, count)) + ? AcquireServerCredentials(ref thumbPrint, inputBuffer) : AcquireClientCredentials(ref thumbPrint); } @@ -805,7 +806,7 @@ private SecurityStatusPal GenerateToken(byte[] input, int offset, int count, ref status = SslStreamPal.AcceptSecurityContext( ref _credentialsHandle, ref _securityContext, - input, offset, count, + inputBuffer, ref result, _sslAuthenticationOptions); } @@ -815,7 +816,7 @@ private SecurityStatusPal GenerateToken(byte[] input, int offset, int count, ref ref _credentialsHandle, ref _securityContext, _sslAuthenticationOptions.TargetHost, - input, offset, count, + inputBuffer, ref result, _sslAuthenticationOptions); } diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.OSX.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.OSX.cs index eda667a9a9264f..c91cf4665cb25d 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.OSX.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.OSX.cs @@ -35,22 +35,22 @@ public static void VerifyPackageInfo() public static SecurityStatusPal AcceptSecurityContext( ref SafeFreeCredentials credential, ref SafeDeleteSslContext context, - byte[] inputBuffer, int offset, int count, + ReadOnlySpan inputBuffer, ref byte[] outputBuffer, SslAuthenticationOptions sslAuthenticationOptions) { - return HandshakeInternal(credential, ref context, new ReadOnlySpan(inputBuffer, offset, count), ref outputBuffer, sslAuthenticationOptions); + return HandshakeInternal(credential, ref context, inputBuffer, ref outputBuffer, sslAuthenticationOptions); } public static SecurityStatusPal InitializeSecurityContext( ref SafeFreeCredentials credential, ref SafeDeleteSslContext context, string targetName, - byte[] inputBuffer, int offset, int count, + ReadOnlySpan inputBuffer, ref byte[] outputBuffer, SslAuthenticationOptions sslAuthenticationOptions) { - return HandshakeInternal(credential, ref context, new ReadOnlySpan(inputBuffer, offset, count), ref outputBuffer, sslAuthenticationOptions); + return HandshakeInternal(credential, ref context, inputBuffer, ref outputBuffer, sslAuthenticationOptions); } public static SafeFreeCredentials AcquireCredentialsHandle( diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.Unix.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.Unix.cs index 0a975e9abaab7c..9d0d78dcebff16 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.Unix.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.Unix.cs @@ -26,15 +26,15 @@ public static void VerifyPackageInfo() } public static SecurityStatusPal AcceptSecurityContext(ref SafeFreeCredentials credential, ref SafeDeleteSslContext context, - byte[] inputBuffer, int offset, int count, ref byte[] outputBuffer, SslAuthenticationOptions sslAuthenticationOptions) + ReadOnlySpan inputBuffer, ref byte[] outputBuffer, SslAuthenticationOptions sslAuthenticationOptions) { - return HandshakeInternal(credential, ref context, new ReadOnlySpan(inputBuffer, offset, count), ref outputBuffer, sslAuthenticationOptions); + return HandshakeInternal(credential, ref context, inputBuffer, ref outputBuffer, sslAuthenticationOptions); } public static SecurityStatusPal InitializeSecurityContext(ref SafeFreeCredentials credential, ref SafeDeleteSslContext context, string targetName, - byte[] inputBuffer, int offset, int count, ref byte[] outputBuffer, SslAuthenticationOptions sslAuthenticationOptions) + ReadOnlySpan inputBuffer, ref byte[] outputBuffer, SslAuthenticationOptions sslAuthenticationOptions) { - return HandshakeInternal(credential, ref context, new ReadOnlySpan(inputBuffer, offset, count), ref outputBuffer, sslAuthenticationOptions); + return HandshakeInternal(credential, ref context, inputBuffer, ref outputBuffer, sslAuthenticationOptions); } public static SafeFreeCredentials AcquireCredentialsHandle(X509Certificate certificate, diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.Windows.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.Windows.cs index 607836d38297b0..aac05484241872 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.Windows.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.Windows.cs @@ -46,17 +46,19 @@ public static byte[] ConvertAlpnProtocolListToByteArray(List inputBuffer, ref byte[] outputBuffer, SslAuthenticationOptions sslAuthenticationOptions) { Interop.SspiCli.ContextFlags unusedAttributes = default; - ArraySegment input = inputBuffer != null ? new ArraySegment(inputBuffer, offset, count) : default; - ThreeSecurityBuffers threeSecurityBuffers = default; - SecurityBuffer? incomingSecurity = input.Array != null ? - new SecurityBuffer(input.Array, input.Offset, input.Count, SecurityBufferType.SECBUFFER_TOKEN) : - (SecurityBuffer?)null; - Span inputBuffers = MemoryMarshal.CreateSpan(ref threeSecurityBuffers._item0, 3); - GetIncomingSecurityBuffers(sslAuthenticationOptions, in incomingSecurity, ref inputBuffers); + InputSecurityBuffers inputBuffers = default; + inputBuffers.SetNextBuffer(new InputSecurityBuffer(inputBuffer, SecurityBufferType.SECBUFFER_TOKEN)); + inputBuffers.SetNextBuffer(new InputSecurityBuffer(default, SecurityBufferType.SECBUFFER_EMPTY)); + + if (sslAuthenticationOptions.ApplicationProtocols != null && sslAuthenticationOptions.ApplicationProtocols.Count != 0) + { + byte[] alpnBytes = ConvertAlpnProtocolListToByteArray(sslAuthenticationOptions.ApplicationProtocols); + inputBuffers.SetNextBuffer(new InputSecurityBuffer(new ReadOnlySpan(alpnBytes), SecurityBufferType.SECBUFFER_APPLICATION_PROTOCOLS)); + } var resultBuffer = new SecurityBuffer(outputBuffer, SecurityBufferType.SECBUFFER_TOKEN); @@ -74,17 +76,18 @@ public static SecurityStatusPal AcceptSecurityContext(ref SafeFreeCredentials cr return SecurityStatusAdapterPal.GetSecurityStatusPalFromNativeInt(errorCode); } - public static SecurityStatusPal InitializeSecurityContext(ref SafeFreeCredentials credentialsHandle, ref SafeDeleteSslContext context, string targetName, byte[] inputBuffer, int offset, int count, ref byte[] outputBuffer, SslAuthenticationOptions sslAuthenticationOptions) + public static SecurityStatusPal InitializeSecurityContext(ref SafeFreeCredentials credentialsHandle, ref SafeDeleteSslContext context, string targetName, ReadOnlySpan inputBuffer, ref byte[] outputBuffer, SslAuthenticationOptions sslAuthenticationOptions) { Interop.SspiCli.ContextFlags unusedAttributes = default; - ArraySegment input = inputBuffer != null ? new ArraySegment(inputBuffer, offset, count) : default; - ThreeSecurityBuffers threeSecurityBuffers = default; - SecurityBuffer? incomingSecurity = input.Array != null ? - new SecurityBuffer(input.Array, input.Offset, input.Count, SecurityBufferType.SECBUFFER_TOKEN) : - (SecurityBuffer?)null; - Span inputBuffers = MemoryMarshal.CreateSpan(ref threeSecurityBuffers._item0, 3); - GetIncomingSecurityBuffers(sslAuthenticationOptions, in incomingSecurity, ref inputBuffers); + InputSecurityBuffers inputBuffers = default; + inputBuffers.SetNextBuffer(new InputSecurityBuffer(inputBuffer, SecurityBufferType.SECBUFFER_TOKEN)); + inputBuffers.SetNextBuffer(new InputSecurityBuffer(default, SecurityBufferType.SECBUFFER_EMPTY)); + if (sslAuthenticationOptions.ApplicationProtocols != null && sslAuthenticationOptions.ApplicationProtocols.Count != 0) + { + byte[] alpnBytes = ConvertAlpnProtocolListToByteArray(sslAuthenticationOptions.ApplicationProtocols); + inputBuffers.SetNextBuffer(new InputSecurityBuffer(new ReadOnlySpan(alpnBytes), SecurityBufferType.SECBUFFER_APPLICATION_PROTOCOLS)); + } var resultBuffer = new SecurityBuffer(outputBuffer, SecurityBufferType.SECBUFFER_TOKEN); @@ -103,45 +106,6 @@ public static SecurityStatusPal InitializeSecurityContext(ref SafeFreeCredential return SecurityStatusAdapterPal.GetSecurityStatusPalFromNativeInt(errorCode); } - private static void GetIncomingSecurityBuffers(SslAuthenticationOptions options, in SecurityBuffer? incomingSecurity, ref Span incomingSecurityBuffers) - { - SecurityBuffer? alpnBuffer = null; - - if (options.ApplicationProtocols != null && options.ApplicationProtocols.Count != 0) - { - byte[] alpnBytes = ConvertAlpnProtocolListToByteArray(options.ApplicationProtocols); - alpnBuffer = new SecurityBuffer(alpnBytes, 0, alpnBytes.Length, SecurityBufferType.SECBUFFER_APPLICATION_PROTOCOLS); - } - - if (incomingSecurity != null) - { - if (alpnBuffer != null) - { - Debug.Assert(incomingSecurityBuffers.Length >= 3); - incomingSecurityBuffers[0] = incomingSecurity.GetValueOrDefault(); - incomingSecurityBuffers[1] = new SecurityBuffer(null, 0, 0, SecurityBufferType.SECBUFFER_EMPTY); - incomingSecurityBuffers[2] = alpnBuffer.GetValueOrDefault(); - incomingSecurityBuffers = incomingSecurityBuffers.Slice(0, 3); - } - else - { - Debug.Assert(incomingSecurityBuffers.Length >= 2); - incomingSecurityBuffers[0] = incomingSecurity.GetValueOrDefault(); - incomingSecurityBuffers[1] = new SecurityBuffer(null, 0, 0, SecurityBufferType.SECBUFFER_EMPTY); - incomingSecurityBuffers = incomingSecurityBuffers.Slice(0, 2); - } - } - else if (alpnBuffer != null) - { - incomingSecurityBuffers[0] = alpnBuffer.GetValueOrDefault(); - incomingSecurityBuffers = incomingSecurityBuffers.Slice(0, 1); - } - else - { - incomingSecurityBuffers = default; - } - } - public static SafeFreeCredentials AcquireCredentialsHandle(X509Certificate certificate, SslProtocols protocols, EncryptionPolicy policy, bool isServer) { int protocolFlags = GetProtocolFlagsFromSslProtocols(protocols, isServer);