diff --git a/Sources/NIOSSH/Child Channels/SSHChannelMultiplexer.swift b/Sources/NIOSSH/Child Channels/SSHChannelMultiplexer.swift index 1e307f9..1638460 100644 --- a/Sources/NIOSSH/Child Channels/SSHChannelMultiplexer.swift +++ b/Sources/NIOSSH/Child Channels/SSHChannelMultiplexer.swift @@ -31,6 +31,9 @@ final class SSHChannelMultiplexer { private var childChannelInitializer: SSHChildChannel.Initializer? + /// Whether new channels are allowed. Set to `false` once the parent channel is shut down at the TCP level. + private var canCreateNewChannels: Bool + init(delegate: SSHMultiplexerDelegate, allocator: ByteBufferAllocator, childChannelInitializer: SSHChildChannel.Initializer?) { self.channels = [:] self.channels.reserveCapacity(8) @@ -39,6 +42,7 @@ final class SSHChannelMultiplexer { self.nextChannelID = 0 self.allocator = allocator self.childChannelInitializer = childChannelInitializer + self.canCreateNewChannels = true } // Time to clean up. We drop references to things that may be keeping us alive. @@ -46,6 +50,7 @@ final class SSHChannelMultiplexer { func parentHandlerRemoved() { self.delegate = nil self.childChannelInitializer = nil + self.canCreateNewChannels = false } } @@ -160,6 +165,7 @@ extension SSHChannelMultiplexer { } func parentChannelInactive() { + self.canCreateNewChannels = false for channel in self.channels.values { channel.parentChannelInactive() } @@ -171,6 +177,10 @@ extension SSHChannelMultiplexer { throw NIOSSHError.protocolViolation(protocolName: "channel", violation: "Opening new channel after channel shutdown") } + guard self.canCreateNewChannels else { + throw NIOSSHError.tcpShutdown + } + // TODO: We need a better channel ID system. Maybe use indices into Arrays instead? let channelID = self.nextChannelID self.nextChannelID &+= 1 diff --git a/Tests/NIOSSHTests/ChildChannelMultiplexerTests.swift b/Tests/NIOSSHTests/ChildChannelMultiplexerTests.swift index 33b87ef..b51a9e8 100644 --- a/Tests/NIOSSHTests/ChildChannelMultiplexerTests.swift +++ b/Tests/NIOSSHTests/ChildChannelMultiplexerTests.swift @@ -1552,4 +1552,26 @@ final class ChildChannelMultiplexerTests: XCTestCase { harness.multiplexer.parentChannelReadComplete() XCTAssertEqual(readCounter.readCount, 2) } + + func testTCPCloseBeforeInitializer() throws { + let harness = self.harnessForbiddingInboundChannels() + defer { + harness.finish() + } + + let childPromise: EventLoopPromise = harness.eventLoop.makePromise() + + var childPromiseError: Error? + childPromise.futureResult.whenFailure { error in childPromiseError = error } + + // TCP Close + harness.multiplexer.parentChannelInactive() + harness.multiplexer.createChildChannel(childPromise, channelType: .session) { channel, _ in + channel.eventLoop.makeSucceededFuture(()) + } + harness.eventLoop.run() + + XCTAssertEqual(harness.flushedMessages.count, 0) + XCTAssertEqual((childPromiseError as? NIOSSHError?)??.type, .tcpShutdown) + } }