diff --git a/TUnit.Core.SourceGenerator/Generators/HookMetadataGenerator.cs b/TUnit.Core.SourceGenerator/Generators/HookMetadataGenerator.cs index 5363990590..99f8af2eba 100644 --- a/TUnit.Core.SourceGenerator/Generators/HookMetadataGenerator.cs +++ b/TUnit.Core.SourceGenerator/Generators/HookMetadataGenerator.cs @@ -1,6 +1,7 @@ using System.Collections.Immutable; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp.Syntax; +using TUnit.Core.SourceGenerator.CodeGenerators.Helpers; using TUnit.Core.SourceGenerator.CodeGenerators.Writers; using TUnit.Core.SourceGenerator.Extensions; @@ -136,6 +137,7 @@ public int GetHashCode(HookMethodMetadata? obj) var lineNumber = location.GetLineSpan().StartLinePosition.Line + 1; var order = GetHookOrder(hookAttribute); + var hookExecutor = GetHookExecutorType(methodSymbol); return new HookMethodMetadata { @@ -146,7 +148,8 @@ public int GetHashCode(HookMethodMetadata? obj) HookKind = hookKind, HookType = hookType, Order = order, - Context = context + Context = context, + HookExecutor = hookExecutor }; } @@ -237,6 +240,62 @@ private static int GetHookOrder(AttributeData attribute) return 0; } + private static string? GetHookExecutorType(IMethodSymbol methodSymbol) + { + var hookExecutorAttribute = methodSymbol.GetAttributes() + .FirstOrDefault(a => a.AttributeClass?.Name == "HookExecutorAttribute" || + (a.AttributeClass?.IsGenericType == true && + a.AttributeClass?.ConstructedFrom?.Name == "HookExecutorAttribute")); + + if (hookExecutorAttribute == null) + { + return null; + } + + // For generic HookExecutorAttribute, get the type argument + if (hookExecutorAttribute.AttributeClass?.IsGenericType == true) + { + var typeArg = hookExecutorAttribute.AttributeClass.TypeArguments.FirstOrDefault(); + return typeArg?.GloballyQualified(); + } + + // For non-generic HookExecutorAttribute(Type type), get the constructor argument + var typeArgument = hookExecutorAttribute.ConstructorArguments.FirstOrDefault(); + if (typeArgument.Value is ITypeSymbol typeSymbol) + { + return typeSymbol.GloballyQualified(); + } + + return null; + } + + private static string GetConcreteHookType(string dictionaryName, bool isInstance) + { + if (isInstance) + { + return "InstanceHookMethod"; + } + + return dictionaryName switch + { + "BeforeClassHooks" => "BeforeClassHookMethod", + "AfterClassHooks" => "AfterClassHookMethod", + "BeforeAssemblyHooks" => "BeforeAssemblyHookMethod", + "AfterAssemblyHooks" => "AfterAssemblyHookMethod", + "BeforeTestSessionHooks" => "BeforeTestSessionHookMethod", + "AfterTestSessionHooks" => "AfterTestSessionHookMethod", + "BeforeTestDiscoveryHooks" => "BeforeTestDiscoveryHookMethod", + "AfterTestDiscoveryHooks" => "AfterTestDiscoveryHookMethod", + "BeforeEveryTestHooks" => "BeforeTestHookMethod", + "AfterEveryTestHooks" => "AfterTestHookMethod", + "BeforeEveryClassHooks" => "BeforeClassHookMethod", + "AfterEveryClassHooks" => "AfterClassHookMethod", + "BeforeEveryAssemblyHooks" => "BeforeAssemblyHookMethod", + "AfterEveryAssemblyHooks" => "AfterAssemblyHookMethod", + _ => throw new ArgumentException($"Unknown dictionary name: {dictionaryName}") + }; + } + private static void GenerateHookRegistry(SourceProductionContext context, ImmutableArray hooks) { try @@ -663,7 +722,8 @@ private static void GenerateHookDelegate(CodeWriter writer, HookMethodMetadata h private static void GenerateHookListPopulation(CodeWriter writer, string dictionaryName, string typeDisplay, List hooks, bool isInstance) { - writer.AppendLine($"global::TUnit.Core.Sources.{dictionaryName}.GetOrAdd(typeof({typeDisplay}), _ => new global::System.Collections.Concurrent.ConcurrentBag")}>());"); + var hookType = GetConcreteHookType(dictionaryName, isInstance); + writer.AppendLine($"global::TUnit.Core.Sources.{dictionaryName}.GetOrAdd(typeof({typeDisplay}), _ => new global::System.Collections.Concurrent.ConcurrentBag());"); foreach (var hook in hooks.OrderBy(h => h.Order)) { @@ -679,7 +739,8 @@ private static void GenerateHookListPopulation(CodeWriter writer, string diction private static void GenerateAssemblyHookListPopulation(CodeWriter writer, string dictionaryName, string assemblyVarName, List hooks) { var assemblyVar = assemblyVarName.Replace(".", "_") + "_assembly"; - writer.AppendLine($"global::TUnit.Core.Sources.{dictionaryName}.GetOrAdd({assemblyVar}, _ => new global::System.Collections.Concurrent.ConcurrentBag>());"); + var hookType = GetConcreteHookType(dictionaryName, false); + writer.AppendLine($"global::TUnit.Core.Sources.{dictionaryName}.GetOrAdd({assemblyVar}, _ => new global::System.Collections.Concurrent.ConcurrentBag());"); foreach (var hook in hooks.OrderBy(h => h.Order)) { @@ -722,7 +783,7 @@ private static void GenerateHookObject(CodeWriter writer, HookMethodMetadata hoo writer.Append("MethodInfo = "); SourceInformationWriter.GenerateMethodInformation(writer, hook.Context.SemanticModel.Compilation, hook.TypeSymbol, hook.MethodSymbol, null, ','); writer.AppendLine(); - writer.AppendLine("HookExecutor = null!,"); + writer.AppendLine($"HookExecutor = {HookExecutorHelper.GetHookExecutor(hook.HookExecutor)},"); writer.AppendLine($"Order = {hook.Order},"); writer.AppendLine($"Body = {delegateKey}_Body" + (isInstance ? "" : ",")); @@ -835,4 +896,5 @@ public class HookMethodMetadata public required string HookType { get; init; } public required int Order { get; init; } public required GeneratorAttributeSyntaxContext Context { get; init; } + public string? HookExecutor { get; init; } } diff --git a/TUnit.Core/Sources.cs b/TUnit.Core/Sources.cs index ecec1c8800..988c8dd73a 100644 --- a/TUnit.Core/Sources.cs +++ b/TUnit.Core/Sources.cs @@ -15,23 +15,23 @@ public static class Sources public static readonly ConcurrentDictionary> BeforeTestHooks = new(); public static readonly ConcurrentDictionary> AfterTestHooks = new(); - public static readonly ConcurrentBag> BeforeEveryTestHooks = new(); - public static readonly ConcurrentBag> AfterEveryTestHooks = new(); + public static readonly ConcurrentBag BeforeEveryTestHooks = new(); + public static readonly ConcurrentBag AfterEveryTestHooks = new(); - public static readonly ConcurrentDictionary>> BeforeClassHooks = new(); - public static readonly ConcurrentDictionary>> AfterClassHooks = new(); - public static readonly ConcurrentBag> BeforeEveryClassHooks = new(); - public static readonly ConcurrentBag> AfterEveryClassHooks = new(); + public static readonly ConcurrentDictionary> BeforeClassHooks = new(); + public static readonly ConcurrentDictionary> AfterClassHooks = new(); + public static readonly ConcurrentBag BeforeEveryClassHooks = new(); + public static readonly ConcurrentBag AfterEveryClassHooks = new(); - public static readonly ConcurrentDictionary>> BeforeAssemblyHooks = new(); - public static readonly ConcurrentDictionary>> AfterAssemblyHooks = new(); - public static readonly ConcurrentBag> BeforeEveryAssemblyHooks = new(); - public static readonly ConcurrentBag> AfterEveryAssemblyHooks = new(); + public static readonly ConcurrentDictionary> BeforeAssemblyHooks = new(); + public static readonly ConcurrentDictionary> AfterAssemblyHooks = new(); + public static readonly ConcurrentBag BeforeEveryAssemblyHooks = new(); + public static readonly ConcurrentBag AfterEveryAssemblyHooks = new(); - public static readonly ConcurrentBag> BeforeTestSessionHooks = []; - public static readonly ConcurrentBag> AfterTestSessionHooks = []; - public static readonly ConcurrentBag> BeforeTestDiscoveryHooks = []; - public static readonly ConcurrentBag> AfterTestDiscoveryHooks = []; + public static readonly ConcurrentBag BeforeTestSessionHooks = []; + public static readonly ConcurrentBag AfterTestSessionHooks = []; + public static readonly ConcurrentBag BeforeTestDiscoveryHooks = []; + public static readonly ConcurrentBag AfterTestDiscoveryHooks = []; public static readonly ConcurrentQueue> GlobalInitializers = []; public static readonly ConcurrentQueue PropertySources = []; diff --git a/TUnit.Engine/Discovery/ReflectionHookDiscoveryService.cs b/TUnit.Engine/Discovery/ReflectionHookDiscoveryService.cs deleted file mode 100644 index f424b49fc5..0000000000 --- a/TUnit.Engine/Discovery/ReflectionHookDiscoveryService.cs +++ /dev/null @@ -1,460 +0,0 @@ -using System.Collections.Concurrent; -using System.Diagnostics.CodeAnalysis; -using System.Reflection; -using System.Runtime.CompilerServices; -using TUnit.Core; -using TUnit.Core.Hooks; -using TUnit.Core.Interfaces; -using TUnit.Engine.Building; -using TUnit.Engine.Services; - -namespace TUnit.Engine.Discovery; - -/// -/// Discovers hooks from assemblies using reflection and populates the Sources dictionaries -/// -[RequiresUnreferencedCode("Reflection-based hook discovery requires unreferenced code")] -internal sealed class ReflectionHookDiscoveryService -{ - private static readonly HashSet _scannedAssemblies = new(); - private static readonly object _lock = new(); - private readonly IHookExecutor _hookExecutor; - - public ReflectionHookDiscoveryService(IHookExecutor hookExecutor) - { - _hookExecutor = hookExecutor; - } - - public Task DiscoverHooksInAssembliesAsync(IEnumerable assemblies) - { - var assembliesToScan = assemblies.Where(assembly => - { - lock (_lock) - { - return _scannedAssemblies.Add(assembly); - } - }).ToList(); - - foreach (var assembly in assembliesToScan) - { - DiscoverHooksInAssembly(assembly); - } - - return Task.CompletedTask; - } - - [UnconditionalSuppressMessage("Trimming", "IL2026:Using member 'System.Reflection.Assembly.GetExportedTypes()' which has 'RequiresUnreferencedCodeAttribute' can break functionality when trimming application code")] - [UnconditionalSuppressMessage("Trimming", "IL2070:'this' argument does not satisfy 'DynamicallyAccessedMemberTypes.PublicMethods' in call to 'System.Type.GetMethods(BindingFlags)'")] - private void DiscoverHooksInAssembly(Assembly assembly) - { - try - { - // In single file mode, GetExportedTypes might miss some types - // Use GetTypes() instead which gets all types including nested ones - var allTypes = assembly.GetTypes(); - var types = allTypes - .Where(t => t is { IsClass: true, IsAbstract: false } && !IsCompilerGenerated(t)); - - foreach (var type in types) - { - DiscoverHooksInType(type, assembly); - } - - // Discover assembly-level hooks (static methods in any type) - DiscoverAssemblyLevelHooks(assembly, types); - } - catch (Exception ex) - { - Console.WriteLine($"Warning: Failed to discover hooks in assembly {assembly.FullName}: {ex.Message}"); - } - } - - private void DiscoverHooksInType(Type type, Assembly assembly) - { - try - { - var methods = type.GetMethods( - BindingFlags.Public | BindingFlags.NonPublic | - BindingFlags.Instance | BindingFlags.Static | - BindingFlags.DeclaredOnly); - - foreach (var method in methods) - { - // Process Before attributes - var beforeAttrs = method.GetCustomAttributes(inherit: false); - foreach (var attr in beforeAttrs) - { - if (attr.HookType.HasFlag(HookType.TestDiscovery)) - { - AddTestDiscoveryHook(method, attr, isBeforeHook: true); - } - else - { - AddRuntimeHook(type, method, attr, isBeforeHook: true); - } - } - - // Process After attributes - var afterAttrs = method.GetCustomAttributes(inherit: false); - foreach (var attr in afterAttrs) - { - if (attr.HookType.HasFlag(HookType.TestDiscovery)) - { - AddTestDiscoveryHook(method, attr, isBeforeHook: false); - } - else - { - AddRuntimeHook(type, method, attr, isBeforeHook: false); - } - } - } - } - catch (Exception ex) - { - Console.WriteLine($"Warning: Failed to discover hooks in type {type.FullName}: {ex.Message}"); - } - } - - private void DiscoverAssemblyLevelHooks(Assembly assembly, IEnumerable types) - { - foreach (var type in types) - { - try - { - var methods = type.GetMethods( - BindingFlags.Public | BindingFlags.NonPublic | - BindingFlags.Static | BindingFlags.DeclaredOnly); - - foreach (var method in methods) - { - // Check for assembly-level Before hooks - var beforeAttrs = method.GetCustomAttributes() - .Where(a => a.HookType.HasFlag(HookType.Assembly)); - - foreach (var attr in beforeAttrs) - { - if (!attr.HookType.HasFlag(HookType.TestDiscovery)) - { - AddAssemblyHook(assembly, method, attr, isBeforeHook: true); - } - } - - // Check for assembly-level After hooks - var afterAttrs = method.GetCustomAttributes() - .Where(a => a.HookType.HasFlag(HookType.Assembly)); - - foreach (var attr in afterAttrs) - { - if (!attr.HookType.HasFlag(HookType.TestDiscovery)) - { - AddAssemblyHook(assembly, method, attr, isBeforeHook: false); - } - } - } - } - catch (Exception ex) - { - Console.WriteLine($"Warning: Failed to discover assembly hooks in type {type.FullName}: {ex.Message}"); - } - } - } - - private void AddRuntimeHook(Type type, MethodInfo method, HookAttribute attr, bool isBeforeHook) - { - var methodMetadata = ReflectionMetadataBuilder.CreateMethodMetadata(type, method); - - if (attr.HookType.HasFlag(HookType.Class)) - { - var hook = new BeforeClassHookMethod - { - MethodInfo = methodMetadata, - HookExecutor = _hookExecutor, - Order = attr.Order, - FilePath = "Unknown", // Assembly.Location not available in single-file apps - LineNumber = 0, - Body = CreateClassHookBody(method) - }; - - if (isBeforeHook) - { - Sources.BeforeClassHooks.GetOrAdd(type, _ => new ConcurrentBag>()).Add(hook); - } - else - { - Sources.AfterClassHooks.GetOrAdd(type, _ => new ConcurrentBag>()).Add(hook); - } - } - - if (attr.HookType.HasFlag(HookType.Test)) - { - if (method.IsStatic) - { - var hook = new BeforeTestHookMethod - { - MethodInfo = methodMetadata, - HookExecutor = _hookExecutor, - Order = attr.Order, - FilePath = "Unknown", // Assembly.Location not available in single-file apps - LineNumber = 0, - Body = CreateStaticTestHookBody(method) - }; - - if (isBeforeHook) - { - Sources.BeforeEveryTestHooks.Add(hook); - } - else - { - Sources.AfterEveryTestHooks.Add(hook); - } - } - else - { - var hook = new InstanceHookMethod - { - ClassType = type, - MethodInfo = methodMetadata, - HookExecutor = _hookExecutor, - Order = attr.Order, - Body = CreateInstanceHookBody(method) - }; - - if (isBeforeHook) - { - Sources.BeforeTestHooks.GetOrAdd(type, _ => new ConcurrentBag()).Add(hook); - } - else - { - Sources.AfterTestHooks.GetOrAdd(type, _ => new ConcurrentBag()).Add(hook); - } - } - } - - if (attr.HookType.HasFlag(HookType.TestSession)) - { - var hook = new BeforeTestSessionHookMethod - { - MethodInfo = methodMetadata, - HookExecutor = _hookExecutor, - Order = attr.Order, - FilePath = "Unknown", // Assembly.Location not available in single-file apps - LineNumber = 0, - Body = CreateTestSessionHookBody(method) - }; - - if (isBeforeHook) - { - Sources.BeforeTestSessionHooks.Add(hook); - } - else - { - Sources.AfterTestSessionHooks.Add(hook); - } - } - } - - private void AddAssemblyHook(Assembly assembly, MethodInfo method, HookAttribute attr, bool isBeforeHook) - { - var methodMetadata = ReflectionMetadataBuilder.CreateMethodMetadata(method.DeclaringType!, method); - - var hook = new BeforeAssemblyHookMethod - { - MethodInfo = methodMetadata, - HookExecutor = _hookExecutor, - Order = attr.Order, - FilePath = "Unknown", // Assembly.Location not available in single-file apps - LineNumber = 0, - Body = CreateAssemblyHookBody(method) - }; - - if (isBeforeHook) - { - Sources.BeforeAssemblyHooks.GetOrAdd(assembly, _ => new ConcurrentBag>()).Add(hook); - } - else - { - Sources.AfterAssemblyHooks.GetOrAdd(assembly, _ => new ConcurrentBag>()).Add(hook); - } - } - - private void AddTestDiscoveryHook(MethodInfo method, HookAttribute attr, bool isBeforeHook) - { - var methodMetadata = ReflectionMetadataBuilder.CreateMethodMetadata(method.DeclaringType!, method); - - if (isBeforeHook) - { - var hook = new BeforeTestDiscoveryHookMethod - { - MethodInfo = methodMetadata, - HookExecutor = _hookExecutor, - Order = attr.Order, - FilePath = "Unknown", // Assembly.Location not available in single-file apps - LineNumber = 0, - Body = CreateBeforeTestDiscoveryHookBody(method) - }; - Sources.BeforeTestDiscoveryHooks.Add(hook); - } - else - { - var hook = new AfterTestDiscoveryHookMethod - { - MethodInfo = methodMetadata, - HookExecutor = _hookExecutor, - Order = attr.Order, - FilePath = "Unknown", // Assembly.Location not available in single-file apps - LineNumber = 0, - Body = CreateAfterTestDiscoveryHookBody(method) - }; - Sources.AfterTestDiscoveryHooks.Add(hook); - } - } - - [UnconditionalSuppressMessage("Trimming", "IL2070:Target method parameter does not satisfy annotation requirements")] - private Func? CreateInstanceHookBody(MethodInfo method) - { - return async (instance, context, cancellationToken) => - { - var parameters = BuildHookParameters(method, context, cancellationToken); - var result = method.Invoke(instance, parameters); - await HandleHookResult(result); - }; - } - - [UnconditionalSuppressMessage("Trimming", "IL2070:Target method parameter does not satisfy annotation requirements")] - private Func? CreateStaticTestHookBody(MethodInfo method) - { - return async (context, cancellationToken) => - { - var parameters = BuildHookParameters(method, context, cancellationToken); - var result = method.Invoke(null, parameters); - await HandleHookResult(result); - }; - } - - [UnconditionalSuppressMessage("Trimming", "IL2070:Target method parameter does not satisfy annotation requirements")] - private Func? CreateClassHookBody(MethodInfo method) - { - return async (context, cancellationToken) => - { - var parameters = BuildHookParameters(method, context, cancellationToken); - var result = method.Invoke(null, parameters); - await HandleHookResult(result); - }; - } - - [UnconditionalSuppressMessage("Trimming", "IL2070:Target method parameter does not satisfy annotation requirements")] - private Func? CreateAssemblyHookBody(MethodInfo method) - { - return async (context, cancellationToken) => - { - var parameters = BuildHookParameters(method, context, cancellationToken); - var result = method.Invoke(null, parameters); - await HandleHookResult(result); - }; - } - - [UnconditionalSuppressMessage("Trimming", "IL2070:Target method parameter does not satisfy annotation requirements")] - private Func? CreateTestSessionHookBody(MethodInfo method) - { - return async (context, cancellationToken) => - { - var parameters = BuildHookParameters(method, context, cancellationToken); - var result = method.Invoke(null, parameters); - await HandleHookResult(result); - }; - } - - [UnconditionalSuppressMessage("Trimming", "IL2070:Target method parameter does not satisfy annotation requirements")] - private Func? CreateBeforeTestDiscoveryHookBody(MethodInfo method) - { - return async (context, cancellationToken) => - { - var parameters = BuildHookParameters(method, context, cancellationToken); - var result = method.Invoke(null, parameters); - await HandleHookResult(result); - }; - } - - [UnconditionalSuppressMessage("Trimming", "IL2070:Target method parameter does not satisfy annotation requirements")] - private Func? CreateAfterTestDiscoveryHookBody(MethodInfo method) - { - return async (context, cancellationToken) => - { - var parameters = BuildHookParameters(method, context, cancellationToken); - var result = method.Invoke(null, parameters); - await HandleHookResult(result); - }; - } - - private object?[] BuildHookParameters(MethodInfo method, object context, CancellationToken cancellationToken) - { - var parameters = method.GetParameters(); - var args = new object?[parameters.Length]; - - for (var i = 0; i < parameters.Length; i++) - { - var paramType = parameters[i].ParameterType; - - if (paramType == typeof(CancellationToken)) - { - args[i] = cancellationToken; - } - else if (paramType == context.GetType()) - { - args[i] = context; - } - else if (context is TestContext testContext) - { - if (paramType == typeof(ClassHookContext)) - { - args[i] = testContext.ClassContext; - } - else if (paramType == typeof(AssemblyHookContext)) - { - args[i] = testContext.ClassContext.AssemblyContext; - } - else if (paramType == typeof(TestSessionContext)) - { - args[i] = testContext.ClassContext.AssemblyContext.TestSessionContext; - } - } - else if (context is ClassHookContext classContext) - { - if (paramType == typeof(AssemblyHookContext)) - { - args[i] = classContext.AssemblyContext; - } - else if (paramType == typeof(TestSessionContext)) - { - args[i] = classContext.AssemblyContext.TestSessionContext; - } - } - else if (context is AssemblyHookContext assemblyContext) - { - if (paramType == typeof(TestSessionContext)) - { - args[i] = assemblyContext.TestSessionContext; - } - } - } - - return args; - } - - private static async Task HandleHookResult(object? result) - { - if (result is Task task) - { - await task; - } - else if (result is ValueTask valueTask) - { - await valueTask; - } - } - - private static bool IsCompilerGenerated(Type type) - { - return type.IsDefined(typeof(CompilerGeneratedAttribute), inherit: false); - } -} diff --git a/TUnit.Engine/Services/HookCollectionService.cs b/TUnit.Engine/Services/HookCollectionService.cs index 41a7844f32..0f80798f0e 100644 --- a/TUnit.Engine/Services/HookCollectionService.cs +++ b/TUnit.Engine/Services/HookCollectionService.cs @@ -487,19 +487,7 @@ private static Func CreateInstanceHookDele { return async (context, cancellationToken) => { - // Skip instance hooks if this is a pre-skipped test - if (context.TestDetails.ClassInstance is SkippedTestInstance) - { - return; - } - - if (hook.Body != null) - { - await hook.Body( - context.TestDetails.ClassInstance, - context, - cancellationToken); - } + await hook.ExecuteAsync(context, cancellationToken); }; } @@ -507,10 +495,7 @@ private static Func CreateStaticHookDelega { return async (context, cancellationToken) => { - if (hook.Body != null) - { - await hook.Body(context, cancellationToken); - } + await hook.ExecuteAsync(context, cancellationToken); }; } @@ -518,10 +503,7 @@ private static Func CreateClassHookDe { return async (context, cancellationToken) => { - if (hook.Body != null) - { - await hook.Body(context, cancellationToken); - } + await hook.ExecuteAsync(context, cancellationToken); }; } @@ -529,10 +511,7 @@ private static Func CreateAssembly { return async (context, cancellationToken) => { - if (hook.Body != null) - { - await hook.Body(context, cancellationToken); - } + await hook.ExecuteAsync(context, cancellationToken); }; } @@ -540,10 +519,7 @@ private static Func CreateTestSessi { return async (context, cancellationToken) => { - if (hook.Body != null) - { - await hook.Body(context, cancellationToken); - } + await hook.ExecuteAsync(context, cancellationToken); }; } @@ -551,10 +527,7 @@ private static Func CreateB { return async (context, cancellationToken) => { - if (hook.Body != null) - { - await hook.Body(context, cancellationToken); - } + await hook.ExecuteAsync(context, cancellationToken); }; } @@ -562,10 +535,7 @@ private static Func CreateTestDis { return async (context, cancellationToken) => { - if (hook.Body != null) - { - await hook.Body(context, cancellationToken); - } + await hook.ExecuteAsync(context, cancellationToken); }; } diff --git a/TUnit.Engine/Services/HookOrchestrator.cs b/TUnit.Engine/Services/HookOrchestrator.cs index f82afde5dd..5ed0bf7b55 100644 --- a/TUnit.Engine/Services/HookOrchestrator.cs +++ b/TUnit.Engine/Services/HookOrchestrator.cs @@ -109,11 +109,11 @@ private Task GetOrCreateBeforeClassTask( exceptions.Add(ex); } } - + if (exceptions.Count > 0) { - throw exceptions.Count == 1 - ? new HookFailedException(exceptions[0]) + throw exceptions.Count == 1 + ? new HookFailedException(exceptions[0]) : new HookFailedException("Multiple AfterTestSession hooks failed", new AggregateException(exceptions)); } @@ -167,11 +167,11 @@ private Task GetOrCreateBeforeClassTask( exceptions.Add(ex); } } - + if (exceptions.Count > 0) { - throw exceptions.Count == 1 - ? new HookFailedException(exceptions[0]) + throw exceptions.Count == 1 + ? new HookFailedException(exceptions[0]) : new HookFailedException("Multiple AfterTestDiscovery hooks failed", new AggregateException(exceptions)); } @@ -184,6 +184,11 @@ private Task GetOrCreateBeforeClassTask( public async Task OnTestStartingAsync(AbstractExecutableTest test, CancellationToken cancellationToken) { + if (test.Context.TestDetails.ClassInstance is SkippedTestInstance) + { + return ExecutionContext.Capture()!; + } + var testClassType = test.Metadata.TestClassType; var assemblyName = testClassType.Assembly.GetName().Name ?? "Unknown"; @@ -210,6 +215,11 @@ public async Task OnTestStartingAsync(AbstractExecutableTest t public async Task OnTestCompletedAsync(AbstractExecutableTest test, CancellationToken cancellationToken) { + if (test.Context.TestDetails.ClassInstance is SkippedTestInstance) + { + return; + } + var testClassType = test.Metadata.TestClassType; var assemblyName = testClassType.Assembly.GetName().Name ?? "Unknown"; @@ -309,11 +319,11 @@ private async Task ExecuteAfterAssemblyHooksAsync(Assembly ass exceptions.Add(ex); } } - + if (exceptions.Count > 0) { - throw exceptions.Count == 1 - ? new HookFailedException(exceptions[0]) + throw exceptions.Count == 1 + ? new HookFailedException(exceptions[0]) : new HookFailedException("Multiple AfterAssembly hooks failed", new AggregateException(exceptions)); } @@ -398,11 +408,11 @@ private async Task ExecuteAfterClassHooksAsync( exceptions.Add(ex); } } - + if (exceptions.Count > 0) { - throw exceptions.Count == 1 - ? new HookFailedException(exceptions[0]) + throw exceptions.Count == 1 + ? new HookFailedException(exceptions[0]) : new HookFailedException("Multiple AfterClass hooks failed", new AggregateException(exceptions)); } @@ -446,11 +456,11 @@ private async Task ExecuteAfterEveryTestHooksAsync(Type testClassType, TestConte exceptions.Add(ex); } } - + if (exceptions.Count > 0) { - throw exceptions.Count == 1 - ? new HookFailedException(exceptions[0]) + throw exceptions.Count == 1 + ? new HookFailedException(exceptions[0]) : new HookFailedException("Multiple AfterEveryTest hooks failed", new AggregateException(exceptions)); } } diff --git a/TUnit.TestProject/HookExecutorTests.cs b/TUnit.TestProject/HookExecutorTests.cs new file mode 100644 index 0000000000..163bbd6627 --- /dev/null +++ b/TUnit.TestProject/HookExecutorTests.cs @@ -0,0 +1,145 @@ +using System.Diagnostics.CodeAnalysis; +using TUnit.Core.Executors; + +namespace TUnit.TestProject; + +[UnconditionalSuppressMessage("Interoperability", "CA1416:Validate platform compatibility")] +public class HookExecutorTests +{ + // Test Session Hooks + [Before(TestSession)] + [HookExecutor] + public static async Task BeforeTestSessionWithSTA(TestSessionContext context) + { + await Assert.That(Thread.CurrentThread.GetApartmentState()).IsEquatableOrEqualTo(ApartmentState.STA); + + var test = context.AllTests.FirstOrDefault(x => + x.TestDetails.TestName == nameof(VerifyBeforeTestSessionSTAExecuted)); + test?.ObjectBag.Add("BeforeTestSessionSTAExecuted", true); + } + + [After(TestSession)] + [HookExecutor] + public static async Task AfterTestSessionWithSTA(TestSessionContext context) + { + await Assert.That(Thread.CurrentThread.GetApartmentState()).IsEquatableOrEqualTo(ApartmentState.STA); + } + + // Test Discovery Hooks + [Before(TestDiscovery)] + [HookExecutor] + public static async Task BeforeTestDiscoveryWithSTA(BeforeTestDiscoveryContext context) + { + await Assert.That(Thread.CurrentThread.GetApartmentState()).IsEquatableOrEqualTo(ApartmentState.STA); + } + + [After(TestDiscovery)] + [HookExecutor] + public static async Task AfterTestDiscoveryWithSTA(TestDiscoveryContext context) + { + await Assert.That(Thread.CurrentThread.GetApartmentState()).IsEquatableOrEqualTo(ApartmentState.STA); + } + + // Assembly Hooks + private static bool _beforeAssemblySTAExecuted; + + [Before(Assembly)] + [HookExecutor] + public static async Task BeforeAssemblyWithSTA(AssemblyHookContext context) + { + await Assert.That(Thread.CurrentThread.GetApartmentState()).IsEquatableOrEqualTo(ApartmentState.STA); + _beforeAssemblySTAExecuted = true; + } + + [After(Assembly)] + [HookExecutor] + public static async Task AfterAssemblyWithSTA(AssemblyHookContext context) + { + await Assert.That(Thread.CurrentThread.GetApartmentState()).IsEquatableOrEqualTo(ApartmentState.STA); + } + + // Class Hooks + private static bool _beforeClassSTAExecuted; + + [Before(Class)] + [HookExecutor] + public static async Task BeforeClassWithSTA(ClassHookContext context) + { + await Assert.That(Thread.CurrentThread.GetApartmentState()).IsEquatableOrEqualTo(ApartmentState.STA); + _beforeClassSTAExecuted = true; + } + + [After(Class)] + [HookExecutor] + public static async Task AfterClassWithSTA(ClassHookContext context) + { + await Assert.That(Thread.CurrentThread.GetApartmentState()).IsEquatableOrEqualTo(ApartmentState.STA); + } + + // Test Hooks - Instance methods + [Before(Test)] + [HookExecutor] + public async Task BeforeTestWithSTA(TestContext context) + { + await Assert.That(Thread.CurrentThread.GetApartmentState()).IsEquatableOrEqualTo(ApartmentState.STA); + context.ObjectBag.Add("BeforeTestSTAExecuted", true); + } + + [After(Test)] + [HookExecutor] + public async Task AfterTestWithSTA(TestContext context) + { + await Assert.That(Thread.CurrentThread.GetApartmentState()).IsEquatableOrEqualTo(ApartmentState.STA); + } + + // Test Hooks - Static methods + [BeforeEvery(Test)] + [HookExecutor] + public static async Task BeforeEveryTestWithSTA(TestContext context) + { + await Assert.That(Thread.CurrentThread.GetApartmentState()).IsEquatableOrEqualTo(ApartmentState.STA); + + if (context.TestDetails.TestName == nameof(VerifyStaticTestHooksSTAExecuted)) + { + context.ObjectBag.Add("BeforeEveryTestSTAExecuted", true); + } + } + + [AfterEvery(Test)] + [HookExecutor] + public static async Task AfterEveryTestWithSTA(TestContext context) + { + await Assert.That(Thread.CurrentThread.GetApartmentState()).IsEquatableOrEqualTo(ApartmentState.STA); + } + + // Tests to verify hooks executed + [Test] + public async Task VerifyBeforeTestSessionSTAExecuted() + { + await Assert.That(TestContext.Current?.ObjectBag["BeforeTestSessionSTAExecuted"]).IsEquatableOrEqualTo(true); + } + + [Test] + public async Task VerifyBeforeAssemblySTAExecuted() + { + await Assert.That(_beforeAssemblySTAExecuted).IsTrue(); + } + + [Test] + public async Task VerifyBeforeClassSTAExecuted() + { + await Assert.That(_beforeClassSTAExecuted).IsTrue(); + } + + [Test] + public async Task VerifyBeforeTestSTAExecuted() + { + await Assert.That(TestContext.Current?.ObjectBag["BeforeTestSTAExecuted"]).IsEquatableOrEqualTo(true); + } + + [Test] + public async Task VerifyStaticTestHooksSTAExecuted() + { + await Assert.That(TestContext.Current?.ObjectBag["BeforeEveryTestSTAExecuted"]).IsEquatableOrEqualTo(true); + } +} \ No newline at end of file