Skip to content

Commit 3bc2a9f

Browse files
Improve behavior for concurrent access
1 parent 173b83e commit 3bc2a9f

File tree

3 files changed

+125
-30
lines changed

3 files changed

+125
-30
lines changed

src/Immediate.Cache.Shared/ApplicationCacheBase.cs

Lines changed: 40 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using System.Diagnostics;
12
using System.Diagnostics.CodeAnalysis;
23
using Immediate.Handlers.Shared;
34
using Microsoft.Extensions.Caching.Memory;
@@ -129,11 +130,42 @@ Owned<IHandler<TRequest, TResponse>> handler
129130

130131
public async ValueTask<TResponse> GetValue(CancellationToken cancellationToken)
131132
{
132-
if (!TryAcquireResponseSource(cancellationToken))
133-
return await _responseSource.Task.ConfigureAwait(false);
133+
try
134+
{
135+
return await GetHandlerTask().WaitAsync(cancellationToken).ConfigureAwait(false);
136+
}
137+
catch (OperationCanceledException) when (
138+
!cancellationToken.IsCancellationRequested
139+
&& _responseSource?.Task is { IsCompletedSuccessfully: true } task
140+
)
141+
{
142+
return await task.ConfigureAwait(false);
143+
}
144+
}
134145

135-
var token = _tokenSource.Token;
146+
[SuppressMessage("Maintainability", "CA1508:Avoid dead conditional code", Justification = "Double-checked lock pattern")]
147+
private Task<TResponse> GetHandlerTask()
148+
{
149+
if (_responseSource is not null)
150+
return _responseSource.Task;
151+
152+
lock (_lock)
153+
{
154+
if (_responseSource is not null)
155+
return _responseSource.Task;
156+
157+
_tokenSource = new();
158+
_responseSource = new TaskCompletionSource<TResponse>();
136159

160+
return Task.Run(() => RunHandler(_tokenSource, _responseSource));
161+
}
162+
}
163+
164+
private async Task<TResponse> RunHandler(
165+
CancellationTokenSource tokenSource,
166+
TaskCompletionSource<TResponse> responseSource
167+
)
168+
{
137169
try
138170
{
139171
var scope = handler.GetScope();
@@ -143,42 +175,25 @@ public async ValueTask<TResponse> GetValue(CancellationToken cancellationToken)
143175
var response = await scope.Service
144176
.HandleAsync(
145177
request,
146-
cancellationToken: token
178+
tokenSource.Token
147179
)
148180
.ConfigureAwait(false);
149181

150182
lock (_lock)
151183
{
152-
_responseSource.SetResult(response);
184+
Debug.Assert(_responseSource == responseSource);
185+
responseSource.SetResult(response);
186+
153187
return response;
154188
}
155189
}
156190
}
157-
catch (OperationCanceledException) when (!cancellationToken.IsCancellationRequested && _responseSource is not null)
191+
catch (OperationCanceledException) when (_responseSource is not null)
158192
{
159193
return await _responseSource.Task.ConfigureAwait(false);
160194
}
161195
}
162196

163-
[MemberNotNull(nameof(_responseSource))]
164-
[MemberNotNullWhen(true, nameof(_tokenSource))]
165-
[SuppressMessage("Maintainability", "CA1508:Avoid dead conditional code", Justification = "Double-checked lock pattern")]
166-
private bool TryAcquireResponseSource(CancellationToken cancellationToken)
167-
{
168-
if (_responseSource is not null)
169-
return false;
170-
171-
lock (_lock)
172-
{
173-
if (_responseSource is not null)
174-
return false;
175-
176-
_tokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
177-
_responseSource = new TaskCompletionSource<TResponse>();
178-
return true;
179-
}
180-
}
181-
182197
public void SetValue(TResponse response)
183198
{
184199
lock (_lock)

tests/Immediate.Cache.FunctionalTests/ApplicationCacheTests.cs

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,91 @@ public async Task ProperlyUsesCancellationToken()
138138
var responseTask = cache.GetValue(request, tcs.Token);
139139

140140
await tcs.CancelAsync();
141-
request.CompletionSource.SetResult();
142141

143142
Assert.True(responseTask.IsCanceled);
144143
}
144+
145+
[Test]
146+
public async Task CancellingFirstAccessOperatesCorrectly()
147+
{
148+
using var cts1 = new CancellationTokenSource();
149+
var request1 = new DelayGetValue.Query()
150+
{
151+
Value = 1,
152+
Name = "Request1",
153+
CompletionSource = new(),
154+
};
155+
156+
using var cts2 = new CancellationTokenSource();
157+
var request2 = new DelayGetValue.Query()
158+
{
159+
Value = 1,
160+
Name = "Request2",
161+
CompletionSource = new(),
162+
};
163+
164+
var cache = _serviceProvider.GetRequiredService<DelayGetValueCache>();
165+
var response1Task = cache.GetValue(request1, cts1.Token);
166+
var response2Task = cache.GetValue(request2, cts2.Token);
167+
168+
// both waiting until cancellation triggered
169+
Assert.False(response1Task.IsCompleted);
170+
Assert.False(response2Task.IsCompleted);
171+
172+
Assert.Equal(0, request1.TimesExecuted);
173+
Assert.Equal(0, request2.TimesExecuted);
174+
175+
// cancel query1; query2 should remain uncancelled
176+
await cts1.CancelAsync();
177+
178+
Assert.True(response1Task.IsCanceled);
179+
Assert.False(response2Task.IsCanceled);
180+
181+
await cts2.CancelAsync();
182+
183+
Assert.True(response1Task.IsCanceled);
184+
Assert.True(response2Task.IsCanceled);
185+
}
186+
187+
[Test]
188+
public async Task CancellingSecondAccessOperatesCorrectly()
189+
{
190+
using var cts1 = new CancellationTokenSource();
191+
var request1 = new DelayGetValue.Query()
192+
{
193+
Value = 1,
194+
Name = "Request1",
195+
CompletionSource = new(),
196+
};
197+
198+
using var cts2 = new CancellationTokenSource();
199+
var request2 = new DelayGetValue.Query()
200+
{
201+
Value = 1,
202+
Name = "Request2",
203+
CompletionSource = new(),
204+
};
205+
206+
var cache = _serviceProvider.GetRequiredService<DelayGetValueCache>();
207+
var response1Task = cache.GetValue(request1, cts1.Token);
208+
var response2Task = cache.GetValue(request2, cts2.Token);
209+
210+
// both waiting until cancellation triggered
211+
Assert.False(response1Task.IsCompleted);
212+
Assert.False(response2Task.IsCompleted);
213+
214+
Assert.Equal(0, request1.TimesExecuted);
215+
Assert.Equal(0, request2.TimesExecuted);
216+
217+
// cancel query2; query1 should remain uncancelled
218+
await cts2.CancelAsync();
219+
220+
Assert.False(response1Task.IsCanceled);
221+
Assert.True(response2Task.IsCanceled);
222+
223+
await cts1.CancelAsync();
224+
225+
Assert.True(response1Task.IsCanceled);
226+
Assert.True(response2Task.IsCanceled);
227+
}
145228
}

tests/Immediate.Cache.FunctionalTests/DelayGetValue.cs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ public sealed class Query
1111
public required string Name { get; init; }
1212
public required TaskCompletionSource CompletionSource { get; init; }
1313
public int TimesExecuted { get; set; }
14-
public CancellationToken CancellationToken { get; set; }
1514
}
1615

1716
public sealed record Response(int Value, bool ExecutedHandler, Guid RandomValue);
@@ -23,9 +22,7 @@ private static async ValueTask<Response> HandleAsync(
2322
CancellationToken token
2423
)
2524
{
26-
query.CancellationToken = token;
27-
await query.CompletionSource.Task;
28-
token.ThrowIfCancellationRequested();
25+
await query.CompletionSource.Task.WaitAsync(token);
2926

3027
lock (s_lock)
3128
query.TimesExecuted++;

0 commit comments

Comments
 (0)