Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ public static CheckResult AreEquivalent<TItem>(
return CheckResult.Failure("collection was null");
}

var actualList = actual.ToList();
var expectedList = expected.ToList();
// Optimize for collections that are already lists to avoid re-enumeration
var actualList = actual is List<TItem> actualListCasted ? actualListCasted : actual.ToList();
var expectedList = expected is List<TItem> expectedListCasted ? expectedListCasted : expected.ToList();

// Check counts first
if (actualList.Count != expectedList.Count)
Expand Down
95 changes: 51 additions & 44 deletions TUnit.Core/AsyncEvent.cs
Original file line number Diff line number Diff line change
@@ -1,78 +1,85 @@
using TUnit.Core.Interfaces;
using TUnit.Core.Interfaces;

namespace TUnit.Core;

public class AsyncEvent<TEventArgs>
{
public int Order
{
get;
set
{
field = value;

if (InvocationList.Count > 0)
{
InvocationList[^1].Order = field;
}
}
} = int.MaxValue / 2;

internal List<Invocation> InvocationList { get; } = [];

private static readonly Lock _newEventLock = new();
private readonly Lock _locker = new();
private List<Invocation>? _handlers;

public class Invocation(Func<object, TEventArgs, ValueTask> factory, int order) : IEventReceiver
{
public int Order
{
get;
internal set;
} = order;
public int Order { get; } = order;

public async ValueTask InvokeAsync(object sender, TEventArgs eventArgs)
{
await factory(sender, eventArgs);
}
}

public static AsyncEvent<TEventArgs> operator +(
AsyncEvent<TEventArgs>? e, Func<object, TEventArgs, ValueTask> callback
)
public void Add(Func<object, TEventArgs, ValueTask> callback, int order = int.MaxValue / 2)
{
if (callback == null)
{
throw new NullReferenceException("callback is null");
throw new ArgumentNullException(nameof(callback));
}

lock (_newEventLock)
{
e ??= new AsyncEvent<TEventArgs>();
}
var invocation = new Invocation(callback, order);
var insertIndex = FindInsertionIndex(order);
(_handlers ??= []).Insert(insertIndex, invocation);
}

lock (e._locker)
public void AddAt(Func<object, TEventArgs, ValueTask> callback, int index, int order = int.MaxValue / 2)
{
if (callback == null)
{
e.InvocationList.Add(new Invocation(callback, e.Order));
e.Order = int.MaxValue / 2;
throw new ArgumentNullException(nameof(callback));
}

return e;
var invocation = new Invocation(callback, order);
var handlers = _handlers ??= [];
var clampedIndex = index < 0 ? 0 : (index > handlers.Count ? handlers.Count : index);
handlers.Insert(clampedIndex, invocation);
}

public AsyncEvent<TEventArgs> InsertAtFront(Func<object, TEventArgs, ValueTask> callback)
public IReadOnlyList<Invocation> InvocationList
{
if (callback == null)
get
{
throw new NullReferenceException("callback is null");
}
if (_handlers == null)
{
return [];
}

return _handlers;

lock (_locker)
{
InvocationList.Insert(0, new Invocation(callback, Order));
Order = int.MaxValue / 2;
}
}

public AsyncEvent<TEventArgs> InsertAtFront(Func<object, TEventArgs, ValueTask> callback)
{
AddAt(callback, 0);
return this;
}

public static AsyncEvent<TEventArgs> operator +(
AsyncEvent<TEventArgs>? e, Func<object, TEventArgs, ValueTask> callback)
{
e ??= new AsyncEvent<TEventArgs>();
e.Add(callback);
return e;
}

private int FindInsertionIndex(int order)
{
int left = 0, right = (_handlers ??= []).Count;
while (left < right)
{
var mid = left + (right - left) / 2;
if (_handlers[mid].Order <= order)
left = mid + 1;
else
right = mid;
}
return left;
}
}
4 changes: 2 additions & 2 deletions TUnit.Core/DataGeneratorMetadataCreator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public static DataGeneratorMetadata CreateDataGeneratorMetadata(
// Filter out CancellationToken if it's the last parameter (handled by the engine)
if (generatorType == DataGeneratorType.TestParameters && parametersToGenerate.Length > 0)
{
var lastParam = parametersToGenerate[parametersToGenerate.Length - 1];
var lastParam = parametersToGenerate[^1];
if (lastParam.Type == typeof(CancellationToken))
{
var newArray = new ParameterMetadata[parametersToGenerate.Length - 1];
Expand Down Expand Up @@ -244,7 +244,7 @@ private static ParameterMetadata[] FilterOutCancellationToken(ParameterMetadata[
{
if (parameters.Length > 0)
{
var lastParam = parameters[parameters.Length - 1];
var lastParam = parameters[^1];
if (lastParam.Type == typeof(CancellationToken))
{
var newArray = new ParameterMetadata[parameters.Length - 1];
Expand Down
1 change: 1 addition & 0 deletions TUnit.Core/TUnit.Core.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
<ItemGroup Condition="'$(TargetFramework)' == 'netstandard2.0'">
<PackageReference Include="Microsoft.CSharp" />
<PackageReference Include="System.Text.Json" />
<PackageReference Include="System.Collections.Immutable" />
</ItemGroup>
<ItemGroup>
<Compile Remove="IClassDataSourceAttribute.cs" />
Expand Down
13 changes: 6 additions & 7 deletions TUnit.Engine/Building/TestBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -813,16 +813,15 @@ public async Task<AbstractExecutableTest> BuildTestAsync(TestMetadata metadata,
private static string? GetBasicSkipReason(TestMetadata metadata, Attribute[]? cachedAttributes = null)
{
var attributes = cachedAttributes ?? metadata.AttributeFactory();
var skipAttributes = attributes.OfType<SkipAttribute>().ToList();
var skipAttributes = attributes.OfType<SkipAttribute>();

if (skipAttributes.Count == 0)
{
return null; // No skip attributes
}
SkipAttribute? firstSkipAttribute = null;

// Check if all skip attributes are basic (non-derived) SkipAttribute instances
foreach (var skipAttribute in skipAttributes)
{
firstSkipAttribute ??= skipAttribute;

var attributeType = skipAttribute.GetType();
if (attributeType != typeof(SkipAttribute))
{
Expand All @@ -832,8 +831,8 @@ public async Task<AbstractExecutableTest> BuildTestAsync(TestMetadata metadata,
}

// All skip attributes are basic SkipAttribute instances
// Return the first reason (they all should skip)
return skipAttributes[0].Reason;
// Return the first reason (they all should skip), or null if no skip attributes
return firstSkipAttribute?.Reason;
}


Expand Down
2 changes: 1 addition & 1 deletion TUnit.Engine/Capabilities/StopExecutionCapability.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ public async Task StopTestExecutionAsync(CancellationToken cancellationToken)

if (OnStopRequested != null)
{
foreach (var invocation in OnStopRequested.InvocationList.OrderBy(x => x.Order))
foreach (var invocation in OnStopRequested.InvocationList)
{
await invocation.InvokeAsync(this, EventArgs.Empty);
}
Expand Down
116 changes: 14 additions & 102 deletions TUnit.Engine/ConcurrentHashSet.cs
Original file line number Diff line number Diff line change
@@ -1,122 +1,34 @@
namespace TUnit.Engine;
using System.Collections.Concurrent;

internal class ConcurrentHashSet<T>
{
private readonly ReaderWriterLockSlim _lock = new(LockRecursionPolicy.SupportsRecursion);
private readonly HashSet<T> _hashSet = [];
namespace TUnit.Engine;

#region Implementation of ICollection<T> ...ish
/// <summary>
/// Thread-safe hash set implementation using ConcurrentDictionary for better performance.
/// Provides lock-free reads and fine-grained locking for writes.
/// </summary>
internal class ConcurrentHashSet<T> where T : notnull
{
private readonly ConcurrentDictionary<T, byte> _dictionary = new();

public bool Add(T item)
{
_lock.EnterWriteLock();

try
{
return _hashSet.Add(item);
}
finally
{
if (_lock.IsWriteLockHeld)
{
_lock.ExitWriteLock();
}
}
return _dictionary.TryAdd(item, 0);
}

public void Clear()
{
_lock.EnterWriteLock();

try
{
_hashSet.Clear();
}
finally
{
if (_lock.IsWriteLockHeld)
{
_lock.ExitWriteLock();
}
}
_dictionary.Clear();
}

public bool Contains(T item)
{
_lock.EnterReadLock();

try
{
return _hashSet.Contains(item);
}
finally
{
if (_lock.IsReadLockHeld)
{
_lock.ExitReadLock();
}
}
return _dictionary.ContainsKey(item);
}

public bool Remove(T item)
{
_lock.EnterWriteLock();

try
{
return _hashSet.Remove(item);
}
finally
{
if (_lock.IsWriteLockHeld)
{
_lock.ExitWriteLock();
}
}
}

public int Count
{
get
{
_lock.EnterReadLock();

try
{
return _hashSet.Count;
}
finally
{
if (_lock.IsReadLockHeld)
{
_lock.ExitReadLock();
}
}
}
}

#endregion

#region Dispose

public void Dispose()
{
Dispose(true);
GC.SuppressFinalize(this);
}

protected virtual void Dispose(bool disposing)
{
if (disposing)
{
_lock.Dispose();
}
}

~ConcurrentHashSet()
{
Dispose(false);
return _dictionary.TryRemove(item, out _);
}

#endregion
public int Count => _dictionary.Count;
}
15 changes: 10 additions & 5 deletions TUnit.Engine/Discovery/ReflectionTestMetadata.cs
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,16 @@ async Task<object> CreateInstance(TestContext testContext)

// Create test invoker with CancellationToken support
// Determine if the test method has a CancellationToken parameter
var parameterTypes = MethodMetadata.Parameters.Select(static p => p.Type).ToArray();
var hasCancellationToken = parameterTypes.Any(t => t == typeof(CancellationToken));
var cancellationTokenIndex = hasCancellationToken
? Array.IndexOf(parameterTypes, typeof(CancellationToken))
: -1;
var cancellationTokenIndex = -1;
for (var i = 0; i < MethodMetadata.Parameters.Length; i++)
{
if (MethodMetadata.Parameters[i].Type == typeof(CancellationToken))
{
cancellationTokenIndex = i;
break;
}
}
var hasCancellationToken = cancellationTokenIndex != -1;

Func<object, object?[], TestContext, CancellationToken, Task> invokeTest = async (instance, args, testContext, cancellationToken) =>
{
Expand Down
Loading
Loading