diff --git a/src/libraries/Common/src/Interop/Windows/SspiCli/SecuritySafeHandles.cs b/src/libraries/Common/src/Interop/Windows/SspiCli/SecuritySafeHandles.cs index 382d1f6a2979ed..19224c956140e1 100644 --- a/src/libraries/Common/src/Interop/Windows/SspiCli/SecuritySafeHandles.cs +++ b/src/libraries/Common/src/Interop/Windows/SspiCli/SecuritySafeHandles.cs @@ -446,28 +446,50 @@ internal static unsafe int InitializeSecurityContext( if (inSecBuffers.Count > 2) { inUnmanagedBuffer[2].BufferType = inSecBuffers._item2.Type; - inUnmanagedBuffer[2].cbBuffer = inSecBuffers._item2.Token.Length; - inUnmanagedBuffer[2].pvBuffer = inSecBuffers._item2.UnmanagedToken != null ? - (IntPtr)inSecBuffers._item2.UnmanagedToken.DangerousGetHandle() : - (IntPtr)pinnedToken2; + if (inSecBuffers._item2.UnmanagedToken != null) + { + Debug.Assert(inSecBuffers._item2.Type == SecurityBufferType.SECBUFFER_CHANNEL_BINDINGS); + inUnmanagedBuffer[2].pvBuffer = (IntPtr)inSecBuffers._item2.UnmanagedToken.DangerousGetHandle(); + inUnmanagedBuffer[2].cbBuffer = ((ChannelBinding)inSecBuffers._item2.UnmanagedToken).Size; + } + else + { + inUnmanagedBuffer[2].cbBuffer = inSecBuffers._item2.Token.Length; + inUnmanagedBuffer[2].pvBuffer = (IntPtr)pinnedToken2; + } + } if (inSecBuffers.Count > 1) { inUnmanagedBuffer[1].BufferType = inSecBuffers._item1.Type; - inUnmanagedBuffer[1].cbBuffer = inSecBuffers._item1.Token.Length; - inUnmanagedBuffer[1].pvBuffer = inSecBuffers._item1.UnmanagedToken != null ? - (IntPtr)inSecBuffers._item1.UnmanagedToken.DangerousGetHandle() : - (IntPtr)pinnedToken1; + if (inSecBuffers._item1.UnmanagedToken != null) + { + Debug.Assert(inSecBuffers._item1.Type == SecurityBufferType.SECBUFFER_CHANNEL_BINDINGS); + inUnmanagedBuffer[1].pvBuffer = (IntPtr)inSecBuffers._item1.UnmanagedToken.DangerousGetHandle(); + inUnmanagedBuffer[1].cbBuffer = ((ChannelBinding)inSecBuffers._item1.UnmanagedToken).Size; + } + else + { + inUnmanagedBuffer[1].cbBuffer = inSecBuffers._item1.Token.Length; + inUnmanagedBuffer[1].pvBuffer = (IntPtr)pinnedToken1; + } } if (inSecBuffers.Count > 0) { inUnmanagedBuffer[0].BufferType = inSecBuffers._item0.Type; - inUnmanagedBuffer[0].cbBuffer = inSecBuffers._item0.Token.Length; - inUnmanagedBuffer[0].pvBuffer = inSecBuffers._item0.UnmanagedToken != null ? - (IntPtr)inSecBuffers._item0.UnmanagedToken.DangerousGetHandle() : - (IntPtr)pinnedToken0; + if (inSecBuffers._item0.UnmanagedToken != null) + { + Debug.Assert(inSecBuffers._item0.Type == SecurityBufferType.SECBUFFER_CHANNEL_BINDINGS); + inUnmanagedBuffer[0].pvBuffer = (IntPtr)inSecBuffers._item0.UnmanagedToken.DangerousGetHandle(); + inUnmanagedBuffer[0].cbBuffer = ((ChannelBinding)inSecBuffers._item0.UnmanagedToken).Size; + } + else + { + inUnmanagedBuffer[0].cbBuffer = inSecBuffers._item0.Token.Length; + inUnmanagedBuffer[0].pvBuffer = (IntPtr)pinnedToken0; + } } fixed (byte* pinnedOutBytes = outSecBuffer.token) @@ -687,28 +709,50 @@ internal static unsafe int AcceptSecurityContext( if (inSecBuffers.Count > 2) { inUnmanagedBuffer[2].BufferType = inSecBuffers._item2.Type; - inUnmanagedBuffer[2].cbBuffer = inSecBuffers._item2.Token.Length; - inUnmanagedBuffer[2].pvBuffer = inSecBuffers._item2.UnmanagedToken != null ? - (IntPtr)inSecBuffers._item2.UnmanagedToken.DangerousGetHandle() : - (IntPtr)pinnedToken2; + if (inSecBuffers._item2.UnmanagedToken != null) + { + Debug.Assert(inSecBuffers._item2.Type == SecurityBufferType.SECBUFFER_CHANNEL_BINDINGS); + inUnmanagedBuffer[2].pvBuffer = (IntPtr)inSecBuffers._item2.UnmanagedToken.DangerousGetHandle(); + inUnmanagedBuffer[2].cbBuffer = ((ChannelBinding)inSecBuffers._item2.UnmanagedToken).Size; + } + else + { + inUnmanagedBuffer[2].cbBuffer = inSecBuffers._item2.Token.Length; + inUnmanagedBuffer[2].pvBuffer = (IntPtr)pinnedToken2; + } + } if (inSecBuffers.Count > 1) { inUnmanagedBuffer[1].BufferType = inSecBuffers._item1.Type; - inUnmanagedBuffer[1].cbBuffer = inSecBuffers._item1.Token.Length; - inUnmanagedBuffer[1].pvBuffer = inSecBuffers._item1.UnmanagedToken != null ? - (IntPtr)inSecBuffers._item1.UnmanagedToken.DangerousGetHandle() : - (IntPtr)pinnedToken1; + if (inSecBuffers._item1.UnmanagedToken != null) + { + Debug.Assert(inSecBuffers._item1.Type == SecurityBufferType.SECBUFFER_CHANNEL_BINDINGS); + inUnmanagedBuffer[1].pvBuffer = (IntPtr)inSecBuffers._item1.UnmanagedToken.DangerousGetHandle(); + inUnmanagedBuffer[1].cbBuffer = ((ChannelBinding)inSecBuffers._item1.UnmanagedToken).Size; + } + else + { + inUnmanagedBuffer[1].cbBuffer = inSecBuffers._item1.Token.Length; + inUnmanagedBuffer[1].pvBuffer = (IntPtr)pinnedToken1; + } } if (inSecBuffers.Count > 0) { inUnmanagedBuffer[0].BufferType = inSecBuffers._item0.Type; - inUnmanagedBuffer[0].cbBuffer = inSecBuffers._item0.Token.Length; - inUnmanagedBuffer[0].pvBuffer = inSecBuffers._item0.UnmanagedToken != null ? - (IntPtr)inSecBuffers._item0.UnmanagedToken.DangerousGetHandle() : - (IntPtr)pinnedToken0; + if (inSecBuffers._item0.UnmanagedToken != null) + { + Debug.Assert(inSecBuffers._item0.Type == SecurityBufferType.SECBUFFER_CHANNEL_BINDINGS); + inUnmanagedBuffer[0].pvBuffer = (IntPtr)inSecBuffers._item0.UnmanagedToken.DangerousGetHandle(); + inUnmanagedBuffer[0].cbBuffer = ((ChannelBinding)inSecBuffers._item0.UnmanagedToken).Size; + } + else + { + inUnmanagedBuffer[0].cbBuffer = inSecBuffers._item0.Token.Length; + inUnmanagedBuffer[0].pvBuffer = (IntPtr)pinnedToken0; + } } fixed (byte* pinnedOutBytes = outSecBuffer.token)