Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 103 additions & 40 deletions Microsoft.Azure.Cosmos/src/Routing/GatewayAddressCache.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ internal class GatewayAddressCache : IAddressCache, IDisposable
private const string AddressResolutionBatchSize = "AddressResolutionBatchSize";
private const int DefaultBatchSize = 50;

// This warmup cache and connection timeout is meant to mimic an indefinite timeframe till which
// a delay task will run, until a cancellation token is requested to cancel the task. The default
// value for this timeout is 45 minutes at the moment.
private static readonly TimeSpan WarmupCacheAndOpenConnectionTimeout = TimeSpan.FromMinutes(45);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is 45M determined/tuned future?
Incoming cancellatoinToken is an option right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incoming cancellation is honored through:

using CancellationTokenSource linkedTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, just incoming cancellationToken might suffice right? Why is extra 45M cap?

Copy link
Member Author

@kundadebdatta kundadebdatta Jun 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 45 Min cap is there to honor the connection opening time, when there is no cancellation token provided. In order to understand the wiring better, please see the code snippet below:

        using CancellationTokenSource linkedTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
        Task timeoutTask = Task.Delay(45 mins, linkedTokenSource.Token);
        Task resultTask = await Task.WhenAny(Task.WhenAll(tasks), timeoutTask);

        if (resultTask == timeoutTask)
        {
            // Operation has been cancelled.
            DefaultTrace.TraceWarning("The open connection task was cancelled because the cancellation token was expired. '{0}'",
                System.Diagnostics.Trace.CorrelationManager.ActivityId);
        }
        else
        {
            linkedTokenSource.Cancel();
        }

In the above code, the resultTask will be completed, either when the timeoutTask is completed, i.e. the cancellationToken is expired OR the Task.WhenAll(tasks) , which is opening the rntbd connections is completed. Today, in .net standard, the Task.WhenAll() doesn't allow a cancellation token to cancel the tasks. There is a Task.WaitAsync on .NET >= 6, which allows cancellation token, however it is not supported in the current .net framework. This is the reason we are taking this alternate route to honor the cancellation token.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, just incoming cancellationToken might suffice right? Why is extra 45M cap?

Incoming cancellation is optional, what happens if the user does not provide one? How do we timebox this in case something hangs?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personal opinion, 45M too long for any startup task (might be as bad as deadlock for them).
Its a new behavior right, is it an ASK from customers?


private readonly Uri serviceEndpoint;
private readonly Uri addressEndpoint;

Expand Down Expand Up @@ -113,7 +118,7 @@ public async Task OpenConnectionsAsync(
bool shouldOpenRntbdChannels,
CancellationToken cancellationToken)
{
List<Task<TryCatch<DocumentServiceResponse>>> tasks = new ();
List<Task> tasks = new ();
int batchSize = GatewayAddressCache.DefaultBatchSize;

#if !(NETSTANDARD15 || NETSTANDARD16)
Expand Down Expand Up @@ -147,50 +152,33 @@ public async Task OpenConnectionsAsync(
{
for (int i = 0; i < partitionKeyRangeIdentities.Count; i += batchSize)
{
tasks
.Add(this.GetAddressesAsync(
request: request,
collectionRid: collection.ResourceId,
partitionKeyRangeIds: partitionKeyRangeIdentities.Skip(i).Take(batchSize).Select(range => range.PartitionKeyRangeId)));
tasks.Add(
this.WarmupCachesAndOpenConnectionsAsync(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wire CancellationTokens?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my understanding, the cancellationToken should honor the total open connection time to all replicas, and not individual connection time per partition key ranges. The entire thing is covered in Task resultTask = await Task.WhenAny(Task.WhenAll(tasks), timeoutTask);

request: request,
collectionRid: collection.ResourceId,
partitionKeyRangeIds: partitionKeyRangeIdentities.Skip(i).Take(batchSize).Select(range => range.PartitionKeyRangeId),
containerProperties: collection,
shouldOpenRntbdChannels: shouldOpenRntbdChannels));
}
}

foreach (TryCatch<DocumentServiceResponse> task in await Task.WhenAll(tasks))
{
if (task.Failed)
{
continue;
}
using CancellationTokenSource linkedTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);

using (DocumentServiceResponse response = task.Result)
{
FeedResource<Address> addressFeed = response.GetResource<FeedResource<Address>>();

bool inNetworkRequest = this.IsInNetworkRequest(response);
// The `timeoutTask` is a background task which adds a delay for a period of WarmupCacheAndOpenConnectionTimeout. The task will
// be cancelled either by - a) when `linkedTokenSource` expires, which means the original `cancellationToken` expires or
// b) the the `linkedTokenSource.Cancel()` is called.
Task timeoutTask = Task.Delay(GatewayAddressCache.WarmupCacheAndOpenConnectionTimeout, linkedTokenSource.Token);
Task resultTask = await Task.WhenAny(Task.WhenAll(tasks), timeoutTask);

IEnumerable<Tuple<PartitionKeyRangeIdentity, PartitionAddressInformation>> addressInfos =
addressFeed.Where(addressInfo => ProtocolFromString(addressInfo.Protocol) == this.protocol)
.GroupBy(address => address.PartitionKeyRangeId, StringComparer.Ordinal)
.Select(group => this.ToPartitionAddressAndRange(collection.ResourceId, @group.ToList(), inNetworkRequest));

foreach (Tuple<PartitionKeyRangeIdentity, PartitionAddressInformation> addressInfo in addressInfos)
{
this.serverPartitionAddressCache.Set(
new PartitionKeyRangeIdentity(collection.ResourceId, addressInfo.Item1.PartitionKeyRangeId),
addressInfo.Item2);

// The `shouldOpenRntbdChannels` boolean flag indicates whether the SDK should establish Rntbd connections to the
// backend replica nodes. For the `CosmosClient.CreateAndInitializeAsync()` flow, the flag should be passed as
// `true` so that the Rntbd connections to the backend replicas could be established deterministically. For any
// other flow, the flag should be passed as `false`.
if (this.openConnectionsHandler != null && shouldOpenRntbdChannels)
{
await this.openConnectionsHandler
.TryOpenRntbdChannelsAsync(
addresses: addressInfo.Item2.Get(Protocol.Tcp)?.ReplicaTransportAddressUris);
}
}
}
if (resultTask == timeoutTask)
{
// Operation has been cancelled.
DefaultTrace.TraceWarning("The open connection task was cancelled because the cancellation token was expired. '{0}'",
System.Diagnostics.Trace.CorrelationManager.ActivityId);
}
else
{
linkedTokenSource.Cancel();
}
}

Expand Down Expand Up @@ -350,6 +338,81 @@ public async Task<PartitionAddressInformation> TryGetAddressesAsync(
}
}

/// <summary>
/// Gets the address information from the gateway using the partition key range ids, and warms up the async non blocking cache
/// by inserting them as a key value pair for later lookup. Additionally attempts to establish Rntbd connections to the backend
/// replicas based on `shouldOpenRntbdChannels` boolean flag.
/// </summary>
/// <param name="request">An instance of <see cref="DocumentServiceRequest"/> containing the request payload.</param>
/// <param name="collectionRid">A string containing the collection ids.</param>
/// <param name="partitionKeyRangeIds">An instance of <see cref="IEnumerable{T}"/> containing the list of partition key range ids.</param>
/// <param name="containerProperties">An instance of <see cref="ContainerProperties"/> containing the collection properties.</param>
/// <param name="shouldOpenRntbdChannels">A boolean flag indicating whether Rntbd connections are required to be established to the backend replica nodes.</param>
private async Task WarmupCachesAndOpenConnectionsAsync(
DocumentServiceRequest request,
string collectionRid,
IEnumerable<string> partitionKeyRangeIds,
ContainerProperties containerProperties,
bool shouldOpenRntbdChannels)
{
TryCatch<DocumentServiceResponse> documentServiceResponseWrapper = await this.GetAddressesAsync(
request: request,
collectionRid: collectionRid,
partitionKeyRangeIds: partitionKeyRangeIds);

if (documentServiceResponseWrapper.Failed)
{
return;
}

try
{
using (DocumentServiceResponse response = documentServiceResponseWrapper.Result)
{
FeedResource<Address> addressFeed = response.GetResource<FeedResource<Address>>();

bool inNetworkRequest = this.IsInNetworkRequest(response);

IEnumerable<Tuple<PartitionKeyRangeIdentity, PartitionAddressInformation>> addressInfos =
addressFeed.Where(addressInfo => ProtocolFromString(addressInfo.Protocol) == this.protocol)
.GroupBy(address => address.PartitionKeyRangeId, StringComparer.Ordinal)
.Select(group => this.ToPartitionAddressAndRange(containerProperties.ResourceId, @group.ToList(), inNetworkRequest));

List<Task> openConnectionTasks = new ();
foreach (Tuple<PartitionKeyRangeIdentity, PartitionAddressInformation> addressInfo in addressInfos)
{
this.serverPartitionAddressCache.Set(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yield based on cancelaltion after wiring through?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please help me understanding this better ? Currently there is no need to have the cancellationToken inside WarmupCachesAndOpenConnectionsAsync method.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stop the foreach if the cancellation is signaling (if ct.IsCancellationRequested break)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OnCancellation behavior: With new changes its trying to gracefully complete right?
Current existing pattern across is to throw exception

cancellationToken.ThrowIfCancellationRequested()

new PartitionKeyRangeIdentity(containerProperties.ResourceId, addressInfo.Item1.PartitionKeyRangeId),
addressInfo.Item2);

// The `shouldOpenRntbdChannels` boolean flag indicates whether the SDK should establish Rntbd connections to the
// backend replica nodes. For the `CosmosClient.CreateAndInitializeAsync()` flow, the flag should be passed as
// `true` so that the Rntbd connections to the backend replicas could be established deterministically. For any
// other flow, the flag should be passed as `false`.
if (this.openConnectionsHandler != null && shouldOpenRntbdChannels)
{
openConnectionTasks
.Add(this.openConnectionsHandler
.TryOpenRntbdChannelsAsync(
addresses: addressInfo.Item2.Get(Protocol.Tcp)?.ReplicaTransportAddressUris));
}
}

if (openConnectionTasks.Any())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: What is more performant? Any() or Count > 0?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the check necessary?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, I think we can remove this condition. The await Task.WhenAll(openConnectionTasks); should be sufficient to check for empty tasks as well.

{
await Task.WhenAll(openConnectionTasks);
}
}
}
catch (Exception ex)
{
DefaultTrace.TraceWarning("Failed to warm-up caches and open connections for the server addresses: {0} with exception: {1}. '{2}'",
collectionRid,
ex,
System.Diagnostics.Trace.CorrelationManager.ActivityId);
}
}

private static void SetTransportAddressUrisToUnhealthy(
PartitionAddressInformation stalePartitionAddressInformation,
Lazy<HashSet<TransportAddressUri>> failedEndpoints)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,75 @@ await cache.OpenConnectionsAsync(
expectedTotalSuccessAddressesToOpenCount: 0);
}

/// <summary>
/// Test to validate that when <see cref="GatewayAddressCache.OpenConnectionsAsync()"/> is called with a valid open connection handler
/// and a cancellation token that will expire with a pre-configured time, the handler method is indeed invoked and the open connection
/// operation gets cancelled successfully, if the cancellation token expires. The open connection operation succeeds if the operation
/// is finished before the cancellation token expiry time.
/// </summary>
[TestMethod]
[Owner("dkunda")]
[DataRow(1, 2, 1, 0, 3, 0, true, DisplayName = "Validate that when the cancellation token expiry time (i.e. 1 sec) is smaller than the open connection opperation duration (i.e. 2 sec)," +
"the open connection operation gets cancelled and the cancellation token is indeed respected and eventually cancelled.")]
[DataRow(3, 1, 1, 0, 3, 3, false, DisplayName = "Validate that when the cancellation token expiry time (i.e. 3 sec) is larger than the open connection opperation duration (i.e. 1 sec)," +
"the open connection operation completes successfully and the cancellation token is not cancelled.")]
public async Task OpenConnectionsAsync_WithValidOpenConnectionHandlerAndCancellationTokenExpires_ShouldInvokeHandlerMethodAndCancelToken(
int cancellationTokenTimeoutInSeconds,
int openConnectionDelayInSeconds,
int expectedTotalHandlerInvocationCount,
int expectedTotalFailedAddressesToOpenCount,
int expectedTotalReceivedAddressesCount,
int expectedTotalSuccessAddressesToOpenCount,
bool shouldCancelToken)
{
// Arrange.
FakeMessageHandler messageHandler = new ();
FakeOpenConnectionHandler fakeOpenConnectionHandler = new (
failingIndexes: new HashSet<int>(),
openConnectionDelayInSeconds: openConnectionDelayInSeconds);

ContainerProperties containerProperties = ContainerProperties.CreateWithResourceId("ccZ1ANCszwk=");
containerProperties.Id = "TestId";
containerProperties.PartitionKeyPath = "/pk";
HttpClient httpClient = new(messageHandler)
{
Timeout = TimeSpan.FromSeconds(120)
};

CancellationTokenSource cts = new (TimeSpan.FromSeconds(cancellationTokenTimeoutInSeconds));
CancellationToken token = cts.Token;

GatewayAddressCache cache = new (
new Uri(GatewayAddressCacheTests.DatabaseAccountApiEndpoint),
Protocol.Tcp,
this.mockTokenProvider.Object,
this.mockServiceConfigReader.Object,
MockCosmosUtil.CreateCosmosHttpClient(() => httpClient),
openConnectionsHandler: fakeOpenConnectionHandler,
suboptimalPartitionForceRefreshIntervalInSeconds: 2);

// Act.
await cache.OpenConnectionsAsync(
databaseName: "test-database",
collection: containerProperties,
partitionKeyRangeIdentities: new List<PartitionKeyRangeIdentity>()
{
this.testPartitionKeyRangeIdentity
},
shouldOpenRntbdChannels: true,
cancellationToken: token);

// Assert.
GatewayAddressCacheTests.AssertOpenConnectionHandlerAttributes(
fakeOpenConnectionHandler: fakeOpenConnectionHandler,
expectedTotalFailedAddressesToOpenCount: expectedTotalFailedAddressesToOpenCount,
expectedTotalHandlerInvocationCount: expectedTotalHandlerInvocationCount,
expectedTotalReceivedAddressesCount: expectedTotalReceivedAddressesCount,
expectedTotalSuccessAddressesToOpenCount: expectedTotalSuccessAddressesToOpenCount);

Assert.AreEqual(shouldCancelToken, token.IsCancellationRequested);
}

/// <summary>
/// Test to validate that when <see cref="GlobalAddressResolver.OpenConnectionsToAllReplicasAsync()"/> is called with a
/// valid open connection handler, the handler method is indeed invoked and an attempt is made to open
Expand Down Expand Up @@ -1074,7 +1143,7 @@ public async Task TryGetAddressesAsync_WhenReplicaVlidationEnabledAndUnhealthyUr
FakeOpenConnectionHandler fakeOpenConnectionHandler = new (
failIndexesByAttempts: new Dictionary<int, HashSet<int>>()
{
{ 0, new HashSet<int>() { 1 } }
{ 0, new HashSet<int>() { 2 } }
},
manualResetEvent: manualResetEvent);

Expand Down Expand Up @@ -1513,16 +1582,19 @@ public class FakeOpenConnectionHandler : IOpenConnectionsHandler
private int successInvocationCounter = 0;
private int totalReceivedAddressesCounter = 0;
private readonly HashSet<int> failingIndexes;
private readonly int openConnectionDelayInSeconds;
private readonly bool useAttemptBasedFailingIndexs;
private readonly ManualResetEvent manualResetEvent;
private readonly Dictionary<int, HashSet<int>> failIndexesByAttempts;

public FakeOpenConnectionHandler(
HashSet<int> failingIndexes,
ManualResetEvent manualResetEvent = null)
ManualResetEvent manualResetEvent = null,
int openConnectionDelayInSeconds = 0)
{
this.failingIndexes = failingIndexes;
this.manualResetEvent = manualResetEvent;
this.openConnectionDelayInSeconds = openConnectionDelayInSeconds;
}

public FakeOpenConnectionHandler(
Expand Down Expand Up @@ -1554,11 +1626,18 @@ public int GetTotalMethodInvocationCount()
return this.methodInvocationCounter;
}

Task IOpenConnectionsHandler.TryOpenRntbdChannelsAsync(
async Task IOpenConnectionsHandler.TryOpenRntbdChannelsAsync(
IEnumerable<TransportAddressUri> addresses)
{
int idx = 0;
this.methodInvocationCounter++;
this.totalReceivedAddressesCounter += addresses.Count();

if (this.openConnectionDelayInSeconds > 0)
{
await Task.Delay(TimeSpan.FromSeconds(this.openConnectionDelayInSeconds));
}

foreach (TransportAddressUri transportAddress in addresses)
{
if (this.useAttemptBasedFailingIndexs)
Expand Down Expand Up @@ -1587,9 +1666,7 @@ Task IOpenConnectionsHandler.TryOpenRntbdChannelsAsync(
idx++;
}

this.methodInvocationCounter++;
this.manualResetEvent?.Set();
return Task.CompletedTask;
}

private void ExecuteSuccessCondition(
Expand Down