Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Make CancellationToken available in call credentials interceptor
  • Loading branch information
JamesNK committed Apr 28, 2023
commit 8546264b7e9f8ef6b5836c26c1fd38c8b54c8079
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,7 @@ protected override string PeerCore
get
{
// Follows the standard at https://github.com/grpc/grpc/blob/master/doc/naming.md
if (_peer == null)
{
_peer = BuildPeer();
}

return _peer;
return _peer ??= BuildPeer();
}
}

Expand Down Expand Up @@ -291,10 +286,7 @@ private void EndCallCore()

private void LogCallEnd()
{
if (_activity != null)
{
_activity.AddTag(GrpcServerConstants.ActivityStatusCodeTag, _status.StatusCode.ToTrailerString());
}
_activity?.AddTag(GrpcServerConstants.ActivityStatusCodeTag, _status.StatusCode.ToTrailerString());
if (_status.StatusCode != StatusCode.OK)
{
if (GrpcEventSource.Log.IsEnabled())
Expand Down Expand Up @@ -387,10 +379,7 @@ protected override Task WriteResponseHeadersAsyncCore(Metadata responseHeaders)
public void Initialize(ISystemClock? clock = null)
{
_activity = GetHostActivity();
if (_activity != null)
{
_activity.AddTag(GrpcServerConstants.ActivityMethodTag, MethodCore);
}
_activity?.AddTag(GrpcServerConstants.ActivityMethodTag, MethodCore);

if (GrpcEventSource.Log.IsEnabled())
{
Expand Down
24 changes: 21 additions & 3 deletions src/Grpc.Core.Api/AsyncAuthInterceptor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#endregion

using System.Threading;
using System.Threading.Tasks;
using Grpc.Core.Utils;

Expand All @@ -34,16 +35,25 @@ namespace Grpc.Core;
/// </summary>
public class AuthInterceptorContext
{
readonly string serviceUrl;
readonly string methodName;
private readonly string serviceUrl;
private readonly string methodName;
private readonly CancellationToken cancellationToken;

/// <summary>
/// Initializes a new instance of <c>AuthInterceptorContext</c>.
/// </summary>
public AuthInterceptorContext(string serviceUrl, string methodName)
public AuthInterceptorContext(string serviceUrl, string methodName) : this(serviceUrl, methodName, CancellationToken.None)
{
}

/// <summary>
/// Initializes a new instance of <c>AuthInterceptorContext</c>.
/// </summary>
public AuthInterceptorContext(string serviceUrl, string methodName, CancellationToken cancellationToken)
{
this.serviceUrl = GrpcPreconditions.CheckNotNull(serviceUrl, nameof(serviceUrl));
this.methodName = GrpcPreconditions.CheckNotNull(methodName, nameof(methodName));
this.cancellationToken = cancellationToken;
}

/// <summary>
Expand All @@ -61,4 +71,12 @@ public string MethodName
{
get { return methodName; }
}

/// <summary>
/// The cancellation token of the RPC being called.
/// </summary>
public CancellationToken CancellationToken
{
get { return cancellationToken; }
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#region Copyright notice and license
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
Expand All @@ -20,15 +20,18 @@

namespace Grpc.Net.Client.Internal;

internal class DefaultCallCredentialsConfigurator : CallCredentialsConfiguratorBase
internal sealed class DefaultCallCredentialsConfigurator : CallCredentialsConfiguratorBase
{
public AsyncAuthInterceptor? Interceptor { get; private set; }
public IReadOnlyList<CallCredentials>? Credentials { get; private set; }
public IReadOnlyList<CallCredentials>? CompositeCredentials { get; private set; }

// A place to cache context to avoid creating a new context for each auth interceptor call.
public AuthInterceptorContext? CachedContext { get; set; }

public void Reset()
{
Interceptor = null;
Credentials = null;
CompositeCredentials = null;
}

public override void SetAsyncAuthInterceptorCredentials(object? state, AsyncAuthInterceptor interceptor)
Expand All @@ -38,6 +41,6 @@ public override void SetAsyncAuthInterceptorCredentials(object? state, AsyncAuth

public override void SetCompositeCredentials(object? state, IReadOnlyList<CallCredentials> credentials)
{
Credentials = credentials;
CompositeCredentials = credentials;
}
}
4 changes: 2 additions & 2 deletions src/Grpc.Net.Client/Internal/GrpcCall.cs
Original file line number Diff line number Diff line change
Expand Up @@ -955,13 +955,13 @@ private async Task ReadCredentials(HttpRequestMessage request)

if (Options.Credentials != null)
{
await GrpcProtocolHelpers.ReadCredentialMetadata(configurator, Channel, request, Method, Options.Credentials).ConfigureAwait(false);
await GrpcProtocolHelpers.ReadCredentialMetadata(configurator, Channel, request, Method, Options.Credentials, _callCts.Token).ConfigureAwait(false);
}
if (Channel.CallCredentials?.Count > 0)
{
foreach (var credentials in Channel.CallCredentials)
{
await GrpcProtocolHelpers.ReadCredentialMetadata(configurator, Channel, request, Method, credentials).ConfigureAwait(false);
await GrpcProtocolHelpers.ReadCredentialMetadata(configurator, Channel, request, Method, credentials, _callCts.Token).ConfigureAwait(false);
}
}
}
Expand Down
56 changes: 39 additions & 17 deletions src/Grpc.Net.Client/Internal/GrpcProtocolHelpers.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#region Copyright notice and license
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
Expand Down Expand Up @@ -121,13 +121,34 @@ internal static bool ShouldSkipHeader(string name)
/* round an integer up to the next value with three significant figures */
private static long TimeoutRoundUpToThreeSignificantFigures(long x)
{
if (x < 1000) return x;
if (x < 10000) return RoundUp(x, 10);
if (x < 100000) return RoundUp(x, 100);
if (x < 1000000) return RoundUp(x, 1000);
if (x < 10000000) return RoundUp(x, 10000);
if (x < 100000000) return RoundUp(x, 100000);
if (x < 1000000000) return RoundUp(x, 1000000);
if (x < 1000)
{
return x;
}
if (x < 10000)
{
return RoundUp(x, 10);
}
if (x < 100000)
{
return RoundUp(x, 100);
}
if (x < 1000000)
{
return RoundUp(x, 1000);
}
if (x < 10000000)
{
return RoundUp(x, 10000);
}
if (x < 100000000)
{
return RoundUp(x, 100000);
}
if (x < 1000000000)
{
return RoundUp(x, 1000000);
}
return RoundUp(x, 10000000);

static long RoundUp(long x, long divisor)
Expand Down Expand Up @@ -235,7 +256,7 @@ internal static bool CanWriteCompressed(WriteOptions? writeOptions)
return canCompress;
}

internal static AuthInterceptorContext CreateAuthInterceptorContext(Uri baseAddress, IMethod method)
internal static AuthInterceptorContext CreateAuthInterceptorContext(Uri baseAddress, IMethod method, CancellationToken cancellationToken)
{
var authority = baseAddress.Authority;
if (baseAddress.Scheme == Uri.UriSchemeHttps && authority.EndsWith(":443", StringComparison.Ordinal))
Expand All @@ -252,38 +273,39 @@ internal static AuthInterceptorContext CreateAuthInterceptorContext(Uri baseAddr
serviceUrl += "/";
}
serviceUrl += method.ServiceName;
return new AuthInterceptorContext(serviceUrl, method.Name);
return new AuthInterceptorContext(serviceUrl, method.Name, cancellationToken);
}

internal static async Task ReadCredentialMetadata(
DefaultCallCredentialsConfigurator configurator,
GrpcChannel channel,
HttpRequestMessage message,
IMethod method,
CallCredentials credentials)
CallCredentials credentials,
CancellationToken cancellationToken)
{
credentials.InternalPopulateConfiguration(configurator, null);

if (configurator.Interceptor != null)
{
var authInterceptorContext = GrpcProtocolHelpers.CreateAuthInterceptorContext(channel.Address, method);
configurator.CachedContext ??= CreateAuthInterceptorContext(channel.Address, method, cancellationToken);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looked scary at first, but the call sites don't reuse the DefaultCallCredentialsConfigurator for different channel + method + cancellation combos.

Can we make that more obvious somehow so this code doesn't look scary?

e.g.

ReadCredentialMetadata(...)
{
    ReadCredentialMetadataCore(...);
    configurator.CachedContext = null;
}

// Rename current method
ReadCredentialMetadataInner(...)
{
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated comments and renamed some fields to make it clearer what is going on.

var metadata = new Metadata();
await configurator.Interceptor(authInterceptorContext, metadata).ConfigureAwait(false);
await configurator.Interceptor(configurator.CachedContext, metadata).ConfigureAwait(false);

foreach (var entry in metadata)
{
AddHeader(message.Headers, entry);
}
}

if (configurator.Credentials != null)
if (configurator.CompositeCredentials != null)
{
// Copy credentials locally. ReadCredentialMetadata will update it.
var callCredentials = configurator.Credentials;
foreach (var c in callCredentials)
var compositeCredentials = configurator.CompositeCredentials;
foreach (var callCredentials in compositeCredentials)
{
configurator.Reset();
await ReadCredentialMetadata(configurator, channel, message, method, c).ConfigureAwait(false);
await ReadCredentialMetadata(configurator, channel, message, method, callCredentials, cancellationToken).ConfigureAwait(false);
}
}
}
Expand Down
55 changes: 52 additions & 3 deletions test/Grpc.Net.Client.Tests/CallCredentialTests.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#region Copyright notice and license
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
Expand All @@ -18,6 +18,7 @@

using System.Net;
using System.Net.Http.Headers;
using System.Threading;
using Greet;
using Grpc.Core;
using Grpc.Net.Client.Tests.Infrastructure;
Expand Down Expand Up @@ -79,19 +80,67 @@ public async Task CallCredentialsWithHttps_MetadataOnRequest()
var invoker = HttpClientCallInvokerFactory.Create(httpClient);

// Act
var tcs = new TaskCompletionSource<object?>(TaskCreationOptions.RunContinuationsAsynchronously);
var callCredentials = CallCredentials.FromInterceptor(async (context, metadata) =>
{
// The operation is asynchronous to ensure delegate is awaited
await Task.Delay(50);

// Ensure task hasn't been completed.
Assert.False(tcs.Task.IsCompleted);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit odd. I think in theory the compiler or synccontext could choose to run this asynchronously before it gets here in which case it would be a race.

Generally you would instead use 2 tcs's to ensure order.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the test was fine. The tcs couldn't be completed until after the auth interceptor awaits. The assert was to verify that behavior.

But I agree the test looks a little weird. And it's brittle. I've switched it to use a sync point.

// Wait for TCS to be completed.
await tcs.Task;

// Set header.
metadata.Add("authorization", "SECRET_TOKEN");
});
var call = invoker.AsyncUnaryCall<HelloRequest, HelloReply>(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(credentials: callCredentials), new HelloRequest());
await call.ResponseAsync.DefaultTimeout();
var responseTask = call.ResponseAsync;

tcs.SetResult(null);
await responseTask.DefaultTimeout();

// Assert
Assert.AreEqual("SECRET_TOKEN", authorizationValue);
}

[Test]
public async Task CallCredentialsWithHttps_CancellationToken()
{
// Arrange
string? authorizationValue = null;
var httpClient = ClientTestHelpers.CreateTestClient(async request =>
{
authorizationValue = request.Headers.GetValues("authorization").Single();

var reply = new HelloReply { Message = "Hello world" };
var streamContent = await ClientTestHelpers.CreateResponseContent(reply).DefaultTimeout();
return ResponseUtils.CreateResponse(HttpStatusCode.OK, streamContent);
});
var invoker = HttpClientCallInvokerFactory.Create(httpClient);

// Act
var unreachableAuthInterceptorSection = false;
var callCredentials = CallCredentials.FromInterceptor(async (context, metadata) =>
{
var tcs = new TaskCompletionSource<object?>(TaskCreationOptions.RunContinuationsAsynchronously);
context.CancellationToken.Register(s => ((TaskCompletionSource<object?>)s!).SetCanceled(), tcs);

await tcs.Task;

unreachableAuthInterceptorSection = true;
});
var call = invoker.AsyncUnaryCall<HelloRequest, HelloReply>(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(credentials: callCredentials), new HelloRequest());
var responseTask = call.ResponseAsync;

call.Dispose();

var ex = await ExceptionAssert.ThrowsAsync<RpcException>(() => responseTask).DefaultTimeout();
Assert.AreEqual(StatusCode.Cancelled, ex.StatusCode);

// Assert
Assert.False(unreachableAuthInterceptorSection);
}

[Test]
public async Task CallCredentialsWithHttp_NoMetadataOnRequest()
{
Expand Down