Skip to content
Merged
Prev Previous commit
Next Next commit
Offload parsing to threadpool as well
  • Loading branch information
rzikm committed Feb 26, 2024
commit 702ea9423320429e335e68a3d8a326a500f995c2
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ internal static class CertificateValidation
private static readonly IdnMapping s_idnMapping = new IdnMapping();

// WARNING: This function will do the verification using OpenSSL. If the intention is to use OS function, caller should use CertificatePal interface.
internal static SslPolicyErrors BuildChainAndVerifyProperties(X509Chain chain, X509Certificate2 remoteCertificate, bool checkCertName, bool _ /*isServer*/, string? hostName, IntPtr certificateBuffer, int bufferLength = 0)
internal static SslPolicyErrors BuildChainAndVerifyProperties(X509Chain chain, X509Certificate2 remoteCertificate, bool checkCertName, bool _ /*isServer*/, string? hostName, Span<byte> certificateBuffer)
{
SslPolicyErrors errors = chain.Build(remoteCertificate) ?
SslPolicyErrors.None :
Expand All @@ -31,15 +31,24 @@ internal static SslPolicyErrors BuildChainAndVerifyProperties(X509Chain chain, X
}

SafeX509Handle certHandle;
if (certificateBuffer != IntPtr.Zero && bufferLength > 0)
unsafe
{
certHandle = Interop.Crypto.DecodeX509(certificateBuffer, bufferLength);
}
else
{
// We dont't have DER encoded buffer.
byte[] der = remoteCertificate.Export(X509ContentType.Cert);
certHandle = Interop.Crypto.DecodeX509(Marshal.UnsafeAddrOfPinnedArrayElement(der, 0), der.Length);
if (certificateBuffer.Length > 0)
{
fixed (byte* pCert = certificateBuffer)
{
certHandle = Interop.Crypto.DecodeX509((IntPtr)pCert, certificateBuffer.Length);
}
}
else
{
// We dont't have DER encoded buffer.
byte[] der = remoteCertificate.Export(X509ContentType.Cert);
fixed (byte* pDer = der)
{
certHandle = Interop.Crypto.DecodeX509((IntPtr)pDer, der.Length);
}
}
}

int hostNameMatch;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ internal static class CertificateValidation
private static readonly IdnMapping s_idnMapping = new IdnMapping();

#pragma warning disable IDE0060
internal static SslPolicyErrors BuildChainAndVerifyProperties(X509Chain chain, X509Certificate2 remoteCertificate, bool checkCertName, bool isServer, string? hostName, IntPtr certificateBuffer, int bufferLength)
internal static SslPolicyErrors BuildChainAndVerifyProperties(X509Chain chain, X509Certificate2 remoteCertificate, bool checkCertName, bool isServer, string? hostName, Span<byte> certificateBuffer)
=> BuildChainAndVerifyProperties(chain, remoteCertificate, checkCertName, isServer, hostName);
#pragma warning restore IDE0060

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace System.Net
internal static partial class CertificateValidation
{
#pragma warning disable IDE0060
internal static SslPolicyErrors BuildChainAndVerifyProperties(X509Chain chain, X509Certificate2 remoteCertificate, bool checkCertName, bool isServer, string? hostName, IntPtr certificateBuffer, int bufferLength)
internal static SslPolicyErrors BuildChainAndVerifyProperties(X509Chain chain, X509Certificate2 remoteCertificate, bool checkCertName, bool isServer, string? hostName, Span<byte> certificateBuffer)
=> BuildChainAndVerifyProperties(chain, remoteCertificate, checkCertName, isServer, hostName);
#pragma warning restore IDE0060

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,38 +63,45 @@ public SslConnectionOptions(QuicConnection connection, bool isClient,
_certificateChainPolicy = certificateChainPolicy;
}

public unsafe int ValidateCertificate(X509Certificate2? certificate, X509Chain? chain)
public unsafe int ValidateCertificate(X509Certificate2? certificate, Span<byte> certData, Span<byte> chainData)
{
SslPolicyErrors sslPolicyErrors = SslPolicyErrors.None;
bool wrapException = false;

X509Chain? chain = null;
try
{
chain ??= new X509Chain();

if (_certificateChainPolicy != null)
{
chain.ChainPolicy = _certificateChainPolicy;
}
else
if (certificate is not null)
{
chain.ChainPolicy.RevocationMode = _revocationMode;
chain.ChainPolicy.RevocationFlag = X509RevocationFlag.ExcludeRoot;
chain = new X509Chain();
if (_certificateChainPolicy != null)
{
chain.ChainPolicy = _certificateChainPolicy;
}
else
{
chain.ChainPolicy.RevocationMode = _revocationMode;
chain.ChainPolicy.RevocationFlag = X509RevocationFlag.ExcludeRoot;

// TODO: configure chain.ChainPolicy.CustomTrustStore to mirror behavior of SslStream.VerifyRemoteCertificate (https://github.com/dotnet/runtime/issues/73053)
}
// TODO: configure chain.ChainPolicy.CustomTrustStore to mirror behavior of SslStream.VerifyRemoteCertificate (https://github.com/dotnet/runtime/issues/73053)
}

// set ApplicationPolicy unless already provided.
if (chain.ChainPolicy.ApplicationPolicy.Count == 0)
{
// Authenticate the remote party: (e.g. when operating in server mode, authenticate the client).
chain.ChainPolicy.ApplicationPolicy.Add(_isClient ? s_serverAuthOid : s_clientAuthOid);
}
// set ApplicationPolicy unless already provided.
if (chain.ChainPolicy.ApplicationPolicy.Count == 0)
{
// Authenticate the remote party: (e.g. when operating in server mode, authenticate the client).
chain.ChainPolicy.ApplicationPolicy.Add(_isClient ? s_serverAuthOid : s_clientAuthOid);
}

if (chainData.Length > 0)
{
X509Certificate2Collection additionalCertificates = new X509Certificate2Collection();
additionalCertificates.Import(chainData);
chain.ChainPolicy.ExtraStore.AddRange(additionalCertificates);
}

if (certificate is not null)
{
bool checkCertName = !chain!.ChainPolicy!.VerificationFlags.HasFlag(X509VerificationFlags.IgnoreInvalidName);
sslPolicyErrors |= CertificateValidation.BuildChainAndVerifyProperties(chain!, certificate, checkCertName, !_isClient, TargetHostNameHelper.NormalizeHostName(_targetHost), IntPtr.Zero, 0);
sslPolicyErrors |= CertificateValidation.BuildChainAndVerifyProperties(chain!, certificate, checkCertName, !_isClient, TargetHostNameHelper.NormalizeHostName(_targetHost), certData);
}
else if (_certificateRequired)
{
Expand Down Expand Up @@ -130,7 +137,6 @@ public unsafe int ValidateCertificate(X509Certificate2? certificate, X509Chain?
}
catch (Exception ex)
{
certificate?.Dispose();
if (wrapException)
{
throw new QuicException(QuicError.CallbackError, null, SR.net_quic_callback_error, ex);
Expand Down
67 changes: 49 additions & 18 deletions src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Buffers;
using System.Diagnostics;
using System.Net.Security;
using System.Net.Sockets;
Expand Down Expand Up @@ -571,55 +572,85 @@ private unsafe int HandleEventPeerStreamStarted(ref PEER_STREAM_STARTED_DATA dat
}
private unsafe int HandleEventPeerCertificateReceived(ref PEER_CERTIFICATE_RECEIVED_DATA data)
{
//
// the certificate validation is an expensive operation and we don't want to delay MsQuic
// worker thread. So we offload the validation to the .NET threadpool. Incidentally, this
// also prevents potential user RemoteCertificateValidationCallback from blocking MsQuic
// worker threads.
//
// the provided data pointers are valid only while still inside the callback so they need to be
// copied before handing them to threadpool.
//

X509Certificate2? certificate = null;
X509Chain? chain = null;

QUIC_BUFFER* certificatePtr = (QUIC_BUFFER*)data.Certificate;
QUIC_BUFFER* chainPtr = (QUIC_BUFFER*)data.Chain;
byte[]? certDataRented = null;
Memory<byte> certData = default;
byte[]? chainDataRented = null;
Memory<byte> chainData = default;

if (certificatePtr is not null)
if (data.Certificate != null)
{
chain = new X509Chain();
if (MsQuicApi.UsesSChannelBackend)
{
certificate = new X509Certificate2((IntPtr)certificatePtr);
certificate = new X509Certificate2((IntPtr)data.Certificate);
// TODO: what about chainPtr?
}
else
{
// on non-SChannel backends we specify USE_PORTABLE_CERTIFICATES and the content is buffers
// with DER encoded cert and chain
QUIC_BUFFER* certificatePtr = (QUIC_BUFFER*)data.Certificate;
QUIC_BUFFER* chainPtr = (QUIC_BUFFER*)data.Chain;

if (certificatePtr->Length > 0)
{
certificate = new X509Certificate2(certificatePtr->Span);
certDataRented = ArrayPool<byte>.Shared.Rent((int)certificatePtr->Length);
certData = certDataRented.AsMemory(0, (int)certificatePtr->Length);
certificatePtr->Span.CopyTo(certData.Span);
}

if (chainPtr->Length > 0)
{
X509Certificate2Collection additionalCertificates = new X509Certificate2Collection();
additionalCertificates.Import(chainPtr->Span);
chain.ChainPolicy.ExtraStore.AddRange(additionalCertificates);
chainDataRented = ArrayPool<byte>.Shared.Rent((int)chainPtr->Length);
chainData = chainDataRented.AsMemory(0, (int)chainPtr->Length);
chainPtr->Span.CopyTo(chainData.Span);
}
}
}

//
// the certificate validation is an expensive operation and we don't want to delay MsQuic
// worker thread. So we offload the validation to the .NET threadpool. Incidentally, this
// also prevents potential user RemoteCertificateValidationCallback from blocking MsQuic
// worker threads.
//

_ = Task.Run(() =>
{
int result;
try
{
result = _sslConnectionOptions.ValidateCertificate(certificate, chain);
if (certData.Length > 0)
{
Debug.Assert(certificate == null);
certificate = new X509Certificate2(certData.Span);
}

result = _sslConnectionOptions.ValidateCertificate(certificate, certData.Span, chainData.Span);
_remoteCertificate = certificate;
}
catch (Exception ex)
{
certificate?.Dispose();
_connectedTcs.TrySetException(ex);
result = QUIC_STATUS_HANDSHAKE_FAILURE;
}
finally
{
if (certDataRented != null)
{
ArrayPool<byte>.Shared.Return(certDataRented);
}

if (chainDataRented != null)
{
ArrayPool<byte>.Shared.Return(chainDataRented);
}
}

int status = MsQuicApi.Api.ConnectionCertificateValidationComplete(
_handle,
Expand Down