diff --git a/Sources/NIOSSH/Connection State Machine/SSHConnectionStateMachine.swift b/Sources/NIOSSH/Connection State Machine/SSHConnectionStateMachine.swift index 6ca7a12..815c7ac 100644 --- a/Sources/NIOSSH/Connection State Machine/SSHConnectionStateMachine.swift +++ b/Sources/NIOSSH/Connection State Machine/SSHConnectionStateMachine.swift @@ -37,6 +37,21 @@ struct SSHConnectionStateMachine { /// The SSH connection is active. case active(ActiveState) + /// We have received a KeyExchangeInit message from the active state, and so will be rekeying shortly. + case receivedKexInitWhenActive(ReceivedKexInitWhenActiveState) + + /// We have sent a KeyExchangeInit message from the active state, and so will be rekeying shortly. + case sentKexInitWhenActive(SentKexInitWhenActiveState) + + /// Both peers have sent KeyExchangeInit and are actively engaged in a key exchange to rekey. + case rekeying(RekeyingState) + + /// We have sent our newKeys message, but have not yet received the peer newKeys for this rekeying operation. + case rekeyingSentNewKeysState(RekeyingSentNewKeysState) + + /// We have received the peer newKeys message, but not yet sent our own for this rekeying operation. + case rekeyingReceivedNewKeysState(RekeyingReceivedNewKeysState) + case receivedDisconnect(SSHConnectionRole) case sentDisconnect(SSHConnectionRole) @@ -57,7 +72,9 @@ struct SSHConnectionStateMachine { switch self.state { case .idle: return SSHMultiMessage(SSHMessage.version(Constants.version)) - case .sentVersion, .keyExchange, .sentNewKeys, .receivedNewKeys, .userAuthentication, .active, .receivedDisconnect, .sentDisconnect: + case .sentVersion, .keyExchange, .sentNewKeys, .receivedNewKeys, .userAuthentication, + .active, .receivedKexInitWhenActive, .sentKexInitWhenActive, .rekeying, .rekeyingReceivedNewKeysState, + .rekeyingSentNewKeysState, .receivedDisconnect, .sentDisconnect: preconditionFailure("Cannot call start twice, state \(self.state)") } } @@ -84,6 +101,21 @@ struct SSHConnectionStateMachine { case .active(var state): state.parser.append(bytes: &data) self.state = .active(state) + case .receivedKexInitWhenActive(var state): + state.parser.append(bytes: &data) + self.state = .receivedKexInitWhenActive(state) + case .sentKexInitWhenActive(var state): + state.parser.append(bytes: &data) + self.state = .sentKexInitWhenActive(state) + case .rekeying(var state): + state.parser.append(bytes: &data) + self.state = .rekeying(state) + case .rekeyingReceivedNewKeysState(var state): + state.parser.append(bytes: &data) + self.state = .rekeyingReceivedNewKeysState(state) + case .rekeyingSentNewKeysState(var state): + state.parser.append(bytes: &data) + self.state = .rekeyingSentNewKeysState(state) case .receivedDisconnect, .sentDisconnect: // No more I/O, we're done. break @@ -103,10 +135,10 @@ struct SSHConnectionStateMachine { switch message { case .version(let version): try state.receiveVersionMessage(version) - var newState = KeyExchangeState(sentVersionState: state, allocator: allocator, remoteVersion: version) - let message = newState.keyExchangeStateMachine.startKeyExchange() + let newState = KeyExchangeState(sentVersionState: state, allocator: allocator, remoteVersion: version) + let message = newState.keyExchangeStateMachine.createKeyExchangeMessage() self.state = .keyExchange(newState) - return .emitMessage(message) + return .emitMessage(SSHMultiMessage(.keyExchange(message))) case .disconnect: self.state = .receivedDisconnect(state.role) @@ -321,8 +353,6 @@ struct SSHConnectionStateMachine { } switch message { - // TODO(cory): One day soon we'll need to support re-keying in this state. - // For now we only support channel messages. case .channelOpen(let message): try state.receiveChannelOpen(message) case .channelOpenConfirmation(let message): @@ -357,6 +387,12 @@ struct SSHConnectionStateMachine { try state.receiveRequestFailure() self.state = .active(state) return .globalRequestResponse(.failure) + case .keyExchange(let message): + // Attempting to rekey. + var newState = ReceivedKexInitWhenActiveState(state, allocator: allocator) + let result = try newState.receiveKeyExchangeMessage(message) + self.state = .receivedKexInitWhenActive(newState) + return result case .disconnect: self.state = .receivedDisconnect(state.role) return .disconnect @@ -374,6 +410,281 @@ struct SSHConnectionStateMachine { self.state = .active(state) return .forwardToMultiplexer(message) + case .receivedKexInitWhenActive(var state): + // We've received a key exchange packet. We only expect the first two messages (key exchange and key exchange init) before + // we have sent a reply. + guard let message = try state.parser.nextPacket() else { + return nil + } + + switch message { + case .keyExchange(let message): + let result = try state.receiveKeyExchangeMessage(message) + self.state = .receivedKexInitWhenActive(state) + return result + case .keyExchangeInit(let message): + let result = try state.receiveKeyExchangeInitMessage(message) + self.state = .receivedKexInitWhenActive(state) + return result + case .disconnect: + self.state = .receivedDisconnect(state.role) + return .disconnect + case .ignore, .debug: + // Ignore these + self.state = .receivedKexInitWhenActive(state) + return .noMessage + case .unimplemented(let unimplemented): + throw NIOSSHError.remotePeerDoesNotSupportMessage(unimplemented) + default: + // TODO: enforce RFC 4253: + // + // > Once a party has sent a SSH_MSG_KEXINIT message for key exchange or + // > re-exchange, until it has sent a SSH_MSG_NEWKEYS message (Section + // > 7.3), it MUST NOT send any messages other than: + // > + // > o Transport layer generic messages (1 to 19) (but + // > SSH_MSG_SERVICE_REQUEST and SSH_MSG_SERVICE_ACCEPT MUST NOT be + // > sent); + // > + // > o Algorithm negotiation messages (20 to 29) (but further + // > SSH_MSG_KEXINIT messages MUST NOT be sent); + // > + // > o Specific key exchange method messages (30 to 49). + // + // We should enforce that, but right now we don't have a good mechanism by which to do so. + throw NIOSSHError.protocolViolation(protocolName: "key exchange", violation: "Unexpected message: \(message)") + } + + case .sentKexInitWhenActive(var state): + // We've sent a key exchange packet. We expect channel messages _or_ a kexinit packet. + guard let message = try state.parser.nextPacket() else { + return nil + } + + switch message { + case .keyExchange(let message): + let result = try state.receiveKeyExchangeMessage(message) + self.state = .rekeying(.init(state)) + return result + + case .channelOpen(let message): + try state.receiveChannelOpen(message) + case .channelOpenConfirmation(let message): + try state.receiveChannelOpenConfirmation(message) + case .channelOpenFailure(let message): + try state.receiveChannelOpenFailure(message) + case .channelEOF(let message): + try state.receiveChannelEOF(message) + case .channelClose(let message): + try state.receiveChannelClose(message) + case .channelWindowAdjust(let message): + try state.receiveChannelWindowAdjust(message) + case .channelData(let message): + try state.receiveChannelData(message) + case .channelExtendedData(let message): + try state.receiveChannelExtendedData(message) + case .channelRequest(let message): + try state.receiveChannelRequest(message) + case .channelSuccess(let message): + try state.receiveChannelSuccess(message) + case .channelFailure(let message): + try state.receiveChannelFailure(message) + case .globalRequest(let message): + try state.receiveGlobalRequest(message) + self.state = .sentKexInitWhenActive(state) + return .globalRequest(message) + case .requestSuccess(let message): + try state.receiveRequestSuccess(message) + self.state = .sentKexInitWhenActive(state) + return .globalRequestResponse(.success(message)) + case .requestFailure: + try state.receiveRequestFailure() + self.state = .sentKexInitWhenActive(state) + return .globalRequestResponse(.failure) + case .disconnect: + self.state = .receivedDisconnect(state.role) + return .disconnect + case .ignore, .debug: + // Ignore these + self.state = .sentKexInitWhenActive(state) + return .noMessage + case .unimplemented(let unimplemented): + throw NIOSSHError.remotePeerDoesNotSupportMessage(unimplemented) + + default: + throw NIOSSHError.protocolViolation(protocolName: "connection", violation: "Unexpected inbound message: \(message)") + } + + self.state = .sentKexInitWhenActive(state) + return .forwardToMultiplexer(message) + + case .rekeying(var state): + // This is basically the regular key exchange state. + guard let message = try state.parser.nextPacket() else { + return nil + } + + switch message { + case .keyExchange(let message): + let result = try state.receiveKeyExchangeMessage(message) + self.state = .rekeying(state) + return result + case .keyExchangeInit(let message): + let result = try state.receiveKeyExchangeInitMessage(message) + self.state = .rekeying(state) + return result + case .keyExchangeReply(let message): + let result = try state.receiveKeyExchangeReplyMessage(message) + self.state = .rekeying(state) + return result + case .newKeys: + try state.receiveNewKeysMessage() + let newState = RekeyingReceivedNewKeysState(state) + self.state = .rekeyingReceivedNewKeysState(newState) + return .noMessage + case .disconnect: + self.state = .receivedDisconnect(state.role) + return .disconnect + case .ignore, .debug: + // Ignore these + self.state = .rekeying(state) + return .noMessage + case .unimplemented(let unimplemented): + throw NIOSSHError.remotePeerDoesNotSupportMessage(unimplemented) + default: + // TODO: enforce RFC 4253: + // + // > Once a party has sent a SSH_MSG_KEXINIT message for key exchange or + // > re-exchange, until it has sent a SSH_MSG_NEWKEYS message (Section + // > 7.3), it MUST NOT send any messages other than: + // > + // > o Transport layer generic messages (1 to 19) (but + // > SSH_MSG_SERVICE_REQUEST and SSH_MSG_SERVICE_ACCEPT MUST NOT be + // > sent); + // > + // > o Algorithm negotiation messages (20 to 29) (but further + // > SSH_MSG_KEXINIT messages MUST NOT be sent); + // > + // > o Specific key exchange method messages (30 to 49). + // + // We should enforce that, but right now we don't have a good mechanism by which to do so. + throw NIOSSHError.protocolViolation(protocolName: "user auth", violation: "Unexpected user auth message: \(message)") + } + + case .rekeyingReceivedNewKeysState(var state): + // This is basically a regular active state. + guard let message = try state.parser.nextPacket() else { + return nil + } + + switch message { + // TODO(cory): One day soon we'll need to support re-keying in this state. + // For now we only support channel messages. + case .channelOpen(let message): + try state.receiveChannelOpen(message) + case .channelOpenConfirmation(let message): + try state.receiveChannelOpenConfirmation(message) + case .channelOpenFailure(let message): + try state.receiveChannelOpenFailure(message) + case .channelEOF(let message): + try state.receiveChannelEOF(message) + case .channelClose(let message): + try state.receiveChannelClose(message) + case .channelWindowAdjust(let message): + try state.receiveChannelWindowAdjust(message) + case .channelData(let message): + try state.receiveChannelData(message) + case .channelExtendedData(let message): + try state.receiveChannelExtendedData(message) + case .channelRequest(let message): + try state.receiveChannelRequest(message) + case .channelSuccess(let message): + try state.receiveChannelSuccess(message) + case .channelFailure(let message): + try state.receiveChannelFailure(message) + case .globalRequest(let message): + try state.receiveGlobalRequest(message) + self.state = .rekeyingReceivedNewKeysState(state) + return .globalRequest(message) + case .requestSuccess(let message): + try state.receiveRequestSuccess(message) + self.state = .rekeyingReceivedNewKeysState(state) + return .globalRequestResponse(.success(message)) + case .requestFailure: + try state.receiveRequestFailure() + self.state = .rekeyingReceivedNewKeysState(state) + return .globalRequestResponse(.failure) + case .disconnect: + self.state = .receivedDisconnect(state.role) + return .disconnect + case .ignore, .debug: + // Ignore these + self.state = .rekeyingReceivedNewKeysState(state) + return .noMessage + case .unimplemented(let unimplemented): + throw NIOSSHError.remotePeerDoesNotSupportMessage(unimplemented) + + default: + throw NIOSSHError.protocolViolation(protocolName: "connection", violation: "Unexpected inbound message: \(message)") + } + + self.state = .rekeyingReceivedNewKeysState(state) + return .forwardToMultiplexer(message) + + case .rekeyingSentNewKeysState(var state): + // This is key exchange state. + guard let message = try state.parser.nextPacket() else { + return nil + } + + switch message { + case .keyExchange(let message): + let result = try state.receiveKeyExchangeMessage(message) + self.state = .rekeyingSentNewKeysState(state) + return result + case .keyExchangeInit(let message): + let result = try state.receiveKeyExchangeInitMessage(message) + self.state = .rekeyingSentNewKeysState(state) + return result + case .keyExchangeReply(let message): + let result = try state.receiveKeyExchangeReplyMessage(message) + self.state = .rekeyingSentNewKeysState(state) + return result + case .newKeys: + try state.receiveNewKeysMessage() + let newState = ActiveState(state) + self.state = .active(newState) + return .noMessage + case .disconnect: + self.state = .receivedDisconnect(state.role) + return .disconnect + case .ignore, .debug: + // Ignore these + self.state = .rekeyingSentNewKeysState(state) + return .noMessage + case .unimplemented(let unimplemented): + throw NIOSSHError.remotePeerDoesNotSupportMessage(unimplemented) + + default: + // TODO: enforce RFC 4253: + // + // > Once a party has sent a SSH_MSG_KEXINIT message for key exchange or + // > re-exchange, until it has sent a SSH_MSG_NEWKEYS message (Section + // > 7.3), it MUST NOT send any messages other than: + // > + // > o Transport layer generic messages (1 to 19) (but + // > SSH_MSG_SERVICE_REQUEST and SSH_MSG_SERVICE_ACCEPT MUST NOT be + // > sent); + // > + // > o Algorithm negotiation messages (20 to 29) (but further + // > SSH_MSG_KEXINIT messages MUST NOT be sent); + // > + // > o Specific key exchange method messages (30 to 49). + // + // We should enforce that, but right now we don't have a good mechanism by which to do so. + throw NIOSSHError.protocolViolation(protocolName: "key exchange", violation: "Unexpected message: \(message)") + } + case .receivedDisconnect, .sentDisconnect: // We do no further I/O in these states. return nil @@ -525,8 +836,6 @@ struct SSHConnectionStateMachine { case .active(var state): switch message { - // TODO(cory): One day soon we'll need to support re-keying in this state. - // For now we only support channel messages. case .channelOpen(let message): try state.writeChannelOpen(message, into: &buffer) case .channelOpenConfirmation(let message): @@ -567,6 +876,140 @@ struct SSHConnectionStateMachine { self.state = .active(state) + case .receivedKexInitWhenActive(var state): + // In this state we only allow sending key exchange messages. In particular, the key exchange message is the only allowed one. + switch message { + case .keyExchange(let keyExchangeMessage): + try state.writeKeyExchangeMessage(keyExchangeMessage, into: &buffer) + self.state = .rekeying(.init(state)) + + case .disconnect: + try state.serializer.serialize(message: message, to: &buffer) + self.state = .sentDisconnect(state.role) + + case .ignore, .debug, .unimplemented: + try state.serializer.serialize(message: message, to: &buffer) + self.state = .receivedKexInitWhenActive(state) + + default: + throw NIOSSHError.protocolViolation(protocolName: "key exchange", violation: "Sent unexpected message type: \(message)") + } + + case .sentKexInitWhenActive(var state): + // In this state we've send a key exchange init message, but not received one from the peer. We have nothing to send. + switch message { + case .disconnect: + try state.serializer.serialize(message: message, to: &buffer) + self.state = .sentDisconnect(state.role) + + case .ignore, .debug, .unimplemented: + try state.serializer.serialize(message: message, to: &buffer) + self.state = .sentKexInitWhenActive(state) + + default: + throw NIOSSHError.protocolViolation(protocolName: "key exchange", violation: "Sent unexpected message type: \(message)") + } + + case .rekeying(var state): + // This is a full key exchange state. + switch message { + case .keyExchange(let keyExchangeMessage): + try state.writeKeyExchangeMessage(keyExchangeMessage, into: &buffer) + self.state = .rekeying(state) + case .keyExchangeInit(let kexInit): + try state.writeKeyExchangeInitMessage(kexInit, into: &buffer) + self.state = .rekeying(state) + case .keyExchangeReply(let kexReply): + try state.writeKeyExchangeReplyMessage(kexReply, into: &buffer) + self.state = .rekeying(state) + case .newKeys: + try state.writeNewKeysMessage(into: &buffer) + self.state = .rekeyingSentNewKeysState(.init(state)) + + case .disconnect: + try state.serializer.serialize(message: message, to: &buffer) + self.state = .sentDisconnect(state.role) + + case .ignore, .debug, .unimplemented: + try state.serializer.serialize(message: message, to: &buffer) + self.state = .rekeying(state) + + default: + throw NIOSSHError.protocolViolation(protocolName: "key exchange", violation: "Sent unexpected message type: \(message)") + } + + case .rekeyingReceivedNewKeysState(var state): + // We may be doing any part of key exchange still. + switch message { + case .keyExchange(let keyExchangeMessage): + try state.writeKeyExchangeMessage(keyExchangeMessage, into: &buffer) + self.state = .rekeyingReceivedNewKeysState(state) + case .keyExchangeInit(let kexInit): + try state.writeKeyExchangeInitMessage(kexInit, into: &buffer) + self.state = .rekeyingReceivedNewKeysState(state) + case .keyExchangeReply(let kexReply): + try state.writeKeyExchangeReplyMessage(kexReply, into: &buffer) + self.state = .rekeyingReceivedNewKeysState(state) + case .newKeys: + try state.writeNewKeysMessage(into: &buffer) + self.state = .active(.init(state)) + + case .disconnect: + try state.serializer.serialize(message: message, to: &buffer) + self.state = .sentDisconnect(state.role) + + case .ignore, .debug, .unimplemented: + try state.serializer.serialize(message: message, to: &buffer) + self.state = .rekeyingReceivedNewKeysState(state) + + default: + throw NIOSSHError.protocolViolation(protocolName: "key exchange", violation: "Sent unexpected message type: \(message)") + } + + case .rekeyingSentNewKeysState(var state): + // We can send channel messages again. + switch message { + case .channelOpen(let message): + try state.writeChannelOpen(message, into: &buffer) + case .channelOpenConfirmation(let message): + try state.writeChannelOpenConfirmation(message, into: &buffer) + case .channelOpenFailure(let message): + try state.writeChannelOpenFailure(message, into: &buffer) + case .channelEOF(let message): + try state.writeChannelEOF(message, into: &buffer) + case .channelClose(let message): + try state.writeChannelClose(message, into: &buffer) + case .channelWindowAdjust(let message): + try state.writeChannelWindowAdjust(message, into: &buffer) + case .channelData(let message): + try state.writeChannelData(message, into: &buffer) + case .channelExtendedData(let message): + try state.writeChannelExtendedData(message, into: &buffer) + case .channelRequest(let message): + try state.writeChannelRequest(message, into: &buffer) + case .channelSuccess(let message): + try state.writeChannelSuccess(message, into: &buffer) + case .channelFailure(let message): + try state.writeChannelFailure(message, into: &buffer) + case .globalRequest(let message): + try state.writeGlobalRequest(message, into: &buffer) + case .requestSuccess(let message): + try state.writeRequestSuccess(message, into: &buffer) + case .requestFailure: + try state.writeRequestFailure(into: &buffer) + case .disconnect: + try state.serializer.serialize(message: message, to: &buffer) + self.state = .sentDisconnect(state.role) + return + case .ignore, .debug, .unimplemented: + try state.serializer.serialize(message: message, to: &buffer) + self.state = .rekeyingSentNewKeysState(state) + default: + throw NIOSSHError.protocolViolation(protocolName: "connection", violation: "Sent unexpected message type: \(message)") + } + + self.state = .rekeyingSentNewKeysState(state) + case .sentDisconnect, .receivedDisconnect: // We don't allow more messages once disconnect has occured throw NIOSSHError.protocolViolation(protocolName: "transport", violation: "I/O after disconnect") @@ -597,6 +1040,27 @@ extension SSHConnectionStateMachine { } } +// MARK: Rekeying + +extension SSHConnectionStateMachine { + // Called when we wish to re-key the connection. + mutating func beginRekeying(buffer: inout ByteBuffer, allocator: ByteBufferAllocator) throws { + switch self.state { + case .active(let state): + // Trying to rekey. + var newState = SentKexInitWhenActiveState(state, allocator: allocator) + let message = newState.keyExchangeStateMachine.createKeyExchangeMessage() + try newState.writeKeyExchangeMessage(message, into: &buffer) + self.state = .sentKexInitWhenActive(newState) + return + case .idle, .sentVersion, .keyExchange, .receivedNewKeys, .sentNewKeys, .userAuthentication, + .receivedKexInitWhenActive, .sentKexInitWhenActive, .rekeying, .rekeyingReceivedNewKeysState, + .rekeyingSentNewKeysState, .receivedDisconnect, .sentDisconnect: + preconditionFailure("May not rekey in this state: \(self.state)") + } + } +} + // MARK: Helper properties extension SSHConnectionStateMachine { @@ -604,7 +1068,9 @@ extension SSHConnectionStateMachine { switch self.state { case .active: return true - case .idle, .sentVersion, .keyExchange, .receivedNewKeys, .sentNewKeys, .userAuthentication, .receivedDisconnect, .sentDisconnect: + case .idle, .sentVersion, .keyExchange, .receivedNewKeys, .sentNewKeys, .userAuthentication, + .receivedKexInitWhenActive, .sentKexInitWhenActive, .rekeying, .rekeyingReceivedNewKeysState, + .rekeyingSentNewKeysState, .receivedDisconnect, .sentDisconnect: return false } } @@ -613,7 +1079,9 @@ extension SSHConnectionStateMachine { switch self.state { case .receivedDisconnect, .sentDisconnect: return true - case .idle, .sentVersion, .keyExchange, .receivedNewKeys, .sentNewKeys, .userAuthentication, .active: + case .idle, .sentVersion, .keyExchange, .receivedNewKeys, .sentNewKeys, .userAuthentication, .active, + .receivedKexInitWhenActive, .sentKexInitWhenActive, .rekeying, .rekeyingReceivedNewKeysState, + .rekeyingSentNewKeysState: return false } } @@ -634,6 +1102,16 @@ extension SSHConnectionStateMachine { return state.role case .active(let state): return state.role + case .receivedKexInitWhenActive(let state): + return state.role + case .sentKexInitWhenActive(let state): + return state.role + case .rekeying(let state): + return state.role + case .rekeyingReceivedNewKeysState(let state): + return state.role + case .rekeyingSentNewKeysState(let state): + return state.role case .receivedDisconnect(let role): return role case .sentDisconnect(let role): diff --git a/Sources/NIOSSH/Connection State Machine/States/ActiveState.swift b/Sources/NIOSSH/Connection State Machine/States/ActiveState.swift index a14a573..c243fac 100644 --- a/Sources/NIOSSH/Connection State Machine/States/ActiveState.swift +++ b/Sources/NIOSSH/Connection State Machine/States/ActiveState.swift @@ -11,6 +11,7 @@ // SPDX-License-Identifier: Apache-2.0 // //===----------------------------------------------------------------------===// +import NIO extension SSHConnectionStateMachine { /// The state of a state machine that has completed user auth and key exchange and is @@ -24,10 +25,37 @@ extension SSHConnectionStateMachine { internal var parser: SSHPacketParser + internal var remoteVersion: String + + internal var protectionSchemes: [NIOSSHTransportProtection.Type] + + internal var sessionIdentifier: ByteBuffer + init(_ previous: UserAuthenticationState) { self.role = previous.role self.serializer = previous.serializer self.parser = previous.parser + self.remoteVersion = previous.remoteVersion + self.protectionSchemes = previous.protectionSchemes + self.sessionIdentifier = previous.sessionIdentifier + } + + init(_ previous: RekeyingReceivedNewKeysState) { + self.role = previous.role + self.serializer = previous.serializer + self.parser = previous.parser + self.remoteVersion = previous.remoteVersion + self.protectionSchemes = previous.protectionSchemes + self.sessionIdentifier = previous.sessionIdentifier + } + + init(_ previous: RekeyingSentNewKeysState) { + self.role = previous.role + self.serializer = previous.serializer + self.parser = previous.parser + self.remoteVersion = previous.remoteVersion + self.protectionSchemes = previous.protectionSchemes + self.sessionIdentifier = previous.sessionIdentifier } } } diff --git a/Sources/NIOSSH/Connection State Machine/States/KeyExchangeState.swift b/Sources/NIOSSH/Connection State Machine/States/KeyExchangeState.swift index 3d4b80a..dfc13f4 100644 --- a/Sources/NIOSSH/Connection State Machine/States/KeyExchangeState.swift +++ b/Sources/NIOSSH/Connection State Machine/States/KeyExchangeState.swift @@ -26,6 +26,10 @@ extension SSHConnectionStateMachine { /// The packet serializer used by this state machine. var serializer: SSHPacketSerializer + var remoteVersion: String + + var protectionSchemes: [NIOSSHTransportProtection.Type] + /// The backing state machine. var keyExchangeStateMachine: SSHKeyExchangeStateMachine @@ -33,7 +37,9 @@ extension SSHConnectionStateMachine { self.role = state.role self.parser = state.parser self.serializer = state.serializer - self.keyExchangeStateMachine = SSHKeyExchangeStateMachine(allocator: allocator, role: state.role, remoteVersion: remoteVersion, protectionSchemes: state.protectionSchemes) + self.remoteVersion = remoteVersion + self.protectionSchemes = state.protectionSchemes + self.keyExchangeStateMachine = SSHKeyExchangeStateMachine(allocator: allocator, role: state.role, remoteVersion: remoteVersion, protectionSchemes: state.protectionSchemes, previousSessionIdentifier: nil) } } } diff --git a/Sources/NIOSSH/Connection State Machine/States/ReceivedKexInitWhenActiveState.swift b/Sources/NIOSSH/Connection State Machine/States/ReceivedKexInitWhenActiveState.swift new file mode 100644 index 0000000..9101058 --- /dev/null +++ b/Sources/NIOSSH/Connection State Machine/States/ReceivedKexInitWhenActiveState.swift @@ -0,0 +1,53 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2020 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +import NIO + +extension SSHConnectionStateMachine { + /// The state of a state machine that has received a KeyExchangeInit message after + /// having been active. In this state, no further channel messages may be sent by the + /// remote peer until key exchange is done. We can send channel messages _and_ key exchange init. + struct ReceivedKexInitWhenActiveState { + /// The role of the connection + let role: SSHConnectionRole + + /// The packet serializer used by this state machine. + internal var serializer: SSHPacketSerializer + + internal var parser: SSHPacketParser + + internal var remoteVersion: String + + internal var protectionSchemes: [NIOSSHTransportProtection.Type] + + internal var keyExchangeStateMachine: SSHKeyExchangeStateMachine + + internal var sessionIdentifier: ByteBuffer + + init(_ previous: ActiveState, allocator: ByteBufferAllocator) { + self.role = previous.role + self.serializer = previous.serializer + self.parser = previous.parser + self.remoteVersion = previous.remoteVersion + self.protectionSchemes = previous.protectionSchemes + self.sessionIdentifier = previous.sessionIdentifier + self.keyExchangeStateMachine = SSHKeyExchangeStateMachine(allocator: allocator, role: previous.role, remoteVersion: previous.remoteVersion, protectionSchemes: previous.protectionSchemes, previousSessionIdentifier: self.sessionIdentifier) + } + } +} + +extension SSHConnectionStateMachine.ReceivedKexInitWhenActiveState: AcceptsKeyExchangeMessages {} + +extension SSHConnectionStateMachine.ReceivedKexInitWhenActiveState: SendsChannelMessages {} + +extension SSHConnectionStateMachine.ReceivedKexInitWhenActiveState: SendsKeyExchangeMessages {} diff --git a/Sources/NIOSSH/Connection State Machine/States/ReceivedNewKeysState.swift b/Sources/NIOSSH/Connection State Machine/States/ReceivedNewKeysState.swift index 7ee37e3..e7ac34a 100644 --- a/Sources/NIOSSH/Connection State Machine/States/ReceivedNewKeysState.swift +++ b/Sources/NIOSSH/Connection State Machine/States/ReceivedNewKeysState.swift @@ -27,6 +27,12 @@ extension SSHConnectionStateMachine { /// The packet serializer used by this state machine. var serializer: SSHPacketSerializer + var remoteVersion: String + + var protectionSchemes: [NIOSSHTransportProtection.Type] + + var sessionIdentifier: ByteBuffer + /// The backing state machine. var keyExchangeStateMachine: SSHKeyExchangeStateMachine @@ -38,12 +44,15 @@ extension SSHConnectionStateMachine { self.role = state.role self.parser = state.parser self.serializer = state.serializer + self.remoteVersion = state.remoteVersion + self.protectionSchemes = state.protectionSchemes self.keyExchangeStateMachine = state.keyExchangeStateMachine // We force unwrap the session ID because it's programmer error to not have it at this time. + self.sessionIdentifier = state.keyExchangeStateMachine.sessionID! self.userAuthStateMachine = UserAuthenticationStateMachine(role: self.role, loop: loop, - sessionID: state.keyExchangeStateMachine.sessionID!) + sessionID: self.sessionIdentifier) } } } diff --git a/Sources/NIOSSH/Connection State Machine/States/RekeyingReceivedNewKeysState.swift b/Sources/NIOSSH/Connection State Machine/States/RekeyingReceivedNewKeysState.swift new file mode 100644 index 0000000..37a9c5b --- /dev/null +++ b/Sources/NIOSSH/Connection State Machine/States/RekeyingReceivedNewKeysState.swift @@ -0,0 +1,53 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2020 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import NIO + +extension SSHConnectionStateMachine { + /// The state of a state machine that has received new keys after a key exchange operation from active, + /// but has not yet sent its new keys to the peer. + struct RekeyingReceivedNewKeysState { + /// The role of the connection + let role: SSHConnectionRole + + /// The packet parser. + var parser: SSHPacketParser + + /// The packet serializer used by this state machine. + var serializer: SSHPacketSerializer + + var remoteVersion: String + + var protectionSchemes: [NIOSSHTransportProtection.Type] + + var sessionIdentifier: ByteBuffer + + /// The backing state machine. + var keyExchangeStateMachine: SSHKeyExchangeStateMachine + + init(_ previousState: RekeyingState) { + self.role = previousState.role + self.parser = previousState.parser + self.serializer = previousState.serializer + self.remoteVersion = previousState.remoteVersion + self.protectionSchemes = previousState.protectionSchemes + self.sessionIdentifier = previousState.sessionIdentifier + self.keyExchangeStateMachine = previousState.keyExchangeStateMachine + } + } +} + +extension SSHConnectionStateMachine.RekeyingReceivedNewKeysState: SendsKeyExchangeMessages {} + +extension SSHConnectionStateMachine.RekeyingReceivedNewKeysState: AcceptsChannelMessages {} diff --git a/Sources/NIOSSH/Connection State Machine/States/RekeyingSentNewKeysState.swift b/Sources/NIOSSH/Connection State Machine/States/RekeyingSentNewKeysState.swift new file mode 100644 index 0000000..9e6af70 --- /dev/null +++ b/Sources/NIOSSH/Connection State Machine/States/RekeyingSentNewKeysState.swift @@ -0,0 +1,53 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2020 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import NIO + +extension SSHConnectionStateMachine { + /// The state of a state machine that has sent new keys after a key exchange operation from an active channel, + /// but has not yet received the new keys from the peer. + struct RekeyingSentNewKeysState { + /// The role of the connection + let role: SSHConnectionRole + + /// The packet parser. + var parser: SSHPacketParser + + /// The packet serializer used by this state machine. + var serializer: SSHPacketSerializer + + var remoteVersion: String + + var protectionSchemes: [NIOSSHTransportProtection.Type] + + var sessionIdentifier: ByteBuffer + + /// The backing state machine. + var keyExchangeStateMachine: SSHKeyExchangeStateMachine + + init(_ previousState: RekeyingState) { + self.role = previousState.role + self.parser = previousState.parser + self.serializer = previousState.serializer + self.remoteVersion = previousState.remoteVersion + self.protectionSchemes = previousState.protectionSchemes + self.sessionIdentifier = previousState.sessionIdentifier + self.keyExchangeStateMachine = previousState.keyExchangeStateMachine + } + } +} + +extension SSHConnectionStateMachine.RekeyingSentNewKeysState: AcceptsKeyExchangeMessages {} + +extension SSHConnectionStateMachine.RekeyingSentNewKeysState: SendsChannelMessages {} diff --git a/Sources/NIOSSH/Connection State Machine/States/RekeyingState.swift b/Sources/NIOSSH/Connection State Machine/States/RekeyingState.swift new file mode 100644 index 0000000..000e415 --- /dev/null +++ b/Sources/NIOSSH/Connection State Machine/States/RekeyingState.swift @@ -0,0 +1,62 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2020 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import NIO + +extension SSHConnectionStateMachine { + /// The state of a state machine that is actively engaged in a key exchange operation having been active before. + struct RekeyingState { + /// The role of the connection + let role: SSHConnectionRole + + /// The packet parser. + var parser: SSHPacketParser + + /// The packet serializer used by this state machine. + var serializer: SSHPacketSerializer + + var remoteVersion: String + + var protectionSchemes: [NIOSSHTransportProtection.Type] + + var sessionIdentifier: ByteBuffer + + /// The backing state machine. + var keyExchangeStateMachine: SSHKeyExchangeStateMachine + + init(_ previousState: ReceivedKexInitWhenActiveState) { + self.role = previousState.role + self.parser = previousState.parser + self.serializer = previousState.serializer + self.remoteVersion = previousState.remoteVersion + self.protectionSchemes = previousState.protectionSchemes + self.sessionIdentifier = previousState.sessionIdentifier + self.keyExchangeStateMachine = previousState.keyExchangeStateMachine + } + + init(_ previousState: SentKexInitWhenActiveState) { + self.role = previousState.role + self.parser = previousState.parser + self.serializer = previousState.serializer + self.remoteVersion = previousState.remoteVersion + self.protectionSchemes = previousState.protectionSchemes + self.sessionIdentifier = previousState.sessionIdentitifier + self.keyExchangeStateMachine = previousState.keyExchangeStateMachine + } + } +} + +extension SSHConnectionStateMachine.RekeyingState: AcceptsKeyExchangeMessages {} + +extension SSHConnectionStateMachine.RekeyingState: SendsKeyExchangeMessages {} diff --git a/Sources/NIOSSH/Connection State Machine/States/SentKexInitWhenActiveState.swift b/Sources/NIOSSH/Connection State Machine/States/SentKexInitWhenActiveState.swift new file mode 100644 index 0000000..fd7d848 --- /dev/null +++ b/Sources/NIOSSH/Connection State Machine/States/SentKexInitWhenActiveState.swift @@ -0,0 +1,53 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2020 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +import NIO + +extension SSHConnectionStateMachine { + /// The state of a state machine that has sent a KeyExchangeInit message after + /// having been active. In this state, no further channel messages may be sent by the + /// us until key exchange is done. We can receive channel messages _and_ key exchange init. + struct SentKexInitWhenActiveState { + /// The role of the connection + let role: SSHConnectionRole + + /// The packet serializer used by this state machine. + internal var serializer: SSHPacketSerializer + + internal var parser: SSHPacketParser + + internal var remoteVersion: String + + internal var protectionSchemes: [NIOSSHTransportProtection.Type] + + internal var sessionIdentitifier: ByteBuffer + + internal var keyExchangeStateMachine: SSHKeyExchangeStateMachine + + init(_ previous: ActiveState, allocator: ByteBufferAllocator) { + self.role = previous.role + self.serializer = previous.serializer + self.parser = previous.parser + self.remoteVersion = previous.remoteVersion + self.protectionSchemes = previous.protectionSchemes + self.sessionIdentitifier = previous.sessionIdentifier + self.keyExchangeStateMachine = SSHKeyExchangeStateMachine(allocator: allocator, role: self.role, remoteVersion: self.remoteVersion, protectionSchemes: self.protectionSchemes, previousSessionIdentifier: previous.sessionIdentifier) + } + } +} + +extension SSHConnectionStateMachine.SentKexInitWhenActiveState: AcceptsKeyExchangeMessages {} + +extension SSHConnectionStateMachine.SentKexInitWhenActiveState: AcceptsChannelMessages {} + +extension SSHConnectionStateMachine.SentKexInitWhenActiveState: SendsKeyExchangeMessages {} diff --git a/Sources/NIOSSH/Connection State Machine/States/SentNewKeysState.swift b/Sources/NIOSSH/Connection State Machine/States/SentNewKeysState.swift index 12b99e3..5452fb5 100644 --- a/Sources/NIOSSH/Connection State Machine/States/SentNewKeysState.swift +++ b/Sources/NIOSSH/Connection State Machine/States/SentNewKeysState.swift @@ -27,6 +27,12 @@ extension SSHConnectionStateMachine { /// The packet serializer used by this state machine. var serializer: SSHPacketSerializer + var remoteVersion: String + + var protectionSchemes: [NIOSSHTransportProtection.Type] + + var sessionIdentifier: ByteBuffer + /// The backing state machine. var keyExchangeStateMachine: SSHKeyExchangeStateMachine @@ -39,11 +45,14 @@ extension SSHConnectionStateMachine { self.parser = state.parser self.serializer = state.serializer self.keyExchangeStateMachine = state.keyExchangeStateMachine + self.remoteVersion = state.remoteVersion + self.protectionSchemes = state.protectionSchemes // We force unwrap the session ID here because it's programmer error to not have it at this stage. + self.sessionIdentifier = self.keyExchangeStateMachine.sessionID! self.userAuthStateMachine = UserAuthenticationStateMachine(role: self.role, loop: loop, - sessionID: self.keyExchangeStateMachine.sessionID!) + sessionID: self.sessionIdentifier) } } } diff --git a/Sources/NIOSSH/Connection State Machine/States/UserAuthenticationState.swift b/Sources/NIOSSH/Connection State Machine/States/UserAuthenticationState.swift index 1cfe752..7b6d2b7 100644 --- a/Sources/NIOSSH/Connection State Machine/States/UserAuthenticationState.swift +++ b/Sources/NIOSSH/Connection State Machine/States/UserAuthenticationState.swift @@ -26,6 +26,12 @@ extension SSHConnectionStateMachine { /// The packet serializer used by this state machine. var serializer: SSHPacketSerializer + var remoteVersion: String + + var protectionSchemes: [NIOSSHTransportProtection.Type] + + var sessionIdentifier: ByteBuffer + /// The backing state machine. var userAuthStateMachine: UserAuthenticationStateMachine @@ -34,6 +40,9 @@ extension SSHConnectionStateMachine { self.parser = state.parser self.serializer = state.serializer self.userAuthStateMachine = state.userAuthStateMachine + self.remoteVersion = state.remoteVersion + self.protectionSchemes = state.protectionSchemes + self.sessionIdentifier = state.sessionIdentifier } init(receivedNewKeysState state: ReceivedNewKeysState) { @@ -41,6 +50,9 @@ extension SSHConnectionStateMachine { self.parser = state.parser self.serializer = state.serializer self.userAuthStateMachine = state.userAuthStateMachine + self.remoteVersion = state.remoteVersion + self.protectionSchemes = state.protectionSchemes + self.sessionIdentifier = state.sessionIdentifier } } } diff --git a/Sources/NIOSSH/Key Exchange/SSHKeyExchangeStateMachine.swift b/Sources/NIOSSH/Key Exchange/SSHKeyExchangeStateMachine.swift index 89fb29b..22c160a 100644 --- a/Sources/NIOSSH/Key Exchange/SSHKeyExchangeStateMachine.swift +++ b/Sources/NIOSSH/Key Exchange/SSHKeyExchangeStateMachine.swift @@ -27,12 +27,19 @@ struct SSHKeyExchangeStateMachine { /// We've sent our key exchange message. /// /// Either clients or servers can send this message: they are entitled to race. Thus, either - /// party can enter this state. We assume they will so there is no equivalent keyExchangeReceived - /// state. + /// party can enter this state. /// /// We store the message we sent for later. case keyExchangeSent(message: SSHMessage.KeyExchangeMessage) + /// We've received a key exchange message. + /// + /// Either clients or servers can send this message: they are entitled to race. Thus, either + /// party can enter this state. The remote peer may be sending a guess as well. + /// + /// We store the message we sent for later. + case keyExchangeReceived(exchange: Curve25519KeyExchange, negotiated: NegotiationResult, expectingGuess: Bool) + /// The peer has guessed what key exchange init packet is coming, and guessed wrong. We need to wait for them to send that packet. case awaitingKeyExchangeInitInvalidGuess(exchange: Curve25519KeyExchange, negotiated: NegotiationResult) @@ -63,13 +70,15 @@ struct SSHKeyExchangeStateMachine { private var state: State private var initialExchangeBytes: ByteBuffer private var protectionSchemes: [NIOSSHTransportProtection.Type] + private var previousSessionIdentifier: ByteBuffer? - init(allocator: ByteBufferAllocator, role: SSHConnectionRole, remoteVersion: String, protectionSchemes: [NIOSSHTransportProtection.Type]) { + init(allocator: ByteBufferAllocator, role: SSHConnectionRole, remoteVersion: String, protectionSchemes: [NIOSSHTransportProtection.Type], previousSessionIdentifier: ByteBuffer?) { self.allocator = allocator self.role = role self.initialExchangeBytes = allocator.buffer(capacity: 1024) self.state = .idle self.protectionSchemes = protectionSchemes + self.previousSessionIdentifier = previousSessionIdentifier switch self.role { case .client: @@ -83,13 +92,13 @@ struct SSHKeyExchangeStateMachine { /// Currently we statically only use a single key exchange message. In future this will expand out to /// support arbitrary SSHTransportProtection schemes. - private func createKeyExchangeMessage() -> SSHMessage { + func createKeyExchangeMessage() -> SSHMessage.KeyExchangeMessage { var rng = CSPRNG() let encryptionAlgorithms = self.supportedEncryptionAlgorithms let macAlgorithms = self.supportedMacAlgorithms - return .keyExchange(.init( + return .init( cookie: rng.randomCookie(allocator: self.allocator), keyExchangeAlgorithms: Self.supportedKeyExchangeAlgorithms, serverHostKeyAlgorithms: self.supportedHostKeyAlgorithms, @@ -102,18 +111,7 @@ struct SSHKeyExchangeStateMachine { languagesClientToServer: [], languagesServerToClient: [], firstKexPacketFollows: false - )) - } - - /// Begins the key exchange process. This may be called by both clients and servers to speed up the key exchange process. - mutating func startKeyExchange() -> SSHMultiMessage { - switch self.state { - case .idle: - return SSHMultiMessage(self.createKeyExchangeMessage()) - - case .keyExchangeSent, .awaitingKeyExchangeInitInvalidGuess, .keyExchangeInitSent, .awaitingKeyExchangeInit, .keyExchangeInitReceived, .keysExchanged, .newKeysSent, .newKeysReceived, .complete: - preconditionFailure("Duplicate call to startKeyExchange") - } + ) } mutating func handle(keyExchange message: SSHMessage.KeyExchangeMessage) throws -> SSHMultiMessage? { @@ -147,8 +145,32 @@ struct SSHKeyExchangeStateMachine { return nil } case .idle: - preconditionFailure("Received the key exchange message before we sent our own") - case .awaitingKeyExchangeInit, .awaitingKeyExchangeInitInvalidGuess, .keyExchangeInitReceived, .keyExchangeInitSent, .keysExchanged, .newKeysSent, .newKeysReceived, .complete: + // We received a key exchange message while idle. We will need to send our own key exchange message back, + // and also follow immediately up with our own key exchange init message. + let ourMessage = self.createKeyExchangeMessage() + + switch self.role { + case .client: + self.addKeyExchangeInitMessagesToExchangeBytes(clientsMessage: ourMessage, serversMessage: message) + case .server: + self.addKeyExchangeInitMessagesToExchangeBytes(clientsMessage: message, serversMessage: ourMessage) + } + + let negotiated = try self.negotiatedAlgorithms(message) + let exchanger = self.exchangerForAlgorithm(negotiated.negotiatedKeyExchangeAlgorithm) + + let result: SSHMultiMessage + switch self.role { + case .client: + result = SSHMultiMessage(.keyExchange(ourMessage), SSHMessage.keyExchangeInit(exchanger.initiateKeyExchangeClientSide(allocator: self.allocator))) + case .server: + result = SSHMultiMessage(.keyExchange(ourMessage)) + } + + self.state = .keyExchangeReceived(exchange: exchanger, negotiated: negotiated, expectingGuess: self.expectingIncorrectGuess(message)) + return result + + case .keyExchangeReceived, .awaitingKeyExchangeInit, .awaitingKeyExchangeInitInvalidGuess, .keyExchangeInitReceived, .keyExchangeInitSent, .keysExchanged, .newKeysSent, .newKeysReceived, .complete: throw SSHKeyExchangeError.unexpectedMessage } } @@ -157,6 +179,20 @@ struct SSHKeyExchangeStateMachine { switch self.state { case .idle: self.state = .keyExchangeSent(message: message) + case .keyExchangeReceived(let exchanger, let negotiated, let expectingGuess): + switch self.role { + case .server: + // Ok, we're waiting for a key exchange init message. + if expectingGuess { + self.state = .awaitingKeyExchangeInitInvalidGuess(exchange: exchanger, negotiated: negotiated) + } else { + self.state = .awaitingKeyExchangeInit(exchange: exchanger, negotiated: negotiated) + } + + case .client: + // We're going to send a key exchange init message. + self.state = .awaitingKeyExchangeInit(exchange: exchanger, negotiated: negotiated) + } case .keyExchangeSent, .awaitingKeyExchangeInit, .awaitingKeyExchangeInitInvalidGuess, .keyExchangeInitReceived, .keyExchangeInitSent, .keysExchanged, .newKeysSent, .newKeysReceived, .complete: // This is a precondition not a throw because we control the sending of this message. preconditionFailure("Cannot send key exchange message after idle") @@ -187,7 +223,7 @@ struct SSHKeyExchangeStateMachine { self.state = .keyExchangeInitReceived(result: result, negotiated: negotiated) return SSHMultiMessage(message, .newKeys) } - case .idle, .keyExchangeSent, .keyExchangeInitReceived, .keyExchangeInitSent, .keysExchanged, .newKeysSent, .newKeysReceived, .complete: + case .idle, .keyExchangeSent, .keyExchangeReceived, .keyExchangeInitReceived, .keyExchangeInitSent, .keysExchanged, .newKeysSent, .newKeysReceived, .complete: throw SSHKeyExchangeError.unexpectedMessage } } @@ -197,7 +233,7 @@ struct SSHKeyExchangeStateMachine { case .awaitingKeyExchangeInit(exchange: let exchanger, negotiated: let negotiated): precondition(self.role.isClient, "Servers must not send ecdh key exchange init messages") self.state = .keyExchangeInitSent(exchange: exchanger, negotiated: negotiated) - case .idle, .keyExchangeSent, .awaitingKeyExchangeInitInvalidGuess, .keyExchangeInitSent, .keyExchangeInitReceived, .keysExchanged, .newKeysSent, .newKeysReceived, .complete: + case .idle, .keyExchangeSent, .keyExchangeReceived, .awaitingKeyExchangeInitInvalidGuess, .keyExchangeInitSent, .keyExchangeInitReceived, .keysExchanged, .newKeysSent, .newKeysReceived, .complete: // This is a precondition not a throw because we control the sending of this message. preconditionFailure("Cannot send ECDH key exchange message in state \(self.state)") } @@ -220,7 +256,7 @@ struct SSHKeyExchangeStateMachine { case .server: preconditionFailure("Servers cannot enter key exchange init sent.") } - case .idle, .keyExchangeSent, .awaitingKeyExchangeInit, .awaitingKeyExchangeInitInvalidGuess, .keyExchangeInitReceived, .keysExchanged, .newKeysSent, .newKeysReceived, .complete: + case .idle, .keyExchangeSent, .keyExchangeReceived, .awaitingKeyExchangeInit, .awaitingKeyExchangeInitInvalidGuess, .keyExchangeInitReceived, .keysExchanged, .newKeysSent, .newKeysReceived, .complete: throw SSHKeyExchangeError.unexpectedMessage } } @@ -230,7 +266,7 @@ struct SSHKeyExchangeStateMachine { case .keyExchangeInitReceived(result: let result, negotiated: let negotiated): precondition(self.role.isServer, "Clients cannot enter key exchange init received") self.state = .keysExchanged(result: result, protection: try negotiated.negotiatedProtection.init(initialKeys: result.keys), negotiated: negotiated) - case .idle, .keyExchangeSent, .awaitingKeyExchangeInit, .awaitingKeyExchangeInitInvalidGuess, .keyExchangeInitSent, .keysExchanged, .newKeysSent, .newKeysReceived, .complete: + case .idle, .keyExchangeSent, .keyExchangeReceived, .awaitingKeyExchangeInit, .awaitingKeyExchangeInitInvalidGuess, .keyExchangeInitSent, .keysExchanged, .newKeysSent, .newKeysReceived, .complete: // This is a precondition not a throw because we control the sending of this message. preconditionFailure("Cannot send ECDH key exchange message in state \(self.state)") } @@ -244,7 +280,7 @@ struct SSHKeyExchangeStateMachine { case .newKeysSent(result: let result, protection: let protection, _): self.state = .complete(result: result) return protection - case .idle, .keyExchangeSent, .awaitingKeyExchangeInit, .awaitingKeyExchangeInitInvalidGuess, .keyExchangeInitSent, .keyExchangeInitReceived, .newKeysReceived, .complete: + case .idle, .keyExchangeSent, .keyExchangeReceived, .awaitingKeyExchangeInit, .awaitingKeyExchangeInitInvalidGuess, .keyExchangeInitSent, .keyExchangeInitReceived, .newKeysReceived, .complete: throw SSHKeyExchangeError.unexpectedMessage } } @@ -257,7 +293,7 @@ struct SSHKeyExchangeStateMachine { case .newKeysReceived(result: let result, protection: let protection, _): self.state = .complete(result: result) return protection - case .idle, .keyExchangeSent, .awaitingKeyExchangeInit, .awaitingKeyExchangeInitInvalidGuess, .keyExchangeInitSent, .keyExchangeInitReceived, .newKeysSent, .complete: + case .idle, .keyExchangeSent, .keyExchangeReceived, .awaitingKeyExchangeInit, .awaitingKeyExchangeInitInvalidGuess, .keyExchangeInitSent, .keyExchangeInitReceived, .newKeysSent, .complete: // This is a precondition not a throw because we control the sending of this message. preconditionFailure("Cannot send ECDH key exchange message in state \(self.state)") } @@ -393,7 +429,7 @@ struct SSHKeyExchangeStateMachine { assert(Self.supportedKeyExchangeAlgorithms.contains(algorithm)) // We only support Curve25519 right now, so we up this to a precondition. precondition(Self.supportedKeyExchangeAlgorithms.contains(algorithm)) - return Curve25519KeyExchange(ourRole: self.role, previousSessionIdentifier: nil) + return Curve25519KeyExchange(ourRole: self.role, previousSessionIdentifier: self.previousSessionIdentifier) } private func expectingIncorrectGuess(_ kexMessage: SSHMessage.KeyExchangeMessage) -> Bool { @@ -467,7 +503,7 @@ extension SSHKeyExchangeStateMachine { .complete(result: let result): return result.sessionID - case .idle, .keyExchangeSent, .awaitingKeyExchangeInit, .awaitingKeyExchangeInitInvalidGuess, .keyExchangeInitSent: + case .idle, .keyExchangeSent, .keyExchangeReceived, .awaitingKeyExchangeInit, .awaitingKeyExchangeInitInvalidGuess, .keyExchangeInitSent: return nil } } diff --git a/Sources/NIOSSH/NIOSSHHandler.swift b/Sources/NIOSSH/NIOSSHHandler.swift index a33b43d..a4b22b1 100644 --- a/Sources/NIOSSH/NIOSSHHandler.swift +++ b/Sources/NIOSSH/NIOSSHHandler.swift @@ -350,6 +350,20 @@ extension NIOSSHHandler { } } +// MARK: Initiate rekeying + +extension NIOSSHHandler { + // This function mostly exists for testing purposes: we don't initiate re-keying today because it's not + // well-supported by evidence. But we want to be able to test against implementations who do, so we have support for + // kicking it off. + internal func _rekey() throws { + // As this is test-only there are a bunch of preconditions in here, we don't really mind if we hit them in testing. + var buffer = self.context!.channel.allocator.buffer(capacity: 1024) + try self.stateMachine.beginRekeying(buffer: &buffer, allocator: self.context!.channel.allocator) + self.context!.writeAndFlush(self.wrapOutboundOut(buffer), promise: nil) + } +} + // MARK: Functions called from the multiplexer extension NIOSSHHandler: SSHMultiplexerDelegate { diff --git a/Sources/NIOSSH/SSHPacketParser.swift b/Sources/NIOSSH/SSHPacketParser.swift index 450fc20..68f3727 100644 --- a/Sources/NIOSSH/SSHPacketParser.swift +++ b/Sources/NIOSSH/SSHPacketParser.swift @@ -47,7 +47,9 @@ struct SSHPacketParser { switch self.state { case .cleartextWaitingForLength: self.state = .encryptedWaitingForLength(protection) - case .cleartextWaitingForBytes, .initialized, .encryptedWaitingForLength, .encryptedWaitingForBytes: + case .encryptedWaitingForLength: + self.state = .encryptedWaitingForLength(protection) + case .cleartextWaitingForBytes, .initialized, .encryptedWaitingForBytes: preconditionFailure("Adding encryption in invalid state: \(self.state)") } } diff --git a/Sources/NIOSSH/SSHPacketSerializer.swift b/Sources/NIOSSH/SSHPacketSerializer.swift index 9b39f56..825408c 100644 --- a/Sources/NIOSSH/SSHPacketSerializer.swift +++ b/Sources/NIOSSH/SSHPacketSerializer.swift @@ -24,12 +24,13 @@ struct SSHPacketSerializer { private var state: State = .initialized /// Encryption schemes can be added to a packet serializer whenever encryption is negotiated. - /// They may only be added once, while the serializer is in an idle state. mutating func addEncryption(_ protection: NIOSSHTransportProtection) { switch self.state { case .cleartext: self.state = .encrypted(protection) - case .initialized, .encrypted: + case .encrypted: + self.state = .encrypted(protection) + case .initialized: preconditionFailure("Adding encryption in invalid state: \(self.state)") } } diff --git a/Sources/NIOSSH/TransportProtection/AESGCM.swift b/Sources/NIOSSH/TransportProtection/AESGCM.swift index 6bba7a5..ecc8f7b 100644 --- a/Sources/NIOSSH/TransportProtection/AESGCM.swift +++ b/Sources/NIOSSH/TransportProtection/AESGCM.swift @@ -119,9 +119,9 @@ extension AESGCMTransportProtection: NIOSSHTransportProtection { func encryptPacket(_ packet: NIOSSHEncryptablePayload, to outboundBuffer: inout ByteBuffer) throws { // Keep track of where the length is going to be written. - let packetLengthIndex = outboundBuffer.readerIndex + let packetLengthIndex = outboundBuffer.writerIndex let packetLengthLength = MemoryLayout.size - let packetPaddingIndex = outboundBuffer.readerIndex + packetLengthLength + let packetPaddingIndex = outboundBuffer.writerIndex + packetLengthLength let packetPaddingLength = MemoryLayout.size outboundBuffer.moveWriterIndex(forwardBy: packetLengthLength + packetPaddingLength) diff --git a/Tests/NIOSSHTests/EndToEndTests.swift b/Tests/NIOSSHTests/EndToEndTests.swift index 2e0656e..b3d169b 100644 --- a/Tests/NIOSSHTests/EndToEndTests.swift +++ b/Tests/NIOSSHTests/EndToEndTests.swift @@ -14,7 +14,7 @@ import Crypto import NIO -import NIOSSH +@testable import NIOSSH import XCTest enum EndToEndTestError: Error { @@ -395,4 +395,36 @@ class EndToEndTests: XCTestCase { XCTAssertEqual(self.channel.activeServerChannels.count, 1) #endif } + + func testSupportClientInitiatedRekeying() throws { + XCTAssertNoThrow(try self.channel.configureWithHarness(TestHarness())) + XCTAssertNoThrow(try self.channel.activate()) + XCTAssertNoThrow(try self.channel.interactInMemory()) + + // Initiate re-keying on the client. + XCTAssertNoThrow(try self.channel.clientSSHHandler!._rekey()) + XCTAssertNoThrow(try self.channel.interactInMemory()) + + // We should be able to send a message here. + XCTAssertEqual(self.channel.activeServerChannels.count, 0) + self.channel.clientSSHHandler?.createChannel(nil, nil) + XCTAssertNoThrow(try self.channel.interactInMemory()) + XCTAssertEqual(self.channel.activeServerChannels.count, 1) + } + + func testSupportServerInitiatedRekeying() throws { + XCTAssertNoThrow(try self.channel.configureWithHarness(TestHarness())) + XCTAssertNoThrow(try self.channel.activate()) + XCTAssertNoThrow(try self.channel.interactInMemory()) + + // Initiate re-keying on the server. + XCTAssertNoThrow(try self.channel.serverSSHHandler!._rekey()) + XCTAssertNoThrow(try self.channel.interactInMemory()) + + // We should be able to send a message here. + XCTAssertEqual(self.channel.activeServerChannels.count, 0) + self.channel.clientSSHHandler?.createChannel(nil, nil) + XCTAssertNoThrow(try self.channel.interactInMemory()) + XCTAssertEqual(self.channel.activeServerChannels.count, 1) + } } diff --git a/Tests/NIOSSHTests/SSHConnectionStateMachineTests.swift b/Tests/NIOSSHTests/SSHConnectionStateMachineTests.swift index 9d44a81..7f2d873 100644 --- a/Tests/NIOSSHTests/SSHConnectionStateMachineTests.swift +++ b/Tests/NIOSSHTests/SSHConnectionStateMachineTests.swift @@ -377,4 +377,23 @@ final class SSHConnectionStateMachineTests: XCTestCase { try assertSuccessfulConnection(client: &client, server: &server, allocator: allocator, loop: loop) try self.assertUnimplementedCausesError(sequenceNumber: 0, sender: &client, receiver: &server, allocator: allocator, loop: loop) } + + func testWeTolerateMessagesAfterSendingKexInit() throws { + let allocator = ByteBufferAllocator() + let loop = EmbeddedEventLoop() + var client = SSHConnectionStateMachine(role: .client(.init(userAuthDelegate: InfinitePasswordDelegate()))) + var server = SSHConnectionStateMachine(role: .server(.init(hostKeys: [NIOSSHPrivateKey(ed25519Key: .init())], userAuthDelegate: DenyThenAcceptDelegate(messagesToDeny: 1)))) + + try assertSuccessfulConnection(client: &client, server: &server, allocator: allocator, loop: loop) + + // Ok, the server is going to try to rekey. + var buffer = allocator.buffer(capacity: 1024) + XCTAssertNoThrow(try server.beginRekeying(buffer: &buffer, allocator: allocator)) + + // We're not passing this to the client though. Now we'll send all the channel messages through: the server should tolerate them + // all. + for message in self.channelMessages { + XCTAssertNoThrow(try self.assertForwardsToMultiplexer(message, sender: &client, receiver: &server, allocator: allocator, loop: loop)) + } + } } diff --git a/Tests/NIOSSHTests/SSHKeyExchangeStateMachineTests.swift b/Tests/NIOSSHTests/SSHKeyExchangeStateMachineTests.swift index 5f9e3b9..b813de6 100644 --- a/Tests/NIOSSHTests/SSHKeyExchangeStateMachineTests.swift +++ b/Tests/NIOSSHTests/SSHKeyExchangeStateMachineTests.swift @@ -24,8 +24,11 @@ final class SSHKeyExchangeStateMachineTests: XCTestCase { case unexpectedMissingMessage } - private func assertGeneratesKeyExchangeMessage(_ messageFactory: @autoclosure () throws -> SSHMultiMessage) throws -> SSHMessage.KeyExchangeMessage { - let message = try assertNoThrowWithValue(messageFactory()) + private func assertGeneratesKeyExchangeMessage(_ messageFactory: @autoclosure () throws -> SSHMultiMessage?) throws -> SSHMessage.KeyExchangeMessage { + guard let message = try assertNoThrowWithValue(messageFactory()) else { + XCTFail("Unexpected missing message") + throw AssertionFailure.unexpectedMissingMessage + } guard message.count == 1 else { XCTFail("Unexpected multiple message (found \(message.count))") @@ -177,18 +180,20 @@ final class SSHKeyExchangeStateMachineTests: XCTestCase { allocator: allocator, role: .client(.init(userAuthDelegate: ExplodingAuthDelegate())), remoteVersion: Constants.version, - protectionSchemes: [AES256GCMOpenSSHTransportProtection.self] + protectionSchemes: [AES256GCMOpenSSHTransportProtection.self], + previousSessionIdentifier: nil ) var server = SSHKeyExchangeStateMachine( allocator: allocator, role: .server(.init(hostKeys: [.init(ed25519Key: .init())], userAuthDelegate: DenyAllServerAuthDelegate())), remoteVersion: Constants.version, - protectionSchemes: [AES256GCMOpenSSHTransportProtection.self] + protectionSchemes: [AES256GCMOpenSSHTransportProtection.self], + previousSessionIdentifier: nil ) // Both sides begin by generating a key exchange message. - let serverMessage = try assertGeneratesKeyExchangeMessage(server.startKeyExchange()) - let clientMessage = try assertGeneratesKeyExchangeMessage(client.startKeyExchange()) + let serverMessage = server.createKeyExchangeMessage() + let clientMessage = client.createKeyExchangeMessage() server.send(keyExchange: serverMessage) client.send(keyExchange: clientMessage) @@ -238,18 +243,20 @@ final class SSHKeyExchangeStateMachineTests: XCTestCase { allocator: allocator, role: .client(.init(userAuthDelegate: ExplodingAuthDelegate())), remoteVersion: Constants.version, - protectionSchemes: [AES256GCMOpenSSHTransportProtection.self] + protectionSchemes: [AES256GCMOpenSSHTransportProtection.self], + previousSessionIdentifier: nil ) var server = SSHKeyExchangeStateMachine( allocator: allocator, role: .server(.init(hostKeys: [.init(ed25519Key: .init())], userAuthDelegate: DenyAllServerAuthDelegate())), remoteVersion: Constants.version, - protectionSchemes: [AES256GCMOpenSSHTransportProtection.self] + protectionSchemes: [AES256GCMOpenSSHTransportProtection.self], + previousSessionIdentifier: nil ) // Both sides begin by generating a key exchange message. - let serverMessage = try assertGeneratesKeyExchangeMessage(server.startKeyExchange()) - let clientMessage = try assertGeneratesKeyExchangeMessage(client.startKeyExchange()) + let serverMessage = server.createKeyExchangeMessage() + let clientMessage = client.createKeyExchangeMessage() server.send(keyExchange: serverMessage) client.send(keyExchange: clientMessage) @@ -286,10 +293,11 @@ final class SSHKeyExchangeStateMachineTests: XCTestCase { role: .server(.init(hostKeys: [.init(ed25519Key: .init())], userAuthDelegate: DenyAllServerAuthDelegate())), remoteVersion: Constants.version, - protectionSchemes: [AES256GCMOpenSSHTransportProtection.self]) + protectionSchemes: [AES256GCMOpenSSHTransportProtection.self], + previousSessionIdentifier: nil) // Server generates a key exchange message. - let serverMessage = try assertGeneratesKeyExchangeMessage(server.startKeyExchange()) + let serverMessage = server.createKeyExchangeMessage() server.send(keyExchange: serverMessage) // Client sends a key exchange that is _subtly_ different from the server (we just add a different key exchange mechanism to the front). @@ -321,13 +329,14 @@ final class SSHKeyExchangeStateMachineTests: XCTestCase { // This test verifies that the state machine forbids extra key exchange messages. // We get the key exchange message out of the server because it's a pain to build by hand. let allocator = ByteBufferAllocator() - var server = SSHKeyExchangeStateMachine( + let server = SSHKeyExchangeStateMachine( allocator: allocator, role: .server(.init(hostKeys: [.init(ed25519Key: .init())], userAuthDelegate: DenyAllServerAuthDelegate())), remoteVersion: Constants.version, - protectionSchemes: [AES256GCMOpenSSHTransportProtection.self] + protectionSchemes: [AES256GCMOpenSSHTransportProtection.self], + previousSessionIdentifier: nil ) - let serverMessage = try assertGeneratesKeyExchangeMessage(server.startKeyExchange()) + let serverMessage = server.createKeyExchangeMessage() try self.assertSendingExtraMessageFails(message: SSHMessage.keyExchange(serverMessage), allowedStages: [.beforeReceiveKeyExchangeClient, .beforeReceiveKeyExchangeServer]) } @@ -353,18 +362,20 @@ final class SSHKeyExchangeStateMachineTests: XCTestCase { allocator: allocator, role: .client(.init(userAuthDelegate: ExplodingAuthDelegate())), remoteVersion: Constants.version, - protectionSchemes: [AES256GCMOpenSSHTransportProtection.self] + protectionSchemes: [AES256GCMOpenSSHTransportProtection.self], + previousSessionIdentifier: nil ) var server = SSHKeyExchangeStateMachine( allocator: allocator, role: .server(.init(hostKeys: [.init(ed25519Key: .init())], userAuthDelegate: DenyAllServerAuthDelegate())), remoteVersion: Constants.version, - protectionSchemes: [AES256GCMOpenSSHTransportProtection.self] + protectionSchemes: [AES256GCMOpenSSHTransportProtection.self], + previousSessionIdentifier: nil ) // Both sides begin by generating a key exchange message. - let serverMessage = try assertGeneratesKeyExchangeMessage(server.startKeyExchange()) - let clientMessage = try assertGeneratesKeyExchangeMessage(client.startKeyExchange()) + let serverMessage = server.createKeyExchangeMessage() + let clientMessage = client.createKeyExchangeMessage() server.send(keyExchange: serverMessage) client.send(keyExchange: clientMessage) @@ -402,18 +413,20 @@ final class SSHKeyExchangeStateMachineTests: XCTestCase { allocator: allocator, role: .client(.init(userAuthDelegate: ExplodingAuthDelegate())), remoteVersion: Constants.version, - protectionSchemes: [AES256GCMOpenSSHTransportProtection.self] + protectionSchemes: [AES256GCMOpenSSHTransportProtection.self], + previousSessionIdentifier: nil ) var server = SSHKeyExchangeStateMachine( allocator: allocator, role: .server(.init(hostKeys: [.init(p256Key: .init())], userAuthDelegate: DenyAllServerAuthDelegate())), remoteVersion: Constants.version, - protectionSchemes: [AES256GCMOpenSSHTransportProtection.self] + protectionSchemes: [AES256GCMOpenSSHTransportProtection.self], + previousSessionIdentifier: nil ) // Both sides begin by generating a key exchange message. - let serverMessage = try assertGeneratesKeyExchangeMessage(server.startKeyExchange()) - let clientMessage = try assertGeneratesKeyExchangeMessage(client.startKeyExchange()) + let serverMessage = server.createKeyExchangeMessage() + let clientMessage = client.createKeyExchangeMessage() server.send(keyExchange: serverMessage) client.send(keyExchange: clientMessage) @@ -447,13 +460,14 @@ final class SSHKeyExchangeStateMachineTests: XCTestCase { var cookies: [ByteBuffer] = [] for _ in 0 ..< 5 { let allocator = ByteBufferAllocator() - var client = SSHKeyExchangeStateMachine( + let client = SSHKeyExchangeStateMachine( allocator: allocator, role: .client(.init(userAuthDelegate: ExplodingAuthDelegate())), remoteVersion: Constants.version, - protectionSchemes: [AES256GCMOpenSSHTransportProtection.self] + protectionSchemes: [AES256GCMOpenSSHTransportProtection.self], + previousSessionIdentifier: nil ) - let clientMessage = try assertGeneratesKeyExchangeMessage(client.startKeyExchange()) + let clientMessage = client.createKeyExchangeMessage() cookies.append(clientMessage.cookie) } @@ -474,18 +488,20 @@ final class SSHKeyExchangeStateMachineTests: XCTestCase { allocator: allocator, role: .client(.init(userAuthDelegate: ExplodingAuthDelegate())), remoteVersion: Constants.version, - protectionSchemes: [AES256GCMOpenSSHTransportProtection.self] + protectionSchemes: [AES256GCMOpenSSHTransportProtection.self], + previousSessionIdentifier: nil ) var server = SSHKeyExchangeStateMachine( allocator: allocator, role: .server(.init(hostKeys: [.init(ed25519Key: .init())], userAuthDelegate: DenyAllServerAuthDelegate())), remoteVersion: Constants.version, - protectionSchemes: [AES128GCMOpenSSHTransportProtection.self] + protectionSchemes: [AES128GCMOpenSSHTransportProtection.self], + previousSessionIdentifier: nil ) // Both sides begin by generating a key exchange message. - let serverMessage = try assertGeneratesKeyExchangeMessage(server.startKeyExchange()) - let clientMessage = try assertGeneratesKeyExchangeMessage(client.startKeyExchange()) + let serverMessage = server.createKeyExchangeMessage() + let clientMessage = client.createKeyExchangeMessage() server.send(keyExchange: serverMessage) client.send(keyExchange: clientMessage) @@ -509,18 +525,20 @@ final class SSHKeyExchangeStateMachineTests: XCTestCase { allocator: allocator, role: .client(.init(userAuthDelegate: ExplodingAuthDelegate())), remoteVersion: Constants.version, - protectionSchemes: [AES128GCMOpenSSHTransportProtection.self, AES256GCMOpenSSHTransportProtection.self] + protectionSchemes: [AES128GCMOpenSSHTransportProtection.self, AES256GCMOpenSSHTransportProtection.self], + previousSessionIdentifier: nil ) var server = SSHKeyExchangeStateMachine( allocator: allocator, role: .server(.init(hostKeys: [.init(ed25519Key: .init())], userAuthDelegate: DenyAllServerAuthDelegate())), remoteVersion: Constants.version, - protectionSchemes: [AES256GCMOpenSSHTransportProtection.self, AES128GCMOpenSSHTransportProtection.self] + protectionSchemes: [AES256GCMOpenSSHTransportProtection.self, AES128GCMOpenSSHTransportProtection.self], + previousSessionIdentifier: nil ) // Both sides begin by generating a key exchange message. - let serverMessage = try assertGeneratesKeyExchangeMessage(server.startKeyExchange()) - let clientMessage = try assertGeneratesKeyExchangeMessage(client.startKeyExchange()) + let serverMessage = server.createKeyExchangeMessage() + let clientMessage = client.createKeyExchangeMessage() server.send(keyExchange: serverMessage) client.send(keyExchange: clientMessage) @@ -549,6 +567,121 @@ final class SSHKeyExchangeStateMachineTests: XCTestCase { self.assertCompatibleProtection(client: clientInboundProtection, server: serverInboundProtection) XCTAssertTrue(clientInboundProtection is AES128GCMOpenSSHTransportProtection) } + + func testWeCanReExchangeKeysClientInitiates() throws { + // This tests the specific message flow for re-exchange: namely, a key exchange message arrives in idle. + // We should tolerate this flow. + let allocator = ByteBufferAllocator() + + var client = SSHKeyExchangeStateMachine( + allocator: allocator, + role: .client(.init(userAuthDelegate: ExplodingAuthDelegate())), + remoteVersion: Constants.version, + protectionSchemes: [AES128GCMOpenSSHTransportProtection.self], + previousSessionIdentifier: nil + ) + var server = SSHKeyExchangeStateMachine( + allocator: allocator, + role: .server(.init(hostKeys: [.init(ed25519Key: .init())], userAuthDelegate: DenyAllServerAuthDelegate())), + remoteVersion: Constants.version, + protectionSchemes: [AES128GCMOpenSSHTransportProtection.self], + previousSessionIdentifier: nil + ) + + // Here only the client generates a key exchange message. + let clientMessage = client.createKeyExchangeMessage() + client.send(keyExchange: clientMessage) + + // The server generates a key exchange message in response. + let serverMessage = try self.assertGeneratesKeyExchangeMessage(server.handle(keyExchange: clientMessage)) + server.send(keyExchange: serverMessage) + let ecdhInit = try assertGeneratesECDHKeyExchangeInit(client.handle(keyExchange: serverMessage)) + client.send(keyExchangeInit: ecdhInit) + + // From here on, the flow is the same as in the regular key exchange. + // Now the server receives the ECDH init message and generates the reply, as well as the newKeys message. + let ecdhReply = try assertGeneratesECDHKeyExchangeReplyAndNewKeys(server.handle(keyExchangeInit: ecdhInit)) + XCTAssertNoThrow(try server.send(keyExchangeReply: ecdhReply)) + let serverOutboundProtection = server.sendNewKeys() + + // Now the client receives the reply, and generates a newKeys message. + try self.assertGeneratesNewKeys(client.handle(keyExchangeReply: ecdhReply)) + let clientOutboundProtection = client.sendNewKeys() + + // Both peers receive the newKeys messages. + let clientInboundProtection = try assertNoThrowWithValue(client.handleNewKeys()) + let serverInboundProtection = try assertNoThrowWithValue(server.handleNewKeys()) + + // Each peer has generated the exact same protection object for both directions. + XCTAssertTrue(clientInboundProtection === clientOutboundProtection) + XCTAssertTrue(serverInboundProtection === serverOutboundProtection) + + self.assertCompatibleProtection(client: clientInboundProtection, server: serverInboundProtection) + XCTAssertTrue(clientInboundProtection is AES128GCMOpenSSHTransportProtection) + } + + func testWeCanReExchangeKeysServerInitiates() throws { + // This tests the specific message flow for re-exchange: namely, a key exchange message arrives in idle. + // We should tolerate this flow. + let allocator = ByteBufferAllocator() + + var client = SSHKeyExchangeStateMachine( + allocator: allocator, + role: .client(.init(userAuthDelegate: ExplodingAuthDelegate())), + remoteVersion: Constants.version, + protectionSchemes: [AES128GCMOpenSSHTransportProtection.self], + previousSessionIdentifier: nil + ) + var server = SSHKeyExchangeStateMachine( + allocator: allocator, + role: .server(.init(hostKeys: [.init(ed25519Key: .init())], userAuthDelegate: DenyAllServerAuthDelegate())), + remoteVersion: Constants.version, + protectionSchemes: [AES128GCMOpenSSHTransportProtection.self], + previousSessionIdentifier: nil + ) + + // Here only the server generates a key exchange message. + let serverMessage = server.createKeyExchangeMessage() + server.send(keyExchange: serverMessage) + + // The client generates both a key exchange message and a ecdh key exchange init message. + guard let clientMessages = try client.handle(keyExchange: serverMessage) else { + XCTFail("Client didn't generate messages") + return + } + guard let firstClientMessage = clientMessages.first, let secondClientMessage = clientMessages.dropFirst().first else { + XCTFail("Unexpected number of client messages: \(clientMessages)") + return + } + let clientInitMessage = try self.assertGeneratesKeyExchangeMessage(SSHMultiMessage(firstClientMessage)) + let clientEcdhInitMessage = try self.assertGeneratesECDHKeyExchangeInit(SSHMultiMessage(secondClientMessage)) + client.send(keyExchange: clientInitMessage) + client.send(keyExchangeInit: clientEcdhInitMessage) + + // The server now receives these messages. + try self.assertGeneratesNoMessage(server.handle(keyExchange: clientInitMessage)) + + // From here on, the flow is the same as in the regular key exchange. + // Now the server receives the ECDH init message and generates the reply, as well as the newKeys message. + let ecdhReply = try assertGeneratesECDHKeyExchangeReplyAndNewKeys(server.handle(keyExchangeInit: clientEcdhInitMessage)) + XCTAssertNoThrow(try server.send(keyExchangeReply: ecdhReply)) + let serverOutboundProtection = server.sendNewKeys() + + // Now the client receives the reply, and generates a newKeys message. + try self.assertGeneratesNewKeys(client.handle(keyExchangeReply: ecdhReply)) + let clientOutboundProtection = client.sendNewKeys() + + // Both peers receive the newKeys messages. + let clientInboundProtection = try assertNoThrowWithValue(client.handleNewKeys()) + let serverInboundProtection = try assertNoThrowWithValue(server.handleNewKeys()) + + // Each peer has generated the exact same protection object for both directions. + XCTAssertTrue(clientInboundProtection === clientOutboundProtection) + XCTAssertTrue(serverInboundProtection === serverOutboundProtection) + + self.assertCompatibleProtection(client: clientInboundProtection, server: serverInboundProtection) + XCTAssertTrue(clientInboundProtection is AES128GCMOpenSSHTransportProtection) + } } extension SSHKeyExchangeStateMachineTests {