Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 66 additions & 4 deletions TUnit.Core.SourceGenerator/Generators/HookMetadataGenerator.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System.Collections.Immutable;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using TUnit.Core.SourceGenerator.CodeGenerators.Helpers;
using TUnit.Core.SourceGenerator.CodeGenerators.Writers;
using TUnit.Core.SourceGenerator.Extensions;

Expand Down Expand Up @@ -136,6 +137,7 @@ public int GetHashCode(HookMethodMetadata? obj)
var lineNumber = location.GetLineSpan().StartLinePosition.Line + 1;

var order = GetHookOrder(hookAttribute);
var hookExecutor = GetHookExecutorType(methodSymbol);

return new HookMethodMetadata
{
Expand All @@ -146,7 +148,8 @@ public int GetHashCode(HookMethodMetadata? obj)
HookKind = hookKind,
HookType = hookType,
Order = order,
Context = context
Context = context,
HookExecutor = hookExecutor
};
}

Expand Down Expand Up @@ -237,6 +240,62 @@ private static int GetHookOrder(AttributeData attribute)
return 0;
}

private static string? GetHookExecutorType(IMethodSymbol methodSymbol)
{
var hookExecutorAttribute = methodSymbol.GetAttributes()
.FirstOrDefault(a => a.AttributeClass?.Name == "HookExecutorAttribute" ||
(a.AttributeClass?.IsGenericType == true &&
a.AttributeClass?.ConstructedFrom?.Name == "HookExecutorAttribute"));

if (hookExecutorAttribute == null)
{
return null;
}

// For generic HookExecutorAttribute<T>, get the type argument
if (hookExecutorAttribute.AttributeClass?.IsGenericType == true)
{
var typeArg = hookExecutorAttribute.AttributeClass.TypeArguments.FirstOrDefault();
return typeArg?.GloballyQualified();
}

// For non-generic HookExecutorAttribute(Type type), get the constructor argument
var typeArgument = hookExecutorAttribute.ConstructorArguments.FirstOrDefault();
if (typeArgument.Value is ITypeSymbol typeSymbol)
{
return typeSymbol.GloballyQualified();
}

return null;
}

private static string GetConcreteHookType(string dictionaryName, bool isInstance)
{
if (isInstance)
{
return "InstanceHookMethod";
}

return dictionaryName switch
{
"BeforeClassHooks" => "BeforeClassHookMethod",
"AfterClassHooks" => "AfterClassHookMethod",
"BeforeAssemblyHooks" => "BeforeAssemblyHookMethod",
"AfterAssemblyHooks" => "AfterAssemblyHookMethod",
"BeforeTestSessionHooks" => "BeforeTestSessionHookMethod",
"AfterTestSessionHooks" => "AfterTestSessionHookMethod",
"BeforeTestDiscoveryHooks" => "BeforeTestDiscoveryHookMethod",
"AfterTestDiscoveryHooks" => "AfterTestDiscoveryHookMethod",
"BeforeEveryTestHooks" => "BeforeTestHookMethod",
"AfterEveryTestHooks" => "AfterTestHookMethod",
"BeforeEveryClassHooks" => "BeforeClassHookMethod",
"AfterEveryClassHooks" => "AfterClassHookMethod",
"BeforeEveryAssemblyHooks" => "BeforeAssemblyHookMethod",
"AfterEveryAssemblyHooks" => "AfterAssemblyHookMethod",
_ => throw new ArgumentException($"Unknown dictionary name: {dictionaryName}")
};
}

private static void GenerateHookRegistry(SourceProductionContext context, ImmutableArray<HookMethodMetadata> hooks)
{
try
Expand Down Expand Up @@ -663,7 +722,8 @@ private static void GenerateHookDelegate(CodeWriter writer, HookMethodMetadata h

private static void GenerateHookListPopulation(CodeWriter writer, string dictionaryName, string typeDisplay, List<HookMethodMetadata> hooks, bool isInstance)
{
writer.AppendLine($"global::TUnit.Core.Sources.{dictionaryName}.GetOrAdd(typeof({typeDisplay}), _ => new global::System.Collections.Concurrent.ConcurrentBag<global::TUnit.Core.Hooks.{(isInstance ? "InstanceHookMethod" : $"StaticHookMethod<{GetContextType(hooks.First().HookType)}>")}>());");
var hookType = GetConcreteHookType(dictionaryName, isInstance);
writer.AppendLine($"global::TUnit.Core.Sources.{dictionaryName}.GetOrAdd(typeof({typeDisplay}), _ => new global::System.Collections.Concurrent.ConcurrentBag<global::TUnit.Core.Hooks.{hookType}>());");

foreach (var hook in hooks.OrderBy(h => h.Order))
{
Expand All @@ -679,7 +739,8 @@ private static void GenerateHookListPopulation(CodeWriter writer, string diction
private static void GenerateAssemblyHookListPopulation(CodeWriter writer, string dictionaryName, string assemblyVarName, List<HookMethodMetadata> hooks)
{
var assemblyVar = assemblyVarName.Replace(".", "_") + "_assembly";
writer.AppendLine($"global::TUnit.Core.Sources.{dictionaryName}.GetOrAdd({assemblyVar}, _ => new global::System.Collections.Concurrent.ConcurrentBag<global::TUnit.Core.Hooks.StaticHookMethod<AssemblyHookContext>>());");
var hookType = GetConcreteHookType(dictionaryName, false);
writer.AppendLine($"global::TUnit.Core.Sources.{dictionaryName}.GetOrAdd({assemblyVar}, _ => new global::System.Collections.Concurrent.ConcurrentBag<global::TUnit.Core.Hooks.{hookType}>());");

foreach (var hook in hooks.OrderBy(h => h.Order))
{
Expand Down Expand Up @@ -722,7 +783,7 @@ private static void GenerateHookObject(CodeWriter writer, HookMethodMetadata hoo
writer.Append("MethodInfo = ");
SourceInformationWriter.GenerateMethodInformation(writer, hook.Context.SemanticModel.Compilation, hook.TypeSymbol, hook.MethodSymbol, null, ',');
writer.AppendLine();
writer.AppendLine("HookExecutor = null!,");
writer.AppendLine($"HookExecutor = {HookExecutorHelper.GetHookExecutor(hook.HookExecutor)},");
writer.AppendLine($"Order = {hook.Order},");
writer.AppendLine($"Body = {delegateKey}_Body" + (isInstance ? "" : ","));

Expand Down Expand Up @@ -835,4 +896,5 @@ public class HookMethodMetadata
public required string HookType { get; init; }
public required int Order { get; init; }
public required GeneratorAttributeSyntaxContext Context { get; init; }
public string? HookExecutor { get; init; }
}
28 changes: 14 additions & 14 deletions TUnit.Core/Sources.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,23 @@ public static class Sources

public static readonly ConcurrentDictionary<Type, ConcurrentBag<Hooks.InstanceHookMethod>> BeforeTestHooks = new();
public static readonly ConcurrentDictionary<Type, ConcurrentBag<Hooks.InstanceHookMethod>> AfterTestHooks = new();
public static readonly ConcurrentBag<Hooks.StaticHookMethod<TestContext>> BeforeEveryTestHooks = new();
public static readonly ConcurrentBag<Hooks.StaticHookMethod<TestContext>> AfterEveryTestHooks = new();
public static readonly ConcurrentBag<Hooks.BeforeTestHookMethod> BeforeEveryTestHooks = new();
public static readonly ConcurrentBag<Hooks.AfterTestHookMethod> AfterEveryTestHooks = new();

public static readonly ConcurrentDictionary<Type, ConcurrentBag<Hooks.StaticHookMethod<ClassHookContext>>> BeforeClassHooks = new();
public static readonly ConcurrentDictionary<Type, ConcurrentBag<Hooks.StaticHookMethod<ClassHookContext>>> AfterClassHooks = new();
public static readonly ConcurrentBag<Hooks.StaticHookMethod<ClassHookContext>> BeforeEveryClassHooks = new();
public static readonly ConcurrentBag<Hooks.StaticHookMethod<ClassHookContext>> AfterEveryClassHooks = new();
public static readonly ConcurrentDictionary<Type, ConcurrentBag<Hooks.BeforeClassHookMethod>> BeforeClassHooks = new();
public static readonly ConcurrentDictionary<Type, ConcurrentBag<Hooks.AfterClassHookMethod>> AfterClassHooks = new();
public static readonly ConcurrentBag<Hooks.BeforeClassHookMethod> BeforeEveryClassHooks = new();
public static readonly ConcurrentBag<Hooks.AfterClassHookMethod> AfterEveryClassHooks = new();

public static readonly ConcurrentDictionary<Assembly, ConcurrentBag<Hooks.StaticHookMethod<AssemblyHookContext>>> BeforeAssemblyHooks = new();
public static readonly ConcurrentDictionary<Assembly, ConcurrentBag<Hooks.StaticHookMethod<AssemblyHookContext>>> AfterAssemblyHooks = new();
public static readonly ConcurrentBag<Hooks.StaticHookMethod<AssemblyHookContext>> BeforeEveryAssemblyHooks = new();
public static readonly ConcurrentBag<Hooks.StaticHookMethod<AssemblyHookContext>> AfterEveryAssemblyHooks = new();
public static readonly ConcurrentDictionary<Assembly, ConcurrentBag<Hooks.BeforeAssemblyHookMethod>> BeforeAssemblyHooks = new();
public static readonly ConcurrentDictionary<Assembly, ConcurrentBag<Hooks.AfterAssemblyHookMethod>> AfterAssemblyHooks = new();
public static readonly ConcurrentBag<Hooks.BeforeAssemblyHookMethod> BeforeEveryAssemblyHooks = new();
public static readonly ConcurrentBag<Hooks.AfterAssemblyHookMethod> AfterEveryAssemblyHooks = new();

public static readonly ConcurrentBag<Hooks.StaticHookMethod<TestSessionContext>> BeforeTestSessionHooks = [];
public static readonly ConcurrentBag<Hooks.StaticHookMethod<TestSessionContext>> AfterTestSessionHooks = [];
public static readonly ConcurrentBag<Hooks.StaticHookMethod<BeforeTestDiscoveryContext>> BeforeTestDiscoveryHooks = [];
public static readonly ConcurrentBag<Hooks.StaticHookMethod<TestDiscoveryContext>> AfterTestDiscoveryHooks = [];
public static readonly ConcurrentBag<Hooks.BeforeTestSessionHookMethod> BeforeTestSessionHooks = [];
public static readonly ConcurrentBag<Hooks.AfterTestSessionHookMethod> AfterTestSessionHooks = [];
public static readonly ConcurrentBag<Hooks.BeforeTestDiscoveryHookMethod> BeforeTestDiscoveryHooks = [];
public static readonly ConcurrentBag<Hooks.AfterTestDiscoveryHookMethod> AfterTestDiscoveryHooks = [];

public static readonly ConcurrentQueue<Func<Task>> GlobalInitializers = [];
public static readonly ConcurrentQueue<IPropertySource> PropertySources = [];
Expand Down
Loading
Loading