diff --git a/src/Common/src/System/Net/Internals/IPEndPointExtensions.cs b/src/Common/src/System/Net/Internals/IPEndPointExtensions.cs index ca35f98c7071..e575dd187b92 100644 --- a/src/Common/src/System/Net/Internals/IPEndPointExtensions.cs +++ b/src/Common/src/System/Net/Internals/IPEndPointExtensions.cs @@ -9,28 +9,65 @@ internal static class IPEndPointExtensions { public static Internals.SocketAddress Serialize(EndPoint endpoint) { - Debug.Assert(endpoint is IPEndPoint); + Debug.Assert(!(endpoint is DnsEndPoint)); - return new Internals.SocketAddress(((IPEndPoint)endpoint).Address, ((IPEndPoint)endpoint).Port); + var ipEndPoint = endpoint as IPEndPoint; + if (ipEndPoint != null) + { + return new Internals.SocketAddress(ipEndPoint.Address, ipEndPoint.Port); + } + + System.Net.SocketAddress address = endpoint.Serialize(); + return GetInternalSocketAddress(address); } public static EndPoint Create(this EndPoint thisObj, Internals.SocketAddress socketAddress) { - if (socketAddress.Family != thisObj.AddressFamily) + AddressFamily family = socketAddress.Family; + if (family != thisObj.AddressFamily) { - throw new ArgumentException(SR.Format(SR.net_InvalidAddressFamily, socketAddress.Family.ToString(), thisObj.GetType().FullName, thisObj.AddressFamily.ToString()), "socketAddress"); + throw new ArgumentException(SR.Format(SR.net_InvalidAddressFamily, family.ToString(), thisObj.GetType().FullName, thisObj.AddressFamily.ToString()), "socketAddress"); } - if (socketAddress.Size < 8) + + if (family == AddressFamily.InterNetwork || family == AddressFamily.InterNetworkV6) { - throw new ArgumentException(SR.Format(SR.net_InvalidSocketAddressSize, socketAddress.GetType().FullName, thisObj.GetType().FullName), "socketAddress"); + if (socketAddress.Size < 8) + { + throw new ArgumentException(SR.Format(SR.net_InvalidSocketAddressSize, socketAddress.GetType().FullName, thisObj.GetType().FullName), "socketAddress"); + } + + return socketAddress.GetIPEndPoint(); } - return socketAddress.GetIPEndPoint(); + System.Net.SocketAddress address = GetNetSocketAddress(socketAddress); + return thisObj.Create(address); } internal static IPEndPoint Snapshot(this IPEndPoint thisObj) { return new IPEndPoint(thisObj.Address.Snapshot(), thisObj.Port); } + + private static Internals.SocketAddress GetInternalSocketAddress(System.Net.SocketAddress address) + { + var result = new Internals.SocketAddress(address.Family, address.Size); + for (int index = 0; index < address.Size; index++) + { + result[index] = address[index]; + } + + return result; + } + + private static System.Net.SocketAddress GetNetSocketAddress(Internals.SocketAddress address) + { + var result = new System.Net.SocketAddress(address.Family, address.Size); + for (int index = 0; index < address.Size; index++) + { + result[index] = address[index]; + } + + return result; + } } } diff --git a/src/Common/tests/System/Net/Sockets/SocketTestServer.cs b/src/Common/tests/System/Net/Sockets/SocketTestServer.cs index 4c23bd2b30c5..e44c3a41f15f 100644 --- a/src/Common/tests/System/Net/Sockets/SocketTestServer.cs +++ b/src/Common/tests/System/Net/Sockets/SocketTestServer.cs @@ -12,9 +12,9 @@ public abstract partial class SocketTestServer : IDisposable protected abstract int Port { get; } - public static SocketTestServer SocketTestServerFactory(EndPoint endpoint) + public static SocketTestServer SocketTestServerFactory(EndPoint endpoint, ProtocolType protocolType = ProtocolType.Tcp) { - return SocketTestServerFactory(DefaultNumConnections, DefaultReceiveBufferSize, endpoint); + return SocketTestServerFactory(DefaultNumConnections, DefaultReceiveBufferSize, endpoint, protocolType); } public static SocketTestServer SocketTestServerFactory(IPAddress address, out int port) @@ -25,13 +25,15 @@ public static SocketTestServer SocketTestServerFactory(IPAddress address, out in public static SocketTestServer SocketTestServerFactory( int numConnections, int receiveBufferSize, - EndPoint localEndPoint) + EndPoint localEndPoint, + ProtocolType protocolType = ProtocolType.Tcp) { return SocketTestServerFactory( s_implementationType, numConnections, receiveBufferSize, - localEndPoint); + localEndPoint, + protocolType); } public static SocketTestServer SocketTestServerFactory( @@ -52,14 +54,15 @@ public static SocketTestServer SocketTestServerFactory( SocketImplementationType type, int numConnections, int receiveBufferSize, - EndPoint localEndPoint) + EndPoint localEndPoint, + ProtocolType protocolType = ProtocolType.Tcp) { switch (type) { case SocketImplementationType.APM: return new SocketTestServerAPM(numConnections, receiveBufferSize, localEndPoint); case SocketImplementationType.Async: - return new SocketTestServerAsync(numConnections, receiveBufferSize, localEndPoint); + return new SocketTestServerAsync(numConnections, receiveBufferSize, localEndPoint, protocolType); default: throw new ArgumentOutOfRangeException("type"); } diff --git a/src/Common/tests/System/Net/Sockets/SocketTestServerAsync.cs b/src/Common/tests/System/Net/Sockets/SocketTestServerAsync.cs index 02b82030e1de..6e983aeebfb9 100644 --- a/src/Common/tests/System/Net/Sockets/SocketTestServerAsync.cs +++ b/src/Common/tests/System/Net/Sockets/SocketTestServerAsync.cs @@ -31,12 +31,13 @@ public class SocketTestServerAsync : SocketTestServer private int _numConnectedSockets; // The total number of clients connected to the server. private Semaphore _maxNumberAcceptedClientsSemaphore; private int _acceptRetryCount = 10; + private ProtocolType _protocolType; private object _listenSocketLock = new object(); protected sealed override int Port { get { return ((IPEndPoint)_listenSocket.LocalEndPoint).Port; } } - public SocketTestServerAsync(int numConnections, int receiveBufferSize, EndPoint localEndPoint) + public SocketTestServerAsync(int numConnections, int receiveBufferSize, EndPoint localEndPoint, ProtocolType protocolType = ProtocolType.Tcp) { _log = VerboseTestLogging.GetInstance(); _totalBytesRead = 0; @@ -51,6 +52,7 @@ public SocketTestServerAsync(int numConnections, int receiveBufferSize, EndPoint _readWritePool = new SocketAsyncEventArgsPool(numConnections); _maxNumberAcceptedClientsSemaphore = new Semaphore(numConnections, numConnections); + _protocolType = protocolType; Init(); Start(localEndPoint); } @@ -108,7 +110,7 @@ private void Init() private void Start(EndPoint localEndPoint) { // Create the socket which listens for incoming connections. - _listenSocket = new Socket(localEndPoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp); + _listenSocket = new Socket(localEndPoint.AddressFamily, SocketType.Stream, _protocolType); _listenSocket.Bind(localEndPoint); // Start the server with a listen backlog of 100 connections. diff --git a/src/Native/System.Native/pal_networking.cpp b/src/Native/System.Native/pal_networking.cpp index 5ac8cccc6367..a7c8f503f02f 100644 --- a/src/Native/System.Native/pal_networking.cpp +++ b/src/Native/System.Native/pal_networking.cpp @@ -1863,6 +1863,10 @@ static bool TryConvertProtocolTypePalToPlatform(int32_t palProtocolType, int* pl switch (palProtocolType) { + case PAL_PT_UNSPECIFIED: + *platformProtocolType = 0; + return true; + case PAL_PT_ICMP: *platformProtocolType = IPPROTO_ICMP; return true; diff --git a/src/Native/System.Native/pal_networking.h b/src/Native/System.Native/pal_networking.h index d5579aac0a23..891a3c88673d 100644 --- a/src/Native/System.Native/pal_networking.h +++ b/src/Native/System.Native/pal_networking.h @@ -81,6 +81,7 @@ enum SocketType : int32_t */ enum ProtocolType : int32_t { + PAL_PT_UNSPECIFIED = 0, // System.Net.ProtocolType.Unspecified PAL_PT_ICMP = 1, // System.Net.ProtocolType.Icmp PAL_PT_TCP = 6, // System.Net.ProtocolType.Tcp PAL_PT_UDP = 17, // System.Net.ProtocolType.Udp diff --git a/src/System.Net.Sockets/src/System/Net/Sockets/Socket.cs b/src/System.Net.Sockets/src/System/Net/Sockets/Socket.cs index c036c911d94b..bcb97c8defae 100644 --- a/src/System.Net.Sockets/src/System/Net/Sockets/Socket.cs +++ b/src/System.Net.Sockets/src/System/Net/Sockets/Socket.cs @@ -4231,7 +4231,7 @@ public bool ConnectAsync(SocketAsyncEventArgs e) { InternalBind(new IPEndPoint(IPAddress.Any, 0)); } - else + else if (endPointSnapshot.AddressFamily != AddressFamily.Unix) { InternalBind(new IPEndPoint(IPAddress.IPv6Any, 0)); } diff --git a/src/System.Net.Sockets/tests/FunctionalTests/System.Net.Sockets.Tests.csproj b/src/System.Net.Sockets/tests/FunctionalTests/System.Net.Sockets.Tests.csproj index 568a83233e85..9b4e527c2ea1 100644 --- a/src/System.Net.Sockets/tests/FunctionalTests/System.Net.Sockets.Tests.csproj +++ b/src/System.Net.Sockets/tests/FunctionalTests/System.Net.Sockets.Tests.csproj @@ -27,7 +27,8 @@ - + + SocketCommon\Configuration.cs @@ -77,7 +78,7 @@ System.Net.Sockets - + diff --git a/src/System.Net.Sockets/tests/FunctionalTests/UnixDomainSocketTest.cs b/src/System.Net.Sockets/tests/FunctionalTests/UnixDomainSocketTest.cs new file mode 100644 index 000000000000..8d01972f2fdc --- /dev/null +++ b/src/System.Net.Sockets/tests/FunctionalTests/UnixDomainSocketTest.cs @@ -0,0 +1,223 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.IO; +using System.Net.Test.Common; +using System.Text; +using System.Threading; + +using Xunit; +using Xunit.Abstractions; + +namespace System.Net.Sockets.Tests +{ + public class UnixDomainSocketTest + { + private readonly ITestOutputHelper _log; + + public UnixDomainSocketTest(ITestOutputHelper output) + { + _log = TestLogging.GetInstance(); + } + + private void OnConnectAsyncCompleted(object sender, SocketAsyncEventArgs args) + { + ManualResetEvent complete = (ManualResetEvent)args.UserToken; + complete.Set(); + } + + [Fact] + [PlatformSpecific(PlatformID.Windows)] + public void Socket_CreateUnixDomainSocket_Throws_OnWindows() + { + // Throws SocketException with this message "An address incompatible with the requested protocol was used" + Assert.Throws(() => new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified)); + } + + [Fact] + [PlatformSpecific(PlatformID.Linux | PlatformID.OSX)] + public void Socket_ConnectAsyncUnixDomainSocketEndPoint_Success() + { + string path = null; + SocketTestServer server = null; + UnixDomainSocketEndPoint endPoint = null; + + for (int attempt = 0; attempt < 5; attempt++) + { + path = GetRandomNonExistingFilePath(); + endPoint = new UnixDomainSocketEndPoint(path); + try + { + server = SocketTestServer.SocketTestServerFactory(endPoint, ProtocolType.Unspecified); + break; + } + catch (SocketException) + { + // Path selection is contingent on a successful Bind(). + // If it fails, the next iteration will try another path. + } + } + + try + { + Assert.NotNull(server); + + SocketAsyncEventArgs args = new SocketAsyncEventArgs(); + args.RemoteEndPoint = endPoint; + args.Completed += OnConnectAsyncCompleted; + + ManualResetEvent complete = new ManualResetEvent(false); + args.UserToken = complete; + + Socket sock = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified); + Assert.True(sock.ConnectAsync(args)); + + complete.WaitOne(); + + Assert.Equal(SocketError.Success, args.SocketError); + Assert.Null(args.ConnectByNameError); + + complete.Dispose(); + sock.Dispose(); + server.Dispose(); + } + finally + { + File.Delete(path); + } + } + + [Fact] + [PlatformSpecific(PlatformID.Linux | PlatformID.OSX)] + public void Socket_ConnectAsyncUnixDomainSocketEndPoint_NotServer() + { + string path = GetRandomNonExistingFilePath(); + var endPoint = new UnixDomainSocketEndPoint(path); + + SocketAsyncEventArgs args = new SocketAsyncEventArgs(); + args.RemoteEndPoint = endPoint; + args.Completed += OnConnectAsyncCompleted; + + ManualResetEvent complete = new ManualResetEvent(false); + args.UserToken = complete; + + Socket sock = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified); + + bool willRaiseEvent = sock.ConnectAsync(args); + if (willRaiseEvent) + { + complete.WaitOne(); + } + + Assert.Equal(SocketError.SocketError, args.SocketError); + + complete.Dispose(); + sock.Dispose(); + } + + private static string GetRandomNonExistingFilePath() + { + string result; + do + { + result = Path.Combine(Path.GetTempPath(), Path.GetRandomFileName()); + } + while (File.Exists(result)); + + return result; + } + + private class UnixDomainSocketEndPoint : EndPoint + { + private static readonly Encoding PathEncoding = Encoding.UTF8; + + private const int MaxPathLength = 92; // sockaddr_un.sun_path at http://pubs.opengroup.org/onlinepubs/9699919799/basedefs/sys_un.h.html + private const int PathOffset = 2; // = offsetof(struct sockaddr_un, sun_path). It's the same on Linux and OSX + private const int MaxSocketAddressSize = PathOffset + MaxPathLength; + private const int MinSocketAddressSize = PathOffset + 2; // +1 for one character and +1 for \0 ending + private const AddressFamily EndPointAddressFamily = AddressFamily.Unix; + + private readonly string _path; + private readonly byte[] _encodedPath; + + public UnixDomainSocketEndPoint(string path) + { + if (path == null) + { + throw new ArgumentNullException("path"); + } + + if (path.Length == 0 || PathEncoding.GetByteCount(path) >= MaxPathLength) + { + throw new ArgumentOutOfRangeException("path"); + } + + _path = path; + _encodedPath = PathEncoding.GetBytes(_path); + } + + internal UnixDomainSocketEndPoint(SocketAddress socketAddress) + { + if (socketAddress == null) + { + throw new ArgumentNullException("socketAddress"); + } + + if (socketAddress.Family != EndPointAddressFamily || socketAddress.Size < MinSocketAddressSize || socketAddress.Size > MaxSocketAddressSize) + { + throw new ArgumentException("socketAddress"); + } + + _encodedPath = new byte[socketAddress.Size - PathOffset]; + for (int index = 0; index < socketAddress.Size - PathOffset; index++) + { + _encodedPath[index] = socketAddress[PathOffset + index]; + } + + _path = PathEncoding.GetString(_encodedPath); + } + + public string Path + { + get + { + return _path; + } + } + + public override AddressFamily AddressFamily + { + get + { + return EndPointAddressFamily; + } + } + + public override SocketAddress Serialize() + { + SocketAddress result = new SocketAddress(AddressFamily.Unix, MaxSocketAddressSize); + + // Ctor has already checked that PathOffset + _encodedPath.Length < MaxSocketAddressSize + for (int index = 0; index < _encodedPath.Length; index++) + { + result[PathOffset + index] = _encodedPath[index]; + } + + // The path must be ending with \0 + result[PathOffset + _encodedPath.Length] = 0; + + return result; + } + + public override EndPoint Create(SocketAddress socketAddress) + { + return new UnixDomainSocketEndPoint(socketAddress); + } + + public override string ToString() + { + return Path; + } + } + } +}