Skip to content
Prev Previous commit
Next Next commit
Fix hook timeout implementation for PR #2891
- Fixed InstanceHookMethod to use InitClassType property instead of trying to override ClassType with init accessor
- Updated source generators to generate InitClassType instead of ClassType for InstanceHookMethod
- Fixed async/await issues in HookCollectionService by restructuring methods to avoid async lambdas
- Added proper DynamicallyAccessedMembers annotations for trimming compatibility
- Fixed EventReceiverOrchestrator to use HookMethod property from HookRegisteredContext
  • Loading branch information
thomhurst committed Aug 13, 2025
commit d0c57d3f591a85ae6f2ee1adf8a441b7d97575cb
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public static void Execute(ICodeWriter sourceBuilder, HooksDataModel model)

sourceBuilder.Append("new global::TUnit.Core.Hooks.InstanceHookMethod");
sourceBuilder.Append("{");
sourceBuilder.Append($"ClassType = typeof({model.FullyQualifiedTypeName}),");
sourceBuilder.Append($"InitClassType = typeof({model.FullyQualifiedTypeName}),");
sourceBuilder.Append("MethodInfo = ");
SourceInformationWriter.GenerateMethodInformation(sourceBuilder, model.Context.SemanticModel.Compilation, model.ClassType, model.Method, null, ',');

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -803,7 +803,7 @@ private static void GenerateHookObject(CodeWriter writer, HookMethodMetadata hoo

if (isInstance)
{
writer.AppendLine($"ClassType = typeof({hook.TypeSymbol.GloballyQualified()}),");
writer.AppendLine($"InitClassType = typeof({hook.TypeSymbol.GloballyQualified()}),");
}

writer.Append("MethodInfo = ");
Expand Down
1 change: 1 addition & 0 deletions TUnit.Core/Hooks/HookMethod.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ public abstract record HookMethod
[field: AllowNull, MaybeNull]
public string Name => field ??= $"{ClassType.Name}.{MethodInfo.Name}({string.Join(", ", MethodInfo.Parameters.Select(x => x.Name))})";

[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicMethods)]
public abstract Type ClassType { get; }
public virtual Assembly? Assembly => ClassType?.Assembly;

Expand Down
11 changes: 10 additions & 1 deletion TUnit.Core/Hooks/InstanceHookMethod.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,16 @@ namespace TUnit.Core.Hooks;
public record InstanceHookMethod : HookMethod, IExecutableHook<TestContext>
{
[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicMethods)]
public override Type ClassType { get; init; }
private readonly Type _classType = null!;

[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicMethods)]
public override Type ClassType => _classType;

public required Type InitClassType
{
[param: DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicMethods)]
init { _classType = value; }
}

public Func<object, TestContext, CancellationToken, ValueTask>? Body { get; init; }

Expand Down
1 change: 1 addition & 0 deletions TUnit.Core/Hooks/StaticHookMethod.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ public abstract record StaticHookMethod<T> : StaticHookMethod, IExecutableHook<T
#endif
public abstract record StaticHookMethod : HookMethod
{
[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicMethods)]
public override Type ClassType => MethodInfo.Class.Type;

public required string FilePath { get; init; }
Expand Down
11 changes: 2 additions & 9 deletions TUnit.Engine/Services/EventReceiverOrchestrator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -221,16 +221,9 @@ public async ValueTask InvokeHookRegistrationEventReceiversAsync(HookRegisteredC
}

// Apply the timeout from the context back to the hook method
if (hookContext.Timeout.HasValue)
if (hookContext.Timeout.HasValue && hookContext.HookMethod != null)
{
if (hookContext.StaticHookMethod != null)
{
hookContext.StaticHookMethod.Timeout = hookContext.Timeout;
}
else if (hookContext.InstanceHookMethod != null)
{
hookContext.InstanceHookMethod.Timeout = hookContext.Timeout;
}
hookContext.HookMethod.Timeout = hookContext.Timeout;
}
}

Expand Down
68 changes: 48 additions & 20 deletions TUnit.Engine/Services/HookCollectionService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,19 @@ private sealed class CompleteHookChain
];
}

public ValueTask<IReadOnlyList<Func<TestContext, CancellationToken, Task>>> CollectBeforeTestHooksAsync(Type testClassType)
public async ValueTask<IReadOnlyList<Func<TestContext, CancellationToken, Task>>> CollectBeforeTestHooksAsync(Type testClassType)
{
var hooks = _beforeTestHooksCache.GetOrAdd(testClassType, type =>
if (_beforeTestHooksCache.TryGetValue(testClassType, out var cachedHooks))
{
return cachedHooks;
}

var hooks = await BuildBeforeTestHooksAsync(testClassType);
_beforeTestHooksCache.TryAdd(testClassType, hooks);
return hooks;
}

private async Task<IReadOnlyList<Func<TestContext, CancellationToken, Task>>> BuildBeforeTestHooksAsync(Type type)
{
var hooksByType = new List<(Type type, List<(int order, int registrationIndex, Func<TestContext, CancellationToken, Task> hook)> hooks)>();

Expand Down Expand Up @@ -120,14 +130,21 @@ public ValueTask<IReadOnlyList<Func<TestContext, CancellationToken, Task>>> Coll
}

return finalHooks;
});

return new ValueTask<IReadOnlyList<Func<TestContext, CancellationToken, Task>>>(hooks);
}

public ValueTask<IReadOnlyList<Func<TestContext, CancellationToken, Task>>> CollectAfterTestHooksAsync(Type testClassType)
public async ValueTask<IReadOnlyList<Func<TestContext, CancellationToken, Task>>> CollectAfterTestHooksAsync(Type testClassType)
{
var hooks = _afterTestHooksCache.GetOrAdd(testClassType, type =>
if (_afterTestHooksCache.TryGetValue(testClassType, out var cachedHooks))
{
return cachedHooks;
}

var hooks = await BuildAfterTestHooksAsync(testClassType);
_afterTestHooksCache.TryAdd(testClassType, hooks);
return hooks;
}

private async Task<IReadOnlyList<Func<TestContext, CancellationToken, Task>>> BuildAfterTestHooksAsync(Type type)
{
var hooksByType = new List<(Type type, List<(int order, int registrationIndex, Func<TestContext, CancellationToken, Task> hook)> hooks)>();

Expand Down Expand Up @@ -179,14 +196,21 @@ public ValueTask<IReadOnlyList<Func<TestContext, CancellationToken, Task>>> Coll
}

return finalHooks;
});

return new ValueTask<IReadOnlyList<Func<TestContext, CancellationToken, Task>>>(hooks);
}

public ValueTask<IReadOnlyList<Func<TestContext, CancellationToken, Task>>> CollectBeforeEveryTestHooksAsync(Type testClassType)
public async ValueTask<IReadOnlyList<Func<TestContext, CancellationToken, Task>>> CollectBeforeEveryTestHooksAsync(Type testClassType)
{
var hooks = _beforeEveryTestHooksCache.GetOrAdd(testClassType, type =>
if (_beforeEveryTestHooksCache.TryGetValue(testClassType, out var cachedHooks))
{
return cachedHooks;
}

var hooks = await BuildBeforeEveryTestHooksAsync(testClassType);
_beforeEveryTestHooksCache.TryAdd(testClassType, hooks);
return hooks;
}

private async Task<IReadOnlyList<Func<TestContext, CancellationToken, Task>>> BuildBeforeEveryTestHooksAsync(Type type)
{
var allHooks = new List<(int order, int registrationIndex, Func<TestContext, CancellationToken, Task> hook)>();

Expand All @@ -202,14 +226,21 @@ public ValueTask<IReadOnlyList<Func<TestContext, CancellationToken, Task>>> Coll
.ThenBy(h => h.registrationIndex)
.Select(h => h.hook)
.ToList();
});

return new ValueTask<IReadOnlyList<Func<TestContext, CancellationToken, Task>>>(hooks);
}

public ValueTask<IReadOnlyList<Func<TestContext, CancellationToken, Task>>> CollectAfterEveryTestHooksAsync(Type testClassType)
public async ValueTask<IReadOnlyList<Func<TestContext, CancellationToken, Task>>> CollectAfterEveryTestHooksAsync(Type testClassType)
{
var hooks = _afterEveryTestHooksCache.GetOrAdd(testClassType, type =>
if (_afterEveryTestHooksCache.TryGetValue(testClassType, out var cachedHooks))
{
return cachedHooks;
}

var hooks = await BuildAfterEveryTestHooksAsync(testClassType);
_afterEveryTestHooksCache.TryAdd(testClassType, hooks);
return hooks;
}

private async Task<IReadOnlyList<Func<TestContext, CancellationToken, Task>>> BuildAfterEveryTestHooksAsync(Type type)
{
var allHooks = new List<(int order, int registrationIndex, Func<TestContext, CancellationToken, Task> hook)>();

Expand All @@ -225,9 +256,6 @@ public ValueTask<IReadOnlyList<Func<TestContext, CancellationToken, Task>>> Coll
.ThenBy(h => h.registrationIndex)
.Select(h => h.hook)
.ToList();
});

return new ValueTask<IReadOnlyList<Func<TestContext, CancellationToken, Task>>>(hooks);
}

public ValueTask<IReadOnlyList<Func<ClassHookContext, CancellationToken, Task>>> CollectBeforeClassHooksAsync(Type testClassType)
Expand Down
Loading