diff --git a/TUnit.Assertions.Tests/AssertConditions/BecauseTests.cs b/TUnit.Assertions.Tests/AssertConditions/BecauseTests.cs index 1d44b17e48..0c4c88137f 100644 --- a/TUnit.Assertions.Tests/AssertConditions/BecauseTests.cs +++ b/TUnit.Assertions.Tests/AssertConditions/BecauseTests.cs @@ -1,4 +1,4 @@ -namespace TUnit.Assertions.Tests.AssertConditions; +namespace TUnit.Assertions.Tests.AssertConditions; public class BecauseTests { @@ -68,7 +68,7 @@ at Assert.That(variable).IsFalse() }; var exception = await Assert.ThrowsAsync(action); - await Assert.That(exception.Message).IsEqualTo(expectedMessage); + await Assert.That(exception.Message.NormalizeLineEndings()).IsEqualTo(expectedMessage.NormalizeLineEndings()); } [Test] @@ -91,7 +91,7 @@ await Assert.That(variable).IsTrue().Because(because) }; var exception = await Assert.ThrowsAsync(action); - await Assert.That(exception.Message.NormalizeLineEndings()).IsEqualTo(expectedMessage); + await Assert.That(exception.Message.NormalizeLineEndings()).IsEqualTo(expectedMessage.NormalizeLineEndings()); } [Test] diff --git a/TUnit.Assertions.Tests/Assertions/Delegates/Throws.ExactlyTests.cs b/TUnit.Assertions.Tests/Assertions/Delegates/Throws.ExactlyTests.cs index c3292ec78e..35d744499e 100644 --- a/TUnit.Assertions.Tests/Assertions/Delegates/Throws.ExactlyTests.cs +++ b/TUnit.Assertions.Tests/Assertions/Delegates/Throws.ExactlyTests.cs @@ -13,15 +13,15 @@ public async Task Fails_For_Code_With_Other_Exceptions() but threw TUnit.Assertions.Tests.Assertions.Delegates.Throws+OtherException at Assert.That(action).ThrowsExactly() - """; + """.NormalizeLineEndings(); Exception exception = CreateOtherException(); Action action = () => throw exception; var sut = async () => await Assert.That(action).ThrowsExactly(); - await Assert.That(sut).ThrowsException() - .WithMessage(expectedMessage); + var thrownException = await Assert.That(sut).ThrowsException(); + await Assert.That(thrownException.Message.NormalizeLineEndings()).IsEqualTo(expectedMessage); } [Test] @@ -32,15 +32,15 @@ public async Task Fails_For_Code_With_Subtype_Exceptions() but wrong exception type: SubCustomException instead of exactly CustomException at Assert.That(action).ThrowsExactly() - """; + """.NormalizeLineEndings(); Exception exception = CreateSubCustomException(); Action action = () => throw exception; var sut = async () => await Assert.That(action).ThrowsExactly(); - await Assert.That(sut).ThrowsException() - .WithMessage(expectedMessage); + var thrownException = await Assert.That(sut).ThrowsException(); + await Assert.That(thrownException.Message.NormalizeLineEndings()).IsEqualTo(expectedMessage); } [Test] @@ -51,14 +51,14 @@ public async Task Fails_For_Code_Without_Exceptions() but no exception was thrown at Assert.That(action).ThrowsExactly() - """; + """.NormalizeLineEndings(); var action = () => { }; var sut = async () => await Assert.That(action).ThrowsExactly(); - await Assert.That(sut).ThrowsException() - .WithMessage(expectedMessage); + var thrownException = await Assert.That(sut).ThrowsException(); + await Assert.That(thrownException.Message.NormalizeLineEndings()).IsEqualTo(expectedMessage); } [Test] @@ -117,10 +117,11 @@ public async Task Conversion_To_Value_Assertion_Builder_On_Casted_Exception_Type await Assert.That((object)ex).IsAssignableTo(); }); - await Assert.That(assertionException).HasMessageStartingWith(""" - Expected to throw exactly Exception - but wrong exception type: CustomException instead of exactly Exception - """); + var expectedPrefix = """ + Expected to throw exactly Exception + but wrong exception type: CustomException instead of exactly Exception + """.NormalizeLineEndings(); + await Assert.That(assertionException.Message.NormalizeLineEndings()).StartsWith(expectedPrefix); } [Test] diff --git a/TUnit.Assertions.Tests/Assertions/Delegates/Throws.ExceptionTests.cs b/TUnit.Assertions.Tests/Assertions/Delegates/Throws.ExceptionTests.cs index cedfa0b176..3b51f1e29e 100644 --- a/TUnit.Assertions.Tests/Assertions/Delegates/Throws.ExceptionTests.cs +++ b/TUnit.Assertions.Tests/Assertions/Delegates/Throws.ExceptionTests.cs @@ -14,14 +14,14 @@ public async Task Fails_For_Code_Without_Exceptions() but no exception was thrown at Assert.That(action).ThrowsException() - """; + """.NormalizeLineEndings(); var action = () => { }; var sut = async () => await Assert.That(action).ThrowsException(); - await Assert.That(sut).ThrowsException() - .WithMessage(expectedMessage); + var thrownException = await Assert.That(sut).ThrowsException(); + await Assert.That(thrownException.Message.NormalizeLineEndings()).IsEqualTo(expectedMessage); } [Test] diff --git a/TUnit.Assertions.Tests/Assertions/Delegates/Throws.NothingTests.cs b/TUnit.Assertions.Tests/Assertions/Delegates/Throws.NothingTests.cs index 48077a848d..29fc9e7d33 100644 --- a/TUnit.Assertions.Tests/Assertions/Delegates/Throws.NothingTests.cs +++ b/TUnit.Assertions.Tests/Assertions/Delegates/Throws.NothingTests.cs @@ -12,15 +12,15 @@ public async Task Fails_For_Code_With_Exceptions() but threw TUnit.Assertions.Tests.Assertions.Delegates.Throws+CustomException: {nameof(Fails_For_Code_With_Exceptions)} at Assert.That(action).ThrowsNothing() - """; + """.NormalizeLineEndings(); Exception exception = CreateCustomException(); Action action = () => throw exception; var sut = async () => await Assert.That(action).ThrowsNothing(); - await Assert.That(sut).ThrowsException() - .WithMessage(expectedMessage); + var thrownException = await Assert.That(sut).ThrowsException(); + await Assert.That(thrownException.Message.NormalizeLineEndings()).IsEqualTo(expectedMessage); } [Test] diff --git a/TUnit.Assertions.Tests/Assertions/Delegates/Throws.OfTypeTests.cs b/TUnit.Assertions.Tests/Assertions/Delegates/Throws.OfTypeTests.cs index a900ceaff2..ef4c853221 100644 --- a/TUnit.Assertions.Tests/Assertions/Delegates/Throws.OfTypeTests.cs +++ b/TUnit.Assertions.Tests/Assertions/Delegates/Throws.OfTypeTests.cs @@ -13,15 +13,15 @@ public async Task Fails_For_Code_With_Other_Exceptions() but threw TUnit.Assertions.Tests.Assertions.Delegates.Throws+OtherException at Assert.That(action).Throws() - """; + """.NormalizeLineEndings(); Exception exception = CreateOtherException(); Action action = () => throw exception; var sut = async () => await Assert.That(action).Throws(); - await Assert.That(sut).ThrowsException() - .WithMessage(expectedMessage); + var thrownException = await Assert.That(sut).ThrowsException(); + await Assert.That(thrownException.Message.NormalizeLineEndings()).IsEqualTo(expectedMessage); } [Test] @@ -32,15 +32,15 @@ public async Task Fails_For_Code_With_Supertype_Exceptions() but threw TUnit.Assertions.Tests.Assertions.Delegates.Throws+CustomException at Assert.That(action).Throws() - """; + """.NormalizeLineEndings(); Exception exception = CreateCustomException(); Action action = () => throw exception; var sut = async () => await Assert.That(action).Throws(); - await Assert.That(sut).ThrowsException() - .WithMessage(expectedMessage); + var thrownException = await Assert.That(sut).ThrowsException(); + await Assert.That(thrownException.Message.NormalizeLineEndings()).IsEqualTo(expectedMessage); } [Test] @@ -51,14 +51,14 @@ public async Task Fails_For_Code_Without_Exceptions() but no exception was thrown at Assert.That(action).Throws() - """; + """.NormalizeLineEndings(); var action = () => { }; var sut = async () => await Assert.That(action).Throws(); - await Assert.That(sut).ThrowsException() - .WithMessage(expectedMessage); + var thrownException = await Assert.That(sut).ThrowsException(); + await Assert.That(thrownException.Message.NormalizeLineEndings()).IsEqualTo(expectedMessage); } [Test] diff --git a/TUnit.Assertions.Tests/Assertions/Delegates/Throws.WithInnerExceptionTests.cs b/TUnit.Assertions.Tests/Assertions/Delegates/Throws.WithInnerExceptionTests.cs index 88386844cd..2b167014c5 100644 --- a/TUnit.Assertions.Tests/Assertions/Delegates/Throws.WithInnerExceptionTests.cs +++ b/TUnit.Assertions.Tests/Assertions/Delegates/Throws.WithInnerExceptionTests.cs @@ -15,7 +15,7 @@ Expected exception message to equal "bar" but exception message was "some different inner message" at Assert.That(action).ThrowsException().WithInnerException().WithMessage("bar") - """; + """.NormalizeLineEndings(); Exception exception = CreateCustomException(outerMessage, CreateCustomException("some different inner message")); Action action = () => throw exception; @@ -24,8 +24,8 @@ at Assert.That(action).ThrowsException().WithInnerException().WithMessage("bar") => await Assert.That(action).ThrowsException() .WithInnerException().WithMessage(expectedInnerMessage); - await Assert.That(sut).ThrowsException() - .WithMessage(expectedMessage); + var thrownException = await Assert.That(sut).ThrowsException(); + await Assert.That(thrownException.Message.NormalizeLineEndings()).IsEqualTo(expectedMessage); } [Test] diff --git a/TUnit.Assertions.Tests/Assertions/Delegates/Throws.WithMessageMatchingTests.cs b/TUnit.Assertions.Tests/Assertions/Delegates/Throws.WithMessageMatchingTests.cs index 8401907110..0d8e0b37ff 100644 --- a/TUnit.Assertions.Tests/Assertions/Delegates/Throws.WithMessageMatchingTests.cs +++ b/TUnit.Assertions.Tests/Assertions/Delegates/Throws.WithMessageMatchingTests.cs @@ -42,15 +42,15 @@ Expected exception message to match pattern "bar" but exception message "foo" does not match pattern "bar" at Assert.That(action).ThrowsExactly().WithMessageMatching("bar") - """; + """.NormalizeLineEndings(); Exception exception = CreateCustomException(message1); Action action = () => throw exception; var sut = async () => await Assert.That(action).ThrowsExactly().WithMessageMatching(message2); - await Assert.That(sut).ThrowsException() - .WithMessage(expectedMessage); + var thrownException = await Assert.That(sut).ThrowsException(); + await Assert.That(thrownException.Message.NormalizeLineEndings()).IsEqualTo(expectedMessage); } [Test] diff --git a/TUnit.Assertions.Tests/Assertions/Delegates/Throws.WithMessageTests.cs b/TUnit.Assertions.Tests/Assertions/Delegates/Throws.WithMessageTests.cs index d444dc3372..293d109d1e 100644 --- a/TUnit.Assertions.Tests/Assertions/Delegates/Throws.WithMessageTests.cs +++ b/TUnit.Assertions.Tests/Assertions/Delegates/Throws.WithMessageTests.cs @@ -15,15 +15,15 @@ Expected exception message to equal "bar" but exception message was "foo" at Assert.That(action).ThrowsExactly().WithMessage("bar") - """; + """.NormalizeLineEndings(); Exception exception = CreateCustomException(message1); Action action = () => throw exception; var sut = async () => await Assert.That(action).ThrowsExactly().WithMessage(message2); - await Assert.That(sut).ThrowsException() - .WithMessage(expectedMessage); + var thrownException = await Assert.That(sut).ThrowsException(); + await Assert.That(thrownException.Message.NormalizeLineEndings()).IsEqualTo(expectedMessage); } [Test] diff --git a/TUnit.Assertions.Tests/Assertions/Delegates/Throws.WithParameterNameTests.cs b/TUnit.Assertions.Tests/Assertions/Delegates/Throws.WithParameterNameTests.cs index 41c2ca56c0..70c0c658f9 100644 --- a/TUnit.Assertions.Tests/Assertions/Delegates/Throws.WithParameterNameTests.cs +++ b/TUnit.Assertions.Tests/Assertions/Delegates/Throws.WithParameterNameTests.cs @@ -15,15 +15,15 @@ public async Task Fails_For_Different_Parameter_Name() but ArgumentException parameter name was "foo" at Assert.That(action).ThrowsExactly().WithParameterName("bar") - """; + """.NormalizeLineEndings(); ArgumentException exception = new(string.Empty, paramName1); Action action = () => throw exception; var sut = async () => await Assert.That(action).ThrowsExactly().WithParameterName(paramName2); - await Assert.That(sut).ThrowsException() - .WithMessage(expectedMessage); + var thrownException = await Assert.That(sut).ThrowsException(); + await Assert.That(thrownException.Message.NormalizeLineEndings()).IsEqualTo(expectedMessage); } [Test] diff --git a/TUnit.Assertions.Tests/Bugs/Tests2117.cs b/TUnit.Assertions.Tests/Bugs/Tests2117.cs index df5fcad7cb..6fe04798e3 100644 --- a/TUnit.Assertions.Tests/Bugs/Tests2117.cs +++ b/TUnit.Assertions.Tests/Bugs/Tests2117.cs @@ -28,12 +28,13 @@ at Assert.That(a).IsEquivalentTo(b) """)] public async Task IsEquivalent_Fail(int[] a, int[] b, CollectionOrdering? collectionOrdering, string expectedError) { - await Assert.That(async () => + var exception = await Assert.That(async () => await (collectionOrdering is null ? Assert.That(a).IsEquivalentTo(b) : Assert.That(a).IsEquivalentTo(b, collectionOrdering.Value)) - ).Throws() - .WithMessage(expectedError); + ).Throws(); + + await Assert.That(exception.Message.NormalizeLineEndings()).IsEqualTo(expectedError.NormalizeLineEndings()); } [Test] @@ -60,11 +61,12 @@ at Assert.That(a).IsNotEquivalentTo(b) """)] public async Task IsNotEquivalent_Fail(int[] a, int[] b, CollectionOrdering? collectionOrdering, string expectedError) { - await Assert.That(async () => + var exception = await Assert.That(async () => await (collectionOrdering is null ? Assert.That(a).IsNotEquivalentTo(b) : Assert.That(a).IsNotEquivalentTo(b, collectionOrdering.Value)) - ).Throws() - .WithMessage(expectedError); + ).Throws(); + + await Assert.That(exception.Message.NormalizeLineEndings()).IsEqualTo(expectedError.NormalizeLineEndings()); } } diff --git a/TUnit.Assertions.Tests/Helpers/StringDifferenceTests.cs b/TUnit.Assertions.Tests/Helpers/StringDifferenceTests.cs index 8aee42cc8e..18d1ef01d7 100644 --- a/TUnit.Assertions.Tests/Helpers/StringDifferenceTests.cs +++ b/TUnit.Assertions.Tests/Helpers/StringDifferenceTests.cs @@ -10,15 +10,15 @@ Expected to be equal to "some text" but found "" at Assert.That(actual).IsEqualTo(expected) - """; + """.NormalizeLineEndings(); var actual = ""; var expected = "some text"; var sut = async () => await Assert.That(actual).IsEqualTo(expected); - await Assert.That(sut).ThrowsException() - .WithMessage(expectedMessage); + var exception = await Assert.That(sut).ThrowsException(); + await Assert.That(exception.Message.NormalizeLineEndings()).IsEqualTo(expectedMessage); } [Test] @@ -29,15 +29,15 @@ Expected to be equal to "" but found "actual text" at Assert.That(actual).IsEqualTo(expected) - """; + """.NormalizeLineEndings(); var actual = "actual text"; var expected = ""; var sut = async () => await Assert.That(actual).IsEqualTo(expected); - await Assert.That(sut).ThrowsException() - .WithMessage(expectedMessage); + var exception = await Assert.That(sut).ThrowsException(); + await Assert.That(exception.Message.NormalizeLineEndings()).IsEqualTo(expectedMessage); } [Test] @@ -48,15 +48,15 @@ Expected to be equal to "some text" but found "some" at Assert.That(actual).IsEqualTo(expected) - """; + """.NormalizeLineEndings(); var actual = "some"; var expected = "some text"; var sut = async () => await Assert.That(actual).IsEqualTo(expected); - await Assert.That(sut).ThrowsException() - .WithMessage(expectedMessage); + var exception = await Assert.That(sut).ThrowsException(); + await Assert.That(exception.Message.NormalizeLineEndings()).IsEqualTo(expectedMessage); } [Test] @@ -67,14 +67,14 @@ Expected to be equal to "some" but found "some text" at Assert.That(actual).IsEqualTo(expected) - """; + """.NormalizeLineEndings(); var actual = "some text"; var expected = "some"; var sut = async () => await Assert.That(actual).IsEqualTo(expected); - await Assert.That(sut).ThrowsException() - .WithMessage(expectedMessage); + var exception = await Assert.That(sut).ThrowsException(); + await Assert.That(exception.Message.NormalizeLineEndings()).IsEqualTo(expectedMessage); } } diff --git a/TUnit.Assertions.Tests/Old/AssertMultipleTests.cs b/TUnit.Assertions.Tests/Old/AssertMultipleTests.cs index a12e4a5421..8f7ff94050 100644 --- a/TUnit.Assertions.Tests/Old/AssertMultipleTests.cs +++ b/TUnit.Assertions.Tests/Old/AssertMultipleTests.cs @@ -33,35 +33,35 @@ Expected to be 2 but found 1 at Assert.That(1).IsEqualTo(2) - """); + """.NormalizeLineEndings()); await TUnitAssert.That(exception2.Message.NormalizeLineEndings()).IsEqualTo(""" Expected to be 3 but found 2 at Assert.That(2).IsEqualTo(3) - """); + """.NormalizeLineEndings()); await TUnitAssert.That(exception3.Message.NormalizeLineEndings()).IsEqualTo(""" Expected to be 4 but found 3 at Assert.That(3).IsEqualTo(4) - """); + """.NormalizeLineEndings()); await TUnitAssert.That(exception4.Message.NormalizeLineEndings()).IsEqualTo(""" Expected to be 5 but found 4 at Assert.That(4).IsEqualTo(5) - """); + """.NormalizeLineEndings()); await TUnitAssert.That(exception5.Message.NormalizeLineEndings()).IsEqualTo(""" Expected to be 6 but found 5 at Assert.That(5).IsEqualTo(6) - """); + """.NormalizeLineEndings()); } [Test] @@ -93,7 +93,7 @@ or to be 3 but found 1 at Assert.That(1).IsEqualTo(2).Or.IsEqualTo(3) - """); + """.NormalizeLineEndings()); await TUnitAssert.That(exception2.Message.NormalizeLineEndings()).IsEqualTo(""" Expected to be 3 @@ -101,7 +101,7 @@ and to be 4 but found 2 at Assert.That(2).IsEqualTo(3).And.IsEqualTo(4) - """); + """.NormalizeLineEndings()); await TUnitAssert.That(exception3.Message.NormalizeLineEndings()).IsEqualTo(""" Expected to be 4 @@ -109,7 +109,7 @@ or to be 5 but found 3 at Assert.That(3).IsEqualTo(4).Or.IsEqualTo(5) - """); + """.NormalizeLineEndings()); await TUnitAssert.That(exception4.Message.NormalizeLineEndings()).IsEqualTo(""" Expected to be 5 @@ -117,7 +117,7 @@ and to be 6 but found 4 at Assert.That(4).IsEqualTo(5).And.IsEqualTo(6) - """); + """.NormalizeLineEndings()); await TUnitAssert.That(exception5.Message.NormalizeLineEndings()).IsEqualTo(""" Expected to be 6 @@ -125,7 +125,7 @@ or to be 7 but found 5 at Assert.That(5).IsEqualTo(6).Or.IsEqualTo(7) - """); + """.NormalizeLineEndings()); } [Test] @@ -176,48 +176,48 @@ Expected to be 2 but found 1 at Assert.That(1).IsEqualTo(2) - """); + """.NormalizeLineEndings()); await TUnitAssert.That(assertionException2.Message.NormalizeLineEndings()).IsEqualTo(""" Expected to be 3 but found 2 at Assert.That(2).IsEqualTo(3) - """); + """.NormalizeLineEndings()); await TUnitAssert.That(assertionException3.Message.NormalizeLineEndings()).IsEqualTo(""" Expected to be 4 but found 3 at Assert.That(3).IsEqualTo(4) - """); + """.NormalizeLineEndings()); await TUnitAssert.That(assertionException4.Message.NormalizeLineEndings()).IsEqualTo(""" Expected to be 5 but found 4 at Assert.That(4).IsEqualTo(5) - """); + """.NormalizeLineEndings()); await TUnitAssert.That(assertionException5.Message.NormalizeLineEndings()).IsEqualTo(""" Expected to be 6 but found 5 at Assert.That(5).IsEqualTo(6) - """); + """.NormalizeLineEndings()); await TUnitAssert.That(assertionException6.Message.NormalizeLineEndings()).IsEqualTo(""" Expected to be 7 but found 6 at Assert.That(6).IsEqualTo(7) - """); + """.NormalizeLineEndings()); await TUnitAssert.That(assertionException7.Message.NormalizeLineEndings()).IsEqualTo(""" Expected to be 8 but found 7 at Assert.That(7).IsEqualTo(8) - """); + """.NormalizeLineEndings()); } } diff --git a/TUnit.Assertions.Tests/Old/EquivalentAssertionTests.cs b/TUnit.Assertions.Tests/Old/EquivalentAssertionTests.cs index 841660fe8d..a420ea73d1 100644 --- a/TUnit.Assertions.Tests/Old/EquivalentAssertionTests.cs +++ b/TUnit.Assertions.Tests/Old/EquivalentAssertionTests.cs @@ -136,7 +136,7 @@ await TUnitAssert.That(exception!.Message.NormalizeLineEndings()).IsEqualTo( but collection item at index 1 does not match: expected 2, but was 5 at Assert.That(array).IsEquivalentTo(list, CollectionOrdering.Matching) - """ + """.NormalizeLineEndings() ); } @@ -155,7 +155,7 @@ await TUnitAssert.That(exception!.Message.NormalizeLineEndings()).IsEqualTo( but collection item at index 1 does not match: expected 2, but was 5 at Assert.That(array).IsEquivalentTo(list, CollectionOrdering.Matching) - """ + """.NormalizeLineEndings() ); } diff --git a/TUnit.Assertions.Tests/Old/StringRegexAssertionTests.cs b/TUnit.Assertions.Tests/Old/StringRegexAssertionTests.cs index 3c1bffb616..98b2bcae9f 100644 --- a/TUnit.Assertions.Tests/Old/StringRegexAssertionTests.cs +++ b/TUnit.Assertions.Tests/Old/StringRegexAssertionTests.cs @@ -56,13 +56,13 @@ public async Task Matches_WithInvalidPattern_StringPattern_Throws(Type exception return; } - await TUnitAssert.That(exception!.Message).IsEqualTo( + await TUnitAssert.That(exception!.Message.NormalizeLineEndings()).IsEqualTo( $""" Expected text match pattern but The regex "^\d+$" does not match with "{text}" at Assert.That(text).Matches(pattern) - """ + """.NormalizeLineEndings() ); } @@ -81,13 +81,13 @@ public async Task Matches_WithInvalidPattern_RegexPattern_Throws(Type exceptionT return; } - await TUnitAssert.That(exception!.Message).IsEqualTo( + await TUnitAssert.That(exception!.Message.NormalizeLineEndings()).IsEqualTo( $""" Expected text match pattern but The regex "^\d+$" does not match with "{text}" at Assert.That(text).Matches(pattern) - """ + """.NormalizeLineEndings() ); } @@ -110,13 +110,13 @@ public async Task Matches_WithInvalidPattern_GeneratedRegexPattern_Throws(Type e return; } - await TUnitAssert.That(exception!.Message).IsEqualTo( + await TUnitAssert.That(exception!.Message.NormalizeLineEndings()).IsEqualTo( $""" Expected text match regex but The regex "^\d+$" does not match with "Hello123World" at Assert.That(text).Matches(regex) - """ + """.NormalizeLineEndings() ); } #endif @@ -192,13 +192,13 @@ public async Task DoesNotMatch_WithInvalidPattern_StringPattern_Throws(Type exce return; } - await TUnitAssert.That(exception!.Message).IsEqualTo( + await TUnitAssert.That(exception!.Message.NormalizeLineEndings()).IsEqualTo( $""" Expected text to not match with pattern but The regex "^\d+$" matches with "{text}" at Assert.That(text).DoesNotMatch(pattern) - """ + """.NormalizeLineEndings() ); } @@ -217,13 +217,13 @@ public async Task DoesNotMatch_WithInvalidPattern_RegexPattern_Throws(Type excep return; } - await TUnitAssert.That(exception!.Message).IsEqualTo( + await TUnitAssert.That(exception!.Message.NormalizeLineEndings()).IsEqualTo( $""" Expected text to not match with pattern but The regex "^\d+$" matches with "{text}" at Assert.That(text).DoesNotMatch(pattern) - """ + """.NormalizeLineEndings() ); } @@ -246,13 +246,13 @@ public async Task DoesNotMatch_WithInvalidPattern_GeneratedRegexPattern_Throws(T return; } - await TUnitAssert.That(exception!.Message).IsEqualTo( + await TUnitAssert.That(exception!.Message.NormalizeLineEndings()).IsEqualTo( $""" Expected text to not match with regex but The regex "^\d+$" matches with "{text}" at Assert.That(text).DoesNotMatch(regex) - """ + """.NormalizeLineEndings() ); } #endif diff --git a/TUnit.Assertions.Tests/ThrowInDelegateValueAssertionTests.cs b/TUnit.Assertions.Tests/ThrowInDelegateValueAssertionTests.cs index 86cec759e2..c776362562 100644 --- a/TUnit.Assertions.Tests/ThrowInDelegateValueAssertionTests.cs +++ b/TUnit.Assertions.Tests/ThrowInDelegateValueAssertionTests.cs @@ -5,18 +5,19 @@ public class ThrowInDelegateValueAssertionTests [Test] public async Task ThrowInDelegateValueAssertion_ReturnsExpectedErrorMessage() { + var expectedContains = """ + Expected to be equal to True + but threw System.Exception + """.NormalizeLineEndings(); var assertion = async () => await Assert.That(() => { throw new Exception("No"); return true; }).IsEqualTo(true); - await Assert.That(assertion) - .Throws() - .WithMessageContaining(""" - Expected to be equal to True - but threw System.Exception - """); + var exception = await Assert.That(assertion) + .Throws(); + await Assert.That(exception.Message.NormalizeLineEndings()).Contains(expectedContains); } [Test] diff --git a/TUnit.Assertions/Conditions/EqualsAssertion.cs b/TUnit.Assertions/Conditions/EqualsAssertion.cs index 821a2b4e7f..350c439fad 100644 --- a/TUnit.Assertions/Conditions/EqualsAssertion.cs +++ b/TUnit.Assertions/Conditions/EqualsAssertion.cs @@ -3,6 +3,7 @@ using System.Reflection; using System.Text; using TUnit.Assertions.Attributes; +using TUnit.Assertions.Conditions.Helpers; using TUnit.Assertions.Core; namespace TUnit.Assertions.Conditions; @@ -84,7 +85,7 @@ protected override Task CheckAsync(EvaluationMetadata m if (_ignoredTypes.Count > 0) { // Use reference-based tracking to detect cycles - var visited = new HashSet(new ReferenceEqualityComparer()); + var visited = new HashSet(ReferenceEqualityComparer.Instance); var result = DeepEquals(value, _expected, _ignoredTypes, visited); if (result.IsSuccess) { @@ -213,15 +214,4 @@ private static (bool IsSuccess, string? Message) DeepEquals(object? actual, obje } protected override string GetExpectation() => $"to be equal to {(_expected is string s ? $"\"{s}\"" : _expected)}"; - - /// - /// Comparer that uses reference equality instead of value equality. - /// Used for cycle detection in deep comparison. - /// - private sealed class ReferenceEqualityComparer : IEqualityComparer - { - public new bool Equals(object? x, object? y) => ReferenceEquals(x, y); - - public int GetHashCode(object obj) => System.Runtime.CompilerServices.RuntimeHelpers.GetHashCode(obj); - } } diff --git a/TUnit.Assertions/Conditions/Helpers/ExpressionHelper.cs b/TUnit.Assertions/Conditions/Helpers/ExpressionHelper.cs new file mode 100644 index 0000000000..4070c1c6a5 --- /dev/null +++ b/TUnit.Assertions/Conditions/Helpers/ExpressionHelper.cs @@ -0,0 +1,43 @@ +namespace TUnit.Assertions.Conditions.Helpers; + +/// +/// Helper methods for parsing and extracting information from assertion expressions. +/// Consolidates expression parsing logic to ensure consistent behavior across assertion classes. +/// +internal static class ExpressionHelper +{ + /// + /// Extracts the source variable name from an assertion expression string. + /// + /// The expression string, e.g., "Assert.That(variableName).IsEquivalentTo(...)" + /// The variable name, or "value" if it cannot be extracted or is a lambda expression. + /// + /// Input: "Assert.That(myObject).IsEquivalentTo(expected)" + /// Output: "myObject" + /// + /// Input: "Assert.That(async () => GetValue()).IsEquivalentTo(expected)" + /// Output: "value" + /// + public static string ExtractSourceVariable(string expression) + { + // Extract variable name from "Assert.That(variableName)" or similar + var thatIndex = expression.IndexOf(".That(", StringComparison.Ordinal); + if (thatIndex >= 0) + { + var startIndex = thatIndex + 6; // Length of ".That(" + var endIndex = expression.IndexOf(')', startIndex); + if (endIndex > startIndex) + { + var variable = expression.Substring(startIndex, endIndex - startIndex); + // Handle lambda expressions like "async () => ..." by returning "value" + if (variable.Contains("=>") || variable.StartsWith("()", StringComparison.Ordinal)) + { + return "value"; + } + return variable; + } + } + + return "value"; + } +} diff --git a/TUnit.Assertions/Conditions/Helpers/ReflectionHelper.cs b/TUnit.Assertions/Conditions/Helpers/ReflectionHelper.cs new file mode 100644 index 0000000000..c6c53162c5 --- /dev/null +++ b/TUnit.Assertions/Conditions/Helpers/ReflectionHelper.cs @@ -0,0 +1,63 @@ +using System.Diagnostics.CodeAnalysis; +using System.Reflection; + +namespace TUnit.Assertions.Conditions.Helpers; + +/// +/// Helper methods for reflection-based member access. +/// Consolidates reflection logic to ensure consistent behavior and reduce code duplication. +/// +internal static class ReflectionHelper +{ + /// + /// Gets all public instance properties and fields to compare for structural equivalency. + /// + /// The type to get members from. + /// A list of PropertyInfo and FieldInfo members. + public static List GetMembersToCompare( + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicFields)] + Type type) + { + var members = new List(); + members.AddRange(type.GetProperties(BindingFlags.Public | BindingFlags.Instance)); + members.AddRange(type.GetFields(BindingFlags.Public | BindingFlags.Instance)); + return members; + } + + /// + /// Gets the value of a member (property or field) from an object. + /// + /// The object to get the value from. + /// The member (PropertyInfo or FieldInfo) to read. + /// The value of the member. + /// Thrown if the member is not a PropertyInfo or FieldInfo. + public static object? GetMemberValue(object obj, MemberInfo member) + { + return member switch + { + PropertyInfo prop => prop.GetValue(obj), + FieldInfo field => field.GetValue(obj), + _ => throw new InvalidOperationException($"Unknown member type: {member.GetType()}") + }; + } + + /// + /// Gets a member (property or field) by name from a type. + /// + /// The type to search. + /// The member name to find. + /// The MemberInfo if found; null otherwise. + public static MemberInfo? GetMemberInfo( + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicFields)] + Type type, + string name) + { + var property = type.GetProperty(name, BindingFlags.Public | BindingFlags.Instance); + if (property != null) + { + return property; + } + + return type.GetField(name, BindingFlags.Public | BindingFlags.Instance); + } +} diff --git a/TUnit.Assertions/Conditions/Helpers/StructuralEqualityComparer.cs b/TUnit.Assertions/Conditions/Helpers/StructuralEqualityComparer.cs index af1a1876e9..13273b4a45 100644 --- a/TUnit.Assertions/Conditions/Helpers/StructuralEqualityComparer.cs +++ b/TUnit.Assertions/Conditions/Helpers/StructuralEqualityComparer.cs @@ -1,6 +1,5 @@ using System.Collections; using System.Diagnostics.CodeAnalysis; -using System.Reflection; namespace TUnit.Assertions.Conditions.Helpers; @@ -36,12 +35,12 @@ public bool Equals(T? x, T? y) var type = typeof(T); - if (IsPrimitiveType(type)) + if (TypeHelper.IsPrimitiveOrWellKnownType(type)) { return EqualityComparer.Default.Equals(x, y); } - return CompareStructurally(x, y, new HashSet(new ReferenceEqualityComparer())); + return CompareStructurally(x, y, new HashSet(ReferenceEqualityComparer.Instance)); } public int GetHashCode(T obj) @@ -54,23 +53,6 @@ public int GetHashCode(T obj) return EqualityComparer.Default.GetHashCode(obj); } - private static bool IsPrimitiveType(Type type) - { - return type.IsPrimitive - || type.IsEnum - || type == typeof(string) - || type == typeof(decimal) - || type == typeof(DateTime) - || type == typeof(DateTimeOffset) - || type == typeof(TimeSpan) - || type == typeof(Guid) -#if NET6_0_OR_GREATER - || type == typeof(DateOnly) - || type == typeof(TimeOnly) -#endif - ; - } - [UnconditionalSuppressMessage("Trimming", "IL2072", Justification = "GetType() is acceptable for runtime structural comparison")] private bool CompareStructurally(object? x, object? y, HashSet visited) { @@ -87,7 +69,7 @@ private bool CompareStructurally(object? x, object? y, HashSet visited) var xType = x.GetType(); var yType = y.GetType(); - if (IsPrimitiveType(xType)) + if (TypeHelper.IsPrimitiveOrWellKnownType(xType)) { return Equals(x, y); } @@ -121,12 +103,12 @@ private bool CompareStructurally(object? x, object? y, HashSet visited) return true; } - var members = GetMembersToCompare(xType); + var members = ReflectionHelper.GetMembersToCompare(xType); foreach (var member in members) { - var xValue = GetMemberValue(x, member); - var yValue = GetMemberValue(y, member); + var xValue = ReflectionHelper.GetMemberValue(x, member); + var yValue = ReflectionHelper.GetMemberValue(y, member); if (!CompareStructurally(xValue, yValue, visited)) { @@ -136,28 +118,4 @@ private bool CompareStructurally(object? x, object? y, HashSet visited) return true; } - - private static List GetMembersToCompare([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicFields)] Type type) - { - var members = new List(); - members.AddRange(type.GetProperties(BindingFlags.Public | BindingFlags.Instance)); - members.AddRange(type.GetFields(BindingFlags.Public | BindingFlags.Instance)); - return members; - } - - private static object? GetMemberValue(object obj, MemberInfo member) - { - return member switch - { - PropertyInfo prop => prop.GetValue(obj), - FieldInfo field => field.GetValue(obj), - _ => throw new InvalidOperationException($"Unknown member type: {member.GetType()}") - }; - } - - private sealed class ReferenceEqualityComparer : IEqualityComparer - { - public new bool Equals(object? x, object? y) => ReferenceEquals(x, y); - public int GetHashCode(object obj) => System.Runtime.CompilerServices.RuntimeHelpers.GetHashCode(obj); - } } diff --git a/TUnit.Assertions/Conditions/Helpers/TypeHelper.cs b/TUnit.Assertions/Conditions/Helpers/TypeHelper.cs new file mode 100644 index 0000000000..856a05d9e6 --- /dev/null +++ b/TUnit.Assertions/Conditions/Helpers/TypeHelper.cs @@ -0,0 +1,84 @@ +using System.Collections.Concurrent; + +namespace TUnit.Assertions.Conditions.Helpers; + +/// +/// Helper methods for type checking and classification. +/// Consolidates type checking logic to ensure consistent behavior across assertion classes. +/// +internal static class TypeHelper +{ + /// + /// Thread-safe registry of user-defined types that should be treated as primitives + /// (using value equality rather than structural comparison). + /// + private static readonly ConcurrentDictionary CustomPrimitiveTypes = new(); + + /// + /// Registers a type to be treated as a primitive for structural equivalency comparisons. + /// Once registered, instances of this type will use value equality (via Equals) rather + /// than having their properties compared individually. + /// + /// The type to register as a primitive. + public static void RegisterAsPrimitive() + { + CustomPrimitiveTypes.TryAdd(typeof(T), 0); + } + + /// + /// Registers a type to be treated as a primitive for structural equivalency comparisons. + /// + /// The type to register as a primitive. + public static void RegisterAsPrimitive(Type type) + { + CustomPrimitiveTypes.TryAdd(type, 0); + } + + /// + /// Removes a previously registered custom primitive type. + /// + /// The type to unregister. + /// True if the type was removed; false if it wasn't registered. + public static bool UnregisterPrimitive() + { + return CustomPrimitiveTypes.TryRemove(typeof(T), out _); + } + + /// + /// Clears all registered custom primitive types. + /// Useful for test cleanup between tests. + /// + public static void ClearCustomPrimitives() + { + CustomPrimitiveTypes.Clear(); + } + + /// + /// Determines if a type is a primitive or well-known immutable type that should use + /// value equality rather than structural comparison. + /// + /// The type to check. + /// True if the type should use value equality; false for structural comparison. + public static bool IsPrimitiveOrWellKnownType(Type type) + { + // Check user-defined primitives first (fast path for common case) + if (CustomPrimitiveTypes.ContainsKey(type)) + { + return true; + } + + return type.IsPrimitive + || type.IsEnum + || type == typeof(string) + || type == typeof(decimal) + || type == typeof(DateTime) + || type == typeof(DateTimeOffset) + || type == typeof(TimeSpan) + || type == typeof(Guid) +#if NET6_0_OR_GREATER + || type == typeof(DateOnly) + || type == typeof(TimeOnly) +#endif + ; + } +} diff --git a/TUnit.Assertions/Conditions/NotStructuralEquivalencyAssertion.cs b/TUnit.Assertions/Conditions/NotStructuralEquivalencyAssertion.cs index 4e58fd0bb6..2d8c30306a 100644 --- a/TUnit.Assertions/Conditions/NotStructuralEquivalencyAssertion.cs +++ b/TUnit.Assertions/Conditions/NotStructuralEquivalencyAssertion.cs @@ -1,5 +1,6 @@ using System.Diagnostics.CodeAnalysis; using System.Text; +using TUnit.Assertions.Conditions.Helpers; using TUnit.Assertions.Core; namespace TUnit.Assertions.Conditions; @@ -90,7 +91,12 @@ protected override Task CheckAsync(EvaluationMetadata m foreach (var type in _ignoredTypes) tempAssertion.IgnoringType(type); - var result = tempAssertion.CompareObjects(value, _notExpected, "", new HashSet(new ReferenceEqualityComparer())); + var result = tempAssertion.CompareObjects( + value, + _notExpected, + "", + new HashSet(ReferenceEqualityComparer.Instance), + new HashSet(ReferenceEqualityComparer.Instance)); // Invert the result - we want them to NOT be equivalent if (result.IsPassed) @@ -101,43 +107,14 @@ protected override Task CheckAsync(EvaluationMetadata m return Task.FromResult(AssertionResult.Passed); } - private sealed class ReferenceEqualityComparer : IEqualityComparer - { - public new bool Equals(object? x, object? y) => ReferenceEquals(x, y); - public int GetHashCode(object obj) => System.Runtime.CompilerServices.RuntimeHelpers.GetHashCode(obj); - } - protected override string GetExpectation() { // Extract the source variable name from the expression builder // Format: "Assert.That(variableName).IsNotEquivalentTo(...)" var expressionString = Context.ExpressionBuilder.ToString(); - var sourceVariable = ExtractSourceVariable(expressionString); + var sourceVariable = ExpressionHelper.ExtractSourceVariable(expressionString); var notExpectedDesc = _notExpectedExpression ?? "expected value"; return $"{sourceVariable} to not be equivalent to {notExpectedDesc}"; } - - private static string ExtractSourceVariable(string expression) - { - // Extract variable name from "Assert.That(variableName)" or similar - var thatIndex = expression.IndexOf(".That("); - if (thatIndex >= 0) - { - var startIndex = thatIndex + 6; // Length of ".That(" - var endIndex = expression.IndexOf(')', startIndex); - if (endIndex > startIndex) - { - var variable = expression.Substring(startIndex, endIndex - startIndex); - // Handle lambda expressions like "async () => ..." by returning "value" - if (variable.Contains("=>") || variable.StartsWith("()")) - { - return "value"; - } - return variable; - } - } - - return "value"; - } } diff --git a/TUnit.Assertions/Conditions/StructuralEquivalencyAssertion.cs b/TUnit.Assertions/Conditions/StructuralEquivalencyAssertion.cs index 8591bf4da9..2f1086a1a7 100644 --- a/TUnit.Assertions/Conditions/StructuralEquivalencyAssertion.cs +++ b/TUnit.Assertions/Conditions/StructuralEquivalencyAssertion.cs @@ -2,6 +2,7 @@ using System.Diagnostics.CodeAnalysis; using System.Reflection; using System.Text; +using TUnit.Assertions.Conditions.Helpers; using TUnit.Assertions.Core; namespace TUnit.Assertions.Conditions; @@ -77,11 +78,21 @@ protected override Task CheckAsync(EvaluationMetadata m return Task.FromResult(AssertionResult.Failed($"threw {exception.GetType().Name}: {exception.Message}")); } - var result = CompareObjects(value, _expected, "", new HashSet(new ReferenceEqualityComparer())); + var result = CompareObjects( + value, + _expected, + "", + new HashSet(ReferenceEqualityComparer.Instance), + new HashSet(ReferenceEqualityComparer.Instance)); return Task.FromResult(result); } - internal AssertionResult CompareObjects(object? actual, object? expected, string path, HashSet visited) + internal AssertionResult CompareObjects( + object? actual, + object? expected, + string path, + HashSet visitedActual, + HashSet? visitedExpected = null) { // Check for ignored paths if (_ignoredMembers.Contains(path)) @@ -109,7 +120,7 @@ internal AssertionResult CompareObjects(object? actual, object? expected, string var expectedType = expected.GetType(); // Handle primitive types and strings - if (IsPrimitiveType(actualType)) + if (TypeHelper.IsPrimitiveOrWellKnownType(actualType)) { if (!Equals(actual, expected)) { @@ -118,13 +129,25 @@ internal AssertionResult CompareObjects(object? actual, object? expected, string return AssertionResult.Passed; } - // Handle cycles - if (visited.Contains(actual)) + // Handle cycles - check both actual and expected to prevent infinite recursion + // from cycles in either object graph + if (visitedActual.Contains(actual)) { return AssertionResult.Passed; } - visited.Add(actual); + visitedActual.Add(actual); + + // Also track expected objects to handle cycles in the expected graph + if (visitedExpected != null) + { + if (visitedExpected.Contains(expected)) + { + return AssertionResult.Passed; + } + + visitedExpected.Add(expected); + } // Handle enumerables if (actual is IEnumerable actualEnumerable && expected is IEnumerable expectedEnumerable @@ -156,7 +179,7 @@ internal AssertionResult CompareObjects(object? actual, object? expected, string return AssertionResult.Failed($"{itemPath} did not match{Environment.NewLine}Expected: null{Environment.NewLine}Received: {FormatValue(actualList[i])}"); } - var result = CompareObjects(actualList[i], expectedList[i], itemPath, visited); + var result = CompareObjects(actualList[i], expectedList[i], itemPath, visitedActual, visitedExpected); if (!result.IsPassed) { return result; @@ -167,7 +190,7 @@ internal AssertionResult CompareObjects(object? actual, object? expected, string } // Compare properties and fields - var expectedMembers = GetMembersToCompare(expectedType); + var expectedMembers = ReflectionHelper.GetMembersToCompare(expectedType); foreach (var member in expectedMembers) { @@ -178,7 +201,7 @@ internal AssertionResult CompareObjects(object? actual, object? expected, string continue; } - var expectedValue = GetMemberValue(expected, member); + var expectedValue = ReflectionHelper.GetMemberValue(expected, member); // Check if this member's type should be ignored var memberType = member switch @@ -198,25 +221,25 @@ internal AssertionResult CompareObjects(object? actual, object? expected, string // In partial equivalency mode, skip members that don't exist on actual if (_usePartialEquivalency) { - var actualMember = GetMemberInfo(actualType, member.Name); + var actualMember = ReflectionHelper.GetMemberInfo(actualType, member.Name); if (actualMember == null) { continue; } - actualValue = GetMemberValue(actual, actualMember); + actualValue = ReflectionHelper.GetMemberValue(actual, actualMember); } else { - var actualMember = GetMemberInfo(actualType, member.Name); + var actualMember = ReflectionHelper.GetMemberInfo(actualType, member.Name); if (actualMember == null) { return AssertionResult.Failed($"Property {memberPath} did not match{Environment.NewLine}Expected: {FormatValue(expectedValue)}{Environment.NewLine}Received: null"); } - actualValue = GetMemberValue(actual, actualMember); + actualValue = ReflectionHelper.GetMemberValue(actual, actualMember); } - var result = CompareObjects(actualValue, expectedValue, memberPath, visited); + var result = CompareObjects(actualValue, expectedValue, memberPath, visitedActual, visitedExpected); if (!result.IsPassed) { return result; @@ -226,7 +249,7 @@ internal AssertionResult CompareObjects(object? actual, object? expected, string // In non-partial mode, check for extra properties on actual if (!_usePartialEquivalency) { - var actualMembers = GetMembersToCompare(actualType); + var actualMembers = ReflectionHelper.GetMembersToCompare(actualType); var expectedMemberNames = new HashSet(expectedMembers.Select(m => m.Name)); foreach (var member in actualMembers) @@ -247,7 +270,7 @@ internal AssertionResult CompareObjects(object? actual, object? expected, string } var memberPath = string.IsNullOrEmpty(path) ? member.Name : $"{path}.{member.Name}"; - var actualValue = GetMemberValue(actual, member); + var actualValue = ReflectionHelper.GetMemberValue(actual, member); // Skip properties with null values - they're equivalent to not having the property if (actualValue == null) @@ -263,13 +286,6 @@ internal AssertionResult CompareObjects(object? actual, object? expected, string return AssertionResult.Passed; } - private static bool IsPrimitiveType(Type type) - { - return type.IsPrimitive || type.IsEnum || type == typeof(string) || type == typeof(decimal) - || type == typeof(DateTime) || type == typeof(DateTimeOffset) || type == typeof(TimeSpan) - || type == typeof(Guid); - } - private bool ShouldIgnoreType(Type type) { // Check if the type itself should be ignored @@ -288,36 +304,6 @@ private bool ShouldIgnoreType(Type type) return false; } - private static List GetMembersToCompare([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicFields)] Type type) - { - var members = new List(); - members.AddRange(type.GetProperties(BindingFlags.Public | BindingFlags.Instance)); - members.AddRange(type.GetFields(BindingFlags.Public | BindingFlags.Instance)); - return members; - } - - private static MemberInfo? GetMemberInfo([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicFields)] Type type, string name) - { - var property = type.GetProperty(name, BindingFlags.Public | BindingFlags.Instance); - if (property != null) - { - return property; - } - - var field = type.GetField(name, BindingFlags.Public | BindingFlags.Instance); - return field; - } - - private static object? GetMemberValue(object obj, MemberInfo member) - { - return member switch - { - PropertyInfo prop => prop.GetValue(obj), - FieldInfo field => field.GetValue(obj), - _ => throw new InvalidOperationException($"Unknown member type: {member.GetType()}") - }; - } - private static string FormatValue(object? value) { if (value == null) @@ -338,38 +324,9 @@ protected override string GetExpectation() // Extract the source variable name from the expression builder // Format: "Assert.That(variableName).IsEquivalentTo(...)" var expressionString = Context.ExpressionBuilder.ToString(); - var sourceVariable = ExtractSourceVariable(expressionString); + var sourceVariable = ExpressionHelper.ExtractSourceVariable(expressionString); var expectedDesc = _expectedExpression ?? "expected value"; return $"{sourceVariable} to be equivalent to {expectedDesc}"; } - - private static string ExtractSourceVariable(string expression) - { - // Extract variable name from "Assert.That(variableName)" or similar - var thatIndex = expression.IndexOf(".That("); - if (thatIndex >= 0) - { - var startIndex = thatIndex + 6; // Length of ".That(" - var endIndex = expression.IndexOf(')', startIndex); - if (endIndex > startIndex) - { - var variable = expression.Substring(startIndex, endIndex - startIndex); - // Handle lambda expressions like "async () => ..." by returning "value" - if (variable.Contains("=>") || variable.StartsWith("()")) - { - return "value"; - } - return variable; - } - } - - return "value"; - } - - private sealed class ReferenceEqualityComparer : IEqualityComparer - { - public new bool Equals(object? x, object? y) => ReferenceEquals(x, y); - public int GetHashCode(object obj) => System.Runtime.CompilerServices.RuntimeHelpers.GetHashCode(obj); - } } diff --git a/TUnit.Core/Attributes/TestData/MethodDataSourceAttribute.cs b/TUnit.Core/Attributes/TestData/MethodDataSourceAttribute.cs index efa53d7741..c6bd260fe0 100644 --- a/TUnit.Core/Attributes/TestData/MethodDataSourceAttribute.cs +++ b/TUnit.Core/Attributes/TestData/MethodDataSourceAttribute.cs @@ -82,7 +82,10 @@ public MethodDataSourceAttribute( // If we have a test class instance and no explicit class was provided, // use the instance's actual type (which will be the constructed generic type) - if (ClassProvidingDataSource == null && dataGeneratorMetadata.TestClassInstance != null) + // Skip PlaceholderInstance as it's used during discovery when the actual instance isn't created yet + if (ClassProvidingDataSource == null + && dataGeneratorMetadata.TestClassInstance != null + && dataGeneratorMetadata.TestClassInstance is not PlaceholderInstance) { targetType = dataGeneratorMetadata.TestClassInstance.GetType(); } diff --git a/TUnit.Core/Data/ScopedDictionary.cs b/TUnit.Core/Data/ScopedDictionary.cs index 7a751429ff..e5766792b0 100644 --- a/TUnit.Core/Data/ScopedDictionary.cs +++ b/TUnit.Core/Data/ScopedDictionary.cs @@ -1,6 +1,4 @@ -using TUnit.Core.Tracking; - -namespace TUnit.Core.Data; +namespace TUnit.Core.Data; public class ScopedDictionary where TScope : notnull @@ -11,14 +9,6 @@ public class ScopedDictionary { var innerDictionary = _scopedContainers.GetOrAdd(scope, static _ => new ThreadSafeDictionary()); - var obj = innerDictionary.GetOrAdd(type, factory); - - ObjectTracker.OnDisposed(obj, () => - { - innerDictionary.Remove(type); - }); - - return obj; + return innerDictionary.GetOrAdd(type, factory); } - } diff --git a/TUnit.Core/Discovery/ObjectGraph.cs b/TUnit.Core/Discovery/ObjectGraph.cs new file mode 100644 index 0000000000..85a1b6f960 --- /dev/null +++ b/TUnit.Core/Discovery/ObjectGraph.cs @@ -0,0 +1,124 @@ +using System.Collections.Concurrent; +using System.Collections.ObjectModel; +using TUnit.Core.Interfaces; + +namespace TUnit.Core.Discovery; + +/// +/// Represents a discovered object graph organized by depth level. +/// +/// +/// Internal collections are stored privately and exposed as read-only views +/// to prevent callers from corrupting internal state. +/// Uses Lazy<T> for thread-safe lazy initialization of read-only views. +/// +internal sealed class ObjectGraph : IObjectGraph +{ + private readonly ConcurrentDictionary> _objectsByDepth; + private readonly HashSet _allObjects; + + // Thread-safe lazy initialization of read-only views + private readonly Lazy>> _lazyReadOnlyObjectsByDepth; + private readonly Lazy> _lazyReadOnlyAllObjects; + + // Cached sorted depths (computed once in constructor) + private readonly int[] _sortedDepthsDescending; + + /// + /// Creates a new object graph from the discovered objects. + /// + /// Objects organized by depth level. + /// All unique objects in the graph. + public ObjectGraph(ConcurrentDictionary> objectsByDepth, HashSet allObjects) + { + _objectsByDepth = objectsByDepth; + _allObjects = allObjects; + + // Compute MaxDepth and sorted depths without LINQ to reduce allocations + var keyCount = objectsByDepth.Count; + if (keyCount == 0) + { + MaxDepth = -1; + _sortedDepthsDescending = []; + } + else + { + var keys = new int[keyCount]; + objectsByDepth.Keys.CopyTo(keys, 0); + + // Find max manually + var maxDepth = int.MinValue; + foreach (var key in keys) + { + if (key > maxDepth) + { + maxDepth = key; + } + } + MaxDepth = maxDepth; + + // Sort in descending order using Array.Sort with reverse comparison + Array.Sort(keys, (a, b) => b.CompareTo(a)); + _sortedDepthsDescending = keys; + } + + // Use Lazy with ExecutionAndPublication for thread-safe single initialization + _lazyReadOnlyObjectsByDepth = new Lazy>>( + CreateReadOnlyObjectsByDepth, + LazyThreadSafetyMode.ExecutionAndPublication); + + _lazyReadOnlyAllObjects = new Lazy>( + () => _allObjects.ToArray(), + LazyThreadSafetyMode.ExecutionAndPublication); + } + + /// + public IReadOnlyDictionary> ObjectsByDepth => _lazyReadOnlyObjectsByDepth.Value; + + /// + public IReadOnlyCollection AllObjects => _lazyReadOnlyAllObjects.Value; + + /// + public int MaxDepth { get; } + + /// + public IEnumerable GetObjectsAtDepth(int depth) + { + if (!_objectsByDepth.TryGetValue(depth, out var objects)) + { + return []; + } + + // Lock and copy to prevent concurrent modification issues + lock (objects) + { + return objects.ToArray(); + } + } + + /// + public IEnumerable GetDepthsDescending() + { + // Return cached sorted depths (computed once in constructor) + return _sortedDepthsDescending; + } + + /// + /// Creates a thread-safe read-only snapshot of objects by depth. + /// + private IReadOnlyDictionary> CreateReadOnlyObjectsByDepth() + { + var dict = new Dictionary>(_objectsByDepth.Count); + + foreach (var kvp in _objectsByDepth) + { + // Lock each HashSet while copying to ensure consistency + lock (kvp.Value) + { + dict[kvp.Key] = kvp.Value.ToArray(); + } + } + + return new ReadOnlyDictionary>(dict); + } +} diff --git a/TUnit.Core/Discovery/ObjectGraphDiscoverer.cs b/TUnit.Core/Discovery/ObjectGraphDiscoverer.cs new file mode 100644 index 0000000000..f2389fda28 --- /dev/null +++ b/TUnit.Core/Discovery/ObjectGraphDiscoverer.cs @@ -0,0 +1,516 @@ +using System.Collections.Concurrent; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Reflection; +using TUnit.Core.Helpers; +using TUnit.Core.Interfaces; +using TUnit.Core.Interfaces.SourceGenerator; +using TUnit.Core.PropertyInjection; + +namespace TUnit.Core.Discovery; + +/// +/// Represents an error that occurred during object graph discovery. +/// +/// The name of the type being inspected. +/// The name of the property that failed to access. +/// The error message. +/// The exception that occurred. +internal readonly record struct DiscoveryError(string TypeName, string PropertyName, string ErrorMessage, Exception Exception); + +/// +/// Centralized service for discovering and organizing object graphs. +/// Consolidates duplicate graph traversal logic from ObjectGraphDiscoveryService and TrackableObjectGraphProvider. +/// Follows Single Responsibility Principle - only discovers objects, doesn't modify them. +/// +/// +/// +/// This class is thread-safe and uses cached reflection for performance. +/// Objects are organized by their nesting depth in the hierarchy: +/// +/// +/// Depth 0: Root objects (class args, method args, property values) +/// Depth 1+: Nested objects found in properties of objects at previous depth +/// +/// +/// Discovery errors (e.g., property access failures) are collected in +/// rather than thrown, allowing discovery to continue despite individual property failures. +/// +/// +internal sealed class ObjectGraphDiscoverer : IObjectGraphTracker +{ + /// + /// Maximum recursion depth for object graph discovery. + /// Prevents stack overflow on deep or circular object graphs. + /// + private const int MaxRecursionDepth = 50; + + // Reference equality comparer for object tracking (ignores Equals overrides) + private static readonly Helpers.ReferenceEqualityComparer ReferenceComparer = Helpers.ReferenceEqualityComparer.Instance; + + // Types to skip during discovery (primitives, strings, system types) + private static readonly HashSet SkipTypes = + [ + typeof(string), + typeof(decimal), + typeof(DateTime), + typeof(DateTimeOffset), + typeof(TimeSpan), + typeof(Guid) + ]; + + // Thread-safe collection of discovery errors for diagnostics + private static readonly ConcurrentBag DiscoveryErrors = []; + + /// + /// Gets all discovery errors that occurred during object graph traversal. + /// Useful for debugging and diagnostics when property access fails. + /// + /// A read-only list of discovery errors. + public static IReadOnlyList GetDiscoveryErrors() + { + return DiscoveryErrors.ToArray(); + } + + /// + /// Clears all recorded discovery errors. Call at end of test session. + /// + public static void ClearDiscoveryErrors() + { + DiscoveryErrors.Clear(); + } + + /// + /// Delegate for adding discovered objects to collections. + /// Returns true if the object was newly added (not a duplicate). + /// + private delegate bool TryAddObjectFunc(object obj, int depth); + + /// + /// Delegate for recursive discovery after an object is added. + /// + private delegate void RecurseFunc(object obj, int depth); + + /// + /// Delegate for processing a root object after it's been added. + /// + private delegate void RootObjectCallback(object obj); + + /// + public IObjectGraph DiscoverObjectGraph(TestContext testContext, CancellationToken cancellationToken = default) + { + var objectsByDepth = new ConcurrentDictionary>(); + var allObjects = new HashSet(ReferenceComparer); + var allObjectsLock = new object(); // Thread-safety for allObjects HashSet + var visitedObjects = new ConcurrentDictionary(ReferenceComparer); + + // Standard mode add callback (thread-safe) + bool TryAddStandard(object obj, int depth) + { + if (!visitedObjects.TryAdd(obj, 0)) + { + return false; + } + + AddToDepth(objectsByDepth, depth, obj); + lock (allObjectsLock) + { + allObjects.Add(obj); + } + + return true; + } + + // Collect root-level objects and discover nested objects + CollectRootObjects( + testContext.Metadata.TestDetails, + TryAddStandard, + obj => DiscoverNestedObjects(obj, objectsByDepth, visitedObjects, allObjects, allObjectsLock, currentDepth: 1, cancellationToken), + cancellationToken); + + return new ObjectGraph(objectsByDepth, allObjects); + } + + /// + public IObjectGraph DiscoverNestedObjectGraph(object rootObject, CancellationToken cancellationToken = default) + { + var objectsByDepth = new ConcurrentDictionary>(); + var allObjects = new HashSet(ReferenceComparer); + var allObjectsLock = new object(); // Thread-safety for allObjects HashSet + var visitedObjects = new ConcurrentDictionary(ReferenceComparer); + + if (visitedObjects.TryAdd(rootObject, 0)) + { + AddToDepth(objectsByDepth, 0, rootObject); + lock (allObjectsLock) + { + allObjects.Add(rootObject); + } + + DiscoverNestedObjects(rootObject, objectsByDepth, visitedObjects, allObjects, allObjectsLock, currentDepth: 1, cancellationToken); + } + + return new ObjectGraph(objectsByDepth, allObjects); + } + + /// + /// Discovers objects and adds them to the existing tracked objects dictionary. + /// Used by TrackableObjectGraphProvider to populate TestContext.TrackedObjects. + /// + /// The test context to discover objects from. + /// Cancellation token for the operation. + /// The tracked objects dictionary (same as testContext.TrackedObjects). + public ConcurrentDictionary> DiscoverAndTrackObjects(TestContext testContext, CancellationToken cancellationToken = default) + { + var visitedObjects = testContext.TrackedObjects; + + // Collect root-level objects and discover nested objects for tracking + CollectRootObjects( + testContext.Metadata.TestDetails, + (obj, depth) => TryAddToHashSet(visitedObjects, depth, obj), + obj => DiscoverNestedObjectsForTracking(obj, visitedObjects, 1, cancellationToken), + cancellationToken); + + return visitedObjects; + } + + /// + /// Recursively discovers nested objects that have injectable properties OR implement IAsyncInitializer. + /// Uses consolidated TraverseInjectableProperties and TraverseInitializerProperties methods. + /// + private void DiscoverNestedObjects( + object obj, + ConcurrentDictionary> objectsByDepth, + ConcurrentDictionary visitedObjects, + HashSet allObjects, + object allObjectsLock, + int currentDepth, + CancellationToken cancellationToken) + { + if (!CheckRecursionDepth(obj, currentDepth)) + { + return; + } + + cancellationToken.ThrowIfCancellationRequested(); + + // Standard mode add callback: visitedObjects + objectsByDepth + allObjects (thread-safe) + bool TryAddStandard(object value, int depth) + { + if (!visitedObjects.TryAdd(value, 0)) + { + return false; + } + + AddToDepth(objectsByDepth, depth, value); + lock (allObjectsLock) + { + allObjects.Add(value); + } + + return true; + } + + // Recursive callback + void Recurse(object value, int depth) + { + DiscoverNestedObjects(value, objectsByDepth, visitedObjects, allObjects, allObjectsLock, depth, cancellationToken); + } + + // Traverse injectable properties (useSourceRegistrarCheck = false) + TraverseInjectableProperties(obj, TryAddStandard, Recurse, currentDepth, cancellationToken, useSourceRegistrarCheck: false); + + // Also discover nested IAsyncInitializer objects from ALL properties + TraverseInitializerProperties(obj, TryAddStandard, Recurse, currentDepth, cancellationToken); + } + + /// + /// Discovers nested objects for tracking (uses HashSet pattern for compatibility with TestContext.TrackedObjects). + /// Uses consolidated TraverseInjectableProperties and TraverseInitializerProperties methods. + /// + private void DiscoverNestedObjectsForTracking( + object obj, + ConcurrentDictionary> visitedObjects, + int currentDepth, + CancellationToken cancellationToken) + { + if (!CheckRecursionDepth(obj, currentDepth)) + { + return; + } + + cancellationToken.ThrowIfCancellationRequested(); + + // Tracking mode add callback: TryAddToHashSet only + bool TryAddTracking(object value, int depth) + { + return TryAddToHashSet(visitedObjects, depth, value); + } + + // Recursive callback + void Recurse(object value, int depth) + { + DiscoverNestedObjectsForTracking(value, visitedObjects, depth, cancellationToken); + } + + // Traverse injectable properties (useSourceRegistrarCheck = true for tracking mode) + TraverseInjectableProperties(obj, TryAddTracking, Recurse, currentDepth, cancellationToken, useSourceRegistrarCheck: true); + + // Also discover nested IAsyncInitializer objects from ALL properties + TraverseInitializerProperties(obj, TryAddTracking, Recurse, currentDepth, cancellationToken); + } + + /// + /// Clears all caches. Called at end of test session to release memory. + /// + public static void ClearCache() + { + PropertyCacheManager.ClearCache(); + ClearDiscoveryErrors(); + } + + /// + /// Checks if a type should be skipped during discovery. + /// + private static bool ShouldSkipType(Type type) + { + return type.IsPrimitive || + SkipTypes.Contains(type) || + type.Namespace?.StartsWith("System") == true; + } + + /// + /// Adds an object to the specified depth level. + /// Thread-safe: uses lock to protect HashSet modifications. + /// + private static void AddToDepth(ConcurrentDictionary> objectsByDepth, int depth, object obj) + { + var hashSet = objectsByDepth.GetOrAdd(depth, _ => new HashSet(ReferenceComparer)); + lock (hashSet) + { + hashSet.Add(obj); + } + } + + /// + /// Thread-safe add to HashSet at specified depth. Returns true if added (not duplicate). + /// + private static bool TryAddToHashSet(ConcurrentDictionary> dict, int depth, object obj) + { + var hashSet = dict.GetOrAdd(depth, _ => new HashSet(ReferenceComparer)); + lock (hashSet) + { + return hashSet.Add(obj); + } + } + + #region Consolidated Traversal Methods (DRY) + + /// + /// Checks recursion depth guard. Returns false if depth exceeded (caller should return early). + /// + private static bool CheckRecursionDepth(object obj, int currentDepth) + { + if (currentDepth > MaxRecursionDepth) + { +#if DEBUG + Debug.WriteLine($"[ObjectGraphDiscoverer] Max recursion depth ({MaxRecursionDepth}) reached for type '{obj.GetType().Name}'"); +#endif + return false; + } + + return true; + } + + /// + /// Unified traversal for injectable properties (from PropertyInjectionCache). + /// Eliminates duplicate code between DiscoverNestedObjects and DiscoverNestedObjectsForTracking. + /// + [UnconditionalSuppressMessage("Trimming", "IL2026", Justification = "Property discovery handles both AOT and reflection modes")] + private static void TraverseInjectableProperties( + object obj, + TryAddObjectFunc tryAdd, + RecurseFunc recurse, + int currentDepth, + CancellationToken cancellationToken, + bool useSourceRegistrarCheck) + { + var plan = PropertyInjectionCache.GetOrCreatePlan(obj.GetType()); + + if (!plan.HasProperties && !useSourceRegistrarCheck) + { + return; + } + + // The two modes differ in how they choose source-gen vs reflection: + // - Standard mode: Uses plan.SourceGeneratedProperties.Length > 0 + // - Tracking mode: Uses SourceRegistrar.IsEnabled + bool useSourceGen = useSourceRegistrarCheck + ? SourceRegistrar.IsEnabled + : plan.SourceGeneratedProperties.Length > 0; + + if (useSourceGen) + { + TraverseSourceGeneratedProperties(obj, plan.SourceGeneratedProperties, tryAdd, recurse, currentDepth, cancellationToken); + } + else + { + var reflectionProps = useSourceRegistrarCheck + ? plan.ReflectionProperties + : (plan.ReflectionProperties.Length > 0 ? plan.ReflectionProperties : []); + + TraverseReflectionProperties(obj, reflectionProps, tryAdd, recurse, currentDepth, cancellationToken); + } + } + + /// + /// Traverses source-generated properties and discovers nested objects. + /// Extracted for reduced complexity in TraverseInjectableProperties. + /// + [UnconditionalSuppressMessage("Trimming", "IL2026", Justification = "Property discovery handles both AOT and reflection modes")] + private static void TraverseSourceGeneratedProperties( + object obj, + PropertyInjectionMetadata[] sourceGeneratedProperties, + TryAddObjectFunc tryAdd, + RecurseFunc recurse, + int currentDepth, + CancellationToken cancellationToken) + { + foreach (var metadata in sourceGeneratedProperties) + { + cancellationToken.ThrowIfCancellationRequested(); + var property = metadata.ContainingType.GetProperty(metadata.PropertyName); + if (property == null || !property.CanRead) + { + continue; + } + + var value = property.GetValue(obj); + if (value != null && tryAdd(value, currentDepth)) + { + recurse(value, currentDepth + 1); + } + } + } + + /// + /// Traverses reflection-based properties and discovers nested objects. + /// Extracted for reduced complexity in TraverseInjectableProperties. + /// + private static void TraverseReflectionProperties( + object obj, + (PropertyInfo Property, IDataSourceAttribute DataSource)[] reflectionProperties, + TryAddObjectFunc tryAdd, + RecurseFunc recurse, + int currentDepth, + CancellationToken cancellationToken) + { + foreach (var prop in reflectionProperties) + { + cancellationToken.ThrowIfCancellationRequested(); + var value = prop.Property.GetValue(obj); + if (value != null && tryAdd(value, currentDepth)) + { + recurse(value, currentDepth + 1); + } + } + } + + /// + /// Unified traversal for IAsyncInitializer objects (from all properties). + /// Eliminates duplicate code between DiscoverNestedInitializerObjects and DiscoverNestedInitializerObjectsForTracking. + /// + [UnconditionalSuppressMessage("Trimming", "IL2026", Justification = "Reflection fallback for nested initializers. In AOT, source-gen handles primary discovery.")] + [UnconditionalSuppressMessage("Trimming", "IL2070", Justification = "Reflection fallback for nested initializers. In AOT, source-gen handles primary discovery.")] + [UnconditionalSuppressMessage("Trimming", "IL2075", Justification = "Reflection fallback for nested initializers. In AOT, source-gen handles primary discovery.")] + private static void TraverseInitializerProperties( + object obj, + TryAddObjectFunc tryAdd, + RecurseFunc recurse, + int currentDepth, + CancellationToken cancellationToken) + { + var type = obj.GetType(); + + if (ShouldSkipType(type)) + { + return; + } + + var properties = PropertyCacheManager.GetCachedProperties(type); + + foreach (var property in properties) + { + cancellationToken.ThrowIfCancellationRequested(); + try + { + var value = property.GetValue(obj); + if (value == null) + { + continue; + } + + // Only discover IAsyncInitializer objects + if (value is IAsyncInitializer && tryAdd(value, currentDepth)) + { + recurse(value, currentDepth + 1); + } + } + catch (OperationCanceledException) + { + throw; // Propagate cancellation + } + catch (Exception ex) + { + // Record error for diagnostics (available via GetDiscoveryErrors()) + DiscoveryErrors.Add(new DiscoveryError(type.Name, property.Name, ex.Message, ex)); +#if DEBUG + Debug.WriteLine($"[ObjectGraphDiscoverer] Failed to access property '{property.Name}' on type '{type.Name}': {ex.Message}"); +#endif + // Continue discovery despite property access failures + } + } + } + + /// + /// Collects root-level objects (class args, method args, properties) from test details. + /// Eliminates duplicate loops in DiscoverObjectGraph and DiscoverAndTrackObjects. + /// + private static void CollectRootObjects( + TestDetails testDetails, + TryAddObjectFunc tryAdd, + RootObjectCallback onRootObjectAdded, + CancellationToken cancellationToken) + { + // Process class arguments + ProcessRootCollection(testDetails.TestClassArguments, tryAdd, onRootObjectAdded, cancellationToken); + + // Process method arguments + ProcessRootCollection(testDetails.TestMethodArguments, tryAdd, onRootObjectAdded, cancellationToken); + + // Process injected property values + ProcessRootCollection(testDetails.TestClassInjectedPropertyArguments.Values, tryAdd, onRootObjectAdded, cancellationToken); + } + + /// + /// Processes a collection of root objects, adding them to the graph and invoking callback. + /// Extracted to eliminate duplicate iteration patterns in CollectRootObjects. + /// + private static void ProcessRootCollection( + IEnumerable collection, + TryAddObjectFunc tryAdd, + RootObjectCallback onRootObjectAdded, + CancellationToken cancellationToken) + { + foreach (var item in collection) + { + cancellationToken.ThrowIfCancellationRequested(); + if (item != null && tryAdd(item, 0)) + { + onRootObjectAdded(item); + } + } + } + + #endregion +} diff --git a/TUnit.Core/Discovery/PropertyCacheManager.cs b/TUnit.Core/Discovery/PropertyCacheManager.cs new file mode 100644 index 0000000000..12319608cc --- /dev/null +++ b/TUnit.Core/Discovery/PropertyCacheManager.cs @@ -0,0 +1,116 @@ +using System.Collections.Concurrent; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Reflection; + +namespace TUnit.Core.Discovery; + +/// +/// Manages cached property reflection results for object graph discovery. +/// Extracted from ObjectGraphDiscoverer to follow Single Responsibility Principle. +/// +/// +/// +/// This class caches arrays per type to avoid repeated reflection calls. +/// Includes automatic cache cleanup when size exceeds to prevent memory leaks. +/// +/// +/// Thread-safe: Uses and for coordination. +/// +/// +internal static class PropertyCacheManager +{ + /// + /// Maximum size for the property cache before cleanup is triggered. + /// Prevents unbounded memory growth in long-running test sessions. + /// + private const int MaxCacheSize = 10000; + + // Cache for GetProperties() results per type - eliminates repeated reflection calls + private static readonly ConcurrentDictionary PropertyCache = new(); + + // Flag to coordinate cache cleanup (prevents multiple threads cleaning simultaneously) + private static int _cleanupInProgress; + + /// + /// Gets cached properties for a type, filtering to only readable non-indexed properties. + /// Includes periodic cache cleanup to prevent unbounded memory growth. + /// + /// The type to get properties for. + /// An array of readable, non-indexed properties for the type. + [UnconditionalSuppressMessage("Trimming", "IL2070", Justification = "Reflection fallback for nested initializers. In AOT, source-gen handles primary discovery.")] + public static PropertyInfo[] GetCachedProperties(Type type) + { + // Periodic cleanup if cache grows too large to prevent memory leaks + // Use Interlocked to ensure only one thread performs cleanup at a time + if (PropertyCache.Count > MaxCacheSize && + Interlocked.CompareExchange(ref _cleanupInProgress, 1, 0) == 0) + { + try + { + // Double-check after acquiring cleanup flag + if (PropertyCache.Count > MaxCacheSize) + { + // Use ToArray() to get a true snapshot for thread-safe enumeration + // This prevents issues with concurrent modifications during iteration + var allKeys = PropertyCache.Keys.ToArray(); + var removeCount = Math.Min(allKeys.Length / 2, MaxCacheSize / 2); + + for (var i = 0; i < removeCount; i++) + { + PropertyCache.TryRemove(allKeys[i], out _); + } +#if DEBUG + Debug.WriteLine($"[PropertyCacheManager] PropertyCache exceeded {MaxCacheSize} entries, cleared {removeCount} entries"); +#endif + } + } + finally + { + Interlocked.Exchange(ref _cleanupInProgress, 0); + } + } + + return PropertyCache.GetOrAdd(type, static t => + { + // Use explicit loops instead of LINQ to avoid allocations in hot path + var allProps = t.GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance); + + // First pass: count eligible properties + var eligibleCount = 0; + foreach (var p in allProps) + { + if (p.CanRead && p.GetIndexParameters().Length == 0) + { + eligibleCount++; + } + } + + // Second pass: fill result array + var result = new PropertyInfo[eligibleCount]; + var i = 0; + foreach (var p in allProps) + { + if (p.CanRead && p.GetIndexParameters().Length == 0) + { + result[i++] = p; + } + } + + return result; + }); + } + + /// + /// Clears the property cache. Called at end of test session to release memory. + /// + public static void ClearCache() + { + PropertyCache.Clear(); + } + + /// + /// Gets the current number of cached types. Useful for diagnostics. + /// + public static int CacheCount => PropertyCache.Count; +} diff --git a/TUnit.Core/Helpers/Counter.cs b/TUnit.Core/Helpers/Counter.cs index d3df14af6f..e4ca4b759a 100644 --- a/TUnit.Core/Helpers/Counter.cs +++ b/TUnit.Core/Helpers/Counter.cs @@ -2,6 +2,11 @@ namespace TUnit.Core.Helpers; +/// +/// Thread-safe counter with event notification. +/// Captures event handler BEFORE state change to prevent race conditions +/// where subscribers miss notifications that occur during subscription. +/// [DebuggerDisplay("Count = {CurrentCount}")] public class Counter { @@ -11,34 +16,91 @@ public class Counter public int Increment() { + // Capture handler BEFORE state change to ensure all subscribers + // at the time of the change are notified (prevents TOCTOU race) + var handler = _onCountChanged; var newCount = Interlocked.Increment(ref _count); - var handler = _onCountChanged; - handler?.Invoke(this, newCount); + RaiseEventSafely(handler, newCount); return newCount; } public int Decrement() { + // Capture handler BEFORE state change to ensure all subscribers + // at the time of the change are notified (prevents TOCTOU race) + var handler = _onCountChanged; var newCount = Interlocked.Decrement(ref _count); - var handler = _onCountChanged; - handler?.Invoke(this, newCount); + RaiseEventSafely(handler, newCount); return newCount; } + /// + /// Adds a value to the counter. Use Increment/Decrement for single-step changes. + /// + /// The value to add (can be positive or negative). + /// The new count after the addition. + /// Thrown if the resulting count is negative, indicating a logic error. public int Add(int value) { + // Capture handler BEFORE state change to ensure all subscribers + // at the time of the change are notified (prevents TOCTOU race) + var handler = _onCountChanged; var newCount = Interlocked.Add(ref _count, value); - var handler = _onCountChanged; - handler?.Invoke(this, newCount); + // Guard against reference count going negative - indicates a bug in calling code + if (newCount < 0) + { + throw new InvalidOperationException( + $"Counter went below zero (result: {newCount}). This indicates a bug in the reference counting logic."); + } + + RaiseEventSafely(handler, newCount); return newCount; } + /// + /// Raises the event safely, ensuring all subscribers are notified even if some throw exceptions. + /// Collects all exceptions and throws AggregateException if any occurred. + /// + private void RaiseEventSafely(EventHandler? handler, int newCount) + { + if (handler == null) + { + return; + } + + var invocationList = handler.GetInvocationList(); + List? exceptions = null; + + foreach (var subscriber in invocationList) + { + try + { + ((EventHandler)subscriber).Invoke(this, newCount); + } + catch (Exception ex) + { + exceptions ??= []; + exceptions.Add(ex); + +#if DEBUG + Debug.WriteLine($"[Counter] Exception in OnCountChanged subscriber: {ex.Message}"); +#endif + } + } + + // If any subscribers threw, aggregate and rethrow after all are notified + if (exceptions?.Count > 0) + { + throw new AggregateException("One or more OnCountChanged subscribers threw an exception.", exceptions); + } + } + public int CurrentCount => Interlocked.CompareExchange(ref _count, 0, 0); public event EventHandler? OnCountChanged diff --git a/TUnit.Core/Helpers/DataSourceHelpers.cs b/TUnit.Core/Helpers/DataSourceHelpers.cs index 6802e5add5..fc7ba158e6 100644 --- a/TUnit.Core/Helpers/DataSourceHelpers.cs +++ b/TUnit.Core/Helpers/DataSourceHelpers.cs @@ -178,12 +178,9 @@ public static T InvokeIfFunc(object? value) // If it's a Func, invoke it first var actualData = InvokeIfFunc(data); - // Only initialize during discovery if explicitly opted-in via IAsyncDiscoveryInitializer - // Regular IAsyncInitializer objects are initialized during test execution by ObjectLifecycleService - if (actualData is IAsyncDiscoveryInitializer) - { - await ObjectInitializer.InitializeAsync(actualData); - } + // During discovery, only IAsyncDiscoveryInitializer objects are initialized. + // Regular IAsyncInitializer objects are deferred to Execution phase. + await ObjectInitializer.InitializeForDiscoveryAsync(actualData); return actualData; } @@ -202,11 +199,8 @@ public static T InvokeIfFunc(object? value) if (enumerator.MoveNext()) { var value = enumerator.Current; - // Only initialize during discovery if explicitly opted-in via IAsyncDiscoveryInitializer - if (value is IAsyncDiscoveryInitializer) - { - await ObjectInitializer.InitializeAsync(value); - } + // Discovery: only IAsyncDiscoveryInitializer + await ObjectInitializer.InitializeForDiscoveryAsync(value); return value; } @@ -233,22 +227,16 @@ public static T InvokeIfFunc(object? value) if (enumerator.MoveNext()) { var value = enumerator.Current; - // Only initialize during discovery if explicitly opted-in via IAsyncDiscoveryInitializer - if (value is IAsyncDiscoveryInitializer) - { - await ObjectInitializer.InitializeAsync(value); - } + // Discovery: only IAsyncDiscoveryInitializer + await ObjectInitializer.InitializeForDiscoveryAsync(value); return value; } return null; } - // Only initialize during discovery if explicitly opted-in via IAsyncDiscoveryInitializer - // Regular IAsyncInitializer objects are initialized during test execution by ObjectLifecycleService - if (actualData is IAsyncDiscoveryInitializer) - { - await ObjectInitializer.InitializeAsync(actualData); - } + // During discovery, only IAsyncDiscoveryInitializer objects are initialized. + // Regular IAsyncInitializer objects are deferred to Execution phase. + await ObjectInitializer.InitializeForDiscoveryAsync(actualData); return actualData; } @@ -596,12 +584,8 @@ public static void RegisterTypeCreator(Func> { var value = args[0]; - // Only initialize during discovery if explicitly opted-in via IAsyncDiscoveryInitializer - // Regular IAsyncInitializer objects are initialized during test execution by ObjectLifecycleService - if (value is IAsyncDiscoveryInitializer) - { - await ObjectInitializer.InitializeAsync(value); - } + // Discovery: only IAsyncDiscoveryInitializer + await ObjectInitializer.InitializeForDiscoveryAsync(value); return value; } diff --git a/TUnit.Core/Helpers/Disposer.cs b/TUnit.Core/Helpers/Disposer.cs index 772b49de07..a70033bb79 100644 --- a/TUnit.Core/Helpers/Disposer.cs +++ b/TUnit.Core/Helpers/Disposer.cs @@ -1,9 +1,18 @@ -using TUnit.Core.Logging; +using TUnit.Core.Interfaces; +using TUnit.Core.Logging; namespace TUnit.Core.Helpers; -internal class Disposer(ILogger logger) +/// +/// Disposes objects asynchronously with logging. +/// Implements IDisposer for Dependency Inversion Principle compliance. +/// +internal class Disposer(ILogger logger) : IDisposer { + /// + /// Disposes an object and propagates any exceptions. + /// Exceptions are logged but NOT swallowed - callers must handle them. + /// public async ValueTask DisposeAsync(object? obj) { try @@ -19,10 +28,15 @@ public async ValueTask DisposeAsync(object? obj) } catch (Exception e) { + // Log the error for diagnostics if (logger != null) { await logger.LogErrorAsync(e); } + + // Propagate the exception - don't silently swallow disposal failures + // Callers can catch and aggregate if disposing multiple objects + throw; } } } diff --git a/TUnit.Core/Helpers/ParallelTaskHelper.cs b/TUnit.Core/Helpers/ParallelTaskHelper.cs new file mode 100644 index 0000000000..81f05f06d4 --- /dev/null +++ b/TUnit.Core/Helpers/ParallelTaskHelper.cs @@ -0,0 +1,165 @@ +namespace TUnit.Core.Helpers; + +/// +/// Helper methods for parallel task execution without LINQ allocations. +/// Provides optimized patterns for executing async operations in parallel. +/// Exceptions are aggregated in AggregateException when multiple tasks fail. +/// +public static class ParallelTaskHelper +{ + /// + /// Executes an async action for each item in an array, in parallel. + /// Uses pre-allocated task array to avoid LINQ allocations. + /// + /// The type of items to process. + /// The array of items to process. + /// The async action to execute for each item. + /// A task that completes when all items have been processed. + public static async Task ForEachAsync(T[] items, Func action) + { + if (items.Length == 0) + { + return; + } + + var tasks = new Task[items.Length]; + for (var i = 0; i < items.Length; i++) + { + tasks[i] = action(items[i]); + } + + await Task.WhenAll(tasks); + } + + /// + /// Executes an async action for each item in an array, in parallel, with cancellation support. + /// Uses pre-allocated task array to avoid LINQ allocations. + /// + /// The type of items to process. + /// The array of items to process. + /// The async action to execute for each item. + /// Token to cancel the operation. + /// A task that completes when all items have been processed. + public static async Task ForEachAsync(T[] items, Func action, CancellationToken cancellationToken) + { + if (items.Length == 0) + { + return; + } + + cancellationToken.ThrowIfCancellationRequested(); + + var tasks = new Task[items.Length]; + for (var i = 0; i < items.Length; i++) + { + cancellationToken.ThrowIfCancellationRequested(); + tasks[i] = action(items[i], cancellationToken); + } + + await Task.WhenAll(tasks); + } + + /// + /// Executes an async action for each item in an array, in parallel, with an index. + /// Uses pre-allocated task array to avoid LINQ allocations. + /// + /// The type of items to process. + /// The array of items to process. + /// The async action to execute for each item with its index. + /// A task that completes when all items have been processed. + public static async Task ForEachWithIndexAsync(T[] items, Func action) + { + if (items.Length == 0) + { + return; + } + + var tasks = new Task[items.Length]; + for (var i = 0; i < items.Length; i++) + { + tasks[i] = action(items[i], i); + } + + await Task.WhenAll(tasks); + } + + /// + /// Executes an async action for each item in an array, in parallel, with an index and cancellation support. + /// Uses pre-allocated task array to avoid LINQ allocations. + /// + /// The type of items to process. + /// The array of items to process. + /// The async action to execute for each item with its index. + /// Token to cancel the operation. + /// A task that completes when all items have been processed. + public static async Task ForEachWithIndexAsync(T[] items, Func action, CancellationToken cancellationToken) + { + if (items.Length == 0) + { + return; + } + + cancellationToken.ThrowIfCancellationRequested(); + + var tasks = new Task[items.Length]; + for (var i = 0; i < items.Length; i++) + { + cancellationToken.ThrowIfCancellationRequested(); + tasks[i] = action(items[i], i, cancellationToken); + } + + await Task.WhenAll(tasks); + } + + /// + /// Executes an async action for each item in a list, in parallel. + /// Uses pre-allocated task array to avoid LINQ allocations. + /// + /// The type of items to process. + /// The list of items to process. + /// The async action to execute for each item. + /// A task that completes when all items have been processed. + public static async Task ForEachAsync(IReadOnlyList items, Func action) + { + if (items.Count == 0) + { + return; + } + + var tasks = new Task[items.Count]; + for (var i = 0; i < items.Count; i++) + { + tasks[i] = action(items[i]); + } + + await Task.WhenAll(tasks); + } + + /// + /// Executes an async action for each item in a list, in parallel, with cancellation support. + /// Uses pre-allocated task array to avoid LINQ allocations. + /// + /// The type of items to process. + /// The list of items to process. + /// The async action to execute for each item. + /// Token to cancel the operation. + /// A task that completes when all items have been processed. + public static async Task ForEachAsync(IReadOnlyList items, Func action, CancellationToken cancellationToken) + { + if (items.Count == 0) + { + return; + } + + cancellationToken.ThrowIfCancellationRequested(); + + var tasks = new Task[items.Count]; + for (var i = 0; i < items.Count; i++) + { + cancellationToken.ThrowIfCancellationRequested(); + tasks[i] = action(items[i], cancellationToken); + } + + await Task.WhenAll(tasks); + } +} diff --git a/TUnit.Core/Helpers/ReferenceEqualityComparer.cs b/TUnit.Core/Helpers/ReferenceEqualityComparer.cs index 2c26bfb5d1..16da75d77b 100644 --- a/TUnit.Core/Helpers/ReferenceEqualityComparer.cs +++ b/TUnit.Core/Helpers/ReferenceEqualityComparer.cs @@ -1,7 +1,34 @@ -namespace TUnit.Core.Helpers; +using System.Runtime.CompilerServices; -public class ReferenceEqualityComparer : IEqualityComparer +namespace TUnit.Core.Helpers; + +/// +/// Compares objects by reference identity, not value equality. +/// Uses RuntimeHelpers.GetHashCode to get identity-based hash codes. +/// +public sealed class ReferenceEqualityComparer : IEqualityComparer { + /// + /// Singleton instance to avoid repeated allocations. + /// + public static readonly ReferenceEqualityComparer Instance = new(); + + /// + /// Private constructor to enforce singleton pattern. + /// + private ReferenceEqualityComparer() + { + } + + /// + /// Compares two objects by reference identity. + /// + /// + /// The 'new' keyword is used because this method explicitly implements + /// IEqualityComparer<object>.Equals with nullable parameters, which + /// hides the inherited static Object.Equals(object?, object?) method. + /// This is intentional and provides the correct behavior for reference equality. + /// public new bool Equals(object? x, object? y) { return ReferenceEquals(x, y); @@ -9,6 +36,8 @@ public class ReferenceEqualityComparer : IEqualityComparer public int GetHashCode(object obj) { - return obj.GetHashCode(); + // Use RuntimeHelpers.GetHashCode for identity-based hash code + // This returns the same value as Object.GetHashCode() would if not overridden + return RuntimeHelpers.GetHashCode(obj); } } diff --git a/TUnit.Core/Interfaces/IDisposer.cs b/TUnit.Core/Interfaces/IDisposer.cs new file mode 100644 index 0000000000..039665dbae --- /dev/null +++ b/TUnit.Core/Interfaces/IDisposer.cs @@ -0,0 +1,16 @@ +namespace TUnit.Core.Interfaces; + +/// +/// Interface for disposing objects. +/// Follows Dependency Inversion Principle - high-level modules depend on this abstraction. +/// +public interface IDisposer +{ + /// + /// Disposes an object asynchronously. + /// Implementations should propagate exceptions - callers handle aggregation. + /// + /// The object to dispose. + /// A task representing the disposal operation. + ValueTask DisposeAsync(object? obj); +} diff --git a/TUnit.Core/Interfaces/IInitializationCallback.cs b/TUnit.Core/Interfaces/IInitializationCallback.cs new file mode 100644 index 0000000000..4e6449dd6e --- /dev/null +++ b/TUnit.Core/Interfaces/IInitializationCallback.cs @@ -0,0 +1,33 @@ +using System.Collections.Concurrent; + +namespace TUnit.Core.Interfaces; + +/// +/// Defines a callback interface for object initialization during property injection. +/// +/// +/// +/// This interface is used to break circular dependencies between property injection +/// and initialization services. Property injectors can call back to the initialization +/// service without directly depending on it. +/// +/// +internal interface IInitializationCallback +{ + /// + /// Ensures an object is fully initialized (property injection + IAsyncInitializer). + /// + /// The type of object to initialize. + /// The object to initialize. + /// Shared object bag for the test context. + /// Method metadata for the test. Can be null. + /// Test context events for tracking. + /// A token to monitor for cancellation requests. + /// The initialized object. + ValueTask EnsureInitializedAsync( + T obj, + ConcurrentDictionary? objectBag = null, + MethodMetadata? methodMetadata = null, + TestContextEvents? events = null, + CancellationToken cancellationToken = default) where T : notnull; +} diff --git a/TUnit.Core/Interfaces/IObjectGraphDiscoverer.cs b/TUnit.Core/Interfaces/IObjectGraphDiscoverer.cs new file mode 100644 index 0000000000..8ac4586be3 --- /dev/null +++ b/TUnit.Core/Interfaces/IObjectGraphDiscoverer.cs @@ -0,0 +1,132 @@ +using System.Collections.Concurrent; + +namespace TUnit.Core.Interfaces; + +/// +/// Defines a contract for discovering object graphs from test contexts. +/// Pure query interface - only reads and returns data, does not modify state. +/// +/// +/// +/// Object graph discovery is used to find all objects that need initialization or disposal, +/// organized by their nesting depth in the object hierarchy. +/// +/// +/// The discoverer traverses: +/// +/// Test class constructor arguments +/// Test method arguments +/// Injected property values +/// Nested objects that implement +/// +/// +/// +/// For tracking operations that modify TestContext.TrackedObjects, see . +/// +/// +internal interface IObjectGraphDiscoverer +{ + /// + /// Discovers all objects from a test context, organized by depth level. + /// + /// The test context to discover objects from. + /// Optional cancellation token for long-running discovery. + /// + /// An containing all discovered objects organized by depth. + /// Depth 0 contains root objects (arguments and property values). + /// Higher depths contain nested objects. + /// + IObjectGraph DiscoverObjectGraph(TestContext testContext, CancellationToken cancellationToken = default); + + /// + /// Discovers nested objects from a single root object, organized by depth. + /// + /// The root object to discover nested objects from. + /// Optional cancellation token for long-running discovery. + /// + /// An containing all discovered objects organized by depth. + /// Depth 0 contains the root object itself. + /// Higher depths contain nested objects. + /// + IObjectGraph DiscoverNestedObjectGraph(object rootObject, CancellationToken cancellationToken = default); + + /// + /// Discovers objects and populates the test context's tracked objects dictionary directly. + /// Used for efficient object tracking without intermediate allocations. + /// + /// The test context to discover objects from and populate. + /// Optional cancellation token for long-running discovery. + /// + /// The tracked objects dictionary (same as testContext.TrackedObjects) populated with discovered objects. + /// + /// + /// This method modifies testContext.TrackedObjects directly. For pure query operations, + /// use instead. + /// + ConcurrentDictionary> DiscoverAndTrackObjects(TestContext testContext, CancellationToken cancellationToken = default); +} + +/// +/// Marker interface for object graph tracking operations. +/// Extends with operations that modify state. +/// +/// +/// +/// This interface exists to support Interface Segregation Principle: +/// clients that only need query operations can depend on , +/// while clients that need tracking can depend on . +/// +/// +/// Currently inherits all methods from . +/// The distinction exists for semantic clarity and future extensibility. +/// +/// +internal interface IObjectGraphTracker : IObjectGraphDiscoverer +{ + // All methods inherited from IObjectGraphDiscoverer + // This interface provides semantic clarity for tracking operations +} + +/// +/// Represents a discovered object graph organized by depth level. +/// +/// +/// Collections are exposed as read-only to prevent callers from corrupting internal state. +/// Use and for safe iteration. +/// +internal interface IObjectGraph +{ + /// + /// Gets objects organized by depth (0 = root arguments, 1+ = nested). + /// + /// + /// Returns a read-only view. Use for iteration. + /// + IReadOnlyDictionary> ObjectsByDepth { get; } + + /// + /// Gets all unique objects in the graph. + /// + /// + /// Returns a read-only view to prevent modification. + /// + IReadOnlyCollection AllObjects { get; } + + /// + /// Gets the maximum nesting depth (-1 if empty). + /// + int MaxDepth { get; } + + /// + /// Gets objects at a specific depth level. + /// + /// The depth level to retrieve objects from. + /// An enumerable of objects at the specified depth, or empty if none exist. + IEnumerable GetObjectsAtDepth(int depth); + + /// + /// Gets depth levels in descending order (deepest first). + /// + /// An enumerable of depth levels ordered from deepest to shallowest. + IEnumerable GetDepthsDescending(); +} diff --git a/TUnit.Core/Interfaces/IObjectInitializationService.cs b/TUnit.Core/Interfaces/IObjectInitializationService.cs new file mode 100644 index 0000000000..6e5eb70338 --- /dev/null +++ b/TUnit.Core/Interfaces/IObjectInitializationService.cs @@ -0,0 +1,67 @@ +namespace TUnit.Core.Interfaces; + +/// +/// Defines a contract for managing object initialization with phase awareness. +/// +/// +/// +/// This service provides thread-safe, deduplicated initialization of objects that implement +/// or . +/// +/// +/// The service supports two initialization phases: +/// +/// Discovery phase: Only objects are initialized +/// Execution phase: All objects are initialized +/// +/// +/// +internal interface IObjectInitializationService +{ + /// + /// Initializes an object during the execution phase. + /// + /// The object to initialize. If null or not an , no action is taken. + /// A token to monitor for cancellation requests. + /// A representing the asynchronous operation. + /// + /// + /// This method is thread-safe and ensures that each object is initialized exactly once. + /// Multiple concurrent calls for the same object will share the same initialization task. + /// + /// + ValueTask InitializeAsync(object? obj, CancellationToken cancellationToken = default); + + /// + /// Initializes an object during the discovery phase. + /// + /// The object to initialize. If null or not an , no action is taken. + /// A token to monitor for cancellation requests. + /// A representing the asynchronous operation. + /// + /// + /// Only objects implementing are initialized during discovery. + /// Regular objects are deferred to execution phase. + /// + /// + ValueTask InitializeForDiscoveryAsync(object? obj, CancellationToken cancellationToken = default); + + /// + /// Checks if an object has been successfully initialized. + /// + /// The object to check. + /// True if the object has been initialized successfully; otherwise, false. + /// + /// Returns false if the object is null, not an , + /// has not been initialized yet, or if initialization failed. + /// + bool IsInitialized(object? obj); + + /// + /// Clears the initialization cache. + /// + /// + /// Called at the end of a test session to release resources. + /// + void ClearCache(); +} diff --git a/TUnit.Core/ObjectInitializer.cs b/TUnit.Core/ObjectInitializer.cs index 362445816e..46c52923e8 100644 --- a/TUnit.Core/ObjectInitializer.cs +++ b/TUnit.Core/ObjectInitializer.cs @@ -1,52 +1,131 @@ -using System.Runtime.CompilerServices; +using System.Collections.Concurrent; +using TUnit.Core.Helpers; using TUnit.Core.Interfaces; +using TUnit.Core.Services; namespace TUnit.Core; -public static class ObjectInitializer +/// +/// Static facade for initializing objects that implement . +/// Provides thread-safe, deduplicated initialization with explicit phase control. +/// +/// +/// +/// Use during test discovery - only objects are initialized. +/// Use during test execution - all objects are initialized. +/// +/// +/// For dependency injection scenarios, use directly. +/// +/// +internal static class ObjectInitializer { - private static readonly ConditionalWeakTable _initializationTasks = new(); - private static readonly Lock _lock = new(); + // Use Lazy pattern to ensure InitializeAsync is called exactly once per object, + // even under contention. GetOrAdd's factory can be called multiple times, but with + // Lazy + ExecutionAndPublication mode, only one initialization actually runs. + private static readonly ConcurrentDictionary> InitializationTasks = + new(Helpers.ReferenceEqualityComparer.Instance); - internal static bool IsInitialized(object? obj) + /// + /// Initializes an object during the discovery phase. + /// Only objects implementing IAsyncDiscoveryInitializer are initialized. + /// Regular IAsyncInitializer objects are skipped (deferred to execution phase). + /// Thread-safe with deduplication - safe to call multiple times. + /// + /// The object to potentially initialize. + /// Cancellation token. + internal static ValueTask InitializeForDiscoveryAsync(object? obj, CancellationToken cancellationToken = default) { - if (obj is not IAsyncInitializer) + // During discovery, only initialize IAsyncDiscoveryInitializer + if (obj is not IAsyncDiscoveryInitializer asyncDiscoveryInitializer) { - return false; + return default; } - lock (_lock) + return InitializeCoreAsync(obj, asyncDiscoveryInitializer, cancellationToken); + } + + /// + /// Initializes an object during the execution phase. + /// All objects implementing IAsyncInitializer are initialized. + /// Thread-safe with deduplication - safe to call multiple times. + /// + /// The object to potentially initialize. + /// Cancellation token. + internal static ValueTask InitializeAsync(object? obj, CancellationToken cancellationToken = default) + { + if (obj is not IAsyncInitializer asyncInitializer) { - return _initializationTasks.TryGetValue(obj, out var task) && task.IsCompleted; + return default; } + + return InitializeCoreAsync(obj, asyncInitializer, cancellationToken); } - public static async ValueTask InitializeAsync(object? obj, CancellationToken cancellationToken = default) + /// + /// Checks if an object has been successfully initialized by ObjectInitializer. + /// + /// The object to check. + /// True if the object has been initialized successfully; otherwise, false. + /// + /// Returns false if the object is null, not an , + /// has not been initialized yet, or if initialization failed. + /// + internal static bool IsInitialized(object? obj) { - if (obj is IAsyncInitializer asyncInitializer) + if (obj is not IAsyncInitializer) { - await GetInitializationTask(obj, asyncInitializer, cancellationToken); + return false; } + + // Use Status == RanToCompletion to ensure we don't return true for faulted/canceled tasks + // (IsCompletedSuccessfully is not available in netstandard2.0) + // With Lazy, we need to check if the Lazy has a value AND that value completed successfully + return InitializationTasks.TryGetValue(obj, out var lazyTask) && + lazyTask.IsValueCreated && + lazyTask.Value.Status == TaskStatus.RanToCompletion; } - private static async Task GetInitializationTask(object obj, IAsyncInitializer asyncInitializer, CancellationToken cancellationToken) + /// + /// Clears the initialization cache. + /// + /// + /// Called at the end of a test session to release resources. + /// + internal static void ClearCache() { - Task initializationTask; + InitializationTasks.Clear(); + } - lock (_lock) + private static async ValueTask InitializeCoreAsync( + object obj, + IAsyncInitializer asyncInitializer, + CancellationToken cancellationToken) + { + // Use Lazy with ExecutionAndPublication mode to ensure InitializeAsync + // is called exactly once, even under contention. GetOrAdd's factory may be + // called multiple times, but Lazy ensures only one initialization runs. + var lazyTask = InitializationTasks.GetOrAdd(obj, + _ => new Lazy( + () => asyncInitializer.InitializeAsync(), + LazyThreadSafetyMode.ExecutionAndPublication)); + + try { - if (_initializationTasks.TryGetValue(obj, out var existingTask)) - { - initializationTask = existingTask; - } - else - { - initializationTask = asyncInitializer.InitializeAsync(); - _initializationTasks.Add(obj, initializationTask); - } + // Wait for initialization with cancellation support + await lazyTask.Value.WaitAsync(cancellationToken); + } + catch (OperationCanceledException) + { + // Propagate cancellation without modification + throw; + } + catch + { + // Remove failed initialization from cache to allow retry + // This is important for transient failures that may succeed on retry + InitializationTasks.TryRemove(obj, out _); + throw; } - - // Wait for initialization with cancellation support - await initializationTask.WaitAsync(cancellationToken); } } diff --git a/TUnit.Core/PropertyInjection/Initialization/PropertyInitializationContext.cs b/TUnit.Core/PropertyInjection/Initialization/PropertyInitializationContext.cs index 12aef89941..85d4f6ae80 100644 --- a/TUnit.Core/PropertyInjection/Initialization/PropertyInitializationContext.cs +++ b/TUnit.Core/PropertyInjection/Initialization/PropertyInitializationContext.cs @@ -7,6 +7,7 @@ namespace TUnit.Core.PropertyInjection.Initialization; /// /// Encapsulates all context needed for property initialization. /// Follows Single Responsibility Principle by being a pure data container. +/// Provides factory methods to reduce duplication when creating contexts (DRY). /// internal sealed class PropertyInitializationContext { @@ -84,4 +85,124 @@ internal sealed class PropertyInitializationContext /// Parent object for nested properties. /// public object? ParentInstance { get; init; } + + #region Factory Methods (DRY) + + /// + /// Creates a context for source-generated property injection. + /// + public static PropertyInitializationContext ForSourceGenerated( + object instance, + PropertyInjectionMetadata metadata, + ConcurrentDictionary objectBag, + MethodMetadata? methodMetadata, + TestContextEvents events, + ConcurrentDictionary visitedObjects, + TestContext? testContext, + bool isNestedProperty = false) + { + return new PropertyInitializationContext + { + Instance = instance, + SourceGeneratedMetadata = metadata, + PropertyName = metadata.PropertyName, + PropertyType = metadata.PropertyType, + PropertySetter = metadata.SetProperty, + ObjectBag = objectBag, + MethodMetadata = methodMetadata, + Events = events, + VisitedObjects = visitedObjects, + TestContext = testContext, + IsNestedProperty = isNestedProperty + }; + } + + /// + /// Creates a context for reflection-based property injection. + /// + public static PropertyInitializationContext ForReflection( + object instance, + PropertyInfo property, + IDataSourceAttribute dataSource, + Action propertySetter, + ConcurrentDictionary objectBag, + MethodMetadata? methodMetadata, + TestContextEvents events, + ConcurrentDictionary visitedObjects, + TestContext? testContext, + bool isNestedProperty = false) + { + return new PropertyInitializationContext + { + Instance = instance, + PropertyInfo = property, + DataSource = dataSource, + PropertyName = property.Name, + PropertyType = property.PropertyType, + PropertySetter = propertySetter, + ObjectBag = objectBag, + MethodMetadata = methodMetadata, + Events = events, + VisitedObjects = visitedObjects, + TestContext = testContext, + IsNestedProperty = isNestedProperty + }; + } + + /// + /// Creates a context for caching during registration (uses placeholder instance). + /// + public static PropertyInitializationContext ForCaching( + PropertyInjectionMetadata metadata, + ConcurrentDictionary objectBag, + MethodMetadata? methodMetadata, + TestContextEvents events, + TestContext testContext) + { + return new PropertyInitializationContext + { + Instance = PlaceholderInstance.Instance, + SourceGeneratedMetadata = metadata, + PropertyName = metadata.PropertyName, + PropertyType = metadata.PropertyType, + PropertySetter = metadata.SetProperty, + ObjectBag = objectBag, + MethodMetadata = methodMetadata, + Events = events, + VisitedObjects = new ConcurrentDictionary(), + TestContext = testContext, + IsNestedProperty = false + }; + } + + /// + /// Creates a context for reflection caching during registration (uses placeholder instance). + /// + public static PropertyInitializationContext ForReflectionCaching( + PropertyInfo property, + IDataSourceAttribute dataSource, + Action propertySetter, + ConcurrentDictionary objectBag, + MethodMetadata? methodMetadata, + TestContextEvents events, + TestContext testContext) + { + return new PropertyInitializationContext + { + Instance = PlaceholderInstance.Instance, + PropertyInfo = property, + DataSource = dataSource, + PropertyName = property.Name, + PropertyType = property.PropertyType, + PropertySetter = propertySetter, + ObjectBag = objectBag, + MethodMetadata = methodMetadata, + Events = events, + VisitedObjects = new ConcurrentDictionary(), + TestContext = testContext, + IsNestedProperty = false + }; + } + + #endregion } \ No newline at end of file diff --git a/TUnit.Core/PropertyInjection/PropertyCacheKeyGenerator.cs b/TUnit.Core/PropertyInjection/PropertyCacheKeyGenerator.cs new file mode 100644 index 0000000000..c98abadadb --- /dev/null +++ b/TUnit.Core/PropertyInjection/PropertyCacheKeyGenerator.cs @@ -0,0 +1,36 @@ +using System.Reflection; +using TUnit.Core.Interfaces.SourceGenerator; + +namespace TUnit.Core.PropertyInjection; + +/// +/// Generates consistent cache keys for property injection values. +/// Centralizes cache key generation to ensure consistency across the codebase (DRY principle). +/// +/// +/// Cache keys are formatted as "{DeclaringTypeName}.{PropertyName}" to uniquely identify +/// properties across different types. This format is used for storing and retrieving +/// injected property values in test contexts. +/// +public static class PropertyCacheKeyGenerator +{ + /// + /// Generates a cache key from source-generated property metadata. + /// + /// The property injection metadata from source generation. + /// A unique cache key string for the property. + public static string GetCacheKey(PropertyInjectionMetadata metadata) + { + return $"{metadata.ContainingType.FullName}.{metadata.PropertyName}"; + } + + /// + /// Generates a cache key from a PropertyInfo (reflection-based properties). + /// + /// The PropertyInfo from reflection. + /// A unique cache key string for the property. + public static string GetCacheKey(PropertyInfo property) + { + return $"{property.DeclaringType!.FullName}.{property.Name}"; + } +} diff --git a/TUnit.Core/PropertyInjection/PropertyInjectionPlanBuilder.cs b/TUnit.Core/PropertyInjection/PropertyInjectionPlanBuilder.cs index c0251ce158..9c950da9c5 100644 --- a/TUnit.Core/PropertyInjection/PropertyInjectionPlanBuilder.cs +++ b/TUnit.Core/PropertyInjection/PropertyInjectionPlanBuilder.cs @@ -10,6 +10,22 @@ namespace TUnit.Core.PropertyInjection; /// internal static class PropertyInjectionPlanBuilder { + /// + /// Walks up the inheritance chain from the given type to typeof(object), + /// invoking the action for each type in the hierarchy. + /// + /// The starting type. + /// The action to invoke for each type in the inheritance chain. + private static void WalkInheritanceChain(Type type, Action action) + { + var currentType = type; + while (currentType != null && currentType != typeof(object)) + { + action(currentType); + currentType = currentType.BaseType; + } + } + /// /// Creates an injection plan for source-generated mode. /// Walks the inheritance chain to include all injectable properties from base classes. @@ -20,8 +36,7 @@ public static PropertyInjectionPlan BuildSourceGeneratedPlan(Type type) var processedProperties = new HashSet(); // Walk up the inheritance chain to find all properties with data sources - var currentType = type; - while (currentType != null && currentType != typeof(object)) + WalkInheritanceChain(type, currentType => { var propertySource = PropertySourceRegistry.GetSource(currentType); if (propertySource?.ShouldInitialize == true) @@ -35,9 +50,7 @@ public static PropertyInjectionPlan BuildSourceGeneratedPlan(Type type) } } } - - currentType = currentType.BaseType; - } + }); var sourceGenProps = allProperties.ToArray(); @@ -62,8 +75,7 @@ public static PropertyInjectionPlan BuildReflectionPlan(Type type) var processedProperties = new HashSet(); // Walk up the inheritance chain to find all properties with data source attributes - var currentType = type; - while (currentType != null && currentType != typeof(object)) + WalkInheritanceChain(type, currentType => { var properties = currentType.GetProperties( BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Static | BindingFlags.DeclaredOnly) @@ -87,9 +99,7 @@ public static PropertyInjectionPlan BuildReflectionPlan(Type type) } } } - - currentType = currentType.BaseType; - } + }); return new PropertyInjectionPlan { @@ -106,7 +116,7 @@ public static PropertyInjectionPlan BuildReflectionPlan(Type type) /// This handles generic types like ErrFixture<MyType> where the source generator /// couldn't register a property source for the closed generic type. /// - [UnconditionalSuppressMessage("Trimming", "IL2026:Members annotated with 'RequiresUnreferencedCodeAttribute' require dynamic access otherwise can break functionality when trimming application code", Justification = "Source gen mode has its own path>")] + [UnconditionalSuppressMessage("Trimming", "IL2026:Members annotated with 'RequiresUnreferencedCodeAttribute' require dynamic access otherwise can break functionality when trimming application code", Justification = "Source gen mode has its own path")] public static PropertyInjectionPlan Build(Type type) { if (!SourceRegistrar.IsEnabled) @@ -134,6 +144,7 @@ public static PropertyInjectionPlan Build(Type type) /// /// Represents a plan for injecting properties into an object. +/// Provides iterator methods to abstract source-gen vs reflection branching (DRY). /// internal sealed class PropertyInjectionPlan { @@ -141,4 +152,100 @@ internal sealed class PropertyInjectionPlan public required PropertyInjectionMetadata[] SourceGeneratedProperties { get; init; } public required (PropertyInfo Property, IDataSourceAttribute DataSource)[] ReflectionProperties { get; init; } public required bool HasProperties { get; init; } + + /// + /// Iterates over all properties in the plan, abstracting source-gen vs reflection. + /// Call the appropriate callback based on which mode has properties. + /// + /// Action to invoke for each source-generated property. + /// Action to invoke for each reflection property. + public void ForEachProperty( + Action onSourceGenerated, + Action<(PropertyInfo Property, IDataSourceAttribute DataSource)> onReflection) + { + if (SourceGeneratedProperties.Length > 0) + { + foreach (var metadata in SourceGeneratedProperties) + { + onSourceGenerated(metadata); + } + } + else if (ReflectionProperties.Length > 0) + { + foreach (var prop in ReflectionProperties) + { + onReflection(prop); + } + } + } + + /// + /// Iterates over all properties in the plan asynchronously. + /// + public async Task ForEachPropertyAsync( + Func onSourceGenerated, + Func<(PropertyInfo Property, IDataSourceAttribute DataSource), Task> onReflection) + { + if (SourceGeneratedProperties.Length > 0) + { + foreach (var metadata in SourceGeneratedProperties) + { + await onSourceGenerated(metadata); + } + } + else if (ReflectionProperties.Length > 0) + { + foreach (var prop in ReflectionProperties) + { + await onReflection(prop); + } + } + } + + /// + /// Executes actions for all properties in parallel. + /// + public Task ForEachPropertyParallelAsync( + Func onSourceGenerated, + Func<(PropertyInfo Property, IDataSourceAttribute DataSource), Task> onReflection) + { + if (SourceGeneratedProperties.Length > 0) + { + return Helpers.ParallelTaskHelper.ForEachAsync(SourceGeneratedProperties, onSourceGenerated); + } + else if (ReflectionProperties.Length > 0) + { + return Helpers.ParallelTaskHelper.ForEachAsync(ReflectionProperties, onReflection); + } + + return Task.CompletedTask; + } + + /// + /// Gets property values from an instance, abstracting source-gen vs reflection. + /// + public IEnumerable GetPropertyValues(object instance) + { + if (SourceGeneratedProperties.Length > 0) + { + foreach (var metadata in SourceGeneratedProperties) + { + var property = metadata.ContainingType.GetProperty(metadata.PropertyName); + if (property?.CanRead == true) + { + yield return property.GetValue(instance); + } + } + } + else if (ReflectionProperties.Length > 0) + { + foreach (var (property, _) in ReflectionProperties) + { + if (property.CanRead) + { + yield return property.GetValue(instance); + } + } + } + } } diff --git a/TUnit.Core/PropertyInjection/PropertySetterFactory.cs b/TUnit.Core/PropertyInjection/PropertySetterFactory.cs index 3d9dce651d..8bc1662cdf 100644 --- a/TUnit.Core/PropertyInjection/PropertySetterFactory.cs +++ b/TUnit.Core/PropertyInjection/PropertySetterFactory.cs @@ -1,21 +1,56 @@ -using System.Diagnostics.CodeAnalysis; +using System.Collections.Concurrent; +using System.Diagnostics.CodeAnalysis; using System.Reflection; namespace TUnit.Core.PropertyInjection; /// -/// Factory for creating property setters. +/// Factory for creating property setters with caching for performance. /// Consolidates all property setter creation logic in one place following DRY principle. /// +/// +/// Setters are cached using the PropertyInfo as the key to avoid repeated reflection calls. +/// This significantly improves performance when the same property is accessed multiple times +/// (e.g., in test retries or shared test data scenarios). +/// internal static class PropertySetterFactory { + // Cache setters per PropertyInfo to avoid repeated reflection + private static readonly ConcurrentDictionary> SetterCache = new(); + + /// + /// Gets or creates a setter delegate for the given property. + /// Uses caching to avoid repeated reflection calls. + /// + #if NET6_0_OR_GREATER + [RequiresUnreferencedCode("Backing field access for init-only properties requires reflection")] + #endif + public static Action GetOrCreateSetter(PropertyInfo property) + { + return SetterCache.GetOrAdd(property, CreateSetterCore); + } + /// /// Creates a setter delegate for the given property. + /// Consider using for better performance through caching. /// #if NET6_0_OR_GREATER [RequiresUnreferencedCode("Backing field access for init-only properties requires reflection")] #endif public static Action CreateSetter(PropertyInfo property) + { + // Delegate to cached version for consistency + return GetOrCreateSetter(property); + } + + /// + /// Core implementation for creating a setter delegate. + /// Called by GetOrCreateSetter for caching. + /// + #if NET6_0_OR_GREATER + [RequiresUnreferencedCode("Backing field access for init-only properties requires reflection")] + #endif + private static Action CreateSetterCore(PropertyInfo property) { if (property.CanWrite && property.SetMethod != null) { diff --git a/TUnit.Core/Services/ObjectInitializationService.cs b/TUnit.Core/Services/ObjectInitializationService.cs new file mode 100644 index 0000000000..a18845001c --- /dev/null +++ b/TUnit.Core/Services/ObjectInitializationService.cs @@ -0,0 +1,40 @@ +using TUnit.Core.Interfaces; + +namespace TUnit.Core.Services; + +/// +/// Thread-safe service for initializing objects that implement . +/// Provides deduplicated initialization with explicit phase control. +/// +/// +/// +/// This service delegates to the static to ensure consistent +/// behavior and avoid duplicate caches. This consolidates initialization tracking in one place. +/// +/// +internal sealed class ObjectInitializationService : IObjectInitializationService +{ + /// + /// Creates a new instance of the initialization service. + /// + public ObjectInitializationService() + { + // No local cache needed - delegates to static ObjectInitializer + } + + /// + public ValueTask InitializeForDiscoveryAsync(object? obj, CancellationToken cancellationToken = default) + => ObjectInitializer.InitializeForDiscoveryAsync(obj, cancellationToken); + + /// + public ValueTask InitializeAsync(object? obj, CancellationToken cancellationToken = default) + => ObjectInitializer.InitializeAsync(obj, cancellationToken); + + /// + public bool IsInitialized(object? obj) + => ObjectInitializer.IsInitialized(obj); + + /// + public void ClearCache() + => ObjectInitializer.ClearCache(); +} diff --git a/TUnit.Core/TestBuilderContext.cs b/TUnit.Core/TestBuilderContext.cs index 25b20c65d4..f6ec4b0eef 100644 --- a/TUnit.Core/TestBuilderContext.cs +++ b/TUnit.Core/TestBuilderContext.cs @@ -49,7 +49,8 @@ public void RegisterForInitialization(object? obj) { Events.OnInitialize += async (sender, args) => { - await ObjectInitializer.InitializeAsync(obj); + // Discovery: only IAsyncDiscoveryInitializer + await ObjectInitializer.InitializeForDiscoveryAsync(obj); }; } diff --git a/TUnit.Core/TestContext.cs b/TUnit.Core/TestContext.cs index af1b1710d9..6ef95724de 100644 --- a/TUnit.Core/TestContext.cs +++ b/TUnit.Core/TestContext.cs @@ -49,7 +49,8 @@ public TestContext(string testName, IServiceProvider serviceProvider, ClassHookC private static readonly AsyncLocal TestContexts = new(); - internal static readonly Dictionary> InternalParametersDictionary = new(); + // Use ConcurrentDictionary for thread-safe access during parallel test discovery + internal static readonly ConcurrentDictionary> InternalParametersDictionary = new(); private StringWriter? _outputWriter; @@ -72,7 +73,21 @@ internal set public static IReadOnlyDictionary> Parameters => InternalParametersDictionary; - public static IConfiguration Configuration { get; internal set; } = null!; + private static IConfiguration? _configuration; + + /// + /// Gets the test configuration. Throws a descriptive exception if accessed before initialization. + /// + /// Thrown if Configuration is accessed before the test engine initializes it. + public static IConfiguration Configuration + { + get => _configuration ?? throw new InvalidOperationException( + "TestContext.Configuration has not been initialized. " + + "This property is only available after the TUnit test engine has started. " + + "If you are accessing this from a static constructor or field initializer, " + + "consider moving the code to a test setup method or test body instead."); + internal set => _configuration = value; + } public static string? OutputDirectory { @@ -158,8 +173,13 @@ internal override void SetAsyncLocalContext() internal AbstractExecutableTest InternalExecutableTest { get; set; } = null!; private ConcurrentDictionary>? _trackedObjects; + + /// + /// Thread-safe lazy initialization of TrackedObjects using LazyInitializer + /// to prevent race conditions when multiple threads access this property simultaneously. + /// internal ConcurrentDictionary> TrackedObjects => - _trackedObjects ??= new(); + LazyInitializer.EnsureInitialized(ref _trackedObjects)!; /// /// Sets the output captured during test building phase. diff --git a/TUnit.Core/Tracking/ObjectTracker.cs b/TUnit.Core/Tracking/ObjectTracker.cs index 8e702a6262..bb273255c5 100644 --- a/TUnit.Core/Tracking/ObjectTracker.cs +++ b/TUnit.Core/Tracking/ObjectTracker.cs @@ -7,19 +7,105 @@ namespace TUnit.Core.Tracking; /// /// Pure reference counting object tracker for disposable objects. /// Objects are disposed when their reference count reaches zero, regardless of sharing type. +/// Uses ReferenceEqualityComparer to track objects by identity, not value equality. /// +/// +/// The static s_trackedObjects dictionary is shared across all tests. +/// Call at the end of a test session to release memory. +/// internal class ObjectTracker(TrackableObjectGraphProvider trackableObjectGraphProvider, Disposer disposer) { - private static readonly ConcurrentDictionary _trackedObjects = new(); + // Use ReferenceEqualityComparer to prevent objects with custom Equals from sharing state + private static readonly ConcurrentDictionary s_trackedObjects = + new(Helpers.ReferenceEqualityComparer.Instance); + + // Lock for atomic decrement-check-dispose operations to prevent race conditions + private static readonly object s_disposalLock = new(); + + // Collects errors from async disposal callbacks for post-session review + private static readonly ConcurrentBag s_asyncCallbackErrors = new(); + + /// + /// Gets any errors that occurred during async disposal callbacks. + /// Check this at the end of a test session to surface hidden failures. + /// + public static IReadOnlyCollection GetAsyncCallbackErrors() => s_asyncCallbackErrors.ToArray(); + + /// + /// Clears all static tracking state. Call at the end of a test session to release memory. + /// + public static void ClearStaticTracking() + { + s_trackedObjects.Clear(); + s_asyncCallbackErrors.Clear(); + } + + /// + /// Gets an existing counter for the object or creates a new one. + /// Centralizes the GetOrAdd pattern to ensure consistent counter creation. + /// + private static Counter GetOrCreateCounter(object obj) => + s_trackedObjects.GetOrAdd(obj, static _ => new Counter()); + + /// + /// Flattens a ConcurrentDictionary of depth-keyed HashSets into a single HashSet. + /// Thread-safe: locks each HashSet while copying. + /// Pre-calculates capacity to avoid HashSet resizing during population. + /// + private static HashSet FlattenTrackedObjects(ConcurrentDictionary> trackedObjects) + { +#if NETSTANDARD2_0 + // .NET Standard 2.0 doesn't support HashSet capacity constructor + var result = new HashSet(Helpers.ReferenceEqualityComparer.Instance); +#else + // First pass: calculate total capacity to avoid resizing + var totalCapacity = 0; + foreach (var kvp in trackedObjects) + { + lock (kvp.Value) + { + totalCapacity += kvp.Value.Count; + } + } + + // Second pass: populate with pre-sized HashSet + var result = new HashSet(totalCapacity, Helpers.ReferenceEqualityComparer.Instance); +#endif + foreach (var kvp in trackedObjects) + { + lock (kvp.Value) + { + foreach (var obj in kvp.Value) + { + result.Add(obj); + } + } + } + + return result; + } public void TrackObjects(TestContext testContext) { - var alreadyTracked = testContext.TrackedObjects.SelectMany(x => x.Value).ToHashSet(); + // Get already tracked objects (DRY: use helper method) + var alreadyTracked = FlattenTrackedObjects(testContext.TrackedObjects); - var newTrackableObjects = trackableObjectGraphProvider.GetTrackableObjects(testContext) - .SelectMany(x => x.Value) - .Except(alreadyTracked) - .ToHashSet(); + // Get new trackable objects + var newTrackableObjects = new HashSet(Helpers.ReferenceEqualityComparer.Instance); + var trackableDict = trackableObjectGraphProvider.GetTrackableObjects(testContext); + foreach (var kvp in trackableDict) + { + lock (kvp.Value) + { + foreach (var obj in kvp.Value) + { + if (!alreadyTracked.Contains(obj)) + { + newTrackableObjects.Add(obj); + } + } + } + } foreach (var obj in newTrackableObjects) { @@ -29,9 +115,10 @@ public void TrackObjects(TestContext testContext) public async ValueTask UntrackObjects(TestContext testContext, List cleanupExceptions) { - foreach (var obj in testContext.TrackedObjects - .SelectMany(x => x.Value) - .ToHashSet()) + // Get all objects to untrack (DRY: use helper method) + var objectsToUntrack = FlattenTrackedObjects(testContext.TrackedObjects); + + foreach (var obj in objectsToUntrack) { try { @@ -70,7 +157,7 @@ private void TrackObject(object? obj) return; } - var counter = _trackedObjects.GetOrAdd(obj, static _ => new Counter()); + var counter = GetOrCreateCounter(obj); counter.Increment(); } @@ -81,20 +168,36 @@ private async ValueTask UntrackObject(object? obj) return; } - if (_trackedObjects.TryGetValue(obj, out var counter)) - { - var count = counter.Decrement(); + var shouldDispose = false; - if (count < 0) + // Use lock to make decrement-check-remove atomic and prevent race conditions + // where multiple tests could try to dispose the same object simultaneously + lock (s_disposalLock) + { + if (s_trackedObjects.TryGetValue(obj, out var counter)) { - throw new InvalidOperationException("Reference count for object went below zero. This indicates a bug in the reference counting logic."); - } + var count = counter.Decrement(); - if (count == 0) - { - await disposer.DisposeAsync(obj).ConfigureAwait(false); + if (count < 0) + { + throw new InvalidOperationException("Reference count for object went below zero. This indicates a bug in the reference counting logic."); + } + + if (count == 0) + { + // Remove from tracking dictionary to prevent memory leak + // Use TryRemove to ensure atomicity - only remove if still in dictionary + s_trackedObjects.TryRemove(obj, out _); + shouldDispose = true; + } } } + + // Dispose outside the lock to avoid blocking other untrack operations + if (shouldDispose) + { + await disposer.DisposeAsync(obj).ConfigureAwait(false); + } } /// @@ -105,37 +208,128 @@ private static bool ShouldSkipTracking(object? obj) return obj is not IDisposable and not IAsyncDisposable; } + /// + /// Registers a callback to be invoked when the object is disposed (ref count reaches 0). + /// If the object is already disposed (or was never tracked), the callback is invoked immediately. + /// The callback is guaranteed to be invoked exactly once (idempotent). + /// + /// The object to monitor for disposal. If null or not disposable, the method returns without action. + /// The callback to invoke on disposal. Must not be null. + /// Thrown when is null. public static void OnDisposed(object? o, Action action) { - if(o is not IDisposable and not IAsyncDisposable) +#if NET6_0_OR_GREATER + ArgumentNullException.ThrowIfNull(action); +#else + if (action == null) { - return; + throw new ArgumentNullException(nameof(action)); } +#endif - _trackedObjects.GetOrAdd(o, static _ => new Counter()) - .OnCountChanged += (_, count) => - { - if (count == 0) - { - action(); - } - }; + RegisterDisposalCallback(o, action, static a => a()); } + /// + /// Registers an async callback to be invoked when the object is disposed (ref count reaches 0). + /// If the object is already disposed (or was never tracked), the callback is invoked immediately. + /// The callback is guaranteed to be invoked exactly once (idempotent). + /// + /// The object to monitor for disposal. If null or not disposable, the method returns without action. + /// The async callback to invoke on disposal. Must not be null. + /// Thrown when is null. public static void OnDisposedAsync(object? o, Func asyncAction) { - if(o is not IDisposable and not IAsyncDisposable) +#if NET6_0_OR_GREATER + ArgumentNullException.ThrowIfNull(asyncAction); +#else + if (asyncAction == null) + { + throw new ArgumentNullException(nameof(asyncAction)); + } +#endif + + // Wrap async action in fire-and-forget with exception collection + RegisterDisposalCallback(o, asyncAction, static a => _ = SafeExecuteAsync(a)); + } + + /// + /// Core implementation for registering disposal callbacks. + /// Extracts common logic from OnDisposed and OnDisposedAsync (DRY principle). + /// + /// The type of action (Action or Func<Task>). + /// The object to monitor for disposal. + /// The callback action. + /// How to invoke the action (sync vs async wrapper). + private static void RegisterDisposalCallback( + object? o, + TAction action, + Action invoker) + where TAction : Delegate + { + if (o is not IDisposable and not IAsyncDisposable) { return; } - _trackedObjects.GetOrAdd(o, static _ => new Counter()) - .OnCountChanged += async (_, count) => + // Only register callback if the object is actually being tracked. + // If not tracked, invoke callback immediately (object is effectively "disposed"). + // This prevents creating spurious counters for untracked objects. + if (!s_trackedObjects.TryGetValue(o, out var counter)) { - if (count == 0) + // Object not tracked - invoke callback immediately + invoker(action); + return; + } + + // Use flag to ensure callback only fires once (idempotent) + var invoked = 0; + EventHandler? handler = null; + + handler = (sender, count) => + { + if (count == 0 && Interlocked.Exchange(ref invoked, 1) == 0) { - await asyncAction().ConfigureAwait(false); + // Remove handler to prevent memory leaks + if (sender is Counter c && handler != null) + { + c.OnCountChanged -= handler; + } + + invoker(action); } }; + + counter.OnCountChanged += handler; + + // Check if already disposed (count is 0) - invoke immediately if so + // This prevents lost callbacks when registering after disposal + // Idempotent check ensures this doesn't double-fire if event already triggered + if (counter.CurrentCount == 0 && Interlocked.Exchange(ref invoked, 1) == 0) + { + counter.OnCountChanged -= handler; + invoker(action); + } + } + + /// + /// Executes an async action safely, catching and collecting exceptions + /// for post-session review instead of silently swallowing them. + /// + private static async Task SafeExecuteAsync(Func asyncAction) + { + try + { + await asyncAction().ConfigureAwait(false); + } + catch (Exception ex) + { + // Collect error for post-session review instead of silently swallowing + s_asyncCallbackErrors.Add(ex); + +#if DEBUG + System.Diagnostics.Debug.WriteLine($"[ObjectTracker] Exception in OnDisposedAsync callback: {ex.Message}"); +#endif + } } } diff --git a/TUnit.Core/Tracking/TrackableObjectGraphProvider.cs b/TUnit.Core/Tracking/TrackableObjectGraphProvider.cs index 97301e53f8..2460406c6f 100644 --- a/TUnit.Core/Tracking/TrackableObjectGraphProvider.cs +++ b/TUnit.Core/Tracking/TrackableObjectGraphProvider.cs @@ -1,57 +1,47 @@ -using System.Collections.Concurrent; -using System.Diagnostics.CodeAnalysis; -using TUnit.Core.PropertyInjection; +using System.Collections.Concurrent; +using TUnit.Core.Discovery; +using TUnit.Core.Interfaces; using TUnit.Core.StaticProperties; namespace TUnit.Core.Tracking; +/// +/// Provides trackable objects from test contexts for lifecycle management. +/// Delegates to for the actual discovery logic. +/// internal class TrackableObjectGraphProvider { - public ConcurrentDictionary> GetTrackableObjects(TestContext testContext) - { - var visitedObjects = testContext.TrackedObjects; - - var testDetails = testContext.Metadata.TestDetails; - - foreach (var classArgument in testDetails.TestClassArguments) - { - if (classArgument != null && visitedObjects.GetOrAdd(0, []).Add(classArgument)) - { - AddNestedTrackableObjects(classArgument, visitedObjects, 1); - } - } + private readonly IObjectGraphDiscoverer _discoverer; - foreach (var methodArgument in testDetails.TestMethodArguments) - { - if (methodArgument != null && visitedObjects.GetOrAdd(0, []).Add(methodArgument)) - { - AddNestedTrackableObjects(methodArgument, visitedObjects, 1); - } - } - - foreach (var property in testDetails.TestClassInjectedPropertyArguments.Values) - { - if (property != null && visitedObjects.GetOrAdd(0, []).Add(property)) - { - AddNestedTrackableObjects(property, visitedObjects, 1); - } - } + /// + /// Creates a new instance with the default discoverer. + /// + public TrackableObjectGraphProvider() : this(new ObjectGraphDiscoverer()) + { + } - return visitedObjects; + /// + /// Creates a new instance with a custom discoverer (for testing). + /// + public TrackableObjectGraphProvider(IObjectGraphDiscoverer discoverer) + { + _discoverer = discoverer; } - private static void AddToLevel(Dictionary> objectsByLevel, int level, object obj) + /// + /// Gets trackable objects from a test context, organized by depth level. + /// Delegates to the shared IObjectGraphDiscoverer to eliminate code duplication. + /// + /// The test context to get trackable objects from. + /// Optional cancellation token for long-running discovery. + public ConcurrentDictionary> GetTrackableObjects(TestContext testContext, CancellationToken cancellationToken = default) { - if (!objectsByLevel.TryGetValue(level, out var list)) - { - list = []; - objectsByLevel[level] = list; - } - list.Add(obj); + // OCP-compliant: Use the interface method directly instead of type-checking + return _discoverer.DiscoverAndTrackObjects(testContext, cancellationToken); } /// - /// Get trackable objects for static properties (session-level) + /// Gets trackable objects for static properties (session-level). /// public IEnumerable GetStaticPropertyTrackableObjects() { @@ -63,67 +53,4 @@ public IEnumerable GetStaticPropertyTrackableObjects() } } } - - private void AddNestedTrackableObjects(object obj, ConcurrentDictionary> visitedObjects, int currentDepth) - { - var plan = PropertyInjectionCache.GetOrCreatePlan(obj.GetType()); - - if(!SourceRegistrar.IsEnabled) - { - foreach (var prop in plan.ReflectionProperties) - { - var value = prop.Property.GetValue(obj); - - if (value == null) - { - continue; - } - - // Check if already visited before yielding to prevent duplicates - if (!visitedObjects.GetOrAdd(currentDepth, []).Add(value)) - { - continue; - } - - if (!PropertyInjectionCache.HasInjectableProperties(value.GetType())) - { - continue; - } - - AddNestedTrackableObjects(value, visitedObjects, currentDepth + 1); - } - } - else - { - foreach (var metadata in plan.SourceGeneratedProperties) - { - var property = metadata.ContainingType.GetProperty(metadata.PropertyName); - - if (property == null || !property.CanRead) - { - continue; - } - - var value = property.GetValue(obj); - - if (value == null) - { - continue; - } - - // Check if already visited before yielding to prevent duplicates - if (!visitedObjects.GetOrAdd(currentDepth, []).Add(value)) - { - continue; - } - - if (!PropertyInjectionCache.HasInjectableProperties(value.GetType())) - { - continue; - } - - AddNestedTrackableObjects(value, visitedObjects, currentDepth + 1); - } - } - } } diff --git a/TUnit.Engine/Building/TestBuilder.cs b/TUnit.Engine/Building/TestBuilder.cs index 30dc1bcb7e..a8ebc2afbc 100644 --- a/TUnit.Engine/Building/TestBuilder.cs +++ b/TUnit.Engine/Building/TestBuilder.cs @@ -45,9 +45,11 @@ public TestBuilder( } /// - /// Initializes any IAsyncInitializer objects in class data that were deferred during registration. + /// Initializes class data objects during test building. + /// Only IAsyncDiscoveryInitializer objects are initialized during discovery. + /// Regular IAsyncInitializer objects are deferred to execution phase. /// - private async Task InitializeDeferredClassDataAsync(object?[] classData) + private static async Task InitializeClassDataAsync(object?[] classData) { if (classData == null || classData.Length == 0) { @@ -56,44 +58,16 @@ private async Task InitializeDeferredClassDataAsync(object?[] classData) foreach (var data in classData) { - if (data is IAsyncInitializer asyncInitializer && data is not IDataSourceAttribute) - { - if (!ObjectInitializer.IsInitialized(data)) - { - await ObjectInitializer.InitializeAsync(data); - } - } - } - } - - /// - /// Initializes any IAsyncDiscoveryInitializer objects in class data during test discovery. - /// This is called BEFORE method data sources are evaluated, enabling data sources - /// to access initialized shared objects (like Docker containers). - /// - private static async Task InitializeDiscoveryObjectsAsync(object?[] classData) - { - if (classData == null || classData.Length == 0) - { - return; - } - - foreach (var data in classData) - { - if (data is IAsyncDiscoveryInitializer) - { - // Uses ObjectInitializer which handles deduplication. - // This also prevents double-init during execution since ObjectInitializer - // tracks initialized objects. - await ObjectInitializer.InitializeAsync(data); - } + // Discovery: only IAsyncDiscoveryInitializer objects are initialized. + // Regular IAsyncInitializer objects are deferred to execution phase. + await ObjectInitializer.InitializeForDiscoveryAsync(data); } } private async Task CreateInstance(TestMetadata metadata, Type[] resolvedClassGenericArgs, object?[] classData, TestBuilderContext builderContext) { // Initialize any deferred IAsyncInitializer objects in class data - await InitializeDeferredClassDataAsync(classData); + await InitializeClassDataAsync(classData); // First try to create instance with ClassConstructor attribute // Use attributes from context if available @@ -230,9 +204,9 @@ public async Task> BuildTestsFromMetadataAsy var classDataResult = await classDataFactory() ?? []; var classData = DataUnwrapper.Unwrap(classDataResult); - // Initialize IAsyncDiscoveryInitializer objects before method data sources are evaluated. - // This enables InstanceMethodDataSource to access initialized shared objects. - await InitializeDiscoveryObjectsAsync(classData); + // Initialize objects before method data sources are evaluated. + // ObjectInitializer is phase-aware and will only initialize IAsyncDiscoveryInitializer during Discovery. + await InitializeClassDataAsync(classData); var needsInstanceForMethodDataSources = metadata.DataSources.Any(ds => ds is IAccessesInstanceData); @@ -294,11 +268,8 @@ await _objectLifecycleService.RegisterObjectAsync( metadata.MethodMetadata, tempEvents); - // Initialize the test class instance if it implements IAsyncDiscoveryInitializer - if (instanceForMethodDataSources is IAsyncDiscoveryInitializer) - { - await ObjectInitializer.InitializeAsync(instanceForMethodDataSources); - } + // Discovery: only IAsyncDiscoveryInitializer is initialized + await ObjectInitializer.InitializeForDiscoveryAsync(instanceForMethodDataSources); } } catch (Exception ex) @@ -347,8 +318,8 @@ await _objectLifecycleService.RegisterObjectAsync( classData = DataUnwrapper.Unwrap(await classDataFactory() ?? []); var methodData = DataUnwrapper.UnwrapWithTypes(await methodDataFactory() ?? [], metadata.MethodMetadata.Parameters); - // Initialize any IAsyncDiscoveryInitializer objects in method data - await InitializeDiscoveryObjectsAsync(methodData); + // Initialize method data objects (ObjectInitializer is phase-aware) + await InitializeClassDataAsync(methodData); // For concrete generic instantiations, check if the data is compatible with the expected types if (metadata.GenericMethodTypeArguments is { Length: > 0 }) @@ -1423,9 +1394,8 @@ public async IAsyncEnumerable BuildTestsStreamingAsync( var classData = DataUnwrapper.Unwrap(await classDataFactory() ?? []); - // Initialize IAsyncDiscoveryInitializer objects before method data sources are evaluated. - // This enables InstanceMethodDataSource to access initialized shared objects. - await InitializeDiscoveryObjectsAsync(classData); + // Initialize objects before method data sources are evaluated (ObjectInitializer is phase-aware) + await InitializeClassDataAsync(classData); // Handle instance creation for method data sources var needsInstanceForMethodDataSources = metadata.DataSources.Any(ds => ds is IAccessesInstanceData); @@ -1452,11 +1422,8 @@ await _objectLifecycleService.RegisterObjectAsync( metadata.MethodMetadata, tempEvents); - // Initialize the test class instance if it implements IAsyncDiscoveryInitializer - if (instanceForMethodDataSources is IAsyncDiscoveryInitializer) - { - await ObjectInitializer.InitializeAsync(instanceForMethodDataSources); - } + // Discovery: only IAsyncDiscoveryInitializer is initialized + await ObjectInitializer.InitializeForDiscoveryAsync(instanceForMethodDataSources); } // Stream through method data sources @@ -1567,8 +1534,8 @@ await _objectLifecycleService.RegisterObjectAsync( var methodData = DataUnwrapper.UnwrapWithTypes(await methodDataFactory() ?? [], metadata.MethodMetadata.Parameters); - // Initialize any IAsyncDiscoveryInitializer objects in method data - await InitializeDiscoveryObjectsAsync(methodData); + // Initialize method data objects (ObjectInitializer is phase-aware) + await InitializeClassDataAsync(methodData); // Check data compatibility for generic methods if (metadata.GenericMethodTypeArguments is { Length: > 0 }) diff --git a/TUnit.Engine/Framework/TUnitServiceProvider.cs b/TUnit.Engine/Framework/TUnitServiceProvider.cs index f64d64d004..9a31e5eb99 100644 --- a/TUnit.Engine/Framework/TUnitServiceProvider.cs +++ b/TUnit.Engine/Framework/TUnitServiceProvider.cs @@ -114,9 +114,11 @@ public TUnitServiceProvider(IExtension extension, var objectTracker = new ObjectTracker(trackableObjectGraphProvider, disposer); // Use Lazy to break circular dependency between PropertyInjector and ObjectLifecycleService + // PropertyInjector now depends on IInitializationCallback interface (implemented by ObjectLifecycleService) + // This follows Dependency Inversion Principle and improves testability ObjectLifecycleService? objectLifecycleServiceInstance = null; - var lazyObjectLifecycleService = new Lazy(() => objectLifecycleServiceInstance!); - var lazyPropertyInjector = new Lazy(() => new PropertyInjector(lazyObjectLifecycleService, TestSessionId)); + var lazyInitializationCallback = new Lazy(() => objectLifecycleServiceInstance!); + var lazyPropertyInjector = new Lazy(() => new PropertyInjector(lazyInitializationCallback, TestSessionId)); objectLifecycleServiceInstance = new ObjectLifecycleService(lazyPropertyInjector, objectGraphDiscoveryService, objectTracker); ObjectLifecycleService = Register(objectLifecycleServiceInstance); diff --git a/TUnit.Engine/Services/ObjectGraphDiscoveryService.cs b/TUnit.Engine/Services/ObjectGraphDiscoveryService.cs index 8d5e46f7f9..1a0636f605 100644 --- a/TUnit.Engine/Services/ObjectGraphDiscoveryService.cs +++ b/TUnit.Engine/Services/ObjectGraphDiscoveryService.cs @@ -1,198 +1,45 @@ -using System.Collections.Concurrent; -using System.Diagnostics.CodeAnalysis; using TUnit.Core; -using TUnit.Core.PropertyInjection; +using TUnit.Core.Discovery; +using TUnit.Core.Interfaces; namespace TUnit.Engine.Services; /// -/// Centralized service for discovering and organizing object graphs. -/// Eliminates duplicate graph traversal logic that was scattered across -/// PropertyInjectionService, DataSourceInitializer, and TrackableObjectGraphProvider. -/// Follows Single Responsibility Principle - only discovers objects, doesn't modify them. +/// Service for discovering and organizing object graphs in TUnit.Engine. +/// Delegates to in TUnit.Core for the actual discovery logic. /// internal sealed class ObjectGraphDiscoveryService { - /// - /// Discovers all objects from test context arguments and properties, organized by depth level. - /// Depth 0 = root objects (class args, method args, property values) - /// Depth 1+ = nested objects found in properties of objects at previous depth - /// - public ObjectGraph DiscoverObjectGraph(TestContext testContext) - { - var objectsByDepth = new ConcurrentDictionary>(); - var allObjects = new HashSet(); - var visitedObjects = new HashSet(); - - var testDetails = testContext.Metadata.TestDetails; - - // Collect root-level objects (depth 0) - foreach (var classArgument in testDetails.TestClassArguments) - { - if (classArgument != null && visitedObjects.Add(classArgument)) - { - AddToDepth(objectsByDepth, 0, classArgument); - allObjects.Add(classArgument); - DiscoverNestedObjects(classArgument, objectsByDepth, visitedObjects, allObjects, currentDepth: 1); - } - } - - foreach (var methodArgument in testDetails.TestMethodArguments) - { - if (methodArgument != null && visitedObjects.Add(methodArgument)) - { - AddToDepth(objectsByDepth, 0, methodArgument); - allObjects.Add(methodArgument); - DiscoverNestedObjects(methodArgument, objectsByDepth, visitedObjects, allObjects, currentDepth: 1); - } - } - - foreach (var property in testDetails.TestClassInjectedPropertyArguments.Values) - { - if (property != null && visitedObjects.Add(property)) - { - AddToDepth(objectsByDepth, 0, property); - allObjects.Add(property); - DiscoverNestedObjects(property, objectsByDepth, visitedObjects, allObjects, currentDepth: 1); - } - } - - return new ObjectGraph(objectsByDepth, allObjects); - } + private readonly IObjectGraphDiscoverer _discoverer; /// - /// Discovers nested objects from a single root object, organized by depth. - /// Used for discovering objects within a data source or property value. + /// Creates a new instance with the default discoverer. /// - public ObjectGraph DiscoverNestedObjectGraph(object rootObject) + public ObjectGraphDiscoveryService() : this(new ObjectGraphDiscoverer()) { - var objectsByDepth = new ConcurrentDictionary>(); - var allObjects = new HashSet(); - var visitedObjects = new HashSet(); - - if (visitedObjects.Add(rootObject)) - { - AddToDepth(objectsByDepth, 0, rootObject); - allObjects.Add(rootObject); - DiscoverNestedObjects(rootObject, objectsByDepth, visitedObjects, allObjects, currentDepth: 1); - } - - return new ObjectGraph(objectsByDepth, allObjects); } /// - /// Recursively discovers nested objects that have injectable properties. + /// Creates a new instance with a custom discoverer (for testing). /// - [UnconditionalSuppressMessage("Trimming", "IL2026", Justification = "Property discovery handles both AOT and reflection modes")] - private void DiscoverNestedObjects( - object obj, - ConcurrentDictionary> objectsByDepth, - HashSet visitedObjects, - HashSet allObjects, - int currentDepth) + public ObjectGraphDiscoveryService(IObjectGraphDiscoverer discoverer) { - var plan = PropertyInjectionCache.GetOrCreatePlan(obj.GetType()); - - if (!plan.HasProperties) - { - return; - } - - // Use source-generated properties if available, otherwise fall back to reflection - if (plan.SourceGeneratedProperties.Length > 0) - { - foreach (var metadata in plan.SourceGeneratedProperties) - { - var property = metadata.ContainingType.GetProperty(metadata.PropertyName); - if (property == null || !property.CanRead) - { - continue; - } - - var value = property.GetValue(obj); - if (value == null || !visitedObjects.Add(value)) - { - continue; - } - - AddToDepth(objectsByDepth, currentDepth, value); - allObjects.Add(value); - - // Recursively discover if this value has injectable properties - if (PropertyInjectionCache.HasInjectableProperties(value.GetType())) - { - DiscoverNestedObjects(value, objectsByDepth, visitedObjects, allObjects, currentDepth + 1); - } - } - } - else if (plan.ReflectionProperties.Length > 0) - { - foreach (var (property, _) in plan.ReflectionProperties) - { - var value = property.GetValue(obj); - if (value == null || !visitedObjects.Add(value)) - { - continue; - } - - AddToDepth(objectsByDepth, currentDepth, value); - allObjects.Add(value); - - // Recursively discover if this value has injectable properties - if (PropertyInjectionCache.HasInjectableProperties(value.GetType())) - { - DiscoverNestedObjects(value, objectsByDepth, visitedObjects, allObjects, currentDepth + 1); - } - } - } + _discoverer = discoverer; } - private static void AddToDepth(ConcurrentDictionary> objectsByDepth, int depth, object obj) - { - objectsByDepth.GetOrAdd(depth, _ => []).Add(obj); - } -} - -/// -/// Represents a discovered object graph organized by depth level. -/// -internal sealed class ObjectGraph -{ - public ObjectGraph(ConcurrentDictionary> objectsByDepth, HashSet allObjects) - { - ObjectsByDepth = objectsByDepth; - AllObjects = allObjects; - MaxDepth = objectsByDepth.Count > 0 ? objectsByDepth.Keys.Max() : -1; - } - - /// - /// Objects organized by depth (0 = root arguments, 1+ = nested). - /// - public ConcurrentDictionary> ObjectsByDepth { get; } - - /// - /// All unique objects in the graph. - /// - public HashSet AllObjects { get; } - /// - /// Maximum nesting depth (-1 if empty). - /// - public int MaxDepth { get; } - - /// - /// Gets objects at a specific depth level. + /// Discovers all objects from test context arguments and properties, organized by depth level. /// - public IEnumerable GetObjectsAtDepth(int depth) + public IObjectGraph DiscoverObjectGraph(TestContext testContext, CancellationToken cancellationToken = default) { - return ObjectsByDepth.TryGetValue(depth, out var objects) ? objects : []; + return _discoverer.DiscoverObjectGraph(testContext, cancellationToken); } /// - /// Gets depth levels in descending order (deepest first). + /// Discovers nested objects from a single root object, organized by depth. /// - public IEnumerable GetDepthsDescending() + public IObjectGraph DiscoverNestedObjectGraph(object rootObject, CancellationToken cancellationToken = default) { - return ObjectsByDepth.Keys.OrderByDescending(d => d); + return _discoverer.DiscoverNestedObjectGraph(rootObject, cancellationToken); } } diff --git a/TUnit.Engine/Services/ObjectLifecycleService.cs b/TUnit.Engine/Services/ObjectLifecycleService.cs index 33d50d9b08..8ab1773f6a 100644 --- a/TUnit.Engine/Services/ObjectLifecycleService.cs +++ b/TUnit.Engine/Services/ObjectLifecycleService.cs @@ -1,6 +1,7 @@ using System.Collections.Concurrent; using System.Diagnostics.CodeAnalysis; using TUnit.Core; +using TUnit.Core.Helpers; using TUnit.Core.Interfaces; using TUnit.Core.PropertyInjection; using TUnit.Core.PropertyInjection.Initialization; @@ -17,14 +18,20 @@ namespace TUnit.Engine.Services; /// Uses Lazy<T> for dependencies to break circular references without manual Initialize() calls. /// Follows clear phase separation: Register → Inject → Initialize → Cleanup. /// -internal sealed class ObjectLifecycleService : IObjectRegistry +/// +/// Implements to allow PropertyInjector to call back for initialization +/// without creating a direct dependency (breaking the circular reference pattern). +/// +internal sealed class ObjectLifecycleService : IObjectRegistry, IInitializationCallback { private readonly Lazy _propertyInjector; private readonly ObjectGraphDiscoveryService _objectGraphDiscoveryService; private readonly ObjectTracker _objectTracker; // Track initialization state per object - private readonly ConcurrentDictionary> _initializationTasks = new(); + // Use ReferenceEqualityComparer to prevent objects with custom Equals from sharing initialization state + private readonly ConcurrentDictionary> _initializationTasks = + new(Core.Helpers.ReferenceEqualityComparer.Instance); public ObjectLifecycleService( Lazy propertyInjector, @@ -94,7 +101,8 @@ public async Task RegisterArgumentsAsync( return; } - var tasks = new List(); + // Pre-allocate with expected capacity to avoid resizing + var tasks = new List(arguments.Length); foreach (var argument in arguments) { if (argument != null) @@ -103,7 +111,10 @@ public async Task RegisterArgumentsAsync( } } - await Task.WhenAll(tasks); + if (tasks.Count > 0) + { + await Task.WhenAll(tasks); + } } #endregion @@ -112,10 +123,11 @@ public async Task RegisterArgumentsAsync( /// /// Prepares a test for execution. - /// Sets already-resolved cached property values on the current instance and initializes tracked objects. + /// Sets already-resolved cached property values on the current instance. /// This is needed because retries create new instances that don't have properties set yet. + /// Does NOT call IAsyncInitializer - that is deferred until after BeforeClass hooks via InitializeTestObjectsAsync. /// - public async Task PrepareTestAsync(TestContext testContext, CancellationToken cancellationToken) + public void PrepareTest(TestContext testContext) { var testClassInstance = testContext.Metadata.TestDetails.ClassInstance; @@ -123,7 +135,14 @@ public async Task PrepareTestAsync(TestContext testContext, CancellationToken ca // Properties were resolved and cached during RegisterTestAsync, so shared objects are already created // We just need to set them on the actual test instance (retries create new instances) SetCachedPropertiesOnInstance(testClassInstance, testContext); + } + /// + /// Initializes test objects (IAsyncInitializer) after BeforeClass hooks have run. + /// This ensures resources like Docker containers are not started until needed. + /// + public async Task InitializeTestObjectsAsync(TestContext testContext, CancellationToken cancellationToken) + { // Initialize all tracked objects (IAsyncInitializer) depth-first await InitializeTrackedObjectsAsync(testContext, cancellationToken); } @@ -148,7 +167,7 @@ private void SetCachedPropertiesOnInstance(object instance, TestContext testCont { foreach (var metadata in plan.SourceGeneratedProperties) { - var cacheKey = $"{metadata.ContainingType.FullName}.{metadata.PropertyName}"; + var cacheKey = PropertyCacheKeyGenerator.GetCacheKey(metadata); if (cachedProperties.TryGetValue(cacheKey, out var cachedValue) && cachedValue != null) { @@ -161,7 +180,7 @@ private void SetCachedPropertiesOnInstance(object instance, TestContext testCont { foreach (var (property, _) in plan.ReflectionProperties) { - var cacheKey = $"{property.DeclaringType!.FullName}.{property.Name}"; + var cacheKey = PropertyCacheKeyGenerator.GetCacheKey(property); if (cachedProperties.TryGetValue(cacheKey, out var cachedValue) && cachedValue != null) { @@ -175,31 +194,76 @@ private void SetCachedPropertiesOnInstance(object instance, TestContext testCont /// /// Initializes all tracked objects depth-first (deepest objects first). + /// This is called during test execution (after BeforeClass hooks) to initialize IAsyncInitializer objects. + /// Objects at the same level are initialized in parallel. /// private async Task InitializeTrackedObjectsAsync(TestContext testContext, CancellationToken cancellationToken) { - var levels = testContext.TrackedObjects.Keys.OrderByDescending(level => level); + // Get levels without LINQ - use Array.Sort with reverse comparison for descending order + var trackedObjects = testContext.TrackedObjects; + var levelCount = trackedObjects.Count; - foreach (var level in levels) + if (levelCount > 0) { - var objectsAtLevel = testContext.TrackedObjects[level]; - - // Initialize all objects at this depth in parallel - await Task.WhenAll(objectsAtLevel.Select(obj => - EnsureInitializedAsync( - obj, - testContext.StateBag.Items, - testContext.Metadata.TestDetails.MethodMetadata, - testContext.InternalEvents, - cancellationToken).AsTask())); + var levels = new int[levelCount]; + trackedObjects.Keys.CopyTo(levels, 0); + Array.Sort(levels, (a, b) => b.CompareTo(a)); // Descending order + + foreach (var level in levels) + { + if (!trackedObjects.TryGetValue(level, out var objectsAtLevel)) + { + continue; + } + + // Copy to array under lock to prevent concurrent modification + object[] objectsCopy; + lock (objectsAtLevel) + { + objectsCopy = new object[objectsAtLevel.Count]; + objectsAtLevel.CopyTo(objectsCopy); + } + + // Initialize all objects at this level in parallel + var tasks = new List(objectsCopy.Length); + foreach (var obj in objectsCopy) + { + tasks.Add(InitializeObjectWithNestedAsync(obj, cancellationToken)); + } + + if (tasks.Count > 0) + { + await Task.WhenAll(tasks); + } + } } - // Finally initialize the test class itself - await EnsureInitializedAsync( - testContext.Metadata.TestDetails.ClassInstance, - testContext.StateBag.Items, - testContext.Metadata.TestDetails.MethodMetadata, - testContext.InternalEvents, + // Finally initialize the test class and its nested objects + var classInstance = testContext.Metadata.TestDetails.ClassInstance; + await InitializeNestedObjectsForExecutionAsync(classInstance, cancellationToken); + await ObjectInitializer.InitializeAsync(classInstance, cancellationToken); + } + + /// + /// Initializes an object and its nested objects. + /// + private async Task InitializeObjectWithNestedAsync(object obj, CancellationToken cancellationToken) + { + // First initialize nested objects depth-first + await InitializeNestedObjectsForExecutionAsync(obj, cancellationToken); + + // Then initialize the object itself + await ObjectInitializer.InitializeAsync(obj, cancellationToken); + } + + /// + /// Initializes nested objects during execution phase - all IAsyncInitializer objects. + /// + private Task InitializeNestedObjectsForExecutionAsync(object rootObject, CancellationToken cancellationToken) + { + return InitializeNestedObjectsAsync( + rootObject, + ObjectInitializer.InitializeAsync, cancellationToken); } @@ -234,6 +298,7 @@ public async ValueTask InjectPropertiesAsync( /// /// Ensures an object is fully initialized (property injection + IAsyncInitializer). /// Thread-safe with fast-path for already-initialized objects. + /// Called during test execution to initialize all IAsyncInitializer objects. /// public async ValueTask EnsureInitializedAsync( T obj, @@ -247,13 +312,17 @@ public async ValueTask EnsureInitializedAsync( throw new ArgumentNullException(nameof(obj)); } - // Fast path: already initialized + // Fast path: already processed by this service if (_initializationTasks.TryGetValue(obj, out var existingTcs) && existingTcs.Task.IsCompleted) { if (existingTcs.Task.IsFaulted) { await existingTcs.Task.ConfigureAwait(false); } + + // EnsureInitializedAsync is only called during discovery (from PropertyInjector). + // If the object is shared and has already been processed, just return it. + // Regular IAsyncInitializer objects will be initialized during execution via InitializeTrackedObjectsAsync. return obj; } @@ -268,8 +337,18 @@ public async ValueTask EnsureInitializedAsync( await InitializeObjectCoreAsync(obj, objectBag, methodMetadata, events, cancellationToken); tcs.SetResult(true); } + catch (OperationCanceledException) + { + // Propagate cancellation without caching failure - allows retry after cancel + _initializationTasks.TryRemove(obj, out _); + tcs.SetCanceled(); + throw; + } catch (Exception ex) { + // Remove failed initialization from cache to allow retry + // This is important for transient failures that may succeed on retry + _initializationTasks.TryRemove(obj, out _); tcs.SetException(ex); throw; } @@ -284,7 +363,8 @@ public async ValueTask EnsureInitializedAsync( } /// - /// Core initialization: property injection + nested objects + IAsyncInitializer. + /// Core initialization: property injection + IAsyncDiscoveryInitializer only. + /// Regular IAsyncInitializer objects are NOT initialized here - they are deferred to execution phase. /// private async Task InitializeObjectCoreAsync( object obj, @@ -296,44 +376,66 @@ private async Task InitializeObjectCoreAsync( objectBag ??= new ConcurrentDictionary(); events ??= new TestContextEvents(); - try - { - // Step 1: Inject properties - await PropertyInjector.InjectPropertiesAsync(obj, objectBag, methodMetadata, events); + // Let exceptions propagate naturally - don't wrap in InvalidOperationException + // This aligns with ObjectInitializer behavior and provides cleaner stack traces - // Step 2: Initialize nested objects depth-first - await InitializeNestedObjectsAsync(obj, cancellationToken); + // Step 1: Inject properties + await PropertyInjector.InjectPropertiesAsync(obj, objectBag, methodMetadata, events); - // Step 3: Call IAsyncInitializer on the object itself - if (obj is IAsyncInitializer asyncInitializer) - { - await ObjectInitializer.InitializeAsync(asyncInitializer, cancellationToken); - } - } - catch (Exception ex) - { - throw new InvalidOperationException( - $"Failed to initialize object of type '{obj.GetType().Name}': {ex.Message}", ex); - } + // Step 2: Initialize nested objects depth-first (discovery-only) + await InitializeNestedObjectsForDiscoveryAsync(obj, cancellationToken); + + // Step 3: Call IAsyncDiscoveryInitializer only (not regular IAsyncInitializer) + // Regular IAsyncInitializer objects are deferred to execution phase via InitializeTestObjectsAsync + await ObjectInitializer.InitializeForDiscoveryAsync(obj, cancellationToken); } /// - /// Initializes nested objects depth-first using the centralized ObjectGraphDiscoveryService. + /// Initializes nested objects during discovery phase - only IAsyncDiscoveryInitializer objects. /// - private async Task InitializeNestedObjectsAsync(object rootObject, CancellationToken cancellationToken) + private Task InitializeNestedObjectsForDiscoveryAsync(object rootObject, CancellationToken cancellationToken) { - var graph = _objectGraphDiscoveryService.DiscoverNestedObjectGraph(rootObject); + return InitializeNestedObjectsAsync( + rootObject, + ObjectInitializer.InitializeForDiscoveryAsync, + cancellationToken); + } + + /// + /// Shared implementation for nested object initialization (DRY). + /// Discovers nested objects and initializes them depth-first using the provided initializer. + /// + /// The root object to discover nested objects from. + /// The initializer function to call for each object. + /// Cancellation token. + private async Task InitializeNestedObjectsAsync( + object rootObject, + Func initializer, + CancellationToken cancellationToken) + { + var graph = _objectGraphDiscoveryService.DiscoverNestedObjectGraph(rootObject, cancellationToken); // Initialize from deepest to shallowest (skip depth 0 which is the root itself) foreach (var depth in graph.GetDepthsDescending()) { - if (depth == 0) continue; // Root handled separately + if (depth == 0) + { + continue; // Root handled separately + } var objectsAtDepth = graph.GetObjectsAtDepth(depth); - await Task.WhenAll(objectsAtDepth - .Where(obj => obj is IAsyncInitializer) - .Select(obj => ObjectInitializer.InitializeAsync(obj, cancellationToken).AsTask())); + // Pre-allocate task list without LINQ Select + var tasks = new List(); + foreach (var obj in objectsAtDepth) + { + tasks.Add(initializer(obj, cancellationToken).AsTask()); + } + + if (tasks.Count > 0) + { + await Task.WhenAll(tasks); + } } } diff --git a/TUnit.Engine/Services/PropertyInjector.cs b/TUnit.Engine/Services/PropertyInjector.cs index 249f19207b..e0008d17ac 100644 --- a/TUnit.Engine/Services/PropertyInjector.cs +++ b/TUnit.Engine/Services/PropertyInjector.cs @@ -2,6 +2,7 @@ using System.Diagnostics.CodeAnalysis; using System.Reflection; using TUnit.Core; +using TUnit.Core.Helpers; using TUnit.Core.Interfaces; using TUnit.Core.Interfaces.SourceGenerator; using TUnit.Core.PropertyInjection; @@ -14,17 +15,21 @@ namespace TUnit.Engine.Services; /// Follows Single Responsibility Principle - only injects property values, doesn't initialize objects. /// Uses Lazy initialization to break circular dependencies without manual Initialize() calls. /// +/// +/// Depends on rather than a concrete service, +/// enabling testability and following Dependency Inversion Principle. +/// internal sealed class PropertyInjector { - private readonly Lazy _objectLifecycleService; + private readonly Lazy _initializationCallback; private readonly string _testSessionId; // Object pool for visited dictionaries to reduce allocations private static readonly ConcurrentBag> _visitedObjectsPool = new(); - public PropertyInjector(Lazy objectLifecycleService, string testSessionId) + public PropertyInjector(Lazy initializationCallback, string testSessionId) { - _objectLifecycleService = objectLifecycleService; + _initializationCallback = initializationCallback; _testSessionId = testSessionId; } @@ -95,7 +100,7 @@ public async Task InjectPropertiesAsync( #if NETSTANDARD2_0 visitedObjects = new ConcurrentDictionary(); #else - visitedObjects = new ConcurrentDictionary(ReferenceEqualityComparer.Instance); + visitedObjects = new ConcurrentDictionary(Core.Helpers.ReferenceEqualityComparer.Instance); #endif } @@ -124,17 +129,29 @@ public async Task InjectPropertiesIntoArgumentsAsync( return; } - var injectableArgs = arguments - .Where(arg => arg != null && PropertyInjectionCache.HasInjectableProperties(arg.GetType())) - .ToArray(); + // Build list of injectable args without LINQ + var injectableArgs = new List(arguments.Length); + foreach (var arg in arguments) + { + if (arg != null && PropertyInjectionCache.HasInjectableProperties(arg.GetType())) + { + injectableArgs.Add(arg); + } + } - if (injectableArgs.Length == 0) + if (injectableArgs.Count == 0) { return; } - await Task.WhenAll(injectableArgs.Select(arg => - InjectPropertiesAsync(arg!, objectBag, methodMetadata, events))); + // Build task list without LINQ Select + var tasks = new List(injectableArgs.Count); + foreach (var arg in injectableArgs) + { + tasks.Add(InjectPropertiesAsync(arg, objectBag, methodMetadata, events)); + } + + await Task.WhenAll(tasks); } private async Task InjectPropertiesRecursiveAsync( @@ -184,7 +201,7 @@ await InjectReflectionPropertiesAsync( } } - private async Task InjectSourceGeneratedPropertiesAsync( + private Task InjectSourceGeneratedPropertiesAsync( object instance, PropertyInjectionMetadata[] properties, ConcurrentDictionary objectBag, @@ -192,14 +209,8 @@ private async Task InjectSourceGeneratedPropertiesAsync( TestContextEvents events, ConcurrentDictionary visitedObjects) { - if (properties.Length == 0) - { - return; - } - - // Initialize properties in parallel - await Task.WhenAll(properties.Select(metadata => - InjectSourceGeneratedPropertyAsync(instance, metadata, objectBag, methodMetadata, events, visitedObjects))); + return ParallelTaskHelper.ForEachAsync(properties, + prop => InjectSourceGeneratedPropertyAsync(instance, prop, objectBag, methodMetadata, events, visitedObjects)); } [UnconditionalSuppressMessage("Trimming", "IL2026", Justification = "Source-gen properties are AOT-safe")] @@ -228,7 +239,7 @@ private async Task InjectSourceGeneratedPropertyAsync( object? resolvedValue = null; // Use a composite key to avoid conflicts when nested classes have properties with the same name - var cacheKey = $"{metadata.ContainingType.FullName}.{metadata.PropertyName}"; + var cacheKey = PropertyCacheKeyGenerator.GetCacheKey(metadata); // Check if property was pre-resolved during registration if (testContext?.Metadata.TestDetails.TestClassInjectedPropertyArguments.TryGetValue(cacheKey, out resolvedValue) == true) @@ -272,7 +283,7 @@ private async Task InjectSourceGeneratedPropertyAsync( } [UnconditionalSuppressMessage("Trimming", "IL2026", Justification = "Reflection mode is not used in AOT")] - private async Task InjectReflectionPropertiesAsync( + private Task InjectReflectionPropertiesAsync( object instance, (PropertyInfo Property, IDataSourceAttribute DataSource)[] properties, ConcurrentDictionary objectBag, @@ -280,13 +291,8 @@ private async Task InjectReflectionPropertiesAsync( TestContextEvents events, ConcurrentDictionary visitedObjects) { - if (properties.Length == 0) - { - return; - } - - await Task.WhenAll(properties.Select(pair => - InjectReflectionPropertyAsync(instance, pair.Property, pair.DataSource, objectBag, methodMetadata, events, visitedObjects))); + return ParallelTaskHelper.ForEachAsync(properties, + pair => InjectReflectionPropertyAsync(instance, pair.Property, pair.DataSource, objectBag, methodMetadata, events, visitedObjects)); } [UnconditionalSuppressMessage("Trimming", "IL2026", Justification = "Reflection mode is not used in AOT")] @@ -380,21 +386,15 @@ private async Task RecurseIntoNestedPropertiesAsync( } } - private async Task ResolveAndCacheSourceGeneratedPropertiesAsync( + private Task ResolveAndCacheSourceGeneratedPropertiesAsync( PropertyInjectionMetadata[] properties, ConcurrentDictionary objectBag, MethodMetadata? methodMetadata, TestContextEvents events, TestContext testContext) { - if (properties.Length == 0) - { - return; - } - - // Resolve properties in parallel - await Task.WhenAll(properties.Select(metadata => - ResolveAndCacheSourceGeneratedPropertyAsync(metadata, objectBag, methodMetadata, events, testContext))); + return ParallelTaskHelper.ForEachAsync(properties, + prop => ResolveAndCacheSourceGeneratedPropertyAsync(prop, objectBag, methodMetadata, events, testContext)); } [UnconditionalSuppressMessage("Trimming", "IL2026", Justification = "Source-gen properties are AOT-safe")] @@ -405,7 +405,7 @@ private async Task ResolveAndCacheSourceGeneratedPropertyAsync( TestContextEvents events, TestContext testContext) { - var cacheKey = $"{metadata.ContainingType.FullName}.{metadata.PropertyName}"; + var cacheKey = PropertyCacheKeyGenerator.GetCacheKey(metadata); // Check if already cached if (testContext.Metadata.TestDetails.TestClassInjectedPropertyArguments.ContainsKey(cacheKey)) @@ -439,20 +439,15 @@ private async Task ResolveAndCacheSourceGeneratedPropertyAsync( } [UnconditionalSuppressMessage("Trimming", "IL2026", Justification = "Reflection mode is not used in AOT")] - private async Task ResolveAndCacheReflectionPropertiesAsync( + private Task ResolveAndCacheReflectionPropertiesAsync( (PropertyInfo Property, IDataSourceAttribute DataSource)[] properties, ConcurrentDictionary objectBag, MethodMetadata? methodMetadata, TestContextEvents events, TestContext testContext) { - if (properties.Length == 0) - { - return; - } - - await Task.WhenAll(properties.Select(pair => - ResolveAndCacheReflectionPropertyAsync(pair.Property, pair.DataSource, objectBag, methodMetadata, events, testContext))); + return ParallelTaskHelper.ForEachAsync(properties, + pair => ResolveAndCacheReflectionPropertyAsync(pair.Property, pair.DataSource, objectBag, methodMetadata, events, testContext)); } [UnconditionalSuppressMessage("Trimming", "IL2026", Justification = "Reflection mode is not used in AOT")] @@ -464,7 +459,7 @@ private async Task ResolveAndCacheReflectionPropertyAsync( TestContextEvents events, TestContext testContext) { - var cacheKey = $"{property.DeclaringType!.FullName}.{property.Name}"; + var cacheKey = PropertyCacheKeyGenerator.GetCacheKey(property); // Check if already cached if (testContext.Metadata.TestDetails.TestClassInjectedPropertyArguments.ContainsKey(cacheKey)) @@ -525,15 +520,14 @@ private async Task ResolveAndCacheReflectionPropertyAsync( if (value != null) { - // Ensure nested objects are initialized - if (PropertyInjectionCache.HasInjectableProperties(value.GetType()) || value is IAsyncInitializer) - { - await _objectLifecycleService.Value.EnsureInitializedAsync( - value, - context.ObjectBag, - context.MethodMetadata, - context.Events); - } + // EnsureInitializedAsync handles property injection and initialization. + // ObjectInitializer is phase-aware: during Discovery phase, only IAsyncDiscoveryInitializer + // objects are initialized; regular IAsyncInitializer objects are deferred to Execution phase. + await _initializationCallback.Value.EnsureInitializedAsync( + value, + context.ObjectBag, + context.MethodMetadata, + context.Events); return value; } @@ -561,7 +555,7 @@ await _objectLifecycleService.Value.EnsureInitializedAsync( } // Ensure the data source is initialized - return await _objectLifecycleService.Value.EnsureInitializedAsync( + return await _initializationCallback.Value.EnsureInitializedAsync( dataSource, context.ObjectBag, context.MethodMetadata, diff --git a/TUnit.Engine/Services/TestExecution/TestCoordinator.cs b/TUnit.Engine/Services/TestExecution/TestCoordinator.cs index b448e459e6..683a9e8f4c 100644 --- a/TUnit.Engine/Services/TestExecution/TestCoordinator.cs +++ b/TUnit.Engine/Services/TestExecution/TestCoordinator.cs @@ -126,9 +126,9 @@ await TimeoutHelper.ExecuteWithTimeoutAsync( try { - await _testInitializer.InitializeTest(test, ct).ConfigureAwait(false); + _testInitializer.PrepareTest(test, ct); test.Context.RestoreExecutionContext(); - await _testExecutor.ExecuteAsync(test, ct).ConfigureAwait(false); + await _testExecutor.ExecuteAsync(test, _testInitializer, ct).ConfigureAwait(false); } finally { diff --git a/TUnit.Engine/TestExecutor.cs b/TUnit.Engine/TestExecutor.cs index e7d65b92f0..dbb80cc3f6 100644 --- a/TUnit.Engine/TestExecutor.cs +++ b/TUnit.Engine/TestExecutor.cs @@ -62,7 +62,7 @@ await _beforeHookTaskCache.GetOrCreateBeforeTestSessionTask( /// Creates a test executor delegate that wraps the provided executor with hook orchestration. /// Uses focused services that follow SRP to manage lifecycle and execution. /// - public async ValueTask ExecuteAsync(AbstractExecutableTest executableTest, CancellationToken cancellationToken) + public async ValueTask ExecuteAsync(AbstractExecutableTest executableTest, TestInitializer testInitializer, CancellationToken cancellationToken) { var testClass = executableTest.Metadata.TestClassType; @@ -112,6 +112,12 @@ await _eventReceiverOrchestrator.InvokeFirstTestInClassEventReceiversAsync( executableTest.Context.ClassContext.RestoreExecutionContext(); + // Initialize test objects (IAsyncInitializer) AFTER BeforeClass hooks + // This ensures resources like Docker containers are not started until needed + await testInitializer.InitializeTestObjectsAsync(executableTest, cancellationToken).ConfigureAwait(false); + + executableTest.Context.RestoreExecutionContext(); + // Early stage test start receivers run before instance-level hooks await _eventReceiverOrchestrator.InvokeTestStartEventReceiversAsync(executableTest.Context, cancellationToken, EventReceiverStage.Early).ConfigureAwait(false); diff --git a/TUnit.Engine/TestInitializer.cs b/TUnit.Engine/TestInitializer.cs index de117c6746..73fbb7fbc2 100644 --- a/TUnit.Engine/TestInitializer.cs +++ b/TUnit.Engine/TestInitializer.cs @@ -20,12 +20,19 @@ public TestInitializer( _objectLifecycleService = objectLifecycleService; } - public async ValueTask InitializeTest(AbstractExecutableTest test, CancellationToken cancellationToken) + public void PrepareTest(AbstractExecutableTest test, CancellationToken cancellationToken) { // Register event receivers _eventReceiverOrchestrator.RegisterReceivers(test.Context, cancellationToken); - // Prepare test: inject properties, track objects, initialize (IAsyncInitializer) - await _objectLifecycleService.PrepareTestAsync(test.Context, cancellationToken); + // Prepare test: set cached property values on the instance + // Does NOT call IAsyncInitializer - that is deferred until after BeforeClass hooks + _objectLifecycleService.PrepareTest(test.Context); + } + + public async ValueTask InitializeTestObjectsAsync(AbstractExecutableTest test, CancellationToken cancellationToken) + { + // Initialize test objects (IAsyncInitializer) - called after BeforeClass hooks + await _objectLifecycleService.InitializeTestObjectsAsync(test.Context, cancellationToken); } } diff --git a/TUnit.PublicAPI/Tests.Core_Library_Has_No_API_Changes.DotNet10_0.verified.txt b/TUnit.PublicAPI/Tests.Core_Library_Has_No_API_Changes.DotNet10_0.verified.txt index 33684f6146..212cbcb9fc 100644 --- a/TUnit.PublicAPI/Tests.Core_Library_Has_No_API_Changes.DotNet10_0.verified.txt +++ b/TUnit.PublicAPI/Tests.Core_Library_Has_No_API_Changes.DotNet10_0.verified.txt @@ -1000,10 +1000,6 @@ namespace public . NotInParallelConstraintKeys { get; init; } public int Order { get; set; } } - public static class ObjectInitializer - { - public static . InitializeAsync(object? obj, .CancellationToken cancellationToken = default) { } - } public class ParallelGroupAttribute : .TUnitAttribute, ., . { public ParallelGroupAttribute(string group) { } @@ -2021,14 +2017,23 @@ namespace .Helpers [.("MakeGenericType requires runtime code generation")] public static MakeGenericTypeSafe( genericTypeDefinition, params [] typeArguments) { } } + public static class ParallelTaskHelper + { + public static . ForEachAsync(. items, action) { } + public static . ForEachAsync(T[] items, action) { } + public static . ForEachAsync(. items, action, .CancellationToken cancellationToken) { } + public static . ForEachAsync(T[] items, action, .CancellationToken cancellationToken) { } + public static . ForEachWithIndexAsync(T[] items, action) { } + public static . ForEachWithIndexAsync(T[] items, action, .CancellationToken cancellationToken) { } + } public class ProcessorCountParallelLimit : . { public ProcessorCountParallelLimit() { } public int Limit { get; } } - public class ReferenceEqualityComparer : . + public sealed class ReferenceEqualityComparer : . { - public ReferenceEqualityComparer() { } + public static readonly . Instance; public bool Equals(object? x, object? y) { } public int GetHashCode(object obj) { } } @@ -2229,6 +2234,10 @@ namespace .Interfaces { .<> GenerateDataFactories(.DataSourceContext context); } + public interface IDisposer + { + . DisposeAsync(object? obj); + } public interface IEventReceiver { int Order { get; } @@ -2588,6 +2597,14 @@ namespace .Models public ? MethodInvoker { get; set; } } } +namespace .PropertyInjection +{ + public static class PropertyCacheKeyGenerator + { + public static string GetCacheKey(.PropertyInfo property) { } + public static string GetCacheKey(..PropertyInjectionMetadata metadata) { } + } +} namespace .Services { [.("Generic type resolution requires runtime type generation")] diff --git a/TUnit.PublicAPI/Tests.Core_Library_Has_No_API_Changes.DotNet8_0.verified.txt b/TUnit.PublicAPI/Tests.Core_Library_Has_No_API_Changes.DotNet8_0.verified.txt index d97c4f38b5..00c508f06b 100644 --- a/TUnit.PublicAPI/Tests.Core_Library_Has_No_API_Changes.DotNet8_0.verified.txt +++ b/TUnit.PublicAPI/Tests.Core_Library_Has_No_API_Changes.DotNet8_0.verified.txt @@ -1000,10 +1000,6 @@ namespace public . NotInParallelConstraintKeys { get; init; } public int Order { get; set; } } - public static class ObjectInitializer - { - public static . InitializeAsync(object? obj, .CancellationToken cancellationToken = default) { } - } public class ParallelGroupAttribute : .TUnitAttribute, ., . { public ParallelGroupAttribute(string group) { } @@ -2021,14 +2017,23 @@ namespace .Helpers [.("MakeGenericType requires runtime code generation")] public static MakeGenericTypeSafe( genericTypeDefinition, params [] typeArguments) { } } + public static class ParallelTaskHelper + { + public static . ForEachAsync(. items, action) { } + public static . ForEachAsync(T[] items, action) { } + public static . ForEachAsync(. items, action, .CancellationToken cancellationToken) { } + public static . ForEachAsync(T[] items, action, .CancellationToken cancellationToken) { } + public static . ForEachWithIndexAsync(T[] items, action) { } + public static . ForEachWithIndexAsync(T[] items, action, .CancellationToken cancellationToken) { } + } public class ProcessorCountParallelLimit : . { public ProcessorCountParallelLimit() { } public int Limit { get; } } - public class ReferenceEqualityComparer : . + public sealed class ReferenceEqualityComparer : . { - public ReferenceEqualityComparer() { } + public static readonly . Instance; public bool Equals(object? x, object? y) { } public int GetHashCode(object obj) { } } @@ -2229,6 +2234,10 @@ namespace .Interfaces { .<> GenerateDataFactories(.DataSourceContext context); } + public interface IDisposer + { + . DisposeAsync(object? obj); + } public interface IEventReceiver { int Order { get; } @@ -2588,6 +2597,14 @@ namespace .Models public ? MethodInvoker { get; set; } } } +namespace .PropertyInjection +{ + public static class PropertyCacheKeyGenerator + { + public static string GetCacheKey(.PropertyInfo property) { } + public static string GetCacheKey(..PropertyInjectionMetadata metadata) { } + } +} namespace .Services { [.("Generic type resolution requires runtime type generation")] diff --git a/TUnit.PublicAPI/Tests.Core_Library_Has_No_API_Changes.DotNet9_0.verified.txt b/TUnit.PublicAPI/Tests.Core_Library_Has_No_API_Changes.DotNet9_0.verified.txt index 5280c9595c..af52063e82 100644 --- a/TUnit.PublicAPI/Tests.Core_Library_Has_No_API_Changes.DotNet9_0.verified.txt +++ b/TUnit.PublicAPI/Tests.Core_Library_Has_No_API_Changes.DotNet9_0.verified.txt @@ -1000,10 +1000,6 @@ namespace public . NotInParallelConstraintKeys { get; init; } public int Order { get; set; } } - public static class ObjectInitializer - { - public static . InitializeAsync(object? obj, .CancellationToken cancellationToken = default) { } - } public class ParallelGroupAttribute : .TUnitAttribute, ., . { public ParallelGroupAttribute(string group) { } @@ -2021,14 +2017,23 @@ namespace .Helpers [.("MakeGenericType requires runtime code generation")] public static MakeGenericTypeSafe( genericTypeDefinition, params [] typeArguments) { } } + public static class ParallelTaskHelper + { + public static . ForEachAsync(. items, action) { } + public static . ForEachAsync(T[] items, action) { } + public static . ForEachAsync(. items, action, .CancellationToken cancellationToken) { } + public static . ForEachAsync(T[] items, action, .CancellationToken cancellationToken) { } + public static . ForEachWithIndexAsync(T[] items, action) { } + public static . ForEachWithIndexAsync(T[] items, action, .CancellationToken cancellationToken) { } + } public class ProcessorCountParallelLimit : . { public ProcessorCountParallelLimit() { } public int Limit { get; } } - public class ReferenceEqualityComparer : . + public sealed class ReferenceEqualityComparer : . { - public ReferenceEqualityComparer() { } + public static readonly . Instance; public bool Equals(object? x, object? y) { } public int GetHashCode(object obj) { } } @@ -2229,6 +2234,10 @@ namespace .Interfaces { .<> GenerateDataFactories(.DataSourceContext context); } + public interface IDisposer + { + . DisposeAsync(object? obj); + } public interface IEventReceiver { int Order { get; } @@ -2588,6 +2597,14 @@ namespace .Models public ? MethodInvoker { get; set; } } } +namespace .PropertyInjection +{ + public static class PropertyCacheKeyGenerator + { + public static string GetCacheKey(.PropertyInfo property) { } + public static string GetCacheKey(..PropertyInjectionMetadata metadata) { } + } +} namespace .Services { [.("Generic type resolution requires runtime type generation")] diff --git a/TUnit.PublicAPI/Tests.Core_Library_Has_No_API_Changes.Net4_7.verified.txt b/TUnit.PublicAPI/Tests.Core_Library_Has_No_API_Changes.Net4_7.verified.txt index 538bd3ca00..a3d525329b 100644 --- a/TUnit.PublicAPI/Tests.Core_Library_Has_No_API_Changes.Net4_7.verified.txt +++ b/TUnit.PublicAPI/Tests.Core_Library_Has_No_API_Changes.Net4_7.verified.txt @@ -963,10 +963,6 @@ namespace public . NotInParallelConstraintKeys { get; init; } public int Order { get; set; } } - public static class ObjectInitializer - { - public static . InitializeAsync(object? obj, .CancellationToken cancellationToken = default) { } - } public class ParallelGroupAttribute : .TUnitAttribute, ., . { public ParallelGroupAttribute(string group) { } @@ -1960,14 +1956,23 @@ namespace .Helpers public static bool IsConstructedGenericType( type) { } public static MakeGenericTypeSafe( genericTypeDefinition, params [] typeArguments) { } } + public static class ParallelTaskHelper + { + public static . ForEachAsync(. items, action) { } + public static . ForEachAsync(T[] items, action) { } + public static . ForEachAsync(. items, action, .CancellationToken cancellationToken) { } + public static . ForEachAsync(T[] items, action, .CancellationToken cancellationToken) { } + public static . ForEachWithIndexAsync(T[] items, action) { } + public static . ForEachWithIndexAsync(T[] items, action, .CancellationToken cancellationToken) { } + } public class ProcessorCountParallelLimit : . { public ProcessorCountParallelLimit() { } public int Limit { get; } } - public class ReferenceEqualityComparer : . + public sealed class ReferenceEqualityComparer : . { - public ReferenceEqualityComparer() { } + public static readonly . Instance; public bool Equals(object? x, object? y) { } public int GetHashCode(object obj) { } } @@ -2161,6 +2166,10 @@ namespace .Interfaces { .<> GenerateDataFactories(.DataSourceContext context); } + public interface IDisposer + { + . DisposeAsync(object? obj); + } public interface IEventReceiver { int Order { get; } @@ -2510,6 +2519,14 @@ namespace .Models public ? MethodInvoker { get; set; } } } +namespace .PropertyInjection +{ + public static class PropertyCacheKeyGenerator + { + public static string GetCacheKey(.PropertyInfo property) { } + public static string GetCacheKey(..PropertyInjectionMetadata metadata) { } + } +} namespace .Services { public class GenericTypeResolver : . diff --git a/TUnit.TestProject/Bugs/3992/RuntimeInitializeTests.cs b/TUnit.TestProject/Bugs/3992/RuntimeInitializeTests.cs new file mode 100644 index 0000000000..1d86987a8a --- /dev/null +++ b/TUnit.TestProject/Bugs/3992/RuntimeInitializeTests.cs @@ -0,0 +1,56 @@ +using TUnit.Core.Interfaces; +using TUnit.TestProject.Attributes; + +namespace TUnit.TestProject.Bugs._3992; + +/// +/// Once this is discovered during test discovery, containers spin up +/// +[EngineTest(ExpectedResult.Pass)] +public sealed class RuntimeInitializeTests +{ + //Docker container + [ClassDataSource(Shared = SharedType.PerClass)] + public required DummyContainer Container { get; init; } + + [Before(Class)] + public static Task BeforeClass(ClassHookContext context) => NotInitialised(context.Tests); + + [After(TestDiscovery)] + public static Task AfterDiscovery(TestDiscoveryContext context) => NotInitialised(context.AllTests); + + public static async Task NotInitialised(IEnumerable tests) + { + var bugRecreations = tests.Select(x => x.Metadata.TestDetails.ClassInstance).OfType(); + + foreach (var bugRecreation in bugRecreations) + { + await Assert.That(bugRecreation.Container).IsNotNull(); + await Assert.That(DummyContainer.NumberOfInits).IsEqualTo(0); + } + } + + [Test, Arguments(1)] + public async Task Test(int value, CancellationToken token) + { + await Assert.That(value).IsNotDefault(); + await Assert.That(DummyContainer.NumberOfInits).IsEqualTo(1); + } + + public class DummyContainer : IAsyncInitializer, IAsyncDisposable + { + public Task InitializeAsync() + { + NumberOfInits++; + return Task.CompletedTask; + } + + public static int NumberOfInits { get; private set; } + + public ValueTask DisposeAsync() + { + return default; + } + } + +} diff --git a/TUnit.TestProject/Bugs/3992/RuntimeInitializeTests2.cs b/TUnit.TestProject/Bugs/3992/RuntimeInitializeTests2.cs new file mode 100644 index 0000000000..4edc097c47 --- /dev/null +++ b/TUnit.TestProject/Bugs/3992/RuntimeInitializeTests2.cs @@ -0,0 +1,56 @@ +using TUnit.Core.Interfaces; +using TUnit.TestProject.Attributes; + +namespace TUnit.TestProject.Bugs._3992; + +/// +/// Once this is discovered during test discovery, containers spin up +/// +[EngineTest(ExpectedResult.Pass)] +public sealed class DiscoveryInitializeTests +{ + //Docker container + [ClassDataSource(Shared = SharedType.PerClass)] + public required DummyContainer Container { get; init; } + + [Before(Class)] + public static Task BeforeClass(ClassHookContext context) => Initialised(context.Tests); + + [After(TestDiscovery)] + public static Task AfterDiscovery(TestDiscoveryContext context) => Initialised(context.AllTests); + + public static async Task Initialised(IEnumerable tests) + { + var bugRecreations = tests.Select(x => x.Metadata.TestDetails.ClassInstance).OfType(); + + foreach (var bugRecreation in bugRecreations) + { + await Assert.That(bugRecreation.Container).IsNotNull(); + await Assert.That(DummyContainer.NumberOfInits).IsEqualTo(1); + } + } + + [Test, Arguments(1)] + public async Task Test(int value, CancellationToken token) + { + await Assert.That(value).IsNotDefault(); + await Assert.That(DummyContainer.NumberOfInits).IsEqualTo(1); + } + + public class DummyContainer : IAsyncDiscoveryInitializer, IAsyncDisposable + { + public Task InitializeAsync() + { + NumberOfInits++; + return Task.CompletedTask; + } + + public static int NumberOfInits { get; private set; } + + public ValueTask DisposeAsync() + { + return default; + } + } + +} diff --git a/TUnit.TestProject/TestBuildContextOutputCaptureTests.cs b/TUnit.TestProject/TestBuildContextOutputCaptureTests.cs index cb11a76453..246e622391 100644 --- a/TUnit.TestProject/TestBuildContextOutputCaptureTests.cs +++ b/TUnit.TestProject/TestBuildContextOutputCaptureTests.cs @@ -30,9 +30,12 @@ public DataSourceWithConstructorOutput() } /// - /// Data source that writes to console in async initializer + /// Data source that writes to console in async initializer. + /// Uses IAsyncDiscoveryInitializer so it initializes during test discovery/building, + /// allowing the output to be captured in the test's build context. + /// Note: Regular IAsyncInitializer only runs during test execution (per issue #3992 fix). /// - public class DataSourceWithAsyncInitOutput : IAsyncInitializer + public class DataSourceWithAsyncInitOutput : IAsyncDiscoveryInitializer { public string Value { get; private set; } = "Uninitialized"; @@ -88,8 +91,9 @@ public async Task Test_CapturesConstructorOutput_InTestResults(DataSourceWithCon [ClassDataSource] public async Task Test_CapturesAsyncInitializerOutput_InTestResults(DataSourceWithAsyncInitOutput data) { - // The InitializeAsync output should be captured during test building - // and included in the test's output + // The InitializeAsync output should be captured during test building. + // Note: This uses IAsyncDiscoveryInitializer which runs during discovery. + // Regular IAsyncInitializer runs during execution only (per issue #3992 fix). // Get the test output var output = TestContext.Current!.GetStandardOutput();