Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Use string for SNI instead of byte[]
  • Loading branch information
twsouthwick committed Feb 26, 2025
commit bbb5636eee6434d44ccaed2741dfe48833e3963a
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,9 @@
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\SqlObjectPool.cs">
<Link>Microsoft\Data\SqlClient\SqlObjectPool.cs</Link>
</Compile>
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\SqlObjectPools.cs">
<Link>Microsoft\Data\SqlClient\SqlObjectPools.cs</Link>
</Compile>
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\SqlParameter.cs">
<Link>Microsoft\Data\SqlClient\SqlParameter.cs</Link>
</Compile>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

using System;
using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Net;
Expand Down Expand Up @@ -50,7 +51,7 @@ internal static SNIHandle CreateConnectionHandle(
string fullServerName,
TimeoutTimer timeout,
out byte[] instanceName,
ref byte[][] spnBuffer,
ref string[] spnBuffer,
string serverSPN,
bool flushCache,
bool async,
Expand Down Expand Up @@ -114,12 +115,12 @@ internal static SNIHandle CreateConnectionHandle(
return sniHandle;
}

private static byte[][] GetSqlServerSPNs(DataSource dataSource, string serverSPN)
private static string[] GetSqlServerSPNs(DataSource dataSource, string serverSPN)
{
Debug.Assert(!string.IsNullOrWhiteSpace(dataSource.ServerName));
if (!string.IsNullOrWhiteSpace(serverSPN))
{
return new byte[1][] { Encoding.Unicode.GetBytes(serverSPN) };
return new[] { serverSPN };
}

string hostName = dataSource.ServerName;
Expand All @@ -137,7 +138,7 @@ private static byte[][] GetSqlServerSPNs(DataSource dataSource, string serverSPN
return GetSqlServerSPNs(hostName, postfix, dataSource.ResolvedProtocol);
}

private static byte[][] GetSqlServerSPNs(string hostNameOrAddress, string portOrInstanceName, DataSource.Protocol protocol)
private static string[] GetSqlServerSPNs(string hostNameOrAddress, string portOrInstanceName, DataSource.Protocol protocol)
{
Debug.Assert(!string.IsNullOrWhiteSpace(hostNameOrAddress));
IPHostEntry hostEntry = null;
Expand Down Expand Up @@ -168,12 +169,12 @@ private static byte[][] GetSqlServerSPNs(string hostNameOrAddress, string portOr
string serverSpnWithDefaultPort = serverSpn + $":{DefaultSqlServerPort}";
// Set both SPNs with and without Port as Port is optional for default instance
SqlClientEventSource.Log.TryAdvancedTraceEvent("SNIProxy.GetSqlServerSPN | Info | ServerSPNs {0} and {1}", serverSpn, serverSpnWithDefaultPort);
return new byte[][] { Encoding.Unicode.GetBytes(serverSpn), Encoding.Unicode.GetBytes(serverSpnWithDefaultPort) };
return new[] { serverSpn, serverSpnWithDefaultPort };
}
// else Named Pipes do not need to valid port

SqlClientEventSource.Log.TryAdvancedTraceEvent("SNIProxy.GetSqlServerSPN | Info | ServerSPN {0}", serverSpn);
return new byte[][] { Encoding.Unicode.GetBytes(serverSpn) };
return new[] { serverSpn };
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,7 @@ internal sealed partial class TdsParser

private bool _is2022 = false;

private byte[][] _sniSpnBuffer = null;
// UNDONE - need to have some for both instances - both command and default???
private string[] _sniSpn = null;

// SqlStatistics
private SqlStatistics _statistics = null;
Expand Down Expand Up @@ -390,7 +389,7 @@ internal void Connect(
}
else
{
_sniSpnBuffer = null;
_sniSpn = null;
SqlClientEventSource.Log.TryTraceEvent("TdsParser.Connect | SEC | Connection Object Id {0}, Authentication Mode: {1}", _connHandler.ObjectID,
authType == SqlAuthenticationMethod.NotSpecified ? SqlAuthenticationMethod.SqlPassword.ToString() : authType.ToString());
}
Expand All @@ -402,7 +401,7 @@ internal void Connect(
SqlClientEventSource.Log.TryTraceEvent("<sc.TdsParser.Connect|SEC> Encryption will be disabled as target server is a SQL Local DB instance.");
}

_sniSpnBuffer = null;
_sniSpn = null;
_authenticationProvider = null;

// AD Integrated behaves like Windows integrated when connecting to a non-fedAuth server
Expand Down Expand Up @@ -441,7 +440,7 @@ internal void Connect(
serverInfo.ExtendedServerName,
timeout,
out instanceName,
ref _sniSpnBuffer,
ref _sniSpn,
false,
true,
fParallel,
Expand All @@ -454,8 +453,6 @@ internal void Connect(
hostNameInCertificate,
serverCertificateFilename);

_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this);

if (TdsEnums.SNI_SUCCESS != _physicalStateObj.Status)
{
_physicalStateObj.AddError(ProcessSNIError(_physicalStateObj));
Expand All @@ -470,6 +467,8 @@ internal void Connect(
Debug.Fail("SNI returned status != success, but no error thrown?");
}

_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this);

_server = serverInfo.ResolvedServerName;

if (connHandler.PoolGroupProviderInfo != null)
Expand Down Expand Up @@ -540,7 +539,7 @@ internal void Connect(
_physicalStateObj.CreatePhysicalSNIHandle(
serverInfo.ExtendedServerName,
timeout, out instanceName,
ref _sniSpnBuffer,
ref _sniSpn,
true,
true,
fParallel,
Expand All @@ -553,15 +552,15 @@ internal void Connect(
hostNameInCertificate,
serverCertificateFilename);

_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this);

if (TdsEnums.SNI_SUCCESS != _physicalStateObj.Status)
{
_physicalStateObj.AddError(ProcessSNIError(_physicalStateObj));
SqlClientEventSource.Log.TryTraceEvent("<sc.TdsParser.Connect|ERR|SEC> Login failure");
ThrowExceptionAndWarning(_physicalStateObj);
}

_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this);

uint retCode = _physicalStateObj.SniGetConnectionId(ref _connHandler._clientConnectionId);

Debug.Assert(retCode == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SniGetConnectionId");
Expand Down Expand Up @@ -13317,7 +13316,7 @@ internal string TraceString()
_fMARS ? bool.TrueString : bool.FalseString,
_sessionPool == null ? "(null)" : _sessionPool.TraceString(),
_is2005 ? bool.TrueString : bool.FalseString,
_sniSpnBuffer == null ? "(null)" : _sniSpnBuffer.Length.ToString((IFormatProvider)null),
_sniSpn == null ? "(null)" : _sniSpn.Length.ToString((IFormatProvider)null),
_physicalStateObj != null ? "(null)" : _physicalStateObj.ErrorCount.ToString((IFormatProvider)null),
_physicalStateObj != null ? "(null)" : _physicalStateObj.WarningCount.ToString((IFormatProvider)null),
_physicalStateObj != null ? "(null)" : _physicalStateObj.PreAttentionErrorCount.ToString((IFormatProvider)null),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ internal abstract void CreatePhysicalSNIHandle(
string serverName,
TimeoutTimer timeout,
out byte[] instanceName,
ref byte[][] spnBuffer,
ref string[] spn,
bool flushCache,
bool async,
bool fParallel,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ internal override void CreatePhysicalSNIHandle(
string serverName,
TimeoutTimer timeout,
out byte[] instanceName,
ref byte[][] spnBuffer,
ref string[] spn,
bool flushCache,
bool async,
bool parallel,
Expand All @@ -94,7 +94,7 @@ internal override void CreatePhysicalSNIHandle(
string hostNameInCertificate,
string serverCertificateFilename)
{
SNIHandle? sessionHandle = SNIProxy.CreateConnectionHandle(serverName, timeout, out instanceName, ref spnBuffer, serverSPN,
SNIHandle? sessionHandle = SNIProxy.CreateConnectionHandle(serverName, timeout, out instanceName, ref spn, serverSPN,
flushCache, async, parallel, isIntegratedSecurity, iPAddressPreference, cachedFQDN, ref pendingDNSInfo, tlsFirst,
hostNameInCertificate, serverCertificateFilename);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ internal override void CreatePhysicalSNIHandle(
string serverName,
TimeoutTimer timeout,
out byte[] instanceName,
ref byte[][] spnBuffer,
ref string[] spn,
bool flushCache,
bool async,
bool fParallel,
Expand All @@ -157,31 +157,28 @@ internal override void CreatePhysicalSNIHandle(
string hostNameInCertificate,
string serverCertificateFilename)
{
// We assume that the loadSSPILibrary has been called already. now allocate proper length of buffer
spnBuffer = new byte[1][];
if (isIntegratedSecurity)
{
// now allocate proper length of buffer
if (!string.IsNullOrEmpty(serverSPN))
{
// Native SNI requires the Unicode encoding and any other encoding like UTF8 breaks the code.
byte[] srvSPN = Encoding.Unicode.GetBytes(serverSPN);
Trace.Assert(srvSPN.Length <= SniNativeWrapper.SniMaxComposedSpnLength, "Length of the provided SPN exceeded the buffer size.");
spnBuffer[0] = srvSPN;
SqlClientEventSource.Log.TryTraceEvent("<{0}.{1}|SEC> Server SPN `{2}` from the connection string is used.", nameof(TdsParserStateObjectNative), nameof(CreatePhysicalSNIHandle), serverSPN);
}
else
{
spnBuffer[0] = new byte[SniNativeWrapper.SniMaxComposedSpnLength];
// This will signal to the interop layer that we need to retrieve the SPN
serverSPN = string.Empty;
}
}

ConsumerInfo myInfo = CreateConsumerInfo(async);
SQLDNSInfo cachedDNSInfo;
bool ret = SQLFallbackDNSCache.Instance.GetDNSInfo(cachedFQDN, out cachedDNSInfo);

_sessionHandle = new SNIHandle(myInfo, serverName, spnBuffer[0], timeout.MillisecondsRemainingInt, out instanceName,
_sessionHandle = new SNIHandle(myInfo, serverName, ref serverSPN, timeout.MillisecondsRemainingInt, out instanceName,
flushCache, !async, fParallel, ipPreference, cachedDNSInfo, hostNameInCertificate);
spn = new[] { serverSPN.TrimEnd() };
}

protected override uint SNIPacketGetData(PacketHandle packet, byte[] _inBuff, ref uint dataSize)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -337,10 +337,13 @@
<Compile Include="$(CommonSourceRoot)Microsoft\Data\Sql\SqlDataSourceEnumerator.cs">
<Link>Microsoft\Data\Sql\SqlDataSourceEnumerator.cs</Link>
</Compile>
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\SqlObjectPool.cs">
<Link>Microsoft\Data\SqlClient\SqlObjectPool.cs</Link>
</Compile>
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\AAsyncCallContext.cs">
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\SqlObjectPool.cs">
<Link>Microsoft\Data\SqlClient\SqlObjectPool.cs</Link>
</Compile>
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\SqlObjectPools.cs">
<Link>Microsoft\Data\SqlClient\SqlObjectPools.cs</Link>
</Compile>
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\AAsyncCallContext.cs">
<Link>Microsoft\Data\SqlClient\AAsyncCallContext.cs</Link>
</Compile>
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\ActiveDirectoryAuthenticationTimeoutRetryHelper.cs">
Expand Down Expand Up @@ -854,6 +857,7 @@
<Compile Include="Microsoft\Data\Common\DbConnectionString.cs" />
<Compile Include="Microsoft\Data\Common\GreenMethods.cs" />
<Compile Include="Microsoft\Data\SqlClient\assemblycache.cs" />
<Compile Include="Microsoft\Data\SqlClient\BufferWriterExtensions.cs" />
<Compile Include="Microsoft\Data\SqlClient\Reliability\SqlConfigurableRetryLogicManager.LoadType.cs" />
<Compile Include="Microsoft\Data\SqlClient\Server\SmiConnection.cs" />
<Compile Include="Microsoft\Data\SqlClient\Server\SmiContext.cs" />
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
using System.Buffers;
using System.Text;

namespace Microsoft.Data.SqlClient
{
internal static class BufferWriterExtensions
{
internal static long GetBytes(this Encoding encoding, string str, IBufferWriter<byte> bufferWriter)
{
var count = encoding.GetByteCount(str);
var array = ArrayPool<byte>.Shared.Rent(count);

try
{
encoding.GetBytes(str, 0, str.Length, array, 0);
bufferWriter.Write(array);
return count;
}
finally
{
ArrayPool<byte>.Shared.Return(array);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ internal int ObjectID

private bool _is2022 = false;

private byte[] _sniSpnBuffer = null;
private string _sniSpn = null;

// UNDONE - need to have some for both instances - both command and default???

Expand Down Expand Up @@ -430,27 +430,24 @@ internal void Connect(ServerInfo serverInfo,
// AD Integrated behaves like Windows integrated when connecting to a non-fedAuth server
if (integratedSecurity || authType == SqlAuthenticationMethod.ActiveDirectoryIntegrated)
{
_authenticationProvider = _physicalStateObj.CreateSSPIContextProvider();

if (!string.IsNullOrEmpty(serverInfo.ServerSPN))
{
// Native SNI requires the Unicode encoding and any other encoding like UTF8 breaks the code.
byte[] srvSPN = Encoding.Unicode.GetBytes(serverInfo.ServerSPN);
Trace.Assert(srvSPN.Length <= SniNativeWrapper.SniMaxComposedSpnLength, "The provided SPN length exceeded the buffer size.");
_sniSpnBuffer = srvSPN;
_sniSpn = serverInfo.ServerSPN;
SqlClientEventSource.Log.TryTraceEvent("<sc.TdsParser.Connect|SEC> Server SPN `{0}` from the connection string is used.", serverInfo.ServerSPN);
}
else
{
// now allocate proper length of buffer
_sniSpnBuffer = new byte[SniNativeWrapper.SniMaxComposedSpnLength];
// Empty signifies to interop layer that SNI needs to be generated
_sniSpn = string.Empty;
}

_authenticationProvider = _physicalStateObj.CreateSSPIContextProvider();
SqlClientEventSource.Log.TryTraceEvent("<sc.TdsParser.Connect|SEC> SSPI or Active Directory Authentication Library for SQL Server based integrated authentication");
}
else
{
_authenticationProvider = null;
_sniSpnBuffer = null;
_sniSpn = null;

switch (authType)
{
Expand Down Expand Up @@ -529,7 +526,7 @@ internal void Connect(ServerInfo serverInfo,
serverInfo.ExtendedServerName,
timeout,
out instanceName,
_sniSpnBuffer,
ref _sniSpn,
false,
true,
fParallel,
Expand All @@ -539,8 +536,6 @@ internal void Connect(ServerInfo serverInfo,
FQDNforDNSCache,
hostNameInCertificate);

_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this);

if (TdsEnums.SNI_SUCCESS != _physicalStateObj.Status)
{
_physicalStateObj.AddError(ProcessSNIError(_physicalStateObj));
Expand All @@ -555,6 +550,8 @@ internal void Connect(ServerInfo serverInfo,
Debug.Fail("SNI returned status != success, but no error thrown?");
}

_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this);

_server = serverInfo.ResolvedServerName;

if (connHandler.PoolGroupProviderInfo != null)
Expand Down Expand Up @@ -629,7 +626,7 @@ internal void Connect(ServerInfo serverInfo,
serverInfo.ExtendedServerName,
timeout,
out instanceName,
_sniSpnBuffer,
ref _sniSpn,
true,
true,
fParallel,
Expand All @@ -639,15 +636,15 @@ internal void Connect(ServerInfo serverInfo,
serverInfo.ResolvedServerName,
hostNameInCertificate);

_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this);

if (TdsEnums.SNI_SUCCESS != _physicalStateObj.Status)
{
_physicalStateObj.AddError(ProcessSNIError(_physicalStateObj));
SqlClientEventSource.Log.TryTraceEvent("<sc.TdsParser.Connect|ERR|SEC> Login failure");
ThrowExceptionAndWarning(_physicalStateObj);
}

_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this);

uint retCode = SniNativeWrapper.SniGetConnectionId(_physicalStateObj.Handle, ref _connHandler._clientConnectionId);
Debug.Assert(retCode == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SniGetConnectionId");
SqlClientEventSource.Log.TryTraceEvent("<sc.TdsParser.Connect|SEC> Sending prelogin handshake");
Expand Down Expand Up @@ -13785,7 +13782,7 @@ internal string TraceString()
_is2000 ? bool.TrueString : bool.FalseString,
_is2000SP1 ? bool.TrueString : bool.FalseString,
_is2005 ? bool.TrueString : bool.FalseString,
_sniSpnBuffer == null ? "(null)" : _sniSpnBuffer.Length.ToString((IFormatProvider)null),
_sniSpn == null ? "(null)" : _sniSpn.Length.ToString((IFormatProvider)null),
_physicalStateObj != null ? "(null)" : _physicalStateObj.ErrorCount.ToString((IFormatProvider)null),
_physicalStateObj != null ? "(null)" : _physicalStateObj.WarningCount.ToString((IFormatProvider)null),
_physicalStateObj != null ? "(null)" : _physicalStateObj.PreAttentionErrorCount.ToString((IFormatProvider)null),
Expand Down
Loading
Loading