From 024a2ca55e1fdbc669877e6854fcb4002605da80 Mon Sep 17 00:00:00 2001 From: Tom Longhurst <30480171+thomhurst@users.noreply.github.com> Date: Sun, 7 Dec 2025 11:49:05 +0000 Subject: [PATCH 01/20] feat: implement IAsyncDiscoveryInitializer and related classes for improved test discovery handling --- TUnit.Engine/Building/TestBuilder.cs | 8 +++- TUnit.Engine/Services/PropertyInjector.cs | 19 ++++++++- TUnit.TestProject/Bugs/3992/BugRecreation.cs | 41 +++++++++++++++++++ TUnit.TestProject/Bugs/3992/DummyContainer.cs | 22 ++++++++++ 4 files changed, 86 insertions(+), 4 deletions(-) create mode 100644 TUnit.TestProject/Bugs/3992/BugRecreation.cs create mode 100644 TUnit.TestProject/Bugs/3992/DummyContainer.cs diff --git a/TUnit.Engine/Building/TestBuilder.cs b/TUnit.Engine/Building/TestBuilder.cs index 30dc1bcb7e..c94a4b6ded 100644 --- a/TUnit.Engine/Building/TestBuilder.cs +++ b/TUnit.Engine/Building/TestBuilder.cs @@ -45,7 +45,9 @@ public TestBuilder( } /// - /// Initializes any IAsyncInitializer objects in class data that were deferred during registration. + /// Initializes any IAsyncDiscoveryInitializer objects in class data during test building. + /// Regular IAsyncInitializer objects are NOT initialized here - they are deferred to test execution + /// via ObjectLifecycleService to avoid premature initialization during discovery. /// private async Task InitializeDeferredClassDataAsync(object?[] classData) { @@ -56,7 +58,9 @@ private async Task InitializeDeferredClassDataAsync(object?[] classData) foreach (var data in classData) { - if (data is IAsyncInitializer asyncInitializer && data is not IDataSourceAttribute) + // Only initialize IAsyncDiscoveryInitializer during discovery/building. + // Regular IAsyncInitializer objects are initialized during test execution. + if (data is IAsyncDiscoveryInitializer && data is not IDataSourceAttribute) { if (!ObjectInitializer.IsInitialized(data)) { diff --git a/TUnit.Engine/Services/PropertyInjector.cs b/TUnit.Engine/Services/PropertyInjector.cs index 249f19207b..36179c779e 100644 --- a/TUnit.Engine/Services/PropertyInjector.cs +++ b/TUnit.Engine/Services/PropertyInjector.cs @@ -525,15 +525,30 @@ private async Task ResolveAndCacheReflectionPropertyAsync( if (value != null) { - // Ensure nested objects are initialized - if (PropertyInjectionCache.HasInjectableProperties(value.GetType()) || value is IAsyncInitializer) + // Handle property injection and initialization appropriately during discovery + var hasInjectableProperties = PropertyInjectionCache.HasInjectableProperties(value.GetType()); + var isDiscoveryInitializer = value is IAsyncDiscoveryInitializer; + + if (isDiscoveryInitializer) { + // Full initialization during discovery (property injection + IAsyncInitializer.InitializeAsync) + // for objects that explicitly opt-in via IAsyncDiscoveryInitializer await _objectLifecycleService.Value.EnsureInitializedAsync( value, context.ObjectBag, context.MethodMetadata, context.Events); } + else if (hasInjectableProperties) + { + // Property injection only, IAsyncInitializer.InitializeAsync deferred to execution + // Regular IAsyncInitializer objects are initialized during test execution by ObjectLifecycleService + await _objectLifecycleService.Value.InjectPropertiesAsync( + value, + context.ObjectBag, + context.MethodMetadata, + context.Events); + } return value; } diff --git a/TUnit.TestProject/Bugs/3992/BugRecreation.cs b/TUnit.TestProject/Bugs/3992/BugRecreation.cs new file mode 100644 index 0000000000..fbc042d9fa --- /dev/null +++ b/TUnit.TestProject/Bugs/3992/BugRecreation.cs @@ -0,0 +1,41 @@ +using TUnit.TestProject.Attributes; + +namespace TUnit.TestProject.Bugs._3992; + +/// +/// Once this is discovered during test discovery, containers spin up +/// +[EngineTest(ExpectedResult.Pass)] +public sealed class BugRecreation +{ + //Docker container + [ClassDataSource(Shared = SharedType.PerClass)] + public required DummyContainer Container { get; init; } + + public IEnumerable> Executions + => Container.Ints.Select(e => new Func(() => e)); + + [Before(Class)] + public static Task BeforeClass(ClassHookContext context) => NotInitialised(context.Tests); + + [After(TestDiscovery)] + public static Task AfterDiscovery(TestDiscoveryContext context) => NotInitialised(context.AllTests); + + public static async Task NotInitialised(IEnumerable tests) + { + var bugRecreations = tests.Select(x => x.Metadata.TestDetails.ClassInstance).OfType(); + + foreach (var bugRecreation in bugRecreations) + { + await Assert.That(bugRecreation.Container).IsNotNull(); + await Assert.That(DummyContainer.NumberOfInits).IsEqualTo(0); + } + } + + [Test, Arguments(1)] + public async Task Test(int value, CancellationToken token) + { + await Assert.That(value).IsNotDefault(); + await Assert.That(DummyContainer.NumberOfInits).IsEqualTo(1); + } +} diff --git a/TUnit.TestProject/Bugs/3992/DummyContainer.cs b/TUnit.TestProject/Bugs/3992/DummyContainer.cs new file mode 100644 index 0000000000..b075f5e9f1 --- /dev/null +++ b/TUnit.TestProject/Bugs/3992/DummyContainer.cs @@ -0,0 +1,22 @@ +using TUnit.Core.Interfaces; + +namespace TUnit.TestProject.Bugs._3992; + +public class DummyContainer : IAsyncInitializer, IAsyncDisposable +{ + public Task InitializeAsync() + { + NumberOfInits++; + Ints = [1, 2, 3, 4, 5, 6]; + return Task.CompletedTask; + } + + public int[] Ints { get; private set; } = null!; + + public static int NumberOfInits { get; private set; } + + public ValueTask DisposeAsync() + { + return default; + } +} From 7cdc0fc4960e7cb8710e1adc98ce7cd18b7f48c6 Mon Sep 17 00:00:00 2001 From: Tom Longhurst <30480171+thomhurst@users.noreply.github.com> Date: Sun, 7 Dec 2025 13:28:49 +0000 Subject: [PATCH 02/20] feat: implement IAsyncDiscoveryInitializer and related classes for improved test discovery handling --- TUnit.Core/Helpers/DataSourceHelpers.cs | 40 ++---- TUnit.Core/ObjectInitializer.cs | 61 +++++++-- TUnit.Core/TestBuilderContext.cs | 3 +- .../Tracking/TrackableObjectGraphProvider.cs | 62 ++++++++- TUnit.Engine/Building/TestBuilder.cs | 79 +++--------- .../Services/ObjectGraphDiscoveryService.cs | 120 +++++++++++++----- .../Services/ObjectLifecycleService.cs | 91 ++++++++----- TUnit.Engine/Services/PropertyInjector.cs | 32 ++--- .../Services/TestExecution/TestCoordinator.cs | 4 +- TUnit.Engine/TestExecutor.cs | 8 +- TUnit.Engine/TestInitializer.cs | 13 +- 11 files changed, 319 insertions(+), 194 deletions(-) diff --git a/TUnit.Core/Helpers/DataSourceHelpers.cs b/TUnit.Core/Helpers/DataSourceHelpers.cs index 6802e5add5..fc7ba158e6 100644 --- a/TUnit.Core/Helpers/DataSourceHelpers.cs +++ b/TUnit.Core/Helpers/DataSourceHelpers.cs @@ -178,12 +178,9 @@ public static T InvokeIfFunc(object? value) // If it's a Func, invoke it first var actualData = InvokeIfFunc(data); - // Only initialize during discovery if explicitly opted-in via IAsyncDiscoveryInitializer - // Regular IAsyncInitializer objects are initialized during test execution by ObjectLifecycleService - if (actualData is IAsyncDiscoveryInitializer) - { - await ObjectInitializer.InitializeAsync(actualData); - } + // During discovery, only IAsyncDiscoveryInitializer objects are initialized. + // Regular IAsyncInitializer objects are deferred to Execution phase. + await ObjectInitializer.InitializeForDiscoveryAsync(actualData); return actualData; } @@ -202,11 +199,8 @@ public static T InvokeIfFunc(object? value) if (enumerator.MoveNext()) { var value = enumerator.Current; - // Only initialize during discovery if explicitly opted-in via IAsyncDiscoveryInitializer - if (value is IAsyncDiscoveryInitializer) - { - await ObjectInitializer.InitializeAsync(value); - } + // Discovery: only IAsyncDiscoveryInitializer + await ObjectInitializer.InitializeForDiscoveryAsync(value); return value; } @@ -233,22 +227,16 @@ public static T InvokeIfFunc(object? value) if (enumerator.MoveNext()) { var value = enumerator.Current; - // Only initialize during discovery if explicitly opted-in via IAsyncDiscoveryInitializer - if (value is IAsyncDiscoveryInitializer) - { - await ObjectInitializer.InitializeAsync(value); - } + // Discovery: only IAsyncDiscoveryInitializer + await ObjectInitializer.InitializeForDiscoveryAsync(value); return value; } return null; } - // Only initialize during discovery if explicitly opted-in via IAsyncDiscoveryInitializer - // Regular IAsyncInitializer objects are initialized during test execution by ObjectLifecycleService - if (actualData is IAsyncDiscoveryInitializer) - { - await ObjectInitializer.InitializeAsync(actualData); - } + // During discovery, only IAsyncDiscoveryInitializer objects are initialized. + // Regular IAsyncInitializer objects are deferred to Execution phase. + await ObjectInitializer.InitializeForDiscoveryAsync(actualData); return actualData; } @@ -596,12 +584,8 @@ public static void RegisterTypeCreator(Func> { var value = args[0]; - // Only initialize during discovery if explicitly opted-in via IAsyncDiscoveryInitializer - // Regular IAsyncInitializer objects are initialized during test execution by ObjectLifecycleService - if (value is IAsyncDiscoveryInitializer) - { - await ObjectInitializer.InitializeAsync(value); - } + // Discovery: only IAsyncDiscoveryInitializer + await ObjectInitializer.InitializeForDiscoveryAsync(value); return value; } diff --git a/TUnit.Core/ObjectInitializer.cs b/TUnit.Core/ObjectInitializer.cs index 362445816e..33da8a31d6 100644 --- a/TUnit.Core/ObjectInitializer.cs +++ b/TUnit.Core/ObjectInitializer.cs @@ -1,35 +1,76 @@ -using System.Runtime.CompilerServices; +using System.Runtime.CompilerServices; using TUnit.Core.Interfaces; namespace TUnit.Core; +/// +/// Centralized service for initializing objects that implement IAsyncInitializer. +/// Provides thread-safe, deduplicated initialization with explicit phase control. +/// +/// Use InitializeForDiscoveryAsync during test discovery - only IAsyncDiscoveryInitializer objects are initialized. +/// Use InitializeAsync during test execution - all IAsyncInitializer objects are initialized. +/// public static class ObjectInitializer { private static readonly ConditionalWeakTable _initializationTasks = new(); private static readonly Lock _lock = new(); - internal static bool IsInitialized(object? obj) + /// + /// Initializes an object during the discovery phase. + /// Only objects implementing IAsyncDiscoveryInitializer are initialized. + /// Regular IAsyncInitializer objects are skipped (deferred to execution phase). + /// Thread-safe with deduplication - safe to call multiple times. + /// + /// The object to potentially initialize. + /// Cancellation token. + public static ValueTask InitializeForDiscoveryAsync(object? obj, CancellationToken cancellationToken = default) { - if (obj is not IAsyncInitializer) + // During discovery, only initialize IAsyncDiscoveryInitializer + if (obj is not IAsyncDiscoveryInitializer asyncDiscoveryInitializer) { - return false; + return default; } - lock (_lock) + return InitializeCoreAsync(obj, asyncDiscoveryInitializer, cancellationToken); + } + + /// + /// Initializes an object during the execution phase. + /// All objects implementing IAsyncInitializer are initialized. + /// Thread-safe with deduplication - safe to call multiple times. + /// + /// The object to potentially initialize. + /// Cancellation token. + public static ValueTask InitializeAsync(object? obj, CancellationToken cancellationToken = default) + { + if (obj is not IAsyncInitializer asyncInitializer) { - return _initializationTasks.TryGetValue(obj, out var task) && task.IsCompleted; + return default; } + + return InitializeCoreAsync(obj, asyncInitializer, cancellationToken); } - public static async ValueTask InitializeAsync(object? obj, CancellationToken cancellationToken = default) + /// + /// Checks if an object has been initialized by ObjectInitializer. + /// + internal static bool IsInitialized(object? obj) { - if (obj is IAsyncInitializer asyncInitializer) + if (obj is not IAsyncInitializer) + { + return false; + } + + lock (_lock) { - await GetInitializationTask(obj, asyncInitializer, cancellationToken); + return _initializationTasks.TryGetValue(obj, out var task) && task.IsCompleted; } } - private static async Task GetInitializationTask(object obj, IAsyncInitializer asyncInitializer, CancellationToken cancellationToken) + private static async ValueTask InitializeCoreAsync( + object obj, + IAsyncInitializer asyncInitializer, + CancellationToken cancellationToken) { Task initializationTask; diff --git a/TUnit.Core/TestBuilderContext.cs b/TUnit.Core/TestBuilderContext.cs index 25b20c65d4..f6ec4b0eef 100644 --- a/TUnit.Core/TestBuilderContext.cs +++ b/TUnit.Core/TestBuilderContext.cs @@ -49,7 +49,8 @@ public void RegisterForInitialization(object? obj) { Events.OnInitialize += async (sender, args) => { - await ObjectInitializer.InitializeAsync(obj); + // Discovery: only IAsyncDiscoveryInitializer + await ObjectInitializer.InitializeForDiscoveryAsync(obj); }; } diff --git a/TUnit.Core/Tracking/TrackableObjectGraphProvider.cs b/TUnit.Core/Tracking/TrackableObjectGraphProvider.cs index 97301e53f8..8f1e5bb7a1 100644 --- a/TUnit.Core/Tracking/TrackableObjectGraphProvider.cs +++ b/TUnit.Core/Tracking/TrackableObjectGraphProvider.cs @@ -1,5 +1,7 @@ using System.Collections.Concurrent; using System.Diagnostics.CodeAnalysis; +using System.Reflection; +using TUnit.Core.Interfaces; using TUnit.Core.PropertyInjection; using TUnit.Core.StaticProperties; @@ -64,6 +66,7 @@ public IEnumerable GetStaticPropertyTrackableObjects() } } + [UnconditionalSuppressMessage("Trimming", "IL2026", Justification = "Property discovery handles both AOT and reflection modes")] private void AddNestedTrackableObjects(object obj, ConcurrentDictionary> visitedObjects, int currentDepth) { var plan = PropertyInjectionCache.GetOrCreatePlan(obj.GetType()); @@ -85,11 +88,6 @@ private void AddNestedTrackableObjects(object obj, ConcurrentDictionary + /// Discovers nested objects that implement IAsyncInitializer from all readable properties. + /// This is separate from injectable property discovery to handle objects without data source attributes. + /// This is a best-effort fallback - in AOT scenarios, properties with data source attributes are discovered via source generation. + /// + [UnconditionalSuppressMessage("Trimming", "IL2026", Justification = "Reflection fallback for nested initializers. In AOT, source-gen handles primary discovery.")] + [UnconditionalSuppressMessage("Trimming", "IL2070", Justification = "Reflection fallback for nested initializers. In AOT, source-gen handles primary discovery.")] + [UnconditionalSuppressMessage("Trimming", "IL2075", Justification = "Reflection fallback for nested initializers. In AOT, source-gen handles primary discovery.")] + private void AddNestedInitializerObjects(object obj, ConcurrentDictionary> visitedObjects, int currentDepth) + { + var type = obj.GetType(); + + // Skip primitive types, strings, and system types + if (type.IsPrimitive || type == typeof(string) || type.Namespace?.StartsWith("System") == true) + { + return; + } + + // Get all readable instance properties + var properties = type.GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance); + + foreach (var property in properties) + { + if (!property.CanRead || property.GetIndexParameters().Length > 0) + { + continue; + } + + try + { + var value = property.GetValue(obj); + if (value == null) { continue; } - AddNestedTrackableObjects(value, visitedObjects, currentDepth + 1); + // Only discover if it implements IAsyncInitializer and hasn't been visited + if (value is IAsyncInitializer && visitedObjects.GetOrAdd(currentDepth, []).Add(value)) + { + // Recursively discover nested objects + AddNestedTrackableObjects(value, visitedObjects, currentDepth + 1); + } + } + catch + { + // Ignore properties that throw exceptions when accessed } } } diff --git a/TUnit.Engine/Building/TestBuilder.cs b/TUnit.Engine/Building/TestBuilder.cs index c94a4b6ded..a8ebc2afbc 100644 --- a/TUnit.Engine/Building/TestBuilder.cs +++ b/TUnit.Engine/Building/TestBuilder.cs @@ -45,11 +45,11 @@ public TestBuilder( } /// - /// Initializes any IAsyncDiscoveryInitializer objects in class data during test building. - /// Regular IAsyncInitializer objects are NOT initialized here - they are deferred to test execution - /// via ObjectLifecycleService to avoid premature initialization during discovery. + /// Initializes class data objects during test building. + /// Only IAsyncDiscoveryInitializer objects are initialized during discovery. + /// Regular IAsyncInitializer objects are deferred to execution phase. /// - private async Task InitializeDeferredClassDataAsync(object?[] classData) + private static async Task InitializeClassDataAsync(object?[] classData) { if (classData == null || classData.Length == 0) { @@ -58,46 +58,16 @@ private async Task InitializeDeferredClassDataAsync(object?[] classData) foreach (var data in classData) { - // Only initialize IAsyncDiscoveryInitializer during discovery/building. - // Regular IAsyncInitializer objects are initialized during test execution. - if (data is IAsyncDiscoveryInitializer && data is not IDataSourceAttribute) - { - if (!ObjectInitializer.IsInitialized(data)) - { - await ObjectInitializer.InitializeAsync(data); - } - } - } - } - - /// - /// Initializes any IAsyncDiscoveryInitializer objects in class data during test discovery. - /// This is called BEFORE method data sources are evaluated, enabling data sources - /// to access initialized shared objects (like Docker containers). - /// - private static async Task InitializeDiscoveryObjectsAsync(object?[] classData) - { - if (classData == null || classData.Length == 0) - { - return; - } - - foreach (var data in classData) - { - if (data is IAsyncDiscoveryInitializer) - { - // Uses ObjectInitializer which handles deduplication. - // This also prevents double-init during execution since ObjectInitializer - // tracks initialized objects. - await ObjectInitializer.InitializeAsync(data); - } + // Discovery: only IAsyncDiscoveryInitializer objects are initialized. + // Regular IAsyncInitializer objects are deferred to execution phase. + await ObjectInitializer.InitializeForDiscoveryAsync(data); } } private async Task CreateInstance(TestMetadata metadata, Type[] resolvedClassGenericArgs, object?[] classData, TestBuilderContext builderContext) { // Initialize any deferred IAsyncInitializer objects in class data - await InitializeDeferredClassDataAsync(classData); + await InitializeClassDataAsync(classData); // First try to create instance with ClassConstructor attribute // Use attributes from context if available @@ -234,9 +204,9 @@ public async Task> BuildTestsFromMetadataAsy var classDataResult = await classDataFactory() ?? []; var classData = DataUnwrapper.Unwrap(classDataResult); - // Initialize IAsyncDiscoveryInitializer objects before method data sources are evaluated. - // This enables InstanceMethodDataSource to access initialized shared objects. - await InitializeDiscoveryObjectsAsync(classData); + // Initialize objects before method data sources are evaluated. + // ObjectInitializer is phase-aware and will only initialize IAsyncDiscoveryInitializer during Discovery. + await InitializeClassDataAsync(classData); var needsInstanceForMethodDataSources = metadata.DataSources.Any(ds => ds is IAccessesInstanceData); @@ -298,11 +268,8 @@ await _objectLifecycleService.RegisterObjectAsync( metadata.MethodMetadata, tempEvents); - // Initialize the test class instance if it implements IAsyncDiscoveryInitializer - if (instanceForMethodDataSources is IAsyncDiscoveryInitializer) - { - await ObjectInitializer.InitializeAsync(instanceForMethodDataSources); - } + // Discovery: only IAsyncDiscoveryInitializer is initialized + await ObjectInitializer.InitializeForDiscoveryAsync(instanceForMethodDataSources); } } catch (Exception ex) @@ -351,8 +318,8 @@ await _objectLifecycleService.RegisterObjectAsync( classData = DataUnwrapper.Unwrap(await classDataFactory() ?? []); var methodData = DataUnwrapper.UnwrapWithTypes(await methodDataFactory() ?? [], metadata.MethodMetadata.Parameters); - // Initialize any IAsyncDiscoveryInitializer objects in method data - await InitializeDiscoveryObjectsAsync(methodData); + // Initialize method data objects (ObjectInitializer is phase-aware) + await InitializeClassDataAsync(methodData); // For concrete generic instantiations, check if the data is compatible with the expected types if (metadata.GenericMethodTypeArguments is { Length: > 0 }) @@ -1427,9 +1394,8 @@ public async IAsyncEnumerable BuildTestsStreamingAsync( var classData = DataUnwrapper.Unwrap(await classDataFactory() ?? []); - // Initialize IAsyncDiscoveryInitializer objects before method data sources are evaluated. - // This enables InstanceMethodDataSource to access initialized shared objects. - await InitializeDiscoveryObjectsAsync(classData); + // Initialize objects before method data sources are evaluated (ObjectInitializer is phase-aware) + await InitializeClassDataAsync(classData); // Handle instance creation for method data sources var needsInstanceForMethodDataSources = metadata.DataSources.Any(ds => ds is IAccessesInstanceData); @@ -1456,11 +1422,8 @@ await _objectLifecycleService.RegisterObjectAsync( metadata.MethodMetadata, tempEvents); - // Initialize the test class instance if it implements IAsyncDiscoveryInitializer - if (instanceForMethodDataSources is IAsyncDiscoveryInitializer) - { - await ObjectInitializer.InitializeAsync(instanceForMethodDataSources); - } + // Discovery: only IAsyncDiscoveryInitializer is initialized + await ObjectInitializer.InitializeForDiscoveryAsync(instanceForMethodDataSources); } // Stream through method data sources @@ -1571,8 +1534,8 @@ await _objectLifecycleService.RegisterObjectAsync( var methodData = DataUnwrapper.UnwrapWithTypes(await methodDataFactory() ?? [], metadata.MethodMetadata.Parameters); - // Initialize any IAsyncDiscoveryInitializer objects in method data - await InitializeDiscoveryObjectsAsync(methodData); + // Initialize method data objects (ObjectInitializer is phase-aware) + await InitializeClassDataAsync(methodData); // Check data compatibility for generic methods if (metadata.GenericMethodTypeArguments is { Length: > 0 }) diff --git a/TUnit.Engine/Services/ObjectGraphDiscoveryService.cs b/TUnit.Engine/Services/ObjectGraphDiscoveryService.cs index 8d5e46f7f9..4819a275f7 100644 --- a/TUnit.Engine/Services/ObjectGraphDiscoveryService.cs +++ b/TUnit.Engine/Services/ObjectGraphDiscoveryService.cs @@ -1,6 +1,8 @@ using System.Collections.Concurrent; using System.Diagnostics.CodeAnalysis; +using System.Reflection; using TUnit.Core; +using TUnit.Core.Interfaces; using TUnit.Core.PropertyInjection; namespace TUnit.Engine.Services; @@ -81,7 +83,9 @@ public ObjectGraph DiscoverNestedObjectGraph(object rootObject) } /// - /// Recursively discovers nested objects that have injectable properties. + /// Recursively discovers nested objects that have injectable properties OR implement IAsyncInitializer. + /// This ensures that all nested objects that need initialization are discovered, + /// even if they don't have explicit data source attributes. /// [UnconditionalSuppressMessage("Trimming", "IL2026", Justification = "Property discovery handles both AOT and reflection modes")] private void DiscoverNestedObjects( @@ -93,57 +97,113 @@ private void DiscoverNestedObjects( { var plan = PropertyInjectionCache.GetOrCreatePlan(obj.GetType()); - if (!plan.HasProperties) + // First, discover objects from injectable properties (data source attributes) + if (plan.HasProperties) { - return; - } - - // Use source-generated properties if available, otherwise fall back to reflection - if (plan.SourceGeneratedProperties.Length > 0) - { - foreach (var metadata in plan.SourceGeneratedProperties) + // Use source-generated properties if available, otherwise fall back to reflection + if (plan.SourceGeneratedProperties.Length > 0) { - var property = metadata.ContainingType.GetProperty(metadata.PropertyName); - if (property == null || !property.CanRead) + foreach (var metadata in plan.SourceGeneratedProperties) { - continue; - } + var property = metadata.ContainingType.GetProperty(metadata.PropertyName); + if (property == null || !property.CanRead) + { + continue; + } - var value = property.GetValue(obj); - if (value == null || !visitedObjects.Add(value)) - { - continue; - } + var value = property.GetValue(obj); + if (value == null || !visitedObjects.Add(value)) + { + continue; + } - AddToDepth(objectsByDepth, currentDepth, value); - allObjects.Add(value); + AddToDepth(objectsByDepth, currentDepth, value); + allObjects.Add(value); - // Recursively discover if this value has injectable properties - if (PropertyInjectionCache.HasInjectableProperties(value.GetType())) + // Recursively discover nested objects + DiscoverNestedObjects(value, objectsByDepth, visitedObjects, allObjects, currentDepth + 1); + } + } + else if (plan.ReflectionProperties.Length > 0) + { + foreach (var (property, _) in plan.ReflectionProperties) { + var value = property.GetValue(obj); + if (value == null || !visitedObjects.Add(value)) + { + continue; + } + + AddToDepth(objectsByDepth, currentDepth, value); + allObjects.Add(value); + + // Recursively discover nested objects DiscoverNestedObjects(value, objectsByDepth, visitedObjects, allObjects, currentDepth + 1); } } } - else if (plan.ReflectionProperties.Length > 0) + + // Also discover nested IAsyncInitializer objects from ALL properties + // This handles cases where nested objects don't have data source attributes + // but still implement IAsyncInitializer and need to be initialized + DiscoverNestedInitializerObjects(obj, objectsByDepth, visitedObjects, allObjects, currentDepth); + } + + /// + /// Discovers nested objects that implement IAsyncInitializer from all readable properties. + /// This is separate from injectable property discovery to handle objects without data source attributes. + /// This is a best-effort fallback - in AOT scenarios, properties with data source attributes are discovered via source generation. + /// + [UnconditionalSuppressMessage("Trimming", "IL2026", Justification = "Reflection fallback for nested initializers. In AOT, source-gen handles primary discovery.")] + [UnconditionalSuppressMessage("Trimming", "IL2070", Justification = "Reflection fallback for nested initializers. In AOT, source-gen handles primary discovery.")] + [UnconditionalSuppressMessage("Trimming", "IL2075", Justification = "Reflection fallback for nested initializers. In AOT, source-gen handles primary discovery.")] + private void DiscoverNestedInitializerObjects( + object obj, + ConcurrentDictionary> objectsByDepth, + HashSet visitedObjects, + HashSet allObjects, + int currentDepth) + { + var type = obj.GetType(); + + // Skip primitive types, strings, and system types + if (type.IsPrimitive || type == typeof(string) || type.Namespace?.StartsWith("System") == true) + { + return; + } + + // Get all readable instance properties + var properties = type.GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance); + + foreach (var property in properties) { - foreach (var (property, _) in plan.ReflectionProperties) + if (!property.CanRead || property.GetIndexParameters().Length > 0) + { + continue; + } + + try { var value = property.GetValue(obj); - if (value == null || !visitedObjects.Add(value)) + if (value == null) { continue; } - AddToDepth(objectsByDepth, currentDepth, value); - allObjects.Add(value); - - // Recursively discover if this value has injectable properties - if (PropertyInjectionCache.HasInjectableProperties(value.GetType())) + // Only discover if it implements IAsyncInitializer and hasn't been visited + if (value is IAsyncInitializer && visitedObjects.Add(value)) { + AddToDepth(objectsByDepth, currentDepth, value); + allObjects.Add(value); + + // Recursively discover nested objects DiscoverNestedObjects(value, objectsByDepth, visitedObjects, allObjects, currentDepth + 1); } } + catch + { + // Ignore properties that throw exceptions when accessed + } } } diff --git a/TUnit.Engine/Services/ObjectLifecycleService.cs b/TUnit.Engine/Services/ObjectLifecycleService.cs index 33d50d9b08..a18938dd29 100644 --- a/TUnit.Engine/Services/ObjectLifecycleService.cs +++ b/TUnit.Engine/Services/ObjectLifecycleService.cs @@ -112,10 +112,11 @@ public async Task RegisterArgumentsAsync( /// /// Prepares a test for execution. - /// Sets already-resolved cached property values on the current instance and initializes tracked objects. + /// Sets already-resolved cached property values on the current instance. /// This is needed because retries create new instances that don't have properties set yet. + /// Does NOT call IAsyncInitializer - that is deferred until after BeforeClass hooks via InitializeTestObjectsAsync. /// - public async Task PrepareTestAsync(TestContext testContext, CancellationToken cancellationToken) + public void PrepareTest(TestContext testContext) { var testClassInstance = testContext.Metadata.TestDetails.ClassInstance; @@ -123,7 +124,14 @@ public async Task PrepareTestAsync(TestContext testContext, CancellationToken ca // Properties were resolved and cached during RegisterTestAsync, so shared objects are already created // We just need to set them on the actual test instance (retries create new instances) SetCachedPropertiesOnInstance(testClassInstance, testContext); + } + /// + /// Initializes test objects (IAsyncInitializer) after BeforeClass hooks have run. + /// This ensures resources like Docker containers are not started until needed. + /// + public async Task InitializeTestObjectsAsync(TestContext testContext, CancellationToken cancellationToken) + { // Initialize all tracked objects (IAsyncInitializer) depth-first await InitializeTrackedObjectsAsync(testContext, cancellationToken); } @@ -175,6 +183,7 @@ private void SetCachedPropertiesOnInstance(object instance, TestContext testCont /// /// Initializes all tracked objects depth-first (deepest objects first). + /// This is called during test execution (after BeforeClass hooks) to initialize IAsyncInitializer objects. /// private async Task InitializeTrackedObjectsAsync(TestContext testContext, CancellationToken cancellationToken) { @@ -184,23 +193,41 @@ private async Task InitializeTrackedObjectsAsync(TestContext testContext, Cancel { var objectsAtLevel = testContext.TrackedObjects[level]; - // Initialize all objects at this depth in parallel - await Task.WhenAll(objectsAtLevel.Select(obj => - EnsureInitializedAsync( - obj, - testContext.StateBag.Items, - testContext.Metadata.TestDetails.MethodMetadata, - testContext.InternalEvents, - cancellationToken).AsTask())); + // Initialize each tracked object and its nested objects + foreach (var obj in objectsAtLevel) + { + // First initialize nested objects depth-first + await InitializeNestedObjectsForExecutionAsync(obj, cancellationToken); + + // Then initialize the object itself + await ObjectInitializer.InitializeAsync(obj, cancellationToken); + } } - // Finally initialize the test class itself - await EnsureInitializedAsync( - testContext.Metadata.TestDetails.ClassInstance, - testContext.StateBag.Items, - testContext.Metadata.TestDetails.MethodMetadata, - testContext.InternalEvents, - cancellationToken); + // Finally initialize the test class and its nested objects + var classInstance = testContext.Metadata.TestDetails.ClassInstance; + await InitializeNestedObjectsForExecutionAsync(classInstance, cancellationToken); + await ObjectInitializer.InitializeAsync(classInstance, cancellationToken); + } + + /// + /// Initializes nested objects during execution phase - all IAsyncInitializer objects. + /// + private async Task InitializeNestedObjectsForExecutionAsync(object rootObject, CancellationToken cancellationToken) + { + var graph = _objectGraphDiscoveryService.DiscoverNestedObjectGraph(rootObject); + + // Initialize from deepest to shallowest (skip depth 0 which is the root itself) + foreach (var depth in graph.GetDepthsDescending()) + { + if (depth == 0) continue; // Root handled separately + + var objectsAtDepth = graph.GetObjectsAtDepth(depth); + + // Initialize all IAsyncInitializer objects at this depth + await Task.WhenAll(objectsAtDepth + .Select(obj => ObjectInitializer.InitializeAsync(obj, cancellationToken).AsTask())); + } } #endregion @@ -234,6 +261,7 @@ public async ValueTask InjectPropertiesAsync( /// /// Ensures an object is fully initialized (property injection + IAsyncInitializer). /// Thread-safe with fast-path for already-initialized objects. + /// Called during test execution to initialize all IAsyncInitializer objects. /// public async ValueTask EnsureInitializedAsync( T obj, @@ -247,13 +275,17 @@ public async ValueTask EnsureInitializedAsync( throw new ArgumentNullException(nameof(obj)); } - // Fast path: already initialized + // Fast path: already processed by this service if (_initializationTasks.TryGetValue(obj, out var existingTcs) && existingTcs.Task.IsCompleted) { if (existingTcs.Task.IsFaulted) { await existingTcs.Task.ConfigureAwait(false); } + + // EnsureInitializedAsync is only called during discovery (from PropertyInjector). + // If the object is shared and has already been processed, just return it. + // Regular IAsyncInitializer objects will be initialized during execution via InitializeTrackedObjectsAsync. return obj; } @@ -284,7 +316,8 @@ public async ValueTask EnsureInitializedAsync( } /// - /// Core initialization: property injection + nested objects + IAsyncInitializer. + /// Core initialization: property injection + IAsyncDiscoveryInitializer only. + /// Regular IAsyncInitializer objects are NOT initialized here - they are deferred to execution phase. /// private async Task InitializeObjectCoreAsync( object obj, @@ -301,14 +334,12 @@ private async Task InitializeObjectCoreAsync( // Step 1: Inject properties await PropertyInjector.InjectPropertiesAsync(obj, objectBag, methodMetadata, events); - // Step 2: Initialize nested objects depth-first - await InitializeNestedObjectsAsync(obj, cancellationToken); + // Step 2: Initialize nested objects depth-first (discovery-only) + await InitializeNestedObjectsForDiscoveryAsync(obj, cancellationToken); - // Step 3: Call IAsyncInitializer on the object itself - if (obj is IAsyncInitializer asyncInitializer) - { - await ObjectInitializer.InitializeAsync(asyncInitializer, cancellationToken); - } + // Step 3: Call IAsyncDiscoveryInitializer only (not regular IAsyncInitializer) + // Regular IAsyncInitializer objects are deferred to execution phase via InitializeTestObjectsAsync + await ObjectInitializer.InitializeForDiscoveryAsync(obj, cancellationToken); } catch (Exception ex) { @@ -318,9 +349,9 @@ private async Task InitializeObjectCoreAsync( } /// - /// Initializes nested objects depth-first using the centralized ObjectGraphDiscoveryService. + /// Initializes nested objects during discovery phase - only IAsyncDiscoveryInitializer objects. /// - private async Task InitializeNestedObjectsAsync(object rootObject, CancellationToken cancellationToken) + private async Task InitializeNestedObjectsForDiscoveryAsync(object rootObject, CancellationToken cancellationToken) { var graph = _objectGraphDiscoveryService.DiscoverNestedObjectGraph(rootObject); @@ -331,9 +362,9 @@ private async Task InitializeNestedObjectsAsync(object rootObject, CancellationT var objectsAtDepth = graph.GetObjectsAtDepth(depth); + // Only initialize IAsyncDiscoveryInitializer objects during discovery await Task.WhenAll(objectsAtDepth - .Where(obj => obj is IAsyncInitializer) - .Select(obj => ObjectInitializer.InitializeAsync(obj, cancellationToken).AsTask())); + .Select(obj => ObjectInitializer.InitializeForDiscoveryAsync(obj, cancellationToken).AsTask())); } } diff --git a/TUnit.Engine/Services/PropertyInjector.cs b/TUnit.Engine/Services/PropertyInjector.cs index 36179c779e..54cbcf6573 100644 --- a/TUnit.Engine/Services/PropertyInjector.cs +++ b/TUnit.Engine/Services/PropertyInjector.cs @@ -525,30 +525,14 @@ private async Task ResolveAndCacheReflectionPropertyAsync( if (value != null) { - // Handle property injection and initialization appropriately during discovery - var hasInjectableProperties = PropertyInjectionCache.HasInjectableProperties(value.GetType()); - var isDiscoveryInitializer = value is IAsyncDiscoveryInitializer; - - if (isDiscoveryInitializer) - { - // Full initialization during discovery (property injection + IAsyncInitializer.InitializeAsync) - // for objects that explicitly opt-in via IAsyncDiscoveryInitializer - await _objectLifecycleService.Value.EnsureInitializedAsync( - value, - context.ObjectBag, - context.MethodMetadata, - context.Events); - } - else if (hasInjectableProperties) - { - // Property injection only, IAsyncInitializer.InitializeAsync deferred to execution - // Regular IAsyncInitializer objects are initialized during test execution by ObjectLifecycleService - await _objectLifecycleService.Value.InjectPropertiesAsync( - value, - context.ObjectBag, - context.MethodMetadata, - context.Events); - } + // EnsureInitializedAsync handles property injection and initialization. + // ObjectInitializer is phase-aware: during Discovery phase, only IAsyncDiscoveryInitializer + // objects are initialized; regular IAsyncInitializer objects are deferred to Execution phase. + await _objectLifecycleService.Value.EnsureInitializedAsync( + value, + context.ObjectBag, + context.MethodMetadata, + context.Events); return value; } diff --git a/TUnit.Engine/Services/TestExecution/TestCoordinator.cs b/TUnit.Engine/Services/TestExecution/TestCoordinator.cs index b448e459e6..683a9e8f4c 100644 --- a/TUnit.Engine/Services/TestExecution/TestCoordinator.cs +++ b/TUnit.Engine/Services/TestExecution/TestCoordinator.cs @@ -126,9 +126,9 @@ await TimeoutHelper.ExecuteWithTimeoutAsync( try { - await _testInitializer.InitializeTest(test, ct).ConfigureAwait(false); + _testInitializer.PrepareTest(test, ct); test.Context.RestoreExecutionContext(); - await _testExecutor.ExecuteAsync(test, ct).ConfigureAwait(false); + await _testExecutor.ExecuteAsync(test, _testInitializer, ct).ConfigureAwait(false); } finally { diff --git a/TUnit.Engine/TestExecutor.cs b/TUnit.Engine/TestExecutor.cs index e7d65b92f0..dbb80cc3f6 100644 --- a/TUnit.Engine/TestExecutor.cs +++ b/TUnit.Engine/TestExecutor.cs @@ -62,7 +62,7 @@ await _beforeHookTaskCache.GetOrCreateBeforeTestSessionTask( /// Creates a test executor delegate that wraps the provided executor with hook orchestration. /// Uses focused services that follow SRP to manage lifecycle and execution. /// - public async ValueTask ExecuteAsync(AbstractExecutableTest executableTest, CancellationToken cancellationToken) + public async ValueTask ExecuteAsync(AbstractExecutableTest executableTest, TestInitializer testInitializer, CancellationToken cancellationToken) { var testClass = executableTest.Metadata.TestClassType; @@ -112,6 +112,12 @@ await _eventReceiverOrchestrator.InvokeFirstTestInClassEventReceiversAsync( executableTest.Context.ClassContext.RestoreExecutionContext(); + // Initialize test objects (IAsyncInitializer) AFTER BeforeClass hooks + // This ensures resources like Docker containers are not started until needed + await testInitializer.InitializeTestObjectsAsync(executableTest, cancellationToken).ConfigureAwait(false); + + executableTest.Context.RestoreExecutionContext(); + // Early stage test start receivers run before instance-level hooks await _eventReceiverOrchestrator.InvokeTestStartEventReceiversAsync(executableTest.Context, cancellationToken, EventReceiverStage.Early).ConfigureAwait(false); diff --git a/TUnit.Engine/TestInitializer.cs b/TUnit.Engine/TestInitializer.cs index de117c6746..73fbb7fbc2 100644 --- a/TUnit.Engine/TestInitializer.cs +++ b/TUnit.Engine/TestInitializer.cs @@ -20,12 +20,19 @@ public TestInitializer( _objectLifecycleService = objectLifecycleService; } - public async ValueTask InitializeTest(AbstractExecutableTest test, CancellationToken cancellationToken) + public void PrepareTest(AbstractExecutableTest test, CancellationToken cancellationToken) { // Register event receivers _eventReceiverOrchestrator.RegisterReceivers(test.Context, cancellationToken); - // Prepare test: inject properties, track objects, initialize (IAsyncInitializer) - await _objectLifecycleService.PrepareTestAsync(test.Context, cancellationToken); + // Prepare test: set cached property values on the instance + // Does NOT call IAsyncInitializer - that is deferred until after BeforeClass hooks + _objectLifecycleService.PrepareTest(test.Context); + } + + public async ValueTask InitializeTestObjectsAsync(AbstractExecutableTest test, CancellationToken cancellationToken) + { + // Initialize test objects (IAsyncInitializer) - called after BeforeClass hooks + await _objectLifecycleService.InitializeTestObjectsAsync(test.Context, cancellationToken); } } From bebb3e2c79e5749e227c165c59c354e7894da2be Mon Sep 17 00:00:00 2001 From: Tom Longhurst <30480171+thomhurst@users.noreply.github.com> Date: Sun, 7 Dec 2025 14:10:33 +0000 Subject: [PATCH 03/20] feat: implement IAsyncDiscoveryInitializer and related classes for improved test discovery handling --- TUnit.Core/Discovery/ObjectGraph.cs | 44 ++ TUnit.Core/Discovery/ObjectGraphDiscoverer.cs | 392 ++++++++++++++++++ .../Interfaces/IInitializationCallback.cs | 33 ++ .../Interfaces/IObjectGraphDiscoverer.cs | 80 ++++ .../IObjectInitializationService.cs | 67 +++ TUnit.Core/ObjectInitializer.cs | 69 +-- .../PropertySetterFactory.cs | 39 +- .../Services/ObjectInitializationService.cs | 88 ++++ .../Tracking/TrackableObjectGraphProvider.cs | 177 ++------ .../Framework/TUnitServiceProvider.cs | 6 +- .../Services/ObjectGraphDiscoveryService.cs | 243 +---------- .../Services/ObjectLifecycleService.cs | 6 +- TUnit.Engine/Services/PropertyInjector.cs | 14 +- 13 files changed, 853 insertions(+), 405 deletions(-) create mode 100644 TUnit.Core/Discovery/ObjectGraph.cs create mode 100644 TUnit.Core/Discovery/ObjectGraphDiscoverer.cs create mode 100644 TUnit.Core/Interfaces/IInitializationCallback.cs create mode 100644 TUnit.Core/Interfaces/IObjectGraphDiscoverer.cs create mode 100644 TUnit.Core/Interfaces/IObjectInitializationService.cs create mode 100644 TUnit.Core/Services/ObjectInitializationService.cs diff --git a/TUnit.Core/Discovery/ObjectGraph.cs b/TUnit.Core/Discovery/ObjectGraph.cs new file mode 100644 index 0000000000..be8f8315bb --- /dev/null +++ b/TUnit.Core/Discovery/ObjectGraph.cs @@ -0,0 +1,44 @@ +using System.Collections.Concurrent; +using TUnit.Core.Interfaces; + +namespace TUnit.Core.Discovery; + +/// +/// Represents a discovered object graph organized by depth level. +/// +public sealed class ObjectGraph : IObjectGraph +{ + /// + /// Creates a new object graph from the discovered objects. + /// + /// Objects organized by depth level. + /// All unique objects in the graph. + public ObjectGraph(ConcurrentDictionary> objectsByDepth, HashSet allObjects) + { + ObjectsByDepth = objectsByDepth; + AllObjects = allObjects; + // Use IsEmpty for thread-safe check before accessing Keys + MaxDepth = objectsByDepth.IsEmpty ? -1 : objectsByDepth.Keys.Max(); + } + + /// + public ConcurrentDictionary> ObjectsByDepth { get; } + + /// + public HashSet AllObjects { get; } + + /// + public int MaxDepth { get; } + + /// + public IEnumerable GetObjectsAtDepth(int depth) + { + return ObjectsByDepth.TryGetValue(depth, out var objects) ? objects : []; + } + + /// + public IEnumerable GetDepthsDescending() + { + return ObjectsByDepth.Keys.OrderByDescending(d => d); + } +} diff --git a/TUnit.Core/Discovery/ObjectGraphDiscoverer.cs b/TUnit.Core/Discovery/ObjectGraphDiscoverer.cs new file mode 100644 index 0000000000..445dd3cc8b --- /dev/null +++ b/TUnit.Core/Discovery/ObjectGraphDiscoverer.cs @@ -0,0 +1,392 @@ +using System.Collections.Concurrent; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Reflection; +using TUnit.Core.Helpers; +using TUnit.Core.Interfaces; +using TUnit.Core.PropertyInjection; + +namespace TUnit.Core.Discovery; + +/// +/// Centralized service for discovering and organizing object graphs. +/// Consolidates duplicate graph traversal logic from ObjectGraphDiscoveryService and TrackableObjectGraphProvider. +/// Follows Single Responsibility Principle - only discovers objects, doesn't modify them. +/// +/// +/// +/// This class is thread-safe and uses cached reflection for performance. +/// Objects are organized by their nesting depth in the hierarchy: +/// +/// +/// Depth 0: Root objects (class args, method args, property values) +/// Depth 1+: Nested objects found in properties of objects at previous depth +/// +/// +public sealed class ObjectGraphDiscoverer : IObjectGraphDiscoverer +{ + // Cache for GetProperties() results per type - eliminates repeated reflection calls + private static readonly ConcurrentDictionary PropertyCache = new(); + + // Reference equality comparer for object tracking (ignores Equals overrides) + private static readonly Helpers.ReferenceEqualityComparer ReferenceComparer = new(); + + // Types to skip during discovery (primitives, strings, system types) + private static readonly HashSet SkipTypes = + [ + typeof(string), + typeof(decimal), + typeof(DateTime), + typeof(DateTimeOffset), + typeof(TimeSpan), + typeof(Guid) + ]; + + /// + public IObjectGraph DiscoverObjectGraph(TestContext testContext) + { + var objectsByDepth = new ConcurrentDictionary>(); + var allObjects = new HashSet(); + // Use ConcurrentDictionary for thread-safe visited tracking with reference equality + var visitedObjects = new ConcurrentDictionary(ReferenceComparer); + + var testDetails = testContext.Metadata.TestDetails; + + // Collect root-level objects (depth 0) + foreach (var classArgument in testDetails.TestClassArguments) + { + if (classArgument != null && visitedObjects.TryAdd(classArgument, 0)) + { + AddToDepth(objectsByDepth, 0, classArgument); + allObjects.Add(classArgument); + DiscoverNestedObjects(classArgument, objectsByDepth, visitedObjects, allObjects, currentDepth: 1); + } + } + + foreach (var methodArgument in testDetails.TestMethodArguments) + { + if (methodArgument != null && visitedObjects.TryAdd(methodArgument, 0)) + { + AddToDepth(objectsByDepth, 0, methodArgument); + allObjects.Add(methodArgument); + DiscoverNestedObjects(methodArgument, objectsByDepth, visitedObjects, allObjects, currentDepth: 1); + } + } + + foreach (var property in testDetails.TestClassInjectedPropertyArguments.Values) + { + if (property != null && visitedObjects.TryAdd(property, 0)) + { + AddToDepth(objectsByDepth, 0, property); + allObjects.Add(property); + DiscoverNestedObjects(property, objectsByDepth, visitedObjects, allObjects, currentDepth: 1); + } + } + + return new ObjectGraph(objectsByDepth, allObjects); + } + + /// + public IObjectGraph DiscoverNestedObjectGraph(object rootObject) + { + var objectsByDepth = new ConcurrentDictionary>(); + var allObjects = new HashSet(); + var visitedObjects = new ConcurrentDictionary(ReferenceComparer); + + if (visitedObjects.TryAdd(rootObject, 0)) + { + AddToDepth(objectsByDepth, 0, rootObject); + allObjects.Add(rootObject); + DiscoverNestedObjects(rootObject, objectsByDepth, visitedObjects, allObjects, currentDepth: 1); + } + + return new ObjectGraph(objectsByDepth, allObjects); + } + + /// + /// Discovers objects and adds them to the existing tracked objects dictionary. + /// Used by TrackableObjectGraphProvider to populate TestContext.TrackedObjects. + /// + /// The test context to discover objects from. + /// The tracked objects dictionary (same as testContext.TrackedObjects). + public ConcurrentDictionary> DiscoverAndTrackObjects(TestContext testContext) + { + var visitedObjects = testContext.TrackedObjects; + var testDetails = testContext.Metadata.TestDetails; + + foreach (var classArgument in testDetails.TestClassArguments) + { + if (classArgument != null && GetOrAddHashSet(visitedObjects, 0).Add(classArgument)) + { + DiscoverNestedObjectsForTracking(classArgument, visitedObjects, 1); + } + } + + foreach (var methodArgument in testDetails.TestMethodArguments) + { + if (methodArgument != null && GetOrAddHashSet(visitedObjects, 0).Add(methodArgument)) + { + DiscoverNestedObjectsForTracking(methodArgument, visitedObjects, 1); + } + } + + foreach (var property in testDetails.TestClassInjectedPropertyArguments.Values) + { + if (property != null && GetOrAddHashSet(visitedObjects, 0).Add(property)) + { + DiscoverNestedObjectsForTracking(property, visitedObjects, 1); + } + } + + return visitedObjects; + } + + /// + /// Recursively discovers nested objects that have injectable properties OR implement IAsyncInitializer. + /// + [UnconditionalSuppressMessage("Trimming", "IL2026", Justification = "Property discovery handles both AOT and reflection modes")] + private void DiscoverNestedObjects( + object obj, + ConcurrentDictionary> objectsByDepth, + ConcurrentDictionary visitedObjects, + HashSet allObjects, + int currentDepth) + { + var plan = PropertyInjectionCache.GetOrCreatePlan(obj.GetType()); + + // First, discover objects from injectable properties (data source attributes) + if (plan.HasProperties) + { + // Use source-generated properties if available, otherwise fall back to reflection + if (plan.SourceGeneratedProperties.Length > 0) + { + foreach (var metadata in plan.SourceGeneratedProperties) + { + var property = metadata.ContainingType.GetProperty(metadata.PropertyName); + if (property == null || !property.CanRead) + { + continue; + } + + var value = property.GetValue(obj); + if (value == null || !visitedObjects.TryAdd(value, 0)) + { + continue; + } + + AddToDepth(objectsByDepth, currentDepth, value); + allObjects.Add(value); + DiscoverNestedObjects(value, objectsByDepth, visitedObjects, allObjects, currentDepth + 1); + } + } + else if (plan.ReflectionProperties.Length > 0) + { + foreach (var (property, _) in plan.ReflectionProperties) + { + var value = property.GetValue(obj); + if (value == null || !visitedObjects.TryAdd(value, 0)) + { + continue; + } + + AddToDepth(objectsByDepth, currentDepth, value); + allObjects.Add(value); + DiscoverNestedObjects(value, objectsByDepth, visitedObjects, allObjects, currentDepth + 1); + } + } + } + + // Also discover nested IAsyncInitializer objects from ALL properties + DiscoverNestedInitializerObjects(obj, objectsByDepth, visitedObjects, allObjects, currentDepth); + } + + /// + /// Discovers nested objects for tracking (uses HashSet pattern for compatibility with TestContext.TrackedObjects). + /// + [UnconditionalSuppressMessage("Trimming", "IL2026", Justification = "Property discovery handles both AOT and reflection modes")] + private void DiscoverNestedObjectsForTracking( + object obj, + ConcurrentDictionary> visitedObjects, + int currentDepth) + { + var plan = PropertyInjectionCache.GetOrCreatePlan(obj.GetType()); + + // Check SourceRegistrar.IsEnabled for compatibility with existing TrackableObjectGraphProvider behavior + if (!SourceRegistrar.IsEnabled) + { + foreach (var prop in plan.ReflectionProperties) + { + var value = prop.Property.GetValue(obj); + if (value == null) + { + continue; + } + + if (!GetOrAddHashSet(visitedObjects, currentDepth).Add(value)) + { + continue; + } + + DiscoverNestedObjectsForTracking(value, visitedObjects, currentDepth + 1); + } + } + else + { + foreach (var metadata in plan.SourceGeneratedProperties) + { + var property = metadata.ContainingType.GetProperty(metadata.PropertyName); + if (property == null || !property.CanRead) + { + continue; + } + + var value = property.GetValue(obj); + if (value == null) + { + continue; + } + + if (!GetOrAddHashSet(visitedObjects, currentDepth).Add(value)) + { + continue; + } + + DiscoverNestedObjectsForTracking(value, visitedObjects, currentDepth + 1); + } + } + + // Also discover nested IAsyncInitializer objects from ALL properties + DiscoverNestedInitializerObjectsForTracking(obj, visitedObjects, currentDepth); + } + + /// + /// Discovers nested objects that implement IAsyncInitializer from all readable properties. + /// Uses cached reflection for performance. + /// + [UnconditionalSuppressMessage("Trimming", "IL2026", Justification = "Reflection fallback for nested initializers. In AOT, source-gen handles primary discovery.")] + [UnconditionalSuppressMessage("Trimming", "IL2070", Justification = "Reflection fallback for nested initializers. In AOT, source-gen handles primary discovery.")] + [UnconditionalSuppressMessage("Trimming", "IL2075", Justification = "Reflection fallback for nested initializers. In AOT, source-gen handles primary discovery.")] + private void DiscoverNestedInitializerObjects( + object obj, + ConcurrentDictionary> objectsByDepth, + ConcurrentDictionary visitedObjects, + HashSet allObjects, + int currentDepth) + { + var type = obj.GetType(); + + // Skip types that don't need discovery + if (ShouldSkipType(type)) + { + return; + } + + // Use cached properties for performance + var properties = GetCachedProperties(type); + + foreach (var property in properties) + { + try + { + var value = property.GetValue(obj); + if (value == null) + { + continue; + } + + // Only discover if it implements IAsyncInitializer and hasn't been visited + if (value is IAsyncInitializer && visitedObjects.TryAdd(value, 0)) + { + AddToDepth(objectsByDepth, currentDepth, value); + allObjects.Add(value); + DiscoverNestedObjects(value, objectsByDepth, visitedObjects, allObjects, currentDepth + 1); + } + } + catch (Exception ex) + { + // Log instead of silently swallowing - helps with debugging + Debug.WriteLine($"[ObjectGraphDiscoverer] Failed to access property '{property.Name}' on type '{type.Name}': {ex.Message}"); + } + } + } + + /// + /// Discovers nested IAsyncInitializer objects for tracking (uses HashSet pattern). + /// + [UnconditionalSuppressMessage("Trimming", "IL2026", Justification = "Reflection fallback for nested initializers. In AOT, source-gen handles primary discovery.")] + [UnconditionalSuppressMessage("Trimming", "IL2070", Justification = "Reflection fallback for nested initializers. In AOT, source-gen handles primary discovery.")] + [UnconditionalSuppressMessage("Trimming", "IL2075", Justification = "Reflection fallback for nested initializers. In AOT, source-gen handles primary discovery.")] + private void DiscoverNestedInitializerObjectsForTracking( + object obj, + ConcurrentDictionary> visitedObjects, + int currentDepth) + { + var type = obj.GetType(); + + if (ShouldSkipType(type)) + { + return; + } + + var properties = GetCachedProperties(type); + + foreach (var property in properties) + { + try + { + var value = property.GetValue(obj); + if (value == null) + { + continue; + } + + if (value is IAsyncInitializer && GetOrAddHashSet(visitedObjects, currentDepth).Add(value)) + { + DiscoverNestedObjectsForTracking(value, visitedObjects, currentDepth + 1); + } + } + catch (Exception ex) + { + Debug.WriteLine($"[ObjectGraphDiscoverer] Failed to access property '{property.Name}' on type '{type.Name}': {ex.Message}"); + } + } + } + + /// + /// Gets cached properties for a type, filtering to only readable non-indexed properties. + /// + [UnconditionalSuppressMessage("Trimming", "IL2070", Justification = "Reflection fallback for nested initializers. In AOT, source-gen handles primary discovery.")] + private static PropertyInfo[] GetCachedProperties(Type type) + { + return PropertyCache.GetOrAdd(type, t => + t.GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance) + .Where(p => p.CanRead && p.GetIndexParameters().Length == 0) + .ToArray()); + } + + /// + /// Checks if a type should be skipped during discovery. + /// + private static bool ShouldSkipType(Type type) + { + return type.IsPrimitive || + SkipTypes.Contains(type) || + type.Namespace?.StartsWith("System") == true; + } + + /// + /// Adds an object to the specified depth level. + /// + private static void AddToDepth(ConcurrentDictionary> objectsByDepth, int depth, object obj) + { + objectsByDepth.GetOrAdd(depth, _ => []).Add(obj); + } + + /// + /// Gets or creates a HashSet at the specified depth (thread-safe). + /// + private static HashSet GetOrAddHashSet(ConcurrentDictionary> dict, int depth) + { + return dict.GetOrAdd(depth, _ => []); + } +} diff --git a/TUnit.Core/Interfaces/IInitializationCallback.cs b/TUnit.Core/Interfaces/IInitializationCallback.cs new file mode 100644 index 0000000000..4e6449dd6e --- /dev/null +++ b/TUnit.Core/Interfaces/IInitializationCallback.cs @@ -0,0 +1,33 @@ +using System.Collections.Concurrent; + +namespace TUnit.Core.Interfaces; + +/// +/// Defines a callback interface for object initialization during property injection. +/// +/// +/// +/// This interface is used to break circular dependencies between property injection +/// and initialization services. Property injectors can call back to the initialization +/// service without directly depending on it. +/// +/// +internal interface IInitializationCallback +{ + /// + /// Ensures an object is fully initialized (property injection + IAsyncInitializer). + /// + /// The type of object to initialize. + /// The object to initialize. + /// Shared object bag for the test context. + /// Method metadata for the test. Can be null. + /// Test context events for tracking. + /// A token to monitor for cancellation requests. + /// The initialized object. + ValueTask EnsureInitializedAsync( + T obj, + ConcurrentDictionary? objectBag = null, + MethodMetadata? methodMetadata = null, + TestContextEvents? events = null, + CancellationToken cancellationToken = default) where T : notnull; +} diff --git a/TUnit.Core/Interfaces/IObjectGraphDiscoverer.cs b/TUnit.Core/Interfaces/IObjectGraphDiscoverer.cs new file mode 100644 index 0000000000..285d677643 --- /dev/null +++ b/TUnit.Core/Interfaces/IObjectGraphDiscoverer.cs @@ -0,0 +1,80 @@ +using System.Collections.Concurrent; + +namespace TUnit.Core.Interfaces; + +/// +/// Defines a contract for discovering object graphs from test contexts. +/// +/// +/// +/// Object graph discovery is used to find all objects that need initialization or disposal, +/// organized by their nesting depth in the object hierarchy. +/// +/// +/// The discoverer traverses: +/// +/// Test class constructor arguments +/// Test method arguments +/// Injected property values +/// Nested objects that implement +/// +/// +/// +public interface IObjectGraphDiscoverer +{ + /// + /// Discovers all objects from a test context, organized by depth level. + /// + /// The test context to discover objects from. + /// + /// An containing all discovered objects organized by depth. + /// Depth 0 contains root objects (arguments and property values). + /// Higher depths contain nested objects. + /// + IObjectGraph DiscoverObjectGraph(TestContext testContext); + + /// + /// Discovers nested objects from a single root object, organized by depth. + /// + /// The root object to discover nested objects from. + /// + /// An containing all discovered objects organized by depth. + /// Depth 0 contains the root object itself. + /// Higher depths contain nested objects. + /// + IObjectGraph DiscoverNestedObjectGraph(object rootObject); +} + +/// +/// Represents a discovered object graph organized by depth level. +/// +public interface IObjectGraph +{ + /// + /// Gets objects organized by depth (0 = root arguments, 1+ = nested). + /// + ConcurrentDictionary> ObjectsByDepth { get; } + + /// + /// Gets all unique objects in the graph. + /// + HashSet AllObjects { get; } + + /// + /// Gets the maximum nesting depth (-1 if empty). + /// + int MaxDepth { get; } + + /// + /// Gets objects at a specific depth level. + /// + /// The depth level to retrieve objects from. + /// An enumerable of objects at the specified depth, or empty if none exist. + IEnumerable GetObjectsAtDepth(int depth); + + /// + /// Gets depth levels in descending order (deepest first). + /// + /// An enumerable of depth levels ordered from deepest to shallowest. + IEnumerable GetDepthsDescending(); +} diff --git a/TUnit.Core/Interfaces/IObjectInitializationService.cs b/TUnit.Core/Interfaces/IObjectInitializationService.cs new file mode 100644 index 0000000000..4cf5d26bba --- /dev/null +++ b/TUnit.Core/Interfaces/IObjectInitializationService.cs @@ -0,0 +1,67 @@ +namespace TUnit.Core.Interfaces; + +/// +/// Defines a contract for managing object initialization with phase awareness. +/// +/// +/// +/// This service provides thread-safe, deduplicated initialization of objects that implement +/// or . +/// +/// +/// The service supports two initialization phases: +/// +/// Discovery phase: Only objects are initialized +/// Execution phase: All objects are initialized +/// +/// +/// +public interface IObjectInitializationService +{ + /// + /// Initializes an object during the execution phase. + /// + /// The object to initialize. If null or not an , no action is taken. + /// A token to monitor for cancellation requests. + /// A representing the asynchronous operation. + /// + /// + /// This method is thread-safe and ensures that each object is initialized exactly once. + /// Multiple concurrent calls for the same object will share the same initialization task. + /// + /// + ValueTask InitializeAsync(object? obj, CancellationToken cancellationToken = default); + + /// + /// Initializes an object during the discovery phase. + /// + /// The object to initialize. If null or not an , no action is taken. + /// A token to monitor for cancellation requests. + /// A representing the asynchronous operation. + /// + /// + /// Only objects implementing are initialized during discovery. + /// Regular objects are deferred to execution phase. + /// + /// + ValueTask InitializeForDiscoveryAsync(object? obj, CancellationToken cancellationToken = default); + + /// + /// Checks if an object has been successfully initialized. + /// + /// The object to check. + /// True if the object has been initialized successfully; otherwise, false. + /// + /// Returns false if the object is null, not an , + /// has not been initialized yet, or if initialization failed. + /// + bool IsInitialized(object? obj); + + /// + /// Clears the initialization cache. + /// + /// + /// Called at the end of a test session to release resources. + /// + void ClearCache(); +} diff --git a/TUnit.Core/ObjectInitializer.cs b/TUnit.Core/ObjectInitializer.cs index 33da8a31d6..e1a53768b4 100644 --- a/TUnit.Core/ObjectInitializer.cs +++ b/TUnit.Core/ObjectInitializer.cs @@ -1,19 +1,29 @@ -using System.Runtime.CompilerServices; +using System.Collections.Concurrent; +using TUnit.Core.Helpers; using TUnit.Core.Interfaces; +using TUnit.Core.Services; namespace TUnit.Core; /// -/// Centralized service for initializing objects that implement IAsyncInitializer. +/// Static facade for initializing objects that implement . /// Provides thread-safe, deduplicated initialization with explicit phase control. -/// -/// Use InitializeForDiscoveryAsync during test discovery - only IAsyncDiscoveryInitializer objects are initialized. -/// Use InitializeAsync during test execution - all IAsyncInitializer objects are initialized. /// +/// +/// +/// Use during test discovery - only objects are initialized. +/// Use during test execution - all objects are initialized. +/// +/// +/// For dependency injection scenarios, use directly. +/// +/// public static class ObjectInitializer { - private static readonly ConditionalWeakTable _initializationTasks = new(); - private static readonly Lock _lock = new(); + // Use ConcurrentDictionary with reference equality for thread-safe tracking + // This replaces ConditionalWeakTable which doesn't support explicit cleanup + private static readonly ConcurrentDictionary InitializationTasks = + new(new Helpers.ReferenceEqualityComparer()); /// /// Initializes an object during the discovery phase. @@ -52,8 +62,14 @@ public static ValueTask InitializeAsync(object? obj, CancellationToken cancellat } /// - /// Checks if an object has been initialized by ObjectInitializer. + /// Checks if an object has been successfully initialized by ObjectInitializer. /// + /// The object to check. + /// True if the object has been initialized successfully; otherwise, false. + /// + /// Returns false if the object is null, not an , + /// has not been initialized yet, or if initialization failed. + /// internal static bool IsInitialized(object? obj) { if (obj is not IAsyncInitializer) @@ -61,10 +77,20 @@ internal static bool IsInitialized(object? obj) return false; } - lock (_lock) - { - return _initializationTasks.TryGetValue(obj, out var task) && task.IsCompleted; - } + // Use Status == RanToCompletion to ensure we don't return true for faulted/canceled tasks + // (IsCompletedSuccessfully is not available in netstandard2.0) + return InitializationTasks.TryGetValue(obj, out var task) && task.Status == TaskStatus.RanToCompletion; + } + + /// + /// Clears the initialization cache. + /// + /// + /// Called at the end of a test session to release resources. + /// + internal static void ClearCache() + { + InitializationTasks.Clear(); } private static async ValueTask InitializeCoreAsync( @@ -72,20 +98,11 @@ private static async ValueTask InitializeCoreAsync( IAsyncInitializer asyncInitializer, CancellationToken cancellationToken) { - Task initializationTask; - - lock (_lock) - { - if (_initializationTasks.TryGetValue(obj, out var existingTask)) - { - initializationTask = existingTask; - } - else - { - initializationTask = asyncInitializer.InitializeAsync(); - _initializationTasks.Add(obj, initializationTask); - } - } + // Use GetOrAdd for thread-safe deduplication without holding lock during async + // Note: GetOrAdd's factory may be called multiple times under contention, + // but only one result is stored. Multiple InitializeAsync() calls is safe + // since we only await the winning task. + var initializationTask = InitializationTasks.GetOrAdd(obj, _ => asyncInitializer.InitializeAsync()); // Wait for initialization with cancellation support await initializationTask.WaitAsync(cancellationToken); diff --git a/TUnit.Core/PropertyInjection/PropertySetterFactory.cs b/TUnit.Core/PropertyInjection/PropertySetterFactory.cs index 3d9dce651d..8bc1662cdf 100644 --- a/TUnit.Core/PropertyInjection/PropertySetterFactory.cs +++ b/TUnit.Core/PropertyInjection/PropertySetterFactory.cs @@ -1,21 +1,56 @@ -using System.Diagnostics.CodeAnalysis; +using System.Collections.Concurrent; +using System.Diagnostics.CodeAnalysis; using System.Reflection; namespace TUnit.Core.PropertyInjection; /// -/// Factory for creating property setters. +/// Factory for creating property setters with caching for performance. /// Consolidates all property setter creation logic in one place following DRY principle. /// +/// +/// Setters are cached using the PropertyInfo as the key to avoid repeated reflection calls. +/// This significantly improves performance when the same property is accessed multiple times +/// (e.g., in test retries or shared test data scenarios). +/// internal static class PropertySetterFactory { + // Cache setters per PropertyInfo to avoid repeated reflection + private static readonly ConcurrentDictionary> SetterCache = new(); + + /// + /// Gets or creates a setter delegate for the given property. + /// Uses caching to avoid repeated reflection calls. + /// + #if NET6_0_OR_GREATER + [RequiresUnreferencedCode("Backing field access for init-only properties requires reflection")] + #endif + public static Action GetOrCreateSetter(PropertyInfo property) + { + return SetterCache.GetOrAdd(property, CreateSetterCore); + } + /// /// Creates a setter delegate for the given property. + /// Consider using for better performance through caching. /// #if NET6_0_OR_GREATER [RequiresUnreferencedCode("Backing field access for init-only properties requires reflection")] #endif public static Action CreateSetter(PropertyInfo property) + { + // Delegate to cached version for consistency + return GetOrCreateSetter(property); + } + + /// + /// Core implementation for creating a setter delegate. + /// Called by GetOrCreateSetter for caching. + /// + #if NET6_0_OR_GREATER + [RequiresUnreferencedCode("Backing field access for init-only properties requires reflection")] + #endif + private static Action CreateSetterCore(PropertyInfo property) { if (property.CanWrite && property.SetMethod != null) { diff --git a/TUnit.Core/Services/ObjectInitializationService.cs b/TUnit.Core/Services/ObjectInitializationService.cs new file mode 100644 index 0000000000..8264e3d326 --- /dev/null +++ b/TUnit.Core/Services/ObjectInitializationService.cs @@ -0,0 +1,88 @@ +using System.Collections.Concurrent; +using TUnit.Core.Helpers; +using TUnit.Core.Interfaces; + +namespace TUnit.Core.Services; + +/// +/// Thread-safe service for initializing objects that implement . +/// Provides deduplicated initialization with explicit phase control. +/// +/// +/// +/// This service replaces the static for dependency injection scenarios. +/// It uses with reference equality for thread-safe +/// deduplication without lock contention during async operations. +/// +/// +public sealed class ObjectInitializationService : IObjectInitializationService +{ + // Use ConcurrentDictionary with reference equality for thread-safe tracking + // This replaces ConditionalWeakTable which doesn't support explicit cleanup + private readonly ConcurrentDictionary _initializationTasks; + + /// + /// Creates a new instance of the initialization service. + /// + public ObjectInitializationService() + { + _initializationTasks = new ConcurrentDictionary(new Helpers.ReferenceEqualityComparer()); + } + + /// + public ValueTask InitializeForDiscoveryAsync(object? obj, CancellationToken cancellationToken = default) + { + // During discovery, only initialize IAsyncDiscoveryInitializer + if (obj is not IAsyncDiscoveryInitializer asyncDiscoveryInitializer) + { + return default; + } + + return InitializeCoreAsync(obj, asyncDiscoveryInitializer, cancellationToken); + } + + /// + public ValueTask InitializeAsync(object? obj, CancellationToken cancellationToken = default) + { + if (obj is not IAsyncInitializer asyncInitializer) + { + return default; + } + + return InitializeCoreAsync(obj, asyncInitializer, cancellationToken); + } + + /// + public bool IsInitialized(object? obj) + { + if (obj is not IAsyncInitializer) + { + return false; + } + + // Use Status == RanToCompletion to ensure we don't return true for faulted/canceled tasks + // (IsCompletedSuccessfully is not available in netstandard2.0) + return _initializationTasks.TryGetValue(obj, out var task) && task.Status == TaskStatus.RanToCompletion; + } + + /// + public void ClearCache() + { + _initializationTasks.Clear(); + } + + private async ValueTask InitializeCoreAsync( + object obj, + IAsyncInitializer asyncInitializer, + CancellationToken cancellationToken) + { + // Use GetOrAdd for thread-safe deduplication without holding lock during async + // Note: GetOrAdd's factory may be called multiple times under contention, + // but only one result is stored. Multiple InitializeAsync() calls is safe + // since we only await the winning task. + var initializationTask = _initializationTasks.GetOrAdd(obj, _ => asyncInitializer.InitializeAsync()); + + // Wait for initialization with cancellation support + await initializationTask.WaitAsync(cancellationToken); + } +} diff --git a/TUnit.Core/Tracking/TrackableObjectGraphProvider.cs b/TUnit.Core/Tracking/TrackableObjectGraphProvider.cs index 8f1e5bb7a1..585434ff2d 100644 --- a/TUnit.Core/Tracking/TrackableObjectGraphProvider.cs +++ b/TUnit.Core/Tracking/TrackableObjectGraphProvider.cs @@ -1,176 +1,71 @@ -using System.Collections.Concurrent; -using System.Diagnostics.CodeAnalysis; -using System.Reflection; +using System.Collections.Concurrent; +using TUnit.Core.Discovery; using TUnit.Core.Interfaces; -using TUnit.Core.PropertyInjection; using TUnit.Core.StaticProperties; namespace TUnit.Core.Tracking; +/// +/// Provides trackable objects from test contexts for lifecycle management. +/// Delegates to for the actual discovery logic. +/// internal class TrackableObjectGraphProvider { - public ConcurrentDictionary> GetTrackableObjects(TestContext testContext) - { - var visitedObjects = testContext.TrackedObjects; - - var testDetails = testContext.Metadata.TestDetails; - - foreach (var classArgument in testDetails.TestClassArguments) - { - if (classArgument != null && visitedObjects.GetOrAdd(0, []).Add(classArgument)) - { - AddNestedTrackableObjects(classArgument, visitedObjects, 1); - } - } + private readonly IObjectGraphDiscoverer _discoverer; - foreach (var methodArgument in testDetails.TestMethodArguments) - { - if (methodArgument != null && visitedObjects.GetOrAdd(0, []).Add(methodArgument)) - { - AddNestedTrackableObjects(methodArgument, visitedObjects, 1); - } - } - - foreach (var property in testDetails.TestClassInjectedPropertyArguments.Values) - { - if (property != null && visitedObjects.GetOrAdd(0, []).Add(property)) - { - AddNestedTrackableObjects(property, visitedObjects, 1); - } - } - - return visitedObjects; - } - - private static void AddToLevel(Dictionary> objectsByLevel, int level, object obj) + /// + /// Creates a new instance with the default discoverer. + /// + public TrackableObjectGraphProvider() : this(new ObjectGraphDiscoverer()) { - if (!objectsByLevel.TryGetValue(level, out var list)) - { - list = []; - objectsByLevel[level] = list; - } - list.Add(obj); } /// - /// Get trackable objects for static properties (session-level) + /// Creates a new instance with a custom discoverer (for testing). /// - public IEnumerable GetStaticPropertyTrackableObjects() + public TrackableObjectGraphProvider(IObjectGraphDiscoverer discoverer) { - foreach (var value in StaticPropertyRegistry.GetAllInitializedValues()) - { - if (value != null) - { - yield return value; - } - } + _discoverer = discoverer; } - [UnconditionalSuppressMessage("Trimming", "IL2026", Justification = "Property discovery handles both AOT and reflection modes")] - private void AddNestedTrackableObjects(object obj, ConcurrentDictionary> visitedObjects, int currentDepth) + /// + /// Gets trackable objects from a test context, organized by depth level. + /// Delegates to the shared ObjectGraphDiscoverer to eliminate code duplication. + /// + public ConcurrentDictionary> GetTrackableObjects(TestContext testContext) { - var plan = PropertyInjectionCache.GetOrCreatePlan(obj.GetType()); - - if(!SourceRegistrar.IsEnabled) + // Use the ObjectGraphDiscoverer's specialized method that populates TrackedObjects directly + if (_discoverer is ObjectGraphDiscoverer concreteDiscoverer) { - foreach (var prop in plan.ReflectionProperties) - { - var value = prop.Property.GetValue(obj); - - if (value == null) - { - continue; - } + return concreteDiscoverer.DiscoverAndTrackObjects(testContext); + } - // Check if already visited before yielding to prevent duplicates - if (!visitedObjects.GetOrAdd(currentDepth, []).Add(value)) - { - continue; - } + // Fallback for custom implementations (testing) + var graph = _discoverer.DiscoverObjectGraph(testContext); + var trackedObjects = testContext.TrackedObjects; - AddNestedTrackableObjects(value, visitedObjects, currentDepth + 1); - } - } - else + foreach (var (depth, objects) in graph.ObjectsByDepth) { - foreach (var metadata in plan.SourceGeneratedProperties) + var depthSet = trackedObjects.GetOrAdd(depth, _ => []); + foreach (var obj in objects) { - var property = metadata.ContainingType.GetProperty(metadata.PropertyName); - - if (property == null || !property.CanRead) - { - continue; - } - - var value = property.GetValue(obj); - - if (value == null) - { - continue; - } - - // Check if already visited before yielding to prevent duplicates - if (!visitedObjects.GetOrAdd(currentDepth, []).Add(value)) - { - continue; - } - - AddNestedTrackableObjects(value, visitedObjects, currentDepth + 1); + depthSet.Add(obj); } } - // Also discover nested IAsyncInitializer objects from ALL properties - // This handles cases where nested objects don't have data source attributes - // but still implement IAsyncInitializer and need to be tracked for disposal - AddNestedInitializerObjects(obj, visitedObjects, currentDepth); + return trackedObjects; } /// - /// Discovers nested objects that implement IAsyncInitializer from all readable properties. - /// This is separate from injectable property discovery to handle objects without data source attributes. - /// This is a best-effort fallback - in AOT scenarios, properties with data source attributes are discovered via source generation. + /// Gets trackable objects for static properties (session-level). /// - [UnconditionalSuppressMessage("Trimming", "IL2026", Justification = "Reflection fallback for nested initializers. In AOT, source-gen handles primary discovery.")] - [UnconditionalSuppressMessage("Trimming", "IL2070", Justification = "Reflection fallback for nested initializers. In AOT, source-gen handles primary discovery.")] - [UnconditionalSuppressMessage("Trimming", "IL2075", Justification = "Reflection fallback for nested initializers. In AOT, source-gen handles primary discovery.")] - private void AddNestedInitializerObjects(object obj, ConcurrentDictionary> visitedObjects, int currentDepth) + public IEnumerable GetStaticPropertyTrackableObjects() { - var type = obj.GetType(); - - // Skip primitive types, strings, and system types - if (type.IsPrimitive || type == typeof(string) || type.Namespace?.StartsWith("System") == true) - { - return; - } - - // Get all readable instance properties - var properties = type.GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance); - - foreach (var property in properties) + foreach (var value in StaticPropertyRegistry.GetAllInitializedValues()) { - if (!property.CanRead || property.GetIndexParameters().Length > 0) - { - continue; - } - - try - { - var value = property.GetValue(obj); - if (value == null) - { - continue; - } - - // Only discover if it implements IAsyncInitializer and hasn't been visited - if (value is IAsyncInitializer && visitedObjects.GetOrAdd(currentDepth, []).Add(value)) - { - // Recursively discover nested objects - AddNestedTrackableObjects(value, visitedObjects, currentDepth + 1); - } - } - catch + if (value != null) { - // Ignore properties that throw exceptions when accessed + yield return value; } } } diff --git a/TUnit.Engine/Framework/TUnitServiceProvider.cs b/TUnit.Engine/Framework/TUnitServiceProvider.cs index f64d64d004..9a31e5eb99 100644 --- a/TUnit.Engine/Framework/TUnitServiceProvider.cs +++ b/TUnit.Engine/Framework/TUnitServiceProvider.cs @@ -114,9 +114,11 @@ public TUnitServiceProvider(IExtension extension, var objectTracker = new ObjectTracker(trackableObjectGraphProvider, disposer); // Use Lazy to break circular dependency between PropertyInjector and ObjectLifecycleService + // PropertyInjector now depends on IInitializationCallback interface (implemented by ObjectLifecycleService) + // This follows Dependency Inversion Principle and improves testability ObjectLifecycleService? objectLifecycleServiceInstance = null; - var lazyObjectLifecycleService = new Lazy(() => objectLifecycleServiceInstance!); - var lazyPropertyInjector = new Lazy(() => new PropertyInjector(lazyObjectLifecycleService, TestSessionId)); + var lazyInitializationCallback = new Lazy(() => objectLifecycleServiceInstance!); + var lazyPropertyInjector = new Lazy(() => new PropertyInjector(lazyInitializationCallback, TestSessionId)); objectLifecycleServiceInstance = new ObjectLifecycleService(lazyPropertyInjector, objectGraphDiscoveryService, objectTracker); ObjectLifecycleService = Register(objectLifecycleServiceInstance); diff --git a/TUnit.Engine/Services/ObjectGraphDiscoveryService.cs b/TUnit.Engine/Services/ObjectGraphDiscoveryService.cs index 4819a275f7..c7ddb36ffd 100644 --- a/TUnit.Engine/Services/ObjectGraphDiscoveryService.cs +++ b/TUnit.Engine/Services/ObjectGraphDiscoveryService.cs @@ -1,258 +1,45 @@ -using System.Collections.Concurrent; -using System.Diagnostics.CodeAnalysis; -using System.Reflection; using TUnit.Core; +using TUnit.Core.Discovery; using TUnit.Core.Interfaces; -using TUnit.Core.PropertyInjection; namespace TUnit.Engine.Services; /// -/// Centralized service for discovering and organizing object graphs. -/// Eliminates duplicate graph traversal logic that was scattered across -/// PropertyInjectionService, DataSourceInitializer, and TrackableObjectGraphProvider. -/// Follows Single Responsibility Principle - only discovers objects, doesn't modify them. +/// Service for discovering and organizing object graphs in TUnit.Engine. +/// Delegates to in TUnit.Core for the actual discovery logic. /// internal sealed class ObjectGraphDiscoveryService { - /// - /// Discovers all objects from test context arguments and properties, organized by depth level. - /// Depth 0 = root objects (class args, method args, property values) - /// Depth 1+ = nested objects found in properties of objects at previous depth - /// - public ObjectGraph DiscoverObjectGraph(TestContext testContext) - { - var objectsByDepth = new ConcurrentDictionary>(); - var allObjects = new HashSet(); - var visitedObjects = new HashSet(); - - var testDetails = testContext.Metadata.TestDetails; - - // Collect root-level objects (depth 0) - foreach (var classArgument in testDetails.TestClassArguments) - { - if (classArgument != null && visitedObjects.Add(classArgument)) - { - AddToDepth(objectsByDepth, 0, classArgument); - allObjects.Add(classArgument); - DiscoverNestedObjects(classArgument, objectsByDepth, visitedObjects, allObjects, currentDepth: 1); - } - } - - foreach (var methodArgument in testDetails.TestMethodArguments) - { - if (methodArgument != null && visitedObjects.Add(methodArgument)) - { - AddToDepth(objectsByDepth, 0, methodArgument); - allObjects.Add(methodArgument); - DiscoverNestedObjects(methodArgument, objectsByDepth, visitedObjects, allObjects, currentDepth: 1); - } - } - - foreach (var property in testDetails.TestClassInjectedPropertyArguments.Values) - { - if (property != null && visitedObjects.Add(property)) - { - AddToDepth(objectsByDepth, 0, property); - allObjects.Add(property); - DiscoverNestedObjects(property, objectsByDepth, visitedObjects, allObjects, currentDepth: 1); - } - } - - return new ObjectGraph(objectsByDepth, allObjects); - } - - /// - /// Discovers nested objects from a single root object, organized by depth. - /// Used for discovering objects within a data source or property value. - /// - public ObjectGraph DiscoverNestedObjectGraph(object rootObject) - { - var objectsByDepth = new ConcurrentDictionary>(); - var allObjects = new HashSet(); - var visitedObjects = new HashSet(); - - if (visitedObjects.Add(rootObject)) - { - AddToDepth(objectsByDepth, 0, rootObject); - allObjects.Add(rootObject); - DiscoverNestedObjects(rootObject, objectsByDepth, visitedObjects, allObjects, currentDepth: 1); - } - - return new ObjectGraph(objectsByDepth, allObjects); - } + private readonly IObjectGraphDiscoverer _discoverer; /// - /// Recursively discovers nested objects that have injectable properties OR implement IAsyncInitializer. - /// This ensures that all nested objects that need initialization are discovered, - /// even if they don't have explicit data source attributes. + /// Creates a new instance with the default discoverer. /// - [UnconditionalSuppressMessage("Trimming", "IL2026", Justification = "Property discovery handles both AOT and reflection modes")] - private void DiscoverNestedObjects( - object obj, - ConcurrentDictionary> objectsByDepth, - HashSet visitedObjects, - HashSet allObjects, - int currentDepth) + public ObjectGraphDiscoveryService() : this(new ObjectGraphDiscoverer()) { - var plan = PropertyInjectionCache.GetOrCreatePlan(obj.GetType()); - - // First, discover objects from injectable properties (data source attributes) - if (plan.HasProperties) - { - // Use source-generated properties if available, otherwise fall back to reflection - if (plan.SourceGeneratedProperties.Length > 0) - { - foreach (var metadata in plan.SourceGeneratedProperties) - { - var property = metadata.ContainingType.GetProperty(metadata.PropertyName); - if (property == null || !property.CanRead) - { - continue; - } - - var value = property.GetValue(obj); - if (value == null || !visitedObjects.Add(value)) - { - continue; - } - - AddToDepth(objectsByDepth, currentDepth, value); - allObjects.Add(value); - - // Recursively discover nested objects - DiscoverNestedObjects(value, objectsByDepth, visitedObjects, allObjects, currentDepth + 1); - } - } - else if (plan.ReflectionProperties.Length > 0) - { - foreach (var (property, _) in plan.ReflectionProperties) - { - var value = property.GetValue(obj); - if (value == null || !visitedObjects.Add(value)) - { - continue; - } - - AddToDepth(objectsByDepth, currentDepth, value); - allObjects.Add(value); - - // Recursively discover nested objects - DiscoverNestedObjects(value, objectsByDepth, visitedObjects, allObjects, currentDepth + 1); - } - } - } - - // Also discover nested IAsyncInitializer objects from ALL properties - // This handles cases where nested objects don't have data source attributes - // but still implement IAsyncInitializer and need to be initialized - DiscoverNestedInitializerObjects(obj, objectsByDepth, visitedObjects, allObjects, currentDepth); } /// - /// Discovers nested objects that implement IAsyncInitializer from all readable properties. - /// This is separate from injectable property discovery to handle objects without data source attributes. - /// This is a best-effort fallback - in AOT scenarios, properties with data source attributes are discovered via source generation. + /// Creates a new instance with a custom discoverer (for testing). /// - [UnconditionalSuppressMessage("Trimming", "IL2026", Justification = "Reflection fallback for nested initializers. In AOT, source-gen handles primary discovery.")] - [UnconditionalSuppressMessage("Trimming", "IL2070", Justification = "Reflection fallback for nested initializers. In AOT, source-gen handles primary discovery.")] - [UnconditionalSuppressMessage("Trimming", "IL2075", Justification = "Reflection fallback for nested initializers. In AOT, source-gen handles primary discovery.")] - private void DiscoverNestedInitializerObjects( - object obj, - ConcurrentDictionary> objectsByDepth, - HashSet visitedObjects, - HashSet allObjects, - int currentDepth) - { - var type = obj.GetType(); - - // Skip primitive types, strings, and system types - if (type.IsPrimitive || type == typeof(string) || type.Namespace?.StartsWith("System") == true) - { - return; - } - - // Get all readable instance properties - var properties = type.GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance); - - foreach (var property in properties) - { - if (!property.CanRead || property.GetIndexParameters().Length > 0) - { - continue; - } - - try - { - var value = property.GetValue(obj); - if (value == null) - { - continue; - } - - // Only discover if it implements IAsyncInitializer and hasn't been visited - if (value is IAsyncInitializer && visitedObjects.Add(value)) - { - AddToDepth(objectsByDepth, currentDepth, value); - allObjects.Add(value); - - // Recursively discover nested objects - DiscoverNestedObjects(value, objectsByDepth, visitedObjects, allObjects, currentDepth + 1); - } - } - catch - { - // Ignore properties that throw exceptions when accessed - } - } - } - - private static void AddToDepth(ConcurrentDictionary> objectsByDepth, int depth, object obj) - { - objectsByDepth.GetOrAdd(depth, _ => []).Add(obj); - } -} - -/// -/// Represents a discovered object graph organized by depth level. -/// -internal sealed class ObjectGraph -{ - public ObjectGraph(ConcurrentDictionary> objectsByDepth, HashSet allObjects) + public ObjectGraphDiscoveryService(IObjectGraphDiscoverer discoverer) { - ObjectsByDepth = objectsByDepth; - AllObjects = allObjects; - MaxDepth = objectsByDepth.Count > 0 ? objectsByDepth.Keys.Max() : -1; + _discoverer = discoverer; } /// - /// Objects organized by depth (0 = root arguments, 1+ = nested). - /// - public ConcurrentDictionary> ObjectsByDepth { get; } - - /// - /// All unique objects in the graph. - /// - public HashSet AllObjects { get; } - - /// - /// Maximum nesting depth (-1 if empty). - /// - public int MaxDepth { get; } - - /// - /// Gets objects at a specific depth level. + /// Discovers all objects from test context arguments and properties, organized by depth level. /// - public IEnumerable GetObjectsAtDepth(int depth) + public IObjectGraph DiscoverObjectGraph(TestContext testContext) { - return ObjectsByDepth.TryGetValue(depth, out var objects) ? objects : []; + return _discoverer.DiscoverObjectGraph(testContext); } /// - /// Gets depth levels in descending order (deepest first). + /// Discovers nested objects from a single root object, organized by depth. /// - public IEnumerable GetDepthsDescending() + public IObjectGraph DiscoverNestedObjectGraph(object rootObject) { - return ObjectsByDepth.Keys.OrderByDescending(d => d); + return _discoverer.DiscoverNestedObjectGraph(rootObject); } } diff --git a/TUnit.Engine/Services/ObjectLifecycleService.cs b/TUnit.Engine/Services/ObjectLifecycleService.cs index a18938dd29..e3c5938536 100644 --- a/TUnit.Engine/Services/ObjectLifecycleService.cs +++ b/TUnit.Engine/Services/ObjectLifecycleService.cs @@ -17,7 +17,11 @@ namespace TUnit.Engine.Services; /// Uses Lazy<T> for dependencies to break circular references without manual Initialize() calls. /// Follows clear phase separation: Register → Inject → Initialize → Cleanup. /// -internal sealed class ObjectLifecycleService : IObjectRegistry +/// +/// Implements to allow PropertyInjector to call back for initialization +/// without creating a direct dependency (breaking the circular reference pattern). +/// +internal sealed class ObjectLifecycleService : IObjectRegistry, IInitializationCallback { private readonly Lazy _propertyInjector; private readonly ObjectGraphDiscoveryService _objectGraphDiscoveryService; diff --git a/TUnit.Engine/Services/PropertyInjector.cs b/TUnit.Engine/Services/PropertyInjector.cs index 54cbcf6573..4de82d7d5f 100644 --- a/TUnit.Engine/Services/PropertyInjector.cs +++ b/TUnit.Engine/Services/PropertyInjector.cs @@ -14,17 +14,21 @@ namespace TUnit.Engine.Services; /// Follows Single Responsibility Principle - only injects property values, doesn't initialize objects. /// Uses Lazy initialization to break circular dependencies without manual Initialize() calls. /// +/// +/// Depends on rather than a concrete service, +/// enabling testability and following Dependency Inversion Principle. +/// internal sealed class PropertyInjector { - private readonly Lazy _objectLifecycleService; + private readonly Lazy _initializationCallback; private readonly string _testSessionId; // Object pool for visited dictionaries to reduce allocations private static readonly ConcurrentBag> _visitedObjectsPool = new(); - public PropertyInjector(Lazy objectLifecycleService, string testSessionId) + public PropertyInjector(Lazy initializationCallback, string testSessionId) { - _objectLifecycleService = objectLifecycleService; + _initializationCallback = initializationCallback; _testSessionId = testSessionId; } @@ -528,7 +532,7 @@ private async Task ResolveAndCacheReflectionPropertyAsync( // EnsureInitializedAsync handles property injection and initialization. // ObjectInitializer is phase-aware: during Discovery phase, only IAsyncDiscoveryInitializer // objects are initialized; regular IAsyncInitializer objects are deferred to Execution phase. - await _objectLifecycleService.Value.EnsureInitializedAsync( + await _initializationCallback.Value.EnsureInitializedAsync( value, context.ObjectBag, context.MethodMetadata, @@ -560,7 +564,7 @@ await _objectLifecycleService.Value.EnsureInitializedAsync( } // Ensure the data source is initialized - return await _objectLifecycleService.Value.EnsureInitializedAsync( + return await _initializationCallback.Value.EnsureInitializedAsync( dataSource, context.ObjectBag, context.MethodMetadata, From 0041ee5085ac0828ca2be53aee2af78833b1f5b2 Mon Sep 17 00:00:00 2001 From: Tom Longhurst <30480171+thomhurst@users.noreply.github.com> Date: Sun, 7 Dec 2025 14:22:07 +0000 Subject: [PATCH 04/20] feat: implement IAsyncDiscoveryInitializer and related classes for improved test discovery handling --- .../TestBuildContextOutputCaptureTests.cs | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/TUnit.TestProject/TestBuildContextOutputCaptureTests.cs b/TUnit.TestProject/TestBuildContextOutputCaptureTests.cs index cb11a76453..246e622391 100644 --- a/TUnit.TestProject/TestBuildContextOutputCaptureTests.cs +++ b/TUnit.TestProject/TestBuildContextOutputCaptureTests.cs @@ -30,9 +30,12 @@ public DataSourceWithConstructorOutput() } /// - /// Data source that writes to console in async initializer + /// Data source that writes to console in async initializer. + /// Uses IAsyncDiscoveryInitializer so it initializes during test discovery/building, + /// allowing the output to be captured in the test's build context. + /// Note: Regular IAsyncInitializer only runs during test execution (per issue #3992 fix). /// - public class DataSourceWithAsyncInitOutput : IAsyncInitializer + public class DataSourceWithAsyncInitOutput : IAsyncDiscoveryInitializer { public string Value { get; private set; } = "Uninitialized"; @@ -88,8 +91,9 @@ public async Task Test_CapturesConstructorOutput_InTestResults(DataSourceWithCon [ClassDataSource] public async Task Test_CapturesAsyncInitializerOutput_InTestResults(DataSourceWithAsyncInitOutput data) { - // The InitializeAsync output should be captured during test building - // and included in the test's output + // The InitializeAsync output should be captured during test building. + // Note: This uses IAsyncDiscoveryInitializer which runs during discovery. + // Regular IAsyncInitializer runs during execution only (per issue #3992 fix). // Get the test output var output = TestContext.Current!.GetStandardOutput(); From 8e97f3584e1b8f8fa361752205cf0087ea4b776b Mon Sep 17 00:00:00 2001 From: Tom Longhurst <30480171+thomhurst@users.noreply.github.com> Date: Sun, 7 Dec 2025 14:50:12 +0000 Subject: [PATCH 05/20] feat: implement IAsyncDiscoveryInitializer and related classes for improved test discovery handling --- TUnit.Core/Discovery/ObjectGraph.cs | 42 ++++++-- TUnit.Core/Discovery/ObjectGraphDiscoverer.cs | 102 +++++++++++++++--- .../Interfaces/IObjectGraphDiscoverer.cs | 20 +++- TUnit.Core/ObjectInitializer.cs | 43 ++++++-- .../Services/ObjectInitializationService.cs | 62 ++--------- .../Services/ObjectGraphDiscoveryService.cs | 8 +- .../Services/ObjectLifecycleService.cs | 29 +++-- 7 files changed, 197 insertions(+), 109 deletions(-) diff --git a/TUnit.Core/Discovery/ObjectGraph.cs b/TUnit.Core/Discovery/ObjectGraph.cs index be8f8315bb..458cb5a50b 100644 --- a/TUnit.Core/Discovery/ObjectGraph.cs +++ b/TUnit.Core/Discovery/ObjectGraph.cs @@ -1,4 +1,5 @@ using System.Collections.Concurrent; +using System.Collections.ObjectModel; using TUnit.Core.Interfaces; namespace TUnit.Core.Discovery; @@ -6,8 +7,19 @@ namespace TUnit.Core.Discovery; /// /// Represents a discovered object graph organized by depth level. /// +/// +/// Internal collections are stored privately and exposed as read-only views +/// to prevent callers from corrupting internal state. +/// public sealed class ObjectGraph : IObjectGraph { + private readonly ConcurrentDictionary> _objectsByDepth; + private readonly HashSet _allObjects; + + // Cached read-only views (created lazily on first access) + private IReadOnlyDictionary>? _readOnlyObjectsByDepth; + private IReadOnlyCollection? _readOnlyAllObjects; + /// /// Creates a new object graph from the discovered objects. /// @@ -15,17 +27,35 @@ public sealed class ObjectGraph : IObjectGraph /// All unique objects in the graph. public ObjectGraph(ConcurrentDictionary> objectsByDepth, HashSet allObjects) { - ObjectsByDepth = objectsByDepth; - AllObjects = allObjects; + _objectsByDepth = objectsByDepth; + _allObjects = allObjects; // Use IsEmpty for thread-safe check before accessing Keys MaxDepth = objectsByDepth.IsEmpty ? -1 : objectsByDepth.Keys.Max(); } /// - public ConcurrentDictionary> ObjectsByDepth { get; } + public IReadOnlyDictionary> ObjectsByDepth + { + get + { + // Create read-only view lazily and cache it + // Note: This creates a snapshot - subsequent modifications to internal collections won't be reflected + return _readOnlyObjectsByDepth ??= new ReadOnlyDictionary>( + _objectsByDepth.ToDictionary( + kvp => kvp.Key, + kvp => (IReadOnlyCollection)kvp.Value.ToArray())); + } + } /// - public HashSet AllObjects { get; } + public IReadOnlyCollection AllObjects + { + get + { + // Create read-only view lazily and cache it + return _readOnlyAllObjects ??= _allObjects.ToArray(); + } + } /// public int MaxDepth { get; } @@ -33,12 +63,12 @@ public ObjectGraph(ConcurrentDictionary> objectsByDepth, Ha /// public IEnumerable GetObjectsAtDepth(int depth) { - return ObjectsByDepth.TryGetValue(depth, out var objects) ? objects : []; + return _objectsByDepth.TryGetValue(depth, out var objects) ? objects : []; } /// public IEnumerable GetDepthsDescending() { - return ObjectsByDepth.Keys.OrderByDescending(d => d); + return _objectsByDepth.Keys.OrderByDescending(d => d); } } diff --git a/TUnit.Core/Discovery/ObjectGraphDiscoverer.cs b/TUnit.Core/Discovery/ObjectGraphDiscoverer.cs index 445dd3cc8b..cff7be33fd 100644 --- a/TUnit.Core/Discovery/ObjectGraphDiscoverer.cs +++ b/TUnit.Core/Discovery/ObjectGraphDiscoverer.cs @@ -25,6 +25,18 @@ namespace TUnit.Core.Discovery; /// public sealed class ObjectGraphDiscoverer : IObjectGraphDiscoverer { + /// + /// Maximum recursion depth for object graph discovery. + /// Prevents stack overflow on deep or circular object graphs. + /// + private const int MaxRecursionDepth = 50; + + /// + /// Maximum size for the property cache before cleanup is triggered. + /// Prevents unbounded memory growth in long-running test sessions. + /// + private const int MaxCacheSize = 10000; + // Cache for GetProperties() results per type - eliminates repeated reflection calls private static readonly ConcurrentDictionary PropertyCache = new(); @@ -43,7 +55,7 @@ public sealed class ObjectGraphDiscoverer : IObjectGraphDiscoverer ]; /// - public IObjectGraph DiscoverObjectGraph(TestContext testContext) + public IObjectGraph DiscoverObjectGraph(TestContext testContext, CancellationToken cancellationToken = default) { var objectsByDepth = new ConcurrentDictionary>(); var allObjects = new HashSet(); @@ -55,31 +67,34 @@ public IObjectGraph DiscoverObjectGraph(TestContext testContext) // Collect root-level objects (depth 0) foreach (var classArgument in testDetails.TestClassArguments) { + cancellationToken.ThrowIfCancellationRequested(); if (classArgument != null && visitedObjects.TryAdd(classArgument, 0)) { AddToDepth(objectsByDepth, 0, classArgument); allObjects.Add(classArgument); - DiscoverNestedObjects(classArgument, objectsByDepth, visitedObjects, allObjects, currentDepth: 1); + DiscoverNestedObjects(classArgument, objectsByDepth, visitedObjects, allObjects, currentDepth: 1, cancellationToken); } } foreach (var methodArgument in testDetails.TestMethodArguments) { + cancellationToken.ThrowIfCancellationRequested(); if (methodArgument != null && visitedObjects.TryAdd(methodArgument, 0)) { AddToDepth(objectsByDepth, 0, methodArgument); allObjects.Add(methodArgument); - DiscoverNestedObjects(methodArgument, objectsByDepth, visitedObjects, allObjects, currentDepth: 1); + DiscoverNestedObjects(methodArgument, objectsByDepth, visitedObjects, allObjects, currentDepth: 1, cancellationToken); } } foreach (var property in testDetails.TestClassInjectedPropertyArguments.Values) { + cancellationToken.ThrowIfCancellationRequested(); if (property != null && visitedObjects.TryAdd(property, 0)) { AddToDepth(objectsByDepth, 0, property); allObjects.Add(property); - DiscoverNestedObjects(property, objectsByDepth, visitedObjects, allObjects, currentDepth: 1); + DiscoverNestedObjects(property, objectsByDepth, visitedObjects, allObjects, currentDepth: 1, cancellationToken); } } @@ -87,7 +102,7 @@ public IObjectGraph DiscoverObjectGraph(TestContext testContext) } /// - public IObjectGraph DiscoverNestedObjectGraph(object rootObject) + public IObjectGraph DiscoverNestedObjectGraph(object rootObject, CancellationToken cancellationToken = default) { var objectsByDepth = new ConcurrentDictionary>(); var allObjects = new HashSet(); @@ -97,7 +112,7 @@ public IObjectGraph DiscoverNestedObjectGraph(object rootObject) { AddToDepth(objectsByDepth, 0, rootObject); allObjects.Add(rootObject); - DiscoverNestedObjects(rootObject, objectsByDepth, visitedObjects, allObjects, currentDepth: 1); + DiscoverNestedObjects(rootObject, objectsByDepth, visitedObjects, allObjects, currentDepth: 1, cancellationToken); } return new ObjectGraph(objectsByDepth, allObjects); @@ -150,8 +165,18 @@ private void DiscoverNestedObjects( ConcurrentDictionary> objectsByDepth, ConcurrentDictionary visitedObjects, HashSet allObjects, - int currentDepth) + int currentDepth, + CancellationToken cancellationToken) { + // Guard against excessive recursion to prevent stack overflow + if (currentDepth > MaxRecursionDepth) + { + Debug.WriteLine($"[ObjectGraphDiscoverer] Max recursion depth ({MaxRecursionDepth}) reached for type '{obj.GetType().Name}'"); + return; + } + + cancellationToken.ThrowIfCancellationRequested(); + var plan = PropertyInjectionCache.GetOrCreatePlan(obj.GetType()); // First, discover objects from injectable properties (data source attributes) @@ -162,6 +187,7 @@ private void DiscoverNestedObjects( { foreach (var metadata in plan.SourceGeneratedProperties) { + cancellationToken.ThrowIfCancellationRequested(); var property = metadata.ContainingType.GetProperty(metadata.PropertyName); if (property == null || !property.CanRead) { @@ -176,13 +202,14 @@ private void DiscoverNestedObjects( AddToDepth(objectsByDepth, currentDepth, value); allObjects.Add(value); - DiscoverNestedObjects(value, objectsByDepth, visitedObjects, allObjects, currentDepth + 1); + DiscoverNestedObjects(value, objectsByDepth, visitedObjects, allObjects, currentDepth + 1, cancellationToken); } } else if (plan.ReflectionProperties.Length > 0) { foreach (var (property, _) in plan.ReflectionProperties) { + cancellationToken.ThrowIfCancellationRequested(); var value = property.GetValue(obj); if (value == null || !visitedObjects.TryAdd(value, 0)) { @@ -191,13 +218,13 @@ private void DiscoverNestedObjects( AddToDepth(objectsByDepth, currentDepth, value); allObjects.Add(value); - DiscoverNestedObjects(value, objectsByDepth, visitedObjects, allObjects, currentDepth + 1); + DiscoverNestedObjects(value, objectsByDepth, visitedObjects, allObjects, currentDepth + 1, cancellationToken); } } } // Also discover nested IAsyncInitializer objects from ALL properties - DiscoverNestedInitializerObjects(obj, objectsByDepth, visitedObjects, allObjects, currentDepth); + DiscoverNestedInitializerObjects(obj, objectsByDepth, visitedObjects, allObjects, currentDepth, cancellationToken); } /// @@ -209,6 +236,13 @@ private void DiscoverNestedObjectsForTracking( ConcurrentDictionary> visitedObjects, int currentDepth) { + // Guard against excessive recursion to prevent stack overflow + if (currentDepth > MaxRecursionDepth) + { + Debug.WriteLine($"[ObjectGraphDiscoverer] Max recursion depth ({MaxRecursionDepth}) reached for type '{obj.GetType().Name}'"); + return; + } + var plan = PropertyInjectionCache.GetOrCreatePlan(obj.GetType()); // Check SourceRegistrar.IsEnabled for compatibility with existing TrackableObjectGraphProvider behavior @@ -271,8 +305,18 @@ private void DiscoverNestedInitializerObjects( ConcurrentDictionary> objectsByDepth, ConcurrentDictionary visitedObjects, HashSet allObjects, - int currentDepth) + int currentDepth, + CancellationToken cancellationToken) { + // Guard against excessive recursion to prevent stack overflow + if (currentDepth > MaxRecursionDepth) + { + Debug.WriteLine($"[ObjectGraphDiscoverer] Max recursion depth ({MaxRecursionDepth}) reached for type '{obj.GetType().Name}'"); + return; + } + + cancellationToken.ThrowIfCancellationRequested(); + var type = obj.GetType(); // Skip types that don't need discovery @@ -286,6 +330,7 @@ private void DiscoverNestedInitializerObjects( foreach (var property in properties) { + cancellationToken.ThrowIfCancellationRequested(); try { var value = property.GetValue(obj); @@ -299,9 +344,13 @@ private void DiscoverNestedInitializerObjects( { AddToDepth(objectsByDepth, currentDepth, value); allObjects.Add(value); - DiscoverNestedObjects(value, objectsByDepth, visitedObjects, allObjects, currentDepth + 1); + DiscoverNestedObjects(value, objectsByDepth, visitedObjects, allObjects, currentDepth + 1, cancellationToken); } } + catch (OperationCanceledException) + { + throw; // Propagate cancellation + } catch (Exception ex) { // Log instead of silently swallowing - helps with debugging @@ -321,6 +370,13 @@ private void DiscoverNestedInitializerObjectsForTracking( ConcurrentDictionary> visitedObjects, int currentDepth) { + // Guard against excessive recursion to prevent stack overflow + if (currentDepth > MaxRecursionDepth) + { + Debug.WriteLine($"[ObjectGraphDiscoverer] Max recursion depth ({MaxRecursionDepth}) reached for type '{obj.GetType().Name}'"); + return; + } + var type = obj.GetType(); if (ShouldSkipType(type)) @@ -354,16 +410,38 @@ private void DiscoverNestedInitializerObjectsForTracking( /// /// Gets cached properties for a type, filtering to only readable non-indexed properties. + /// Includes periodic cache cleanup to prevent unbounded memory growth. /// [UnconditionalSuppressMessage("Trimming", "IL2070", Justification = "Reflection fallback for nested initializers. In AOT, source-gen handles primary discovery.")] private static PropertyInfo[] GetCachedProperties(Type type) { + // Periodic cleanup if cache grows too large to prevent memory leaks + if (PropertyCache.Count > MaxCacheSize) + { + // Clear half the cache (simple approach - non-LRU for performance) + var keysToRemove = PropertyCache.Keys.Take(MaxCacheSize / 2).ToList(); + foreach (var key in keysToRemove) + { + PropertyCache.TryRemove(key, out _); + } + + Debug.WriteLine($"[ObjectGraphDiscoverer] PropertyCache exceeded {MaxCacheSize} entries, cleared {keysToRemove.Count} entries"); + } + return PropertyCache.GetOrAdd(type, t => t.GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance) .Where(p => p.CanRead && p.GetIndexParameters().Length == 0) .ToArray()); } + /// + /// Clears the property cache. Called at end of test session to release memory. + /// + public static void ClearCache() + { + PropertyCache.Clear(); + } + /// /// Checks if a type should be skipped during discovery. /// diff --git a/TUnit.Core/Interfaces/IObjectGraphDiscoverer.cs b/TUnit.Core/Interfaces/IObjectGraphDiscoverer.cs index 285d677643..b2e2842d70 100644 --- a/TUnit.Core/Interfaces/IObjectGraphDiscoverer.cs +++ b/TUnit.Core/Interfaces/IObjectGraphDiscoverer.cs @@ -26,39 +26,51 @@ public interface IObjectGraphDiscoverer /// Discovers all objects from a test context, organized by depth level. /// /// The test context to discover objects from. + /// Optional cancellation token for long-running discovery. /// /// An containing all discovered objects organized by depth. /// Depth 0 contains root objects (arguments and property values). /// Higher depths contain nested objects. /// - IObjectGraph DiscoverObjectGraph(TestContext testContext); + IObjectGraph DiscoverObjectGraph(TestContext testContext, CancellationToken cancellationToken = default); /// /// Discovers nested objects from a single root object, organized by depth. /// /// The root object to discover nested objects from. + /// Optional cancellation token for long-running discovery. /// /// An containing all discovered objects organized by depth. /// Depth 0 contains the root object itself. /// Higher depths contain nested objects. /// - IObjectGraph DiscoverNestedObjectGraph(object rootObject); + IObjectGraph DiscoverNestedObjectGraph(object rootObject, CancellationToken cancellationToken = default); } /// /// Represents a discovered object graph organized by depth level. /// +/// +/// Collections are exposed as read-only to prevent callers from corrupting internal state. +/// Use and for safe iteration. +/// public interface IObjectGraph { /// /// Gets objects organized by depth (0 = root arguments, 1+ = nested). /// - ConcurrentDictionary> ObjectsByDepth { get; } + /// + /// Returns a read-only view. Use for iteration. + /// + IReadOnlyDictionary> ObjectsByDepth { get; } /// /// Gets all unique objects in the graph. /// - HashSet AllObjects { get; } + /// + /// Returns a read-only view to prevent modification. + /// + IReadOnlyCollection AllObjects { get; } /// /// Gets the maximum nesting depth (-1 if empty). diff --git a/TUnit.Core/ObjectInitializer.cs b/TUnit.Core/ObjectInitializer.cs index e1a53768b4..889a7f4dfa 100644 --- a/TUnit.Core/ObjectInitializer.cs +++ b/TUnit.Core/ObjectInitializer.cs @@ -20,9 +20,10 @@ namespace TUnit.Core; /// public static class ObjectInitializer { - // Use ConcurrentDictionary with reference equality for thread-safe tracking - // This replaces ConditionalWeakTable which doesn't support explicit cleanup - private static readonly ConcurrentDictionary InitializationTasks = + // Use Lazy pattern to ensure InitializeAsync is called exactly once per object, + // even under contention. GetOrAdd's factory can be called multiple times, but with + // Lazy + ExecutionAndPublication mode, only one initialization actually runs. + private static readonly ConcurrentDictionary> InitializationTasks = new(new Helpers.ReferenceEqualityComparer()); /// @@ -79,7 +80,10 @@ internal static bool IsInitialized(object? obj) // Use Status == RanToCompletion to ensure we don't return true for faulted/canceled tasks // (IsCompletedSuccessfully is not available in netstandard2.0) - return InitializationTasks.TryGetValue(obj, out var task) && task.Status == TaskStatus.RanToCompletion; + // With Lazy, we need to check if the Lazy has a value AND that value completed successfully + return InitializationTasks.TryGetValue(obj, out var lazyTask) && + lazyTask.IsValueCreated && + lazyTask.Value.Status == TaskStatus.RanToCompletion; } /// @@ -98,13 +102,30 @@ private static async ValueTask InitializeCoreAsync( IAsyncInitializer asyncInitializer, CancellationToken cancellationToken) { - // Use GetOrAdd for thread-safe deduplication without holding lock during async - // Note: GetOrAdd's factory may be called multiple times under contention, - // but only one result is stored. Multiple InitializeAsync() calls is safe - // since we only await the winning task. - var initializationTask = InitializationTasks.GetOrAdd(obj, _ => asyncInitializer.InitializeAsync()); + // Use Lazy with ExecutionAndPublication mode to ensure InitializeAsync + // is called exactly once, even under contention. GetOrAdd's factory may be + // called multiple times, but Lazy ensures only one initialization runs. + var lazyTask = InitializationTasks.GetOrAdd(obj, + _ => new Lazy( + () => asyncInitializer.InitializeAsync(), + LazyThreadSafetyMode.ExecutionAndPublication)); - // Wait for initialization with cancellation support - await initializationTask.WaitAsync(cancellationToken); + try + { + // Wait for initialization with cancellation support + await lazyTask.Value.WaitAsync(cancellationToken); + } + catch (OperationCanceledException) + { + // Propagate cancellation without modification + throw; + } + catch + { + // Remove failed initialization from cache to allow retry + // This is important for transient failures that may succeed on retry + InitializationTasks.TryRemove(obj, out _); + throw; + } } } diff --git a/TUnit.Core/Services/ObjectInitializationService.cs b/TUnit.Core/Services/ObjectInitializationService.cs index 8264e3d326..d033605fd1 100644 --- a/TUnit.Core/Services/ObjectInitializationService.cs +++ b/TUnit.Core/Services/ObjectInitializationService.cs @@ -1,5 +1,3 @@ -using System.Collections.Concurrent; -using TUnit.Core.Helpers; using TUnit.Core.Interfaces; namespace TUnit.Core.Services; @@ -10,79 +8,33 @@ namespace TUnit.Core.Services; /// /// /// -/// This service replaces the static for dependency injection scenarios. -/// It uses with reference equality for thread-safe -/// deduplication without lock contention during async operations. +/// This service delegates to the static to ensure consistent +/// behavior and avoid duplicate caches. This consolidates initialization tracking in one place. /// /// public sealed class ObjectInitializationService : IObjectInitializationService { - // Use ConcurrentDictionary with reference equality for thread-safe tracking - // This replaces ConditionalWeakTable which doesn't support explicit cleanup - private readonly ConcurrentDictionary _initializationTasks; - /// /// Creates a new instance of the initialization service. /// public ObjectInitializationService() { - _initializationTasks = new ConcurrentDictionary(new Helpers.ReferenceEqualityComparer()); + // No local cache needed - delegates to static ObjectInitializer } /// public ValueTask InitializeForDiscoveryAsync(object? obj, CancellationToken cancellationToken = default) - { - // During discovery, only initialize IAsyncDiscoveryInitializer - if (obj is not IAsyncDiscoveryInitializer asyncDiscoveryInitializer) - { - return default; - } - - return InitializeCoreAsync(obj, asyncDiscoveryInitializer, cancellationToken); - } + => ObjectInitializer.InitializeForDiscoveryAsync(obj, cancellationToken); /// public ValueTask InitializeAsync(object? obj, CancellationToken cancellationToken = default) - { - if (obj is not IAsyncInitializer asyncInitializer) - { - return default; - } - - return InitializeCoreAsync(obj, asyncInitializer, cancellationToken); - } + => ObjectInitializer.InitializeAsync(obj, cancellationToken); /// public bool IsInitialized(object? obj) - { - if (obj is not IAsyncInitializer) - { - return false; - } - - // Use Status == RanToCompletion to ensure we don't return true for faulted/canceled tasks - // (IsCompletedSuccessfully is not available in netstandard2.0) - return _initializationTasks.TryGetValue(obj, out var task) && task.Status == TaskStatus.RanToCompletion; - } + => ObjectInitializer.IsInitialized(obj); /// public void ClearCache() - { - _initializationTasks.Clear(); - } - - private async ValueTask InitializeCoreAsync( - object obj, - IAsyncInitializer asyncInitializer, - CancellationToken cancellationToken) - { - // Use GetOrAdd for thread-safe deduplication without holding lock during async - // Note: GetOrAdd's factory may be called multiple times under contention, - // but only one result is stored. Multiple InitializeAsync() calls is safe - // since we only await the winning task. - var initializationTask = _initializationTasks.GetOrAdd(obj, _ => asyncInitializer.InitializeAsync()); - - // Wait for initialization with cancellation support - await initializationTask.WaitAsync(cancellationToken); - } + => ObjectInitializer.ClearCache(); } diff --git a/TUnit.Engine/Services/ObjectGraphDiscoveryService.cs b/TUnit.Engine/Services/ObjectGraphDiscoveryService.cs index c7ddb36ffd..1a0636f605 100644 --- a/TUnit.Engine/Services/ObjectGraphDiscoveryService.cs +++ b/TUnit.Engine/Services/ObjectGraphDiscoveryService.cs @@ -30,16 +30,16 @@ public ObjectGraphDiscoveryService(IObjectGraphDiscoverer discoverer) /// /// Discovers all objects from test context arguments and properties, organized by depth level. /// - public IObjectGraph DiscoverObjectGraph(TestContext testContext) + public IObjectGraph DiscoverObjectGraph(TestContext testContext, CancellationToken cancellationToken = default) { - return _discoverer.DiscoverObjectGraph(testContext); + return _discoverer.DiscoverObjectGraph(testContext, cancellationToken); } /// /// Discovers nested objects from a single root object, organized by depth. /// - public IObjectGraph DiscoverNestedObjectGraph(object rootObject) + public IObjectGraph DiscoverNestedObjectGraph(object rootObject, CancellationToken cancellationToken = default) { - return _discoverer.DiscoverNestedObjectGraph(rootObject); + return _discoverer.DiscoverNestedObjectGraph(rootObject, cancellationToken); } } diff --git a/TUnit.Engine/Services/ObjectLifecycleService.cs b/TUnit.Engine/Services/ObjectLifecycleService.cs index e3c5938536..4e9b97fb76 100644 --- a/TUnit.Engine/Services/ObjectLifecycleService.cs +++ b/TUnit.Engine/Services/ObjectLifecycleService.cs @@ -219,7 +219,7 @@ private async Task InitializeTrackedObjectsAsync(TestContext testContext, Cancel /// private async Task InitializeNestedObjectsForExecutionAsync(object rootObject, CancellationToken cancellationToken) { - var graph = _objectGraphDiscoveryService.DiscoverNestedObjectGraph(rootObject); + var graph = _objectGraphDiscoveryService.DiscoverNestedObjectGraph(rootObject, cancellationToken); // Initialize from deepest to shallowest (skip depth 0 which is the root itself) foreach (var depth in graph.GetDepthsDescending()) @@ -333,23 +333,18 @@ private async Task InitializeObjectCoreAsync( objectBag ??= new ConcurrentDictionary(); events ??= new TestContextEvents(); - try - { - // Step 1: Inject properties - await PropertyInjector.InjectPropertiesAsync(obj, objectBag, methodMetadata, events); + // Let exceptions propagate naturally - don't wrap in InvalidOperationException + // This aligns with ObjectInitializer behavior and provides cleaner stack traces - // Step 2: Initialize nested objects depth-first (discovery-only) - await InitializeNestedObjectsForDiscoveryAsync(obj, cancellationToken); + // Step 1: Inject properties + await PropertyInjector.InjectPropertiesAsync(obj, objectBag, methodMetadata, events); - // Step 3: Call IAsyncDiscoveryInitializer only (not regular IAsyncInitializer) - // Regular IAsyncInitializer objects are deferred to execution phase via InitializeTestObjectsAsync - await ObjectInitializer.InitializeForDiscoveryAsync(obj, cancellationToken); - } - catch (Exception ex) - { - throw new InvalidOperationException( - $"Failed to initialize object of type '{obj.GetType().Name}': {ex.Message}", ex); - } + // Step 2: Initialize nested objects depth-first (discovery-only) + await InitializeNestedObjectsForDiscoveryAsync(obj, cancellationToken); + + // Step 3: Call IAsyncDiscoveryInitializer only (not regular IAsyncInitializer) + // Regular IAsyncInitializer objects are deferred to execution phase via InitializeTestObjectsAsync + await ObjectInitializer.InitializeForDiscoveryAsync(obj, cancellationToken); } /// @@ -357,7 +352,7 @@ private async Task InitializeObjectCoreAsync( /// private async Task InitializeNestedObjectsForDiscoveryAsync(object rootObject, CancellationToken cancellationToken) { - var graph = _objectGraphDiscoveryService.DiscoverNestedObjectGraph(rootObject); + var graph = _objectGraphDiscoveryService.DiscoverNestedObjectGraph(rootObject, cancellationToken); // Initialize from deepest to shallowest (skip depth 0 which is the root itself) foreach (var depth in graph.GetDepthsDescending()) From 07b221504190c329fe43085fb7872d3c55b81be1 Mon Sep 17 00:00:00 2001 From: Tom Longhurst <30480171+thomhurst@users.noreply.github.com> Date: Sun, 7 Dec 2025 15:13:54 +0000 Subject: [PATCH 06/20] feat: implement IAsyncDiscoveryInitializer and related classes for improved test discovery handling --- TUnit.Core/Discovery/ObjectGraph.cs | 78 +++++++---- TUnit.Core/Discovery/ObjectGraphDiscoverer.cs | 122 +++++++++++++++--- .../Helpers/ReferenceEqualityComparer.cs | 12 +- 3 files changed, 164 insertions(+), 48 deletions(-) diff --git a/TUnit.Core/Discovery/ObjectGraph.cs b/TUnit.Core/Discovery/ObjectGraph.cs index 458cb5a50b..e0a86f3558 100644 --- a/TUnit.Core/Discovery/ObjectGraph.cs +++ b/TUnit.Core/Discovery/ObjectGraph.cs @@ -10,15 +10,19 @@ namespace TUnit.Core.Discovery; /// /// Internal collections are stored privately and exposed as read-only views /// to prevent callers from corrupting internal state. +/// Uses Lazy<T> for thread-safe lazy initialization of read-only views. /// public sealed class ObjectGraph : IObjectGraph { private readonly ConcurrentDictionary> _objectsByDepth; private readonly HashSet _allObjects; - // Cached read-only views (created lazily on first access) - private IReadOnlyDictionary>? _readOnlyObjectsByDepth; - private IReadOnlyCollection? _readOnlyAllObjects; + // Thread-safe lazy initialization of read-only views + private readonly Lazy>> _lazyReadOnlyObjectsByDepth; + private readonly Lazy> _lazyReadOnlyAllObjects; + + // Cached sorted depths (computed once in constructor) + private readonly int[] _sortedDepthsDescending; /// /// Creates a new object graph from the discovered objects. @@ -29,33 +33,28 @@ public ObjectGraph(ConcurrentDictionary> objectsByDepth, Ha { _objectsByDepth = objectsByDepth; _allObjects = allObjects; + // Use IsEmpty for thread-safe check before accessing Keys MaxDepth = objectsByDepth.IsEmpty ? -1 : objectsByDepth.Keys.Max(); + + // Cache sorted depths (computed once, reused on each call to GetDepthsDescending) + _sortedDepthsDescending = objectsByDepth.Keys.OrderByDescending(d => d).ToArray(); + + // Use Lazy with ExecutionAndPublication for thread-safe single initialization + _lazyReadOnlyObjectsByDepth = new Lazy>>( + CreateReadOnlyObjectsByDepth, + LazyThreadSafetyMode.ExecutionAndPublication); + + _lazyReadOnlyAllObjects = new Lazy>( + () => _allObjects.ToArray(), + LazyThreadSafetyMode.ExecutionAndPublication); } /// - public IReadOnlyDictionary> ObjectsByDepth - { - get - { - // Create read-only view lazily and cache it - // Note: This creates a snapshot - subsequent modifications to internal collections won't be reflected - return _readOnlyObjectsByDepth ??= new ReadOnlyDictionary>( - _objectsByDepth.ToDictionary( - kvp => kvp.Key, - kvp => (IReadOnlyCollection)kvp.Value.ToArray())); - } - } + public IReadOnlyDictionary> ObjectsByDepth => _lazyReadOnlyObjectsByDepth.Value; /// - public IReadOnlyCollection AllObjects - { - get - { - // Create read-only view lazily and cache it - return _readOnlyAllObjects ??= _allObjects.ToArray(); - } - } + public IReadOnlyCollection AllObjects => _lazyReadOnlyAllObjects.Value; /// public int MaxDepth { get; } @@ -63,12 +62,41 @@ public IReadOnlyCollection AllObjects /// public IEnumerable GetObjectsAtDepth(int depth) { - return _objectsByDepth.TryGetValue(depth, out var objects) ? objects : []; + if (!_objectsByDepth.TryGetValue(depth, out var objects)) + { + return []; + } + + // Lock and copy to prevent concurrent modification issues + lock (objects) + { + return objects.ToArray(); + } } /// public IEnumerable GetDepthsDescending() { - return _objectsByDepth.Keys.OrderByDescending(d => d); + // Return cached sorted depths (computed once in constructor) + return _sortedDepthsDescending; + } + + /// + /// Creates a thread-safe read-only snapshot of objects by depth. + /// + private IReadOnlyDictionary> CreateReadOnlyObjectsByDepth() + { + var dict = new Dictionary>(_objectsByDepth.Count); + + foreach (var kvp in _objectsByDepth) + { + // Lock each HashSet while copying to ensure consistency + lock (kvp.Value) + { + dict[kvp.Key] = kvp.Value.ToArray(); + } + } + + return new ReadOnlyDictionary>(dict); } } diff --git a/TUnit.Core/Discovery/ObjectGraphDiscoverer.cs b/TUnit.Core/Discovery/ObjectGraphDiscoverer.cs index cff7be33fd..2783c8169b 100644 --- a/TUnit.Core/Discovery/ObjectGraphDiscoverer.cs +++ b/TUnit.Core/Discovery/ObjectGraphDiscoverer.cs @@ -40,6 +40,9 @@ public sealed class ObjectGraphDiscoverer : IObjectGraphDiscoverer // Cache for GetProperties() results per type - eliminates repeated reflection calls private static readonly ConcurrentDictionary PropertyCache = new(); + // Flag to coordinate cache cleanup (prevents multiple threads cleaning simultaneously) + private static int _cleanupInProgress; + // Reference equality comparer for object tracking (ignores Equals overrides) private static readonly Helpers.ReferenceEqualityComparer ReferenceComparer = new(); @@ -131,7 +134,7 @@ public ConcurrentDictionary> DiscoverAndTrackObjects(TestCo foreach (var classArgument in testDetails.TestClassArguments) { - if (classArgument != null && GetOrAddHashSet(visitedObjects, 0).Add(classArgument)) + if (classArgument != null && TryAddToHashSet(visitedObjects, 0, classArgument)) { DiscoverNestedObjectsForTracking(classArgument, visitedObjects, 1); } @@ -139,7 +142,7 @@ public ConcurrentDictionary> DiscoverAndTrackObjects(TestCo foreach (var methodArgument in testDetails.TestMethodArguments) { - if (methodArgument != null && GetOrAddHashSet(visitedObjects, 0).Add(methodArgument)) + if (methodArgument != null && TryAddToHashSet(visitedObjects, 0, methodArgument)) { DiscoverNestedObjectsForTracking(methodArgument, visitedObjects, 1); } @@ -147,7 +150,7 @@ public ConcurrentDictionary> DiscoverAndTrackObjects(TestCo foreach (var property in testDetails.TestClassInjectedPropertyArguments.Values) { - if (property != null && GetOrAddHashSet(visitedObjects, 0).Add(property)) + if (property != null && TryAddToHashSet(visitedObjects, 0, property)) { DiscoverNestedObjectsForTracking(property, visitedObjects, 1); } @@ -171,7 +174,9 @@ private void DiscoverNestedObjects( // Guard against excessive recursion to prevent stack overflow if (currentDepth > MaxRecursionDepth) { +#if DEBUG Debug.WriteLine($"[ObjectGraphDiscoverer] Max recursion depth ({MaxRecursionDepth}) reached for type '{obj.GetType().Name}'"); +#endif return; } @@ -239,7 +244,9 @@ private void DiscoverNestedObjectsForTracking( // Guard against excessive recursion to prevent stack overflow if (currentDepth > MaxRecursionDepth) { +#if DEBUG Debug.WriteLine($"[ObjectGraphDiscoverer] Max recursion depth ({MaxRecursionDepth}) reached for type '{obj.GetType().Name}'"); +#endif return; } @@ -256,7 +263,7 @@ private void DiscoverNestedObjectsForTracking( continue; } - if (!GetOrAddHashSet(visitedObjects, currentDepth).Add(value)) + if (!TryAddToHashSet(visitedObjects, currentDepth, value)) { continue; } @@ -280,7 +287,7 @@ private void DiscoverNestedObjectsForTracking( continue; } - if (!GetOrAddHashSet(visitedObjects, currentDepth).Add(value)) + if (!TryAddToHashSet(visitedObjects, currentDepth, value)) { continue; } @@ -311,7 +318,9 @@ private void DiscoverNestedInitializerObjects( // Guard against excessive recursion to prevent stack overflow if (currentDepth > MaxRecursionDepth) { +#if DEBUG Debug.WriteLine($"[ObjectGraphDiscoverer] Max recursion depth ({MaxRecursionDepth}) reached for type '{obj.GetType().Name}'"); +#endif return; } @@ -353,8 +362,12 @@ private void DiscoverNestedInitializerObjects( } catch (Exception ex) { +#if DEBUG // Log instead of silently swallowing - helps with debugging Debug.WriteLine($"[ObjectGraphDiscoverer] Failed to access property '{property.Name}' on type '{type.Name}': {ex.Message}"); +#endif + // Continue discovery despite property access failures + _ = ex; } } } @@ -373,7 +386,9 @@ private void DiscoverNestedInitializerObjectsForTracking( // Guard against excessive recursion to prevent stack overflow if (currentDepth > MaxRecursionDepth) { +#if DEBUG Debug.WriteLine($"[ObjectGraphDiscoverer] Max recursion depth ({MaxRecursionDepth}) reached for type '{obj.GetType().Name}'"); +#endif return; } @@ -396,14 +411,22 @@ private void DiscoverNestedInitializerObjectsForTracking( continue; } - if (value is IAsyncInitializer && GetOrAddHashSet(visitedObjects, currentDepth).Add(value)) + if (value is IAsyncInitializer && TryAddToHashSet(visitedObjects, currentDepth, value)) { DiscoverNestedObjectsForTracking(value, visitedObjects, currentDepth + 1); } } + catch (OperationCanceledException) + { + throw; // Propagate cancellation + } catch (Exception ex) { +#if DEBUG Debug.WriteLine($"[ObjectGraphDiscoverer] Failed to access property '{property.Name}' on type '{type.Name}': {ex.Message}"); +#endif + // Continue discovery despite property access failures + _ = ex; } } } @@ -416,22 +439,70 @@ private void DiscoverNestedInitializerObjectsForTracking( private static PropertyInfo[] GetCachedProperties(Type type) { // Periodic cleanup if cache grows too large to prevent memory leaks - if (PropertyCache.Count > MaxCacheSize) + // Use Interlocked to ensure only one thread performs cleanup at a time + if (PropertyCache.Count > MaxCacheSize && + Interlocked.CompareExchange(ref _cleanupInProgress, 1, 0) == 0) { - // Clear half the cache (simple approach - non-LRU for performance) - var keysToRemove = PropertyCache.Keys.Take(MaxCacheSize / 2).ToList(); - foreach (var key in keysToRemove) + try { - PropertyCache.TryRemove(key, out _); - } + // Double-check after acquiring cleanup flag + if (PropertyCache.Count > MaxCacheSize) + { + var keysToRemove = new List(MaxCacheSize / 2); + var count = 0; + foreach (var key in PropertyCache.Keys) + { + if (count++ >= MaxCacheSize / 2) + { + break; + } - Debug.WriteLine($"[ObjectGraphDiscoverer] PropertyCache exceeded {MaxCacheSize} entries, cleared {keysToRemove.Count} entries"); + keysToRemove.Add(key); + } + + foreach (var key in keysToRemove) + { + PropertyCache.TryRemove(key, out _); + } +#if DEBUG + Debug.WriteLine($"[ObjectGraphDiscoverer] PropertyCache exceeded {MaxCacheSize} entries, cleared {keysToRemove.Count} entries"); +#endif + } + } + finally + { + Interlocked.Exchange(ref _cleanupInProgress, 0); + } } - return PropertyCache.GetOrAdd(type, t => - t.GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance) - .Where(p => p.CanRead && p.GetIndexParameters().Length == 0) - .ToArray()); + return PropertyCache.GetOrAdd(type, static t => + { + // Use explicit loops instead of LINQ to avoid allocations in hot path + var allProps = t.GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance); + + // First pass: count eligible properties + var count = 0; + foreach (var p in allProps) + { + if (p.CanRead && p.GetIndexParameters().Length == 0) + { + count++; + } + } + + // Second pass: fill result array + var result = new PropertyInfo[count]; + var i = 0; + foreach (var p in allProps) + { + if (p.CanRead && p.GetIndexParameters().Length == 0) + { + result[i++] = p; + } + } + + return result; + }); } /// @@ -454,17 +525,26 @@ private static bool ShouldSkipType(Type type) /// /// Adds an object to the specified depth level. + /// Thread-safe: uses lock to protect HashSet modifications. /// private static void AddToDepth(ConcurrentDictionary> objectsByDepth, int depth, object obj) { - objectsByDepth.GetOrAdd(depth, _ => []).Add(obj); + var hashSet = objectsByDepth.GetOrAdd(depth, _ => new HashSet(ReferenceComparer)); + lock (hashSet) + { + hashSet.Add(obj); + } } /// - /// Gets or creates a HashSet at the specified depth (thread-safe). + /// Thread-safe add to HashSet at specified depth. Returns true if added (not duplicate). /// - private static HashSet GetOrAddHashSet(ConcurrentDictionary> dict, int depth) + private static bool TryAddToHashSet(ConcurrentDictionary> dict, int depth, object obj) { - return dict.GetOrAdd(depth, _ => []); + var hashSet = dict.GetOrAdd(depth, _ => new HashSet(ReferenceComparer)); + lock (hashSet) + { + return hashSet.Add(obj); + } } } diff --git a/TUnit.Core/Helpers/ReferenceEqualityComparer.cs b/TUnit.Core/Helpers/ReferenceEqualityComparer.cs index 2c26bfb5d1..7639a04c17 100644 --- a/TUnit.Core/Helpers/ReferenceEqualityComparer.cs +++ b/TUnit.Core/Helpers/ReferenceEqualityComparer.cs @@ -1,5 +1,11 @@ -namespace TUnit.Core.Helpers; +using System.Runtime.CompilerServices; +namespace TUnit.Core.Helpers; + +/// +/// Compares objects by reference identity, not value equality. +/// Uses RuntimeHelpers.GetHashCode to get identity-based hash codes. +/// public class ReferenceEqualityComparer : IEqualityComparer { public new bool Equals(object? x, object? y) @@ -9,6 +15,8 @@ public class ReferenceEqualityComparer : IEqualityComparer public int GetHashCode(object obj) { - return obj.GetHashCode(); + // Use RuntimeHelpers.GetHashCode for identity-based hash code + // This returns the same value as Object.GetHashCode() would if not overridden + return RuntimeHelpers.GetHashCode(obj); } } From 8589baa483013aa316e1d6e1f8ce2ed444da75a9 Mon Sep 17 00:00:00 2001 From: Tom Longhurst <30480171+thomhurst@users.noreply.github.com> Date: Sun, 7 Dec 2025 15:37:18 +0000 Subject: [PATCH 07/20] feat: implement IAsyncDiscoveryInitializer and related classes for improved test discovery handling --- TUnit.Core/Discovery/ObjectGraph.cs | 30 ++++++- TUnit.Core/Discovery/ObjectGraphDiscoverer.cs | 33 +++++--- TUnit.Core/Helpers/Disposer.cs | 9 ++ TUnit.Core/Tracking/ObjectTracker.cs | 84 ++++++++++++++++--- .../Tracking/TrackableObjectGraphProvider.cs | 17 ++-- .../Services/ObjectLifecycleService.cs | 79 +++++++++++++---- TUnit.Engine/Services/PropertyInjector.cs | 64 ++++++++++---- 7 files changed, 252 insertions(+), 64 deletions(-) diff --git a/TUnit.Core/Discovery/ObjectGraph.cs b/TUnit.Core/Discovery/ObjectGraph.cs index e0a86f3558..3d79a52785 100644 --- a/TUnit.Core/Discovery/ObjectGraph.cs +++ b/TUnit.Core/Discovery/ObjectGraph.cs @@ -34,11 +34,33 @@ public ObjectGraph(ConcurrentDictionary> objectsByDepth, Ha _objectsByDepth = objectsByDepth; _allObjects = allObjects; - // Use IsEmpty for thread-safe check before accessing Keys - MaxDepth = objectsByDepth.IsEmpty ? -1 : objectsByDepth.Keys.Max(); + // Compute MaxDepth and sorted depths without LINQ to reduce allocations + var keyCount = objectsByDepth.Count; + if (keyCount == 0) + { + MaxDepth = -1; + _sortedDepthsDescending = []; + } + else + { + var keys = new int[keyCount]; + objectsByDepth.Keys.CopyTo(keys, 0); - // Cache sorted depths (computed once, reused on each call to GetDepthsDescending) - _sortedDepthsDescending = objectsByDepth.Keys.OrderByDescending(d => d).ToArray(); + // Find max manually + var maxDepth = int.MinValue; + foreach (var key in keys) + { + if (key > maxDepth) + { + maxDepth = key; + } + } + MaxDepth = maxDepth; + + // Sort in descending order using Array.Sort with reverse comparison + Array.Sort(keys, (a, b) => b.CompareTo(a)); + _sortedDepthsDescending = keys; + } // Use Lazy with ExecutionAndPublication for thread-safe single initialization _lazyReadOnlyObjectsByDepth = new Lazy>>( diff --git a/TUnit.Core/Discovery/ObjectGraphDiscoverer.cs b/TUnit.Core/Discovery/ObjectGraphDiscoverer.cs index 2783c8169b..c3ebfb211f 100644 --- a/TUnit.Core/Discovery/ObjectGraphDiscoverer.cs +++ b/TUnit.Core/Discovery/ObjectGraphDiscoverer.cs @@ -126,33 +126,37 @@ public IObjectGraph DiscoverNestedObjectGraph(object rootObject, CancellationTok /// Used by TrackableObjectGraphProvider to populate TestContext.TrackedObjects. /// /// The test context to discover objects from. + /// Cancellation token for the operation. /// The tracked objects dictionary (same as testContext.TrackedObjects). - public ConcurrentDictionary> DiscoverAndTrackObjects(TestContext testContext) + public ConcurrentDictionary> DiscoverAndTrackObjects(TestContext testContext, CancellationToken cancellationToken = default) { var visitedObjects = testContext.TrackedObjects; var testDetails = testContext.Metadata.TestDetails; foreach (var classArgument in testDetails.TestClassArguments) { + cancellationToken.ThrowIfCancellationRequested(); if (classArgument != null && TryAddToHashSet(visitedObjects, 0, classArgument)) { - DiscoverNestedObjectsForTracking(classArgument, visitedObjects, 1); + DiscoverNestedObjectsForTracking(classArgument, visitedObjects, 1, cancellationToken); } } foreach (var methodArgument in testDetails.TestMethodArguments) { + cancellationToken.ThrowIfCancellationRequested(); if (methodArgument != null && TryAddToHashSet(visitedObjects, 0, methodArgument)) { - DiscoverNestedObjectsForTracking(methodArgument, visitedObjects, 1); + DiscoverNestedObjectsForTracking(methodArgument, visitedObjects, 1, cancellationToken); } } foreach (var property in testDetails.TestClassInjectedPropertyArguments.Values) { + cancellationToken.ThrowIfCancellationRequested(); if (property != null && TryAddToHashSet(visitedObjects, 0, property)) { - DiscoverNestedObjectsForTracking(property, visitedObjects, 1); + DiscoverNestedObjectsForTracking(property, visitedObjects, 1, cancellationToken); } } @@ -239,7 +243,8 @@ private void DiscoverNestedObjects( private void DiscoverNestedObjectsForTracking( object obj, ConcurrentDictionary> visitedObjects, - int currentDepth) + int currentDepth, + CancellationToken cancellationToken) { // Guard against excessive recursion to prevent stack overflow if (currentDepth > MaxRecursionDepth) @@ -250,6 +255,8 @@ private void DiscoverNestedObjectsForTracking( return; } + cancellationToken.ThrowIfCancellationRequested(); + var plan = PropertyInjectionCache.GetOrCreatePlan(obj.GetType()); // Check SourceRegistrar.IsEnabled for compatibility with existing TrackableObjectGraphProvider behavior @@ -257,6 +264,7 @@ private void DiscoverNestedObjectsForTracking( { foreach (var prop in plan.ReflectionProperties) { + cancellationToken.ThrowIfCancellationRequested(); var value = prop.Property.GetValue(obj); if (value == null) { @@ -268,13 +276,14 @@ private void DiscoverNestedObjectsForTracking( continue; } - DiscoverNestedObjectsForTracking(value, visitedObjects, currentDepth + 1); + DiscoverNestedObjectsForTracking(value, visitedObjects, currentDepth + 1, cancellationToken); } } else { foreach (var metadata in plan.SourceGeneratedProperties) { + cancellationToken.ThrowIfCancellationRequested(); var property = metadata.ContainingType.GetProperty(metadata.PropertyName); if (property == null || !property.CanRead) { @@ -292,12 +301,12 @@ private void DiscoverNestedObjectsForTracking( continue; } - DiscoverNestedObjectsForTracking(value, visitedObjects, currentDepth + 1); + DiscoverNestedObjectsForTracking(value, visitedObjects, currentDepth + 1, cancellationToken); } } // Also discover nested IAsyncInitializer objects from ALL properties - DiscoverNestedInitializerObjectsForTracking(obj, visitedObjects, currentDepth); + DiscoverNestedInitializerObjectsForTracking(obj, visitedObjects, currentDepth, cancellationToken); } /// @@ -381,7 +390,8 @@ private void DiscoverNestedInitializerObjects( private void DiscoverNestedInitializerObjectsForTracking( object obj, ConcurrentDictionary> visitedObjects, - int currentDepth) + int currentDepth, + CancellationToken cancellationToken) { // Guard against excessive recursion to prevent stack overflow if (currentDepth > MaxRecursionDepth) @@ -392,6 +402,8 @@ private void DiscoverNestedInitializerObjectsForTracking( return; } + cancellationToken.ThrowIfCancellationRequested(); + var type = obj.GetType(); if (ShouldSkipType(type)) @@ -403,6 +415,7 @@ private void DiscoverNestedInitializerObjectsForTracking( foreach (var property in properties) { + cancellationToken.ThrowIfCancellationRequested(); try { var value = property.GetValue(obj); @@ -413,7 +426,7 @@ private void DiscoverNestedInitializerObjectsForTracking( if (value is IAsyncInitializer && TryAddToHashSet(visitedObjects, currentDepth, value)) { - DiscoverNestedObjectsForTracking(value, visitedObjects, currentDepth + 1); + DiscoverNestedObjectsForTracking(value, visitedObjects, currentDepth + 1, cancellationToken); } } catch (OperationCanceledException) diff --git a/TUnit.Core/Helpers/Disposer.cs b/TUnit.Core/Helpers/Disposer.cs index 772b49de07..1bafb121ca 100644 --- a/TUnit.Core/Helpers/Disposer.cs +++ b/TUnit.Core/Helpers/Disposer.cs @@ -4,6 +4,10 @@ namespace TUnit.Core.Helpers; internal class Disposer(ILogger logger) { + /// + /// Disposes an object and propagates any exceptions. + /// Exceptions are logged but NOT swallowed - callers must handle them. + /// public async ValueTask DisposeAsync(object? obj) { try @@ -19,10 +23,15 @@ public async ValueTask DisposeAsync(object? obj) } catch (Exception e) { + // Log the error for diagnostics if (logger != null) { await logger.LogErrorAsync(e); } + + // Propagate the exception - don't silently swallow disposal failures + // Callers can catch and aggregate if disposing multiple objects + throw; } } } diff --git a/TUnit.Core/Tracking/ObjectTracker.cs b/TUnit.Core/Tracking/ObjectTracker.cs index 8e702a6262..d957f3e7e6 100644 --- a/TUnit.Core/Tracking/ObjectTracker.cs +++ b/TUnit.Core/Tracking/ObjectTracker.cs @@ -7,19 +7,46 @@ namespace TUnit.Core.Tracking; /// /// Pure reference counting object tracker for disposable objects. /// Objects are disposed when their reference count reaches zero, regardless of sharing type. +/// Uses ReferenceEqualityComparer to track objects by identity, not value equality. /// internal class ObjectTracker(TrackableObjectGraphProvider trackableObjectGraphProvider, Disposer disposer) { - private static readonly ConcurrentDictionary _trackedObjects = new(); + // Use ReferenceEqualityComparer to prevent objects with custom Equals from sharing state + private static readonly ConcurrentDictionary _trackedObjects = + new(new Helpers.ReferenceEqualityComparer()); public void TrackObjects(TestContext testContext) { - var alreadyTracked = testContext.TrackedObjects.SelectMany(x => x.Value).ToHashSet(); + // Build alreadyTracked set without LINQ to reduce allocations + var alreadyTracked = new HashSet(new Helpers.ReferenceEqualityComparer()); + foreach (var kvp in testContext.TrackedObjects) + { + // Lock while iterating to prevent concurrent modification + lock (kvp.Value) + { + foreach (var obj in kvp.Value) + { + alreadyTracked.Add(obj); + } + } + } - var newTrackableObjects = trackableObjectGraphProvider.GetTrackableObjects(testContext) - .SelectMany(x => x.Value) - .Except(alreadyTracked) - .ToHashSet(); + // Get new trackable objects without LINQ + var newTrackableObjects = new HashSet(new Helpers.ReferenceEqualityComparer()); + var trackableDict = trackableObjectGraphProvider.GetTrackableObjects(testContext); + foreach (var kvp in trackableDict) + { + lock (kvp.Value) + { + foreach (var obj in kvp.Value) + { + if (!alreadyTracked.Contains(obj)) + { + newTrackableObjects.Add(obj); + } + } + } + } foreach (var obj in newTrackableObjects) { @@ -29,9 +56,20 @@ public void TrackObjects(TestContext testContext) public async ValueTask UntrackObjects(TestContext testContext, List cleanupExceptions) { - foreach (var obj in testContext.TrackedObjects - .SelectMany(x => x.Value) - .ToHashSet()) + // Build objects set without LINQ to reduce allocations and with proper locking + var objectsToUntrack = new HashSet(new Helpers.ReferenceEqualityComparer()); + foreach (var kvp in testContext.TrackedObjects) + { + lock (kvp.Value) + { + foreach (var obj in kvp.Value) + { + objectsToUntrack.Add(obj); + } + } + } + + foreach (var obj in objectsToUntrack) { try { @@ -129,13 +167,37 @@ public static void OnDisposedAsync(object? o, Func asyncAction) return; } + // Avoid async void pattern by wrapping in fire-and-forget with exception handling _trackedObjects.GetOrAdd(o, static _ => new Counter()) - .OnCountChanged += async (_, count) => + .OnCountChanged += (_, count) => { if (count == 0) { - await asyncAction().ConfigureAwait(false); + // Fire-and-forget with exception handling to avoid unobserved exceptions + _ = SafeExecuteAsync(asyncAction); } }; } + + /// + /// Executes an async action safely, catching and logging any exceptions + /// to avoid unobserved task exceptions from fire-and-forget patterns. + /// + private static async Task SafeExecuteAsync(Func asyncAction) + { + try + { + await asyncAction().ConfigureAwait(false); + } + catch (Exception ex) + { + // Log to debug in DEBUG builds, otherwise swallow to prevent crashes + // The disposal itself already logged any errors +#if DEBUG + System.Diagnostics.Debug.WriteLine($"[ObjectTracker] Exception in OnDisposedAsync callback: {ex.Message}"); +#endif + // Prevent unobserved task exception from crashing the application + _ = ex; + } + } } diff --git a/TUnit.Core/Tracking/TrackableObjectGraphProvider.cs b/TUnit.Core/Tracking/TrackableObjectGraphProvider.cs index 585434ff2d..121c9da013 100644 --- a/TUnit.Core/Tracking/TrackableObjectGraphProvider.cs +++ b/TUnit.Core/Tracking/TrackableObjectGraphProvider.cs @@ -32,24 +32,31 @@ public TrackableObjectGraphProvider(IObjectGraphDiscoverer discoverer) /// Gets trackable objects from a test context, organized by depth level. /// Delegates to the shared ObjectGraphDiscoverer to eliminate code duplication. /// - public ConcurrentDictionary> GetTrackableObjects(TestContext testContext) + /// The test context to get trackable objects from. + /// Optional cancellation token for long-running discovery. + public ConcurrentDictionary> GetTrackableObjects(TestContext testContext, CancellationToken cancellationToken = default) { // Use the ObjectGraphDiscoverer's specialized method that populates TrackedObjects directly if (_discoverer is ObjectGraphDiscoverer concreteDiscoverer) { - return concreteDiscoverer.DiscoverAndTrackObjects(testContext); + return concreteDiscoverer.DiscoverAndTrackObjects(testContext, cancellationToken); } // Fallback for custom implementations (testing) - var graph = _discoverer.DiscoverObjectGraph(testContext); + var graph = _discoverer.DiscoverObjectGraph(testContext, cancellationToken); var trackedObjects = testContext.TrackedObjects; foreach (var (depth, objects) in graph.ObjectsByDepth) { + cancellationToken.ThrowIfCancellationRequested(); var depthSet = trackedObjects.GetOrAdd(depth, _ => []); - foreach (var obj in objects) + // Lock to ensure thread-safe HashSet modification + lock (depthSet) { - depthSet.Add(obj); + foreach (var obj in objects) + { + depthSet.Add(obj); + } } } diff --git a/TUnit.Engine/Services/ObjectLifecycleService.cs b/TUnit.Engine/Services/ObjectLifecycleService.cs index 4e9b97fb76..30fbb40f9b 100644 --- a/TUnit.Engine/Services/ObjectLifecycleService.cs +++ b/TUnit.Engine/Services/ObjectLifecycleService.cs @@ -1,6 +1,7 @@ using System.Collections.Concurrent; using System.Diagnostics.CodeAnalysis; using TUnit.Core; +using TUnit.Core.Helpers; using TUnit.Core.Interfaces; using TUnit.Core.PropertyInjection; using TUnit.Core.PropertyInjection.Initialization; @@ -28,7 +29,9 @@ internal sealed class ObjectLifecycleService : IObjectRegistry, IInitializationC private readonly ObjectTracker _objectTracker; // Track initialization state per object - private readonly ConcurrentDictionary> _initializationTasks = new(); + // Use ReferenceEqualityComparer to prevent objects with custom Equals from sharing initialization state + private readonly ConcurrentDictionary> _initializationTasks = + new(new Core.Helpers.ReferenceEqualityComparer()); public ObjectLifecycleService( Lazy propertyInjector, @@ -98,7 +101,8 @@ public async Task RegisterArgumentsAsync( return; } - var tasks = new List(); + // Pre-allocate with expected capacity to avoid resizing + var tasks = new List(arguments.Length); foreach (var argument in arguments) { if (argument != null) @@ -107,7 +111,10 @@ public async Task RegisterArgumentsAsync( } } - await Task.WhenAll(tasks); + if (tasks.Count > 0) + { + await Task.WhenAll(tasks); + } } #endregion @@ -191,20 +198,40 @@ private void SetCachedPropertiesOnInstance(object instance, TestContext testCont /// private async Task InitializeTrackedObjectsAsync(TestContext testContext, CancellationToken cancellationToken) { - var levels = testContext.TrackedObjects.Keys.OrderByDescending(level => level); + // Get levels without LINQ - use Array.Sort with reverse comparison for descending order + var trackedObjects = testContext.TrackedObjects; + var levelCount = trackedObjects.Count; - foreach (var level in levels) + if (levelCount > 0) { - var objectsAtLevel = testContext.TrackedObjects[level]; + var levels = new int[levelCount]; + trackedObjects.Keys.CopyTo(levels, 0); + Array.Sort(levels, (a, b) => b.CompareTo(a)); // Descending order - // Initialize each tracked object and its nested objects - foreach (var obj in objectsAtLevel) + foreach (var level in levels) { - // First initialize nested objects depth-first - await InitializeNestedObjectsForExecutionAsync(obj, cancellationToken); + if (!trackedObjects.TryGetValue(level, out var objectsAtLevel)) + { + continue; + } + + // Copy to array under lock to prevent concurrent modification + object[] objectsCopy; + lock (objectsAtLevel) + { + objectsCopy = new object[objectsAtLevel.Count]; + objectsAtLevel.CopyTo(objectsCopy); + } - // Then initialize the object itself - await ObjectInitializer.InitializeAsync(obj, cancellationToken); + // Initialize each tracked object and its nested objects + foreach (var obj in objectsCopy) + { + // First initialize nested objects depth-first + await InitializeNestedObjectsForExecutionAsync(obj, cancellationToken); + + // Then initialize the object itself + await ObjectInitializer.InitializeAsync(obj, cancellationToken); + } } } @@ -228,9 +255,17 @@ private async Task InitializeNestedObjectsForExecutionAsync(object rootObject, C var objectsAtDepth = graph.GetObjectsAtDepth(depth); - // Initialize all IAsyncInitializer objects at this depth - await Task.WhenAll(objectsAtDepth - .Select(obj => ObjectInitializer.InitializeAsync(obj, cancellationToken).AsTask())); + // Pre-allocate task list without LINQ Select + var tasks = new List(); + foreach (var obj in objectsAtDepth) + { + tasks.Add(ObjectInitializer.InitializeAsync(obj, cancellationToken).AsTask()); + } + + if (tasks.Count > 0) + { + await Task.WhenAll(tasks); + } } } @@ -361,9 +396,17 @@ private async Task InitializeNestedObjectsForDiscoveryAsync(object rootObject, C var objectsAtDepth = graph.GetObjectsAtDepth(depth); - // Only initialize IAsyncDiscoveryInitializer objects during discovery - await Task.WhenAll(objectsAtDepth - .Select(obj => ObjectInitializer.InitializeForDiscoveryAsync(obj, cancellationToken).AsTask())); + // Pre-allocate task list without LINQ Select + var tasks = new List(); + foreach (var obj in objectsAtDepth) + { + tasks.Add(ObjectInitializer.InitializeForDiscoveryAsync(obj, cancellationToken).AsTask()); + } + + if (tasks.Count > 0) + { + await Task.WhenAll(tasks); + } } } diff --git a/TUnit.Engine/Services/PropertyInjector.cs b/TUnit.Engine/Services/PropertyInjector.cs index 4de82d7d5f..1073da6ed4 100644 --- a/TUnit.Engine/Services/PropertyInjector.cs +++ b/TUnit.Engine/Services/PropertyInjector.cs @@ -128,17 +128,29 @@ public async Task InjectPropertiesIntoArgumentsAsync( return; } - var injectableArgs = arguments - .Where(arg => arg != null && PropertyInjectionCache.HasInjectableProperties(arg.GetType())) - .ToArray(); + // Build list of injectable args without LINQ + var injectableArgs = new List(arguments.Length); + foreach (var arg in arguments) + { + if (arg != null && PropertyInjectionCache.HasInjectableProperties(arg.GetType())) + { + injectableArgs.Add(arg); + } + } - if (injectableArgs.Length == 0) + if (injectableArgs.Count == 0) { return; } - await Task.WhenAll(injectableArgs.Select(arg => - InjectPropertiesAsync(arg!, objectBag, methodMetadata, events))); + // Build task list without LINQ Select + var tasks = new List(injectableArgs.Count); + foreach (var arg in injectableArgs) + { + tasks.Add(InjectPropertiesAsync(arg, objectBag, methodMetadata, events)); + } + + await Task.WhenAll(tasks); } private async Task InjectPropertiesRecursiveAsync( @@ -201,9 +213,13 @@ private async Task InjectSourceGeneratedPropertiesAsync( return; } - // Initialize properties in parallel - await Task.WhenAll(properties.Select(metadata => - InjectSourceGeneratedPropertyAsync(instance, metadata, objectBag, methodMetadata, events, visitedObjects))); + // Initialize properties in parallel without LINQ Select + var tasks = new Task[properties.Length]; + for (var i = 0; i < properties.Length; i++) + { + tasks[i] = InjectSourceGeneratedPropertyAsync(instance, properties[i], objectBag, methodMetadata, events, visitedObjects); + } + await Task.WhenAll(tasks); } [UnconditionalSuppressMessage("Trimming", "IL2026", Justification = "Source-gen properties are AOT-safe")] @@ -289,8 +305,14 @@ private async Task InjectReflectionPropertiesAsync( return; } - await Task.WhenAll(properties.Select(pair => - InjectReflectionPropertyAsync(instance, pair.Property, pair.DataSource, objectBag, methodMetadata, events, visitedObjects))); + // Initialize properties in parallel without LINQ Select + var tasks = new Task[properties.Length]; + for (var i = 0; i < properties.Length; i++) + { + var pair = properties[i]; + tasks[i] = InjectReflectionPropertyAsync(instance, pair.Property, pair.DataSource, objectBag, methodMetadata, events, visitedObjects); + } + await Task.WhenAll(tasks); } [UnconditionalSuppressMessage("Trimming", "IL2026", Justification = "Reflection mode is not used in AOT")] @@ -396,9 +418,13 @@ private async Task ResolveAndCacheSourceGeneratedPropertiesAsync( return; } - // Resolve properties in parallel - await Task.WhenAll(properties.Select(metadata => - ResolveAndCacheSourceGeneratedPropertyAsync(metadata, objectBag, methodMetadata, events, testContext))); + // Resolve properties in parallel without LINQ Select + var tasks = new Task[properties.Length]; + for (var i = 0; i < properties.Length; i++) + { + tasks[i] = ResolveAndCacheSourceGeneratedPropertyAsync(properties[i], objectBag, methodMetadata, events, testContext); + } + await Task.WhenAll(tasks); } [UnconditionalSuppressMessage("Trimming", "IL2026", Justification = "Source-gen properties are AOT-safe")] @@ -455,8 +481,14 @@ private async Task ResolveAndCacheReflectionPropertiesAsync( return; } - await Task.WhenAll(properties.Select(pair => - ResolveAndCacheReflectionPropertyAsync(pair.Property, pair.DataSource, objectBag, methodMetadata, events, testContext))); + // Resolve properties in parallel without LINQ Select + var tasks = new Task[properties.Length]; + for (var i = 0; i < properties.Length; i++) + { + var pair = properties[i]; + tasks[i] = ResolveAndCacheReflectionPropertyAsync(pair.Property, pair.DataSource, objectBag, methodMetadata, events, testContext); + } + await Task.WhenAll(tasks); } [UnconditionalSuppressMessage("Trimming", "IL2026", Justification = "Reflection mode is not used in AOT")] From 5c78262540d285868287f67cea17380a97997ebf Mon Sep 17 00:00:00 2001 From: Tom Longhurst <30480171+thomhurst@users.noreply.github.com> Date: Sun, 7 Dec 2025 15:56:09 +0000 Subject: [PATCH 08/20] feat: implement IAsyncDiscoveryInitializer and related classes for improved test discovery handling --- TUnit.Core/Discovery/ObjectGraphDiscoverer.cs | 531 ++++++++---------- .../Helpers/ReferenceEqualityComparer.cs | 14 +- .../Interfaces/IObjectGraphDiscoverer.cs | 11 + TUnit.Core/ObjectInitializer.cs | 2 +- TUnit.Core/Tracking/ObjectTracker.cs | 8 +- .../Tracking/TrackableObjectGraphProvider.cs | 29 +- .../Services/ObjectLifecycleService.cs | 2 +- 7 files changed, 276 insertions(+), 321 deletions(-) diff --git a/TUnit.Core/Discovery/ObjectGraphDiscoverer.cs b/TUnit.Core/Discovery/ObjectGraphDiscoverer.cs index c3ebfb211f..13c75b926c 100644 --- a/TUnit.Core/Discovery/ObjectGraphDiscoverer.cs +++ b/TUnit.Core/Discovery/ObjectGraphDiscoverer.cs @@ -44,7 +44,7 @@ public sealed class ObjectGraphDiscoverer : IObjectGraphDiscoverer private static int _cleanupInProgress; // Reference equality comparer for object tracking (ignores Equals overrides) - private static readonly Helpers.ReferenceEqualityComparer ReferenceComparer = new(); + private static readonly Helpers.ReferenceEqualityComparer ReferenceComparer = Helpers.ReferenceEqualityComparer.Instance; // Types to skip during discovery (primitives, strings, system types) private static readonly HashSet SkipTypes = @@ -57,49 +57,48 @@ public sealed class ObjectGraphDiscoverer : IObjectGraphDiscoverer typeof(Guid) ]; + /// + /// Delegate for adding discovered objects to collections. + /// Returns true if the object was newly added (not a duplicate). + /// + private delegate bool TryAddObjectFunc(object obj, int depth); + + /// + /// Delegate for recursive discovery after an object is added. + /// + private delegate void RecurseFunc(object obj, int depth); + + /// + /// Delegate for processing a root object after it's been added. + /// + private delegate void RootObjectCallback(object obj); + /// public IObjectGraph DiscoverObjectGraph(TestContext testContext, CancellationToken cancellationToken = default) { var objectsByDepth = new ConcurrentDictionary>(); var allObjects = new HashSet(); - // Use ConcurrentDictionary for thread-safe visited tracking with reference equality var visitedObjects = new ConcurrentDictionary(ReferenceComparer); - var testDetails = testContext.Metadata.TestDetails; - - // Collect root-level objects (depth 0) - foreach (var classArgument in testDetails.TestClassArguments) + // Standard mode add callback + bool TryAddStandard(object obj, int depth) { - cancellationToken.ThrowIfCancellationRequested(); - if (classArgument != null && visitedObjects.TryAdd(classArgument, 0)) + if (!visitedObjects.TryAdd(obj, 0)) { - AddToDepth(objectsByDepth, 0, classArgument); - allObjects.Add(classArgument); - DiscoverNestedObjects(classArgument, objectsByDepth, visitedObjects, allObjects, currentDepth: 1, cancellationToken); + return false; } - } - foreach (var methodArgument in testDetails.TestMethodArguments) - { - cancellationToken.ThrowIfCancellationRequested(); - if (methodArgument != null && visitedObjects.TryAdd(methodArgument, 0)) - { - AddToDepth(objectsByDepth, 0, methodArgument); - allObjects.Add(methodArgument); - DiscoverNestedObjects(methodArgument, objectsByDepth, visitedObjects, allObjects, currentDepth: 1, cancellationToken); - } + AddToDepth(objectsByDepth, depth, obj); + allObjects.Add(obj); + return true; } - foreach (var property in testDetails.TestClassInjectedPropertyArguments.Values) - { - cancellationToken.ThrowIfCancellationRequested(); - if (property != null && visitedObjects.TryAdd(property, 0)) - { - AddToDepth(objectsByDepth, 0, property); - allObjects.Add(property); - DiscoverNestedObjects(property, objectsByDepth, visitedObjects, allObjects, currentDepth: 1, cancellationToken); - } - } + // Collect root-level objects and discover nested objects + CollectRootObjects( + testContext.Metadata.TestDetails, + TryAddStandard, + obj => DiscoverNestedObjects(obj, objectsByDepth, visitedObjects, allObjects, currentDepth: 1, cancellationToken), + cancellationToken); return new ObjectGraph(objectsByDepth, allObjects); } @@ -131,42 +130,21 @@ public IObjectGraph DiscoverNestedObjectGraph(object rootObject, CancellationTok public ConcurrentDictionary> DiscoverAndTrackObjects(TestContext testContext, CancellationToken cancellationToken = default) { var visitedObjects = testContext.TrackedObjects; - var testDetails = testContext.Metadata.TestDetails; - - foreach (var classArgument in testDetails.TestClassArguments) - { - cancellationToken.ThrowIfCancellationRequested(); - if (classArgument != null && TryAddToHashSet(visitedObjects, 0, classArgument)) - { - DiscoverNestedObjectsForTracking(classArgument, visitedObjects, 1, cancellationToken); - } - } - - foreach (var methodArgument in testDetails.TestMethodArguments) - { - cancellationToken.ThrowIfCancellationRequested(); - if (methodArgument != null && TryAddToHashSet(visitedObjects, 0, methodArgument)) - { - DiscoverNestedObjectsForTracking(methodArgument, visitedObjects, 1, cancellationToken); - } - } - foreach (var property in testDetails.TestClassInjectedPropertyArguments.Values) - { - cancellationToken.ThrowIfCancellationRequested(); - if (property != null && TryAddToHashSet(visitedObjects, 0, property)) - { - DiscoverNestedObjectsForTracking(property, visitedObjects, 1, cancellationToken); - } - } + // Collect root-level objects and discover nested objects for tracking + CollectRootObjects( + testContext.Metadata.TestDetails, + (obj, depth) => TryAddToHashSet(visitedObjects, depth, obj), + obj => DiscoverNestedObjectsForTracking(obj, visitedObjects, 1, cancellationToken), + cancellationToken); return visitedObjects; } /// /// Recursively discovers nested objects that have injectable properties OR implement IAsyncInitializer. + /// Uses consolidated TraverseInjectableProperties and TraverseInitializerProperties methods. /// - [UnconditionalSuppressMessage("Trimming", "IL2026", Justification = "Property discovery handles both AOT and reflection modes")] private void DiscoverNestedObjects( object obj, ConcurrentDictionary> objectsByDepth, @@ -175,273 +153,73 @@ private void DiscoverNestedObjects( int currentDepth, CancellationToken cancellationToken) { - // Guard against excessive recursion to prevent stack overflow - if (currentDepth > MaxRecursionDepth) + if (!CheckRecursionDepth(obj, currentDepth)) { -#if DEBUG - Debug.WriteLine($"[ObjectGraphDiscoverer] Max recursion depth ({MaxRecursionDepth}) reached for type '{obj.GetType().Name}'"); -#endif return; } cancellationToken.ThrowIfCancellationRequested(); - var plan = PropertyInjectionCache.GetOrCreatePlan(obj.GetType()); - - // First, discover objects from injectable properties (data source attributes) - if (plan.HasProperties) + // Standard mode add callback: visitedObjects + objectsByDepth + allObjects + bool TryAddStandard(object value, int depth) { - // Use source-generated properties if available, otherwise fall back to reflection - if (plan.SourceGeneratedProperties.Length > 0) + if (!visitedObjects.TryAdd(value, 0)) { - foreach (var metadata in plan.SourceGeneratedProperties) - { - cancellationToken.ThrowIfCancellationRequested(); - var property = metadata.ContainingType.GetProperty(metadata.PropertyName); - if (property == null || !property.CanRead) - { - continue; - } - - var value = property.GetValue(obj); - if (value == null || !visitedObjects.TryAdd(value, 0)) - { - continue; - } - - AddToDepth(objectsByDepth, currentDepth, value); - allObjects.Add(value); - DiscoverNestedObjects(value, objectsByDepth, visitedObjects, allObjects, currentDepth + 1, cancellationToken); - } + return false; } - else if (plan.ReflectionProperties.Length > 0) - { - foreach (var (property, _) in plan.ReflectionProperties) - { - cancellationToken.ThrowIfCancellationRequested(); - var value = property.GetValue(obj); - if (value == null || !visitedObjects.TryAdd(value, 0)) - { - continue; - } - AddToDepth(objectsByDepth, currentDepth, value); - allObjects.Add(value); - DiscoverNestedObjects(value, objectsByDepth, visitedObjects, allObjects, currentDepth + 1, cancellationToken); - } - } + AddToDepth(objectsByDepth, depth, value); + allObjects.Add(value); + return true; } - // Also discover nested IAsyncInitializer objects from ALL properties - DiscoverNestedInitializerObjects(obj, objectsByDepth, visitedObjects, allObjects, currentDepth, cancellationToken); - } - - /// - /// Discovers nested objects for tracking (uses HashSet pattern for compatibility with TestContext.TrackedObjects). - /// - [UnconditionalSuppressMessage("Trimming", "IL2026", Justification = "Property discovery handles both AOT and reflection modes")] - private void DiscoverNestedObjectsForTracking( - object obj, - ConcurrentDictionary> visitedObjects, - int currentDepth, - CancellationToken cancellationToken) - { - // Guard against excessive recursion to prevent stack overflow - if (currentDepth > MaxRecursionDepth) + // Recursive callback + void Recurse(object value, int depth) { -#if DEBUG - Debug.WriteLine($"[ObjectGraphDiscoverer] Max recursion depth ({MaxRecursionDepth}) reached for type '{obj.GetType().Name}'"); -#endif - return; + DiscoverNestedObjects(value, objectsByDepth, visitedObjects, allObjects, depth, cancellationToken); } - cancellationToken.ThrowIfCancellationRequested(); - - var plan = PropertyInjectionCache.GetOrCreatePlan(obj.GetType()); - - // Check SourceRegistrar.IsEnabled for compatibility with existing TrackableObjectGraphProvider behavior - if (!SourceRegistrar.IsEnabled) - { - foreach (var prop in plan.ReflectionProperties) - { - cancellationToken.ThrowIfCancellationRequested(); - var value = prop.Property.GetValue(obj); - if (value == null) - { - continue; - } - - if (!TryAddToHashSet(visitedObjects, currentDepth, value)) - { - continue; - } - - DiscoverNestedObjectsForTracking(value, visitedObjects, currentDepth + 1, cancellationToken); - } - } - else - { - foreach (var metadata in plan.SourceGeneratedProperties) - { - cancellationToken.ThrowIfCancellationRequested(); - var property = metadata.ContainingType.GetProperty(metadata.PropertyName); - if (property == null || !property.CanRead) - { - continue; - } - - var value = property.GetValue(obj); - if (value == null) - { - continue; - } - - if (!TryAddToHashSet(visitedObjects, currentDepth, value)) - { - continue; - } - - DiscoverNestedObjectsForTracking(value, visitedObjects, currentDepth + 1, cancellationToken); - } - } + // Traverse injectable properties (useSourceRegistrarCheck = false) + TraverseInjectableProperties(obj, TryAddStandard, Recurse, currentDepth, cancellationToken, useSourceRegistrarCheck: false); // Also discover nested IAsyncInitializer objects from ALL properties - DiscoverNestedInitializerObjectsForTracking(obj, visitedObjects, currentDepth, cancellationToken); + TraverseInitializerProperties(obj, TryAddStandard, Recurse, currentDepth, cancellationToken); } /// - /// Discovers nested objects that implement IAsyncInitializer from all readable properties. - /// Uses cached reflection for performance. + /// Discovers nested objects for tracking (uses HashSet pattern for compatibility with TestContext.TrackedObjects). + /// Uses consolidated TraverseInjectableProperties and TraverseInitializerProperties methods. /// - [UnconditionalSuppressMessage("Trimming", "IL2026", Justification = "Reflection fallback for nested initializers. In AOT, source-gen handles primary discovery.")] - [UnconditionalSuppressMessage("Trimming", "IL2070", Justification = "Reflection fallback for nested initializers. In AOT, source-gen handles primary discovery.")] - [UnconditionalSuppressMessage("Trimming", "IL2075", Justification = "Reflection fallback for nested initializers. In AOT, source-gen handles primary discovery.")] - private void DiscoverNestedInitializerObjects( + private void DiscoverNestedObjectsForTracking( object obj, - ConcurrentDictionary> objectsByDepth, - ConcurrentDictionary visitedObjects, - HashSet allObjects, + ConcurrentDictionary> visitedObjects, int currentDepth, CancellationToken cancellationToken) { - // Guard against excessive recursion to prevent stack overflow - if (currentDepth > MaxRecursionDepth) + if (!CheckRecursionDepth(obj, currentDepth)) { -#if DEBUG - Debug.WriteLine($"[ObjectGraphDiscoverer] Max recursion depth ({MaxRecursionDepth}) reached for type '{obj.GetType().Name}'"); -#endif return; } cancellationToken.ThrowIfCancellationRequested(); - var type = obj.GetType(); - - // Skip types that don't need discovery - if (ShouldSkipType(type)) + // Tracking mode add callback: TryAddToHashSet only + bool TryAddTracking(object value, int depth) { - return; + return TryAddToHashSet(visitedObjects, depth, value); } - // Use cached properties for performance - var properties = GetCachedProperties(type); - - foreach (var property in properties) + // Recursive callback + void Recurse(object value, int depth) { - cancellationToken.ThrowIfCancellationRequested(); - try - { - var value = property.GetValue(obj); - if (value == null) - { - continue; - } - - // Only discover if it implements IAsyncInitializer and hasn't been visited - if (value is IAsyncInitializer && visitedObjects.TryAdd(value, 0)) - { - AddToDepth(objectsByDepth, currentDepth, value); - allObjects.Add(value); - DiscoverNestedObjects(value, objectsByDepth, visitedObjects, allObjects, currentDepth + 1, cancellationToken); - } - } - catch (OperationCanceledException) - { - throw; // Propagate cancellation - } - catch (Exception ex) - { -#if DEBUG - // Log instead of silently swallowing - helps with debugging - Debug.WriteLine($"[ObjectGraphDiscoverer] Failed to access property '{property.Name}' on type '{type.Name}': {ex.Message}"); -#endif - // Continue discovery despite property access failures - _ = ex; - } + DiscoverNestedObjectsForTracking(value, visitedObjects, depth, cancellationToken); } - } - /// - /// Discovers nested IAsyncInitializer objects for tracking (uses HashSet pattern). - /// - [UnconditionalSuppressMessage("Trimming", "IL2026", Justification = "Reflection fallback for nested initializers. In AOT, source-gen handles primary discovery.")] - [UnconditionalSuppressMessage("Trimming", "IL2070", Justification = "Reflection fallback for nested initializers. In AOT, source-gen handles primary discovery.")] - [UnconditionalSuppressMessage("Trimming", "IL2075", Justification = "Reflection fallback for nested initializers. In AOT, source-gen handles primary discovery.")] - private void DiscoverNestedInitializerObjectsForTracking( - object obj, - ConcurrentDictionary> visitedObjects, - int currentDepth, - CancellationToken cancellationToken) - { - // Guard against excessive recursion to prevent stack overflow - if (currentDepth > MaxRecursionDepth) - { -#if DEBUG - Debug.WriteLine($"[ObjectGraphDiscoverer] Max recursion depth ({MaxRecursionDepth}) reached for type '{obj.GetType().Name}'"); -#endif - return; - } - - cancellationToken.ThrowIfCancellationRequested(); - - var type = obj.GetType(); - - if (ShouldSkipType(type)) - { - return; - } - - var properties = GetCachedProperties(type); - - foreach (var property in properties) - { - cancellationToken.ThrowIfCancellationRequested(); - try - { - var value = property.GetValue(obj); - if (value == null) - { - continue; - } + // Traverse injectable properties (useSourceRegistrarCheck = true for tracking mode) + TraverseInjectableProperties(obj, TryAddTracking, Recurse, currentDepth, cancellationToken, useSourceRegistrarCheck: true); - if (value is IAsyncInitializer && TryAddToHashSet(visitedObjects, currentDepth, value)) - { - DiscoverNestedObjectsForTracking(value, visitedObjects, currentDepth + 1, cancellationToken); - } - } - catch (OperationCanceledException) - { - throw; // Propagate cancellation - } - catch (Exception ex) - { -#if DEBUG - Debug.WriteLine($"[ObjectGraphDiscoverer] Failed to access property '{property.Name}' on type '{type.Name}': {ex.Message}"); -#endif - // Continue discovery despite property access failures - _ = ex; - } - } + // Also discover nested IAsyncInitializer objects from ALL properties + TraverseInitializerProperties(obj, TryAddTracking, Recurse, currentDepth, cancellationToken); } /// @@ -560,4 +338,181 @@ private static bool TryAddToHashSet(ConcurrentDictionary> d return hashSet.Add(obj); } } + + #region Consolidated Traversal Methods (DRY) + + /// + /// Checks recursion depth guard. Returns false if depth exceeded (caller should return early). + /// + private static bool CheckRecursionDepth(object obj, int currentDepth) + { + if (currentDepth > MaxRecursionDepth) + { +#if DEBUG + Debug.WriteLine($"[ObjectGraphDiscoverer] Max recursion depth ({MaxRecursionDepth}) reached for type '{obj.GetType().Name}'"); +#endif + return false; + } + + return true; + } + + /// + /// Unified traversal for injectable properties (from PropertyInjectionCache). + /// Eliminates duplicate code between DiscoverNestedObjects and DiscoverNestedObjectsForTracking. + /// + [UnconditionalSuppressMessage("Trimming", "IL2026", Justification = "Property discovery handles both AOT and reflection modes")] + private static void TraverseInjectableProperties( + object obj, + TryAddObjectFunc tryAdd, + RecurseFunc recurse, + int currentDepth, + CancellationToken cancellationToken, + bool useSourceRegistrarCheck) + { + var plan = PropertyInjectionCache.GetOrCreatePlan(obj.GetType()); + + if (!plan.HasProperties && !useSourceRegistrarCheck) + { + return; + } + + // The two modes differ in how they choose source-gen vs reflection: + // - Standard mode: Uses plan.SourceGeneratedProperties.Length > 0 + // - Tracking mode: Uses SourceRegistrar.IsEnabled + bool useSourceGen = useSourceRegistrarCheck + ? SourceRegistrar.IsEnabled + : plan.SourceGeneratedProperties.Length > 0; + + if (useSourceGen) + { + foreach (var metadata in plan.SourceGeneratedProperties) + { + cancellationToken.ThrowIfCancellationRequested(); + var property = metadata.ContainingType.GetProperty(metadata.PropertyName); + if (property == null || !property.CanRead) + { + continue; + } + + var value = property.GetValue(obj); + if (value != null && tryAdd(value, currentDepth)) + { + recurse(value, currentDepth + 1); + } + } + } + else + { + // Reflection path - use the appropriate property collection + var reflectionProps = useSourceRegistrarCheck + ? plan.ReflectionProperties + : (plan.ReflectionProperties.Length > 0 ? plan.ReflectionProperties : []); + + foreach (var prop in reflectionProps) + { + cancellationToken.ThrowIfCancellationRequested(); + var value = prop.Property.GetValue(obj); + if (value != null && tryAdd(value, currentDepth)) + { + recurse(value, currentDepth + 1); + } + } + } + } + + /// + /// Unified traversal for IAsyncInitializer objects (from all properties). + /// Eliminates duplicate code between DiscoverNestedInitializerObjects and DiscoverNestedInitializerObjectsForTracking. + /// + [UnconditionalSuppressMessage("Trimming", "IL2026", Justification = "Reflection fallback for nested initializers. In AOT, source-gen handles primary discovery.")] + [UnconditionalSuppressMessage("Trimming", "IL2070", Justification = "Reflection fallback for nested initializers. In AOT, source-gen handles primary discovery.")] + [UnconditionalSuppressMessage("Trimming", "IL2075", Justification = "Reflection fallback for nested initializers. In AOT, source-gen handles primary discovery.")] + private static void TraverseInitializerProperties( + object obj, + TryAddObjectFunc tryAdd, + RecurseFunc recurse, + int currentDepth, + CancellationToken cancellationToken) + { + var type = obj.GetType(); + + if (ShouldSkipType(type)) + { + return; + } + + var properties = GetCachedProperties(type); + + foreach (var property in properties) + { + cancellationToken.ThrowIfCancellationRequested(); + try + { + var value = property.GetValue(obj); + if (value == null) + { + continue; + } + + // Only discover IAsyncInitializer objects + if (value is IAsyncInitializer && tryAdd(value, currentDepth)) + { + recurse(value, currentDepth + 1); + } + } + catch (OperationCanceledException) + { + throw; // Propagate cancellation + } + catch (Exception ex) + { +#if DEBUG + Debug.WriteLine($"[ObjectGraphDiscoverer] Failed to access property '{property.Name}' on type '{type.Name}': {ex.Message}"); +#endif + // Continue discovery despite property access failures + _ = ex; + } + } + } + + /// + /// Collects root-level objects (class args, method args, properties) from test details. + /// Eliminates duplicate loops in DiscoverObjectGraph and DiscoverAndTrackObjects. + /// + private static void CollectRootObjects( + TestDetails testDetails, + TryAddObjectFunc tryAdd, + RootObjectCallback onRootObjectAdded, + CancellationToken cancellationToken) + { + foreach (var classArgument in testDetails.TestClassArguments) + { + cancellationToken.ThrowIfCancellationRequested(); + if (classArgument != null && tryAdd(classArgument, 0)) + { + onRootObjectAdded(classArgument); + } + } + + foreach (var methodArgument in testDetails.TestMethodArguments) + { + cancellationToken.ThrowIfCancellationRequested(); + if (methodArgument != null && tryAdd(methodArgument, 0)) + { + onRootObjectAdded(methodArgument); + } + } + + foreach (var property in testDetails.TestClassInjectedPropertyArguments.Values) + { + cancellationToken.ThrowIfCancellationRequested(); + if (property != null && tryAdd(property, 0)) + { + onRootObjectAdded(property); + } + } + } + + #endregion } diff --git a/TUnit.Core/Helpers/ReferenceEqualityComparer.cs b/TUnit.Core/Helpers/ReferenceEqualityComparer.cs index 7639a04c17..e2058527f0 100644 --- a/TUnit.Core/Helpers/ReferenceEqualityComparer.cs +++ b/TUnit.Core/Helpers/ReferenceEqualityComparer.cs @@ -6,8 +6,20 @@ namespace TUnit.Core.Helpers; /// Compares objects by reference identity, not value equality. /// Uses RuntimeHelpers.GetHashCode to get identity-based hash codes. /// -public class ReferenceEqualityComparer : IEqualityComparer +public sealed class ReferenceEqualityComparer : IEqualityComparer { + /// + /// Singleton instance to avoid repeated allocations. + /// + public static readonly ReferenceEqualityComparer Instance = new(); + + /// + /// Private constructor to enforce singleton pattern. + /// + private ReferenceEqualityComparer() + { + } + public new bool Equals(object? x, object? y) { return ReferenceEquals(x, y); diff --git a/TUnit.Core/Interfaces/IObjectGraphDiscoverer.cs b/TUnit.Core/Interfaces/IObjectGraphDiscoverer.cs index b2e2842d70..6d064b4be0 100644 --- a/TUnit.Core/Interfaces/IObjectGraphDiscoverer.cs +++ b/TUnit.Core/Interfaces/IObjectGraphDiscoverer.cs @@ -45,6 +45,17 @@ public interface IObjectGraphDiscoverer /// Higher depths contain nested objects. /// IObjectGraph DiscoverNestedObjectGraph(object rootObject, CancellationToken cancellationToken = default); + + /// + /// Discovers objects and populates the test context's tracked objects dictionary directly. + /// Used for efficient object tracking without intermediate allocations. + /// + /// The test context to discover objects from and populate. + /// Optional cancellation token for long-running discovery. + /// + /// The tracked objects dictionary (same as testContext.TrackedObjects) populated with discovered objects. + /// + ConcurrentDictionary> DiscoverAndTrackObjects(TestContext testContext, CancellationToken cancellationToken = default); } /// diff --git a/TUnit.Core/ObjectInitializer.cs b/TUnit.Core/ObjectInitializer.cs index 889a7f4dfa..24efeeed71 100644 --- a/TUnit.Core/ObjectInitializer.cs +++ b/TUnit.Core/ObjectInitializer.cs @@ -24,7 +24,7 @@ public static class ObjectInitializer // even under contention. GetOrAdd's factory can be called multiple times, but with // Lazy + ExecutionAndPublication mode, only one initialization actually runs. private static readonly ConcurrentDictionary> InitializationTasks = - new(new Helpers.ReferenceEqualityComparer()); + new(Helpers.ReferenceEqualityComparer.Instance); /// /// Initializes an object during the discovery phase. diff --git a/TUnit.Core/Tracking/ObjectTracker.cs b/TUnit.Core/Tracking/ObjectTracker.cs index d957f3e7e6..86a8afab6d 100644 --- a/TUnit.Core/Tracking/ObjectTracker.cs +++ b/TUnit.Core/Tracking/ObjectTracker.cs @@ -13,12 +13,12 @@ internal class ObjectTracker(TrackableObjectGraphProvider trackableObjectGraphPr { // Use ReferenceEqualityComparer to prevent objects with custom Equals from sharing state private static readonly ConcurrentDictionary _trackedObjects = - new(new Helpers.ReferenceEqualityComparer()); + new(Helpers.ReferenceEqualityComparer.Instance); public void TrackObjects(TestContext testContext) { // Build alreadyTracked set without LINQ to reduce allocations - var alreadyTracked = new HashSet(new Helpers.ReferenceEqualityComparer()); + var alreadyTracked = new HashSet(Helpers.ReferenceEqualityComparer.Instance); foreach (var kvp in testContext.TrackedObjects) { // Lock while iterating to prevent concurrent modification @@ -32,7 +32,7 @@ public void TrackObjects(TestContext testContext) } // Get new trackable objects without LINQ - var newTrackableObjects = new HashSet(new Helpers.ReferenceEqualityComparer()); + var newTrackableObjects = new HashSet(Helpers.ReferenceEqualityComparer.Instance); var trackableDict = trackableObjectGraphProvider.GetTrackableObjects(testContext); foreach (var kvp in trackableDict) { @@ -57,7 +57,7 @@ public void TrackObjects(TestContext testContext) public async ValueTask UntrackObjects(TestContext testContext, List cleanupExceptions) { // Build objects set without LINQ to reduce allocations and with proper locking - var objectsToUntrack = new HashSet(new Helpers.ReferenceEqualityComparer()); + var objectsToUntrack = new HashSet(Helpers.ReferenceEqualityComparer.Instance); foreach (var kvp in testContext.TrackedObjects) { lock (kvp.Value) diff --git a/TUnit.Core/Tracking/TrackableObjectGraphProvider.cs b/TUnit.Core/Tracking/TrackableObjectGraphProvider.cs index 121c9da013..2460406c6f 100644 --- a/TUnit.Core/Tracking/TrackableObjectGraphProvider.cs +++ b/TUnit.Core/Tracking/TrackableObjectGraphProvider.cs @@ -30,37 +30,14 @@ public TrackableObjectGraphProvider(IObjectGraphDiscoverer discoverer) /// /// Gets trackable objects from a test context, organized by depth level. - /// Delegates to the shared ObjectGraphDiscoverer to eliminate code duplication. + /// Delegates to the shared IObjectGraphDiscoverer to eliminate code duplication. /// /// The test context to get trackable objects from. /// Optional cancellation token for long-running discovery. public ConcurrentDictionary> GetTrackableObjects(TestContext testContext, CancellationToken cancellationToken = default) { - // Use the ObjectGraphDiscoverer's specialized method that populates TrackedObjects directly - if (_discoverer is ObjectGraphDiscoverer concreteDiscoverer) - { - return concreteDiscoverer.DiscoverAndTrackObjects(testContext, cancellationToken); - } - - // Fallback for custom implementations (testing) - var graph = _discoverer.DiscoverObjectGraph(testContext, cancellationToken); - var trackedObjects = testContext.TrackedObjects; - - foreach (var (depth, objects) in graph.ObjectsByDepth) - { - cancellationToken.ThrowIfCancellationRequested(); - var depthSet = trackedObjects.GetOrAdd(depth, _ => []); - // Lock to ensure thread-safe HashSet modification - lock (depthSet) - { - foreach (var obj in objects) - { - depthSet.Add(obj); - } - } - } - - return trackedObjects; + // OCP-compliant: Use the interface method directly instead of type-checking + return _discoverer.DiscoverAndTrackObjects(testContext, cancellationToken); } /// diff --git a/TUnit.Engine/Services/ObjectLifecycleService.cs b/TUnit.Engine/Services/ObjectLifecycleService.cs index 30fbb40f9b..109e66eb02 100644 --- a/TUnit.Engine/Services/ObjectLifecycleService.cs +++ b/TUnit.Engine/Services/ObjectLifecycleService.cs @@ -31,7 +31,7 @@ internal sealed class ObjectLifecycleService : IObjectRegistry, IInitializationC // Track initialization state per object // Use ReferenceEqualityComparer to prevent objects with custom Equals from sharing initialization state private readonly ConcurrentDictionary> _initializationTasks = - new(new Core.Helpers.ReferenceEqualityComparer()); + new(Core.Helpers.ReferenceEqualityComparer.Instance); public ObjectLifecycleService( Lazy propertyInjector, From 579cd7539f8c05e28e84c87992b3ada1697e2075 Mon Sep 17 00:00:00 2001 From: Tom Longhurst <30480171+thomhurst@users.noreply.github.com> Date: Sun, 7 Dec 2025 16:11:16 +0000 Subject: [PATCH 09/20] feat: implement IAsyncDiscoveryInitializer and related classes for improved test discovery handling --- TUnit.Core/Discovery/ObjectGraphDiscoverer.cs | 167 +++++++----------- TUnit.Core/Discovery/PropertyCacheManager.cs | 123 +++++++++++++ .../Interfaces/IObjectGraphDiscoverer.cs | 29 +++ TUnit.Core/Tracking/ObjectTracker.cs | 64 ++++--- .../Services/ObjectLifecycleService.cs | 54 +++--- 5 files changed, 285 insertions(+), 152 deletions(-) create mode 100644 TUnit.Core/Discovery/PropertyCacheManager.cs diff --git a/TUnit.Core/Discovery/ObjectGraphDiscoverer.cs b/TUnit.Core/Discovery/ObjectGraphDiscoverer.cs index 13c75b926c..fac430419b 100644 --- a/TUnit.Core/Discovery/ObjectGraphDiscoverer.cs +++ b/TUnit.Core/Discovery/ObjectGraphDiscoverer.cs @@ -8,6 +8,15 @@ namespace TUnit.Core.Discovery; +/// +/// Represents an error that occurred during object graph discovery. +/// +/// The name of the type being inspected. +/// The name of the property that failed to access. +/// The error message. +/// The exception that occurred. +public readonly record struct DiscoveryError(string TypeName, string PropertyName, string ErrorMessage, Exception Exception); + /// /// Centralized service for discovering and organizing object graphs. /// Consolidates duplicate graph traversal logic from ObjectGraphDiscoveryService and TrackableObjectGraphProvider. @@ -22,8 +31,12 @@ namespace TUnit.Core.Discovery; /// Depth 0: Root objects (class args, method args, property values) /// Depth 1+: Nested objects found in properties of objects at previous depth /// +/// +/// Discovery errors (e.g., property access failures) are collected in +/// rather than thrown, allowing discovery to continue despite individual property failures. +/// /// -public sealed class ObjectGraphDiscoverer : IObjectGraphDiscoverer +public sealed class ObjectGraphDiscoverer : IObjectGraphTracker { /// /// Maximum recursion depth for object graph discovery. @@ -31,18 +44,6 @@ public sealed class ObjectGraphDiscoverer : IObjectGraphDiscoverer /// private const int MaxRecursionDepth = 50; - /// - /// Maximum size for the property cache before cleanup is triggered. - /// Prevents unbounded memory growth in long-running test sessions. - /// - private const int MaxCacheSize = 10000; - - // Cache for GetProperties() results per type - eliminates repeated reflection calls - private static readonly ConcurrentDictionary PropertyCache = new(); - - // Flag to coordinate cache cleanup (prevents multiple threads cleaning simultaneously) - private static int _cleanupInProgress; - // Reference equality comparer for object tracking (ignores Equals overrides) private static readonly Helpers.ReferenceEqualityComparer ReferenceComparer = Helpers.ReferenceEqualityComparer.Instance; @@ -57,6 +58,27 @@ public sealed class ObjectGraphDiscoverer : IObjectGraphDiscoverer typeof(Guid) ]; + // Thread-safe collection of discovery errors for diagnostics + private static readonly ConcurrentBag DiscoveryErrors = []; + + /// + /// Gets all discovery errors that occurred during object graph traversal. + /// Useful for debugging and diagnostics when property access fails. + /// + /// A read-only list of discovery errors. + public static IReadOnlyList GetDiscoveryErrors() + { + return DiscoveryErrors.ToArray(); + } + + /// + /// Clears all recorded discovery errors. Call at end of test session. + /// + public static void ClearDiscoveryErrors() + { + DiscoveryErrors.Clear(); + } + /// /// Delegate for adding discovered objects to collections. /// Returns true if the object was newly added (not a duplicate). @@ -77,10 +99,11 @@ public sealed class ObjectGraphDiscoverer : IObjectGraphDiscoverer public IObjectGraph DiscoverObjectGraph(TestContext testContext, CancellationToken cancellationToken = default) { var objectsByDepth = new ConcurrentDictionary>(); - var allObjects = new HashSet(); + var allObjects = new HashSet(ReferenceComparer); + var allObjectsLock = new object(); // Thread-safety for allObjects HashSet var visitedObjects = new ConcurrentDictionary(ReferenceComparer); - // Standard mode add callback + // Standard mode add callback (thread-safe) bool TryAddStandard(object obj, int depth) { if (!visitedObjects.TryAdd(obj, 0)) @@ -89,7 +112,11 @@ bool TryAddStandard(object obj, int depth) } AddToDepth(objectsByDepth, depth, obj); - allObjects.Add(obj); + lock (allObjectsLock) + { + allObjects.Add(obj); + } + return true; } @@ -97,7 +124,7 @@ bool TryAddStandard(object obj, int depth) CollectRootObjects( testContext.Metadata.TestDetails, TryAddStandard, - obj => DiscoverNestedObjects(obj, objectsByDepth, visitedObjects, allObjects, currentDepth: 1, cancellationToken), + obj => DiscoverNestedObjects(obj, objectsByDepth, visitedObjects, allObjects, allObjectsLock, currentDepth: 1, cancellationToken), cancellationToken); return new ObjectGraph(objectsByDepth, allObjects); @@ -107,14 +134,19 @@ bool TryAddStandard(object obj, int depth) public IObjectGraph DiscoverNestedObjectGraph(object rootObject, CancellationToken cancellationToken = default) { var objectsByDepth = new ConcurrentDictionary>(); - var allObjects = new HashSet(); + var allObjects = new HashSet(ReferenceComparer); + var allObjectsLock = new object(); // Thread-safety for allObjects HashSet var visitedObjects = new ConcurrentDictionary(ReferenceComparer); if (visitedObjects.TryAdd(rootObject, 0)) { AddToDepth(objectsByDepth, 0, rootObject); - allObjects.Add(rootObject); - DiscoverNestedObjects(rootObject, objectsByDepth, visitedObjects, allObjects, currentDepth: 1, cancellationToken); + lock (allObjectsLock) + { + allObjects.Add(rootObject); + } + + DiscoverNestedObjects(rootObject, objectsByDepth, visitedObjects, allObjects, allObjectsLock, currentDepth: 1, cancellationToken); } return new ObjectGraph(objectsByDepth, allObjects); @@ -150,6 +182,7 @@ private void DiscoverNestedObjects( ConcurrentDictionary> objectsByDepth, ConcurrentDictionary visitedObjects, HashSet allObjects, + object allObjectsLock, int currentDepth, CancellationToken cancellationToken) { @@ -160,7 +193,7 @@ private void DiscoverNestedObjects( cancellationToken.ThrowIfCancellationRequested(); - // Standard mode add callback: visitedObjects + objectsByDepth + allObjects + // Standard mode add callback: visitedObjects + objectsByDepth + allObjects (thread-safe) bool TryAddStandard(object value, int depth) { if (!visitedObjects.TryAdd(value, 0)) @@ -169,14 +202,18 @@ bool TryAddStandard(object value, int depth) } AddToDepth(objectsByDepth, depth, value); - allObjects.Add(value); + lock (allObjectsLock) + { + allObjects.Add(value); + } + return true; } // Recursive callback void Recurse(object value, int depth) { - DiscoverNestedObjects(value, objectsByDepth, visitedObjects, allObjects, depth, cancellationToken); + DiscoverNestedObjects(value, objectsByDepth, visitedObjects, allObjects, allObjectsLock, depth, cancellationToken); } // Traverse injectable properties (useSourceRegistrarCheck = false) @@ -223,85 +260,12 @@ void Recurse(object value, int depth) } /// - /// Gets cached properties for a type, filtering to only readable non-indexed properties. - /// Includes periodic cache cleanup to prevent unbounded memory growth. - /// - [UnconditionalSuppressMessage("Trimming", "IL2070", Justification = "Reflection fallback for nested initializers. In AOT, source-gen handles primary discovery.")] - private static PropertyInfo[] GetCachedProperties(Type type) - { - // Periodic cleanup if cache grows too large to prevent memory leaks - // Use Interlocked to ensure only one thread performs cleanup at a time - if (PropertyCache.Count > MaxCacheSize && - Interlocked.CompareExchange(ref _cleanupInProgress, 1, 0) == 0) - { - try - { - // Double-check after acquiring cleanup flag - if (PropertyCache.Count > MaxCacheSize) - { - var keysToRemove = new List(MaxCacheSize / 2); - var count = 0; - foreach (var key in PropertyCache.Keys) - { - if (count++ >= MaxCacheSize / 2) - { - break; - } - - keysToRemove.Add(key); - } - - foreach (var key in keysToRemove) - { - PropertyCache.TryRemove(key, out _); - } -#if DEBUG - Debug.WriteLine($"[ObjectGraphDiscoverer] PropertyCache exceeded {MaxCacheSize} entries, cleared {keysToRemove.Count} entries"); -#endif - } - } - finally - { - Interlocked.Exchange(ref _cleanupInProgress, 0); - } - } - - return PropertyCache.GetOrAdd(type, static t => - { - // Use explicit loops instead of LINQ to avoid allocations in hot path - var allProps = t.GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance); - - // First pass: count eligible properties - var count = 0; - foreach (var p in allProps) - { - if (p.CanRead && p.GetIndexParameters().Length == 0) - { - count++; - } - } - - // Second pass: fill result array - var result = new PropertyInfo[count]; - var i = 0; - foreach (var p in allProps) - { - if (p.CanRead && p.GetIndexParameters().Length == 0) - { - result[i++] = p; - } - } - - return result; - }); - } - - /// - /// Clears the property cache. Called at end of test session to release memory. + /// Clears all caches. Called at end of test session to release memory. /// public static void ClearCache() { - PropertyCache.Clear(); + PropertyCacheManager.ClearCache(); + ClearDiscoveryErrors(); } /// @@ -442,7 +406,7 @@ private static void TraverseInitializerProperties( return; } - var properties = GetCachedProperties(type); + var properties = PropertyCacheManager.GetCachedProperties(type); foreach (var property in properties) { @@ -467,11 +431,12 @@ private static void TraverseInitializerProperties( } catch (Exception ex) { + // Record error for diagnostics (available via GetDiscoveryErrors()) + DiscoveryErrors.Add(new DiscoveryError(type.Name, property.Name, ex.Message, ex)); #if DEBUG Debug.WriteLine($"[ObjectGraphDiscoverer] Failed to access property '{property.Name}' on type '{type.Name}': {ex.Message}"); #endif // Continue discovery despite property access failures - _ = ex; } } } diff --git a/TUnit.Core/Discovery/PropertyCacheManager.cs b/TUnit.Core/Discovery/PropertyCacheManager.cs new file mode 100644 index 0000000000..9df656caa7 --- /dev/null +++ b/TUnit.Core/Discovery/PropertyCacheManager.cs @@ -0,0 +1,123 @@ +using System.Collections.Concurrent; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Reflection; + +namespace TUnit.Core.Discovery; + +/// +/// Manages cached property reflection results for object graph discovery. +/// Extracted from ObjectGraphDiscoverer to follow Single Responsibility Principle. +/// +/// +/// +/// This class caches arrays per type to avoid repeated reflection calls. +/// Includes automatic cache cleanup when size exceeds to prevent memory leaks. +/// +/// +/// Thread-safe: Uses and for coordination. +/// +/// +internal static class PropertyCacheManager +{ + /// + /// Maximum size for the property cache before cleanup is triggered. + /// Prevents unbounded memory growth in long-running test sessions. + /// + private const int MaxCacheSize = 10000; + + // Cache for GetProperties() results per type - eliminates repeated reflection calls + private static readonly ConcurrentDictionary PropertyCache = new(); + + // Flag to coordinate cache cleanup (prevents multiple threads cleaning simultaneously) + private static int _cleanupInProgress; + + /// + /// Gets cached properties for a type, filtering to only readable non-indexed properties. + /// Includes periodic cache cleanup to prevent unbounded memory growth. + /// + /// The type to get properties for. + /// An array of readable, non-indexed properties for the type. + [UnconditionalSuppressMessage("Trimming", "IL2070", Justification = "Reflection fallback for nested initializers. In AOT, source-gen handles primary discovery.")] + public static PropertyInfo[] GetCachedProperties(Type type) + { + // Periodic cleanup if cache grows too large to prevent memory leaks + // Use Interlocked to ensure only one thread performs cleanup at a time + if (PropertyCache.Count > MaxCacheSize && + Interlocked.CompareExchange(ref _cleanupInProgress, 1, 0) == 0) + { + try + { + // Double-check after acquiring cleanup flag + if (PropertyCache.Count > MaxCacheSize) + { + var keysToRemove = new List(MaxCacheSize / 2); + var count = 0; + foreach (var key in PropertyCache.Keys) + { + if (count++ >= MaxCacheSize / 2) + { + break; + } + + keysToRemove.Add(key); + } + + foreach (var key in keysToRemove) + { + PropertyCache.TryRemove(key, out _); + } +#if DEBUG + Debug.WriteLine($"[PropertyCacheManager] PropertyCache exceeded {MaxCacheSize} entries, cleared {keysToRemove.Count} entries"); +#endif + } + } + finally + { + Interlocked.Exchange(ref _cleanupInProgress, 0); + } + } + + return PropertyCache.GetOrAdd(type, static t => + { + // Use explicit loops instead of LINQ to avoid allocations in hot path + var allProps = t.GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance); + + // First pass: count eligible properties + var eligibleCount = 0; + foreach (var p in allProps) + { + if (p.CanRead && p.GetIndexParameters().Length == 0) + { + eligibleCount++; + } + } + + // Second pass: fill result array + var result = new PropertyInfo[eligibleCount]; + var i = 0; + foreach (var p in allProps) + { + if (p.CanRead && p.GetIndexParameters().Length == 0) + { + result[i++] = p; + } + } + + return result; + }); + } + + /// + /// Clears the property cache. Called at end of test session to release memory. + /// + public static void ClearCache() + { + PropertyCache.Clear(); + } + + /// + /// Gets the current number of cached types. Useful for diagnostics. + /// + public static int CacheCount => PropertyCache.Count; +} diff --git a/TUnit.Core/Interfaces/IObjectGraphDiscoverer.cs b/TUnit.Core/Interfaces/IObjectGraphDiscoverer.cs index 6d064b4be0..c6525cf634 100644 --- a/TUnit.Core/Interfaces/IObjectGraphDiscoverer.cs +++ b/TUnit.Core/Interfaces/IObjectGraphDiscoverer.cs @@ -4,6 +4,7 @@ namespace TUnit.Core.Interfaces; /// /// Defines a contract for discovering object graphs from test contexts. +/// Pure query interface - only reads and returns data, does not modify state. /// /// /// @@ -19,6 +20,9 @@ namespace TUnit.Core.Interfaces; /// Nested objects that implement /// /// +/// +/// For tracking operations that modify TestContext.TrackedObjects, see . +/// /// public interface IObjectGraphDiscoverer { @@ -55,9 +59,34 @@ public interface IObjectGraphDiscoverer /// /// The tracked objects dictionary (same as testContext.TrackedObjects) populated with discovered objects. /// + /// + /// This method modifies testContext.TrackedObjects directly. For pure query operations, + /// use instead. + /// ConcurrentDictionary> DiscoverAndTrackObjects(TestContext testContext, CancellationToken cancellationToken = default); } +/// +/// Marker interface for object graph tracking operations. +/// Extends with operations that modify state. +/// +/// +/// +/// This interface exists to support Interface Segregation Principle: +/// clients that only need query operations can depend on , +/// while clients that need tracking can depend on . +/// +/// +/// Currently inherits all methods from . +/// The distinction exists for semantic clarity and future extensibility. +/// +/// +public interface IObjectGraphTracker : IObjectGraphDiscoverer +{ + // All methods inherited from IObjectGraphDiscoverer + // This interface provides semantic clarity for tracking operations +} + /// /// Represents a discovered object graph organized by depth level. /// diff --git a/TUnit.Core/Tracking/ObjectTracker.cs b/TUnit.Core/Tracking/ObjectTracker.cs index 86a8afab6d..6a229f5fdc 100644 --- a/TUnit.Core/Tracking/ObjectTracker.cs +++ b/TUnit.Core/Tracking/ObjectTracker.cs @@ -9,29 +9,51 @@ namespace TUnit.Core.Tracking; /// Objects are disposed when their reference count reaches zero, regardless of sharing type. /// Uses ReferenceEqualityComparer to track objects by identity, not value equality. /// +/// +/// The static s_trackedObjects dictionary is shared across all tests. +/// Call at the end of a test session to release memory. +/// internal class ObjectTracker(TrackableObjectGraphProvider trackableObjectGraphProvider, Disposer disposer) { // Use ReferenceEqualityComparer to prevent objects with custom Equals from sharing state - private static readonly ConcurrentDictionary _trackedObjects = + private static readonly ConcurrentDictionary s_trackedObjects = new(Helpers.ReferenceEqualityComparer.Instance); - public void TrackObjects(TestContext testContext) + /// + /// Clears all static tracking state. Call at the end of a test session to release memory. + /// + public static void ClearStaticTracking() { - // Build alreadyTracked set without LINQ to reduce allocations - var alreadyTracked = new HashSet(Helpers.ReferenceEqualityComparer.Instance); - foreach (var kvp in testContext.TrackedObjects) + s_trackedObjects.Clear(); + } + + /// + /// Flattens a ConcurrentDictionary of depth-keyed HashSets into a single HashSet. + /// Thread-safe: locks each HashSet while copying. + /// + private static HashSet FlattenTrackedObjects(ConcurrentDictionary> trackedObjects) + { + var result = new HashSet(Helpers.ReferenceEqualityComparer.Instance); + foreach (var kvp in trackedObjects) { - // Lock while iterating to prevent concurrent modification lock (kvp.Value) { foreach (var obj in kvp.Value) { - alreadyTracked.Add(obj); + result.Add(obj); } } } - // Get new trackable objects without LINQ + return result; + } + + public void TrackObjects(TestContext testContext) + { + // Get already tracked objects (DRY: use helper method) + var alreadyTracked = FlattenTrackedObjects(testContext.TrackedObjects); + + // Get new trackable objects var newTrackableObjects = new HashSet(Helpers.ReferenceEqualityComparer.Instance); var trackableDict = trackableObjectGraphProvider.GetTrackableObjects(testContext); foreach (var kvp in trackableDict) @@ -56,18 +78,8 @@ public void TrackObjects(TestContext testContext) public async ValueTask UntrackObjects(TestContext testContext, List cleanupExceptions) { - // Build objects set without LINQ to reduce allocations and with proper locking - var objectsToUntrack = new HashSet(Helpers.ReferenceEqualityComparer.Instance); - foreach (var kvp in testContext.TrackedObjects) - { - lock (kvp.Value) - { - foreach (var obj in kvp.Value) - { - objectsToUntrack.Add(obj); - } - } - } + // Get all objects to untrack (DRY: use helper method) + var objectsToUntrack = FlattenTrackedObjects(testContext.TrackedObjects); foreach (var obj in objectsToUntrack) { @@ -108,7 +120,7 @@ private void TrackObject(object? obj) return; } - var counter = _trackedObjects.GetOrAdd(obj, static _ => new Counter()); + var counter = s_trackedObjects.GetOrAdd(obj, static _ => new Counter()); counter.Increment(); } @@ -119,7 +131,7 @@ private async ValueTask UntrackObject(object? obj) return; } - if (_trackedObjects.TryGetValue(obj, out var counter)) + if (s_trackedObjects.TryGetValue(obj, out var counter)) { var count = counter.Decrement(); @@ -145,12 +157,12 @@ private static bool ShouldSkipTracking(object? obj) public static void OnDisposed(object? o, Action action) { - if(o is not IDisposable and not IAsyncDisposable) + if (o is not IDisposable and not IAsyncDisposable) { return; } - _trackedObjects.GetOrAdd(o, static _ => new Counter()) + s_trackedObjects.GetOrAdd(o, static _ => new Counter()) .OnCountChanged += (_, count) => { if (count == 0) @@ -162,13 +174,13 @@ public static void OnDisposed(object? o, Action action) public static void OnDisposedAsync(object? o, Func asyncAction) { - if(o is not IDisposable and not IAsyncDisposable) + if (o is not IDisposable and not IAsyncDisposable) { return; } // Avoid async void pattern by wrapping in fire-and-forget with exception handling - _trackedObjects.GetOrAdd(o, static _ => new Counter()) + s_trackedObjects.GetOrAdd(o, static _ => new Counter()) .OnCountChanged += (_, count) => { if (count == 0) diff --git a/TUnit.Engine/Services/ObjectLifecycleService.cs b/TUnit.Engine/Services/ObjectLifecycleService.cs index 109e66eb02..1d72da3098 100644 --- a/TUnit.Engine/Services/ObjectLifecycleService.cs +++ b/TUnit.Engine/Services/ObjectLifecycleService.cs @@ -244,29 +244,12 @@ private async Task InitializeTrackedObjectsAsync(TestContext testContext, Cancel /// /// Initializes nested objects during execution phase - all IAsyncInitializer objects. /// - private async Task InitializeNestedObjectsForExecutionAsync(object rootObject, CancellationToken cancellationToken) + private Task InitializeNestedObjectsForExecutionAsync(object rootObject, CancellationToken cancellationToken) { - var graph = _objectGraphDiscoveryService.DiscoverNestedObjectGraph(rootObject, cancellationToken); - - // Initialize from deepest to shallowest (skip depth 0 which is the root itself) - foreach (var depth in graph.GetDepthsDescending()) - { - if (depth == 0) continue; // Root handled separately - - var objectsAtDepth = graph.GetObjectsAtDepth(depth); - - // Pre-allocate task list without LINQ Select - var tasks = new List(); - foreach (var obj in objectsAtDepth) - { - tasks.Add(ObjectInitializer.InitializeAsync(obj, cancellationToken).AsTask()); - } - - if (tasks.Count > 0) - { - await Task.WhenAll(tasks); - } - } + return InitializeNestedObjectsAsync( + rootObject, + ObjectInitializer.InitializeAsync, + cancellationToken); } #endregion @@ -385,14 +368,35 @@ private async Task InitializeObjectCoreAsync( /// /// Initializes nested objects during discovery phase - only IAsyncDiscoveryInitializer objects. /// - private async Task InitializeNestedObjectsForDiscoveryAsync(object rootObject, CancellationToken cancellationToken) + private Task InitializeNestedObjectsForDiscoveryAsync(object rootObject, CancellationToken cancellationToken) + { + return InitializeNestedObjectsAsync( + rootObject, + ObjectInitializer.InitializeForDiscoveryAsync, + cancellationToken); + } + + /// + /// Shared implementation for nested object initialization (DRY). + /// Discovers nested objects and initializes them depth-first using the provided initializer. + /// + /// The root object to discover nested objects from. + /// The initializer function to call for each object. + /// Cancellation token. + private async Task InitializeNestedObjectsAsync( + object rootObject, + Func initializer, + CancellationToken cancellationToken) { var graph = _objectGraphDiscoveryService.DiscoverNestedObjectGraph(rootObject, cancellationToken); // Initialize from deepest to shallowest (skip depth 0 which is the root itself) foreach (var depth in graph.GetDepthsDescending()) { - if (depth == 0) continue; // Root handled separately + if (depth == 0) + { + continue; // Root handled separately + } var objectsAtDepth = graph.GetObjectsAtDepth(depth); @@ -400,7 +404,7 @@ private async Task InitializeNestedObjectsForDiscoveryAsync(object rootObject, C var tasks = new List(); foreach (var obj in objectsAtDepth) { - tasks.Add(ObjectInitializer.InitializeForDiscoveryAsync(obj, cancellationToken).AsTask()); + tasks.Add(initializer(obj, cancellationToken).AsTask()); } if (tasks.Count > 0) From 8c661b32b632ce9b232917c280448a3577100dfa Mon Sep 17 00:00:00 2001 From: Tom Longhurst <30480171+thomhurst@users.noreply.github.com> Date: Sun, 7 Dec 2025 16:49:19 +0000 Subject: [PATCH 10/20] feat: implement IAsyncDiscoveryInitializer and related classes for improved test discovery handling --- .../Conditions/EqualsAssertion.cs | 14 +- .../Helpers/StructuralEqualityComparer.cs | 8 +- .../NotStructuralEquivalencyAssertion.cs | 9 +- .../StructuralEquivalencyAssertion.cs | 9 +- TUnit.Core/Discovery/ObjectGraphDiscoverer.cs | 117 ++++++++----- TUnit.Core/Discovery/PropertyCacheManager.cs | 21 +-- TUnit.Core/Helpers/Counter.cs | 17 +- TUnit.Core/Helpers/Disposer.cs | 9 +- TUnit.Core/Helpers/ParallelTaskHelper.cs | 165 ++++++++++++++++++ TUnit.Core/Interfaces/IDisposer.cs | 16 ++ .../PropertyInitializationContext.cs | 121 +++++++++++++ .../PropertyCacheKeyGenerator.cs | 36 ++++ .../PropertyInjectionPlanBuilder.cs | 97 ++++++++++ TUnit.Core/TestContext.cs | 3 +- TUnit.Core/Tracking/ObjectTracker.cs | 76 ++++++-- .../Services/ObjectLifecycleService.cs | 4 +- TUnit.Engine/Services/PropertyInjector.cs | 75 ++------ 17 files changed, 631 insertions(+), 166 deletions(-) create mode 100644 TUnit.Core/Helpers/ParallelTaskHelper.cs create mode 100644 TUnit.Core/Interfaces/IDisposer.cs create mode 100644 TUnit.Core/PropertyInjection/PropertyCacheKeyGenerator.cs diff --git a/TUnit.Assertions/Conditions/EqualsAssertion.cs b/TUnit.Assertions/Conditions/EqualsAssertion.cs index 821a2b4e7f..350c439fad 100644 --- a/TUnit.Assertions/Conditions/EqualsAssertion.cs +++ b/TUnit.Assertions/Conditions/EqualsAssertion.cs @@ -3,6 +3,7 @@ using System.Reflection; using System.Text; using TUnit.Assertions.Attributes; +using TUnit.Assertions.Conditions.Helpers; using TUnit.Assertions.Core; namespace TUnit.Assertions.Conditions; @@ -84,7 +85,7 @@ protected override Task CheckAsync(EvaluationMetadata m if (_ignoredTypes.Count > 0) { // Use reference-based tracking to detect cycles - var visited = new HashSet(new ReferenceEqualityComparer()); + var visited = new HashSet(ReferenceEqualityComparer.Instance); var result = DeepEquals(value, _expected, _ignoredTypes, visited); if (result.IsSuccess) { @@ -213,15 +214,4 @@ private static (bool IsSuccess, string? Message) DeepEquals(object? actual, obje } protected override string GetExpectation() => $"to be equal to {(_expected is string s ? $"\"{s}\"" : _expected)}"; - - /// - /// Comparer that uses reference equality instead of value equality. - /// Used for cycle detection in deep comparison. - /// - private sealed class ReferenceEqualityComparer : IEqualityComparer - { - public new bool Equals(object? x, object? y) => ReferenceEquals(x, y); - - public int GetHashCode(object obj) => System.Runtime.CompilerServices.RuntimeHelpers.GetHashCode(obj); - } } diff --git a/TUnit.Assertions/Conditions/Helpers/StructuralEqualityComparer.cs b/TUnit.Assertions/Conditions/Helpers/StructuralEqualityComparer.cs index af1a1876e9..9ed8e63f3d 100644 --- a/TUnit.Assertions/Conditions/Helpers/StructuralEqualityComparer.cs +++ b/TUnit.Assertions/Conditions/Helpers/StructuralEqualityComparer.cs @@ -41,7 +41,7 @@ public bool Equals(T? x, T? y) return EqualityComparer.Default.Equals(x, y); } - return CompareStructurally(x, y, new HashSet(new ReferenceEqualityComparer())); + return CompareStructurally(x, y, new HashSet(ReferenceEqualityComparer.Instance)); } public int GetHashCode(T obj) @@ -154,10 +154,4 @@ private static List GetMembersToCompare([DynamicallyAccessedMembers( _ => throw new InvalidOperationException($"Unknown member type: {member.GetType()}") }; } - - private sealed class ReferenceEqualityComparer : IEqualityComparer - { - public new bool Equals(object? x, object? y) => ReferenceEquals(x, y); - public int GetHashCode(object obj) => System.Runtime.CompilerServices.RuntimeHelpers.GetHashCode(obj); - } } diff --git a/TUnit.Assertions/Conditions/NotStructuralEquivalencyAssertion.cs b/TUnit.Assertions/Conditions/NotStructuralEquivalencyAssertion.cs index 4e58fd0bb6..32399096b5 100644 --- a/TUnit.Assertions/Conditions/NotStructuralEquivalencyAssertion.cs +++ b/TUnit.Assertions/Conditions/NotStructuralEquivalencyAssertion.cs @@ -1,5 +1,6 @@ using System.Diagnostics.CodeAnalysis; using System.Text; +using TUnit.Assertions.Conditions.Helpers; using TUnit.Assertions.Core; namespace TUnit.Assertions.Conditions; @@ -90,7 +91,7 @@ protected override Task CheckAsync(EvaluationMetadata m foreach (var type in _ignoredTypes) tempAssertion.IgnoringType(type); - var result = tempAssertion.CompareObjects(value, _notExpected, "", new HashSet(new ReferenceEqualityComparer())); + var result = tempAssertion.CompareObjects(value, _notExpected, "", new HashSet(ReferenceEqualityComparer.Instance)); // Invert the result - we want them to NOT be equivalent if (result.IsPassed) @@ -101,12 +102,6 @@ protected override Task CheckAsync(EvaluationMetadata m return Task.FromResult(AssertionResult.Passed); } - private sealed class ReferenceEqualityComparer : IEqualityComparer - { - public new bool Equals(object? x, object? y) => ReferenceEquals(x, y); - public int GetHashCode(object obj) => System.Runtime.CompilerServices.RuntimeHelpers.GetHashCode(obj); - } - protected override string GetExpectation() { // Extract the source variable name from the expression builder diff --git a/TUnit.Assertions/Conditions/StructuralEquivalencyAssertion.cs b/TUnit.Assertions/Conditions/StructuralEquivalencyAssertion.cs index 8591bf4da9..e17718a046 100644 --- a/TUnit.Assertions/Conditions/StructuralEquivalencyAssertion.cs +++ b/TUnit.Assertions/Conditions/StructuralEquivalencyAssertion.cs @@ -2,6 +2,7 @@ using System.Diagnostics.CodeAnalysis; using System.Reflection; using System.Text; +using TUnit.Assertions.Conditions.Helpers; using TUnit.Assertions.Core; namespace TUnit.Assertions.Conditions; @@ -77,7 +78,7 @@ protected override Task CheckAsync(EvaluationMetadata m return Task.FromResult(AssertionResult.Failed($"threw {exception.GetType().Name}: {exception.Message}")); } - var result = CompareObjects(value, _expected, "", new HashSet(new ReferenceEqualityComparer())); + var result = CompareObjects(value, _expected, "", new HashSet(ReferenceEqualityComparer.Instance)); return Task.FromResult(result); } @@ -366,10 +367,4 @@ private static string ExtractSourceVariable(string expression) return "value"; } - - private sealed class ReferenceEqualityComparer : IEqualityComparer - { - public new bool Equals(object? x, object? y) => ReferenceEquals(x, y); - public int GetHashCode(object obj) => System.Runtime.CompilerServices.RuntimeHelpers.GetHashCode(obj); - } } diff --git a/TUnit.Core/Discovery/ObjectGraphDiscoverer.cs b/TUnit.Core/Discovery/ObjectGraphDiscoverer.cs index fac430419b..27c8f811b3 100644 --- a/TUnit.Core/Discovery/ObjectGraphDiscoverer.cs +++ b/TUnit.Core/Discovery/ObjectGraphDiscoverer.cs @@ -4,6 +4,7 @@ using System.Reflection; using TUnit.Core.Helpers; using TUnit.Core.Interfaces; +using TUnit.Core.Interfaces.SourceGenerator; using TUnit.Core.PropertyInjection; namespace TUnit.Core.Discovery; @@ -350,37 +351,67 @@ private static void TraverseInjectableProperties( if (useSourceGen) { - foreach (var metadata in plan.SourceGeneratedProperties) - { - cancellationToken.ThrowIfCancellationRequested(); - var property = metadata.ContainingType.GetProperty(metadata.PropertyName); - if (property == null || !property.CanRead) - { - continue; - } - - var value = property.GetValue(obj); - if (value != null && tryAdd(value, currentDepth)) - { - recurse(value, currentDepth + 1); - } - } + TraverseSourceGeneratedProperties(obj, plan.SourceGeneratedProperties, tryAdd, recurse, currentDepth, cancellationToken); } else { - // Reflection path - use the appropriate property collection var reflectionProps = useSourceRegistrarCheck ? plan.ReflectionProperties : (plan.ReflectionProperties.Length > 0 ? plan.ReflectionProperties : []); - foreach (var prop in reflectionProps) + TraverseReflectionProperties(obj, reflectionProps, tryAdd, recurse, currentDepth, cancellationToken); + } + } + + /// + /// Traverses source-generated properties and discovers nested objects. + /// Extracted for reduced complexity in TraverseInjectableProperties. + /// + [UnconditionalSuppressMessage("Trimming", "IL2026", Justification = "Property discovery handles both AOT and reflection modes")] + private static void TraverseSourceGeneratedProperties( + object obj, + PropertyInjectionMetadata[] sourceGeneratedProperties, + TryAddObjectFunc tryAdd, + RecurseFunc recurse, + int currentDepth, + CancellationToken cancellationToken) + { + foreach (var metadata in sourceGeneratedProperties) + { + cancellationToken.ThrowIfCancellationRequested(); + var property = metadata.ContainingType.GetProperty(metadata.PropertyName); + if (property == null || !property.CanRead) + { + continue; + } + + var value = property.GetValue(obj); + if (value != null && tryAdd(value, currentDepth)) { - cancellationToken.ThrowIfCancellationRequested(); - var value = prop.Property.GetValue(obj); - if (value != null && tryAdd(value, currentDepth)) - { - recurse(value, currentDepth + 1); - } + recurse(value, currentDepth + 1); + } + } + } + + /// + /// Traverses reflection-based properties and discovers nested objects. + /// Extracted for reduced complexity in TraverseInjectableProperties. + /// + private static void TraverseReflectionProperties( + object obj, + (PropertyInfo Property, IDataSourceAttribute DataSource)[] reflectionProperties, + TryAddObjectFunc tryAdd, + RecurseFunc recurse, + int currentDepth, + CancellationToken cancellationToken) + { + foreach (var prop in reflectionProperties) + { + cancellationToken.ThrowIfCancellationRequested(); + var value = prop.Property.GetValue(obj); + if (value != null && tryAdd(value, currentDepth)) + { + recurse(value, currentDepth + 1); } } } @@ -451,30 +482,32 @@ private static void CollectRootObjects( RootObjectCallback onRootObjectAdded, CancellationToken cancellationToken) { - foreach (var classArgument in testDetails.TestClassArguments) - { - cancellationToken.ThrowIfCancellationRequested(); - if (classArgument != null && tryAdd(classArgument, 0)) - { - onRootObjectAdded(classArgument); - } - } + // Process class arguments + ProcessRootCollection(testDetails.TestClassArguments, tryAdd, onRootObjectAdded, cancellationToken); - foreach (var methodArgument in testDetails.TestMethodArguments) - { - cancellationToken.ThrowIfCancellationRequested(); - if (methodArgument != null && tryAdd(methodArgument, 0)) - { - onRootObjectAdded(methodArgument); - } - } + // Process method arguments + ProcessRootCollection(testDetails.TestMethodArguments, tryAdd, onRootObjectAdded, cancellationToken); - foreach (var property in testDetails.TestClassInjectedPropertyArguments.Values) + // Process injected property values + ProcessRootCollection(testDetails.TestClassInjectedPropertyArguments.Values, tryAdd, onRootObjectAdded, cancellationToken); + } + + /// + /// Processes a collection of root objects, adding them to the graph and invoking callback. + /// Extracted to eliminate duplicate iteration patterns in CollectRootObjects. + /// + private static void ProcessRootCollection( + IEnumerable collection, + TryAddObjectFunc tryAdd, + RootObjectCallback onRootObjectAdded, + CancellationToken cancellationToken) + { + foreach (var item in collection) { cancellationToken.ThrowIfCancellationRequested(); - if (property != null && tryAdd(property, 0)) + if (item != null && tryAdd(item, 0)) { - onRootObjectAdded(property); + onRootObjectAdded(item); } } } diff --git a/TUnit.Core/Discovery/PropertyCacheManager.cs b/TUnit.Core/Discovery/PropertyCacheManager.cs index 9df656caa7..12319608cc 100644 --- a/TUnit.Core/Discovery/PropertyCacheManager.cs +++ b/TUnit.Core/Discovery/PropertyCacheManager.cs @@ -51,24 +51,17 @@ public static PropertyInfo[] GetCachedProperties(Type type) // Double-check after acquiring cleanup flag if (PropertyCache.Count > MaxCacheSize) { - var keysToRemove = new List(MaxCacheSize / 2); - var count = 0; - foreach (var key in PropertyCache.Keys) - { - if (count++ >= MaxCacheSize / 2) - { - break; - } - - keysToRemove.Add(key); - } + // Use ToArray() to get a true snapshot for thread-safe enumeration + // This prevents issues with concurrent modifications during iteration + var allKeys = PropertyCache.Keys.ToArray(); + var removeCount = Math.Min(allKeys.Length / 2, MaxCacheSize / 2); - foreach (var key in keysToRemove) + for (var i = 0; i < removeCount; i++) { - PropertyCache.TryRemove(key, out _); + PropertyCache.TryRemove(allKeys[i], out _); } #if DEBUG - Debug.WriteLine($"[PropertyCacheManager] PropertyCache exceeded {MaxCacheSize} entries, cleared {keysToRemove.Count} entries"); + Debug.WriteLine($"[PropertyCacheManager] PropertyCache exceeded {MaxCacheSize} entries, cleared {removeCount} entries"); #endif } } diff --git a/TUnit.Core/Helpers/Counter.cs b/TUnit.Core/Helpers/Counter.cs index d3df14af6f..6fb4080833 100644 --- a/TUnit.Core/Helpers/Counter.cs +++ b/TUnit.Core/Helpers/Counter.cs @@ -2,6 +2,11 @@ namespace TUnit.Core.Helpers; +/// +/// Thread-safe counter with event notification. +/// Captures event handler BEFORE state change to prevent race conditions +/// where subscribers miss notifications that occur during subscription. +/// [DebuggerDisplay("Count = {CurrentCount}")] public class Counter { @@ -11,9 +16,11 @@ public class Counter public int Increment() { + // Capture handler BEFORE state change to ensure all subscribers + // at the time of the change are notified (prevents TOCTOU race) + var handler = _onCountChanged; var newCount = Interlocked.Increment(ref _count); - var handler = _onCountChanged; handler?.Invoke(this, newCount); return newCount; @@ -21,9 +28,11 @@ public int Increment() public int Decrement() { + // Capture handler BEFORE state change to ensure all subscribers + // at the time of the change are notified (prevents TOCTOU race) + var handler = _onCountChanged; var newCount = Interlocked.Decrement(ref _count); - var handler = _onCountChanged; handler?.Invoke(this, newCount); return newCount; @@ -31,9 +40,11 @@ public int Decrement() public int Add(int value) { + // Capture handler BEFORE state change to ensure all subscribers + // at the time of the change are notified (prevents TOCTOU race) + var handler = _onCountChanged; var newCount = Interlocked.Add(ref _count, value); - var handler = _onCountChanged; handler?.Invoke(this, newCount); return newCount; diff --git a/TUnit.Core/Helpers/Disposer.cs b/TUnit.Core/Helpers/Disposer.cs index 1bafb121ca..a70033bb79 100644 --- a/TUnit.Core/Helpers/Disposer.cs +++ b/TUnit.Core/Helpers/Disposer.cs @@ -1,8 +1,13 @@ -using TUnit.Core.Logging; +using TUnit.Core.Interfaces; +using TUnit.Core.Logging; namespace TUnit.Core.Helpers; -internal class Disposer(ILogger logger) +/// +/// Disposes objects asynchronously with logging. +/// Implements IDisposer for Dependency Inversion Principle compliance. +/// +internal class Disposer(ILogger logger) : IDisposer { /// /// Disposes an object and propagates any exceptions. diff --git a/TUnit.Core/Helpers/ParallelTaskHelper.cs b/TUnit.Core/Helpers/ParallelTaskHelper.cs new file mode 100644 index 0000000000..81f05f06d4 --- /dev/null +++ b/TUnit.Core/Helpers/ParallelTaskHelper.cs @@ -0,0 +1,165 @@ +namespace TUnit.Core.Helpers; + +/// +/// Helper methods for parallel task execution without LINQ allocations. +/// Provides optimized patterns for executing async operations in parallel. +/// Exceptions are aggregated in AggregateException when multiple tasks fail. +/// +public static class ParallelTaskHelper +{ + /// + /// Executes an async action for each item in an array, in parallel. + /// Uses pre-allocated task array to avoid LINQ allocations. + /// + /// The type of items to process. + /// The array of items to process. + /// The async action to execute for each item. + /// A task that completes when all items have been processed. + public static async Task ForEachAsync(T[] items, Func action) + { + if (items.Length == 0) + { + return; + } + + var tasks = new Task[items.Length]; + for (var i = 0; i < items.Length; i++) + { + tasks[i] = action(items[i]); + } + + await Task.WhenAll(tasks); + } + + /// + /// Executes an async action for each item in an array, in parallel, with cancellation support. + /// Uses pre-allocated task array to avoid LINQ allocations. + /// + /// The type of items to process. + /// The array of items to process. + /// The async action to execute for each item. + /// Token to cancel the operation. + /// A task that completes when all items have been processed. + public static async Task ForEachAsync(T[] items, Func action, CancellationToken cancellationToken) + { + if (items.Length == 0) + { + return; + } + + cancellationToken.ThrowIfCancellationRequested(); + + var tasks = new Task[items.Length]; + for (var i = 0; i < items.Length; i++) + { + cancellationToken.ThrowIfCancellationRequested(); + tasks[i] = action(items[i], cancellationToken); + } + + await Task.WhenAll(tasks); + } + + /// + /// Executes an async action for each item in an array, in parallel, with an index. + /// Uses pre-allocated task array to avoid LINQ allocations. + /// + /// The type of items to process. + /// The array of items to process. + /// The async action to execute for each item with its index. + /// A task that completes when all items have been processed. + public static async Task ForEachWithIndexAsync(T[] items, Func action) + { + if (items.Length == 0) + { + return; + } + + var tasks = new Task[items.Length]; + for (var i = 0; i < items.Length; i++) + { + tasks[i] = action(items[i], i); + } + + await Task.WhenAll(tasks); + } + + /// + /// Executes an async action for each item in an array, in parallel, with an index and cancellation support. + /// Uses pre-allocated task array to avoid LINQ allocations. + /// + /// The type of items to process. + /// The array of items to process. + /// The async action to execute for each item with its index. + /// Token to cancel the operation. + /// A task that completes when all items have been processed. + public static async Task ForEachWithIndexAsync(T[] items, Func action, CancellationToken cancellationToken) + { + if (items.Length == 0) + { + return; + } + + cancellationToken.ThrowIfCancellationRequested(); + + var tasks = new Task[items.Length]; + for (var i = 0; i < items.Length; i++) + { + cancellationToken.ThrowIfCancellationRequested(); + tasks[i] = action(items[i], i, cancellationToken); + } + + await Task.WhenAll(tasks); + } + + /// + /// Executes an async action for each item in a list, in parallel. + /// Uses pre-allocated task array to avoid LINQ allocations. + /// + /// The type of items to process. + /// The list of items to process. + /// The async action to execute for each item. + /// A task that completes when all items have been processed. + public static async Task ForEachAsync(IReadOnlyList items, Func action) + { + if (items.Count == 0) + { + return; + } + + var tasks = new Task[items.Count]; + for (var i = 0; i < items.Count; i++) + { + tasks[i] = action(items[i]); + } + + await Task.WhenAll(tasks); + } + + /// + /// Executes an async action for each item in a list, in parallel, with cancellation support. + /// Uses pre-allocated task array to avoid LINQ allocations. + /// + /// The type of items to process. + /// The list of items to process. + /// The async action to execute for each item. + /// Token to cancel the operation. + /// A task that completes when all items have been processed. + public static async Task ForEachAsync(IReadOnlyList items, Func action, CancellationToken cancellationToken) + { + if (items.Count == 0) + { + return; + } + + cancellationToken.ThrowIfCancellationRequested(); + + var tasks = new Task[items.Count]; + for (var i = 0; i < items.Count; i++) + { + cancellationToken.ThrowIfCancellationRequested(); + tasks[i] = action(items[i], cancellationToken); + } + + await Task.WhenAll(tasks); + } +} diff --git a/TUnit.Core/Interfaces/IDisposer.cs b/TUnit.Core/Interfaces/IDisposer.cs new file mode 100644 index 0000000000..039665dbae --- /dev/null +++ b/TUnit.Core/Interfaces/IDisposer.cs @@ -0,0 +1,16 @@ +namespace TUnit.Core.Interfaces; + +/// +/// Interface for disposing objects. +/// Follows Dependency Inversion Principle - high-level modules depend on this abstraction. +/// +public interface IDisposer +{ + /// + /// Disposes an object asynchronously. + /// Implementations should propagate exceptions - callers handle aggregation. + /// + /// The object to dispose. + /// A task representing the disposal operation. + ValueTask DisposeAsync(object? obj); +} diff --git a/TUnit.Core/PropertyInjection/Initialization/PropertyInitializationContext.cs b/TUnit.Core/PropertyInjection/Initialization/PropertyInitializationContext.cs index 12aef89941..85d4f6ae80 100644 --- a/TUnit.Core/PropertyInjection/Initialization/PropertyInitializationContext.cs +++ b/TUnit.Core/PropertyInjection/Initialization/PropertyInitializationContext.cs @@ -7,6 +7,7 @@ namespace TUnit.Core.PropertyInjection.Initialization; /// /// Encapsulates all context needed for property initialization. /// Follows Single Responsibility Principle by being a pure data container. +/// Provides factory methods to reduce duplication when creating contexts (DRY). /// internal sealed class PropertyInitializationContext { @@ -84,4 +85,124 @@ internal sealed class PropertyInitializationContext /// Parent object for nested properties. /// public object? ParentInstance { get; init; } + + #region Factory Methods (DRY) + + /// + /// Creates a context for source-generated property injection. + /// + public static PropertyInitializationContext ForSourceGenerated( + object instance, + PropertyInjectionMetadata metadata, + ConcurrentDictionary objectBag, + MethodMetadata? methodMetadata, + TestContextEvents events, + ConcurrentDictionary visitedObjects, + TestContext? testContext, + bool isNestedProperty = false) + { + return new PropertyInitializationContext + { + Instance = instance, + SourceGeneratedMetadata = metadata, + PropertyName = metadata.PropertyName, + PropertyType = metadata.PropertyType, + PropertySetter = metadata.SetProperty, + ObjectBag = objectBag, + MethodMetadata = methodMetadata, + Events = events, + VisitedObjects = visitedObjects, + TestContext = testContext, + IsNestedProperty = isNestedProperty + }; + } + + /// + /// Creates a context for reflection-based property injection. + /// + public static PropertyInitializationContext ForReflection( + object instance, + PropertyInfo property, + IDataSourceAttribute dataSource, + Action propertySetter, + ConcurrentDictionary objectBag, + MethodMetadata? methodMetadata, + TestContextEvents events, + ConcurrentDictionary visitedObjects, + TestContext? testContext, + bool isNestedProperty = false) + { + return new PropertyInitializationContext + { + Instance = instance, + PropertyInfo = property, + DataSource = dataSource, + PropertyName = property.Name, + PropertyType = property.PropertyType, + PropertySetter = propertySetter, + ObjectBag = objectBag, + MethodMetadata = methodMetadata, + Events = events, + VisitedObjects = visitedObjects, + TestContext = testContext, + IsNestedProperty = isNestedProperty + }; + } + + /// + /// Creates a context for caching during registration (uses placeholder instance). + /// + public static PropertyInitializationContext ForCaching( + PropertyInjectionMetadata metadata, + ConcurrentDictionary objectBag, + MethodMetadata? methodMetadata, + TestContextEvents events, + TestContext testContext) + { + return new PropertyInitializationContext + { + Instance = PlaceholderInstance.Instance, + SourceGeneratedMetadata = metadata, + PropertyName = metadata.PropertyName, + PropertyType = metadata.PropertyType, + PropertySetter = metadata.SetProperty, + ObjectBag = objectBag, + MethodMetadata = methodMetadata, + Events = events, + VisitedObjects = new ConcurrentDictionary(), + TestContext = testContext, + IsNestedProperty = false + }; + } + + /// + /// Creates a context for reflection caching during registration (uses placeholder instance). + /// + public static PropertyInitializationContext ForReflectionCaching( + PropertyInfo property, + IDataSourceAttribute dataSource, + Action propertySetter, + ConcurrentDictionary objectBag, + MethodMetadata? methodMetadata, + TestContextEvents events, + TestContext testContext) + { + return new PropertyInitializationContext + { + Instance = PlaceholderInstance.Instance, + PropertyInfo = property, + DataSource = dataSource, + PropertyName = property.Name, + PropertyType = property.PropertyType, + PropertySetter = propertySetter, + ObjectBag = objectBag, + MethodMetadata = methodMetadata, + Events = events, + VisitedObjects = new ConcurrentDictionary(), + TestContext = testContext, + IsNestedProperty = false + }; + } + + #endregion } \ No newline at end of file diff --git a/TUnit.Core/PropertyInjection/PropertyCacheKeyGenerator.cs b/TUnit.Core/PropertyInjection/PropertyCacheKeyGenerator.cs new file mode 100644 index 0000000000..c98abadadb --- /dev/null +++ b/TUnit.Core/PropertyInjection/PropertyCacheKeyGenerator.cs @@ -0,0 +1,36 @@ +using System.Reflection; +using TUnit.Core.Interfaces.SourceGenerator; + +namespace TUnit.Core.PropertyInjection; + +/// +/// Generates consistent cache keys for property injection values. +/// Centralizes cache key generation to ensure consistency across the codebase (DRY principle). +/// +/// +/// Cache keys are formatted as "{DeclaringTypeName}.{PropertyName}" to uniquely identify +/// properties across different types. This format is used for storing and retrieving +/// injected property values in test contexts. +/// +public static class PropertyCacheKeyGenerator +{ + /// + /// Generates a cache key from source-generated property metadata. + /// + /// The property injection metadata from source generation. + /// A unique cache key string for the property. + public static string GetCacheKey(PropertyInjectionMetadata metadata) + { + return $"{metadata.ContainingType.FullName}.{metadata.PropertyName}"; + } + + /// + /// Generates a cache key from a PropertyInfo (reflection-based properties). + /// + /// The PropertyInfo from reflection. + /// A unique cache key string for the property. + public static string GetCacheKey(PropertyInfo property) + { + return $"{property.DeclaringType!.FullName}.{property.Name}"; + } +} diff --git a/TUnit.Core/PropertyInjection/PropertyInjectionPlanBuilder.cs b/TUnit.Core/PropertyInjection/PropertyInjectionPlanBuilder.cs index c0251ce158..0c16dc1d14 100644 --- a/TUnit.Core/PropertyInjection/PropertyInjectionPlanBuilder.cs +++ b/TUnit.Core/PropertyInjection/PropertyInjectionPlanBuilder.cs @@ -134,6 +134,7 @@ public static PropertyInjectionPlan Build(Type type) /// /// Represents a plan for injecting properties into an object. +/// Provides iterator methods to abstract source-gen vs reflection branching (DRY). /// internal sealed class PropertyInjectionPlan { @@ -141,4 +142,100 @@ internal sealed class PropertyInjectionPlan public required PropertyInjectionMetadata[] SourceGeneratedProperties { get; init; } public required (PropertyInfo Property, IDataSourceAttribute DataSource)[] ReflectionProperties { get; init; } public required bool HasProperties { get; init; } + + /// + /// Iterates over all properties in the plan, abstracting source-gen vs reflection. + /// Call the appropriate callback based on which mode has properties. + /// + /// Action to invoke for each source-generated property. + /// Action to invoke for each reflection property. + public void ForEachProperty( + Action onSourceGenerated, + Action<(PropertyInfo Property, IDataSourceAttribute DataSource)> onReflection) + { + if (SourceGeneratedProperties.Length > 0) + { + foreach (var metadata in SourceGeneratedProperties) + { + onSourceGenerated(metadata); + } + } + else if (ReflectionProperties.Length > 0) + { + foreach (var prop in ReflectionProperties) + { + onReflection(prop); + } + } + } + + /// + /// Iterates over all properties in the plan asynchronously. + /// + public async Task ForEachPropertyAsync( + Func onSourceGenerated, + Func<(PropertyInfo Property, IDataSourceAttribute DataSource), Task> onReflection) + { + if (SourceGeneratedProperties.Length > 0) + { + foreach (var metadata in SourceGeneratedProperties) + { + await onSourceGenerated(metadata); + } + } + else if (ReflectionProperties.Length > 0) + { + foreach (var prop in ReflectionProperties) + { + await onReflection(prop); + } + } + } + + /// + /// Executes actions for all properties in parallel. + /// + public Task ForEachPropertyParallelAsync( + Func onSourceGenerated, + Func<(PropertyInfo Property, IDataSourceAttribute DataSource), Task> onReflection) + { + if (SourceGeneratedProperties.Length > 0) + { + return Helpers.ParallelTaskHelper.ForEachAsync(SourceGeneratedProperties, onSourceGenerated); + } + else if (ReflectionProperties.Length > 0) + { + return Helpers.ParallelTaskHelper.ForEachAsync(ReflectionProperties, onReflection); + } + + return Task.CompletedTask; + } + + /// + /// Gets property values from an instance, abstracting source-gen vs reflection. + /// + public IEnumerable GetPropertyValues(object instance) + { + if (SourceGeneratedProperties.Length > 0) + { + foreach (var metadata in SourceGeneratedProperties) + { + var property = metadata.ContainingType.GetProperty(metadata.PropertyName); + if (property?.CanRead == true) + { + yield return property.GetValue(instance); + } + } + } + else if (ReflectionProperties.Length > 0) + { + foreach (var (property, _) in ReflectionProperties) + { + if (property.CanRead) + { + yield return property.GetValue(instance); + } + } + } + } } diff --git a/TUnit.Core/TestContext.cs b/TUnit.Core/TestContext.cs index af1b1710d9..bdd7d57a7b 100644 --- a/TUnit.Core/TestContext.cs +++ b/TUnit.Core/TestContext.cs @@ -49,7 +49,8 @@ public TestContext(string testName, IServiceProvider serviceProvider, ClassHookC private static readonly AsyncLocal TestContexts = new(); - internal static readonly Dictionary> InternalParametersDictionary = new(); + // Use ConcurrentDictionary for thread-safe access during parallel test discovery + internal static readonly ConcurrentDictionary> InternalParametersDictionary = new(); private StringWriter? _outputWriter; diff --git a/TUnit.Core/Tracking/ObjectTracker.cs b/TUnit.Core/Tracking/ObjectTracker.cs index 6a229f5fdc..3b86f5d04a 100644 --- a/TUnit.Core/Tracking/ObjectTracker.cs +++ b/TUnit.Core/Tracking/ObjectTracker.cs @@ -19,21 +19,48 @@ internal class ObjectTracker(TrackableObjectGraphProvider trackableObjectGraphPr private static readonly ConcurrentDictionary s_trackedObjects = new(Helpers.ReferenceEqualityComparer.Instance); + // Collects errors from async disposal callbacks for post-session review + private static readonly ConcurrentBag s_asyncCallbackErrors = new(); + + /// + /// Gets any errors that occurred during async disposal callbacks. + /// Check this at the end of a test session to surface hidden failures. + /// + public static IReadOnlyCollection GetAsyncCallbackErrors() => s_asyncCallbackErrors.ToArray(); + /// /// Clears all static tracking state. Call at the end of a test session to release memory. /// public static void ClearStaticTracking() { s_trackedObjects.Clear(); + s_asyncCallbackErrors.Clear(); } /// /// Flattens a ConcurrentDictionary of depth-keyed HashSets into a single HashSet. /// Thread-safe: locks each HashSet while copying. + /// Pre-calculates capacity to avoid HashSet resizing during population. /// private static HashSet FlattenTrackedObjects(ConcurrentDictionary> trackedObjects) { +#if NETSTANDARD2_0 + // .NET Standard 2.0 doesn't support HashSet capacity constructor var result = new HashSet(Helpers.ReferenceEqualityComparer.Instance); +#else + // First pass: calculate total capacity to avoid resizing + var totalCapacity = 0; + foreach (var kvp in trackedObjects) + { + lock (kvp.Value) + { + totalCapacity += kvp.Value.Count; + } + } + + // Second pass: populate with pre-sized HashSet + var result = new HashSet(totalCapacity, Helpers.ReferenceEqualityComparer.Instance); +#endif foreach (var kvp in trackedObjects) { lock (kvp.Value) @@ -142,6 +169,10 @@ private async ValueTask UntrackObject(object? obj) if (count == 0) { + // Remove from tracking dictionary to prevent memory leak + // Use TryRemove to ensure atomicity - only remove if still in dictionary + s_trackedObjects.TryRemove(obj, out _); + await disposer.DisposeAsync(obj).ConfigureAwait(false); } } @@ -155,6 +186,10 @@ private static bool ShouldSkipTracking(object? obj) return obj is not IDisposable and not IAsyncDisposable; } + /// + /// Registers a callback to be invoked when the object is disposed (ref count reaches 0). + /// If the object is already disposed, the callback is invoked immediately. + /// public static void OnDisposed(object? o, Action action) { if (o is not IDisposable and not IAsyncDisposable) @@ -162,16 +197,28 @@ public static void OnDisposed(object? o, Action action) return; } - s_trackedObjects.GetOrAdd(o, static _ => new Counter()) - .OnCountChanged += (_, count) => + var counter = s_trackedObjects.GetOrAdd(o, static _ => new Counter()); + + counter.OnCountChanged += (_, count) => { if (count == 0) { action(); } }; + + // Check if already disposed (count is 0) - invoke immediately if so + // This prevents lost callbacks when registering after disposal + if (counter.CurrentCount == 0) + { + action(); + } } + /// + /// Registers an async callback to be invoked when the object is disposed (ref count reaches 0). + /// If the object is already disposed, the callback is invoked immediately. + /// public static void OnDisposedAsync(object? o, Func asyncAction) { if (o is not IDisposable and not IAsyncDisposable) @@ -179,21 +226,29 @@ public static void OnDisposedAsync(object? o, Func asyncAction) return; } + var counter = s_trackedObjects.GetOrAdd(o, static _ => new Counter()); + // Avoid async void pattern by wrapping in fire-and-forget with exception handling - s_trackedObjects.GetOrAdd(o, static _ => new Counter()) - .OnCountChanged += (_, count) => + counter.OnCountChanged += (_, count) => { if (count == 0) { - // Fire-and-forget with exception handling to avoid unobserved exceptions + // Fire-and-forget with exception collection to surface errors _ = SafeExecuteAsync(asyncAction); } }; + + // Check if already disposed (count is 0) - invoke immediately if so + // This prevents lost callbacks when registering after disposal + if (counter.CurrentCount == 0) + { + _ = SafeExecuteAsync(asyncAction); + } } /// - /// Executes an async action safely, catching and logging any exceptions - /// to avoid unobserved task exceptions from fire-and-forget patterns. + /// Executes an async action safely, catching and collecting exceptions + /// for post-session review instead of silently swallowing them. /// private static async Task SafeExecuteAsync(Func asyncAction) { @@ -203,13 +258,12 @@ private static async Task SafeExecuteAsync(Func asyncAction) } catch (Exception ex) { - // Log to debug in DEBUG builds, otherwise swallow to prevent crashes - // The disposal itself already logged any errors + // Collect error for post-session review instead of silently swallowing + s_asyncCallbackErrors.Add(ex); + #if DEBUG System.Diagnostics.Debug.WriteLine($"[ObjectTracker] Exception in OnDisposedAsync callback: {ex.Message}"); #endif - // Prevent unobserved task exception from crashing the application - _ = ex; } } } diff --git a/TUnit.Engine/Services/ObjectLifecycleService.cs b/TUnit.Engine/Services/ObjectLifecycleService.cs index 1d72da3098..9895d65d3a 100644 --- a/TUnit.Engine/Services/ObjectLifecycleService.cs +++ b/TUnit.Engine/Services/ObjectLifecycleService.cs @@ -167,7 +167,7 @@ private void SetCachedPropertiesOnInstance(object instance, TestContext testCont { foreach (var metadata in plan.SourceGeneratedProperties) { - var cacheKey = $"{metadata.ContainingType.FullName}.{metadata.PropertyName}"; + var cacheKey = PropertyCacheKeyGenerator.GetCacheKey(metadata); if (cachedProperties.TryGetValue(cacheKey, out var cachedValue) && cachedValue != null) { @@ -180,7 +180,7 @@ private void SetCachedPropertiesOnInstance(object instance, TestContext testCont { foreach (var (property, _) in plan.ReflectionProperties) { - var cacheKey = $"{property.DeclaringType!.FullName}.{property.Name}"; + var cacheKey = PropertyCacheKeyGenerator.GetCacheKey(property); if (cachedProperties.TryGetValue(cacheKey, out var cachedValue) && cachedValue != null) { diff --git a/TUnit.Engine/Services/PropertyInjector.cs b/TUnit.Engine/Services/PropertyInjector.cs index 1073da6ed4..e0008d17ac 100644 --- a/TUnit.Engine/Services/PropertyInjector.cs +++ b/TUnit.Engine/Services/PropertyInjector.cs @@ -2,6 +2,7 @@ using System.Diagnostics.CodeAnalysis; using System.Reflection; using TUnit.Core; +using TUnit.Core.Helpers; using TUnit.Core.Interfaces; using TUnit.Core.Interfaces.SourceGenerator; using TUnit.Core.PropertyInjection; @@ -99,7 +100,7 @@ public async Task InjectPropertiesAsync( #if NETSTANDARD2_0 visitedObjects = new ConcurrentDictionary(); #else - visitedObjects = new ConcurrentDictionary(ReferenceEqualityComparer.Instance); + visitedObjects = new ConcurrentDictionary(Core.Helpers.ReferenceEqualityComparer.Instance); #endif } @@ -200,7 +201,7 @@ await InjectReflectionPropertiesAsync( } } - private async Task InjectSourceGeneratedPropertiesAsync( + private Task InjectSourceGeneratedPropertiesAsync( object instance, PropertyInjectionMetadata[] properties, ConcurrentDictionary objectBag, @@ -208,18 +209,8 @@ private async Task InjectSourceGeneratedPropertiesAsync( TestContextEvents events, ConcurrentDictionary visitedObjects) { - if (properties.Length == 0) - { - return; - } - - // Initialize properties in parallel without LINQ Select - var tasks = new Task[properties.Length]; - for (var i = 0; i < properties.Length; i++) - { - tasks[i] = InjectSourceGeneratedPropertyAsync(instance, properties[i], objectBag, methodMetadata, events, visitedObjects); - } - await Task.WhenAll(tasks); + return ParallelTaskHelper.ForEachAsync(properties, + prop => InjectSourceGeneratedPropertyAsync(instance, prop, objectBag, methodMetadata, events, visitedObjects)); } [UnconditionalSuppressMessage("Trimming", "IL2026", Justification = "Source-gen properties are AOT-safe")] @@ -248,7 +239,7 @@ private async Task InjectSourceGeneratedPropertyAsync( object? resolvedValue = null; // Use a composite key to avoid conflicts when nested classes have properties with the same name - var cacheKey = $"{metadata.ContainingType.FullName}.{metadata.PropertyName}"; + var cacheKey = PropertyCacheKeyGenerator.GetCacheKey(metadata); // Check if property was pre-resolved during registration if (testContext?.Metadata.TestDetails.TestClassInjectedPropertyArguments.TryGetValue(cacheKey, out resolvedValue) == true) @@ -292,7 +283,7 @@ private async Task InjectSourceGeneratedPropertyAsync( } [UnconditionalSuppressMessage("Trimming", "IL2026", Justification = "Reflection mode is not used in AOT")] - private async Task InjectReflectionPropertiesAsync( + private Task InjectReflectionPropertiesAsync( object instance, (PropertyInfo Property, IDataSourceAttribute DataSource)[] properties, ConcurrentDictionary objectBag, @@ -300,19 +291,8 @@ private async Task InjectReflectionPropertiesAsync( TestContextEvents events, ConcurrentDictionary visitedObjects) { - if (properties.Length == 0) - { - return; - } - - // Initialize properties in parallel without LINQ Select - var tasks = new Task[properties.Length]; - for (var i = 0; i < properties.Length; i++) - { - var pair = properties[i]; - tasks[i] = InjectReflectionPropertyAsync(instance, pair.Property, pair.DataSource, objectBag, methodMetadata, events, visitedObjects); - } - await Task.WhenAll(tasks); + return ParallelTaskHelper.ForEachAsync(properties, + pair => InjectReflectionPropertyAsync(instance, pair.Property, pair.DataSource, objectBag, methodMetadata, events, visitedObjects)); } [UnconditionalSuppressMessage("Trimming", "IL2026", Justification = "Reflection mode is not used in AOT")] @@ -406,25 +386,15 @@ private async Task RecurseIntoNestedPropertiesAsync( } } - private async Task ResolveAndCacheSourceGeneratedPropertiesAsync( + private Task ResolveAndCacheSourceGeneratedPropertiesAsync( PropertyInjectionMetadata[] properties, ConcurrentDictionary objectBag, MethodMetadata? methodMetadata, TestContextEvents events, TestContext testContext) { - if (properties.Length == 0) - { - return; - } - - // Resolve properties in parallel without LINQ Select - var tasks = new Task[properties.Length]; - for (var i = 0; i < properties.Length; i++) - { - tasks[i] = ResolveAndCacheSourceGeneratedPropertyAsync(properties[i], objectBag, methodMetadata, events, testContext); - } - await Task.WhenAll(tasks); + return ParallelTaskHelper.ForEachAsync(properties, + prop => ResolveAndCacheSourceGeneratedPropertyAsync(prop, objectBag, methodMetadata, events, testContext)); } [UnconditionalSuppressMessage("Trimming", "IL2026", Justification = "Source-gen properties are AOT-safe")] @@ -435,7 +405,7 @@ private async Task ResolveAndCacheSourceGeneratedPropertyAsync( TestContextEvents events, TestContext testContext) { - var cacheKey = $"{metadata.ContainingType.FullName}.{metadata.PropertyName}"; + var cacheKey = PropertyCacheKeyGenerator.GetCacheKey(metadata); // Check if already cached if (testContext.Metadata.TestDetails.TestClassInjectedPropertyArguments.ContainsKey(cacheKey)) @@ -469,26 +439,15 @@ private async Task ResolveAndCacheSourceGeneratedPropertyAsync( } [UnconditionalSuppressMessage("Trimming", "IL2026", Justification = "Reflection mode is not used in AOT")] - private async Task ResolveAndCacheReflectionPropertiesAsync( + private Task ResolveAndCacheReflectionPropertiesAsync( (PropertyInfo Property, IDataSourceAttribute DataSource)[] properties, ConcurrentDictionary objectBag, MethodMetadata? methodMetadata, TestContextEvents events, TestContext testContext) { - if (properties.Length == 0) - { - return; - } - - // Resolve properties in parallel without LINQ Select - var tasks = new Task[properties.Length]; - for (var i = 0; i < properties.Length; i++) - { - var pair = properties[i]; - tasks[i] = ResolveAndCacheReflectionPropertyAsync(pair.Property, pair.DataSource, objectBag, methodMetadata, events, testContext); - } - await Task.WhenAll(tasks); + return ParallelTaskHelper.ForEachAsync(properties, + pair => ResolveAndCacheReflectionPropertyAsync(pair.Property, pair.DataSource, objectBag, methodMetadata, events, testContext)); } [UnconditionalSuppressMessage("Trimming", "IL2026", Justification = "Reflection mode is not used in AOT")] @@ -500,7 +459,7 @@ private async Task ResolveAndCacheReflectionPropertyAsync( TestContextEvents events, TestContext testContext) { - var cacheKey = $"{property.DeclaringType!.FullName}.{property.Name}"; + var cacheKey = PropertyCacheKeyGenerator.GetCacheKey(property); // Check if already cached if (testContext.Metadata.TestDetails.TestClassInjectedPropertyArguments.ContainsKey(cacheKey)) From 8030e4238b59c65eaf66703620f78aa1695823a1 Mon Sep 17 00:00:00 2001 From: Tom Longhurst <30480171+thomhurst@users.noreply.github.com> Date: Sun, 7 Dec 2025 17:06:45 +0000 Subject: [PATCH 11/20] feat: implement IAsyncDiscoveryInitializer and related classes for improved test discovery handling --- .../Conditions/Helpers/ExpressionHelper.cs | 43 ++++++++++ .../Conditions/Helpers/ReflectionHelper.cs | 63 +++++++++++++++ .../Helpers/StructuralEqualityComparer.cs | 46 ++--------- .../Conditions/Helpers/TypeHelper.cs | 31 +++++++ .../StructuralEquivalencyAssertion.cs | 80 +++---------------- TUnit.Core/Helpers/Counter.cs | 44 +++++++++- .../Helpers/ReferenceEqualityComparer.cs | 9 +++ .../PropertyInjectionPlanBuilder.cs | 32 +++++--- TUnit.Core/TestContext.cs | 16 +++- TUnit.Core/Tracking/ObjectTracker.cs | 42 ++++++++-- 10 files changed, 274 insertions(+), 132 deletions(-) create mode 100644 TUnit.Assertions/Conditions/Helpers/ExpressionHelper.cs create mode 100644 TUnit.Assertions/Conditions/Helpers/ReflectionHelper.cs create mode 100644 TUnit.Assertions/Conditions/Helpers/TypeHelper.cs diff --git a/TUnit.Assertions/Conditions/Helpers/ExpressionHelper.cs b/TUnit.Assertions/Conditions/Helpers/ExpressionHelper.cs new file mode 100644 index 0000000000..4070c1c6a5 --- /dev/null +++ b/TUnit.Assertions/Conditions/Helpers/ExpressionHelper.cs @@ -0,0 +1,43 @@ +namespace TUnit.Assertions.Conditions.Helpers; + +/// +/// Helper methods for parsing and extracting information from assertion expressions. +/// Consolidates expression parsing logic to ensure consistent behavior across assertion classes. +/// +internal static class ExpressionHelper +{ + /// + /// Extracts the source variable name from an assertion expression string. + /// + /// The expression string, e.g., "Assert.That(variableName).IsEquivalentTo(...)" + /// The variable name, or "value" if it cannot be extracted or is a lambda expression. + /// + /// Input: "Assert.That(myObject).IsEquivalentTo(expected)" + /// Output: "myObject" + /// + /// Input: "Assert.That(async () => GetValue()).IsEquivalentTo(expected)" + /// Output: "value" + /// + public static string ExtractSourceVariable(string expression) + { + // Extract variable name from "Assert.That(variableName)" or similar + var thatIndex = expression.IndexOf(".That(", StringComparison.Ordinal); + if (thatIndex >= 0) + { + var startIndex = thatIndex + 6; // Length of ".That(" + var endIndex = expression.IndexOf(')', startIndex); + if (endIndex > startIndex) + { + var variable = expression.Substring(startIndex, endIndex - startIndex); + // Handle lambda expressions like "async () => ..." by returning "value" + if (variable.Contains("=>") || variable.StartsWith("()", StringComparison.Ordinal)) + { + return "value"; + } + return variable; + } + } + + return "value"; + } +} diff --git a/TUnit.Assertions/Conditions/Helpers/ReflectionHelper.cs b/TUnit.Assertions/Conditions/Helpers/ReflectionHelper.cs new file mode 100644 index 0000000000..c6c53162c5 --- /dev/null +++ b/TUnit.Assertions/Conditions/Helpers/ReflectionHelper.cs @@ -0,0 +1,63 @@ +using System.Diagnostics.CodeAnalysis; +using System.Reflection; + +namespace TUnit.Assertions.Conditions.Helpers; + +/// +/// Helper methods for reflection-based member access. +/// Consolidates reflection logic to ensure consistent behavior and reduce code duplication. +/// +internal static class ReflectionHelper +{ + /// + /// Gets all public instance properties and fields to compare for structural equivalency. + /// + /// The type to get members from. + /// A list of PropertyInfo and FieldInfo members. + public static List GetMembersToCompare( + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicFields)] + Type type) + { + var members = new List(); + members.AddRange(type.GetProperties(BindingFlags.Public | BindingFlags.Instance)); + members.AddRange(type.GetFields(BindingFlags.Public | BindingFlags.Instance)); + return members; + } + + /// + /// Gets the value of a member (property or field) from an object. + /// + /// The object to get the value from. + /// The member (PropertyInfo or FieldInfo) to read. + /// The value of the member. + /// Thrown if the member is not a PropertyInfo or FieldInfo. + public static object? GetMemberValue(object obj, MemberInfo member) + { + return member switch + { + PropertyInfo prop => prop.GetValue(obj), + FieldInfo field => field.GetValue(obj), + _ => throw new InvalidOperationException($"Unknown member type: {member.GetType()}") + }; + } + + /// + /// Gets a member (property or field) by name from a type. + /// + /// The type to search. + /// The member name to find. + /// The MemberInfo if found; null otherwise. + public static MemberInfo? GetMemberInfo( + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicFields)] + Type type, + string name) + { + var property = type.GetProperty(name, BindingFlags.Public | BindingFlags.Instance); + if (property != null) + { + return property; + } + + return type.GetField(name, BindingFlags.Public | BindingFlags.Instance); + } +} diff --git a/TUnit.Assertions/Conditions/Helpers/StructuralEqualityComparer.cs b/TUnit.Assertions/Conditions/Helpers/StructuralEqualityComparer.cs index 9ed8e63f3d..13273b4a45 100644 --- a/TUnit.Assertions/Conditions/Helpers/StructuralEqualityComparer.cs +++ b/TUnit.Assertions/Conditions/Helpers/StructuralEqualityComparer.cs @@ -1,6 +1,5 @@ using System.Collections; using System.Diagnostics.CodeAnalysis; -using System.Reflection; namespace TUnit.Assertions.Conditions.Helpers; @@ -36,7 +35,7 @@ public bool Equals(T? x, T? y) var type = typeof(T); - if (IsPrimitiveType(type)) + if (TypeHelper.IsPrimitiveOrWellKnownType(type)) { return EqualityComparer.Default.Equals(x, y); } @@ -54,23 +53,6 @@ public int GetHashCode(T obj) return EqualityComparer.Default.GetHashCode(obj); } - private static bool IsPrimitiveType(Type type) - { - return type.IsPrimitive - || type.IsEnum - || type == typeof(string) - || type == typeof(decimal) - || type == typeof(DateTime) - || type == typeof(DateTimeOffset) - || type == typeof(TimeSpan) - || type == typeof(Guid) -#if NET6_0_OR_GREATER - || type == typeof(DateOnly) - || type == typeof(TimeOnly) -#endif - ; - } - [UnconditionalSuppressMessage("Trimming", "IL2072", Justification = "GetType() is acceptable for runtime structural comparison")] private bool CompareStructurally(object? x, object? y, HashSet visited) { @@ -87,7 +69,7 @@ private bool CompareStructurally(object? x, object? y, HashSet visited) var xType = x.GetType(); var yType = y.GetType(); - if (IsPrimitiveType(xType)) + if (TypeHelper.IsPrimitiveOrWellKnownType(xType)) { return Equals(x, y); } @@ -121,12 +103,12 @@ private bool CompareStructurally(object? x, object? y, HashSet visited) return true; } - var members = GetMembersToCompare(xType); + var members = ReflectionHelper.GetMembersToCompare(xType); foreach (var member in members) { - var xValue = GetMemberValue(x, member); - var yValue = GetMemberValue(y, member); + var xValue = ReflectionHelper.GetMemberValue(x, member); + var yValue = ReflectionHelper.GetMemberValue(y, member); if (!CompareStructurally(xValue, yValue, visited)) { @@ -136,22 +118,4 @@ private bool CompareStructurally(object? x, object? y, HashSet visited) return true; } - - private static List GetMembersToCompare([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicFields)] Type type) - { - var members = new List(); - members.AddRange(type.GetProperties(BindingFlags.Public | BindingFlags.Instance)); - members.AddRange(type.GetFields(BindingFlags.Public | BindingFlags.Instance)); - return members; - } - - private static object? GetMemberValue(object obj, MemberInfo member) - { - return member switch - { - PropertyInfo prop => prop.GetValue(obj), - FieldInfo field => field.GetValue(obj), - _ => throw new InvalidOperationException($"Unknown member type: {member.GetType()}") - }; - } } diff --git a/TUnit.Assertions/Conditions/Helpers/TypeHelper.cs b/TUnit.Assertions/Conditions/Helpers/TypeHelper.cs new file mode 100644 index 0000000000..f497679a85 --- /dev/null +++ b/TUnit.Assertions/Conditions/Helpers/TypeHelper.cs @@ -0,0 +1,31 @@ +namespace TUnit.Assertions.Conditions.Helpers; + +/// +/// Helper methods for type checking and classification. +/// Consolidates type checking logic to ensure consistent behavior across assertion classes. +/// +internal static class TypeHelper +{ + /// + /// Determines if a type is a primitive or well-known immutable type that should use + /// value equality rather than structural comparison. + /// + /// The type to check. + /// True if the type should use value equality; false for structural comparison. + public static bool IsPrimitiveOrWellKnownType(Type type) + { + return type.IsPrimitive + || type.IsEnum + || type == typeof(string) + || type == typeof(decimal) + || type == typeof(DateTime) + || type == typeof(DateTimeOffset) + || type == typeof(TimeSpan) + || type == typeof(Guid) +#if NET6_0_OR_GREATER + || type == typeof(DateOnly) + || type == typeof(TimeOnly) +#endif + ; + } +} diff --git a/TUnit.Assertions/Conditions/StructuralEquivalencyAssertion.cs b/TUnit.Assertions/Conditions/StructuralEquivalencyAssertion.cs index e17718a046..09c12803e7 100644 --- a/TUnit.Assertions/Conditions/StructuralEquivalencyAssertion.cs +++ b/TUnit.Assertions/Conditions/StructuralEquivalencyAssertion.cs @@ -110,7 +110,7 @@ internal AssertionResult CompareObjects(object? actual, object? expected, string var expectedType = expected.GetType(); // Handle primitive types and strings - if (IsPrimitiveType(actualType)) + if (TypeHelper.IsPrimitiveOrWellKnownType(actualType)) { if (!Equals(actual, expected)) { @@ -168,7 +168,7 @@ internal AssertionResult CompareObjects(object? actual, object? expected, string } // Compare properties and fields - var expectedMembers = GetMembersToCompare(expectedType); + var expectedMembers = ReflectionHelper.GetMembersToCompare(expectedType); foreach (var member in expectedMembers) { @@ -179,7 +179,7 @@ internal AssertionResult CompareObjects(object? actual, object? expected, string continue; } - var expectedValue = GetMemberValue(expected, member); + var expectedValue = ReflectionHelper.GetMemberValue(expected, member); // Check if this member's type should be ignored var memberType = member switch @@ -199,22 +199,22 @@ internal AssertionResult CompareObjects(object? actual, object? expected, string // In partial equivalency mode, skip members that don't exist on actual if (_usePartialEquivalency) { - var actualMember = GetMemberInfo(actualType, member.Name); + var actualMember = ReflectionHelper.GetMemberInfo(actualType, member.Name); if (actualMember == null) { continue; } - actualValue = GetMemberValue(actual, actualMember); + actualValue = ReflectionHelper.GetMemberValue(actual, actualMember); } else { - var actualMember = GetMemberInfo(actualType, member.Name); + var actualMember = ReflectionHelper.GetMemberInfo(actualType, member.Name); if (actualMember == null) { return AssertionResult.Failed($"Property {memberPath} did not match{Environment.NewLine}Expected: {FormatValue(expectedValue)}{Environment.NewLine}Received: null"); } - actualValue = GetMemberValue(actual, actualMember); + actualValue = ReflectionHelper.GetMemberValue(actual, actualMember); } var result = CompareObjects(actualValue, expectedValue, memberPath, visited); @@ -227,7 +227,7 @@ internal AssertionResult CompareObjects(object? actual, object? expected, string // In non-partial mode, check for extra properties on actual if (!_usePartialEquivalency) { - var actualMembers = GetMembersToCompare(actualType); + var actualMembers = ReflectionHelper.GetMembersToCompare(actualType); var expectedMemberNames = new HashSet(expectedMembers.Select(m => m.Name)); foreach (var member in actualMembers) @@ -248,7 +248,7 @@ internal AssertionResult CompareObjects(object? actual, object? expected, string } var memberPath = string.IsNullOrEmpty(path) ? member.Name : $"{path}.{member.Name}"; - var actualValue = GetMemberValue(actual, member); + var actualValue = ReflectionHelper.GetMemberValue(actual, member); // Skip properties with null values - they're equivalent to not having the property if (actualValue == null) @@ -264,13 +264,6 @@ internal AssertionResult CompareObjects(object? actual, object? expected, string return AssertionResult.Passed; } - private static bool IsPrimitiveType(Type type) - { - return type.IsPrimitive || type.IsEnum || type == typeof(string) || type == typeof(decimal) - || type == typeof(DateTime) || type == typeof(DateTimeOffset) || type == typeof(TimeSpan) - || type == typeof(Guid); - } - private bool ShouldIgnoreType(Type type) { // Check if the type itself should be ignored @@ -289,36 +282,6 @@ private bool ShouldIgnoreType(Type type) return false; } - private static List GetMembersToCompare([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicFields)] Type type) - { - var members = new List(); - members.AddRange(type.GetProperties(BindingFlags.Public | BindingFlags.Instance)); - members.AddRange(type.GetFields(BindingFlags.Public | BindingFlags.Instance)); - return members; - } - - private static MemberInfo? GetMemberInfo([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicFields)] Type type, string name) - { - var property = type.GetProperty(name, BindingFlags.Public | BindingFlags.Instance); - if (property != null) - { - return property; - } - - var field = type.GetField(name, BindingFlags.Public | BindingFlags.Instance); - return field; - } - - private static object? GetMemberValue(object obj, MemberInfo member) - { - return member switch - { - PropertyInfo prop => prop.GetValue(obj), - FieldInfo field => field.GetValue(obj), - _ => throw new InvalidOperationException($"Unknown member type: {member.GetType()}") - }; - } - private static string FormatValue(object? value) { if (value == null) @@ -339,32 +302,9 @@ protected override string GetExpectation() // Extract the source variable name from the expression builder // Format: "Assert.That(variableName).IsEquivalentTo(...)" var expressionString = Context.ExpressionBuilder.ToString(); - var sourceVariable = ExtractSourceVariable(expressionString); + var sourceVariable = ExpressionHelper.ExtractSourceVariable(expressionString); var expectedDesc = _expectedExpression ?? "expected value"; return $"{sourceVariable} to be equivalent to {expectedDesc}"; } - - private static string ExtractSourceVariable(string expression) - { - // Extract variable name from "Assert.That(variableName)" or similar - var thatIndex = expression.IndexOf(".That("); - if (thatIndex >= 0) - { - var startIndex = thatIndex + 6; // Length of ".That(" - var endIndex = expression.IndexOf(')', startIndex); - if (endIndex > startIndex) - { - var variable = expression.Substring(startIndex, endIndex - startIndex); - // Handle lambda expressions like "async () => ..." by returning "value" - if (variable.Contains("=>") || variable.StartsWith("()")) - { - return "value"; - } - return variable; - } - } - - return "value"; - } } diff --git a/TUnit.Core/Helpers/Counter.cs b/TUnit.Core/Helpers/Counter.cs index 6fb4080833..ef7722e02d 100644 --- a/TUnit.Core/Helpers/Counter.cs +++ b/TUnit.Core/Helpers/Counter.cs @@ -21,7 +21,7 @@ public int Increment() var handler = _onCountChanged; var newCount = Interlocked.Increment(ref _count); - handler?.Invoke(this, newCount); + RaiseEventSafely(handler, newCount); return newCount; } @@ -33,7 +33,7 @@ public int Decrement() var handler = _onCountChanged; var newCount = Interlocked.Decrement(ref _count); - handler?.Invoke(this, newCount); + RaiseEventSafely(handler, newCount); return newCount; } @@ -45,11 +45,49 @@ public int Add(int value) var handler = _onCountChanged; var newCount = Interlocked.Add(ref _count, value); - handler?.Invoke(this, newCount); + RaiseEventSafely(handler, newCount); return newCount; } + /// + /// Raises the event safely, ensuring all subscribers are notified even if some throw exceptions. + /// Collects all exceptions and throws AggregateException if any occurred. + /// + private void RaiseEventSafely(EventHandler? handler, int newCount) + { + if (handler == null) + { + return; + } + + var invocationList = handler.GetInvocationList(); + List? exceptions = null; + + foreach (var subscriber in invocationList) + { + try + { + ((EventHandler)subscriber).Invoke(this, newCount); + } + catch (Exception ex) + { + exceptions ??= []; + exceptions.Add(ex); + +#if DEBUG + Debug.WriteLine($"[Counter] Exception in OnCountChanged subscriber: {ex.Message}"); +#endif + } + } + + // If any subscribers threw, aggregate and rethrow after all are notified + if (exceptions?.Count > 0) + { + throw new AggregateException("One or more OnCountChanged subscribers threw an exception.", exceptions); + } + } + public int CurrentCount => Interlocked.CompareExchange(ref _count, 0, 0); public event EventHandler? OnCountChanged diff --git a/TUnit.Core/Helpers/ReferenceEqualityComparer.cs b/TUnit.Core/Helpers/ReferenceEqualityComparer.cs index e2058527f0..16da75d77b 100644 --- a/TUnit.Core/Helpers/ReferenceEqualityComparer.cs +++ b/TUnit.Core/Helpers/ReferenceEqualityComparer.cs @@ -20,6 +20,15 @@ private ReferenceEqualityComparer() { } + /// + /// Compares two objects by reference identity. + /// + /// + /// The 'new' keyword is used because this method explicitly implements + /// IEqualityComparer<object>.Equals with nullable parameters, which + /// hides the inherited static Object.Equals(object?, object?) method. + /// This is intentional and provides the correct behavior for reference equality. + /// public new bool Equals(object? x, object? y) { return ReferenceEquals(x, y); diff --git a/TUnit.Core/PropertyInjection/PropertyInjectionPlanBuilder.cs b/TUnit.Core/PropertyInjection/PropertyInjectionPlanBuilder.cs index 0c16dc1d14..9c950da9c5 100644 --- a/TUnit.Core/PropertyInjection/PropertyInjectionPlanBuilder.cs +++ b/TUnit.Core/PropertyInjection/PropertyInjectionPlanBuilder.cs @@ -10,6 +10,22 @@ namespace TUnit.Core.PropertyInjection; /// internal static class PropertyInjectionPlanBuilder { + /// + /// Walks up the inheritance chain from the given type to typeof(object), + /// invoking the action for each type in the hierarchy. + /// + /// The starting type. + /// The action to invoke for each type in the inheritance chain. + private static void WalkInheritanceChain(Type type, Action action) + { + var currentType = type; + while (currentType != null && currentType != typeof(object)) + { + action(currentType); + currentType = currentType.BaseType; + } + } + /// /// Creates an injection plan for source-generated mode. /// Walks the inheritance chain to include all injectable properties from base classes. @@ -20,8 +36,7 @@ public static PropertyInjectionPlan BuildSourceGeneratedPlan(Type type) var processedProperties = new HashSet(); // Walk up the inheritance chain to find all properties with data sources - var currentType = type; - while (currentType != null && currentType != typeof(object)) + WalkInheritanceChain(type, currentType => { var propertySource = PropertySourceRegistry.GetSource(currentType); if (propertySource?.ShouldInitialize == true) @@ -35,9 +50,7 @@ public static PropertyInjectionPlan BuildSourceGeneratedPlan(Type type) } } } - - currentType = currentType.BaseType; - } + }); var sourceGenProps = allProperties.ToArray(); @@ -62,8 +75,7 @@ public static PropertyInjectionPlan BuildReflectionPlan(Type type) var processedProperties = new HashSet(); // Walk up the inheritance chain to find all properties with data source attributes - var currentType = type; - while (currentType != null && currentType != typeof(object)) + WalkInheritanceChain(type, currentType => { var properties = currentType.GetProperties( BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Static | BindingFlags.DeclaredOnly) @@ -87,9 +99,7 @@ public static PropertyInjectionPlan BuildReflectionPlan(Type type) } } } - - currentType = currentType.BaseType; - } + }); return new PropertyInjectionPlan { @@ -106,7 +116,7 @@ public static PropertyInjectionPlan BuildReflectionPlan(Type type) /// This handles generic types like ErrFixture<MyType> where the source generator /// couldn't register a property source for the closed generic type. /// - [UnconditionalSuppressMessage("Trimming", "IL2026:Members annotated with 'RequiresUnreferencedCodeAttribute' require dynamic access otherwise can break functionality when trimming application code", Justification = "Source gen mode has its own path>")] + [UnconditionalSuppressMessage("Trimming", "IL2026:Members annotated with 'RequiresUnreferencedCodeAttribute' require dynamic access otherwise can break functionality when trimming application code", Justification = "Source gen mode has its own path")] public static PropertyInjectionPlan Build(Type type) { if (!SourceRegistrar.IsEnabled) diff --git a/TUnit.Core/TestContext.cs b/TUnit.Core/TestContext.cs index bdd7d57a7b..f5c366da4a 100644 --- a/TUnit.Core/TestContext.cs +++ b/TUnit.Core/TestContext.cs @@ -73,7 +73,21 @@ internal set public static IReadOnlyDictionary> Parameters => InternalParametersDictionary; - public static IConfiguration Configuration { get; internal set; } = null!; + private static IConfiguration? _configuration; + + /// + /// Gets the test configuration. Throws a descriptive exception if accessed before initialization. + /// + /// Thrown if Configuration is accessed before the test engine initializes it. + public static IConfiguration Configuration + { + get => _configuration ?? throw new InvalidOperationException( + "TestContext.Configuration has not been initialized. " + + "This property is only available after the TUnit test engine has started. " + + "If you are accessing this from a static constructor or field initializer, " + + "consider moving the code to a test setup method or test body instead."); + internal set => _configuration = value; + } public static string? OutputDirectory { diff --git a/TUnit.Core/Tracking/ObjectTracker.cs b/TUnit.Core/Tracking/ObjectTracker.cs index 3b86f5d04a..681ba0134b 100644 --- a/TUnit.Core/Tracking/ObjectTracker.cs +++ b/TUnit.Core/Tracking/ObjectTracker.cs @@ -189,6 +189,7 @@ private static bool ShouldSkipTracking(object? obj) /// /// Registers a callback to be invoked when the object is disposed (ref count reaches 0). /// If the object is already disposed, the callback is invoked immediately. + /// The callback is guaranteed to be invoked exactly once (idempotent). /// public static void OnDisposed(object? o, Action action) { @@ -199,18 +200,32 @@ public static void OnDisposed(object? o, Action action) var counter = s_trackedObjects.GetOrAdd(o, static _ => new Counter()); - counter.OnCountChanged += (_, count) => + // Use flag to ensure callback only fires once (idempotent) + var invoked = 0; + EventHandler? handler = null; + + handler = (sender, count) => { - if (count == 0) + if (count == 0 && Interlocked.Exchange(ref invoked, 1) == 0) { + // Remove handler to prevent memory leaks + if (sender is Counter c && handler != null) + { + c.OnCountChanged -= handler; + } + action(); } }; + counter.OnCountChanged += handler; + // Check if already disposed (count is 0) - invoke immediately if so // This prevents lost callbacks when registering after disposal - if (counter.CurrentCount == 0) + // Idempotent check ensures this doesn't double-fire if event already triggered + if (counter.CurrentCount == 0 && Interlocked.Exchange(ref invoked, 1) == 0) { + counter.OnCountChanged -= handler; action(); } } @@ -218,6 +233,7 @@ public static void OnDisposed(object? o, Action action) /// /// Registers an async callback to be invoked when the object is disposed (ref count reaches 0). /// If the object is already disposed, the callback is invoked immediately. + /// The callback is guaranteed to be invoked exactly once (idempotent). /// public static void OnDisposedAsync(object? o, Func asyncAction) { @@ -228,20 +244,34 @@ public static void OnDisposedAsync(object? o, Func asyncAction) var counter = s_trackedObjects.GetOrAdd(o, static _ => new Counter()); + // Use flag to ensure callback only fires once (idempotent) + var invoked = 0; + EventHandler? handler = null; + // Avoid async void pattern by wrapping in fire-and-forget with exception handling - counter.OnCountChanged += (_, count) => + handler = (sender, count) => { - if (count == 0) + if (count == 0 && Interlocked.Exchange(ref invoked, 1) == 0) { + // Remove handler to prevent memory leaks + if (sender is Counter c && handler != null) + { + c.OnCountChanged -= handler; + } + // Fire-and-forget with exception collection to surface errors _ = SafeExecuteAsync(asyncAction); } }; + counter.OnCountChanged += handler; + // Check if already disposed (count is 0) - invoke immediately if so // This prevents lost callbacks when registering after disposal - if (counter.CurrentCount == 0) + // Idempotent check ensures this doesn't double-fire if event already triggered + if (counter.CurrentCount == 0 && Interlocked.Exchange(ref invoked, 1) == 0) { + counter.OnCountChanged -= handler; _ = SafeExecuteAsync(asyncAction); } } From 164668005b70dbe01e4ce399c8fd7484eedef69f Mon Sep 17 00:00:00 2001 From: Tom Longhurst <30480171+thomhurst@users.noreply.github.com> Date: Sun, 7 Dec 2025 17:49:46 +0000 Subject: [PATCH 12/20] feat: enhance type handling and object tracking with custom primitives and improved disposal logic --- .../Conditions/Helpers/TypeHelper.cs | 53 +++++++++++ .../NotStructuralEquivalencyAssertion.cs | 32 ++----- .../StructuralEquivalencyAssertion.cs | 36 ++++++-- TUnit.Core/Helpers/Counter.cs | 13 +++ TUnit.Core/ObjectInitializer.cs | 6 +- TUnit.Core/TestContext.cs | 7 +- TUnit.Core/Tracking/ObjectTracker.cs | 76 ++++++++++++---- ...Has_No_API_Changes.DotNet10_0.verified.txt | 89 +++++++++++++++++-- ..._Has_No_API_Changes.DotNet8_0.verified.txt | 89 +++++++++++++++++-- ..._Has_No_API_Changes.DotNet9_0.verified.txt | 89 +++++++++++++++++-- ...ary_Has_No_API_Changes.Net4_7.verified.txt | 89 +++++++++++++++++-- 11 files changed, 504 insertions(+), 75 deletions(-) diff --git a/TUnit.Assertions/Conditions/Helpers/TypeHelper.cs b/TUnit.Assertions/Conditions/Helpers/TypeHelper.cs index f497679a85..856a05d9e6 100644 --- a/TUnit.Assertions/Conditions/Helpers/TypeHelper.cs +++ b/TUnit.Assertions/Conditions/Helpers/TypeHelper.cs @@ -1,3 +1,5 @@ +using System.Collections.Concurrent; + namespace TUnit.Assertions.Conditions.Helpers; /// @@ -6,6 +8,51 @@ namespace TUnit.Assertions.Conditions.Helpers; /// internal static class TypeHelper { + /// + /// Thread-safe registry of user-defined types that should be treated as primitives + /// (using value equality rather than structural comparison). + /// + private static readonly ConcurrentDictionary CustomPrimitiveTypes = new(); + + /// + /// Registers a type to be treated as a primitive for structural equivalency comparisons. + /// Once registered, instances of this type will use value equality (via Equals) rather + /// than having their properties compared individually. + /// + /// The type to register as a primitive. + public static void RegisterAsPrimitive() + { + CustomPrimitiveTypes.TryAdd(typeof(T), 0); + } + + /// + /// Registers a type to be treated as a primitive for structural equivalency comparisons. + /// + /// The type to register as a primitive. + public static void RegisterAsPrimitive(Type type) + { + CustomPrimitiveTypes.TryAdd(type, 0); + } + + /// + /// Removes a previously registered custom primitive type. + /// + /// The type to unregister. + /// True if the type was removed; false if it wasn't registered. + public static bool UnregisterPrimitive() + { + return CustomPrimitiveTypes.TryRemove(typeof(T), out _); + } + + /// + /// Clears all registered custom primitive types. + /// Useful for test cleanup between tests. + /// + public static void ClearCustomPrimitives() + { + CustomPrimitiveTypes.Clear(); + } + /// /// Determines if a type is a primitive or well-known immutable type that should use /// value equality rather than structural comparison. @@ -14,6 +61,12 @@ internal static class TypeHelper /// True if the type should use value equality; false for structural comparison. public static bool IsPrimitiveOrWellKnownType(Type type) { + // Check user-defined primitives first (fast path for common case) + if (CustomPrimitiveTypes.ContainsKey(type)) + { + return true; + } + return type.IsPrimitive || type.IsEnum || type == typeof(string) diff --git a/TUnit.Assertions/Conditions/NotStructuralEquivalencyAssertion.cs b/TUnit.Assertions/Conditions/NotStructuralEquivalencyAssertion.cs index 32399096b5..2d8c30306a 100644 --- a/TUnit.Assertions/Conditions/NotStructuralEquivalencyAssertion.cs +++ b/TUnit.Assertions/Conditions/NotStructuralEquivalencyAssertion.cs @@ -91,7 +91,12 @@ protected override Task CheckAsync(EvaluationMetadata m foreach (var type in _ignoredTypes) tempAssertion.IgnoringType(type); - var result = tempAssertion.CompareObjects(value, _notExpected, "", new HashSet(ReferenceEqualityComparer.Instance)); + var result = tempAssertion.CompareObjects( + value, + _notExpected, + "", + new HashSet(ReferenceEqualityComparer.Instance), + new HashSet(ReferenceEqualityComparer.Instance)); // Invert the result - we want them to NOT be equivalent if (result.IsPassed) @@ -107,32 +112,9 @@ protected override string GetExpectation() // Extract the source variable name from the expression builder // Format: "Assert.That(variableName).IsNotEquivalentTo(...)" var expressionString = Context.ExpressionBuilder.ToString(); - var sourceVariable = ExtractSourceVariable(expressionString); + var sourceVariable = ExpressionHelper.ExtractSourceVariable(expressionString); var notExpectedDesc = _notExpectedExpression ?? "expected value"; return $"{sourceVariable} to not be equivalent to {notExpectedDesc}"; } - - private static string ExtractSourceVariable(string expression) - { - // Extract variable name from "Assert.That(variableName)" or similar - var thatIndex = expression.IndexOf(".That("); - if (thatIndex >= 0) - { - var startIndex = thatIndex + 6; // Length of ".That(" - var endIndex = expression.IndexOf(')', startIndex); - if (endIndex > startIndex) - { - var variable = expression.Substring(startIndex, endIndex - startIndex); - // Handle lambda expressions like "async () => ..." by returning "value" - if (variable.Contains("=>") || variable.StartsWith("()")) - { - return "value"; - } - return variable; - } - } - - return "value"; - } } diff --git a/TUnit.Assertions/Conditions/StructuralEquivalencyAssertion.cs b/TUnit.Assertions/Conditions/StructuralEquivalencyAssertion.cs index 09c12803e7..2f1086a1a7 100644 --- a/TUnit.Assertions/Conditions/StructuralEquivalencyAssertion.cs +++ b/TUnit.Assertions/Conditions/StructuralEquivalencyAssertion.cs @@ -78,11 +78,21 @@ protected override Task CheckAsync(EvaluationMetadata m return Task.FromResult(AssertionResult.Failed($"threw {exception.GetType().Name}: {exception.Message}")); } - var result = CompareObjects(value, _expected, "", new HashSet(ReferenceEqualityComparer.Instance)); + var result = CompareObjects( + value, + _expected, + "", + new HashSet(ReferenceEqualityComparer.Instance), + new HashSet(ReferenceEqualityComparer.Instance)); return Task.FromResult(result); } - internal AssertionResult CompareObjects(object? actual, object? expected, string path, HashSet visited) + internal AssertionResult CompareObjects( + object? actual, + object? expected, + string path, + HashSet visitedActual, + HashSet? visitedExpected = null) { // Check for ignored paths if (_ignoredMembers.Contains(path)) @@ -119,13 +129,25 @@ internal AssertionResult CompareObjects(object? actual, object? expected, string return AssertionResult.Passed; } - // Handle cycles - if (visited.Contains(actual)) + // Handle cycles - check both actual and expected to prevent infinite recursion + // from cycles in either object graph + if (visitedActual.Contains(actual)) { return AssertionResult.Passed; } - visited.Add(actual); + visitedActual.Add(actual); + + // Also track expected objects to handle cycles in the expected graph + if (visitedExpected != null) + { + if (visitedExpected.Contains(expected)) + { + return AssertionResult.Passed; + } + + visitedExpected.Add(expected); + } // Handle enumerables if (actual is IEnumerable actualEnumerable && expected is IEnumerable expectedEnumerable @@ -157,7 +179,7 @@ internal AssertionResult CompareObjects(object? actual, object? expected, string return AssertionResult.Failed($"{itemPath} did not match{Environment.NewLine}Expected: null{Environment.NewLine}Received: {FormatValue(actualList[i])}"); } - var result = CompareObjects(actualList[i], expectedList[i], itemPath, visited); + var result = CompareObjects(actualList[i], expectedList[i], itemPath, visitedActual, visitedExpected); if (!result.IsPassed) { return result; @@ -217,7 +239,7 @@ internal AssertionResult CompareObjects(object? actual, object? expected, string actualValue = ReflectionHelper.GetMemberValue(actual, actualMember); } - var result = CompareObjects(actualValue, expectedValue, memberPath, visited); + var result = CompareObjects(actualValue, expectedValue, memberPath, visitedActual, visitedExpected); if (!result.IsPassed) { return result; diff --git a/TUnit.Core/Helpers/Counter.cs b/TUnit.Core/Helpers/Counter.cs index ef7722e02d..e4ca4b759a 100644 --- a/TUnit.Core/Helpers/Counter.cs +++ b/TUnit.Core/Helpers/Counter.cs @@ -38,6 +38,12 @@ public int Decrement() return newCount; } + /// + /// Adds a value to the counter. Use Increment/Decrement for single-step changes. + /// + /// The value to add (can be positive or negative). + /// The new count after the addition. + /// Thrown if the resulting count is negative, indicating a logic error. public int Add(int value) { // Capture handler BEFORE state change to ensure all subscribers @@ -45,6 +51,13 @@ public int Add(int value) var handler = _onCountChanged; var newCount = Interlocked.Add(ref _count, value); + // Guard against reference count going negative - indicates a bug in calling code + if (newCount < 0) + { + throw new InvalidOperationException( + $"Counter went below zero (result: {newCount}). This indicates a bug in the reference counting logic."); + } + RaiseEventSafely(handler, newCount); return newCount; diff --git a/TUnit.Core/ObjectInitializer.cs b/TUnit.Core/ObjectInitializer.cs index 24efeeed71..46c52923e8 100644 --- a/TUnit.Core/ObjectInitializer.cs +++ b/TUnit.Core/ObjectInitializer.cs @@ -18,7 +18,7 @@ namespace TUnit.Core; /// For dependency injection scenarios, use directly. /// /// -public static class ObjectInitializer +internal static class ObjectInitializer { // Use Lazy pattern to ensure InitializeAsync is called exactly once per object, // even under contention. GetOrAdd's factory can be called multiple times, but with @@ -34,7 +34,7 @@ public static class ObjectInitializer /// /// The object to potentially initialize. /// Cancellation token. - public static ValueTask InitializeForDiscoveryAsync(object? obj, CancellationToken cancellationToken = default) + internal static ValueTask InitializeForDiscoveryAsync(object? obj, CancellationToken cancellationToken = default) { // During discovery, only initialize IAsyncDiscoveryInitializer if (obj is not IAsyncDiscoveryInitializer asyncDiscoveryInitializer) @@ -52,7 +52,7 @@ public static ValueTask InitializeForDiscoveryAsync(object? obj, CancellationTok /// /// The object to potentially initialize. /// Cancellation token. - public static ValueTask InitializeAsync(object? obj, CancellationToken cancellationToken = default) + internal static ValueTask InitializeAsync(object? obj, CancellationToken cancellationToken = default) { if (obj is not IAsyncInitializer asyncInitializer) { diff --git a/TUnit.Core/TestContext.cs b/TUnit.Core/TestContext.cs index f5c366da4a..6ef95724de 100644 --- a/TUnit.Core/TestContext.cs +++ b/TUnit.Core/TestContext.cs @@ -173,8 +173,13 @@ internal override void SetAsyncLocalContext() internal AbstractExecutableTest InternalExecutableTest { get; set; } = null!; private ConcurrentDictionary>? _trackedObjects; + + /// + /// Thread-safe lazy initialization of TrackedObjects using LazyInitializer + /// to prevent race conditions when multiple threads access this property simultaneously. + /// internal ConcurrentDictionary> TrackedObjects => - _trackedObjects ??= new(); + LazyInitializer.EnsureInitialized(ref _trackedObjects)!; /// /// Sets the output captured during test building phase. diff --git a/TUnit.Core/Tracking/ObjectTracker.cs b/TUnit.Core/Tracking/ObjectTracker.cs index 681ba0134b..a0d0d28974 100644 --- a/TUnit.Core/Tracking/ObjectTracker.cs +++ b/TUnit.Core/Tracking/ObjectTracker.cs @@ -19,6 +19,9 @@ internal class ObjectTracker(TrackableObjectGraphProvider trackableObjectGraphPr private static readonly ConcurrentDictionary s_trackedObjects = new(Helpers.ReferenceEqualityComparer.Instance); + // Lock for atomic decrement-check-dispose operations to prevent race conditions + private static readonly object s_disposalLock = new(); + // Collects errors from async disposal callbacks for post-session review private static readonly ConcurrentBag s_asyncCallbackErrors = new(); @@ -37,6 +40,13 @@ public static void ClearStaticTracking() s_asyncCallbackErrors.Clear(); } + /// + /// Gets an existing counter for the object or creates a new one. + /// Centralizes the GetOrAdd pattern to ensure consistent counter creation. + /// + private static Counter GetOrCreateCounter(object obj) => + s_trackedObjects.GetOrAdd(obj, static _ => new Counter()); + /// /// Flattens a ConcurrentDictionary of depth-keyed HashSets into a single HashSet. /// Thread-safe: locks each HashSet while copying. @@ -147,7 +157,7 @@ private void TrackObject(object? obj) return; } - var counter = s_trackedObjects.GetOrAdd(obj, static _ => new Counter()); + var counter = GetOrCreateCounter(obj); counter.Increment(); } @@ -158,24 +168,36 @@ private async ValueTask UntrackObject(object? obj) return; } - if (s_trackedObjects.TryGetValue(obj, out var counter)) - { - var count = counter.Decrement(); + var shouldDispose = false; - if (count < 0) + // Use lock to make decrement-check-remove atomic and prevent race conditions + // where multiple tests could try to dispose the same object simultaneously + lock (s_disposalLock) + { + if (s_trackedObjects.TryGetValue(obj, out var counter)) { - throw new InvalidOperationException("Reference count for object went below zero. This indicates a bug in the reference counting logic."); - } + var count = counter.Decrement(); - if (count == 0) - { - // Remove from tracking dictionary to prevent memory leak - // Use TryRemove to ensure atomicity - only remove if still in dictionary - s_trackedObjects.TryRemove(obj, out _); + if (count < 0) + { + throw new InvalidOperationException("Reference count for object went below zero. This indicates a bug in the reference counting logic."); + } - await disposer.DisposeAsync(obj).ConfigureAwait(false); + if (count == 0) + { + // Remove from tracking dictionary to prevent memory leak + // Use TryRemove to ensure atomicity - only remove if still in dictionary + s_trackedObjects.TryRemove(obj, out _); + shouldDispose = true; + } } } + + // Dispose outside the lock to avoid blocking other untrack operations + if (shouldDispose) + { + await disposer.DisposeAsync(obj).ConfigureAwait(false); + } } /// @@ -191,14 +213,26 @@ private static bool ShouldSkipTracking(object? obj) /// If the object is already disposed, the callback is invoked immediately. /// The callback is guaranteed to be invoked exactly once (idempotent). /// + /// The object to monitor for disposal. If null or not disposable, the method returns without action. + /// The callback to invoke on disposal. Must not be null. + /// Thrown when is null. public static void OnDisposed(object? o, Action action) { +#if NET6_0_OR_GREATER + ArgumentNullException.ThrowIfNull(action); +#else + if (action == null) + { + throw new ArgumentNullException(nameof(action)); + } +#endif + if (o is not IDisposable and not IAsyncDisposable) { return; } - var counter = s_trackedObjects.GetOrAdd(o, static _ => new Counter()); + var counter = GetOrCreateCounter(o); // Use flag to ensure callback only fires once (idempotent) var invoked = 0; @@ -235,14 +269,26 @@ public static void OnDisposed(object? o, Action action) /// If the object is already disposed, the callback is invoked immediately. /// The callback is guaranteed to be invoked exactly once (idempotent). /// + /// The object to monitor for disposal. If null or not disposable, the method returns without action. + /// The async callback to invoke on disposal. Must not be null. + /// Thrown when is null. public static void OnDisposedAsync(object? o, Func asyncAction) { +#if NET6_0_OR_GREATER + ArgumentNullException.ThrowIfNull(asyncAction); +#else + if (asyncAction == null) + { + throw new ArgumentNullException(nameof(asyncAction)); + } +#endif + if (o is not IDisposable and not IAsyncDisposable) { return; } - var counter = s_trackedObjects.GetOrAdd(o, static _ => new Counter()); + var counter = GetOrCreateCounter(o); // Use flag to ensure callback only fires once (idempotent) var invoked = 0; diff --git a/TUnit.PublicAPI/Tests.Core_Library_Has_No_API_Changes.DotNet10_0.verified.txt b/TUnit.PublicAPI/Tests.Core_Library_Has_No_API_Changes.DotNet10_0.verified.txt index 33684f6146..d6c7f837e4 100644 --- a/TUnit.PublicAPI/Tests.Core_Library_Has_No_API_Changes.DotNet10_0.verified.txt +++ b/TUnit.PublicAPI/Tests.Core_Library_Has_No_API_Changes.DotNet10_0.verified.txt @@ -1000,10 +1000,6 @@ namespace public . NotInParallelConstraintKeys { get; init; } public int Order { get; set; } } - public static class ObjectInitializer - { - public static . InitializeAsync(object? obj, .CancellationToken cancellationToken = default) { } - } public class ParallelGroupAttribute : .TUnitAttribute, ., . { public ParallelGroupAttribute(string group) { } @@ -1668,6 +1664,36 @@ namespace .DataSources public static string FormatArguments(object?[] arguments, .<> formatters) { } } } +namespace .Discovery +{ + public readonly struct DiscoveryError : <.> + { + public DiscoveryError(string TypeName, string PropertyName, string ErrorMessage, Exception) { } + public string ErrorMessage { get; init; } + public Exception { get; init; } + public string PropertyName { get; init; } + public string TypeName { get; init; } + } + public sealed class ObjectGraph : . + { + public ObjectGraph(.> objectsByDepth, . allObjects) { } + public . AllObjects { get; } + public int MaxDepth { get; } + public .> ObjectsByDepth { get; } + public . GetDepthsDescending() { } + public . GetObjectsAtDepth(int depth) { } + } + public sealed class ObjectGraphDiscoverer : ., . + { + public ObjectGraphDiscoverer() { } + public .> DiscoverAndTrackObjects(.TestContext testContext, .CancellationToken cancellationToken = default) { } + public . DiscoverNestedObjectGraph(object rootObject, .CancellationToken cancellationToken = default) { } + public . DiscoverObjectGraph(.TestContext testContext, .CancellationToken cancellationToken = default) { } + public static void ClearCache() { } + public static void ClearDiscoveryErrors() { } + public static .<.> GetDiscoveryErrors() { } + } +} namespace .Enums { public enum DataGeneratorType @@ -2021,14 +2047,23 @@ namespace .Helpers [.("MakeGenericType requires runtime code generation")] public static MakeGenericTypeSafe( genericTypeDefinition, params [] typeArguments) { } } + public static class ParallelTaskHelper + { + public static . ForEachAsync(. items, action) { } + public static . ForEachAsync(T[] items, action) { } + public static . ForEachAsync(. items, action, .CancellationToken cancellationToken) { } + public static . ForEachAsync(T[] items, action, .CancellationToken cancellationToken) { } + public static . ForEachWithIndexAsync(T[] items, action) { } + public static . ForEachWithIndexAsync(T[] items, action, .CancellationToken cancellationToken) { } + } public class ProcessorCountParallelLimit : . { public ProcessorCountParallelLimit() { } public int Limit { get; } } - public class ReferenceEqualityComparer : . + public sealed class ReferenceEqualityComparer : . { - public ReferenceEqualityComparer() { } + public static readonly . Instance; public bool Equals(object? x, object? y) { } public int GetHashCode(object obj) { } } @@ -2229,6 +2264,10 @@ namespace .Interfaces { .<> GenerateDataFactories(.DataSourceContext context); } + public interface IDisposer + { + . DisposeAsync(object? obj); + } public interface IEventReceiver { int Order { get; } @@ -2289,6 +2328,28 @@ namespace .Interfaces { . OnLastTestInTestSession(.TestSessionContext current, .TestContext testContext); } + public interface IObjectGraph + { + . AllObjects { get; } + int MaxDepth { get; } + .> ObjectsByDepth { get; } + . GetDepthsDescending(); + . GetObjectsAtDepth(int depth); + } + public interface IObjectGraphDiscoverer + { + .> DiscoverAndTrackObjects(.TestContext testContext, .CancellationToken cancellationToken = default); + . DiscoverNestedObjectGraph(object rootObject, .CancellationToken cancellationToken = default); + . DiscoverObjectGraph(.TestContext testContext, .CancellationToken cancellationToken = default); + } + public interface IObjectGraphTracker : . { } + public interface IObjectInitializationService + { + void ClearCache(); + . InitializeAsync(object? obj, .CancellationToken cancellationToken = default); + . InitializeForDiscoveryAsync(object? obj, .CancellationToken cancellationToken = default); + bool IsInitialized(object? obj); + } public interface IParallelConstraint { } public interface IParallelLimit { @@ -2588,6 +2649,14 @@ namespace .Models public ? MethodInvoker { get; set; } } } +namespace .PropertyInjection +{ + public static class PropertyCacheKeyGenerator + { + public static string GetCacheKey(.PropertyInfo property) { } + public static string GetCacheKey(..PropertyInjectionMetadata metadata) { } + } +} namespace .Services { [.("Generic type resolution requires runtime type generation")] @@ -2616,6 +2685,14 @@ namespace .Services public [] ResolveGenericClassArguments([.(..PublicConstructors)] genericTypeDefinition, object?[] constructorArguments) { } public [] ResolveGenericMethodArguments(.MethodInfo genericMethodDefinition, object?[] runtimeArguments) { } } + public sealed class ObjectInitializationService : . + { + public ObjectInitializationService() { } + public void ClearCache() { } + public . InitializeAsync(object? obj, .CancellationToken cancellationToken = default) { } + public . InitializeForDiscoveryAsync(object? obj, .CancellationToken cancellationToken = default) { } + public bool IsInitialized(object? obj) { } + } public static class ServiceProviderExtensions { public static object GetRequiredService(this serviceProvider, serviceType) { } diff --git a/TUnit.PublicAPI/Tests.Core_Library_Has_No_API_Changes.DotNet8_0.verified.txt b/TUnit.PublicAPI/Tests.Core_Library_Has_No_API_Changes.DotNet8_0.verified.txt index d97c4f38b5..b231e50b76 100644 --- a/TUnit.PublicAPI/Tests.Core_Library_Has_No_API_Changes.DotNet8_0.verified.txt +++ b/TUnit.PublicAPI/Tests.Core_Library_Has_No_API_Changes.DotNet8_0.verified.txt @@ -1000,10 +1000,6 @@ namespace public . NotInParallelConstraintKeys { get; init; } public int Order { get; set; } } - public static class ObjectInitializer - { - public static . InitializeAsync(object? obj, .CancellationToken cancellationToken = default) { } - } public class ParallelGroupAttribute : .TUnitAttribute, ., . { public ParallelGroupAttribute(string group) { } @@ -1668,6 +1664,36 @@ namespace .DataSources public static string FormatArguments(object?[] arguments, .<> formatters) { } } } +namespace .Discovery +{ + public readonly struct DiscoveryError : <.> + { + public DiscoveryError(string TypeName, string PropertyName, string ErrorMessage, Exception) { } + public string ErrorMessage { get; init; } + public Exception { get; init; } + public string PropertyName { get; init; } + public string TypeName { get; init; } + } + public sealed class ObjectGraph : . + { + public ObjectGraph(.> objectsByDepth, . allObjects) { } + public . AllObjects { get; } + public int MaxDepth { get; } + public .> ObjectsByDepth { get; } + public . GetDepthsDescending() { } + public . GetObjectsAtDepth(int depth) { } + } + public sealed class ObjectGraphDiscoverer : ., . + { + public ObjectGraphDiscoverer() { } + public .> DiscoverAndTrackObjects(.TestContext testContext, .CancellationToken cancellationToken = default) { } + public . DiscoverNestedObjectGraph(object rootObject, .CancellationToken cancellationToken = default) { } + public . DiscoverObjectGraph(.TestContext testContext, .CancellationToken cancellationToken = default) { } + public static void ClearCache() { } + public static void ClearDiscoveryErrors() { } + public static .<.> GetDiscoveryErrors() { } + } +} namespace .Enums { public enum DataGeneratorType @@ -2021,14 +2047,23 @@ namespace .Helpers [.("MakeGenericType requires runtime code generation")] public static MakeGenericTypeSafe( genericTypeDefinition, params [] typeArguments) { } } + public static class ParallelTaskHelper + { + public static . ForEachAsync(. items, action) { } + public static . ForEachAsync(T[] items, action) { } + public static . ForEachAsync(. items, action, .CancellationToken cancellationToken) { } + public static . ForEachAsync(T[] items, action, .CancellationToken cancellationToken) { } + public static . ForEachWithIndexAsync(T[] items, action) { } + public static . ForEachWithIndexAsync(T[] items, action, .CancellationToken cancellationToken) { } + } public class ProcessorCountParallelLimit : . { public ProcessorCountParallelLimit() { } public int Limit { get; } } - public class ReferenceEqualityComparer : . + public sealed class ReferenceEqualityComparer : . { - public ReferenceEqualityComparer() { } + public static readonly . Instance; public bool Equals(object? x, object? y) { } public int GetHashCode(object obj) { } } @@ -2229,6 +2264,10 @@ namespace .Interfaces { .<> GenerateDataFactories(.DataSourceContext context); } + public interface IDisposer + { + . DisposeAsync(object? obj); + } public interface IEventReceiver { int Order { get; } @@ -2289,6 +2328,28 @@ namespace .Interfaces { . OnLastTestInTestSession(.TestSessionContext current, .TestContext testContext); } + public interface IObjectGraph + { + . AllObjects { get; } + int MaxDepth { get; } + .> ObjectsByDepth { get; } + . GetDepthsDescending(); + . GetObjectsAtDepth(int depth); + } + public interface IObjectGraphDiscoverer + { + .> DiscoverAndTrackObjects(.TestContext testContext, .CancellationToken cancellationToken = default); + . DiscoverNestedObjectGraph(object rootObject, .CancellationToken cancellationToken = default); + . DiscoverObjectGraph(.TestContext testContext, .CancellationToken cancellationToken = default); + } + public interface IObjectGraphTracker : . { } + public interface IObjectInitializationService + { + void ClearCache(); + . InitializeAsync(object? obj, .CancellationToken cancellationToken = default); + . InitializeForDiscoveryAsync(object? obj, .CancellationToken cancellationToken = default); + bool IsInitialized(object? obj); + } public interface IParallelConstraint { } public interface IParallelLimit { @@ -2588,6 +2649,14 @@ namespace .Models public ? MethodInvoker { get; set; } } } +namespace .PropertyInjection +{ + public static class PropertyCacheKeyGenerator + { + public static string GetCacheKey(.PropertyInfo property) { } + public static string GetCacheKey(..PropertyInjectionMetadata metadata) { } + } +} namespace .Services { [.("Generic type resolution requires runtime type generation")] @@ -2616,6 +2685,14 @@ namespace .Services public [] ResolveGenericClassArguments([.(..PublicConstructors)] genericTypeDefinition, object?[] constructorArguments) { } public [] ResolveGenericMethodArguments(.MethodInfo genericMethodDefinition, object?[] runtimeArguments) { } } + public sealed class ObjectInitializationService : . + { + public ObjectInitializationService() { } + public void ClearCache() { } + public . InitializeAsync(object? obj, .CancellationToken cancellationToken = default) { } + public . InitializeForDiscoveryAsync(object? obj, .CancellationToken cancellationToken = default) { } + public bool IsInitialized(object? obj) { } + } public static class ServiceProviderExtensions { public static object GetRequiredService(this serviceProvider, serviceType) { } diff --git a/TUnit.PublicAPI/Tests.Core_Library_Has_No_API_Changes.DotNet9_0.verified.txt b/TUnit.PublicAPI/Tests.Core_Library_Has_No_API_Changes.DotNet9_0.verified.txt index 5280c9595c..b935f9f012 100644 --- a/TUnit.PublicAPI/Tests.Core_Library_Has_No_API_Changes.DotNet9_0.verified.txt +++ b/TUnit.PublicAPI/Tests.Core_Library_Has_No_API_Changes.DotNet9_0.verified.txt @@ -1000,10 +1000,6 @@ namespace public . NotInParallelConstraintKeys { get; init; } public int Order { get; set; } } - public static class ObjectInitializer - { - public static . InitializeAsync(object? obj, .CancellationToken cancellationToken = default) { } - } public class ParallelGroupAttribute : .TUnitAttribute, ., . { public ParallelGroupAttribute(string group) { } @@ -1668,6 +1664,36 @@ namespace .DataSources public static string FormatArguments(object?[] arguments, .<> formatters) { } } } +namespace .Discovery +{ + public readonly struct DiscoveryError : <.> + { + public DiscoveryError(string TypeName, string PropertyName, string ErrorMessage, Exception) { } + public string ErrorMessage { get; init; } + public Exception { get; init; } + public string PropertyName { get; init; } + public string TypeName { get; init; } + } + public sealed class ObjectGraph : . + { + public ObjectGraph(.> objectsByDepth, . allObjects) { } + public . AllObjects { get; } + public int MaxDepth { get; } + public .> ObjectsByDepth { get; } + public . GetDepthsDescending() { } + public . GetObjectsAtDepth(int depth) { } + } + public sealed class ObjectGraphDiscoverer : ., . + { + public ObjectGraphDiscoverer() { } + public .> DiscoverAndTrackObjects(.TestContext testContext, .CancellationToken cancellationToken = default) { } + public . DiscoverNestedObjectGraph(object rootObject, .CancellationToken cancellationToken = default) { } + public . DiscoverObjectGraph(.TestContext testContext, .CancellationToken cancellationToken = default) { } + public static void ClearCache() { } + public static void ClearDiscoveryErrors() { } + public static .<.> GetDiscoveryErrors() { } + } +} namespace .Enums { public enum DataGeneratorType @@ -2021,14 +2047,23 @@ namespace .Helpers [.("MakeGenericType requires runtime code generation")] public static MakeGenericTypeSafe( genericTypeDefinition, params [] typeArguments) { } } + public static class ParallelTaskHelper + { + public static . ForEachAsync(. items, action) { } + public static . ForEachAsync(T[] items, action) { } + public static . ForEachAsync(. items, action, .CancellationToken cancellationToken) { } + public static . ForEachAsync(T[] items, action, .CancellationToken cancellationToken) { } + public static . ForEachWithIndexAsync(T[] items, action) { } + public static . ForEachWithIndexAsync(T[] items, action, .CancellationToken cancellationToken) { } + } public class ProcessorCountParallelLimit : . { public ProcessorCountParallelLimit() { } public int Limit { get; } } - public class ReferenceEqualityComparer : . + public sealed class ReferenceEqualityComparer : . { - public ReferenceEqualityComparer() { } + public static readonly . Instance; public bool Equals(object? x, object? y) { } public int GetHashCode(object obj) { } } @@ -2229,6 +2264,10 @@ namespace .Interfaces { .<> GenerateDataFactories(.DataSourceContext context); } + public interface IDisposer + { + . DisposeAsync(object? obj); + } public interface IEventReceiver { int Order { get; } @@ -2289,6 +2328,28 @@ namespace .Interfaces { . OnLastTestInTestSession(.TestSessionContext current, .TestContext testContext); } + public interface IObjectGraph + { + . AllObjects { get; } + int MaxDepth { get; } + .> ObjectsByDepth { get; } + . GetDepthsDescending(); + . GetObjectsAtDepth(int depth); + } + public interface IObjectGraphDiscoverer + { + .> DiscoverAndTrackObjects(.TestContext testContext, .CancellationToken cancellationToken = default); + . DiscoverNestedObjectGraph(object rootObject, .CancellationToken cancellationToken = default); + . DiscoverObjectGraph(.TestContext testContext, .CancellationToken cancellationToken = default); + } + public interface IObjectGraphTracker : . { } + public interface IObjectInitializationService + { + void ClearCache(); + . InitializeAsync(object? obj, .CancellationToken cancellationToken = default); + . InitializeForDiscoveryAsync(object? obj, .CancellationToken cancellationToken = default); + bool IsInitialized(object? obj); + } public interface IParallelConstraint { } public interface IParallelLimit { @@ -2588,6 +2649,14 @@ namespace .Models public ? MethodInvoker { get; set; } } } +namespace .PropertyInjection +{ + public static class PropertyCacheKeyGenerator + { + public static string GetCacheKey(.PropertyInfo property) { } + public static string GetCacheKey(..PropertyInjectionMetadata metadata) { } + } +} namespace .Services { [.("Generic type resolution requires runtime type generation")] @@ -2616,6 +2685,14 @@ namespace .Services public [] ResolveGenericClassArguments([.(..PublicConstructors)] genericTypeDefinition, object?[] constructorArguments) { } public [] ResolveGenericMethodArguments(.MethodInfo genericMethodDefinition, object?[] runtimeArguments) { } } + public sealed class ObjectInitializationService : . + { + public ObjectInitializationService() { } + public void ClearCache() { } + public . InitializeAsync(object? obj, .CancellationToken cancellationToken = default) { } + public . InitializeForDiscoveryAsync(object? obj, .CancellationToken cancellationToken = default) { } + public bool IsInitialized(object? obj) { } + } public static class ServiceProviderExtensions { public static object GetRequiredService(this serviceProvider, serviceType) { } diff --git a/TUnit.PublicAPI/Tests.Core_Library_Has_No_API_Changes.Net4_7.verified.txt b/TUnit.PublicAPI/Tests.Core_Library_Has_No_API_Changes.Net4_7.verified.txt index 538bd3ca00..5d4c0dc092 100644 --- a/TUnit.PublicAPI/Tests.Core_Library_Has_No_API_Changes.Net4_7.verified.txt +++ b/TUnit.PublicAPI/Tests.Core_Library_Has_No_API_Changes.Net4_7.verified.txt @@ -963,10 +963,6 @@ namespace public . NotInParallelConstraintKeys { get; init; } public int Order { get; set; } } - public static class ObjectInitializer - { - public static . InitializeAsync(object? obj, .CancellationToken cancellationToken = default) { } - } public class ParallelGroupAttribute : .TUnitAttribute, ., . { public ParallelGroupAttribute(string group) { } @@ -1621,6 +1617,36 @@ namespace .DataSources public static string FormatArguments(object?[] arguments, .<> formatters) { } } } +namespace .Discovery +{ + public readonly struct DiscoveryError : <.> + { + public DiscoveryError(string TypeName, string PropertyName, string ErrorMessage, Exception) { } + public string ErrorMessage { get; init; } + public Exception { get; init; } + public string PropertyName { get; init; } + public string TypeName { get; init; } + } + public sealed class ObjectGraph : . + { + public ObjectGraph(.> objectsByDepth, . allObjects) { } + public . AllObjects { get; } + public int MaxDepth { get; } + public .> ObjectsByDepth { get; } + public . GetDepthsDescending() { } + public . GetObjectsAtDepth(int depth) { } + } + public sealed class ObjectGraphDiscoverer : ., . + { + public ObjectGraphDiscoverer() { } + public .> DiscoverAndTrackObjects(.TestContext testContext, .CancellationToken cancellationToken = default) { } + public . DiscoverNestedObjectGraph(object rootObject, .CancellationToken cancellationToken = default) { } + public . DiscoverObjectGraph(.TestContext testContext, .CancellationToken cancellationToken = default) { } + public static void ClearCache() { } + public static void ClearDiscoveryErrors() { } + public static .<.> GetDiscoveryErrors() { } + } +} namespace .Enums { public enum DataGeneratorType @@ -1960,14 +1986,23 @@ namespace .Helpers public static bool IsConstructedGenericType( type) { } public static MakeGenericTypeSafe( genericTypeDefinition, params [] typeArguments) { } } + public static class ParallelTaskHelper + { + public static . ForEachAsync(. items, action) { } + public static . ForEachAsync(T[] items, action) { } + public static . ForEachAsync(. items, action, .CancellationToken cancellationToken) { } + public static . ForEachAsync(T[] items, action, .CancellationToken cancellationToken) { } + public static . ForEachWithIndexAsync(T[] items, action) { } + public static . ForEachWithIndexAsync(T[] items, action, .CancellationToken cancellationToken) { } + } public class ProcessorCountParallelLimit : . { public ProcessorCountParallelLimit() { } public int Limit { get; } } - public class ReferenceEqualityComparer : . + public sealed class ReferenceEqualityComparer : . { - public ReferenceEqualityComparer() { } + public static readonly . Instance; public bool Equals(object? x, object? y) { } public int GetHashCode(object obj) { } } @@ -2161,6 +2196,10 @@ namespace .Interfaces { .<> GenerateDataFactories(.DataSourceContext context); } + public interface IDisposer + { + . DisposeAsync(object? obj); + } public interface IEventReceiver { int Order { get; } @@ -2221,6 +2260,28 @@ namespace .Interfaces { . OnLastTestInTestSession(.TestSessionContext current, .TestContext testContext); } + public interface IObjectGraph + { + . AllObjects { get; } + int MaxDepth { get; } + .> ObjectsByDepth { get; } + . GetDepthsDescending(); + . GetObjectsAtDepth(int depth); + } + public interface IObjectGraphDiscoverer + { + .> DiscoverAndTrackObjects(.TestContext testContext, .CancellationToken cancellationToken = default); + . DiscoverNestedObjectGraph(object rootObject, .CancellationToken cancellationToken = default); + . DiscoverObjectGraph(.TestContext testContext, .CancellationToken cancellationToken = default); + } + public interface IObjectGraphTracker : . { } + public interface IObjectInitializationService + { + void ClearCache(); + . InitializeAsync(object? obj, .CancellationToken cancellationToken = default); + . InitializeForDiscoveryAsync(object? obj, .CancellationToken cancellationToken = default); + bool IsInitialized(object? obj); + } public interface IParallelConstraint { } public interface IParallelLimit { @@ -2510,6 +2571,14 @@ namespace .Models public ? MethodInvoker { get; set; } } } +namespace .PropertyInjection +{ + public static class PropertyCacheKeyGenerator + { + public static string GetCacheKey(.PropertyInfo property) { } + public static string GetCacheKey(..PropertyInjectionMetadata metadata) { } + } +} namespace .Services { public class GenericTypeResolver : . @@ -2537,6 +2606,14 @@ namespace .Services public [] ResolveGenericClassArguments( genericTypeDefinition, object?[] constructorArguments) { } public [] ResolveGenericMethodArguments(.MethodInfo genericMethodDefinition, object?[] runtimeArguments) { } } + public sealed class ObjectInitializationService : . + { + public ObjectInitializationService() { } + public void ClearCache() { } + public . InitializeAsync(object? obj, .CancellationToken cancellationToken = default) { } + public . InitializeForDiscoveryAsync(object? obj, .CancellationToken cancellationToken = default) { } + public bool IsInitialized(object? obj) { } + } public static class ServiceProviderExtensions { public static object GetRequiredService(this serviceProvider, serviceType) { } From 851e21ced02f82ba4405075c12909fdaade6bb05 Mon Sep 17 00:00:00 2001 From: Tom Longhurst <30480171+thomhurst@users.noreply.github.com> Date: Sun, 7 Dec 2025 17:57:59 +0000 Subject: [PATCH 13/20] feat: improve disposal callback registration logic to handle untracked objects and enhance error handling during initialization --- TUnit.Core/Tracking/ObjectTracker.cs | 74 ++++++++----------- .../Services/ObjectLifecycleService.cs | 10 +++ 2 files changed, 42 insertions(+), 42 deletions(-) diff --git a/TUnit.Core/Tracking/ObjectTracker.cs b/TUnit.Core/Tracking/ObjectTracker.cs index a0d0d28974..bb273255c5 100644 --- a/TUnit.Core/Tracking/ObjectTracker.cs +++ b/TUnit.Core/Tracking/ObjectTracker.cs @@ -210,7 +210,7 @@ private static bool ShouldSkipTracking(object? obj) /// /// Registers a callback to be invoked when the object is disposed (ref count reaches 0). - /// If the object is already disposed, the callback is invoked immediately. + /// If the object is already disposed (or was never tracked), the callback is invoked immediately. /// The callback is guaranteed to be invoked exactly once (idempotent). /// /// The object to monitor for disposal. If null or not disposable, the method returns without action. @@ -227,46 +227,12 @@ public static void OnDisposed(object? o, Action action) } #endif - if (o is not IDisposable and not IAsyncDisposable) - { - return; - } - - var counter = GetOrCreateCounter(o); - - // Use flag to ensure callback only fires once (idempotent) - var invoked = 0; - EventHandler? handler = null; - - handler = (sender, count) => - { - if (count == 0 && Interlocked.Exchange(ref invoked, 1) == 0) - { - // Remove handler to prevent memory leaks - if (sender is Counter c && handler != null) - { - c.OnCountChanged -= handler; - } - - action(); - } - }; - - counter.OnCountChanged += handler; - - // Check if already disposed (count is 0) - invoke immediately if so - // This prevents lost callbacks when registering after disposal - // Idempotent check ensures this doesn't double-fire if event already triggered - if (counter.CurrentCount == 0 && Interlocked.Exchange(ref invoked, 1) == 0) - { - counter.OnCountChanged -= handler; - action(); - } + RegisterDisposalCallback(o, action, static a => a()); } /// /// Registers an async callback to be invoked when the object is disposed (ref count reaches 0). - /// If the object is already disposed, the callback is invoked immediately. + /// If the object is already disposed (or was never tracked), the callback is invoked immediately. /// The callback is guaranteed to be invoked exactly once (idempotent). /// /// The object to monitor for disposal. If null or not disposable, the method returns without action. @@ -283,18 +249,43 @@ public static void OnDisposedAsync(object? o, Func asyncAction) } #endif + // Wrap async action in fire-and-forget with exception collection + RegisterDisposalCallback(o, asyncAction, static a => _ = SafeExecuteAsync(a)); + } + + /// + /// Core implementation for registering disposal callbacks. + /// Extracts common logic from OnDisposed and OnDisposedAsync (DRY principle). + /// + /// The type of action (Action or Func<Task>). + /// The object to monitor for disposal. + /// The callback action. + /// How to invoke the action (sync vs async wrapper). + private static void RegisterDisposalCallback( + object? o, + TAction action, + Action invoker) + where TAction : Delegate + { if (o is not IDisposable and not IAsyncDisposable) { return; } - var counter = GetOrCreateCounter(o); + // Only register callback if the object is actually being tracked. + // If not tracked, invoke callback immediately (object is effectively "disposed"). + // This prevents creating spurious counters for untracked objects. + if (!s_trackedObjects.TryGetValue(o, out var counter)) + { + // Object not tracked - invoke callback immediately + invoker(action); + return; + } // Use flag to ensure callback only fires once (idempotent) var invoked = 0; EventHandler? handler = null; - // Avoid async void pattern by wrapping in fire-and-forget with exception handling handler = (sender, count) => { if (count == 0 && Interlocked.Exchange(ref invoked, 1) == 0) @@ -305,8 +296,7 @@ public static void OnDisposedAsync(object? o, Func asyncAction) c.OnCountChanged -= handler; } - // Fire-and-forget with exception collection to surface errors - _ = SafeExecuteAsync(asyncAction); + invoker(action); } }; @@ -318,7 +308,7 @@ public static void OnDisposedAsync(object? o, Func asyncAction) if (counter.CurrentCount == 0 && Interlocked.Exchange(ref invoked, 1) == 0) { counter.OnCountChanged -= handler; - _ = SafeExecuteAsync(asyncAction); + invoker(action); } } diff --git a/TUnit.Engine/Services/ObjectLifecycleService.cs b/TUnit.Engine/Services/ObjectLifecycleService.cs index 9895d65d3a..2d7d4a4ac3 100644 --- a/TUnit.Engine/Services/ObjectLifecycleService.cs +++ b/TUnit.Engine/Services/ObjectLifecycleService.cs @@ -322,8 +322,18 @@ public async ValueTask EnsureInitializedAsync( await InitializeObjectCoreAsync(obj, objectBag, methodMetadata, events, cancellationToken); tcs.SetResult(true); } + catch (OperationCanceledException) + { + // Propagate cancellation without caching failure - allows retry after cancel + _initializationTasks.TryRemove(obj, out _); + tcs.SetCanceled(); + throw; + } catch (Exception ex) { + // Remove failed initialization from cache to allow retry + // This is important for transient failures that may succeed on retry + _initializationTasks.TryRemove(obj, out _); tcs.SetException(ex); throw; } From f89adb836a1c980ab2c8d052acf64261ccf1eb13 Mon Sep 17 00:00:00 2001 From: Tom Longhurst <30480171+thomhurst@users.noreply.github.com> Date: Sun, 7 Dec 2025 18:05:11 +0000 Subject: [PATCH 14/20] feat: change visibility of object graph related classes and interfaces to internal for encapsulation --- TUnit.Core/Discovery/ObjectGraph.cs | 2 +- TUnit.Core/Discovery/ObjectGraphDiscoverer.cs | 4 +- .../Interfaces/IObjectGraphDiscoverer.cs | 6 +- .../IObjectInitializationService.cs | 2 +- .../Services/ObjectInitializationService.cs | 2 +- ...Has_No_API_Changes.DotNet10_0.verified.txt | 60 ------------------- ..._Has_No_API_Changes.DotNet8_0.verified.txt | 60 ------------------- ..._Has_No_API_Changes.DotNet9_0.verified.txt | 60 ------------------- ...ary_Has_No_API_Changes.Net4_7.verified.txt | 60 ------------------- 9 files changed, 8 insertions(+), 248 deletions(-) diff --git a/TUnit.Core/Discovery/ObjectGraph.cs b/TUnit.Core/Discovery/ObjectGraph.cs index 3d79a52785..85a1b6f960 100644 --- a/TUnit.Core/Discovery/ObjectGraph.cs +++ b/TUnit.Core/Discovery/ObjectGraph.cs @@ -12,7 +12,7 @@ namespace TUnit.Core.Discovery; /// to prevent callers from corrupting internal state. /// Uses Lazy<T> for thread-safe lazy initialization of read-only views. /// -public sealed class ObjectGraph : IObjectGraph +internal sealed class ObjectGraph : IObjectGraph { private readonly ConcurrentDictionary> _objectsByDepth; private readonly HashSet _allObjects; diff --git a/TUnit.Core/Discovery/ObjectGraphDiscoverer.cs b/TUnit.Core/Discovery/ObjectGraphDiscoverer.cs index 27c8f811b3..f2389fda28 100644 --- a/TUnit.Core/Discovery/ObjectGraphDiscoverer.cs +++ b/TUnit.Core/Discovery/ObjectGraphDiscoverer.cs @@ -16,7 +16,7 @@ namespace TUnit.Core.Discovery; /// The name of the property that failed to access. /// The error message. /// The exception that occurred. -public readonly record struct DiscoveryError(string TypeName, string PropertyName, string ErrorMessage, Exception Exception); +internal readonly record struct DiscoveryError(string TypeName, string PropertyName, string ErrorMessage, Exception Exception); /// /// Centralized service for discovering and organizing object graphs. @@ -37,7 +37,7 @@ namespace TUnit.Core.Discovery; /// rather than thrown, allowing discovery to continue despite individual property failures. /// /// -public sealed class ObjectGraphDiscoverer : IObjectGraphTracker +internal sealed class ObjectGraphDiscoverer : IObjectGraphTracker { /// /// Maximum recursion depth for object graph discovery. diff --git a/TUnit.Core/Interfaces/IObjectGraphDiscoverer.cs b/TUnit.Core/Interfaces/IObjectGraphDiscoverer.cs index c6525cf634..8ac4586be3 100644 --- a/TUnit.Core/Interfaces/IObjectGraphDiscoverer.cs +++ b/TUnit.Core/Interfaces/IObjectGraphDiscoverer.cs @@ -24,7 +24,7 @@ namespace TUnit.Core.Interfaces; /// For tracking operations that modify TestContext.TrackedObjects, see . /// /// -public interface IObjectGraphDiscoverer +internal interface IObjectGraphDiscoverer { /// /// Discovers all objects from a test context, organized by depth level. @@ -81,7 +81,7 @@ public interface IObjectGraphDiscoverer /// The distinction exists for semantic clarity and future extensibility. /// /// -public interface IObjectGraphTracker : IObjectGraphDiscoverer +internal interface IObjectGraphTracker : IObjectGraphDiscoverer { // All methods inherited from IObjectGraphDiscoverer // This interface provides semantic clarity for tracking operations @@ -94,7 +94,7 @@ public interface IObjectGraphTracker : IObjectGraphDiscoverer /// Collections are exposed as read-only to prevent callers from corrupting internal state. /// Use and for safe iteration. /// -public interface IObjectGraph +internal interface IObjectGraph { /// /// Gets objects organized by depth (0 = root arguments, 1+ = nested). diff --git a/TUnit.Core/Interfaces/IObjectInitializationService.cs b/TUnit.Core/Interfaces/IObjectInitializationService.cs index 4cf5d26bba..6e5eb70338 100644 --- a/TUnit.Core/Interfaces/IObjectInitializationService.cs +++ b/TUnit.Core/Interfaces/IObjectInitializationService.cs @@ -16,7 +16,7 @@ namespace TUnit.Core.Interfaces; /// /// /// -public interface IObjectInitializationService +internal interface IObjectInitializationService { /// /// Initializes an object during the execution phase. diff --git a/TUnit.Core/Services/ObjectInitializationService.cs b/TUnit.Core/Services/ObjectInitializationService.cs index d033605fd1..a18845001c 100644 --- a/TUnit.Core/Services/ObjectInitializationService.cs +++ b/TUnit.Core/Services/ObjectInitializationService.cs @@ -12,7 +12,7 @@ namespace TUnit.Core.Services; /// behavior and avoid duplicate caches. This consolidates initialization tracking in one place. /// /// -public sealed class ObjectInitializationService : IObjectInitializationService +internal sealed class ObjectInitializationService : IObjectInitializationService { /// /// Creates a new instance of the initialization service. diff --git a/TUnit.PublicAPI/Tests.Core_Library_Has_No_API_Changes.DotNet10_0.verified.txt b/TUnit.PublicAPI/Tests.Core_Library_Has_No_API_Changes.DotNet10_0.verified.txt index d6c7f837e4..212cbcb9fc 100644 --- a/TUnit.PublicAPI/Tests.Core_Library_Has_No_API_Changes.DotNet10_0.verified.txt +++ b/TUnit.PublicAPI/Tests.Core_Library_Has_No_API_Changes.DotNet10_0.verified.txt @@ -1664,36 +1664,6 @@ namespace .DataSources public static string FormatArguments(object?[] arguments, .<> formatters) { } } } -namespace .Discovery -{ - public readonly struct DiscoveryError : <.> - { - public DiscoveryError(string TypeName, string PropertyName, string ErrorMessage, Exception) { } - public string ErrorMessage { get; init; } - public Exception { get; init; } - public string PropertyName { get; init; } - public string TypeName { get; init; } - } - public sealed class ObjectGraph : . - { - public ObjectGraph(.> objectsByDepth, . allObjects) { } - public . AllObjects { get; } - public int MaxDepth { get; } - public .> ObjectsByDepth { get; } - public . GetDepthsDescending() { } - public . GetObjectsAtDepth(int depth) { } - } - public sealed class ObjectGraphDiscoverer : ., . - { - public ObjectGraphDiscoverer() { } - public .> DiscoverAndTrackObjects(.TestContext testContext, .CancellationToken cancellationToken = default) { } - public . DiscoverNestedObjectGraph(object rootObject, .CancellationToken cancellationToken = default) { } - public . DiscoverObjectGraph(.TestContext testContext, .CancellationToken cancellationToken = default) { } - public static void ClearCache() { } - public static void ClearDiscoveryErrors() { } - public static .<.> GetDiscoveryErrors() { } - } -} namespace .Enums { public enum DataGeneratorType @@ -2328,28 +2298,6 @@ namespace .Interfaces { . OnLastTestInTestSession(.TestSessionContext current, .TestContext testContext); } - public interface IObjectGraph - { - . AllObjects { get; } - int MaxDepth { get; } - .> ObjectsByDepth { get; } - . GetDepthsDescending(); - . GetObjectsAtDepth(int depth); - } - public interface IObjectGraphDiscoverer - { - .> DiscoverAndTrackObjects(.TestContext testContext, .CancellationToken cancellationToken = default); - . DiscoverNestedObjectGraph(object rootObject, .CancellationToken cancellationToken = default); - . DiscoverObjectGraph(.TestContext testContext, .CancellationToken cancellationToken = default); - } - public interface IObjectGraphTracker : . { } - public interface IObjectInitializationService - { - void ClearCache(); - . InitializeAsync(object? obj, .CancellationToken cancellationToken = default); - . InitializeForDiscoveryAsync(object? obj, .CancellationToken cancellationToken = default); - bool IsInitialized(object? obj); - } public interface IParallelConstraint { } public interface IParallelLimit { @@ -2685,14 +2633,6 @@ namespace .Services public [] ResolveGenericClassArguments([.(..PublicConstructors)] genericTypeDefinition, object?[] constructorArguments) { } public [] ResolveGenericMethodArguments(.MethodInfo genericMethodDefinition, object?[] runtimeArguments) { } } - public sealed class ObjectInitializationService : . - { - public ObjectInitializationService() { } - public void ClearCache() { } - public . InitializeAsync(object? obj, .CancellationToken cancellationToken = default) { } - public . InitializeForDiscoveryAsync(object? obj, .CancellationToken cancellationToken = default) { } - public bool IsInitialized(object? obj) { } - } public static class ServiceProviderExtensions { public static object GetRequiredService(this serviceProvider, serviceType) { } diff --git a/TUnit.PublicAPI/Tests.Core_Library_Has_No_API_Changes.DotNet8_0.verified.txt b/TUnit.PublicAPI/Tests.Core_Library_Has_No_API_Changes.DotNet8_0.verified.txt index b231e50b76..00c508f06b 100644 --- a/TUnit.PublicAPI/Tests.Core_Library_Has_No_API_Changes.DotNet8_0.verified.txt +++ b/TUnit.PublicAPI/Tests.Core_Library_Has_No_API_Changes.DotNet8_0.verified.txt @@ -1664,36 +1664,6 @@ namespace .DataSources public static string FormatArguments(object?[] arguments, .<> formatters) { } } } -namespace .Discovery -{ - public readonly struct DiscoveryError : <.> - { - public DiscoveryError(string TypeName, string PropertyName, string ErrorMessage, Exception) { } - public string ErrorMessage { get; init; } - public Exception { get; init; } - public string PropertyName { get; init; } - public string TypeName { get; init; } - } - public sealed class ObjectGraph : . - { - public ObjectGraph(.> objectsByDepth, . allObjects) { } - public . AllObjects { get; } - public int MaxDepth { get; } - public .> ObjectsByDepth { get; } - public . GetDepthsDescending() { } - public . GetObjectsAtDepth(int depth) { } - } - public sealed class ObjectGraphDiscoverer : ., . - { - public ObjectGraphDiscoverer() { } - public .> DiscoverAndTrackObjects(.TestContext testContext, .CancellationToken cancellationToken = default) { } - public . DiscoverNestedObjectGraph(object rootObject, .CancellationToken cancellationToken = default) { } - public . DiscoverObjectGraph(.TestContext testContext, .CancellationToken cancellationToken = default) { } - public static void ClearCache() { } - public static void ClearDiscoveryErrors() { } - public static .<.> GetDiscoveryErrors() { } - } -} namespace .Enums { public enum DataGeneratorType @@ -2328,28 +2298,6 @@ namespace .Interfaces { . OnLastTestInTestSession(.TestSessionContext current, .TestContext testContext); } - public interface IObjectGraph - { - . AllObjects { get; } - int MaxDepth { get; } - .> ObjectsByDepth { get; } - . GetDepthsDescending(); - . GetObjectsAtDepth(int depth); - } - public interface IObjectGraphDiscoverer - { - .> DiscoverAndTrackObjects(.TestContext testContext, .CancellationToken cancellationToken = default); - . DiscoverNestedObjectGraph(object rootObject, .CancellationToken cancellationToken = default); - . DiscoverObjectGraph(.TestContext testContext, .CancellationToken cancellationToken = default); - } - public interface IObjectGraphTracker : . { } - public interface IObjectInitializationService - { - void ClearCache(); - . InitializeAsync(object? obj, .CancellationToken cancellationToken = default); - . InitializeForDiscoveryAsync(object? obj, .CancellationToken cancellationToken = default); - bool IsInitialized(object? obj); - } public interface IParallelConstraint { } public interface IParallelLimit { @@ -2685,14 +2633,6 @@ namespace .Services public [] ResolveGenericClassArguments([.(..PublicConstructors)] genericTypeDefinition, object?[] constructorArguments) { } public [] ResolveGenericMethodArguments(.MethodInfo genericMethodDefinition, object?[] runtimeArguments) { } } - public sealed class ObjectInitializationService : . - { - public ObjectInitializationService() { } - public void ClearCache() { } - public . InitializeAsync(object? obj, .CancellationToken cancellationToken = default) { } - public . InitializeForDiscoveryAsync(object? obj, .CancellationToken cancellationToken = default) { } - public bool IsInitialized(object? obj) { } - } public static class ServiceProviderExtensions { public static object GetRequiredService(this serviceProvider, serviceType) { } diff --git a/TUnit.PublicAPI/Tests.Core_Library_Has_No_API_Changes.DotNet9_0.verified.txt b/TUnit.PublicAPI/Tests.Core_Library_Has_No_API_Changes.DotNet9_0.verified.txt index b935f9f012..af52063e82 100644 --- a/TUnit.PublicAPI/Tests.Core_Library_Has_No_API_Changes.DotNet9_0.verified.txt +++ b/TUnit.PublicAPI/Tests.Core_Library_Has_No_API_Changes.DotNet9_0.verified.txt @@ -1664,36 +1664,6 @@ namespace .DataSources public static string FormatArguments(object?[] arguments, .<> formatters) { } } } -namespace .Discovery -{ - public readonly struct DiscoveryError : <.> - { - public DiscoveryError(string TypeName, string PropertyName, string ErrorMessage, Exception) { } - public string ErrorMessage { get; init; } - public Exception { get; init; } - public string PropertyName { get; init; } - public string TypeName { get; init; } - } - public sealed class ObjectGraph : . - { - public ObjectGraph(.> objectsByDepth, . allObjects) { } - public . AllObjects { get; } - public int MaxDepth { get; } - public .> ObjectsByDepth { get; } - public . GetDepthsDescending() { } - public . GetObjectsAtDepth(int depth) { } - } - public sealed class ObjectGraphDiscoverer : ., . - { - public ObjectGraphDiscoverer() { } - public .> DiscoverAndTrackObjects(.TestContext testContext, .CancellationToken cancellationToken = default) { } - public . DiscoverNestedObjectGraph(object rootObject, .CancellationToken cancellationToken = default) { } - public . DiscoverObjectGraph(.TestContext testContext, .CancellationToken cancellationToken = default) { } - public static void ClearCache() { } - public static void ClearDiscoveryErrors() { } - public static .<.> GetDiscoveryErrors() { } - } -} namespace .Enums { public enum DataGeneratorType @@ -2328,28 +2298,6 @@ namespace .Interfaces { . OnLastTestInTestSession(.TestSessionContext current, .TestContext testContext); } - public interface IObjectGraph - { - . AllObjects { get; } - int MaxDepth { get; } - .> ObjectsByDepth { get; } - . GetDepthsDescending(); - . GetObjectsAtDepth(int depth); - } - public interface IObjectGraphDiscoverer - { - .> DiscoverAndTrackObjects(.TestContext testContext, .CancellationToken cancellationToken = default); - . DiscoverNestedObjectGraph(object rootObject, .CancellationToken cancellationToken = default); - . DiscoverObjectGraph(.TestContext testContext, .CancellationToken cancellationToken = default); - } - public interface IObjectGraphTracker : . { } - public interface IObjectInitializationService - { - void ClearCache(); - . InitializeAsync(object? obj, .CancellationToken cancellationToken = default); - . InitializeForDiscoveryAsync(object? obj, .CancellationToken cancellationToken = default); - bool IsInitialized(object? obj); - } public interface IParallelConstraint { } public interface IParallelLimit { @@ -2685,14 +2633,6 @@ namespace .Services public [] ResolveGenericClassArguments([.(..PublicConstructors)] genericTypeDefinition, object?[] constructorArguments) { } public [] ResolveGenericMethodArguments(.MethodInfo genericMethodDefinition, object?[] runtimeArguments) { } } - public sealed class ObjectInitializationService : . - { - public ObjectInitializationService() { } - public void ClearCache() { } - public . InitializeAsync(object? obj, .CancellationToken cancellationToken = default) { } - public . InitializeForDiscoveryAsync(object? obj, .CancellationToken cancellationToken = default) { } - public bool IsInitialized(object? obj) { } - } public static class ServiceProviderExtensions { public static object GetRequiredService(this serviceProvider, serviceType) { } diff --git a/TUnit.PublicAPI/Tests.Core_Library_Has_No_API_Changes.Net4_7.verified.txt b/TUnit.PublicAPI/Tests.Core_Library_Has_No_API_Changes.Net4_7.verified.txt index 5d4c0dc092..a3d525329b 100644 --- a/TUnit.PublicAPI/Tests.Core_Library_Has_No_API_Changes.Net4_7.verified.txt +++ b/TUnit.PublicAPI/Tests.Core_Library_Has_No_API_Changes.Net4_7.verified.txt @@ -1617,36 +1617,6 @@ namespace .DataSources public static string FormatArguments(object?[] arguments, .<> formatters) { } } } -namespace .Discovery -{ - public readonly struct DiscoveryError : <.> - { - public DiscoveryError(string TypeName, string PropertyName, string ErrorMessage, Exception) { } - public string ErrorMessage { get; init; } - public Exception { get; init; } - public string PropertyName { get; init; } - public string TypeName { get; init; } - } - public sealed class ObjectGraph : . - { - public ObjectGraph(.> objectsByDepth, . allObjects) { } - public . AllObjects { get; } - public int MaxDepth { get; } - public .> ObjectsByDepth { get; } - public . GetDepthsDescending() { } - public . GetObjectsAtDepth(int depth) { } - } - public sealed class ObjectGraphDiscoverer : ., . - { - public ObjectGraphDiscoverer() { } - public .> DiscoverAndTrackObjects(.TestContext testContext, .CancellationToken cancellationToken = default) { } - public . DiscoverNestedObjectGraph(object rootObject, .CancellationToken cancellationToken = default) { } - public . DiscoverObjectGraph(.TestContext testContext, .CancellationToken cancellationToken = default) { } - public static void ClearCache() { } - public static void ClearDiscoveryErrors() { } - public static .<.> GetDiscoveryErrors() { } - } -} namespace .Enums { public enum DataGeneratorType @@ -2260,28 +2230,6 @@ namespace .Interfaces { . OnLastTestInTestSession(.TestSessionContext current, .TestContext testContext); } - public interface IObjectGraph - { - . AllObjects { get; } - int MaxDepth { get; } - .> ObjectsByDepth { get; } - . GetDepthsDescending(); - . GetObjectsAtDepth(int depth); - } - public interface IObjectGraphDiscoverer - { - .> DiscoverAndTrackObjects(.TestContext testContext, .CancellationToken cancellationToken = default); - . DiscoverNestedObjectGraph(object rootObject, .CancellationToken cancellationToken = default); - . DiscoverObjectGraph(.TestContext testContext, .CancellationToken cancellationToken = default); - } - public interface IObjectGraphTracker : . { } - public interface IObjectInitializationService - { - void ClearCache(); - . InitializeAsync(object? obj, .CancellationToken cancellationToken = default); - . InitializeForDiscoveryAsync(object? obj, .CancellationToken cancellationToken = default); - bool IsInitialized(object? obj); - } public interface IParallelConstraint { } public interface IParallelLimit { @@ -2606,14 +2554,6 @@ namespace .Services public [] ResolveGenericClassArguments( genericTypeDefinition, object?[] constructorArguments) { } public [] ResolveGenericMethodArguments(.MethodInfo genericMethodDefinition, object?[] runtimeArguments) { } } - public sealed class ObjectInitializationService : . - { - public ObjectInitializationService() { } - public void ClearCache() { } - public . InitializeAsync(object? obj, .CancellationToken cancellationToken = default) { } - public . InitializeForDiscoveryAsync(object? obj, .CancellationToken cancellationToken = default) { } - public bool IsInitialized(object? obj) { } - } public static class ServiceProviderExtensions { public static object GetRequiredService(this serviceProvider, serviceType) { } From a1938e394a85c56a6bb6f5e3873c8dc4aff4da3c Mon Sep 17 00:00:00 2001 From: Tom Longhurst <30480171+thomhurst@users.noreply.github.com> Date: Sun, 7 Dec 2025 18:25:08 +0000 Subject: [PATCH 15/20] feat: implement IAsyncDiscoveryInitializer and related classes for improved test discovery handling --- TUnit.TestProject/Bugs/3992/DummyContainer.cs | 22 -------- ...ecreation.cs => RuntimeInitializeTests.cs} | 27 +++++++-- .../Bugs/3992/RuntimeInitializeTests2.cs | 56 +++++++++++++++++++ 3 files changed, 77 insertions(+), 28 deletions(-) delete mode 100644 TUnit.TestProject/Bugs/3992/DummyContainer.cs rename TUnit.TestProject/Bugs/3992/{BugRecreation.cs => RuntimeInitializeTests.cs} (68%) create mode 100644 TUnit.TestProject/Bugs/3992/RuntimeInitializeTests2.cs diff --git a/TUnit.TestProject/Bugs/3992/DummyContainer.cs b/TUnit.TestProject/Bugs/3992/DummyContainer.cs deleted file mode 100644 index b075f5e9f1..0000000000 --- a/TUnit.TestProject/Bugs/3992/DummyContainer.cs +++ /dev/null @@ -1,22 +0,0 @@ -using TUnit.Core.Interfaces; - -namespace TUnit.TestProject.Bugs._3992; - -public class DummyContainer : IAsyncInitializer, IAsyncDisposable -{ - public Task InitializeAsync() - { - NumberOfInits++; - Ints = [1, 2, 3, 4, 5, 6]; - return Task.CompletedTask; - } - - public int[] Ints { get; private set; } = null!; - - public static int NumberOfInits { get; private set; } - - public ValueTask DisposeAsync() - { - return default; - } -} diff --git a/TUnit.TestProject/Bugs/3992/BugRecreation.cs b/TUnit.TestProject/Bugs/3992/RuntimeInitializeTests.cs similarity index 68% rename from TUnit.TestProject/Bugs/3992/BugRecreation.cs rename to TUnit.TestProject/Bugs/3992/RuntimeInitializeTests.cs index fbc042d9fa..1d86987a8a 100644 --- a/TUnit.TestProject/Bugs/3992/BugRecreation.cs +++ b/TUnit.TestProject/Bugs/3992/RuntimeInitializeTests.cs @@ -1,4 +1,5 @@ -using TUnit.TestProject.Attributes; +using TUnit.Core.Interfaces; +using TUnit.TestProject.Attributes; namespace TUnit.TestProject.Bugs._3992; @@ -6,15 +7,12 @@ namespace TUnit.TestProject.Bugs._3992; /// Once this is discovered during test discovery, containers spin up /// [EngineTest(ExpectedResult.Pass)] -public sealed class BugRecreation +public sealed class RuntimeInitializeTests { //Docker container [ClassDataSource(Shared = SharedType.PerClass)] public required DummyContainer Container { get; init; } - public IEnumerable> Executions - => Container.Ints.Select(e => new Func(() => e)); - [Before(Class)] public static Task BeforeClass(ClassHookContext context) => NotInitialised(context.Tests); @@ -23,7 +21,7 @@ public IEnumerable> Executions public static async Task NotInitialised(IEnumerable tests) { - var bugRecreations = tests.Select(x => x.Metadata.TestDetails.ClassInstance).OfType(); + var bugRecreations = tests.Select(x => x.Metadata.TestDetails.ClassInstance).OfType(); foreach (var bugRecreation in bugRecreations) { @@ -38,4 +36,21 @@ public async Task Test(int value, CancellationToken token) await Assert.That(value).IsNotDefault(); await Assert.That(DummyContainer.NumberOfInits).IsEqualTo(1); } + + public class DummyContainer : IAsyncInitializer, IAsyncDisposable + { + public Task InitializeAsync() + { + NumberOfInits++; + return Task.CompletedTask; + } + + public static int NumberOfInits { get; private set; } + + public ValueTask DisposeAsync() + { + return default; + } + } + } diff --git a/TUnit.TestProject/Bugs/3992/RuntimeInitializeTests2.cs b/TUnit.TestProject/Bugs/3992/RuntimeInitializeTests2.cs new file mode 100644 index 0000000000..4edc097c47 --- /dev/null +++ b/TUnit.TestProject/Bugs/3992/RuntimeInitializeTests2.cs @@ -0,0 +1,56 @@ +using TUnit.Core.Interfaces; +using TUnit.TestProject.Attributes; + +namespace TUnit.TestProject.Bugs._3992; + +/// +/// Once this is discovered during test discovery, containers spin up +/// +[EngineTest(ExpectedResult.Pass)] +public sealed class DiscoveryInitializeTests +{ + //Docker container + [ClassDataSource(Shared = SharedType.PerClass)] + public required DummyContainer Container { get; init; } + + [Before(Class)] + public static Task BeforeClass(ClassHookContext context) => Initialised(context.Tests); + + [After(TestDiscovery)] + public static Task AfterDiscovery(TestDiscoveryContext context) => Initialised(context.AllTests); + + public static async Task Initialised(IEnumerable tests) + { + var bugRecreations = tests.Select(x => x.Metadata.TestDetails.ClassInstance).OfType(); + + foreach (var bugRecreation in bugRecreations) + { + await Assert.That(bugRecreation.Container).IsNotNull(); + await Assert.That(DummyContainer.NumberOfInits).IsEqualTo(1); + } + } + + [Test, Arguments(1)] + public async Task Test(int value, CancellationToken token) + { + await Assert.That(value).IsNotDefault(); + await Assert.That(DummyContainer.NumberOfInits).IsEqualTo(1); + } + + public class DummyContainer : IAsyncDiscoveryInitializer, IAsyncDisposable + { + public Task InitializeAsync() + { + NumberOfInits++; + return Task.CompletedTask; + } + + public static int NumberOfInits { get; private set; } + + public ValueTask DisposeAsync() + { + return default; + } + } + +} From 9efbd557225cd2e0af3aeb3ce7ea40516ce050ab Mon Sep 17 00:00:00 2001 From: Tom Longhurst <30480171+thomhurst@users.noreply.github.com> Date: Sun, 7 Dec 2025 18:35:57 +0000 Subject: [PATCH 16/20] feat: skip PlaceholderInstance during data source class resolution for improved test instance handling --- TUnit.Core/Attributes/TestData/MethodDataSourceAttribute.cs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/TUnit.Core/Attributes/TestData/MethodDataSourceAttribute.cs b/TUnit.Core/Attributes/TestData/MethodDataSourceAttribute.cs index efa53d7741..c6bd260fe0 100644 --- a/TUnit.Core/Attributes/TestData/MethodDataSourceAttribute.cs +++ b/TUnit.Core/Attributes/TestData/MethodDataSourceAttribute.cs @@ -82,7 +82,10 @@ public MethodDataSourceAttribute( // If we have a test class instance and no explicit class was provided, // use the instance's actual type (which will be the constructed generic type) - if (ClassProvidingDataSource == null && dataGeneratorMetadata.TestClassInstance != null) + // Skip PlaceholderInstance as it's used during discovery when the actual instance isn't created yet + if (ClassProvidingDataSource == null + && dataGeneratorMetadata.TestClassInstance != null + && dataGeneratorMetadata.TestClassInstance is not PlaceholderInstance) { targetType = dataGeneratorMetadata.TestClassInstance.GetType(); } From da86e469d3379c546ea30a856b8af4978065edf5 Mon Sep 17 00:00:00 2001 From: Tom Longhurst <30480171+thomhurst@users.noreply.github.com> Date: Sun, 7 Dec 2025 19:09:09 +0000 Subject: [PATCH 17/20] fix: normalize line endings in exception message assertions for consistency --- TUnit.Assertions.Tests/AssertConditions/BecauseTests.cs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/TUnit.Assertions.Tests/AssertConditions/BecauseTests.cs b/TUnit.Assertions.Tests/AssertConditions/BecauseTests.cs index 1d44b17e48..0c4c88137f 100644 --- a/TUnit.Assertions.Tests/AssertConditions/BecauseTests.cs +++ b/TUnit.Assertions.Tests/AssertConditions/BecauseTests.cs @@ -1,4 +1,4 @@ -namespace TUnit.Assertions.Tests.AssertConditions; +namespace TUnit.Assertions.Tests.AssertConditions; public class BecauseTests { @@ -68,7 +68,7 @@ at Assert.That(variable).IsFalse() }; var exception = await Assert.ThrowsAsync(action); - await Assert.That(exception.Message).IsEqualTo(expectedMessage); + await Assert.That(exception.Message.NormalizeLineEndings()).IsEqualTo(expectedMessage.NormalizeLineEndings()); } [Test] @@ -91,7 +91,7 @@ await Assert.That(variable).IsTrue().Because(because) }; var exception = await Assert.ThrowsAsync(action); - await Assert.That(exception.Message.NormalizeLineEndings()).IsEqualTo(expectedMessage); + await Assert.That(exception.Message.NormalizeLineEndings()).IsEqualTo(expectedMessage.NormalizeLineEndings()); } [Test] From ca3065fa476f269b9be235f3f7241140715f1e74 Mon Sep 17 00:00:00 2001 From: Tom Longhurst <30480171+thomhurst@users.noreply.github.com> Date: Sun, 7 Dec 2025 19:58:29 +0000 Subject: [PATCH 18/20] fix: normalize line endings in exception messages for consistency across tests --- .../Delegates/Throws.ExactlyTests.cs | 27 ++++++++------- .../Delegates/Throws.ExceptionTests.cs | 6 ++-- .../Delegates/Throws.NothingTests.cs | 6 ++-- .../Delegates/Throws.OfTypeTests.cs | 18 +++++----- .../Throws.WithInnerExceptionTests.cs | 6 ++-- .../Throws.WithMessageMatchingTests.cs | 6 ++-- .../Delegates/Throws.WithMessageTests.cs | 6 ++-- .../Throws.WithParameterNameTests.cs | 6 ++-- TUnit.Assertions.Tests/Bugs/Tests2117.cs | 14 ++++---- .../Helpers/StringDifferenceTests.cs | 24 ++++++------- .../Old/AssertMultipleTests.cs | 34 +++++++++---------- .../Old/EquivalentAssertionTests.cs | 4 +-- .../Old/StringRegexAssertionTests.cs | 24 ++++++------- .../ThrowInDelegateValueAssertionTests.cs | 13 +++---- 14 files changed, 99 insertions(+), 95 deletions(-) diff --git a/TUnit.Assertions.Tests/Assertions/Delegates/Throws.ExactlyTests.cs b/TUnit.Assertions.Tests/Assertions/Delegates/Throws.ExactlyTests.cs index c3292ec78e..35d744499e 100644 --- a/TUnit.Assertions.Tests/Assertions/Delegates/Throws.ExactlyTests.cs +++ b/TUnit.Assertions.Tests/Assertions/Delegates/Throws.ExactlyTests.cs @@ -13,15 +13,15 @@ public async Task Fails_For_Code_With_Other_Exceptions() but threw TUnit.Assertions.Tests.Assertions.Delegates.Throws+OtherException at Assert.That(action).ThrowsExactly() - """; + """.NormalizeLineEndings(); Exception exception = CreateOtherException(); Action action = () => throw exception; var sut = async () => await Assert.That(action).ThrowsExactly(); - await Assert.That(sut).ThrowsException() - .WithMessage(expectedMessage); + var thrownException = await Assert.That(sut).ThrowsException(); + await Assert.That(thrownException.Message.NormalizeLineEndings()).IsEqualTo(expectedMessage); } [Test] @@ -32,15 +32,15 @@ public async Task Fails_For_Code_With_Subtype_Exceptions() but wrong exception type: SubCustomException instead of exactly CustomException at Assert.That(action).ThrowsExactly() - """; + """.NormalizeLineEndings(); Exception exception = CreateSubCustomException(); Action action = () => throw exception; var sut = async () => await Assert.That(action).ThrowsExactly(); - await Assert.That(sut).ThrowsException() - .WithMessage(expectedMessage); + var thrownException = await Assert.That(sut).ThrowsException(); + await Assert.That(thrownException.Message.NormalizeLineEndings()).IsEqualTo(expectedMessage); } [Test] @@ -51,14 +51,14 @@ public async Task Fails_For_Code_Without_Exceptions() but no exception was thrown at Assert.That(action).ThrowsExactly() - """; + """.NormalizeLineEndings(); var action = () => { }; var sut = async () => await Assert.That(action).ThrowsExactly(); - await Assert.That(sut).ThrowsException() - .WithMessage(expectedMessage); + var thrownException = await Assert.That(sut).ThrowsException(); + await Assert.That(thrownException.Message.NormalizeLineEndings()).IsEqualTo(expectedMessage); } [Test] @@ -117,10 +117,11 @@ public async Task Conversion_To_Value_Assertion_Builder_On_Casted_Exception_Type await Assert.That((object)ex).IsAssignableTo(); }); - await Assert.That(assertionException).HasMessageStartingWith(""" - Expected to throw exactly Exception - but wrong exception type: CustomException instead of exactly Exception - """); + var expectedPrefix = """ + Expected to throw exactly Exception + but wrong exception type: CustomException instead of exactly Exception + """.NormalizeLineEndings(); + await Assert.That(assertionException.Message.NormalizeLineEndings()).StartsWith(expectedPrefix); } [Test] diff --git a/TUnit.Assertions.Tests/Assertions/Delegates/Throws.ExceptionTests.cs b/TUnit.Assertions.Tests/Assertions/Delegates/Throws.ExceptionTests.cs index cedfa0b176..3b51f1e29e 100644 --- a/TUnit.Assertions.Tests/Assertions/Delegates/Throws.ExceptionTests.cs +++ b/TUnit.Assertions.Tests/Assertions/Delegates/Throws.ExceptionTests.cs @@ -14,14 +14,14 @@ public async Task Fails_For_Code_Without_Exceptions() but no exception was thrown at Assert.That(action).ThrowsException() - """; + """.NormalizeLineEndings(); var action = () => { }; var sut = async () => await Assert.That(action).ThrowsException(); - await Assert.That(sut).ThrowsException() - .WithMessage(expectedMessage); + var thrownException = await Assert.That(sut).ThrowsException(); + await Assert.That(thrownException.Message.NormalizeLineEndings()).IsEqualTo(expectedMessage); } [Test] diff --git a/TUnit.Assertions.Tests/Assertions/Delegates/Throws.NothingTests.cs b/TUnit.Assertions.Tests/Assertions/Delegates/Throws.NothingTests.cs index 48077a848d..29fc9e7d33 100644 --- a/TUnit.Assertions.Tests/Assertions/Delegates/Throws.NothingTests.cs +++ b/TUnit.Assertions.Tests/Assertions/Delegates/Throws.NothingTests.cs @@ -12,15 +12,15 @@ public async Task Fails_For_Code_With_Exceptions() but threw TUnit.Assertions.Tests.Assertions.Delegates.Throws+CustomException: {nameof(Fails_For_Code_With_Exceptions)} at Assert.That(action).ThrowsNothing() - """; + """.NormalizeLineEndings(); Exception exception = CreateCustomException(); Action action = () => throw exception; var sut = async () => await Assert.That(action).ThrowsNothing(); - await Assert.That(sut).ThrowsException() - .WithMessage(expectedMessage); + var thrownException = await Assert.That(sut).ThrowsException(); + await Assert.That(thrownException.Message.NormalizeLineEndings()).IsEqualTo(expectedMessage); } [Test] diff --git a/TUnit.Assertions.Tests/Assertions/Delegates/Throws.OfTypeTests.cs b/TUnit.Assertions.Tests/Assertions/Delegates/Throws.OfTypeTests.cs index a900ceaff2..ef4c853221 100644 --- a/TUnit.Assertions.Tests/Assertions/Delegates/Throws.OfTypeTests.cs +++ b/TUnit.Assertions.Tests/Assertions/Delegates/Throws.OfTypeTests.cs @@ -13,15 +13,15 @@ public async Task Fails_For_Code_With_Other_Exceptions() but threw TUnit.Assertions.Tests.Assertions.Delegates.Throws+OtherException at Assert.That(action).Throws() - """; + """.NormalizeLineEndings(); Exception exception = CreateOtherException(); Action action = () => throw exception; var sut = async () => await Assert.That(action).Throws(); - await Assert.That(sut).ThrowsException() - .WithMessage(expectedMessage); + var thrownException = await Assert.That(sut).ThrowsException(); + await Assert.That(thrownException.Message.NormalizeLineEndings()).IsEqualTo(expectedMessage); } [Test] @@ -32,15 +32,15 @@ public async Task Fails_For_Code_With_Supertype_Exceptions() but threw TUnit.Assertions.Tests.Assertions.Delegates.Throws+CustomException at Assert.That(action).Throws() - """; + """.NormalizeLineEndings(); Exception exception = CreateCustomException(); Action action = () => throw exception; var sut = async () => await Assert.That(action).Throws(); - await Assert.That(sut).ThrowsException() - .WithMessage(expectedMessage); + var thrownException = await Assert.That(sut).ThrowsException(); + await Assert.That(thrownException.Message.NormalizeLineEndings()).IsEqualTo(expectedMessage); } [Test] @@ -51,14 +51,14 @@ public async Task Fails_For_Code_Without_Exceptions() but no exception was thrown at Assert.That(action).Throws() - """; + """.NormalizeLineEndings(); var action = () => { }; var sut = async () => await Assert.That(action).Throws(); - await Assert.That(sut).ThrowsException() - .WithMessage(expectedMessage); + var thrownException = await Assert.That(sut).ThrowsException(); + await Assert.That(thrownException.Message.NormalizeLineEndings()).IsEqualTo(expectedMessage); } [Test] diff --git a/TUnit.Assertions.Tests/Assertions/Delegates/Throws.WithInnerExceptionTests.cs b/TUnit.Assertions.Tests/Assertions/Delegates/Throws.WithInnerExceptionTests.cs index 88386844cd..2b167014c5 100644 --- a/TUnit.Assertions.Tests/Assertions/Delegates/Throws.WithInnerExceptionTests.cs +++ b/TUnit.Assertions.Tests/Assertions/Delegates/Throws.WithInnerExceptionTests.cs @@ -15,7 +15,7 @@ Expected exception message to equal "bar" but exception message was "some different inner message" at Assert.That(action).ThrowsException().WithInnerException().WithMessage("bar") - """; + """.NormalizeLineEndings(); Exception exception = CreateCustomException(outerMessage, CreateCustomException("some different inner message")); Action action = () => throw exception; @@ -24,8 +24,8 @@ at Assert.That(action).ThrowsException().WithInnerException().WithMessage("bar") => await Assert.That(action).ThrowsException() .WithInnerException().WithMessage(expectedInnerMessage); - await Assert.That(sut).ThrowsException() - .WithMessage(expectedMessage); + var thrownException = await Assert.That(sut).ThrowsException(); + await Assert.That(thrownException.Message.NormalizeLineEndings()).IsEqualTo(expectedMessage); } [Test] diff --git a/TUnit.Assertions.Tests/Assertions/Delegates/Throws.WithMessageMatchingTests.cs b/TUnit.Assertions.Tests/Assertions/Delegates/Throws.WithMessageMatchingTests.cs index 8401907110..0d8e0b37ff 100644 --- a/TUnit.Assertions.Tests/Assertions/Delegates/Throws.WithMessageMatchingTests.cs +++ b/TUnit.Assertions.Tests/Assertions/Delegates/Throws.WithMessageMatchingTests.cs @@ -42,15 +42,15 @@ Expected exception message to match pattern "bar" but exception message "foo" does not match pattern "bar" at Assert.That(action).ThrowsExactly().WithMessageMatching("bar") - """; + """.NormalizeLineEndings(); Exception exception = CreateCustomException(message1); Action action = () => throw exception; var sut = async () => await Assert.That(action).ThrowsExactly().WithMessageMatching(message2); - await Assert.That(sut).ThrowsException() - .WithMessage(expectedMessage); + var thrownException = await Assert.That(sut).ThrowsException(); + await Assert.That(thrownException.Message.NormalizeLineEndings()).IsEqualTo(expectedMessage); } [Test] diff --git a/TUnit.Assertions.Tests/Assertions/Delegates/Throws.WithMessageTests.cs b/TUnit.Assertions.Tests/Assertions/Delegates/Throws.WithMessageTests.cs index d444dc3372..293d109d1e 100644 --- a/TUnit.Assertions.Tests/Assertions/Delegates/Throws.WithMessageTests.cs +++ b/TUnit.Assertions.Tests/Assertions/Delegates/Throws.WithMessageTests.cs @@ -15,15 +15,15 @@ Expected exception message to equal "bar" but exception message was "foo" at Assert.That(action).ThrowsExactly().WithMessage("bar") - """; + """.NormalizeLineEndings(); Exception exception = CreateCustomException(message1); Action action = () => throw exception; var sut = async () => await Assert.That(action).ThrowsExactly().WithMessage(message2); - await Assert.That(sut).ThrowsException() - .WithMessage(expectedMessage); + var thrownException = await Assert.That(sut).ThrowsException(); + await Assert.That(thrownException.Message.NormalizeLineEndings()).IsEqualTo(expectedMessage); } [Test] diff --git a/TUnit.Assertions.Tests/Assertions/Delegates/Throws.WithParameterNameTests.cs b/TUnit.Assertions.Tests/Assertions/Delegates/Throws.WithParameterNameTests.cs index 41c2ca56c0..70c0c658f9 100644 --- a/TUnit.Assertions.Tests/Assertions/Delegates/Throws.WithParameterNameTests.cs +++ b/TUnit.Assertions.Tests/Assertions/Delegates/Throws.WithParameterNameTests.cs @@ -15,15 +15,15 @@ public async Task Fails_For_Different_Parameter_Name() but ArgumentException parameter name was "foo" at Assert.That(action).ThrowsExactly().WithParameterName("bar") - """; + """.NormalizeLineEndings(); ArgumentException exception = new(string.Empty, paramName1); Action action = () => throw exception; var sut = async () => await Assert.That(action).ThrowsExactly().WithParameterName(paramName2); - await Assert.That(sut).ThrowsException() - .WithMessage(expectedMessage); + var thrownException = await Assert.That(sut).ThrowsException(); + await Assert.That(thrownException.Message.NormalizeLineEndings()).IsEqualTo(expectedMessage); } [Test] diff --git a/TUnit.Assertions.Tests/Bugs/Tests2117.cs b/TUnit.Assertions.Tests/Bugs/Tests2117.cs index df5fcad7cb..6fe04798e3 100644 --- a/TUnit.Assertions.Tests/Bugs/Tests2117.cs +++ b/TUnit.Assertions.Tests/Bugs/Tests2117.cs @@ -28,12 +28,13 @@ at Assert.That(a).IsEquivalentTo(b) """)] public async Task IsEquivalent_Fail(int[] a, int[] b, CollectionOrdering? collectionOrdering, string expectedError) { - await Assert.That(async () => + var exception = await Assert.That(async () => await (collectionOrdering is null ? Assert.That(a).IsEquivalentTo(b) : Assert.That(a).IsEquivalentTo(b, collectionOrdering.Value)) - ).Throws() - .WithMessage(expectedError); + ).Throws(); + + await Assert.That(exception.Message.NormalizeLineEndings()).IsEqualTo(expectedError.NormalizeLineEndings()); } [Test] @@ -60,11 +61,12 @@ at Assert.That(a).IsNotEquivalentTo(b) """)] public async Task IsNotEquivalent_Fail(int[] a, int[] b, CollectionOrdering? collectionOrdering, string expectedError) { - await Assert.That(async () => + var exception = await Assert.That(async () => await (collectionOrdering is null ? Assert.That(a).IsNotEquivalentTo(b) : Assert.That(a).IsNotEquivalentTo(b, collectionOrdering.Value)) - ).Throws() - .WithMessage(expectedError); + ).Throws(); + + await Assert.That(exception.Message.NormalizeLineEndings()).IsEqualTo(expectedError.NormalizeLineEndings()); } } diff --git a/TUnit.Assertions.Tests/Helpers/StringDifferenceTests.cs b/TUnit.Assertions.Tests/Helpers/StringDifferenceTests.cs index 8aee42cc8e..18d1ef01d7 100644 --- a/TUnit.Assertions.Tests/Helpers/StringDifferenceTests.cs +++ b/TUnit.Assertions.Tests/Helpers/StringDifferenceTests.cs @@ -10,15 +10,15 @@ Expected to be equal to "some text" but found "" at Assert.That(actual).IsEqualTo(expected) - """; + """.NormalizeLineEndings(); var actual = ""; var expected = "some text"; var sut = async () => await Assert.That(actual).IsEqualTo(expected); - await Assert.That(sut).ThrowsException() - .WithMessage(expectedMessage); + var exception = await Assert.That(sut).ThrowsException(); + await Assert.That(exception.Message.NormalizeLineEndings()).IsEqualTo(expectedMessage); } [Test] @@ -29,15 +29,15 @@ Expected to be equal to "" but found "actual text" at Assert.That(actual).IsEqualTo(expected) - """; + """.NormalizeLineEndings(); var actual = "actual text"; var expected = ""; var sut = async () => await Assert.That(actual).IsEqualTo(expected); - await Assert.That(sut).ThrowsException() - .WithMessage(expectedMessage); + var exception = await Assert.That(sut).ThrowsException(); + await Assert.That(exception.Message.NormalizeLineEndings()).IsEqualTo(expectedMessage); } [Test] @@ -48,15 +48,15 @@ Expected to be equal to "some text" but found "some" at Assert.That(actual).IsEqualTo(expected) - """; + """.NormalizeLineEndings(); var actual = "some"; var expected = "some text"; var sut = async () => await Assert.That(actual).IsEqualTo(expected); - await Assert.That(sut).ThrowsException() - .WithMessage(expectedMessage); + var exception = await Assert.That(sut).ThrowsException(); + await Assert.That(exception.Message.NormalizeLineEndings()).IsEqualTo(expectedMessage); } [Test] @@ -67,14 +67,14 @@ Expected to be equal to "some" but found "some text" at Assert.That(actual).IsEqualTo(expected) - """; + """.NormalizeLineEndings(); var actual = "some text"; var expected = "some"; var sut = async () => await Assert.That(actual).IsEqualTo(expected); - await Assert.That(sut).ThrowsException() - .WithMessage(expectedMessage); + var exception = await Assert.That(sut).ThrowsException(); + await Assert.That(exception.Message.NormalizeLineEndings()).IsEqualTo(expectedMessage); } } diff --git a/TUnit.Assertions.Tests/Old/AssertMultipleTests.cs b/TUnit.Assertions.Tests/Old/AssertMultipleTests.cs index a12e4a5421..8f7ff94050 100644 --- a/TUnit.Assertions.Tests/Old/AssertMultipleTests.cs +++ b/TUnit.Assertions.Tests/Old/AssertMultipleTests.cs @@ -33,35 +33,35 @@ Expected to be 2 but found 1 at Assert.That(1).IsEqualTo(2) - """); + """.NormalizeLineEndings()); await TUnitAssert.That(exception2.Message.NormalizeLineEndings()).IsEqualTo(""" Expected to be 3 but found 2 at Assert.That(2).IsEqualTo(3) - """); + """.NormalizeLineEndings()); await TUnitAssert.That(exception3.Message.NormalizeLineEndings()).IsEqualTo(""" Expected to be 4 but found 3 at Assert.That(3).IsEqualTo(4) - """); + """.NormalizeLineEndings()); await TUnitAssert.That(exception4.Message.NormalizeLineEndings()).IsEqualTo(""" Expected to be 5 but found 4 at Assert.That(4).IsEqualTo(5) - """); + """.NormalizeLineEndings()); await TUnitAssert.That(exception5.Message.NormalizeLineEndings()).IsEqualTo(""" Expected to be 6 but found 5 at Assert.That(5).IsEqualTo(6) - """); + """.NormalizeLineEndings()); } [Test] @@ -93,7 +93,7 @@ or to be 3 but found 1 at Assert.That(1).IsEqualTo(2).Or.IsEqualTo(3) - """); + """.NormalizeLineEndings()); await TUnitAssert.That(exception2.Message.NormalizeLineEndings()).IsEqualTo(""" Expected to be 3 @@ -101,7 +101,7 @@ and to be 4 but found 2 at Assert.That(2).IsEqualTo(3).And.IsEqualTo(4) - """); + """.NormalizeLineEndings()); await TUnitAssert.That(exception3.Message.NormalizeLineEndings()).IsEqualTo(""" Expected to be 4 @@ -109,7 +109,7 @@ or to be 5 but found 3 at Assert.That(3).IsEqualTo(4).Or.IsEqualTo(5) - """); + """.NormalizeLineEndings()); await TUnitAssert.That(exception4.Message.NormalizeLineEndings()).IsEqualTo(""" Expected to be 5 @@ -117,7 +117,7 @@ and to be 6 but found 4 at Assert.That(4).IsEqualTo(5).And.IsEqualTo(6) - """); + """.NormalizeLineEndings()); await TUnitAssert.That(exception5.Message.NormalizeLineEndings()).IsEqualTo(""" Expected to be 6 @@ -125,7 +125,7 @@ or to be 7 but found 5 at Assert.That(5).IsEqualTo(6).Or.IsEqualTo(7) - """); + """.NormalizeLineEndings()); } [Test] @@ -176,48 +176,48 @@ Expected to be 2 but found 1 at Assert.That(1).IsEqualTo(2) - """); + """.NormalizeLineEndings()); await TUnitAssert.That(assertionException2.Message.NormalizeLineEndings()).IsEqualTo(""" Expected to be 3 but found 2 at Assert.That(2).IsEqualTo(3) - """); + """.NormalizeLineEndings()); await TUnitAssert.That(assertionException3.Message.NormalizeLineEndings()).IsEqualTo(""" Expected to be 4 but found 3 at Assert.That(3).IsEqualTo(4) - """); + """.NormalizeLineEndings()); await TUnitAssert.That(assertionException4.Message.NormalizeLineEndings()).IsEqualTo(""" Expected to be 5 but found 4 at Assert.That(4).IsEqualTo(5) - """); + """.NormalizeLineEndings()); await TUnitAssert.That(assertionException5.Message.NormalizeLineEndings()).IsEqualTo(""" Expected to be 6 but found 5 at Assert.That(5).IsEqualTo(6) - """); + """.NormalizeLineEndings()); await TUnitAssert.That(assertionException6.Message.NormalizeLineEndings()).IsEqualTo(""" Expected to be 7 but found 6 at Assert.That(6).IsEqualTo(7) - """); + """.NormalizeLineEndings()); await TUnitAssert.That(assertionException7.Message.NormalizeLineEndings()).IsEqualTo(""" Expected to be 8 but found 7 at Assert.That(7).IsEqualTo(8) - """); + """.NormalizeLineEndings()); } } diff --git a/TUnit.Assertions.Tests/Old/EquivalentAssertionTests.cs b/TUnit.Assertions.Tests/Old/EquivalentAssertionTests.cs index 841660fe8d..a420ea73d1 100644 --- a/TUnit.Assertions.Tests/Old/EquivalentAssertionTests.cs +++ b/TUnit.Assertions.Tests/Old/EquivalentAssertionTests.cs @@ -136,7 +136,7 @@ await TUnitAssert.That(exception!.Message.NormalizeLineEndings()).IsEqualTo( but collection item at index 1 does not match: expected 2, but was 5 at Assert.That(array).IsEquivalentTo(list, CollectionOrdering.Matching) - """ + """.NormalizeLineEndings() ); } @@ -155,7 +155,7 @@ await TUnitAssert.That(exception!.Message.NormalizeLineEndings()).IsEqualTo( but collection item at index 1 does not match: expected 2, but was 5 at Assert.That(array).IsEquivalentTo(list, CollectionOrdering.Matching) - """ + """.NormalizeLineEndings() ); } diff --git a/TUnit.Assertions.Tests/Old/StringRegexAssertionTests.cs b/TUnit.Assertions.Tests/Old/StringRegexAssertionTests.cs index 3c1bffb616..98b2bcae9f 100644 --- a/TUnit.Assertions.Tests/Old/StringRegexAssertionTests.cs +++ b/TUnit.Assertions.Tests/Old/StringRegexAssertionTests.cs @@ -56,13 +56,13 @@ public async Task Matches_WithInvalidPattern_StringPattern_Throws(Type exception return; } - await TUnitAssert.That(exception!.Message).IsEqualTo( + await TUnitAssert.That(exception!.Message.NormalizeLineEndings()).IsEqualTo( $""" Expected text match pattern but The regex "^\d+$" does not match with "{text}" at Assert.That(text).Matches(pattern) - """ + """.NormalizeLineEndings() ); } @@ -81,13 +81,13 @@ public async Task Matches_WithInvalidPattern_RegexPattern_Throws(Type exceptionT return; } - await TUnitAssert.That(exception!.Message).IsEqualTo( + await TUnitAssert.That(exception!.Message.NormalizeLineEndings()).IsEqualTo( $""" Expected text match pattern but The regex "^\d+$" does not match with "{text}" at Assert.That(text).Matches(pattern) - """ + """.NormalizeLineEndings() ); } @@ -110,13 +110,13 @@ public async Task Matches_WithInvalidPattern_GeneratedRegexPattern_Throws(Type e return; } - await TUnitAssert.That(exception!.Message).IsEqualTo( + await TUnitAssert.That(exception!.Message.NormalizeLineEndings()).IsEqualTo( $""" Expected text match regex but The regex "^\d+$" does not match with "Hello123World" at Assert.That(text).Matches(regex) - """ + """.NormalizeLineEndings() ); } #endif @@ -192,13 +192,13 @@ public async Task DoesNotMatch_WithInvalidPattern_StringPattern_Throws(Type exce return; } - await TUnitAssert.That(exception!.Message).IsEqualTo( + await TUnitAssert.That(exception!.Message.NormalizeLineEndings()).IsEqualTo( $""" Expected text to not match with pattern but The regex "^\d+$" matches with "{text}" at Assert.That(text).DoesNotMatch(pattern) - """ + """.NormalizeLineEndings() ); } @@ -217,13 +217,13 @@ public async Task DoesNotMatch_WithInvalidPattern_RegexPattern_Throws(Type excep return; } - await TUnitAssert.That(exception!.Message).IsEqualTo( + await TUnitAssert.That(exception!.Message.NormalizeLineEndings()).IsEqualTo( $""" Expected text to not match with pattern but The regex "^\d+$" matches with "{text}" at Assert.That(text).DoesNotMatch(pattern) - """ + """.NormalizeLineEndings() ); } @@ -246,13 +246,13 @@ public async Task DoesNotMatch_WithInvalidPattern_GeneratedRegexPattern_Throws(T return; } - await TUnitAssert.That(exception!.Message).IsEqualTo( + await TUnitAssert.That(exception!.Message.NormalizeLineEndings()).IsEqualTo( $""" Expected text to not match with regex but The regex "^\d+$" matches with "{text}" at Assert.That(text).DoesNotMatch(regex) - """ + """.NormalizeLineEndings() ); } #endif diff --git a/TUnit.Assertions.Tests/ThrowInDelegateValueAssertionTests.cs b/TUnit.Assertions.Tests/ThrowInDelegateValueAssertionTests.cs index 86cec759e2..c776362562 100644 --- a/TUnit.Assertions.Tests/ThrowInDelegateValueAssertionTests.cs +++ b/TUnit.Assertions.Tests/ThrowInDelegateValueAssertionTests.cs @@ -5,18 +5,19 @@ public class ThrowInDelegateValueAssertionTests [Test] public async Task ThrowInDelegateValueAssertion_ReturnsExpectedErrorMessage() { + var expectedContains = """ + Expected to be equal to True + but threw System.Exception + """.NormalizeLineEndings(); var assertion = async () => await Assert.That(() => { throw new Exception("No"); return true; }).IsEqualTo(true); - await Assert.That(assertion) - .Throws() - .WithMessageContaining(""" - Expected to be equal to True - but threw System.Exception - """); + var exception = await Assert.That(assertion) + .Throws(); + await Assert.That(exception.Message.NormalizeLineEndings()).Contains(expectedContains); } [Test] From 2799dd71fe0aa845a4cf5a570d5ce1c128089968 Mon Sep 17 00:00:00 2001 From: Tom Longhurst <30480171+thomhurst@users.noreply.github.com> Date: Sun, 7 Dec 2025 21:24:46 +0000 Subject: [PATCH 19/20] fix: remove premature cache removal for shared data sources MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The ObjectTracker.OnDisposed callback was removing shared objects from the ScopedDictionary cache when disposed, causing SharedType.PerTestSession and other shared containers to be recreated for each test instead of being reused across the session. This fixes Bug 3803 where PerTestSession containers were instantiated multiple times (once per test) instead of once for the entire session. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- TUnit.Core/Data/ScopedDictionary.cs | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/TUnit.Core/Data/ScopedDictionary.cs b/TUnit.Core/Data/ScopedDictionary.cs index 7a751429ff..e5766792b0 100644 --- a/TUnit.Core/Data/ScopedDictionary.cs +++ b/TUnit.Core/Data/ScopedDictionary.cs @@ -1,6 +1,4 @@ -using TUnit.Core.Tracking; - -namespace TUnit.Core.Data; +namespace TUnit.Core.Data; public class ScopedDictionary where TScope : notnull @@ -11,14 +9,6 @@ public class ScopedDictionary { var innerDictionary = _scopedContainers.GetOrAdd(scope, static _ => new ThreadSafeDictionary()); - var obj = innerDictionary.GetOrAdd(type, factory); - - ObjectTracker.OnDisposed(obj, () => - { - innerDictionary.Remove(type); - }); - - return obj; + return innerDictionary.GetOrAdd(type, factory); } - } From ff2e57edda3085cf2bf49e1eab33c487a141ff36 Mon Sep 17 00:00:00 2001 From: Tom Longhurst <30480171+thomhurst@users.noreply.github.com> Date: Sun, 7 Dec 2025 22:30:53 +0000 Subject: [PATCH 20/20] feat: enable parallel initialization of tracked objects during test execution --- .../Services/ObjectLifecycleService.cs | 25 +++++++++++++++---- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/TUnit.Engine/Services/ObjectLifecycleService.cs b/TUnit.Engine/Services/ObjectLifecycleService.cs index 2d7d4a4ac3..8ab1773f6a 100644 --- a/TUnit.Engine/Services/ObjectLifecycleService.cs +++ b/TUnit.Engine/Services/ObjectLifecycleService.cs @@ -195,6 +195,7 @@ private void SetCachedPropertiesOnInstance(object instance, TestContext testCont /// /// Initializes all tracked objects depth-first (deepest objects first). /// This is called during test execution (after BeforeClass hooks) to initialize IAsyncInitializer objects. + /// Objects at the same level are initialized in parallel. /// private async Task InitializeTrackedObjectsAsync(TestContext testContext, CancellationToken cancellationToken) { @@ -223,14 +224,16 @@ private async Task InitializeTrackedObjectsAsync(TestContext testContext, Cancel objectsAtLevel.CopyTo(objectsCopy); } - // Initialize each tracked object and its nested objects + // Initialize all objects at this level in parallel + var tasks = new List(objectsCopy.Length); foreach (var obj in objectsCopy) { - // First initialize nested objects depth-first - await InitializeNestedObjectsForExecutionAsync(obj, cancellationToken); + tasks.Add(InitializeObjectWithNestedAsync(obj, cancellationToken)); + } - // Then initialize the object itself - await ObjectInitializer.InitializeAsync(obj, cancellationToken); + if (tasks.Count > 0) + { + await Task.WhenAll(tasks); } } } @@ -241,6 +244,18 @@ private async Task InitializeTrackedObjectsAsync(TestContext testContext, Cancel await ObjectInitializer.InitializeAsync(classInstance, cancellationToken); } + /// + /// Initializes an object and its nested objects. + /// + private async Task InitializeObjectWithNestedAsync(object obj, CancellationToken cancellationToken) + { + // First initialize nested objects depth-first + await InitializeNestedObjectsForExecutionAsync(obj, cancellationToken); + + // Then initialize the object itself + await ObjectInitializer.InitializeAsync(obj, cancellationToken); + } + /// /// Initializes nested objects during execution phase - all IAsyncInitializer objects. ///