Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feedback from review
  • Loading branch information
wfurt committed Feb 4, 2022
commit 99240d4baeaefe53aaf1791c4075393f6dfd8aae
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ internal static partial class OpenSsl
{
private const string DisableTlsResumeCtxSwitch = "System.Net.Security.DisableTlsResume";
private const string DisableTlsResumeEnvironmentVariable = "DOTNET_SYSTEM_NET_SECURITY_DISABLETLSRESUME";
private const SslProtocols FakeAlpnSslProtocol = (SslProtocols)1; // used to distinguish server sessions with ALPN
private static readonly IdnMapping s_idnMapping = new IdnMapping();
private static readonly ConcurrentDictionary<SslProtocols, SafeSslContextHandle> s_clientSslContexts = new ConcurrentDictionary<SslProtocols, SafeSslContextHandle>();

Expand Down Expand Up @@ -89,7 +90,7 @@ private static bool DisableTlsResume
private static SslProtocols CalculateEffectiveProtocols(SslAuthenticationOptions sslAuthenticationOptions)
{
// make sure low bit is not set since we use it in context dictionary to distinguish use with ALPN
Debug.Assert(((int)sslAuthenticationOptions.EnabledSslProtocols & 1) == 0);
Debug.Assert((sslAuthenticationOptions.EnabledSslProtocols & FakeAlpnSslProtocol) == 0);
SslProtocols protocols = sslAuthenticationOptions.EnabledSslProtocols & ~((SslProtocols)1);

if (!Interop.Ssl.Capabilities.Tls13Supported)
Expand Down Expand Up @@ -198,7 +199,8 @@ internal static unsafe SafeSslContextHandle AllocateSslContext(SafeFreeSslCreden
}
else
{
Ssl.SslCtxSetCaching(sslCtx, 1, &NewSessionCallback, &RemoveSessionCallback);
int result = Ssl.SslCtxSetCaching(sslCtx, 1, &NewSessionCallback, &RemoveSessionCallback);
Debug.Assert(result == 1);
sslCtx.EnableSessionCache();
}
}
Expand Down Expand Up @@ -285,22 +287,39 @@ internal static SafeSslHandle AllocateSslHandle(SafeFreeSslCredentials credentia
SafeSslContextHandle? newCtxHandle = null;
SslProtocols protocols = CalculateEffectiveProtocols(sslAuthenticationOptions);
bool hasAlpn = sslAuthenticationOptions.ApplicationProtocols != null && sslAuthenticationOptions.ApplicationProtocols.Count != 0;
bool cacheSslContext = !DisableTlsResume && sslAuthenticationOptions.EncryptionPolicy == EncryptionPolicy.RequireEncryption && sslAuthenticationOptions.CipherSuitesPolicy == null &&
((sslAuthenticationOptions.IsServer &&
sslAuthenticationOptions.CertificateContext != null &&
sslAuthenticationOptions.CertificateContext.SslContexts != null) ||
(sslAuthenticationOptions.IsClient &&
// since SNI us our key, we don't wont to resume sessions without it.
!string.IsNullOrEmpty(sslAuthenticationOptions.TargetHost) &&
// on client avoid resume with certificates
sslAuthenticationOptions.CertificateContext == null &&
sslAuthenticationOptions.CertSelectionDelegate == null));
bool cacheSslContext = !DisableTlsResume && sslAuthenticationOptions.EncryptionPolicy == EncryptionPolicy.RequireEncryption && sslAuthenticationOptions.CipherSuitesPolicy == null;

if (cacheSslContext)
{
if (sslAuthenticationOptions.IsClient)
{
// we don't want to try on emtpy TargetName since that is our key.
// And we don't want to mess up with client authentication. It may be possible
// but it seems safe to get full new session.
if (string.IsNullOrEmpty(sslAuthenticationOptions.TargetHost) ||
sslAuthenticationOptions.CertificateContext != null ||
sslAuthenticationOptions.CertSelectionDelegate != null)
{
cacheSslContext = false;
}
}
else
{
// Server should always have certificate
Debug.Assert(sslAuthenticationOptions.CertificateContext != null);
if (sslAuthenticationOptions.CertificateContext == null ||
sslAuthenticationOptions.CertificateContext.SslContexts == null)
{
cacheSslContext = false;
}
}
}

if (cacheSslContext)
{
if (sslAuthenticationOptions.IsServer)
{
sslAuthenticationOptions.CertificateContext!.SslContexts!.TryGetValue(protocols | (SslProtocols)(hasAlpn ? 1 : 0), out sslCtxHandle);
sslAuthenticationOptions.CertificateContext!.SslContexts!.TryGetValue(protocols | (hasAlpn ? FakeAlpnSslProtocol : SslProtocols.None), out sslCtxHandle);
}
else
{
Expand All @@ -317,8 +336,8 @@ internal static SafeSslHandle AllocateSslHandle(SafeFreeSslCredentials credentia
if (cacheSslContext)
{
bool added = sslAuthenticationOptions.IsServer ?
sslAuthenticationOptions.CertificateContext!.SslContexts!.TryAdd(protocols | (SslProtocols)(hasAlpn ? 1 : 0), newCtxHandle) :
s_clientSslContexts.TryAdd(protocols, newCtxHandle);
sslAuthenticationOptions.CertificateContext!.SslContexts!.TryAdd(protocols | (SslProtocols)(hasAlpn ? 1 : 0), newCtxHandle) :
s_clientSslContexts.TryAdd(protocols, newCtxHandle);
if (added)
{
newCtxHandle = null;
Expand All @@ -341,6 +360,7 @@ internal static SafeSslHandle AllocateSslHandle(SafeFreeSslCredentials credentia
{
if (sslAuthenticationOptions.IsServer)
{
Debug.Assert(Interop.Ssl.SslGetData(sslHandle) == IntPtr.Zero);
alpnHandle = GCHandle.Alloc(sslAuthenticationOptions.ApplicationProtocols);
Interop.Ssl.SslSetData(sslHandle, GCHandle.ToIntPtr(alpnHandle));
sslHandle.AlpnHandle = alpnHandle;
Expand Down Expand Up @@ -675,13 +695,16 @@ private static unsafe int AlpnServerSelectCallback(IntPtr ssl, byte** outp, byte
// If we return 1, the ownership is transfered to us and we will need to call SessionFree().
private static unsafe int NewSessionCallback(IntPtr ssl, IntPtr session)
{
Debug.Assert(ssl != IntPtr.Zero);
Debug.Assert(session != IntPtr.Zero);

IntPtr ptr = Ssl.SslGetData(ssl);
Debug.Assert(ptr != IntPtr.Zero);
GCHandle gch = (GCHandle)ptr;
GCHandle gch = GCHandle.FromIntPtr(ptr);

SafeSslContextHandle? ctxHandle = gch.Target as SafeSslContextHandle;
Debug.Assert(ctxHandle != null);

// There is no relation between SafeSslContextHandle and SafeSslHandle so the handle
// may be released while the ssl session is still active.
if (ctxHandle != null && ctxHandle.TryAddSession(Ssl.SslGetServerName(ssl), session))
{
// offered session was stored in our cache.
Expand All @@ -698,25 +721,23 @@ private static unsafe void RemoveSessionCallback(IntPtr ctx, IntPtr session)
Debug.Assert(ctx != IntPtr.Zero && session != IntPtr.Zero);

IntPtr ptr = Ssl.SslCtxGetData(ctx);
Debug.Assert(ptr != IntPtr.Zero);
GCHandle gch = (GCHandle)ptr;
if (!gch.IsAllocated)
if (ptr == IntPtr.Zero)
{
// Same as above, SafeSslContextHandle could be released while OpenSSL still holds refferecne.
return;
}

GCHandle gch = GCHandle.FromIntPtr(ptr);
SafeSslContextHandle? ctxHandle = gch.Target as SafeSslContextHandle;
Debug.Assert(ctxHandle != null);
if (ctxHandle == null)
{
return;
}

string? name = Marshal.PtrToStringAnsi(Ssl.SessionGetHostname(session));
if (!string.IsNullOrEmpty(name))
{
ctxHandle.Remove(name, session);
}
//string? name = Marshal.PtrToStringAnsi(Ssl.SessionGetHostname(session));a
IntPtr name = Ssl.SessionGetHostname(session);
Debug.Assert(name != IntPtr.Zero);
ctxHandle.RemoveSession(name, session);
}

private static int BioRead(SafeBioHandle bio, byte[] buffer, int count)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,9 @@ internal static partial class Ssl
[GeneratedDllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslGetData")]
internal static partial IntPtr SslGetData(IntPtr ssl);

[GeneratedDllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslGetData")]
internal static partial IntPtr SslGetData(SafeSslHandle ssl);

[GeneratedDllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslSetData")]
internal static partial int SslSetData(SafeSslHandle ssl, IntPtr data);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics;
using System.Net.Security;
Expand Down Expand Up @@ -34,7 +33,7 @@ internal static partial class Ssl
internal static unsafe partial void SslCtxSetAlpnSelectCb(SafeSslContextHandle ctx, delegate* unmanaged<IntPtr, byte**, byte*, byte*, uint, IntPtr, int> callback, IntPtr arg);

[GeneratedDllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslCtxSetCaching")]
internal static unsafe partial void SslCtxSetCaching(SafeSslContextHandle ctx, int mode, delegate* unmanaged<IntPtr, IntPtr, int> neewSessionCallback, delegate* unmanaged<IntPtr, IntPtr, void> removeSessionCallback);
internal static unsafe partial int SslCtxSetCaching(SafeSslContextHandle ctx, int mode, delegate* unmanaged<IntPtr, IntPtr, int> neewSessionCallback, delegate* unmanaged<IntPtr, IntPtr, void> removeSessionCallback);

internal static bool AddExtraChainCertificates(SafeSslContextHandle ctx, X509Certificate2[] chain)
{
Expand All @@ -61,7 +60,8 @@ namespace Microsoft.Win32.SafeHandles
{
internal sealed class SafeSslContextHandle : SafeHandle
{
private ConcurrentDictionary<string, IntPtr>? _sslSessions;
// This is session cache keyed by SNI e.g. TargetHost
private Dictionary<string, IntPtr>? _sslSessions;
private GCHandle _gch;

public SafeSslContextHandle()
Expand All @@ -81,60 +81,74 @@ public override bool IsInvalid

protected override bool ReleaseHandle()
{
Interop.Ssl.SslCtxDestroy(handle);
SetHandle(IntPtr.Zero);
if (_gch.IsAllocated)
{
//Interop.Ssl.SslCtxSetData(this, (IntPtr)_gch);
_gch.Free();
}

if (_sslSessions != null)
{
// The SSL_CTX is ref counted and may not immediately die when we call SslCtxDestroy()
// Since there is no relation between SafeSslContextHandle and SafeSslHandle `this` can be release
// while we still have SSL session using it.
Interop.Ssl.SslCtxSetData(handle, IntPtr.Zero);

lock (_sslSessions)
{
foreach (string name in _sslSessions.Keys)
foreach (IntPtr session in _sslSessions.Values)
{
_sslSessions.Remove(name, out IntPtr session);
Interop.Ssl.SessionFree(session);
}

_sslSessions.Clear();
}

Debug.Assert(_gch.IsAllocated);
_gch.Free();
}

Interop.Ssl.SslCtxDestroy(handle);
SetHandle(IntPtr.Zero);

return true;
}

public void EnableSessionCache()
internal void EnableSessionCache()
{
_sslSessions = new ConcurrentDictionary<string, IntPtr>();
Debug.Assert(_sslSessions == null);

_sslSessions = new Dictionary<string, IntPtr>();
_gch = GCHandle.Alloc(this);
// This is needed so we can find the handle from session remove callback.
Debug.Assert(_gch.IsAllocated);
// This is needed so we can find the handle from session in SessionRemove callback.
Interop.Ssl.SslCtxSetData(this, (IntPtr)_gch);
}

public bool TryAddSession(IntPtr serverName, IntPtr session)
internal bool TryAddSession(IntPtr namePtr, IntPtr session)
{
Debug.Assert(_sslSessions != null && session != IntPtr.Zero);

if (_sslSessions == null || serverName == IntPtr.Zero)
if (_sslSessions == null || namePtr == IntPtr.Zero)
{
return false;
}

string? name = Marshal.PtrToStringAnsi(serverName);
if (!string.IsNullOrEmpty(name))
string? targetName = Marshal.PtrToStringAnsi(namePtr);
Debug.Assert(targetName != null);

if (!string.IsNullOrEmpty(targetName))
{
Interop.Ssl.SessionSetHostname(session, serverName);
// We do this only for lookup in RemoveSession.
// Since this is part of chache manipulation and no function impact it is done here.
// This will use strdup() so it is safe to pass in raw pointer.
Interop.Ssl.SessionSetHostname(session, namePtr);

lock (_sslSessions)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The locking around _sslSessions makes sense, since you're manipulating state depending on how the dictionary performed.

But, since you're already locking it, it feels like you want a non-Concurrent dictionary.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed. I was also thinking about grabbing extra reference on the session. That would allow me to use ConcurrentDictionary without locking as the session would never be released in the middle.
Do you have preference/recommendation @bartonjs ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

grabbing extra reference on the session

Something like

Interop.Ssl.SslSessionUpRef(session);

if (!_sslSessions.TryAdd(...))
{
    // Undo the upref since it's not in the dictionary
    Interop.Ssl.SslSessionFree(session);
}

? (Upref inside has a race condition with the cleanup in ReleaseHandle)

That would get a little weird since in the cleanup you'd need to call free twice, I think?

The fact that we wrote ConcurrentDictionary suggests that it gives better perf (on average) than manual locking... but if the code to interact with it is doing memory/lifetime management and it becomes unreadable with the gymnastics... then locking is better for maintainability. (If it's clean code and more performant, than by all means use upref+ConcurrentDictionary)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. and then call free twice. I'm inclined to good with the lock and better maintainability as the perf does not depend on this. This happens one in while - not even for each SSL session. I started with ConcurrentDictionary but you are right - we don't need it at this point.

{
IntPtr oldSession = _sslSessions.GetOrAdd(name, session);
if (oldSession != session)
if (!_sslSessions.TryAdd(targetName, session))
{
_sslSessions.Remove(name, out oldSession);
Interop.Ssl.SessionFree(oldSession);
oldSession = _sslSessions.GetOrAdd(name, session);
Debug.Assert(oldSession == session);
if (_sslSessions.Remove(targetName, out IntPtr oldSession))
{
Interop.Ssl.SessionFree(oldSession);
}

bool added = _sslSessions.TryAdd(targetName, session);
Debug.Assert(added);
}
}

Expand All @@ -144,21 +158,33 @@ public bool TryAddSession(IntPtr serverName, IntPtr session)
return false;
}

public void Remove(string name, IntPtr session)
internal void RemoveSession(IntPtr namePtr, IntPtr session)
{
if (_sslSessions != null)
Debug.Assert(_sslSessions != null);

string? targetName = Marshal.PtrToStringAnsi(namePtr);
Debug.Assert(targetName != null);

if (_sslSessions != null && targetName != null)
{
IntPtr oldSession;
bool removed;
lock (_sslSessions)
{
if (!_sslSessions.Remove(name, out IntPtr oldSession))
{
Interop.Ssl.SessionFree(oldSession);
}
removed = _sslSessions.Remove(targetName, out oldSession);
}

if (removed)
{
// It seems like we may be called more than once. Since we grabbed only one refference
// when added to Dictionary, we will also drop exactly one when removed.
Interop.Ssl.SessionFree(oldSession);
}

}
}

public bool TrySetSession(SafeSslHandle sslHandle, string name)
internal bool TrySetSession(SafeSslHandle sslHandle, string name)
{
Debug.Assert(_sslSessions != null);

Expand All @@ -169,6 +195,7 @@ public bool TrySetSession(SafeSslHandle sslHandle, string name)

// even if we don't have matching session, we can get new one and we need
// way how to link SSL back to `this`.
Debug.Assert(Interop.Ssl.SslGetData(sslHandle) == IntPtr.Zero);
Interop.Ssl.SslSetData(sslHandle, (IntPtr)_gch);

lock (_sslSessions)
Expand Down
22 changes: 0 additions & 22 deletions src/native/libs/System.Security.Cryptography.Native/apibridge.c
Original file line number Diff line number Diff line change
Expand Up @@ -889,26 +889,4 @@ int local_EVP_PKEY_public_check(EVP_PKEY_CTX* ctx)
return -1;
}
}

const char * local_SSL_SESSION_get0_hostname(const SSL_SESSION *s)
{
return s->tlsext_hostname;
}

int local_SSL_SESSION_set1_hostname(SSL_SESSION *s, const char *hostname)
{
if (s->tlsext_hostname != NULL)
{
OPENSSL_free(s->tlsext_hostname);
}

if (hostname == NULL) {
s->tlsext_hostname = NULL;
return 1;
}

s->tlsext_hostname = OPENSSL_strdup(hostname);
return s->tlsext_hostname != NULL;
}

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -57,5 +57,3 @@ const X509_ALGOR* local_X509_get0_tbs_sigalg(const X509* x509);
X509_PUBKEY* local_X509_get_X509_PUBKEY(const X509* x509);
int32_t local_X509_get_version(const X509* x509);
int32_t local_X509_up_ref(X509* x509);
const char * local_SSL_SESSION_get0_hostname(const SSL_SESSION *s);
int local_SSL_SESSION_set1_hostname(SSL_SESSION *s, const char *hostname);
Loading