Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
V2
  • Loading branch information
JamesNK committed May 26, 2025
commit cb4abc2a26fa002e4807888d32e3fe8fcb6dbbc9
17 changes: 14 additions & 3 deletions src/Grpc.AspNetCore.Server/GrpcEndpointRouteBuilderExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ namespace Microsoft.AspNetCore.Builder;
/// </summary>
public static class GrpcEndpointRouteBuilderExtensions
{
private sealed class ServerServiceDefinitionDummyService
{
}

/// <summary>
/// Maps incoming requests to the specified <typeparamref name="TService"/> type.
/// </summary>
Expand All @@ -44,7 +48,7 @@ public static class GrpcEndpointRouteBuilderExtensions
ValidateServicesRegistered(builder.ServiceProvider);

var serviceRouteBuilder = builder.ServiceProvider.GetRequiredService<ServiceRouteBuilder<TService>>();
var endpointConventionBuilders = serviceRouteBuilder.Build(builder);
var endpointConventionBuilders = serviceRouteBuilder.Build(builder, argument: null);

return new GrpcServiceEndpointConventionBuilder(endpointConventionBuilders);
}
Expand All @@ -61,7 +65,10 @@ public static GrpcServiceEndpointConventionBuilder MapGrpcService(this IEndpoint
ArgumentNullException.ThrowIfNull(builder, nameof(builder));
ArgumentNullException.ThrowIfNull(serviceDefinition, nameof(serviceDefinition));

var serviceRouteBuilder = builder.ServiceProvider.GetRequiredService<ServiceRouteBuilder>();
var serviceMethodsRegistry = builder.ServiceProvider.GetRequiredService<ServiceMethodsRegistry>();
serviceMethodsRegistry.ServiceDefinitions.Add(serviceDefinition);

var serviceRouteBuilder = builder.ServiceProvider.GetRequiredService<ServiceRouteBuilder<ServerServiceDefinitionDummyService>>();
var endpointConventionBuilders = serviceRouteBuilder.Build(builder, serviceDefinition);

return new GrpcServiceEndpointConventionBuilder(endpointConventionBuilders);
Expand All @@ -80,7 +87,11 @@ public static GrpcServiceEndpointConventionBuilder MapGrpcService(this IEndpoint
ArgumentNullException.ThrowIfNull(getServiceDefinition, nameof(getServiceDefinition));

var serviceDefinition = getServiceDefinition(builder.ServiceProvider);
var serviceRouteBuilder = builder.ServiceProvider.GetRequiredService<ServiceRouteBuilder>();

var serviceMethodsRegistry = builder.ServiceProvider.GetRequiredService<ServiceMethodsRegistry>();
serviceMethodsRegistry.ServiceDefinitions.Add(serviceDefinition);

var serviceRouteBuilder = builder.ServiceProvider.GetRequiredService<ServiceRouteBuilder<ServerServiceDefinitionDummyService>>();
var endpointConventionBuilders = serviceRouteBuilder.Build(builder, serviceDefinition);

return new GrpcServiceEndpointConventionBuilder(endpointConventionBuilders);
Expand Down
3 changes: 1 addition & 2 deletions src/Grpc.AspNetCore.Server/GrpcServiceExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,16 @@ public static IGrpcServerBuilder AddGrpc(this IServiceCollection services)
#endif
services.AddOptions();
services.TryAddSingleton<GrpcMarkerService>();
services.TryAddSingleton(typeof(ServerCallHandlerFactory));
services.TryAddSingleton(typeof(ServerCallHandlerFactory<>));
services.TryAddSingleton(typeof(IGrpcServiceActivator<>), typeof(DefaultGrpcServiceActivator<>));
services.TryAddSingleton(typeof(IGrpcInterceptorActivator<>), typeof(DefaultGrpcInterceptorActivator<>));
services.TryAddEnumerable(ServiceDescriptor.Singleton<IConfigureOptions<GrpcServiceOptions>, GrpcServiceOptionsSetup>());

// Model
services.TryAddSingleton<ServiceMethodsRegistry>();
services.TryAddSingleton(typeof(ServiceRouteBuilder));
services.TryAddSingleton(typeof(ServiceRouteBuilder<>));
services.TryAddEnumerable(ServiceDescriptor.Singleton(typeof(IServiceMethodProvider<>), typeof(BinderServiceMethodProvider<>)));
services.TryAddEnumerable(ServiceDescriptor.Singleton(typeof(IServiceMethodProvider<>), typeof(ServiceDefinitionMethodProvider<>)));

return new GrpcServerBuilder(services);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,56 +24,6 @@

namespace Grpc.AspNetCore.Server.Internal.CallHandlers;

internal sealed class ClientStreamingServerCallHandler<TRequest, TResponse> : ServerCallHandlerBase<TRequest, TResponse>
where TRequest : class
where TResponse : class
{
private readonly ClientStreamingServerMethodInvoker<TRequest, TResponse> _invoker;

public ClientStreamingServerCallHandler(
ClientStreamingServerMethodInvoker<TRequest, TResponse> invoker,
ILoggerFactory loggerFactory)
: base(invoker, loggerFactory)
{
_invoker = invoker;
}

protected override async Task HandleCallAsyncCore(HttpContext httpContext, HttpContextServerCallContext serverCallContext)
{
// Disable request body data rate for client streaming
DisableMinRequestBodyDataRateAndMaxRequestBodySize(httpContext);

TResponse? response;

var streamReader = new HttpContextStreamReader<TRequest>(serverCallContext, MethodInvoker.Method.RequestMarshaller.ContextualDeserializer);
try
{
response = await _invoker.Invoke(httpContext, serverCallContext, streamReader);
}
finally
{
streamReader.Complete();
}

if (response == null)
{
// This is consistent with Grpc.Core when a null value is returned
throw new RpcException(new Status(StatusCode.Cancelled, "No message returned from method."));
}

// Check if deadline exceeded while method was invoked. If it has then skip trying to write
// the response message because it will always fail.
// Note that the call is still going so the deadline could still be exceeded after this point.
if (serverCallContext.DeadlineManager?.IsDeadlineExceededStarted ?? false)
{
return;
}

var responseBodyWriter = httpContext.Response.BodyWriter;
await responseBodyWriter.WriteSingleMessageAsync(response, serverCallContext, MethodInvoker.Method.ResponseMarshaller.ContextualSerializer);
}
}

internal sealed class ClientStreamingServerCallHandler<[DynamicallyAccessedMembers(GrpcProtocolConstants.ServiceAccessibility)] TService, TRequest, TResponse> : ServerCallHandlerBase<TService, TRequest, TResponse>
where TRequest : class
where TResponse : class
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,39 +23,6 @@

namespace Grpc.AspNetCore.Server.Internal.CallHandlers;

internal sealed class DuplexStreamingServerCallHandler<TRequest, TResponse> : ServerCallHandlerBase<TRequest, TResponse>
where TRequest : class
where TResponse : class
{
private readonly DuplexStreamingServerMethodInvoker<TRequest, TResponse> _invoker;

public DuplexStreamingServerCallHandler(
DuplexStreamingServerMethodInvoker<TRequest, TResponse> invoker,
ILoggerFactory loggerFactory)
: base(invoker, loggerFactory)
{
_invoker = invoker;
}

protected override async Task HandleCallAsyncCore(HttpContext httpContext, HttpContextServerCallContext serverCallContext)
{
// Disable request body data rate for client streaming
DisableMinRequestBodyDataRateAndMaxRequestBodySize(httpContext);

var streamReader = new HttpContextStreamReader<TRequest>(serverCallContext, MethodInvoker.Method.RequestMarshaller.ContextualDeserializer);
var streamWriter = new HttpContextStreamWriter<TResponse>(serverCallContext, MethodInvoker.Method.ResponseMarshaller.ContextualSerializer);
try
{
await _invoker.Invoke(httpContext, serverCallContext, streamReader, streamWriter);
}
finally
{
streamReader.Complete();
streamWriter.Complete();
}
}
}

internal sealed class DuplexStreamingServerCallHandler<[DynamicallyAccessedMembers(GrpcProtocolConstants.ServiceAccessibility)] TService, TRequest, TResponse> : ServerCallHandlerBase<TService, TRequest, TResponse>
where TRequest : class
where TResponse : class
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,139 +30,6 @@

namespace Grpc.AspNetCore.Server.Internal.CallHandlers;

internal abstract class ServerCallHandlerBase<TRequest, TResponse>
where TRequest : class
where TResponse : class
{
private const string LoggerName = "Grpc.AspNetCore.Server.ServerCallHandler";

protected ServerMethodInvokerBase<TRequest, TResponse> MethodInvoker { get; }
protected ILogger Logger { get; }

protected ServerCallHandlerBase(
ServerMethodInvokerBase<TRequest, TResponse> methodInvoker,
ILoggerFactory loggerFactory)
{
MethodInvoker = methodInvoker;
Logger = loggerFactory.CreateLogger(LoggerName);
}

public Task HandleCallAsync(HttpContext httpContext)
{
if (GrpcProtocolHelpers.IsInvalidContentType(httpContext, out var error))
{
return ProcessInvalidContentTypeRequest(httpContext, error);
}

if (!GrpcProtocolConstants.IsHttp2(httpContext.Request.Protocol)
#if NET6_0_OR_GREATER
&& !GrpcProtocolConstants.IsHttp3(httpContext.Request.Protocol)
#endif
)
{
return ProcessNonHttp2Request(httpContext);
}

var serverCallContext = new HttpContextServerCallContext(httpContext, MethodInvoker.Options, typeof(TRequest), typeof(TResponse), Logger);
httpContext.Features.Set<IServerCallContextFeature>(serverCallContext);

GrpcProtocolHelpers.AddProtocolHeaders(httpContext.Response);

try
{
serverCallContext.Initialize();

var handleCallTask = HandleCallAsyncCore(httpContext, serverCallContext);

if (handleCallTask.IsCompletedSuccessfully)
{
return serverCallContext.EndCallAsync();
}
else
{
return AwaitHandleCall(serverCallContext, MethodInvoker.Method, handleCallTask);
}
}
catch (Exception ex)
{
return serverCallContext.ProcessHandlerErrorAsync(ex, MethodInvoker.Method.Name);
}

static async Task AwaitHandleCall(HttpContextServerCallContext serverCallContext, Method<TRequest, TResponse> method, Task handleCall)
{
try
{
await handleCall;
await serverCallContext.EndCallAsync();
}
catch (Exception ex)
{
await serverCallContext.ProcessHandlerErrorAsync(ex, method.Name);
}
}
}

protected abstract Task HandleCallAsyncCore(HttpContext httpContext, HttpContextServerCallContext serverCallContext);

/// <summary>
/// This should only be called from client streaming calls
/// </summary>
/// <param name="httpContext"></param>
protected void DisableMinRequestBodyDataRateAndMaxRequestBodySize(HttpContext httpContext)
{
var minRequestBodyDataRateFeature = httpContext.Features.Get<IHttpMinRequestBodyDataRateFeature>();
if (minRequestBodyDataRateFeature != null)
{
minRequestBodyDataRateFeature.MinDataRate = null;
}

var maxRequestBodySizeFeature = httpContext.Features.Get<IHttpMaxRequestBodySizeFeature>();
if (maxRequestBodySizeFeature != null)
{
if (!maxRequestBodySizeFeature.IsReadOnly)
{
maxRequestBodySizeFeature.MaxRequestBodySize = null;
}
else
{
// IsReadOnly could be true if middleware has already started reading the request body
// In that case we can't disable the max request body size for the request stream
GrpcServerLog.UnableToDisableMaxRequestBodySize(Logger);
}
}
}

private Task ProcessNonHttp2Request(HttpContext httpContext)
{
GrpcServerLog.UnsupportedRequestProtocol(Logger, httpContext.Request.Protocol);

var protocolError = $"Request protocol '{httpContext.Request.Protocol}' is not supported.";
GrpcProtocolHelpers.BuildHttpErrorResponse(httpContext.Response, StatusCodes.Status426UpgradeRequired, StatusCode.Internal, protocolError);
httpContext.Response.Headers[HeaderNames.Upgrade] = GrpcProtocolConstants.Http2Protocol;
return Task.CompletedTask;
}

private Task ProcessInvalidContentTypeRequest(HttpContext httpContext, string error)
{
// This might be a CORS preflight request and CORS middleware hasn't been configured
if (GrpcProtocolHelpers.IsCorsPreflightRequest(httpContext))
{
GrpcServerLog.UnhandledCorsPreflightRequest(Logger);

GrpcProtocolHelpers.BuildHttpErrorResponse(httpContext.Response, StatusCodes.Status405MethodNotAllowed, StatusCode.Internal, "Unhandled CORS preflight request received. CORS may not be configured correctly in the application.");
httpContext.Response.Headers[HeaderNames.Allow] = HttpMethods.Post;
return Task.CompletedTask;
}
else
{
GrpcServerLog.UnsupportedRequestContentType(Logger, httpContext.Request.ContentType);

GrpcProtocolHelpers.BuildHttpErrorResponse(httpContext.Response, StatusCodes.Status415UnsupportedMediaType, StatusCode.Internal, error);
return Task.CompletedTask;
}
}
}

internal abstract class ServerCallHandlerBase<[DynamicallyAccessedMembers(GrpcProtocolConstants.ServiceAccessibility)]TService, TRequest, TResponse>
where TService : class
where TRequest : class
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,37 +23,6 @@

namespace Grpc.AspNetCore.Server.Internal.CallHandlers;

internal sealed class ServerStreamingServerCallHandler<TRequest, TResponse> : ServerCallHandlerBase<TRequest, TResponse>
where TRequest : class
where TResponse : class
{
private readonly ServerStreamingServerMethodInvoker<TRequest, TResponse> _invoker;

public ServerStreamingServerCallHandler(
ServerStreamingServerMethodInvoker<TRequest, TResponse> invoker,
ILoggerFactory loggerFactory)
: base(invoker, loggerFactory)
{
_invoker = invoker;
}

protected override async Task HandleCallAsyncCore(HttpContext httpContext, HttpContextServerCallContext serverCallContext)
{
// Decode request
var request = await httpContext.Request.BodyReader.ReadSingleMessageAsync<TRequest>(serverCallContext, MethodInvoker.Method.RequestMarshaller.ContextualDeserializer);

var streamWriter = new HttpContextStreamWriter<TResponse>(serverCallContext, MethodInvoker.Method.ResponseMarshaller.ContextualSerializer);
try
{
await _invoker.Invoke(httpContext, serverCallContext, request, streamWriter);
}
finally
{
streamWriter.Complete();
}
}
}

internal sealed class ServerStreamingServerCallHandler<[DynamicallyAccessedMembers(GrpcProtocolConstants.ServiceAccessibility)] TService, TRequest, TResponse> : ServerCallHandlerBase<TService, TRequest, TResponse>
where TRequest : class
where TResponse : class
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,45 +24,6 @@

namespace Grpc.AspNetCore.Server.Internal.CallHandlers;

internal sealed class UnaryServerCallHandler<TRequest, TResponse> : ServerCallHandlerBase<TRequest, TResponse>
where TRequest : class
where TResponse : class
{
private readonly UnaryServerMethodInvoker<TRequest, TResponse> _invoker;

public UnaryServerCallHandler(
UnaryServerMethodInvoker<TRequest, TResponse> invoker,
ILoggerFactory loggerFactory)
: base(invoker, loggerFactory)
{
_invoker = invoker;
}

protected override async Task HandleCallAsyncCore(HttpContext httpContext, HttpContextServerCallContext serverCallContext)
{
var request = await httpContext.Request.BodyReader.ReadSingleMessageAsync<TRequest>(serverCallContext, MethodInvoker.Method.RequestMarshaller.ContextualDeserializer);

var response = await _invoker.Invoke(httpContext, serverCallContext, request);

if (response == null)
{
// This is consistent with Grpc.Core when a null value is returned
throw new RpcException(new Status(StatusCode.Cancelled, "No message returned from method."));
}

// Check if deadline exceeded while method was invoked. If it has then skip trying to write
// the response message because it will always fail.
// Note that the call is still going so the deadline could still be exceeded after this point.
if (serverCallContext.DeadlineManager?.IsDeadlineExceededStarted ?? false)
{
return;
}

var responseBodyWriter = httpContext.Response.BodyWriter;
await responseBodyWriter.WriteSingleMessageAsync(response, serverCallContext, MethodInvoker.Method.ResponseMarshaller.ContextualSerializer);
}
}

internal sealed class UnaryServerCallHandler<[DynamicallyAccessedMembers(GrpcProtocolConstants.ServiceAccessibility)] TService, TRequest, TResponse> : ServerCallHandlerBase<TService, TRequest, TResponse>
where TRequest : class
where TResponse : class
Expand Down
Loading