Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions Sources/NIOSSH/SSHPacketParser.swift
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,22 @@ struct SSHPacketParser {
}

private mutating func readVersion() throws -> String? {
// Looking for a string ending with \r\n
let slice = self.buffer.readableBytesView
if let cr = slice.firstIndex(of: 13), cr.advanced(by: 1) < slice.endIndex, slice[cr.advanced(by: 1)] == 10 {
let version = String(decoding: slice[slice.startIndex ..< cr], as: UTF8.self)
// read \r\n
self.buffer.moveReaderIndex(forwardBy: slice.startIndex.distance(to: cr).advanced(by: 2))
return version
let carriageReturn = UInt8(ascii: "\r")
let lineFeed = UInt8(ascii: "\n")

// Search for version line, which starts with "SSH-". Lines without this prefix may come before the version line.
var slice = self.buffer.readableBytesView
while let lfIndex = slice.firstIndex(of: lineFeed), lfIndex < slice.endIndex {
if slice.starts(with: "SSH-".utf8) {
// Return all data upto the last LF we found, excluding the last [CR]LF.
slice = self.buffer.readableBytesView
let versionEndIndex = slice[lfIndex.advanced(by: -1)] == carriageReturn ? lfIndex.advanced(by: -1) : lfIndex
let version = String(decoding: slice[slice.startIndex ..< versionEndIndex], as: UTF8.self)
self.buffer.moveReaderIndex(forwardBy: slice.startIndex.distance(to: lfIndex).advanced(by: 1))
return version
} else {
slice = slice[slice.index(after: lfIndex)...]
}
}
return nil
}
Expand Down
57 changes: 57 additions & 0 deletions Tests/NIOSSHTests/SSHPacketParserTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,63 @@ final class SSHPacketParserTests: XCTestCase {
}
}

func testReadVersionWithExtraLines() throws {
var parser = SSHPacketParser(allocator: ByteBufferAllocator())

var part1 = ByteBuffer.of(string: "xxxx\r\nyyyy\r\nSSH-2.0-")
parser.append(bytes: &part1)

XCTAssertNil(try parser.nextPacket())

var part2 = ByteBuffer.of(string: "OpenSSH_7.9\r\n")
parser.append(bytes: &part2)

switch try parser.nextPacket() {
case .version(let string):
XCTAssertEqual(string, "xxxx\r\nyyyy\r\nSSH-2.0-OpenSSH_7.9")
default:
XCTFail("Expecting .version")
}
}

func testReadVersionWithoutCarriageReturn() throws {
var parser = SSHPacketParser(allocator: ByteBufferAllocator())

var part1 = ByteBuffer.of(string: "SSH-2.0-")
parser.append(bytes: &part1)

XCTAssertNil(try parser.nextPacket())

var part2 = ByteBuffer.of(string: "OpenSSH_7.4\n")
parser.append(bytes: &part2)

switch try parser.nextPacket() {
case .version(let string):
XCTAssertEqual(string, "SSH-2.0-OpenSSH_7.4")
default:
XCTFail("Expecting .version")
}
}

func testReadVersionWithExtraLinesWithoutCarriageReturn() throws {
var parser = SSHPacketParser(allocator: ByteBufferAllocator())

var part1 = ByteBuffer.of(string: "xxxx\nyyyy\nSSH-2.0-")
parser.append(bytes: &part1)

XCTAssertNil(try parser.nextPacket())

var part2 = ByteBuffer.of(string: "OpenSSH_7.4\n")
parser.append(bytes: &part2)

switch try parser.nextPacket() {
case .version(let string):
XCTAssertEqual(string, "xxxx\nyyyy\nSSH-2.0-OpenSSH_7.4")
default:
XCTFail("Expecting .version")
}
}

func testBinaryInParts() throws {
var parser = SSHPacketParser(allocator: ByteBufferAllocator())
self.feedVersion(to: &parser)
Expand Down