Skip to content

Commit 72518ce

Browse files
authored
refactor(assertions): enhance collection assertion types for improved type inference and chaining (#3414)
1 parent 2f8b734 commit 72518ce

6 files changed

+61
-34
lines changed

TUnit.Assertions.Tests/TypeInferenceTests.cs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,21 @@ await Assert.That(enumerable)
3636
[Test]
3737
public async Task PredicateAssertionsReturnItem()
3838
{
39-
// Contains with predicate returns the found item, not the collection
40-
// This allows further assertions on the item itself
39+
// Contains with predicate can be awaited to get the found item
40+
// Or chained with .And to continue collection assertions
4141
IEnumerable<int> enumerable = [1, 2, 3];
4242

4343
try
4444
{
45+
// Test 1: Await to get the found item
46+
var item = await Assert.That(enumerable).Contains(x => x > 1);
47+
await Assert.That(item).IsGreaterThan(0);
48+
49+
// Test 2: Chain with .And for collection assertions
4550
await Assert.That(enumerable)
46-
.Contains(x => x > 1) // Returns Assertion<int> with the found item
51+
.Contains(x => x > 1)
4752
.And
48-
.IsGreaterThan(0); // Can assert on the found item
53+
.Contains(x => x > 2); // Can chain multiple Contains
4954
}
5055
catch
5156
{

TUnit.Assertions/Conditions/CollectionAssertions.cs

Lines changed: 40 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -531,51 +531,69 @@ protected override Task<AssertionResult> CheckAsync(EvaluationMetadata<TValue> m
531531

532532
/// <summary>
533533
/// Asserts that a collection contains an item matching the predicate.
534+
/// When awaited, returns the found item for further assertions.
534535
/// </summary>
535-
public class CollectionContainsPredicateAssertion<TCollection, TItem> : Assertion<TItem>
536+
public class CollectionContainsPredicateAssertion<TCollection, TItem> : Assertion<TCollection>
536537
where TCollection : IEnumerable<TItem>
537538
{
538539
private readonly Func<TItem, bool> _predicate;
540+
private TItem? _foundItem;
539541

540542
public CollectionContainsPredicateAssertion(
541543
AssertionContext<TCollection> context,
542544
Func<TItem, bool> predicate)
543-
: base(context.Map<TItem>(collection =>
544-
{
545-
if (collection == null)
546-
{
547-
throw new ArgumentNullException(nameof(collection), "collection was null");
548-
}
549-
550-
foreach (var item in collection)
551-
{
552-
if (predicate(item))
553-
{
554-
return item;
555-
}
556-
}
557-
558-
throw new InvalidOperationException("no item matching predicate found in collection");
559-
}))
545+
: base(context)
560546
{
561547
_predicate = predicate ?? throw new ArgumentNullException(nameof(predicate));
562548
}
563549

564-
protected override Task<AssertionResult> CheckAsync(EvaluationMetadata<TItem> metadata)
550+
protected override Task<AssertionResult> CheckAsync(EvaluationMetadata<TCollection> metadata)
565551
{
566552
var value = metadata.Value;
567553
var exception = metadata.Exception;
568554

569555
if (exception != null)
570556
{
571-
return Task.FromResult(AssertionResult.Failed(exception.Message));
557+
return Task.FromResult(AssertionResult.Failed($"threw {exception.GetType().Name}"));
572558
}
573559

574-
// If we got here, the item was found (the Map function succeeded)
575-
return Task.FromResult(AssertionResult.Passed);
560+
if (value == null)
561+
{
562+
return Task.FromResult(AssertionResult.Failed("collection was null"));
563+
}
564+
565+
// Search for matching item
566+
foreach (var item in value)
567+
{
568+
if (_predicate(item))
569+
{
570+
_foundItem = item;
571+
return Task.FromResult(AssertionResult.Passed);
572+
}
573+
}
574+
575+
return Task.FromResult(AssertionResult.Failed("no item matching predicate found in collection"));
576576
}
577577

578578
protected override string GetExpectation() => "to contain item matching predicate";
579+
580+
/// <summary>
581+
/// Enables await syntax that returns the found item.
582+
/// This allows both chaining (.And) and item capture (await).
583+
/// </summary>
584+
public new System.Runtime.CompilerServices.TaskAwaiter<TItem> GetAwaiter()
585+
{
586+
return ExecuteAndReturnItemAsync().GetAwaiter();
587+
}
588+
589+
private async Task<TItem> ExecuteAndReturnItemAsync()
590+
{
591+
// Execute the assertion (will throw if item not found)
592+
await AssertAsync();
593+
594+
// Return the found item
595+
return _foundItem!;
596+
}
579597
}
580598

581599
/// <summary>

TUnit.PublicAPI/Tests.Assertions_Library_Has_No_API_Changes.DotNet10_0.verified.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -384,11 +384,12 @@ namespace .Conditions
384384
protected override .<.> CheckAsync(.<TCollection> metadata) { }
385385
protected override string GetExpectation() { }
386386
}
387-
public class CollectionContainsPredicateAssertion<TCollection, TItem> : .<TItem>
387+
public class CollectionContainsPredicateAssertion<TCollection, TItem> : .<TCollection>
388388
where TCollection : .<TItem>
389389
{
390390
public CollectionContainsPredicateAssertion(.<TCollection> context, <TItem, bool> predicate) { }
391-
protected override .<.> CheckAsync(.<TItem> metadata) { }
391+
protected override .<.> CheckAsync(.<TCollection> metadata) { }
392+
public new .<TItem> GetAwaiter() { }
392393
protected override string GetExpectation() { }
393394
}
394395
public class CollectionCountAssertion<TValue> : .<TValue>

TUnit.PublicAPI/Tests.Assertions_Library_Has_No_API_Changes.DotNet8_0.verified.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -384,11 +384,12 @@ namespace .Conditions
384384
protected override .<.> CheckAsync(.<TCollection> metadata) { }
385385
protected override string GetExpectation() { }
386386
}
387-
public class CollectionContainsPredicateAssertion<TCollection, TItem> : .<TItem>
387+
public class CollectionContainsPredicateAssertion<TCollection, TItem> : .<TCollection>
388388
where TCollection : .<TItem>
389389
{
390390
public CollectionContainsPredicateAssertion(.<TCollection> context, <TItem, bool> predicate) { }
391-
protected override .<.> CheckAsync(.<TItem> metadata) { }
391+
protected override .<.> CheckAsync(.<TCollection> metadata) { }
392+
public new .<TItem> GetAwaiter() { }
392393
protected override string GetExpectation() { }
393394
}
394395
public class CollectionCountAssertion<TValue> : .<TValue>

TUnit.PublicAPI/Tests.Assertions_Library_Has_No_API_Changes.DotNet9_0.verified.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -384,11 +384,12 @@ namespace .Conditions
384384
protected override .<.> CheckAsync(.<TCollection> metadata) { }
385385
protected override string GetExpectation() { }
386386
}
387-
public class CollectionContainsPredicateAssertion<TCollection, TItem> : .<TItem>
387+
public class CollectionContainsPredicateAssertion<TCollection, TItem> : .<TCollection>
388388
where TCollection : .<TItem>
389389
{
390390
public CollectionContainsPredicateAssertion(.<TCollection> context, <TItem, bool> predicate) { }
391-
protected override .<.> CheckAsync(.<TItem> metadata) { }
391+
protected override .<.> CheckAsync(.<TCollection> metadata) { }
392+
public new .<TItem> GetAwaiter() { }
392393
protected override string GetExpectation() { }
393394
}
394395
public class CollectionCountAssertion<TValue> : .<TValue>

TUnit.PublicAPI/Tests.Assertions_Library_Has_No_API_Changes.Net4_7.verified.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -382,11 +382,12 @@ namespace .Conditions
382382
protected override .<.> CheckAsync(.<TCollection> metadata) { }
383383
protected override string GetExpectation() { }
384384
}
385-
public class CollectionContainsPredicateAssertion<TCollection, TItem> : .<TItem>
385+
public class CollectionContainsPredicateAssertion<TCollection, TItem> : .<TCollection>
386386
where TCollection : .<TItem>
387387
{
388388
public CollectionContainsPredicateAssertion(.<TCollection> context, <TItem, bool> predicate) { }
389-
protected override .<.> CheckAsync(.<TItem> metadata) { }
389+
protected override .<.> CheckAsync(.<TCollection> metadata) { }
390+
public new .<TItem> GetAwaiter() { }
390391
protected override string GetExpectation() { }
391392
}
392393
public class CollectionCountAssertion<TValue> : .<TValue>

0 commit comments

Comments
 (0)