diff --git a/Sources/NIOSSH/SSHPacketParser.swift b/Sources/NIOSSH/SSHPacketParser.swift index 787dbe6..1d70c94 100644 --- a/Sources/NIOSSH/SSHPacketParser.swift +++ b/Sources/NIOSSH/SSHPacketParser.swift @@ -112,13 +112,22 @@ struct SSHPacketParser { } private mutating func readVersion() throws -> String? { - // Looking for a string ending with \r\n - let slice = self.buffer.readableBytesView - if let cr = slice.firstIndex(of: 13), cr.advanced(by: 1) < slice.endIndex, slice[cr.advanced(by: 1)] == 10 { - let version = String(decoding: slice[slice.startIndex ..< cr], as: UTF8.self) - // read \r\n - self.buffer.moveReaderIndex(forwardBy: slice.startIndex.distance(to: cr).advanced(by: 2)) - return version + let carriageReturn = UInt8(ascii: "\r") + let lineFeed = UInt8(ascii: "\n") + + // Search for version line, which starts with "SSH-". Lines without this prefix may come before the version line. + var slice = self.buffer.readableBytesView + while let lfIndex = slice.firstIndex(of: lineFeed), lfIndex < slice.endIndex { + if slice.starts(with: "SSH-".utf8) { + // Return all data upto the last LF we found, excluding the last [CR]LF. + slice = self.buffer.readableBytesView + let versionEndIndex = slice[lfIndex.advanced(by: -1)] == carriageReturn ? lfIndex.advanced(by: -1) : lfIndex + let version = String(decoding: slice[slice.startIndex ..< versionEndIndex], as: UTF8.self) + self.buffer.moveReaderIndex(forwardBy: slice.startIndex.distance(to: lfIndex).advanced(by: 1)) + return version + } else { + slice = slice[slice.index(after: lfIndex)...] + } } return nil } diff --git a/Tests/NIOSSHTests/SSHPacketParserTests.swift b/Tests/NIOSSHTests/SSHPacketParserTests.swift index 65fe814..0a7192c 100644 --- a/Tests/NIOSSHTests/SSHPacketParserTests.swift +++ b/Tests/NIOSSHTests/SSHPacketParserTests.swift @@ -54,6 +54,63 @@ final class SSHPacketParserTests: XCTestCase { } } + func testReadVersionWithExtraLines() throws { + var parser = SSHPacketParser(allocator: ByteBufferAllocator()) + + var part1 = ByteBuffer.of(string: "xxxx\r\nyyyy\r\nSSH-2.0-") + parser.append(bytes: &part1) + + XCTAssertNil(try parser.nextPacket()) + + var part2 = ByteBuffer.of(string: "OpenSSH_7.9\r\n") + parser.append(bytes: &part2) + + switch try parser.nextPacket() { + case .version(let string): + XCTAssertEqual(string, "xxxx\r\nyyyy\r\nSSH-2.0-OpenSSH_7.9") + default: + XCTFail("Expecting .version") + } + } + + func testReadVersionWithoutCarriageReturn() throws { + var parser = SSHPacketParser(allocator: ByteBufferAllocator()) + + var part1 = ByteBuffer.of(string: "SSH-2.0-") + parser.append(bytes: &part1) + + XCTAssertNil(try parser.nextPacket()) + + var part2 = ByteBuffer.of(string: "OpenSSH_7.4\n") + parser.append(bytes: &part2) + + switch try parser.nextPacket() { + case .version(let string): + XCTAssertEqual(string, "SSH-2.0-OpenSSH_7.4") + default: + XCTFail("Expecting .version") + } + } + + func testReadVersionWithExtraLinesWithoutCarriageReturn() throws { + var parser = SSHPacketParser(allocator: ByteBufferAllocator()) + + var part1 = ByteBuffer.of(string: "xxxx\nyyyy\nSSH-2.0-") + parser.append(bytes: &part1) + + XCTAssertNil(try parser.nextPacket()) + + var part2 = ByteBuffer.of(string: "OpenSSH_7.4\n") + parser.append(bytes: &part2) + + switch try parser.nextPacket() { + case .version(let string): + XCTAssertEqual(string, "xxxx\nyyyy\nSSH-2.0-OpenSSH_7.4") + default: + XCTFail("Expecting .version") + } + } + func testBinaryInParts() throws { var parser = SSHPacketParser(allocator: ByteBufferAllocator()) self.feedVersion(to: &parser)