diff --git a/src/Common/src/System/Net/Internals/IPEndPointExtensions.cs b/src/Common/src/System/Net/Internals/IPEndPointExtensions.cs
index ca35f98c7071..e575dd187b92 100644
--- a/src/Common/src/System/Net/Internals/IPEndPointExtensions.cs
+++ b/src/Common/src/System/Net/Internals/IPEndPointExtensions.cs
@@ -9,28 +9,65 @@ internal static class IPEndPointExtensions
{
public static Internals.SocketAddress Serialize(EndPoint endpoint)
{
- Debug.Assert(endpoint is IPEndPoint);
+ Debug.Assert(!(endpoint is DnsEndPoint));
- return new Internals.SocketAddress(((IPEndPoint)endpoint).Address, ((IPEndPoint)endpoint).Port);
+ var ipEndPoint = endpoint as IPEndPoint;
+ if (ipEndPoint != null)
+ {
+ return new Internals.SocketAddress(ipEndPoint.Address, ipEndPoint.Port);
+ }
+
+ System.Net.SocketAddress address = endpoint.Serialize();
+ return GetInternalSocketAddress(address);
}
public static EndPoint Create(this EndPoint thisObj, Internals.SocketAddress socketAddress)
{
- if (socketAddress.Family != thisObj.AddressFamily)
+ AddressFamily family = socketAddress.Family;
+ if (family != thisObj.AddressFamily)
{
- throw new ArgumentException(SR.Format(SR.net_InvalidAddressFamily, socketAddress.Family.ToString(), thisObj.GetType().FullName, thisObj.AddressFamily.ToString()), "socketAddress");
+ throw new ArgumentException(SR.Format(SR.net_InvalidAddressFamily, family.ToString(), thisObj.GetType().FullName, thisObj.AddressFamily.ToString()), "socketAddress");
}
- if (socketAddress.Size < 8)
+
+ if (family == AddressFamily.InterNetwork || family == AddressFamily.InterNetworkV6)
{
- throw new ArgumentException(SR.Format(SR.net_InvalidSocketAddressSize, socketAddress.GetType().FullName, thisObj.GetType().FullName), "socketAddress");
+ if (socketAddress.Size < 8)
+ {
+ throw new ArgumentException(SR.Format(SR.net_InvalidSocketAddressSize, socketAddress.GetType().FullName, thisObj.GetType().FullName), "socketAddress");
+ }
+
+ return socketAddress.GetIPEndPoint();
}
- return socketAddress.GetIPEndPoint();
+ System.Net.SocketAddress address = GetNetSocketAddress(socketAddress);
+ return thisObj.Create(address);
}
internal static IPEndPoint Snapshot(this IPEndPoint thisObj)
{
return new IPEndPoint(thisObj.Address.Snapshot(), thisObj.Port);
}
+
+ private static Internals.SocketAddress GetInternalSocketAddress(System.Net.SocketAddress address)
+ {
+ var result = new Internals.SocketAddress(address.Family, address.Size);
+ for (int index = 0; index < address.Size; index++)
+ {
+ result[index] = address[index];
+ }
+
+ return result;
+ }
+
+ private static System.Net.SocketAddress GetNetSocketAddress(Internals.SocketAddress address)
+ {
+ var result = new System.Net.SocketAddress(address.Family, address.Size);
+ for (int index = 0; index < address.Size; index++)
+ {
+ result[index] = address[index];
+ }
+
+ return result;
+ }
}
}
diff --git a/src/Common/tests/System/Net/Sockets/SocketTestServer.cs b/src/Common/tests/System/Net/Sockets/SocketTestServer.cs
index 4c23bd2b30c5..e44c3a41f15f 100644
--- a/src/Common/tests/System/Net/Sockets/SocketTestServer.cs
+++ b/src/Common/tests/System/Net/Sockets/SocketTestServer.cs
@@ -12,9 +12,9 @@ public abstract partial class SocketTestServer : IDisposable
protected abstract int Port { get; }
- public static SocketTestServer SocketTestServerFactory(EndPoint endpoint)
+ public static SocketTestServer SocketTestServerFactory(EndPoint endpoint, ProtocolType protocolType = ProtocolType.Tcp)
{
- return SocketTestServerFactory(DefaultNumConnections, DefaultReceiveBufferSize, endpoint);
+ return SocketTestServerFactory(DefaultNumConnections, DefaultReceiveBufferSize, endpoint, protocolType);
}
public static SocketTestServer SocketTestServerFactory(IPAddress address, out int port)
@@ -25,13 +25,15 @@ public static SocketTestServer SocketTestServerFactory(IPAddress address, out in
public static SocketTestServer SocketTestServerFactory(
int numConnections,
int receiveBufferSize,
- EndPoint localEndPoint)
+ EndPoint localEndPoint,
+ ProtocolType protocolType = ProtocolType.Tcp)
{
return SocketTestServerFactory(
s_implementationType,
numConnections,
receiveBufferSize,
- localEndPoint);
+ localEndPoint,
+ protocolType);
}
public static SocketTestServer SocketTestServerFactory(
@@ -52,14 +54,15 @@ public static SocketTestServer SocketTestServerFactory(
SocketImplementationType type,
int numConnections,
int receiveBufferSize,
- EndPoint localEndPoint)
+ EndPoint localEndPoint,
+ ProtocolType protocolType = ProtocolType.Tcp)
{
switch (type)
{
case SocketImplementationType.APM:
return new SocketTestServerAPM(numConnections, receiveBufferSize, localEndPoint);
case SocketImplementationType.Async:
- return new SocketTestServerAsync(numConnections, receiveBufferSize, localEndPoint);
+ return new SocketTestServerAsync(numConnections, receiveBufferSize, localEndPoint, protocolType);
default:
throw new ArgumentOutOfRangeException("type");
}
diff --git a/src/Common/tests/System/Net/Sockets/SocketTestServerAsync.cs b/src/Common/tests/System/Net/Sockets/SocketTestServerAsync.cs
index 02b82030e1de..6e983aeebfb9 100644
--- a/src/Common/tests/System/Net/Sockets/SocketTestServerAsync.cs
+++ b/src/Common/tests/System/Net/Sockets/SocketTestServerAsync.cs
@@ -31,12 +31,13 @@ public class SocketTestServerAsync : SocketTestServer
private int _numConnectedSockets; // The total number of clients connected to the server.
private Semaphore _maxNumberAcceptedClientsSemaphore;
private int _acceptRetryCount = 10;
+ private ProtocolType _protocolType;
private object _listenSocketLock = new object();
protected sealed override int Port { get { return ((IPEndPoint)_listenSocket.LocalEndPoint).Port; } }
- public SocketTestServerAsync(int numConnections, int receiveBufferSize, EndPoint localEndPoint)
+ public SocketTestServerAsync(int numConnections, int receiveBufferSize, EndPoint localEndPoint, ProtocolType protocolType = ProtocolType.Tcp)
{
_log = VerboseTestLogging.GetInstance();
_totalBytesRead = 0;
@@ -51,6 +52,7 @@ public SocketTestServerAsync(int numConnections, int receiveBufferSize, EndPoint
_readWritePool = new SocketAsyncEventArgsPool(numConnections);
_maxNumberAcceptedClientsSemaphore = new Semaphore(numConnections, numConnections);
+ _protocolType = protocolType;
Init();
Start(localEndPoint);
}
@@ -108,7 +110,7 @@ private void Init()
private void Start(EndPoint localEndPoint)
{
// Create the socket which listens for incoming connections.
- _listenSocket = new Socket(localEndPoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
+ _listenSocket = new Socket(localEndPoint.AddressFamily, SocketType.Stream, _protocolType);
_listenSocket.Bind(localEndPoint);
// Start the server with a listen backlog of 100 connections.
diff --git a/src/Native/System.Native/pal_networking.cpp b/src/Native/System.Native/pal_networking.cpp
index 5ac8cccc6367..a7c8f503f02f 100644
--- a/src/Native/System.Native/pal_networking.cpp
+++ b/src/Native/System.Native/pal_networking.cpp
@@ -1863,6 +1863,10 @@ static bool TryConvertProtocolTypePalToPlatform(int32_t palProtocolType, int* pl
switch (palProtocolType)
{
+ case PAL_PT_UNSPECIFIED:
+ *platformProtocolType = 0;
+ return true;
+
case PAL_PT_ICMP:
*platformProtocolType = IPPROTO_ICMP;
return true;
diff --git a/src/Native/System.Native/pal_networking.h b/src/Native/System.Native/pal_networking.h
index d5579aac0a23..891a3c88673d 100644
--- a/src/Native/System.Native/pal_networking.h
+++ b/src/Native/System.Native/pal_networking.h
@@ -81,6 +81,7 @@ enum SocketType : int32_t
*/
enum ProtocolType : int32_t
{
+ PAL_PT_UNSPECIFIED = 0, // System.Net.ProtocolType.Unspecified
PAL_PT_ICMP = 1, // System.Net.ProtocolType.Icmp
PAL_PT_TCP = 6, // System.Net.ProtocolType.Tcp
PAL_PT_UDP = 17, // System.Net.ProtocolType.Udp
diff --git a/src/System.Net.Sockets/src/System/Net/Sockets/Socket.cs b/src/System.Net.Sockets/src/System/Net/Sockets/Socket.cs
index c036c911d94b..bcb97c8defae 100644
--- a/src/System.Net.Sockets/src/System/Net/Sockets/Socket.cs
+++ b/src/System.Net.Sockets/src/System/Net/Sockets/Socket.cs
@@ -4231,7 +4231,7 @@ public bool ConnectAsync(SocketAsyncEventArgs e)
{
InternalBind(new IPEndPoint(IPAddress.Any, 0));
}
- else
+ else if (endPointSnapshot.AddressFamily != AddressFamily.Unix)
{
InternalBind(new IPEndPoint(IPAddress.IPv6Any, 0));
}
diff --git a/src/System.Net.Sockets/tests/FunctionalTests/System.Net.Sockets.Tests.csproj b/src/System.Net.Sockets/tests/FunctionalTests/System.Net.Sockets.Tests.csproj
index 568a83233e85..9b4e527c2ea1 100644
--- a/src/System.Net.Sockets/tests/FunctionalTests/System.Net.Sockets.Tests.csproj
+++ b/src/System.Net.Sockets/tests/FunctionalTests/System.Net.Sockets.Tests.csproj
@@ -27,7 +27,8 @@
-
+
+
SocketCommon\Configuration.cs
@@ -77,7 +78,7 @@
System.Net.Sockets
-
+
diff --git a/src/System.Net.Sockets/tests/FunctionalTests/UnixDomainSocketTest.cs b/src/System.Net.Sockets/tests/FunctionalTests/UnixDomainSocketTest.cs
new file mode 100644
index 000000000000..8d01972f2fdc
--- /dev/null
+++ b/src/System.Net.Sockets/tests/FunctionalTests/UnixDomainSocketTest.cs
@@ -0,0 +1,223 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System.IO;
+using System.Net.Test.Common;
+using System.Text;
+using System.Threading;
+
+using Xunit;
+using Xunit.Abstractions;
+
+namespace System.Net.Sockets.Tests
+{
+ public class UnixDomainSocketTest
+ {
+ private readonly ITestOutputHelper _log;
+
+ public UnixDomainSocketTest(ITestOutputHelper output)
+ {
+ _log = TestLogging.GetInstance();
+ }
+
+ private void OnConnectAsyncCompleted(object sender, SocketAsyncEventArgs args)
+ {
+ ManualResetEvent complete = (ManualResetEvent)args.UserToken;
+ complete.Set();
+ }
+
+ [Fact]
+ [PlatformSpecific(PlatformID.Windows)]
+ public void Socket_CreateUnixDomainSocket_Throws_OnWindows()
+ {
+ // Throws SocketException with this message "An address incompatible with the requested protocol was used"
+ Assert.Throws(() => new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified));
+ }
+
+ [Fact]
+ [PlatformSpecific(PlatformID.Linux | PlatformID.OSX)]
+ public void Socket_ConnectAsyncUnixDomainSocketEndPoint_Success()
+ {
+ string path = null;
+ SocketTestServer server = null;
+ UnixDomainSocketEndPoint endPoint = null;
+
+ for (int attempt = 0; attempt < 5; attempt++)
+ {
+ path = GetRandomNonExistingFilePath();
+ endPoint = new UnixDomainSocketEndPoint(path);
+ try
+ {
+ server = SocketTestServer.SocketTestServerFactory(endPoint, ProtocolType.Unspecified);
+ break;
+ }
+ catch (SocketException)
+ {
+ // Path selection is contingent on a successful Bind().
+ // If it fails, the next iteration will try another path.
+ }
+ }
+
+ try
+ {
+ Assert.NotNull(server);
+
+ SocketAsyncEventArgs args = new SocketAsyncEventArgs();
+ args.RemoteEndPoint = endPoint;
+ args.Completed += OnConnectAsyncCompleted;
+
+ ManualResetEvent complete = new ManualResetEvent(false);
+ args.UserToken = complete;
+
+ Socket sock = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified);
+ Assert.True(sock.ConnectAsync(args));
+
+ complete.WaitOne();
+
+ Assert.Equal(SocketError.Success, args.SocketError);
+ Assert.Null(args.ConnectByNameError);
+
+ complete.Dispose();
+ sock.Dispose();
+ server.Dispose();
+ }
+ finally
+ {
+ File.Delete(path);
+ }
+ }
+
+ [Fact]
+ [PlatformSpecific(PlatformID.Linux | PlatformID.OSX)]
+ public void Socket_ConnectAsyncUnixDomainSocketEndPoint_NotServer()
+ {
+ string path = GetRandomNonExistingFilePath();
+ var endPoint = new UnixDomainSocketEndPoint(path);
+
+ SocketAsyncEventArgs args = new SocketAsyncEventArgs();
+ args.RemoteEndPoint = endPoint;
+ args.Completed += OnConnectAsyncCompleted;
+
+ ManualResetEvent complete = new ManualResetEvent(false);
+ args.UserToken = complete;
+
+ Socket sock = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified);
+
+ bool willRaiseEvent = sock.ConnectAsync(args);
+ if (willRaiseEvent)
+ {
+ complete.WaitOne();
+ }
+
+ Assert.Equal(SocketError.SocketError, args.SocketError);
+
+ complete.Dispose();
+ sock.Dispose();
+ }
+
+ private static string GetRandomNonExistingFilePath()
+ {
+ string result;
+ do
+ {
+ result = Path.Combine(Path.GetTempPath(), Path.GetRandomFileName());
+ }
+ while (File.Exists(result));
+
+ return result;
+ }
+
+ private class UnixDomainSocketEndPoint : EndPoint
+ {
+ private static readonly Encoding PathEncoding = Encoding.UTF8;
+
+ private const int MaxPathLength = 92; // sockaddr_un.sun_path at http://pubs.opengroup.org/onlinepubs/9699919799/basedefs/sys_un.h.html
+ private const int PathOffset = 2; // = offsetof(struct sockaddr_un, sun_path). It's the same on Linux and OSX
+ private const int MaxSocketAddressSize = PathOffset + MaxPathLength;
+ private const int MinSocketAddressSize = PathOffset + 2; // +1 for one character and +1 for \0 ending
+ private const AddressFamily EndPointAddressFamily = AddressFamily.Unix;
+
+ private readonly string _path;
+ private readonly byte[] _encodedPath;
+
+ public UnixDomainSocketEndPoint(string path)
+ {
+ if (path == null)
+ {
+ throw new ArgumentNullException("path");
+ }
+
+ if (path.Length == 0 || PathEncoding.GetByteCount(path) >= MaxPathLength)
+ {
+ throw new ArgumentOutOfRangeException("path");
+ }
+
+ _path = path;
+ _encodedPath = PathEncoding.GetBytes(_path);
+ }
+
+ internal UnixDomainSocketEndPoint(SocketAddress socketAddress)
+ {
+ if (socketAddress == null)
+ {
+ throw new ArgumentNullException("socketAddress");
+ }
+
+ if (socketAddress.Family != EndPointAddressFamily || socketAddress.Size < MinSocketAddressSize || socketAddress.Size > MaxSocketAddressSize)
+ {
+ throw new ArgumentException("socketAddress");
+ }
+
+ _encodedPath = new byte[socketAddress.Size - PathOffset];
+ for (int index = 0; index < socketAddress.Size - PathOffset; index++)
+ {
+ _encodedPath[index] = socketAddress[PathOffset + index];
+ }
+
+ _path = PathEncoding.GetString(_encodedPath);
+ }
+
+ public string Path
+ {
+ get
+ {
+ return _path;
+ }
+ }
+
+ public override AddressFamily AddressFamily
+ {
+ get
+ {
+ return EndPointAddressFamily;
+ }
+ }
+
+ public override SocketAddress Serialize()
+ {
+ SocketAddress result = new SocketAddress(AddressFamily.Unix, MaxSocketAddressSize);
+
+ // Ctor has already checked that PathOffset + _encodedPath.Length < MaxSocketAddressSize
+ for (int index = 0; index < _encodedPath.Length; index++)
+ {
+ result[PathOffset + index] = _encodedPath[index];
+ }
+
+ // The path must be ending with \0
+ result[PathOffset + _encodedPath.Length] = 0;
+
+ return result;
+ }
+
+ public override EndPoint Create(SocketAddress socketAddress)
+ {
+ return new UnixDomainSocketEndPoint(socketAddress);
+ }
+
+ public override string ToString()
+ {
+ return Path;
+ }
+ }
+ }
+}