Skip to content

Commit 173b83e

Browse files
Validate Using CancellationToken
1 parent 3acad10 commit 173b83e

File tree

3 files changed

+35
-8
lines changed

3 files changed

+35
-8
lines changed

src/Immediate.Cache.Shared/ApplicationCacheBase.cs

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,14 @@ private CacheValue GetCacheValue(TRequest request)
8686
/// <param name="request">
8787
/// The request payload to be cached.
8888
/// </param>
89+
/// <param name="cancellationToken">
90+
/// The token to monitor for cancellation requests. The default value is <see cref="CancellationToken.None"/>.
91+
/// </param>
8992
/// <returns>
9093
/// The response payload from executing the handler.
9194
/// </returns>
92-
public ValueTask<TResponse> GetValue(TRequest request) =>
93-
GetCacheValue(request).GetValue();
95+
public ValueTask<TResponse> GetValue(TRequest request, CancellationToken cancellationToken = default) =>
96+
GetCacheValue(request).GetValue(cancellationToken);
9497

9598
/// <summary>
9699
/// Sets the value for a particular cache entry, bypassing the execution of the handler.
@@ -124,9 +127,9 @@ Owned<IHandler<TRequest, TResponse>> handler
124127
private TaskCompletionSource<TResponse>? _responseSource;
125128
private readonly Lock _lock = new();
126129

127-
public async ValueTask<TResponse> GetValue()
130+
public async ValueTask<TResponse> GetValue(CancellationToken cancellationToken)
128131
{
129-
if (!TryAcquireResponseSource())
132+
if (!TryAcquireResponseSource(cancellationToken))
130133
return await _responseSource.Task.ConfigureAwait(false);
131134

132135
var token = _tokenSource.Token;
@@ -151,7 +154,7 @@ public async ValueTask<TResponse> GetValue()
151154
}
152155
}
153156
}
154-
catch (OperationCanceledException) when (_responseSource is not null)
157+
catch (OperationCanceledException) when (!cancellationToken.IsCancellationRequested && _responseSource is not null)
155158
{
156159
return await _responseSource.Task.ConfigureAwait(false);
157160
}
@@ -160,7 +163,7 @@ public async ValueTask<TResponse> GetValue()
160163
[MemberNotNull(nameof(_responseSource))]
161164
[MemberNotNullWhen(true, nameof(_tokenSource))]
162165
[SuppressMessage("Maintainability", "CA1508:Avoid dead conditional code", Justification = "Double-checked lock pattern")]
163-
private bool TryAcquireResponseSource()
166+
private bool TryAcquireResponseSource(CancellationToken cancellationToken)
164167
{
165168
if (_responseSource is not null)
166169
return false;
@@ -170,7 +173,7 @@ private bool TryAcquireResponseSource()
170173
if (_responseSource is not null)
171174
return false;
172175

173-
_tokenSource = new CancellationTokenSource();
176+
_tokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
174177
_responseSource = new TaskCompletionSource<TResponse>();
175178
return true;
176179
}

tests/Immediate.Cache.FunctionalTests/ApplicationCacheTests.cs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,4 +122,24 @@ public async Task SimultaneousAccessIsSerialized()
122122
// ensure both responses get the same response back
123123
Assert.Equal(response1.RandomValue, response2.RandomValue);
124124
}
125+
126+
[Test]
127+
public async Task ProperlyUsesCancellationToken()
128+
{
129+
var request = new DelayGetValue.Query()
130+
{
131+
Value = 4,
132+
Name = "Request4",
133+
CompletionSource = new(),
134+
};
135+
136+
using var tcs = new CancellationTokenSource();
137+
var cache = _serviceProvider.GetRequiredService<DelayGetValueCache>();
138+
var responseTask = cache.GetValue(request, tcs.Token);
139+
140+
await tcs.CancelAsync();
141+
request.CompletionSource.SetResult();
142+
143+
Assert.True(responseTask.IsCanceled);
144+
}
125145
}

tests/Immediate.Cache.FunctionalTests/DelayGetValue.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ public sealed class Query
1111
public required string Name { get; init; }
1212
public required TaskCompletionSource CompletionSource { get; init; }
1313
public int TimesExecuted { get; set; }
14+
public CancellationToken CancellationToken { get; set; }
1415
}
1516

1617
public sealed record Response(int Value, bool ExecutedHandler, Guid RandomValue);
@@ -19,10 +20,13 @@ public sealed record Response(int Value, bool ExecutedHandler, Guid RandomValue)
1920

2021
private static async ValueTask<Response> HandleAsync(
2122
Query query,
22-
CancellationToken _
23+
CancellationToken token
2324
)
2425
{
26+
query.CancellationToken = token;
2527
await query.CompletionSource.Task;
28+
token.ThrowIfCancellationRequested();
29+
2630
lock (s_lock)
2731
query.TimesExecuted++;
2832

0 commit comments

Comments
 (0)