diff --git a/Sources/NIOSSH/Connection State Machine/States/SentVersionState.swift b/Sources/NIOSSH/Connection State Machine/States/SentVersionState.swift index c855ad5..c3f9818 100644 --- a/Sources/NIOSSH/Connection State Machine/States/SentVersionState.swift +++ b/Sources/NIOSSH/Connection State Machine/States/SentVersionState.swift @@ -35,7 +35,7 @@ extension SSHConnectionStateMachine { self.serializer = state.serializer self.protectionSchemes = state.protectionSchemes - self.parser = SSHPacketParser(allocator: allocator) + self.parser = SSHPacketParser(isServer: self.role.isServer, allocator: allocator) self.allocator = allocator } diff --git a/Sources/NIOSSH/SSHPacketParser.swift b/Sources/NIOSSH/SSHPacketParser.swift index 7e739ee..cd06163 100644 --- a/Sources/NIOSSH/SSHPacketParser.swift +++ b/Sources/NIOSSH/SSHPacketParser.swift @@ -23,6 +23,7 @@ struct SSHPacketParser { case encryptedWaitingForBytes(UInt32, NIOSSHTransportProtection) } + private let isServer: Bool private var buffer: ByteBuffer private var state: State private(set) var sequenceNumber: UInt32 @@ -32,7 +33,8 @@ struct SSHPacketParser { self.buffer.readerIndex } - init(allocator: ByteBufferAllocator) { + init(isServer: Bool, allocator: ByteBufferAllocator) { + self.isServer = isServer self.buffer = allocator.buffer(capacity: 0) self.state = .initialized self.sequenceNumber = 0 @@ -121,18 +123,31 @@ struct SSHPacketParser { let carriageReturn = UInt8(ascii: "\r") let lineFeed = UInt8(ascii: "\n") - // Search for version line, which starts with "SSH-". Lines without this prefix may come before the version line. - var slice = self.buffer.readableBytesView - while let lfIndex = slice.firstIndex(of: lineFeed), lfIndex < slice.endIndex { - if slice.starts(with: "SSH-".utf8) { - // Return all data upto the last LF we found, excluding the last [CR]LF. - slice = self.buffer.readableBytesView + // Per RFC 4253 ยง4.2: + // The server MAY send other lines of data before sending the version string. + // This means that server does not expect any lines before version so we will return all data before first line feed + if self.isServer { + // Looking for a string ending with \r\n + let slice = self.buffer.readableBytesView + if let lfIndex = slice.firstIndex(of: lineFeed), lfIndex < slice.endIndex { let versionEndIndex = slice[lfIndex.advanced(by: -1)] == carriageReturn ? lfIndex.advanced(by: -1) : lfIndex let version = String(decoding: slice[slice.startIndex ..< versionEndIndex], as: UTF8.self) self.buffer.moveReaderIndex(forwardBy: slice.startIndex.distance(to: lfIndex).advanced(by: 1)) return version - } else { - slice = slice[slice.index(after: lfIndex)...] + } + } else { + // Search for version line, which starts with "SSH-". Lines without this prefix may come before the version line. + var slice = self.buffer.readableBytesView + let startIndex = slice.startIndex + while let lfIndex = slice.firstIndex(of: lineFeed), lfIndex < slice.endIndex { + if slice.starts(with: "SSH-".utf8) { + let versionEndIndex = slice[lfIndex.advanced(by: -1)] == carriageReturn ? lfIndex.advanced(by: -1) : lfIndex + let version = String(decoding: slice[slice.startIndex ..< versionEndIndex], as: UTF8.self) + self.buffer.moveReaderIndex(forwardBy: startIndex.distance(to: lfIndex).advanced(by: 1)) + return version + } else { + slice = slice[slice.index(after: lfIndex)...] + } } } return nil diff --git a/Sources/NIOSSHClient/ExecHandler.swift b/Sources/NIOSSHClient/ExecHandler.swift index 2bca12d..8dd244e 100644 --- a/Sources/NIOSSHClient/ExecHandler.swift +++ b/Sources/NIOSSHClient/ExecHandler.swift @@ -48,7 +48,7 @@ final class ExampleExecHandler: ChannelDuplexHandler { DispatchQueue(label: "pipe bootstrap").async { bootstrap.channelOption(ChannelOptions.allowRemoteHalfClosure, value: true).channelInitializer { channel in channel.pipeline.addHandler(theirs) - }.withPipes(inputDescriptor: 0, outputDescriptor: 1).whenComplete { result in + }.takingOwnershipOfDescriptors(input: 0, output: 1).whenComplete { result in switch result { case .success: // We need to exec a thing. diff --git a/Sources/NIOSSHServer/ExecHandler.swift b/Sources/NIOSSHServer/ExecHandler.swift index a3cc644..0970ba5 100644 --- a/Sources/NIOSSHServer/ExecHandler.swift +++ b/Sources/NIOSSHServer/ExecHandler.swift @@ -120,7 +120,7 @@ final class ExampleExecHandler: ChannelDuplexHandler { .channelOption(ChannelOptions.allowRemoteHalfClosure, value: true) .channelInitializer { pipeChannel in pipeChannel.pipeline.addHandler(theirs) - }.withPipes(inputDescriptor: dup(outPipe.fileHandleForReading.fileDescriptor), outputDescriptor: dup(inPipe.fileHandleForWriting.fileDescriptor)).wait() + }.takingOwnershipOfDescriptors(input: dup(outPipe.fileHandleForReading.fileDescriptor), output: dup(inPipe.fileHandleForWriting.fileDescriptor)).wait() // Ok, great, we've sorted stdout and stdin. For stderr we need a different strategy: we just park a thread for this. DispatchQueue(label: "stderrorwhatever").async { diff --git a/Tests/NIOSSHTests/SSHEncryptedTrafficTests.swift b/Tests/NIOSSHTests/SSHEncryptedTrafficTests.swift index bf9be04..d9e39a3 100644 --- a/Tests/NIOSSHTests/SSHEncryptedTrafficTests.swift +++ b/Tests/NIOSSHTests/SSHEncryptedTrafficTests.swift @@ -25,7 +25,7 @@ final class SSHEncryptedTrafficTests: XCTestCase { override func setUp() { self.serializer = SSHPacketSerializer() - self.parser = SSHPacketParser(allocator: .init()) + self.parser = SSHPacketParser(isServer: false, allocator: .init()) self.assertPacketRoundTrips(.version("SSH-2.0-SwiftSSH_1.0")) } diff --git a/Tests/NIOSSHTests/SSHPackerSerializerTests.swift b/Tests/NIOSSHTests/SSHPackerSerializerTests.swift index fdef158..42e8add 100644 --- a/Tests/NIOSSHTests/SSHPackerSerializerTests.swift +++ b/Tests/NIOSSHTests/SSHPackerSerializerTests.swift @@ -51,7 +51,7 @@ final class SSHPacketSerializerTests: XCTestCase { let message = SSHMessage.disconnect(.init(reason: 42, description: "description", tag: "tag")) let allocator = ByteBufferAllocator() var serializer = SSHPacketSerializer() - var parser = SSHPacketParser(allocator: allocator) + var parser = SSHPacketParser(isServer: false, allocator: allocator) self.runVersionHandshake(serializer: &serializer, parser: &parser) @@ -74,7 +74,7 @@ final class SSHPacketSerializerTests: XCTestCase { let message = SSHMessage.serviceRequest(.init(service: "ssh-userauth")) let allocator = ByteBufferAllocator() var serializer = SSHPacketSerializer() - var parser = SSHPacketParser(allocator: allocator) + var parser = SSHPacketParser(isServer: false, allocator: allocator) self.runVersionHandshake(serializer: &serializer, parser: &parser) @@ -97,7 +97,7 @@ final class SSHPacketSerializerTests: XCTestCase { let message = SSHMessage.serviceAccept(.init(service: "ssh-userauth")) let allocator = ByteBufferAllocator() var serializer = SSHPacketSerializer() - var parser = SSHPacketParser(allocator: allocator) + var parser = SSHPacketParser(isServer: false, allocator: allocator) self.runVersionHandshake(serializer: &serializer, parser: &parser) @@ -133,7 +133,7 @@ final class SSHPacketSerializerTests: XCTestCase { )) let allocator = ByteBufferAllocator() var serializer = SSHPacketSerializer() - var parser = SSHPacketParser(allocator: allocator) + var parser = SSHPacketParser(isServer: false, allocator: allocator) self.runVersionHandshake(serializer: &serializer, parser: &parser) @@ -165,7 +165,7 @@ final class SSHPacketSerializerTests: XCTestCase { let message = SSHMessage.keyExchangeInit(.init(publicKey: ByteBuffer.of(bytes: [42]))) let allocator = ByteBufferAllocator() var serializer = SSHPacketSerializer() - var parser = SSHPacketParser(allocator: allocator) + var parser = SSHPacketParser(isServer: false, allocator: allocator) self.runVersionHandshake(serializer: &serializer, parser: &parser) @@ -193,7 +193,7 @@ final class SSHPacketSerializerTests: XCTestCase { )) let allocator = ByteBufferAllocator() var serializer = SSHPacketSerializer() - var parser = SSHPacketParser(allocator: allocator) + var parser = SSHPacketParser(isServer: false, allocator: allocator) self.runVersionHandshake(serializer: &serializer, parser: &parser) @@ -226,7 +226,7 @@ final class SSHPacketSerializerTests: XCTestCase { let message = SSHMessage.newKeys let allocator = ByteBufferAllocator() var serializer = SSHPacketSerializer() - var parser = SSHPacketParser(allocator: allocator) + var parser = SSHPacketParser(isServer: false, allocator: allocator) self.runVersionHandshake(serializer: &serializer, parser: &parser) @@ -247,7 +247,7 @@ final class SSHPacketSerializerTests: XCTestCase { let message = SSHMessage.newKeys let allocator = ByteBufferAllocator() var serializer = SSHPacketSerializer() - var parser = SSHPacketParser(allocator: allocator) + var parser = SSHPacketParser(isServer: false, allocator: allocator) self.runVersionHandshake(serializer: &serializer, parser: &parser) diff --git a/Tests/NIOSSHTests/SSHPacketParserTests.swift b/Tests/NIOSSHTests/SSHPacketParserTests.swift index cd5c960..c3c9b45 100644 --- a/Tests/NIOSSHTests/SSHPacketParserTests.swift +++ b/Tests/NIOSSHTests/SSHPacketParserTests.swift @@ -38,7 +38,7 @@ final class SSHPacketParserTests: XCTestCase { } func testReadVersion() throws { - var parser = SSHPacketParser(allocator: ByteBufferAllocator()) + var parser = SSHPacketParser(isServer: false, allocator: ByteBufferAllocator()) var part1 = ByteBuffer.of(string: "SSH-2.0-") parser.append(bytes: &part1) @@ -57,8 +57,8 @@ final class SSHPacketParserTests: XCTestCase { } } - func testReadVersionWithExtraLines() throws { - var parser = SSHPacketParser(allocator: ByteBufferAllocator()) + func testReadVersionWithExtraLinesOnClient() throws { + var parser = SSHPacketParser(isServer: false, allocator: ByteBufferAllocator()) var part1 = ByteBuffer.of(string: "xxxx\r\nyyyy\r\nSSH-2.0-") parser.append(bytes: &part1) @@ -70,14 +70,33 @@ final class SSHPacketParserTests: XCTestCase { switch try parser.nextPacket() { case .version(let string): - XCTAssertEqual(string, "xxxx\r\nyyyy\r\nSSH-2.0-OpenSSH_7.9") + XCTAssertEqual(string, "SSH-2.0-OpenSSH_7.9") + default: + XCTFail("Expecting .version") + } + } + + func testReadVersionWithExtraLinesOnServer() throws { + var parser = SSHPacketParser(isServer: true, allocator: ByteBufferAllocator()) + + var part1 = ByteBuffer.of(string: "xx") + parser.append(bytes: &part1) + + XCTAssertNil(try parser.nextPacket()) + + var part2 = ByteBuffer.of(string: "xx\r\nyyyy\r\nSSH-2.0-OpenSSH_7.9\r\n") + parser.append(bytes: &part2) + + switch try parser.nextPacket() { + case .version(let string): + XCTAssertEqual(string, "xxxx") default: XCTFail("Expecting .version") } } func testReadVersionWithoutCarriageReturn() throws { - var parser = SSHPacketParser(allocator: ByteBufferAllocator()) + var parser = SSHPacketParser(isServer: false, allocator: ByteBufferAllocator()) var part1 = ByteBuffer.of(string: "SSH-2.0-") parser.append(bytes: &part1) @@ -95,8 +114,8 @@ final class SSHPacketParserTests: XCTestCase { } } - func testReadVersionWithExtraLinesWithoutCarriageReturn() throws { - var parser = SSHPacketParser(allocator: ByteBufferAllocator()) + func testReadVersionWithExtraLinesWithoutCarriageReturnOnClient() throws { + var parser = SSHPacketParser(isServer: false, allocator: ByteBufferAllocator()) var part1 = ByteBuffer.of(string: "xxxx\nyyyy\nSSH-2.0-") parser.append(bytes: &part1) @@ -108,14 +127,33 @@ final class SSHPacketParserTests: XCTestCase { switch try parser.nextPacket() { case .version(let string): - XCTAssertEqual(string, "xxxx\nyyyy\nSSH-2.0-OpenSSH_7.4") + XCTAssertEqual(string, "SSH-2.0-OpenSSH_7.4") + default: + XCTFail("Expecting .version") + } + } + + func testReadVersionWithExtraLinesWithoutCarriageReturnOnServer() throws { + var parser = SSHPacketParser(isServer: true, allocator: ByteBufferAllocator()) + + var part1 = ByteBuffer.of(string: "xx") + parser.append(bytes: &part1) + + XCTAssertNil(try parser.nextPacket()) + + var part2 = ByteBuffer.of(string: "xx\nyyyy\nSSH-2.0-OpenSSH_7.4\n") + parser.append(bytes: &part2) + + switch try parser.nextPacket() { + case .version(let string): + XCTAssertEqual(string, "xxxx") default: XCTFail("Expecting .version") } } func testBinaryInParts() throws { - var parser = SSHPacketParser(allocator: ByteBufferAllocator()) + var parser = SSHPacketParser(isServer: false, allocator: ByteBufferAllocator()) self.feedVersion(to: &parser) var part1 = ByteBuffer.of(bytes: [0, 0, 0]) @@ -148,7 +186,7 @@ final class SSHPacketParserTests: XCTestCase { } func testBinaryFull() throws { - var parser = SSHPacketParser(allocator: ByteBufferAllocator()) + var parser = SSHPacketParser(isServer: false, allocator: ByteBufferAllocator()) self.feedVersion(to: &parser) var part1 = ByteBuffer.of(bytes: [0, 0, 0, 28, 10, 5, 0, 0, 0, 12, 115, 115, 104, 45, 117, 115, 101, 114, 97, 117, 116, 104, 42, 111, 216, 12, 226, 248, 144, 175, 157, 207]) @@ -164,7 +202,7 @@ final class SSHPacketParserTests: XCTestCase { } func testBinaryTwoMessages() throws { - var parser = SSHPacketParser(allocator: ByteBufferAllocator()) + var parser = SSHPacketParser(isServer: false, allocator: ByteBufferAllocator()) self.feedVersion(to: &parser) var part = ByteBuffer.of(bytes: [0, 0, 0, 28, 10, 5, 0, 0, 0, 12, 115, 115, 104, 45, 117, 115, 101, 114, 97, 117, 116, 104, 42, 111, 216, 12, 226, 248, 144, 175, 157, 207, 0, 0, 0, 28, 10, 5, 0, 0, 0, 12, 115, 115, 104, 45, 117, 115, 101, 114, 97, 117, 116, 104, 42, 111, 216, 12, 226, 248, 144, 175, 157, 207]) @@ -187,7 +225,7 @@ final class SSHPacketParserTests: XCTestCase { } func testWeReclaimStorage() throws { - var parser = SSHPacketParser(allocator: ByteBufferAllocator()) + var parser = SSHPacketParser(isServer: false, allocator: ByteBufferAllocator()) self.feedVersion(to: &parser) XCTAssertNoThrow(try parser.nextPacket()) @@ -213,7 +251,7 @@ final class SSHPacketParserTests: XCTestCase { func testSequencePreservedBetweenPlainAndCypher() throws { let allocator = ByteBufferAllocator() - var parser = SSHPacketParser(allocator: allocator) + var parser = SSHPacketParser(isServer: false, allocator: allocator) self.feedVersion(to: &parser) var part = ByteBuffer(bytes: [0, 0, 0, 12, 10, 21, 41, 114, 125, 250, 3, 79, 3, 217, 166, 136])