Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Switching to just LINQ for grabbing all the actions
  • Loading branch information
jbogard committed Jul 29, 2022
commit d9a8d1cbaad3d80808a1afb552af76f842a8995e
2 changes: 1 addition & 1 deletion src/MediatR.Contracts/MediatR.Contracts.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
<Authors>Jimmy Bogard</Authors>
<Description>Contracts package for requests, responses, and notifications</Description>
<Copyright>Copyright Jimmy Bogard</Copyright>
<TargetFrameworks>netstandard2.0;net461;</TargetFrameworks>
<TargetFramework>netstandard2.0</TargetFramework>
<Nullable>enable</Nullable>
<Features>strict</Features>
<PackageTags>mediator;request;response;queries;commands;notifications</PackageTags>
Expand Down
70 changes: 41 additions & 29 deletions src/MediatR/Pipeline/RequestExceptionActionProcessorBehavior.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,50 +31,62 @@ public async Task<TResponse> Handle(TRequest request, CancellationToken cancella
}
catch (Exception exception)
{
var attemptedActionTypes = new List<Type>();
var exceptionTypes = GetExceptionTypes(exception.GetType());

for (Type exceptionType = exception.GetType(); exceptionType != typeof(object); exceptionType = exceptionType.BaseType)
{
var actionsForException = GetActionsForException(exceptionType, request, out MethodInfo actionMethod);
var actionsForException = exceptionTypes
.SelectMany(exceptionType => GetActionsForException(exceptionType, request))
.GroupBy(actionForException => actionForException.Action.GetType())
.Select(actionForException => actionForException.First())
.Select(actionForException => (MethodInfo: GetMethodInfoForAction(actionForException.ExceptionType), actionForException.Action))
.ToList();

foreach (var actionForException in actionsForException)
foreach (var actionForException in actionsForException)
{
try
{
var actionType = actionForException.GetType();
if (attemptedActionTypes.Contains(actionType))
{
continue;
}
else
{
attemptedActionTypes.Add(actionType);
}

try
{
await ((Task)(actionMethod.Invoke(actionForException, new object[] { request, exception, cancellationToken })
?? throw new InvalidOperationException($"Could not create task for action method {actionMethod}."))).ConfigureAwait(false);
}
catch (TargetInvocationException invocationException) when (invocationException.InnerException != null)
{
// Unwrap invocation exception to throw the actual error
ExceptionDispatchInfo.Capture(invocationException.InnerException).Throw();
}
await ((Task)(actionForException.MethodInfo.Invoke(actionForException.Action, new object[] { request, exception, cancellationToken })
?? throw new InvalidOperationException($"Could not create task for action method {actionForException.MethodInfo}."))).ConfigureAwait(false);
}
catch (TargetInvocationException invocationException) when (invocationException.InnerException != null)
{
// Unwrap invocation exception to throw the actual error
ExceptionDispatchInfo.Capture(invocationException.InnerException).Throw();
}
}

throw;
}
}

private IList<object> GetActionsForException(Type exceptionType, TRequest request, out MethodInfo actionMethodInfo)
private static IEnumerable<Type> GetExceptionTypes(Type? exceptionType)
{
while (exceptionType != null && exceptionType != typeof(object))
{
yield return exceptionType;
exceptionType = exceptionType.BaseType;
}
}

private IEnumerable<(Type ExceptionType, object Action)> GetActionsForException(Type exceptionType, TRequest request)
{
var exceptionActionInterfaceType = typeof(IRequestExceptionAction<,>).MakeGenericType(typeof(TRequest), exceptionType);
var enumerableExceptionActionInterfaceType = typeof(IEnumerable<>).MakeGenericType(exceptionActionInterfaceType);
actionMethodInfo = exceptionActionInterfaceType.GetMethod(nameof(IRequestExceptionAction<TRequest, Exception>.Execute))
?? throw new InvalidOperationException($"Could not find method {nameof(IRequestExceptionAction<TRequest, Exception>.Execute)} on type {exceptionActionInterfaceType}");

var actionsForException = (IEnumerable<object>)_serviceFactory(enumerableExceptionActionInterfaceType);

return HandlersOrderer.Prioritize(actionsForException.ToList(), request);
return HandlersOrderer.Prioritize(actionsForException.ToList(), request)
.Select(action => (exceptionType, action));
}

private static MethodInfo GetMethodInfoForAction(Type exceptionType)
{
var exceptionActionInterfaceType = typeof(IRequestExceptionAction<,>).MakeGenericType(typeof(TRequest), exceptionType);

var actionMethodInfo =
exceptionActionInterfaceType.GetMethod(nameof(IRequestExceptionAction<TRequest, Exception>.Execute))
?? throw new InvalidOperationException(
$"Could not find method {nameof(IRequestExceptionAction<TRequest, Exception>.Execute)} on type {exceptionActionInterfaceType}");

return actionMethodInfo;
}
}