diff --git a/TUnit.Core.SourceGenerator/Generators/HookMetadataGenerator.cs b/TUnit.Core.SourceGenerator/Generators/HookMetadataGenerator.cs index 1162f4e66c..f44233da9f 100644 --- a/TUnit.Core.SourceGenerator/Generators/HookMetadataGenerator.cs +++ b/TUnit.Core.SourceGenerator/Generators/HookMetadataGenerator.cs @@ -114,7 +114,7 @@ public int GetHashCode(HookMethodMetadata? obj) } var typeSymbol = methodSymbol.ContainingType; - if (typeSymbol == null || typeSymbol is not { }) + if (typeSymbol is not { }) { return null; } @@ -125,6 +125,7 @@ public int GetHashCode(HookMethodMetadata? obj) } var hookAttribute = context.Attributes[0]; + var hookType = GetHookType(hookAttribute); if (!IsValidHookMethod(methodSymbol, hookType)) @@ -133,7 +134,7 @@ public int GetHashCode(HookMethodMetadata? obj) } var location = context.TargetNode.GetLocation(); - var filePath = location.SourceTree?.FilePath ?? ""; + var filePath = location.SourceTree?.FilePath ?? hookAttribute.ConstructorArguments.ElementAtOrDefault(1).Value?.ToString() ?? ""; var lineNumber = location.GetLineSpan().StartLinePosition.Line + 1; var order = GetHookOrder(hookAttribute); @@ -243,8 +244,8 @@ private static int GetHookOrder(AttributeData attribute) private static string? GetHookExecutorType(IMethodSymbol methodSymbol) { var hookExecutorAttribute = methodSymbol.GetAttributes() - .FirstOrDefault(a => a.AttributeClass?.Name == "HookExecutorAttribute" || - (a.AttributeClass?.IsGenericType == true && + .FirstOrDefault(a => a.AttributeClass?.Name == "HookExecutorAttribute" || + (a.AttributeClass?.IsGenericType == true && a.AttributeClass?.ConstructedFrom?.Name == "HookExecutorAttribute")); if (hookExecutorAttribute == null) @@ -617,7 +618,7 @@ private static void GenerateHookDelegate(CodeWriter writer, HookMethodMetadata h } writer.AppendLine($"var result = method.Invoke({(isStatic ? "null" : "instance")}, methodArgs);"); - + if (!hook.MethodSymbol.ReturnsVoid) { writer.AppendLine("if (result != null)"); @@ -916,7 +917,7 @@ private static string GetHookIndexMethodName(HookMethodMetadata hook) var prefix = hook.HookKind == "Before" || hook.HookKind == "BeforeEvery" ? "Before" : "After"; var suffix = hook.HookKind.Contains("Every") && hook.HookType != "TestSession" && hook.HookType != "TestDiscovery" ? "Every" : ""; var hookType = hook.HookType; - + return $"{prefix}{suffix}{hookType}HookIndex()"; } } diff --git a/TUnit.Core.SourceGenerator/Generators/TestMetadataGenerator.cs b/TUnit.Core.SourceGenerator/Generators/TestMetadataGenerator.cs index 7666f5c7c8..5605f8dbaf 100644 --- a/TUnit.Core.SourceGenerator/Generators/TestMetadataGenerator.cs +++ b/TUnit.Core.SourceGenerator/Generators/TestMetadataGenerator.cs @@ -79,6 +79,8 @@ public void Initialize(IncrementalGeneratorInitializationContext context) return null; } + var testAttribute = methodSymbol!.GetRequiredTestAttribute(); + // Skip abstract classes (cannot be instantiated) if (containingType.IsAbstract) { @@ -91,9 +93,9 @@ public void Initialize(IncrementalGeneratorInitializationContext context) return new TestMethodMetadata { - MethodSymbol = methodSymbol ?? throw new global::System.InvalidOperationException("Symbol is not a method"), + MethodSymbol = methodSymbol ?? throw new InvalidOperationException("Symbol is not a method"), TypeSymbol = containingType, - FilePath = methodSyntax.SyntaxTree.FilePath, + FilePath = methodSyntax.GetLocation().SourceTree?.FilePath ?? testAttribute.ConstructorArguments.ElementAtOrDefault(0).Value?.ToString() ?? methodSyntax.SyntaxTree.FilePath, LineNumber = methodSyntax.GetLocation().GetLineSpan().StartLinePosition.Line + 1, TestAttribute = context.Attributes.First(), Context = context, @@ -132,7 +134,7 @@ private static void GenerateInheritedTestSources(SourceProductionContext context { MethodSymbol = concreteMethod ?? method, // Use concrete method if found, otherwise base method TypeSymbol = classInfo.TypeSymbol, - FilePath = classInfo.ClassSyntax.SyntaxTree.FilePath, + FilePath = classInfo.ClassSyntax.GetLocation().SourceTree?.FilePath ?? testAttribute.ConstructorArguments.ElementAtOrDefault(0).Value?.ToString() ?? classInfo.ClassSyntax.SyntaxTree.FilePath, LineNumber = classInfo.ClassSyntax.GetLocation().GetLineSpan().StartLinePosition.Line + 1, TestAttribute = testAttribute, Context = null, // No context for inherited tests @@ -1347,13 +1349,13 @@ private static void GeneratePropertyInjectionsForType(CodeWriter writer, ITypeSy if (property.SetMethod.IsInitOnly) { // For nested init-only properties with ClassDataSource, create the value if null - if (dataSourceAttr != null && + if (dataSourceAttr != null && dataSourceAttr.AttributeClass?.IsOrInherits("global::TUnit.Core.ClassDataSourceAttribute") == true && dataSourceAttr.AttributeClass is { IsGenericType: true, TypeArguments.Length: > 0 }) { var dataSourceType = dataSourceAttr.AttributeClass.TypeArguments[0]; var fullyQualifiedType = dataSourceType.GloballyQualified(); - + writer.AppendLine("Setter = (instance, value) =>"); writer.AppendLine("{"); writer.Indent(); @@ -1427,10 +1429,10 @@ private static void GeneratePropertyValueExtraction(CodeWriter writer, ITypeSymb var currentType = typeSymbol; var processedProperties = new HashSet(); var className = typeSymbol.GloballyQualified(); - + // Generate a single cast check and extract all properties var propertiesWithDataSource = new List(); - + while (currentType != null) { foreach (var member in currentType.GetMembers()) @@ -1450,19 +1452,19 @@ private static void GeneratePropertyValueExtraction(CodeWriter writer, ITypeSymb } currentType = currentType.BaseType; } - + // Generate a single if statement with all property extractions if (propertiesWithDataSource.Any()) { writer.AppendLine($"if (obj is {className} typedObj)"); writer.AppendLine("{"); writer.Indent(); - + foreach (var property in propertiesWithDataSource) { writer.AppendLine($"nestedValues[\"{property.Name}\"] = typedObj.{property.Name};"); } - + writer.Unindent(); writer.AppendLine("}"); } @@ -1832,7 +1834,7 @@ private static void GenerateDependencies(CodeWriter writer, Compilation compilat private static void GenerateTestDependency(CodeWriter writer, AttributeData attributeData) { var constructorArgs = attributeData.ConstructorArguments; - + // Extract ProceedOnFailure property value var proceedOnFailure = GetProceedOnFailureValue(attributeData); @@ -1954,7 +1956,7 @@ private static bool GetProceedOnFailureValue(AttributeData attributeData) return proceedOnFailure; } } - + // Default value is false return false; }