Skip to content
Merged
Next Next commit
Add SynchronousOperations version of Position for non sendable handlers
  • Loading branch information
0xTim committed Jan 16, 2025
commit 09d6bcd50606705bcd7229b4b0c069237ec2f451
58 changes: 46 additions & 12 deletions Sources/NIOCore/ChannelPipeline.swift
Original file line number Diff line number Diff line change
Expand Up @@ -175,11 +175,12 @@ public final class ChannelPipeline: ChannelInvoker {
) -> EventLoopFuture<Void> {
let future: EventLoopFuture<Void>

let syncPosition = ChannelPipeline.SynchronousOperations.Position(position)
if self.eventLoop.inEventLoop {
future = self.eventLoop.makeCompletedFuture(self.addHandlerSync(handler, name: name, position: position))
future = self.eventLoop.makeCompletedFuture(self.addHandlerSync(handler, name: name, position: syncPosition))
} else {
future = self.eventLoop.submit {
try self.addHandlerSync(handler, name: name, position: position).get()
try self.addHandlerSync(handler, name: name, position: syncPosition).get()
}
}

Expand All @@ -198,7 +199,7 @@ public final class ChannelPipeline: ChannelInvoker {
fileprivate func addHandlerSync(
_ handler: ChannelHandler,
name: String? = nil,
position: ChannelPipeline.Position = .last
position: ChannelPipeline.SynchronousOperations.Position = .last
) -> Result<Void, Error> {
self.eventLoop.assertInEventLoop()

Expand Down Expand Up @@ -1122,11 +1123,12 @@ extension ChannelPipeline {
_ handlers: [ChannelHandler & Sendable],
position: ChannelPipeline.Position
) -> Result<Void, Error> {
switch position {
let syncPosition = ChannelPipeline.SynchronousOperations.Position(position)
switch syncPosition {
case .first, .after:
return self._addHandlersSync(handlers.reversed(), position: position)
return self._addHandlersSync(handlers.reversed(), position: syncPosition)
case .last, .before:
return self._addHandlersSync(handlers, position: position)
return self._addHandlersSync(handlers, position: syncPosition)
}
}

Expand All @@ -1143,7 +1145,7 @@ extension ChannelPipeline {
/// - Returns: A result representing whether the handlers were added or not.
fileprivate func addHandlersSyncNotSendable(
_ handlers: [ChannelHandler],
position: ChannelPipeline.Position
position: ChannelPipeline.SynchronousOperations.Position
) -> Result<Void, Error> {
switch position {
case .first, .after:
Expand All @@ -1162,7 +1164,7 @@ extension ChannelPipeline {
/// - Returns: A result representing whether the handlers were added or not.
private func _addHandlersSync<Handlers: Sequence>(
_ handlers: Handlers,
position: ChannelPipeline.Position
position: ChannelPipeline.SynchronousOperations.Position
) -> Result<Void, Error> where Handlers.Element == ChannelHandler & Sendable {
self.eventLoop.assertInEventLoop()

Expand Down Expand Up @@ -1191,7 +1193,7 @@ extension ChannelPipeline {
/// - Returns: A result representing whether the handlers were added or not.
private func _addHandlersSyncNotSendable<Handlers: Sequence>(
_ handlers: Handlers,
position: ChannelPipeline.Position
position: ChannelPipeline.SynchronousOperations.Position
) -> Result<Void, Error> where Handlers.Element == ChannelHandler {
self.eventLoop.assertInEventLoop()

Expand Down Expand Up @@ -1238,7 +1240,7 @@ extension ChannelPipeline {
public func addHandler(
_ handler: ChannelHandler,
name: String? = nil,
position: ChannelPipeline.Position = .last
position: ChannelPipeline.SynchronousOperations.Position = .last
) throws {
try self._pipeline.addHandlerSync(handler, name: name, position: position).get()
}
Expand All @@ -1251,7 +1253,7 @@ extension ChannelPipeline {
/// - position: The position in the `ChannelPipeline` to add `handlers`. Defaults to `.last`.
public func addHandlers(
_ handlers: [ChannelHandler],
position: ChannelPipeline.Position = .last
position: ChannelPipeline.SynchronousOperations.Position = .last
) throws {
try self._pipeline.addHandlersSyncNotSendable(handlers, position: position).get()
}
Expand All @@ -1264,7 +1266,7 @@ extension ChannelPipeline {
/// - position: The position in the `ChannelPipeline` to add `handlers`. Defaults to `.last`.
public func addHandlers(
_ handlers: ChannelHandler...,
position: ChannelPipeline.Position = .last
position: ChannelPipeline.SynchronousOperations.Position = .last
) throws {
try self._pipeline.addHandlersSyncNotSendable(handlers, position: position).get()
}
Expand Down Expand Up @@ -1574,6 +1576,38 @@ extension ChannelPipeline {
}
}

extension ChannelPipeline.SynchronousOperations {
/// A `Position` within the `ChannelPipeline`'s `SynchronousOperations` used to insert non-sendable handlers
/// into the `ChannelPipeline` at a certain position.
@preconcurrency
public enum Position {
/// The first `ChannelHandler` -- the front of the `ChannelPipeline`.
case first

/// The last `ChannelHandler` -- the back of the `ChannelPipeline`.
case last

/// Before the given `ChannelHandler`.
case before(ChannelHandler)

/// After the given `ChannelHandler`.
case after(ChannelHandler)

package init(_ position: ChannelPipeline.Position) {
switch position {
case .first:
self = .first
case .last:
self = .last
case .before(let handler):
self = .before(handler)
case .after(let handler):
self = .after(handler)
}
}
}
}

/// Special `ChannelHandler` that forwards all events to the `Channel.Unsafe` implementation.
final class HeadChannelHandler: _ChannelOutboundHandler, Sendable {

Expand Down
40 changes: 22 additions & 18 deletions Sources/NIOHTTP1/HTTPPipelineSetup.swift
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,11 @@ extension ChannelPipeline {
) -> EventLoopFuture<Void> {
let future: EventLoopFuture<Void>

let syncPosition = ChannelPipeline.SynchronousOperations.Position(position)
if self.eventLoop.inEventLoop {
let result = Result<Void, Error> {
try self.syncOperations.addHTTPClientHandlers(
position: position,
position: syncPosition,
leftOverBytesStrategy: leftOverBytesStrategy,
withClientUpgrade: upgrade
)
Expand All @@ -95,7 +96,7 @@ extension ChannelPipeline {
} else {
future = self.eventLoop.submit {
try self.syncOperations.addHTTPClientHandlers(
position: position,
position: syncPosition,
leftOverBytesStrategy: leftOverBytesStrategy,
withClientUpgrade: upgrade
)
Expand Down Expand Up @@ -126,10 +127,11 @@ extension ChannelPipeline {
) -> EventLoopFuture<Void> {
let future: EventLoopFuture<Void>

let syncPosition = ChannelPipeline.SynchronousOperations.Position(position)
if self.eventLoop.inEventLoop {
let result = Result<Void, Error> {
try self.syncOperations.addHTTPClientHandlers(
position: position,
position: syncPosition,
leftOverBytesStrategy: leftOverBytesStrategy,
enableOutboundHeaderValidation: enableOutboundHeaderValidation,
withClientUpgrade: upgrade
Expand All @@ -139,7 +141,7 @@ extension ChannelPipeline {
} else {
future = self.eventLoop.submit {
try self.syncOperations.addHTTPClientHandlers(
position: position,
position: syncPosition,
leftOverBytesStrategy: leftOverBytesStrategy,
enableOutboundHeaderValidation: enableOutboundHeaderValidation,
withClientUpgrade: upgrade
Expand Down Expand Up @@ -173,10 +175,11 @@ extension ChannelPipeline {
) -> EventLoopFuture<Void> {
let future: EventLoopFuture<Void>

let syncPosition = ChannelPipeline.SynchronousOperations.Position(position)
if self.eventLoop.inEventLoop {
let result = Result<Void, Error> {
try self.syncOperations.addHTTPClientHandlers(
position: position,
position: syncPosition,
leftOverBytesStrategy: leftOverBytesStrategy,
enableOutboundHeaderValidation: enableOutboundHeaderValidation,
encoderConfiguration: encoderConfiguration,
Expand All @@ -187,7 +190,7 @@ extension ChannelPipeline {
} else {
future = self.eventLoop.submit {
try self.syncOperations.addHTTPClientHandlers(
position: position,
position: syncPosition,
leftOverBytesStrategy: leftOverBytesStrategy,
enableOutboundHeaderValidation: enableOutboundHeaderValidation,
encoderConfiguration: encoderConfiguration,
Expand Down Expand Up @@ -342,10 +345,11 @@ extension ChannelPipeline {
) -> EventLoopFuture<Void> {
let future: EventLoopFuture<Void>

let syncPosition = ChannelPipeline.SynchronousOperations.Position(position)
if self.eventLoop.inEventLoop {
let result = Result<Void, Error> {
try self.syncOperations.configureHTTPServerPipeline(
position: position,
position: syncPosition,
withPipeliningAssistance: pipelining,
withServerUpgrade: upgrade,
withErrorHandling: errorHandling,
Expand All @@ -357,7 +361,7 @@ extension ChannelPipeline {
} else {
future = self.eventLoop.submit {
try self.syncOperations.configureHTTPServerPipeline(
position: position,
position: syncPosition,
withPipeliningAssistance: pipelining,
withServerUpgrade: upgrade,
withErrorHandling: errorHandling,
Expand Down Expand Up @@ -386,7 +390,7 @@ extension ChannelPipeline.SynchronousOperations {
/// - Throws: If the pipeline could not be configured.
@preconcurrency
public func addHTTPClientHandlers(
position: ChannelPipeline.Position = .last,
position: ChannelPipeline.SynchronousOperations.Position = .last,
leftOverBytesStrategy: RemoveAfterUpgradeStrategy = .dropBytes,
withClientUpgrade upgrade: NIOHTTPClientUpgradeConfiguration? = nil
) throws {
Expand All @@ -411,7 +415,7 @@ extension ChannelPipeline.SynchronousOperations {
/// for more details.
/// - Throws: If the pipeline could not be configured.
public func addHTTPClientHandlers(
position: ChannelPipeline.Position = .last,
position: ChannelPipeline.SynchronousOperations.Position = .last,
leftOverBytesStrategy: RemoveAfterUpgradeStrategy = .dropBytes,
enableOutboundHeaderValidation: Bool = true,
withClientUpgrade upgrade: NIOHTTPClientUpgradeConfiguration? = nil
Expand Down Expand Up @@ -439,7 +443,7 @@ extension ChannelPipeline.SynchronousOperations {
/// for more details.
/// - Throws: If the pipeline could not be configured.
public func addHTTPClientHandlers(
position: ChannelPipeline.Position = .last,
position: ChannelPipeline.SynchronousOperations.Position = .last,
leftOverBytesStrategy: RemoveAfterUpgradeStrategy = .dropBytes,
enableOutboundHeaderValidation: Bool = true,
encoderConfiguration: HTTPRequestEncoder.Configuration = .init(),
Expand All @@ -455,7 +459,7 @@ extension ChannelPipeline.SynchronousOperations {
}

private func _addHTTPClientHandlers(
position: ChannelPipeline.Position = .last,
position: ChannelPipeline.SynchronousOperations.Position = .last,
leftOverBytesStrategy: RemoveAfterUpgradeStrategy = .dropBytes,
enableOutboundHeaderValidation: Bool = true,
encoderConfiguration: HTTPRequestEncoder.Configuration = .init(),
Expand All @@ -481,7 +485,7 @@ extension ChannelPipeline.SynchronousOperations {
}

private func _addHTTPClientHandlers(
position: ChannelPipeline.Position,
position: ChannelPipeline.SynchronousOperations.Position,
leftOverBytesStrategy: RemoveAfterUpgradeStrategy,
encoderConfiguration: HTTPRequestEncoder.Configuration
) throws {
Expand All @@ -496,7 +500,7 @@ extension ChannelPipeline.SynchronousOperations {
}

private func _addHTTPClientHandlersFallback(
position: ChannelPipeline.Position,
position: ChannelPipeline.SynchronousOperations.Position,
leftOverBytesStrategy: RemoveAfterUpgradeStrategy,
enableOutboundHeaderValidation: Bool,
encoderConfiguration: HTTPRequestEncoder.Configuration,
Expand Down Expand Up @@ -550,7 +554,7 @@ extension ChannelPipeline.SynchronousOperations {
/// - Throws: If the pipeline could not be configured.
@preconcurrency
public func configureHTTPServerPipeline(
position: ChannelPipeline.Position = .last,
position: ChannelPipeline.SynchronousOperations.Position = .last,
withPipeliningAssistance pipelining: Bool = true,
withServerUpgrade upgrade: NIOHTTPServerUpgradeConfiguration? = nil,
withErrorHandling errorHandling: Bool = true
Expand Down Expand Up @@ -594,7 +598,7 @@ extension ChannelPipeline.SynchronousOperations {
/// spec compliance. Defaults to `true`.
/// - Throws: If the pipeline could not be configured.
public func configureHTTPServerPipeline(
position: ChannelPipeline.Position = .last,
position: ChannelPipeline.SynchronousOperations.Position = .last,
withPipeliningAssistance pipelining: Bool = true,
withServerUpgrade upgrade: NIOHTTPServerUpgradeConfiguration? = nil,
withErrorHandling errorHandling: Bool = true,
Expand Down Expand Up @@ -641,7 +645,7 @@ extension ChannelPipeline.SynchronousOperations {
/// - encoderConfiguration: The configuration for the ``HTTPRequestEncoder``.
/// - Throws: If the pipeline could not be configured.
public func configureHTTPServerPipeline(
position: ChannelPipeline.Position = .last,
position: ChannelPipeline.SynchronousOperations.Position = .last,
withPipeliningAssistance pipelining: Bool = true,
withServerUpgrade upgrade: NIOHTTPServerUpgradeConfiguration? = nil,
withErrorHandling errorHandling: Bool = true,
Expand All @@ -659,7 +663,7 @@ extension ChannelPipeline.SynchronousOperations {
}

private func _configureHTTPServerPipeline(
position: ChannelPipeline.Position = .last,
position: ChannelPipeline.SynchronousOperations.Position = .last,
withPipeliningAssistance pipelining: Bool = true,
withServerUpgrade upgrade: NIOHTTPServerUpgradeConfiguration? = nil,
withErrorHandling errorHandling: Bool = true,
Expand Down