Skip to content

Commit ac87f00

Browse files
geoffkizerGeoffrey Kizerstephentoub
authored
cleanup cancellation handling in SocketAsyncContext (#53479)
* cleanup cancellation handling in SocketAsyncContext * fix MacOS failback * Apply suggestions from code review Co-authored-by: Stephen Toub <[email protected]> * Update src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncContext.Unix.cs Co-authored-by: Stephen Toub <[email protected]> * address feedback Co-authored-by: Geoffrey Kizer <[email protected]> Co-authored-by: Stephen Toub <[email protected]>
1 parent 14343bd commit ac87f00

File tree

1 file changed

+90
-113
lines changed

1 file changed

+90
-113
lines changed

src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncContext.Unix.cs

Lines changed: 90 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,10 @@ private abstract class AsyncOperation : IThreadPoolWorkItem
113113
private enum State
114114
{
115115
Waiting = 0,
116-
Running = 1,
117-
Complete = 2,
118-
Cancelled = 3
116+
Running,
117+
RunningWithPendingCancellation,
118+
Complete,
119+
Canceled
119120
}
120121

121122
private int _state; // Actually AsyncOperation.State.
@@ -149,92 +150,103 @@ public void Reset()
149150
#endif
150151
}
151152

152-
public bool TryComplete(SocketAsyncContext context)
153+
public OperationResult TryComplete(SocketAsyncContext context)
153154
{
154155
TraceWithContext(context, "Enter");
155156

156-
bool result = DoTryComplete(context);
157-
158-
TraceWithContext(context, $"Exit, result={result}");
157+
// Set state to Running, unless we've been canceled
158+
int oldState = Interlocked.CompareExchange(ref _state, (int)State.Running, (int)State.Waiting);
159+
if (oldState == (int)State.Canceled)
160+
{
161+
TraceWithContext(context, "Exit, Previously canceled");
162+
return OperationResult.Cancelled;
163+
}
159164

160-
return result;
161-
}
165+
Debug.Assert(oldState == (int)State.Waiting, $"Unexpected operation state: {(State)oldState}");
162166

163-
public bool TrySetRunning()
164-
{
165-
State oldState = (State)Interlocked.CompareExchange(ref _state, (int)State.Running, (int)State.Waiting);
166-
if (oldState == State.Cancelled)
167+
// Try to perform the IO
168+
if (DoTryComplete(context))
167169
{
168-
// This operation has already been cancelled, and had its completion processed.
169-
// Simply return false to indicate no further processing is needed.
170-
return false;
170+
Debug.Assert((State)Volatile.Read(ref _state) is State.Running or State.RunningWithPendingCancellation, "Unexpected operation state");
171+
172+
Volatile.Write(ref _state, (int)State.Complete);
173+
174+
TraceWithContext(context, "Exit, Completed");
175+
return OperationResult.Completed;
171176
}
172177

173-
Debug.Assert(oldState == (int)State.Waiting);
174-
return true;
175-
}
178+
// Set state back to Waiting, unless we were canceled, in which case we have to process cancellation now
179+
int newState;
180+
while (true)
181+
{
182+
int state = Volatile.Read(ref _state);
183+
Debug.Assert(state is (int)State.Running or (int)State.RunningWithPendingCancellation, $"Unexpected operation state: {(State)state}");
176184

177-
public void SetComplete()
178-
{
179-
Debug.Assert(Volatile.Read(ref _state) == (int)State.Running);
185+
newState = (state == (int)State.Running ? (int)State.Waiting : (int)State.Canceled);
186+
if (state == Interlocked.CompareExchange(ref _state, newState, state))
187+
{
188+
break;
189+
}
180190

181-
Volatile.Write(ref _state, (int)State.Complete);
182-
}
191+
// Race to update the state. Loop and try again.
192+
}
183193

184-
public void SetWaiting()
185-
{
186-
Debug.Assert(Volatile.Read(ref _state) == (int)State.Running);
194+
if (newState == (int)State.Canceled)
195+
{
196+
ProcessCancellation();
197+
TraceWithContext(context, "Exit, Newly cancelled");
198+
return OperationResult.Cancelled;
199+
}
187200

188-
Volatile.Write(ref _state, (int)State.Waiting);
201+
TraceWithContext(context, "Exit, Pending");
202+
return OperationResult.Pending;
189203
}
190204

191205
public bool TryCancel()
192206
{
193207
Trace("Enter");
194208

195-
// We're already canceling, so we don't need to still be hooked up to listen to cancellation.
196-
// The cancellation request could also be caused by something other than the token, so it's
197-
// important we clean it up, regardless.
209+
// Note we could be cancelling because of socket close. Regardless, we don't need the registration anymore.
198210
CancellationRegistration.Dispose();
199211

200-
// Try to transition from Waiting to Cancelled
201-
SpinWait spinWait = default;
202-
bool keepWaiting = true;
203-
while (keepWaiting)
212+
int newState;
213+
while (true)
204214
{
205-
int state = Interlocked.CompareExchange(ref _state, (int)State.Cancelled, (int)State.Waiting);
206-
switch ((State)state)
215+
int state = Volatile.Read(ref _state);
216+
if (state is (int)State.Complete or (int)State.Canceled or (int)State.RunningWithPendingCancellation)
207217
{
208-
case State.Running:
209-
// A completion attempt is in progress. Keep busy-waiting.
210-
Trace("Busy wait");
211-
spinWait.SpinOnce();
212-
break;
218+
return false;
219+
}
213220

214-
case State.Complete:
215-
// A completion attempt succeeded. Consider this operation as having completed within the timeout.
216-
Trace("Exit, previously completed");
217-
return false;
221+
newState = (state == (int)State.Waiting ? (int)State.Canceled : (int)State.RunningWithPendingCancellation);
222+
if (state == Interlocked.CompareExchange(ref _state, newState, state))
223+
{
224+
break;
225+
}
218226

219-
case State.Waiting:
220-
// This operation was successfully cancelled.
221-
// Break out of the loop to handle the cancellation
222-
keepWaiting = false;
223-
break;
227+
// Race to update the state. Loop and try again.
228+
}
224229

225-
case State.Cancelled:
226-
// Someone else cancelled the operation.
227-
// The previous canceller will have fired the completion, etc.
228-
Trace("Exit, previously cancelled");
229-
return false;
230-
}
230+
if (newState == (int)State.RunningWithPendingCancellation)
231+
{
232+
// TryComplete will either succeed, or it will see the pending cancellation and deal with it.
233+
return false;
231234
}
232235

233-
Trace("Cancelled, processing completion");
236+
ProcessCancellation();
234237

235-
// The operation successfully cancelled.
236-
// It's our responsibility to set the error code and queue the completion.
237-
DoAbort();
238+
// Note, we leave the operation in the OperationQueue.
239+
// When we get around to processing it, we'll see it's cancelled and skip it.
240+
return true;
241+
}
242+
243+
public void ProcessCancellation()
244+
{
245+
Trace("Enter");
246+
247+
Debug.Assert(_state == (int)State.Canceled);
248+
249+
ErrorCode = SocketError.OperationAborted;
238250

239251
ManualResetEventSlim? e = Event;
240252
if (e != null)
@@ -252,12 +264,6 @@ public bool TryCancel()
252264
// to do further processing on the item that's still in the list.
253265
ThreadPool.UnsafeQueueUserWorkItem(o => ((AsyncOperation)o!).InvokeCallback(allowPooling: false), this);
254266
}
255-
256-
Trace("Exit");
257-
258-
// Note, we leave the operation in the OperationQueue.
259-
// When we get around to processing it, we'll see it's cancelled and skip it.
260-
return true;
261267
}
262268

263269
public void Dispatch()
@@ -306,12 +312,9 @@ void IThreadPoolWorkItem.Execute()
306312
// Called when op is not in the queue yet, so can't be otherwise executing
307313
public void DoAbort()
308314
{
309-
Abort();
310315
ErrorCode = SocketError.OperationAborted;
311316
}
312317

313-
protected abstract void Abort();
314-
315318
protected abstract bool DoTryComplete(SocketAsyncContext context);
316319

317320
public abstract void InvokeCallback(bool allowPooling);
@@ -354,8 +357,6 @@ private abstract class SendOperation : WriteOperation
354357

355358
public SendOperation(SocketAsyncContext context) : base(context) { }
356359

357-
protected sealed override void Abort() { }
358-
359360
public Action<int, byte[]?, int, SocketFlags, SocketError>? Callback { get; set; }
360361

361362
public override void InvokeCallback(bool allowPooling) =>
@@ -442,8 +443,6 @@ private abstract class ReceiveOperation : ReadOperation
442443

443444
public ReceiveOperation(SocketAsyncContext context) : base(context) { }
444445

445-
protected sealed override void Abort() { }
446-
447446
public Action<int, byte[]?, int, SocketFlags, SocketError>? Callback { get; set; }
448447

449448
public override void InvokeCallback(bool allowPooling) =>
@@ -554,8 +553,6 @@ private sealed class ReceiveMessageFromOperation : ReadOperation
554553

555554
public ReceiveMessageFromOperation(SocketAsyncContext context) : base(context) { }
556555

557-
protected sealed override void Abort() { }
558-
559556
public Action<int, byte[], int, SocketFlags, IPPacketInformation, SocketError>? Callback { get; set; }
560557

561558
protected override bool DoTryComplete(SocketAsyncContext context) =>
@@ -579,8 +576,6 @@ private sealed unsafe class BufferPtrReceiveMessageFromOperation : ReadOperation
579576

580577
public BufferPtrReceiveMessageFromOperation(SocketAsyncContext context) : base(context) { }
581578

582-
protected sealed override void Abort() { }
583-
584579
public Action<int, byte[], int, SocketFlags, IPPacketInformation, SocketError>? Callback { get; set; }
585580

586581
protected override bool DoTryComplete(SocketAsyncContext context) =>
@@ -598,9 +593,6 @@ public AcceptOperation(SocketAsyncContext context) : base(context) { }
598593

599594
public Action<IntPtr, byte[], int, SocketError>? Callback { get; set; }
600595

601-
protected override void Abort() =>
602-
AcceptedFileDescriptor = (IntPtr)(-1);
603-
604596
protected override bool DoTryComplete(SocketAsyncContext context)
605597
{
606598
bool completed = SocketPal.TryCompleteAccept(context._socket, SocketAddress!, ref SocketAddressLen, out AcceptedFileDescriptor, out ErrorCode);
@@ -631,8 +623,6 @@ public ConnectOperation(SocketAsyncContext context) : base(context) { }
631623

632624
public Action<SocketError>? Callback { get; set; }
633625

634-
protected override void Abort() { }
635-
636626
protected override bool DoTryComplete(SocketAsyncContext context)
637627
{
638628
bool result = SocketPal.TryCompleteConnect(context._socket, SocketAddressLen, out ErrorCode);
@@ -653,8 +643,6 @@ private sealed class SendFileOperation : WriteOperation
653643

654644
public SendFileOperation(SocketAsyncContext context) : base(context) { }
655645

656-
protected override void Abort() { }
657-
658646
public Action<long, SocketError>? Callback { get; set; }
659647

660648
public override void InvokeCallback(bool allowPooling) =>
@@ -694,6 +682,13 @@ public void Dispose()
694682
}
695683
}
696684

685+
public enum OperationResult
686+
{
687+
Pending = 0,
688+
Completed = 1,
689+
Cancelled = 2
690+
}
691+
697692
private struct OperationQueue<TOperation>
698693
where TOperation : AsyncOperation
699694
{
@@ -864,7 +859,7 @@ public bool StartAsyncOperation(SocketAsyncContext context, TOperation operation
864859
}
865860

866861
// Retry the operation.
867-
if (operation.TryComplete(context))
862+
if (operation.TryComplete(context) != OperationResult.Pending)
868863
{
869864
Trace(context, $"Leave, retry succeeded");
870865
return false;
@@ -880,7 +875,7 @@ static void HandleFailedRegistration(SocketAsyncContext context, TOperation oper
880875
{
881876
// Because the other end close, we expect the operation to complete when we retry it.
882877
// If it doesn't, we fall through and throw an Exception.
883-
if (operation.TryComplete(context))
878+
if (operation.TryComplete(context) != OperationResult.Pending)
884879
{
885880
return;
886881
}
@@ -979,13 +974,6 @@ internal void ProcessAsyncOperation(TOperation op)
979974
}
980975
}
981976

982-
public enum OperationResult
983-
{
984-
Pending = 0,
985-
Completed = 1,
986-
Cancelled = 2
987-
}
988-
989977
public OperationResult ProcessQueuedOperation(TOperation op)
990978
{
991979
SocketAsyncContext context = op.AssociatedContext;
@@ -1010,27 +998,15 @@ public OperationResult ProcessQueuedOperation(TOperation op)
1010998
}
1011999
}
10121000

1013-
bool wasCompleted = false;
1001+
OperationResult result;
10141002
while (true)
10151003
{
1016-
// Try to change the op state to Running.
1017-
// If this fails, it means the operation was previously cancelled,
1018-
// and we should just remove it from the queue without further processing.
1019-
if (!op.TrySetRunning())
1020-
{
1021-
break;
1022-
}
1023-
1024-
// Try to perform the IO
1025-
if (op.TryComplete(context))
1004+
result = op.TryComplete(context);
1005+
if (result != OperationResult.Pending)
10261006
{
1027-
op.SetComplete();
1028-
wasCompleted = true;
10291007
break;
10301008
}
10311009

1032-
op.SetWaiting();
1033-
10341010
// Check for retry and reset queue state.
10351011

10361012
using (Lock())
@@ -1097,7 +1073,8 @@ public OperationResult ProcessQueuedOperation(TOperation op)
10971073

10981074
nextOp?.Dispatch();
10991075

1100-
return (wasCompleted ? OperationResult.Completed : OperationResult.Cancelled);
1076+
Debug.Assert(result != OperationResult.Pending);
1077+
return result;
11011078
}
11021079

11031080
public void CancelAndContinueProcessing(TOperation op)
@@ -1360,9 +1337,9 @@ private void PerformSyncOperation<TOperation>(ref OperationQueue<TOperation> que
13601337
e.Reset();
13611338

13621339
// We've been signalled to try to process the operation.
1363-
OperationQueue<TOperation>.OperationResult result = queue.ProcessQueuedOperation(operation);
1364-
if (result == OperationQueue<TOperation>.OperationResult.Completed ||
1365-
result == OperationQueue<TOperation>.OperationResult.Cancelled)
1340+
OperationResult result = queue.ProcessQueuedOperation(operation);
1341+
if (result == OperationResult.Completed ||
1342+
result == OperationResult.Cancelled)
13661343
{
13671344
break;
13681345
}

0 commit comments

Comments
 (0)