diff --git a/TUnit.Core.SourceGenerator/CodeGenerators/Formatting/TypedConstantFormatter.cs b/TUnit.Core.SourceGenerator/CodeGenerators/Formatting/TypedConstantFormatter.cs index 04918aa600..9a3a20fd3b 100644 --- a/TUnit.Core.SourceGenerator/CodeGenerators/Formatting/TypedConstantFormatter.cs +++ b/TUnit.Core.SourceGenerator/CodeGenerators/Formatting/TypedConstantFormatter.cs @@ -12,6 +12,16 @@ public string FormatForCode(TypedConstant constant, ITypeSymbol? targetType = nu { if (constant.IsNull) { + // If we have a nullable enum target type, cast null to that type + if (targetType?.IsNullableValueType() == true) + { + var underlyingType = targetType.GetNullableUnderlyingType(); + if (underlyingType?.TypeKind == TypeKind.Enum) + { + // For nullable enums, we need to cast null to the nullable enum type + return $"({targetType.GloballyQualified()})null"; + } + } return "null"; } @@ -218,7 +228,20 @@ private string FormatPrimitiveForCode(object? value, ITypeSymbol? targetType) private string FormatEnumForCode(TypedConstant constant, ITypeSymbol? targetType) { - var enumType = targetType as INamedTypeSymbol ?? constant.Type as INamedTypeSymbol; + // Check if target type is a nullable enum, and if so, get the underlying enum type + var isNullableEnum = targetType?.IsNullableValueType() == true; + INamedTypeSymbol? enumType = null; + + if (isNullableEnum) + { + var underlyingType = targetType!.GetNullableUnderlyingType(); + enumType = underlyingType as INamedTypeSymbol; + } + else + { + enumType = targetType as INamedTypeSymbol ?? constant.Type as INamedTypeSymbol; + } + if (enumType == null) { return FormatPrimitive(constant.Value); @@ -227,14 +250,28 @@ private string FormatEnumForCode(TypedConstant constant, ITypeSymbol? targetType var memberName = GetEnumMemberName(enumType, constant.Value); if (memberName != null) { - return $"{enumType.GloballyQualified()}.{memberName}"; + var formattedEnum = $"{enumType.GloballyQualified()}.{memberName}"; + // If the target type is nullable, cast the enum value to the nullable type + if (isNullableEnum) + { + return $"({targetType!.GloballyQualified()}){formattedEnum}"; + } + return formattedEnum; } // Fallback to cast syntax var formattedValue = FormatPrimitive(constant.Value); - return formattedValue != null && formattedValue.StartsWith("-") + var result = formattedValue != null && formattedValue.StartsWith("-") ? $"({enumType.GloballyQualified()})({formattedValue})" : $"({enumType.GloballyQualified()}){formattedValue}"; + + // If the target type is nullable, wrap the result in a cast to the nullable type + if (isNullableEnum) + { + return $"({targetType!.GloballyQualified()})({result})"; + } + + return result; } private string FormatArrayForCode(TypedConstant constant, ITypeSymbol? targetType = null) diff --git a/TUnit.TestProject/Bugs/3185/BugRepro3185.cs b/TUnit.TestProject/Bugs/3185/BugRepro3185.cs new file mode 100644 index 0000000000..9e6b9ec024 --- /dev/null +++ b/TUnit.TestProject/Bugs/3185/BugRepro3185.cs @@ -0,0 +1,110 @@ +using TUnit.TestProject.Attributes; + +namespace TUnit.TestProject.Bugs._3185; + +[Flags] +public enum FlagMock +{ + One = 1, + Two = 2, + Three = 4 +} + +public enum RegularEnum +{ + None = 0, + First = 1, + Second = 2, + Third = 3 +} + +public static class FlagsHelper +{ + public static FlagMock? GetFlags(FlagMock[] flags) + { + if (flags == null || flags.Length == 0) + return null; + + FlagMock result = 0; + foreach (var flag in flags) + { + result |= flag; + } + return result; + } + + public static RegularEnum? ProcessEnum(RegularEnum? input) + { + return input; + } +} + +[EngineTest(ExpectedResult.Pass)] +public class NullableEnumParameterTests +{ + [Test] + [Arguments(new FlagMock[] { }, null)] + [Arguments(new FlagMock[] { FlagMock.Two }, FlagMock.Two)] + [Arguments(new FlagMock[] { FlagMock.One, FlagMock.Three }, FlagMock.One | FlagMock.Three)] + public async Task Nullable_FlagsEnum_WithNull(FlagMock[] flags, FlagMock? expected) + { + await Assert.That(FlagsHelper.GetFlags(flags)).IsEqualTo(expected); + } + + [Test] + [Arguments(null)] + [Arguments(RegularEnum.First)] + [Arguments(RegularEnum.Second)] + [Arguments(RegularEnum.Third)] + public async Task Nullable_RegularEnum_SingleParam(RegularEnum? value) + { + await Assert.That(FlagsHelper.ProcessEnum(value)).IsEqualTo(value); + } + + [Test] + [Arguments(null, null)] + [Arguments(RegularEnum.First, RegularEnum.First)] + [Arguments(RegularEnum.Second, RegularEnum.Third)] + public async Task Nullable_RegularEnum_MultipleParams(RegularEnum? input, RegularEnum? expected) + { + if (input == RegularEnum.Second) + { + await Assert.That(expected).IsEqualTo(RegularEnum.Third); + } + else + { + await Assert.That(input).IsEqualTo(expected); + } + } + + [Test] + [Arguments(FlagMock.One, null)] + [Arguments(FlagMock.Two, FlagMock.Two)] + [Arguments(null, null)] + public async Task Nullable_FlagsEnum_MixedParams(FlagMock? input, FlagMock? expected) + { + if (input == FlagMock.One) + { + await Assert.That((FlagMock?)null).IsEqualTo(expected); + } + else + { + await Assert.That(input).IsEqualTo(expected); + } + } + + [Test] + [Arguments(1, RegularEnum.First, null)] + [Arguments(2, null, FlagMock.Two)] + [Arguments(3, RegularEnum.Third, FlagMock.One | FlagMock.Three)] + public async Task Nullable_MixedEnumTypes(int id, RegularEnum? regular, FlagMock? flags) + { + await Assert.That(id).IsGreaterThan(0); + + if (id == 2) + { + await Assert.That(regular).IsNull(); + await Assert.That(flags).IsNotNull(); + } + } +}