diff --git a/Sources/NIOSSH/Connection State Machine/Operations/AcceptsVersionMessages.swift b/Sources/NIOSSH/Connection State Machine/Operations/AcceptsVersionMessages.swift index 874763d..2b5e92c 100644 --- a/Sources/NIOSSH/Connection State Machine/Operations/AcceptsVersionMessages.swift +++ b/Sources/NIOSSH/Connection State Machine/Operations/AcceptsVersionMessages.swift @@ -15,18 +15,71 @@ protocol AcceptsVersionMessages {} extension AcceptsVersionMessages { - func receiveVersionMessage(_ version: String) throws { - try self.validateVersion(version) + func receiveVersionMessage(_ banner: String, role: SSHConnectionRole) throws { + try self.validateBanner(banner, role: role) } - private func validateVersion(_ version: String) throws { - guard version.count > 7, version.hasPrefix("SSH-") else { - throw NIOSSHError.unsupportedVersion(version) + // From RFC 4253: + // + // > Protocol Version Exchange + // > + // > When the connection has been established, both sides MUST send an + // > identification string. This identification string MUST be + // > SSH-protoversion-softwareversion SP comments CR LF + // > Since the protocol being defined in this set of documents is version + // > 2.0, the 'protoversion' MUST be "2.0". The 'comments' string is + // > OPTIONAL. If the 'comments' string is included, a 'space' character + // > (denoted above as SP, ASCII 32) MUST separate the 'softwareversion' + // > and 'comments' strings. The identification MUST be terminated by a + // > single Carriage Return (CR) and a single Line Feed (LF) character + // > (ASCII 13 and 10, respectively). Implementers who wish to maintain + // > compatibility with older, undocumented versions of this protocol may + // > want to process the identification string without expecting the + // > presence of the carriage return character for reasons described in + // > Section 5 of this document. The null character MUST NOT be sent. + // > The maximum length of the string is 255 characters, including the + // > Carriage Return and Line Feed. + // > + // > The part of the identification string preceding the Carriage Return + // > and Line Feed is used in the Diffie-Hellman key exchange (see Section + // > 8). + // > + // > The server MAY send other lines of data before sending the version + // > string. Each line SHOULD be terminated by a Carriage Return and Line + // > Feed. Such lines MUST NOT begin with "SSH-", and SHOULD be encoded + // > in ISO-10646 UTF-8 [RFC3629] (language is not specified). Clients + // > MUST be able to process such lines. Such lines MAY be silently + // > ignored, or MAY be displayed to the client user. If they are + // > displayed, control character filtering, as discussed in [SSH-ARCH], + // > SHOULD be used. The primary use of this feature is to allow TCP- + // > wrappers to display an error message before disconnecting + private func validateBanner(_ banner: String, role: SSHConnectionRole) throws { + switch role { + case .client: + // split by \n + let lineFeed = UInt8(ascii: "\n") + for line in banner.utf8.split(separator: lineFeed) { + if try self.validateVersion(line) { + return + } + } + throw NIOSSHError.protocolViolation(protocolName: "version exchange", violation: "version string not found") + case .server: + guard try self.validateVersion(Substring(banner).utf8) else { + throw NIOSSHError.protocolViolation(protocolName: "version exchange", violation: "version string not found") + } } - let start = version.index(version.startIndex, offsetBy: 4) - let end = version.index(start, offsetBy: 3) - guard version[start ..< end] == "2.0" else { - throw NIOSSHError.unsupportedVersion(version) + } + + private func validateVersion(_ version: Substring.UTF8View) throws -> Bool { + if version.count > 7, version.starts(with: "SSH-".utf8) { + let start = version.index(version.startIndex, offsetBy: 4) + let end = version.index(start, offsetBy: 3) + guard version[start ..< end].elementsEqual(Substring("2.0").utf8) else { + throw NIOSSHError.unsupportedVersion(String(decoding: version, as: UTF8.self)) + } + return true } + return false } } diff --git a/Sources/NIOSSH/Connection State Machine/SSHConnectionStateMachine.swift b/Sources/NIOSSH/Connection State Machine/SSHConnectionStateMachine.swift index 2556353..f8318aa 100644 --- a/Sources/NIOSSH/Connection State Machine/SSHConnectionStateMachine.swift +++ b/Sources/NIOSSH/Connection State Machine/SSHConnectionStateMachine.swift @@ -134,7 +134,7 @@ struct SSHConnectionStateMachine { switch message { case .version(let version): - try state.receiveVersionMessage(version) + try state.receiveVersionMessage(version, role: self.role) let newState = KeyExchangeState(sentVersionState: state, allocator: allocator, loop: loop, remoteVersion: version) let message = newState.keyExchangeStateMachine.createKeyExchangeMessage() self.state = .keyExchange(newState) diff --git a/Tests/NIOSSHTests/SSHConnectionStateMachineTests.swift b/Tests/NIOSSHTests/SSHConnectionStateMachineTests.swift index dba36c9..2513275 100644 --- a/Tests/NIOSSHTests/SSHConnectionStateMachineTests.swift +++ b/Tests/NIOSSHTests/SSHConnectionStateMachineTests.swift @@ -424,6 +424,92 @@ final class SSHConnectionStateMachineTests: XCTestCase { XCTAssertNil(client.start()) } + + func testClientToleratesLinesBeforeVersion() throws { + let allocator = ByteBufferAllocator() + let loop = EmbeddedEventLoop() + var client = SSHConnectionStateMachine(role: .client(.init(userAuthDelegate: InfinitePasswordDelegate(), serverAuthDelegate: AcceptAllHostKeysDelegate()))) + + let message = client.start() + guard case message = Optional.some(SSHMultiMessage(SSHMessage.version(Constants.version))) else { + XCTFail("Unexpected message") + return + } + + var buffer = allocator.buffer(capacity: 42) + XCTAssertNoThrow(try client.processOutboundMessage(SSHMessage.version(Constants.version), buffer: &buffer, allocator: allocator, loop: loop)) + + var version = ByteBuffer(string: "xxxx\nyyy\nSSH-2.0-OpenSSH_8.1\r\n") + client.bufferInboundData(&version) + + XCTAssertNoThrow(try client.processInboundMessage(allocator: allocator, loop: loop)) + } + + func testServerRejectsLinesBeforeVersion() throws { + let allocator = ByteBufferAllocator() + let loop = EmbeddedEventLoop() + var server = SSHConnectionStateMachine(role: .server(.init(hostKeys: [NIOSSHPrivateKey(ed25519Key: .init())], userAuthDelegate: DenyThenAcceptDelegate(messagesToDeny: 1)))) + + let message = server.start() + guard case message = Optional.some(SSHMultiMessage(SSHMessage.version(Constants.version))) else { + XCTFail("Unexpected message") + return + } + + var buffer = allocator.buffer(capacity: 42) + XCTAssertNoThrow(try server.processOutboundMessage(SSHMessage.version(Constants.version), buffer: &buffer, allocator: allocator, loop: loop)) + + var version = ByteBuffer(string: "xxxx\nyyy\nSSH-2.0-OpenSSH_8.1\r\n") + server.bufferInboundData(&version) + + XCTAssertThrowsError(try server.processInboundMessage(allocator: allocator, loop: loop)) { error in + XCTAssertEqual((error as? NIOSSHError)?.type, .protocolViolation) + } + } + + func testClintVersionNotFound() throws { + let allocator = ByteBufferAllocator() + let loop = EmbeddedEventLoop() + var client = SSHConnectionStateMachine(role: .client(.init(userAuthDelegate: InfinitePasswordDelegate(), serverAuthDelegate: AcceptAllHostKeysDelegate()))) + + let message = client.start() + guard case message = Optional.some(SSHMultiMessage(SSHMessage.version(Constants.version))) else { + XCTFail("Unexpected message") + return + } + + var buffer = allocator.buffer(capacity: 42) + XCTAssertNoThrow(try client.processOutboundMessage(SSHMessage.version(Constants.version), buffer: &buffer, allocator: allocator, loop: loop)) + + var version = ByteBuffer(string: "SSH-\r\n") + client.bufferInboundData(&version) + + XCTAssertThrowsError(try client.processInboundMessage(allocator: allocator, loop: loop)) { error in + XCTAssertEqual((error as? NIOSSHError)?.type, .protocolViolation) + } + } + + func testVersionNotSupported() throws { + let allocator = ByteBufferAllocator() + let loop = EmbeddedEventLoop() + var client = SSHConnectionStateMachine(role: .client(.init(userAuthDelegate: InfinitePasswordDelegate(), serverAuthDelegate: AcceptAllHostKeysDelegate()))) + + let message = client.start() + guard case message = Optional.some(SSHMultiMessage(SSHMessage.version(Constants.version))) else { + XCTFail("Unexpected message") + return + } + + var buffer = allocator.buffer(capacity: 42) + XCTAssertNoThrow(try client.processOutboundMessage(SSHMessage.version(Constants.version), buffer: &buffer, allocator: allocator, loop: loop)) + + var version = ByteBuffer(string: "SSH-1.0-OpenSSH_8.1\r\n") + client.bufferInboundData(&version) + + XCTAssertThrowsError(try client.processInboundMessage(allocator: allocator, loop: loop)) { error in + XCTAssertEqual((error as? NIOSSHError)?.type, .unsupportedVersion) + } + } } extension Optional where Wrapped == SSHMultiMessage {