Skip to content
This repository was archived by the owner on Jan 23, 2023. It is now read-only.

Commit 0d07c5c

Browse files
committed
Enable custom socket end points and allow Unix Domain Sockets
With code separation into System.Net.Primitives and System.Net.Sockets, EndPoint extensibility was broken because System.Net.Sockets started to use its own copy of SocketAddress and didn't respect SocketAddress that a custom EndPoint may provide. The fix is to allow conversion between SocketAddress from System.Net.Primitives and System.Net.Sockets. This way custom implementations of EndPoint will be able to provide their own SocketAddress and it'll be honored by the Socket APIs. The fix also allows sockets to use 'Unspecified' protocol type which is needed for Unix Domain Sockets. There are several changes in socket test server to allow tests pass protocol type. Add new unit tests that use end point extensibility to implement Unix Domain Sockets. Fix #4777
1 parent 975c3ca commit 0d07c5c

File tree

8 files changed

+289
-18
lines changed

8 files changed

+289
-18
lines changed

src/Common/src/System/Net/Internals/IPEndPointExtensions.cs

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,28 +9,65 @@ internal static class IPEndPointExtensions
99
{
1010
public static Internals.SocketAddress Serialize(EndPoint endpoint)
1111
{
12-
Debug.Assert(endpoint is IPEndPoint);
12+
Debug.Assert(!(endpoint is DnsEndPoint));
1313

14-
return new Internals.SocketAddress(((IPEndPoint)endpoint).Address, ((IPEndPoint)endpoint).Port);
14+
var ipEndPoint = endpoint as IPEndPoint;
15+
if (ipEndPoint != null)
16+
{
17+
return new Internals.SocketAddress(ipEndPoint.Address, ipEndPoint.Port);
18+
}
19+
20+
System.Net.SocketAddress address = endpoint.Serialize();
21+
return GetInternalSocketAddress(address);
1522
}
1623

1724
public static EndPoint Create(this EndPoint thisObj, Internals.SocketAddress socketAddress)
1825
{
19-
if (socketAddress.Family != thisObj.AddressFamily)
26+
AddressFamily family = socketAddress.Family;
27+
if (family != thisObj.AddressFamily)
2028
{
21-
throw new ArgumentException(SR.Format(SR.net_InvalidAddressFamily, socketAddress.Family.ToString(), thisObj.GetType().FullName, thisObj.AddressFamily.ToString()), "socketAddress");
29+
throw new ArgumentException(SR.Format(SR.net_InvalidAddressFamily, family.ToString(), thisObj.GetType().FullName, thisObj.AddressFamily.ToString()), "socketAddress");
2230
}
23-
if (socketAddress.Size < 8)
31+
32+
if (family == AddressFamily.InterNetwork || family == AddressFamily.InterNetworkV6)
2433
{
25-
throw new ArgumentException(SR.Format(SR.net_InvalidSocketAddressSize, socketAddress.GetType().FullName, thisObj.GetType().FullName), "socketAddress");
34+
if (socketAddress.Size < 8)
35+
{
36+
throw new ArgumentException(SR.Format(SR.net_InvalidSocketAddressSize, socketAddress.GetType().FullName, thisObj.GetType().FullName), "socketAddress");
37+
}
38+
39+
return socketAddress.GetIPEndPoint();
2640
}
2741

28-
return socketAddress.GetIPEndPoint();
42+
System.Net.SocketAddress address = GetNetSocketAddress(socketAddress);
43+
return thisObj.Create(address);
2944
}
3045

3146
internal static IPEndPoint Snapshot(this IPEndPoint thisObj)
3247
{
3348
return new IPEndPoint(thisObj.Address.Snapshot(), thisObj.Port);
3449
}
50+
51+
private static Internals.SocketAddress GetInternalSocketAddress(System.Net.SocketAddress address)
52+
{
53+
var result = new Internals.SocketAddress(address.Family, address.Size);
54+
for (int index = 0; index < address.Size; index++)
55+
{
56+
result[index] = address[index];
57+
}
58+
59+
return result;
60+
}
61+
62+
private static System.Net.SocketAddress GetNetSocketAddress(Internals.SocketAddress address)
63+
{
64+
var result = new System.Net.SocketAddress(address.Family, address.Size);
65+
for (int index = 0; index < address.Size; index++)
66+
{
67+
result[index] = address[index];
68+
}
69+
70+
return result;
71+
}
3572
}
3673
}

src/Common/tests/System/Net/Sockets/SocketTestServer.cs

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ public abstract partial class SocketTestServer : IDisposable
1212

1313
protected abstract int Port { get; }
1414

15-
public static SocketTestServer SocketTestServerFactory(EndPoint endpoint)
15+
public static SocketTestServer SocketTestServerFactory(EndPoint endpoint, ProtocolType protocolType = ProtocolType.Tcp)
1616
{
17-
return SocketTestServerFactory(DefaultNumConnections, DefaultReceiveBufferSize, endpoint);
17+
return SocketTestServerFactory(DefaultNumConnections, DefaultReceiveBufferSize, endpoint, protocolType);
1818
}
1919

2020
public static SocketTestServer SocketTestServerFactory(IPAddress address, out int port)
@@ -25,13 +25,15 @@ public static SocketTestServer SocketTestServerFactory(IPAddress address, out in
2525
public static SocketTestServer SocketTestServerFactory(
2626
int numConnections,
2727
int receiveBufferSize,
28-
EndPoint localEndPoint)
28+
EndPoint localEndPoint,
29+
ProtocolType protocolType = ProtocolType.Tcp)
2930
{
3031
return SocketTestServerFactory(
3132
s_implementationType,
3233
numConnections,
3334
receiveBufferSize,
34-
localEndPoint);
35+
localEndPoint,
36+
protocolType);
3537
}
3638

3739
public static SocketTestServer SocketTestServerFactory(
@@ -52,14 +54,15 @@ public static SocketTestServer SocketTestServerFactory(
5254
SocketImplementationType type,
5355
int numConnections,
5456
int receiveBufferSize,
55-
EndPoint localEndPoint)
57+
EndPoint localEndPoint,
58+
ProtocolType protocolType = ProtocolType.Tcp)
5659
{
5760
switch (type)
5861
{
5962
case SocketImplementationType.APM:
6063
return new SocketTestServerAPM(numConnections, receiveBufferSize, localEndPoint);
6164
case SocketImplementationType.Async:
62-
return new SocketTestServerAsync(numConnections, receiveBufferSize, localEndPoint);
65+
return new SocketTestServerAsync(numConnections, receiveBufferSize, localEndPoint, protocolType);
6366
default:
6467
throw new ArgumentOutOfRangeException("type");
6568
}

src/Common/tests/System/Net/Sockets/SocketTestServerAsync.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,13 @@ public class SocketTestServerAsync : SocketTestServer
3131
private int _numConnectedSockets; // The total number of clients connected to the server.
3232
private Semaphore _maxNumberAcceptedClientsSemaphore;
3333
private int _acceptRetryCount = 10;
34+
private ProtocolType _protocolType;
3435

3536
private object _listenSocketLock = new object();
3637

3738
protected sealed override int Port { get { return ((IPEndPoint)_listenSocket.LocalEndPoint).Port; } }
3839

39-
public SocketTestServerAsync(int numConnections, int receiveBufferSize, EndPoint localEndPoint)
40+
public SocketTestServerAsync(int numConnections, int receiveBufferSize, EndPoint localEndPoint, ProtocolType protocolType = ProtocolType.Tcp)
4041
{
4142
_log = VerboseTestLogging.GetInstance();
4243
_totalBytesRead = 0;
@@ -51,6 +52,7 @@ public SocketTestServerAsync(int numConnections, int receiveBufferSize, EndPoint
5152

5253
_readWritePool = new SocketAsyncEventArgsPool(numConnections);
5354
_maxNumberAcceptedClientsSemaphore = new Semaphore(numConnections, numConnections);
55+
_protocolType = protocolType;
5456
Init();
5557
Start(localEndPoint);
5658
}
@@ -108,7 +110,7 @@ private void Init()
108110
private void Start(EndPoint localEndPoint)
109111
{
110112
// Create the socket which listens for incoming connections.
111-
_listenSocket = new Socket(localEndPoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
113+
_listenSocket = new Socket(localEndPoint.AddressFamily, SocketType.Stream, _protocolType);
112114
_listenSocket.Bind(localEndPoint);
113115

114116
// Start the server with a listen backlog of 100 connections.

src/Native/System.Native/pal_networking.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1863,6 +1863,10 @@ static bool TryConvertProtocolTypePalToPlatform(int32_t palProtocolType, int* pl
18631863

18641864
switch (palProtocolType)
18651865
{
1866+
case PAL_PT_UNSPECIFIED:
1867+
*platformProtocolType = 0;
1868+
return true;
1869+
18661870
case PAL_PT_ICMP:
18671871
*platformProtocolType = IPPROTO_ICMP;
18681872
return true;

src/Native/System.Native/pal_networking.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ enum SocketType : int32_t
8181
*/
8282
enum ProtocolType : int32_t
8383
{
84+
PAL_PT_UNSPECIFIED = 0, // System.Net.ProtocolType.Unspecified
8485
PAL_PT_ICMP = 1, // System.Net.ProtocolType.Icmp
8586
PAL_PT_TCP = 6, // System.Net.ProtocolType.Tcp
8687
PAL_PT_UDP = 17, // System.Net.ProtocolType.Udp

src/System.Net.Sockets/src/System/Net/Sockets/Socket.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4231,7 +4231,7 @@ public bool ConnectAsync(SocketAsyncEventArgs e)
42314231
{
42324232
InternalBind(new IPEndPoint(IPAddress.Any, 0));
42334233
}
4234-
else
4234+
else if (endPointSnapshot.AddressFamily != AddressFamily.Unix)
42354235
{
42364236
InternalBind(new IPEndPoint(IPAddress.IPv6Any, 0));
42374237
}

src/System.Net.Sockets/tests/FunctionalTests/System.Net.Sockets.Tests.csproj

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727
<Compile Include="SocketAsyncExtensions.cs" />
2828
<Compile Include="SocketOptionNameTest.cs" />
2929
<Compile Include="SocketTestServerAPMMock.cs" />
30-
30+
<Compile Include="UnixDomainSocketTest.cs" />
31+
3132
<!-- Common Sockets files -->
3233
<Compile Include="$(CommonTestPath)\System\Net\Sockets\Configuration.cs">
3334
<Link>SocketCommon\Configuration.cs</Link>
@@ -77,7 +78,7 @@
7778
<Name>System.Net.Sockets</Name>
7879
</ProjectReference>
7980
</ItemGroup>
80-
81+
8182
<ItemGroup>
8283
<None Include="project.json" />
8384
</ItemGroup>

0 commit comments

Comments
 (0)