Skip to content

Commit c95ce29

Browse files
Improve preventing concurrent calls to handler
1 parent 046d615 commit c95ce29

File tree

4 files changed

+137
-11
lines changed

4 files changed

+137
-11
lines changed

src/Immediate.Cache.Shared/ApplicationCacheBase.cs

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using System.Diagnostics.CodeAnalysis;
12
using Immediate.Handlers.Shared;
23
using Microsoft.Extensions.Caching.Memory;
34

@@ -113,22 +114,20 @@ public void SetValue(TRequest request, TResponse value) =>
113114
public void RemoveValue(TRequest request) =>
114115
GetCacheValue(request).RemoveValue();
115116

117+
[SuppressMessage("Design", "CA1001:Types that own disposable fields should be disposable", Justification = "CancellationTokenSource does not need to be disposed here.")]
116118
private sealed class CacheValue(
117119
TRequest request,
118120
Owned<IHandler<TRequest, TResponse>> handler
119121
)
120122
{
121-
private TResponse? _response;
122123
private CancellationTokenSource? _tokenSource;
124+
private TaskCompletionSource<TResponse>? _responseSource;
123125
private readonly object _lock = new();
124126

125127
public async ValueTask<TResponse> GetValue()
126128
{
127-
if (_response is not null)
128-
return _response;
129-
130-
lock (_lock)
131-
_tokenSource ??= new CancellationTokenSource();
129+
if (!TryAcquireResponseSource())
130+
return await _responseSource.Task.ConfigureAwait(false);
132131

133132
var token = _tokenSource.Token;
134133

@@ -146,20 +145,43 @@ public async ValueTask<TResponse> GetValue()
146145
.ConfigureAwait(false);
147146

148147
lock (_lock)
149-
return _response ??= response;
148+
{
149+
_responseSource.SetResult(response);
150+
return response;
151+
}
150152
}
151153
}
152-
catch (OperationCanceledException) when (_response is not null)
154+
catch (OperationCanceledException) when (_responseSource is not null)
155+
{
156+
return await _responseSource.Task.ConfigureAwait(false);
157+
}
158+
}
159+
160+
[MemberNotNull(nameof(_responseSource))]
161+
[MemberNotNullWhen(true, nameof(_tokenSource))]
162+
[SuppressMessage("Maintainability", "CA1508:Avoid dead conditional code", Justification = "Double-checked lock pattern")]
163+
private bool TryAcquireResponseSource()
164+
{
165+
if (_responseSource is not null)
166+
return false;
167+
168+
lock (_lock)
153169
{
154-
return _response;
170+
if (_responseSource is not null)
171+
return false;
172+
173+
_tokenSource = new CancellationTokenSource();
174+
_responseSource = new TaskCompletionSource<TResponse>();
175+
return true;
155176
}
156177
}
157178

158179
public void SetValue(TResponse response)
159180
{
160181
lock (_lock)
161182
{
162-
_response = response;
183+
_responseSource = new TaskCompletionSource<TResponse>();
184+
_responseSource.SetResult(response);
163185
_tokenSource?.Cancel();
164186
}
165187
}
@@ -168,7 +190,7 @@ public void RemoveValue()
168190
{
169191
lock (_lock)
170192
{
171-
_response = null;
193+
_responseSource = null;
172194
_tokenSource?.Cancel();
173195
}
174196
}

tests/Immediate.Cache.FunctionalTests/ApplicationCacheTests.cs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ public ApplicationCacheTests()
1313
var services = new ServiceCollection();
1414
_ = services.AddHandlers();
1515
_ = services.AddSingleton<GetValueCache>();
16+
_ = services.AddSingleton<DelayGetValueCache>();
1617
_ = services.AddSingleton(typeof(Owned<>));
1718
_ = services.AddMemoryCache();
1819

@@ -71,4 +72,54 @@ public async Task RemoveValueRemovesValue()
7172
Assert.Equal(3, response.Value);
7273
Assert.True(response.ExecutedHandler);
7374
}
75+
76+
[Test]
77+
public async Task SimultaneousAccessIsSerialized()
78+
{
79+
var request1 = new DelayGetValue.Query()
80+
{
81+
Value = 1,
82+
Name = "Request1",
83+
CompletionSource = new(),
84+
};
85+
86+
var request2 = new DelayGetValue.Query()
87+
{
88+
Value = 1,
89+
Name = "Request2",
90+
CompletionSource = new(),
91+
};
92+
93+
var cache = _serviceProvider.GetRequiredService<DelayGetValueCache>();
94+
var response1Task = cache.GetValue(request1);
95+
var response2Task = cache.GetValue(request2);
96+
97+
// both waiting until tcs triggered
98+
Assert.False(response1Task.IsCompleted);
99+
Assert.False(response2Task.IsCompleted);
100+
101+
Assert.Equal(0, request1.TimesExecuted);
102+
Assert.Equal(0, request2.TimesExecuted);
103+
104+
// request2 does nothing at this point
105+
request2.CompletionSource.SetResult();
106+
107+
Assert.False(response1Task.IsCompleted);
108+
Assert.False(response2Task.IsCompleted);
109+
110+
Assert.Equal(0, request1.TimesExecuted);
111+
Assert.Equal(0, request2.TimesExecuted);
112+
113+
// trigger request1, which should run exactly once
114+
request1.CompletionSource.SetResult();
115+
116+
var response1 = await response1Task;
117+
var response2 = await response2Task;
118+
119+
Assert.Equal(1, request1.TimesExecuted);
120+
Assert.Equal(0, request2.TimesExecuted);
121+
122+
// ensure both responses get the same response back
123+
Assert.Equal(response1.RandomValue, response2.RandomValue);
124+
}
74125
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
using Immediate.Handlers.Shared;
2+
3+
namespace Immediate.Cache.FunctionalTests;
4+
5+
[Handler]
6+
public static partial class DelayGetValue
7+
{
8+
public sealed class Query
9+
{
10+
public required int Value { get; init; }
11+
public required string Name { get; init; }
12+
public required TaskCompletionSource CompletionSource { get; init; }
13+
public int TimesExecuted { get; set; }
14+
}
15+
16+
public sealed record Response(int Value, bool ExecutedHandler, Guid RandomValue);
17+
18+
private static readonly object s_lock = new();
19+
20+
private static async ValueTask<Response> HandleAsync(
21+
Query query,
22+
CancellationToken _
23+
)
24+
{
25+
await query.CompletionSource.Task;
26+
lock (s_lock)
27+
query.TimesExecuted++;
28+
29+
return new(query.Value, ExecutedHandler: true, RandomValue: Guid.NewGuid());
30+
}
31+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
using System.Diagnostics.CodeAnalysis;
2+
using Immediate.Handlers.Shared;
3+
using Microsoft.Extensions.Caching.Memory;
4+
5+
namespace Immediate.Cache.FunctionalTests;
6+
7+
public sealed class DelayGetValueCache(
8+
IMemoryCache memoryCache,
9+
Owned<IHandler<DelayGetValue.Query, DelayGetValue.Response>> ownedHandler
10+
) : ApplicationCacheBase<DelayGetValue.Query, DelayGetValue.Response>(
11+
memoryCache,
12+
ownedHandler
13+
)
14+
{
15+
[SuppressMessage(
16+
"Design",
17+
"CA1062:Validate arguments of public methods",
18+
Justification = "Not a public method"
19+
)]
20+
protected override string TransformKey(DelayGetValue.Query request) =>
21+
$"GetValue(query: {request.Value})";
22+
}

0 commit comments

Comments
 (0)