diff --git a/TUnit.Analyzers.Tests/TUnit.Analyzers.Tests.csproj b/TUnit.Analyzers.Tests/TUnit.Analyzers.Tests.csproj index 4118449776..dabf679d09 100644 --- a/TUnit.Analyzers.Tests/TUnit.Analyzers.Tests.csproj +++ b/TUnit.Analyzers.Tests/TUnit.Analyzers.Tests.csproj @@ -22,20 +22,20 @@ - - + + - - + + - - + + diff --git a/TUnit.Analyzers.Tests/Verifiers/CSharpAnalyzerVerifier.cs b/TUnit.Analyzers.Tests/Verifiers/CSharpAnalyzerVerifier.cs index b1268538d6..296fc46c1c 100644 --- a/TUnit.Analyzers.Tests/Verifiers/CSharpAnalyzerVerifier.cs +++ b/TUnit.Analyzers.Tests/Verifiers/CSharpAnalyzerVerifier.cs @@ -1,3 +1,4 @@ +using System.Collections.Immutable; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Testing; @@ -42,7 +43,11 @@ public Test() } compilationOptions = compilationOptions - .WithSpecificDiagnosticOptions(compilationOptions.SpecificDiagnosticOptions.SetItems(CSharpVerifierHelper.NullableWarnings)); + .WithSpecificDiagnosticOptions(compilationOptions.SpecificDiagnosticOptions + .SetItems(CSharpVerifierHelper.NullableWarnings) + // Suppress analyzer release tracking warnings - we're testing TUnit analyzers, not release tracking + .SetItem("RS2007", ReportDiagnostic.Suppress) + .SetItem("RS2008", ReportDiagnostic.Suppress)); solution = solution.WithProjectCompilationOptions(projectId, compilationOptions) .WithProjectParseOptions(projectId, parseOptions diff --git a/TUnit.Analyzers.Tests/Verifiers/CSharpCodeFixVerifier.cs b/TUnit.Analyzers.Tests/Verifiers/CSharpCodeFixVerifier.cs index 79656d8ff4..3683793e41 100644 --- a/TUnit.Analyzers.Tests/Verifiers/CSharpCodeFixVerifier.cs +++ b/TUnit.Analyzers.Tests/Verifiers/CSharpCodeFixVerifier.cs @@ -1,3 +1,5 @@ +using System.Collections.Immutable; +using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CodeFixes; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Testing; @@ -35,7 +37,12 @@ public Test() return solution; } - compilationOptions = compilationOptions.WithSpecificDiagnosticOptions(compilationOptions.SpecificDiagnosticOptions.SetItems(CSharpVerifierHelper.NullableWarnings)); + compilationOptions = compilationOptions + .WithSpecificDiagnosticOptions(compilationOptions.SpecificDiagnosticOptions + .SetItems(CSharpVerifierHelper.NullableWarnings) + // Suppress analyzer release tracking warnings - we're testing TUnit analyzers, not release tracking + .SetItem("RS2007", ReportDiagnostic.Suppress) + .SetItem("RS2008", ReportDiagnostic.Suppress)); solution = solution.WithProjectCompilationOptions(projectId, compilationOptions) .WithProjectParseOptions(projectId, parseOptions.WithLanguageVersion(LanguageVersion.Preview)); diff --git a/TUnit.Analyzers.Tests/Verifiers/CSharpCodeRefactoringVerifier.cs b/TUnit.Analyzers.Tests/Verifiers/CSharpCodeRefactoringVerifier.cs index 017501bcdf..18db324888 100644 --- a/TUnit.Analyzers.Tests/Verifiers/CSharpCodeRefactoringVerifier.cs +++ b/TUnit.Analyzers.Tests/Verifiers/CSharpCodeRefactoringVerifier.cs @@ -1,3 +1,4 @@ +using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CodeRefactorings; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Testing; @@ -33,7 +34,12 @@ public Test() return solution; } - compilationOptions = compilationOptions.WithSpecificDiagnosticOptions(compilationOptions.SpecificDiagnosticOptions.SetItems(CSharpVerifierHelper.NullableWarnings)); + compilationOptions = compilationOptions + .WithSpecificDiagnosticOptions(compilationOptions.SpecificDiagnosticOptions + .SetItems(CSharpVerifierHelper.NullableWarnings) + // Suppress analyzer release tracking warnings - we're testing TUnit analyzers, not release tracking + .SetItem("RS2007", ReportDiagnostic.Suppress) + .SetItem("RS2008", ReportDiagnostic.Suppress)); solution = solution.WithProjectCompilationOptions(projectId, compilationOptions) .WithProjectParseOptions(projectId, parseOptions.WithLanguageVersion(LanguageVersion.Preview)); diff --git a/TUnit.Assertions.Analyzers.Tests/AnalyzerTestHelpers.cs b/TUnit.Assertions.Analyzers.Tests/AnalyzerTestHelpers.cs new file mode 100644 index 0000000000..58efd18d48 --- /dev/null +++ b/TUnit.Assertions.Analyzers.Tests/AnalyzerTestHelpers.cs @@ -0,0 +1,186 @@ +using System.Collections.Immutable; +using System.ComponentModel; +using System.Diagnostics.CodeAnalysis; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Testing; +using Microsoft.CodeAnalysis.Diagnostics; +using Microsoft.CodeAnalysis.Testing; + +namespace TUnit.Assertions.Analyzers.Tests; + +public static class AnalyzerTestHelpers +{ + public static CSharpAnalyzerTest CreateAnalyzerTest( + [StringSyntax("c#-test")] string inputSource + ) + where TAnalyzer : DiagnosticAnalyzer, new() + { + var csTest = new CSharpAnalyzerTest + { + TestState = + { + Sources = { inputSource }, + ReferenceAssemblies = new ReferenceAssemblies( + "net8.0", + new PackageIdentity( + "Microsoft.NETCore.App.Ref", + "8.0.0"), + Path.Combine("ref", "net8.0")), + }, + }; + + csTest.TestState.AdditionalReferences + .AddRange( + [ + MetadataReference.CreateFromFile(typeof(TUnitAttribute).Assembly.Location), + MetadataReference.CreateFromFile(typeof(Assert).Assembly.Location), + ] + ); + + return csTest; + } + + public sealed class CSharpSuppressorTest : CSharpAnalyzerTest + where TSuppressor : DiagnosticSuppressor, new() + where TVerifier : IVerifier, new() + { + private readonly List _analyzers = []; + + protected override IEnumerable GetDiagnosticAnalyzers() + { + return base.GetDiagnosticAnalyzers().Concat(_analyzers); + } + + public CSharpSuppressorTest WithAnalyzer(bool enableDiagnostics = false) + where TAnalyzer : DiagnosticAnalyzer, new() + { + var analyzer = new TAnalyzer(); + _analyzers.Add(analyzer); + + if (enableDiagnostics) + { + var diagnosticOptions = analyzer.SupportedDiagnostics + .ToImmutableDictionary( + descriptor => descriptor.Id, + descriptor => descriptor.DefaultSeverity.ToReportDiagnostic() + ); + + SolutionTransforms.Clear(); + SolutionTransforms.Add(EnableDiagnostics(diagnosticOptions)); + } + + return this; + } + + public CSharpSuppressorTest WithSpecificDiagnostics( + params DiagnosticResult[] diagnostics + ) + { + var diagnosticOptions = diagnostics + .ToImmutableDictionary( + descriptor => descriptor.Id, + descriptor => descriptor.Severity.ToReportDiagnostic() + ); + + SolutionTransforms.Clear(); + SolutionTransforms.Add(EnableDiagnostics(diagnosticOptions)); + return this; + } + + private static Func EnableDiagnostics( + ImmutableDictionary diagnostics + ) => + (solution, id) => + { + var options = solution.GetProject(id)?.CompilationOptions + ?? throw new InvalidOperationException("Compilation options missing."); + + return solution + .WithProjectCompilationOptions( + id, + options + .WithSpecificDiagnosticOptions(diagnostics) + ); + }; + + public CSharpSuppressorTest WithExpectedDiagnosticsResults( + params DiagnosticResult[] diagnostics + ) + { + ExpectedDiagnostics.AddRange(diagnostics); + return this; + } + + public CSharpSuppressorTest WithCompilerDiagnostics( + CompilerDiagnostics diagnostics + ) + { + CompilerDiagnostics = diagnostics; + return this; + } + + public CSharpSuppressorTest IgnoringDiagnostics(params string[] diagnostics) + { + DisabledDiagnostics.AddRange(diagnostics); + return this; + } + } + + public static CSharpSuppressorTest CreateSuppressorTest( + [StringSyntax("c#-test")] string inputSource + ) + where TSuppressor : DiagnosticSuppressor, new() + { + var test = new CSharpSuppressorTest + { + TestCode = inputSource, + ReferenceAssemblies = GetReferenceAssemblies() + }; + + test.TestState.AdditionalReferences + .AddRange([ + MetadataReference.CreateFromFile(typeof(TUnitAttribute).Assembly.Location), + MetadataReference.CreateFromFile(typeof(Assert).Assembly.Location), + ]); + + return test; + } + + private static ReferenceAssemblies GetReferenceAssemblies() + { +#if NET472 + return ReferenceAssemblies.NetFramework.Net472.Default; +#elif NET8_0 + return ReferenceAssemblies.Net.Net80; +#elif NET9_0 + return ReferenceAssemblies.Net.Net90; +#elif NET10_0_OR_GREATER + return ReferenceAssemblies.Net.Net90; +#else + return ReferenceAssemblies.Net.Net80; // Default fallback +#endif + } + + public static CSharpSuppressorTest CreateSuppressorTest( + [StringSyntax("c#-test")] string inputSource + ) + where TSuppressor : DiagnosticSuppressor, new() + where TAnalyzer : DiagnosticAnalyzer, new() + { + return CreateSuppressorTest(inputSource) + .WithAnalyzer(enableDiagnostics: true); + } +} + +static file class DiagnosticSeverityExtensions +{ + public static ReportDiagnostic ToReportDiagnostic(this DiagnosticSeverity severity) + => severity switch + { + DiagnosticSeverity.Hidden => ReportDiagnostic.Hidden, + DiagnosticSeverity.Info => ReportDiagnostic.Info, + DiagnosticSeverity.Warning => ReportDiagnostic.Warn, + DiagnosticSeverity.Error => ReportDiagnostic.Error, + _ => throw new InvalidEnumArgumentException(nameof(severity), (int) severity, typeof(DiagnosticSeverity)), + }; +} diff --git a/TUnit.Assertions.Analyzers.Tests/IsNotNullAssertionSuppressorTests.cs b/TUnit.Assertions.Analyzers.Tests/IsNotNullAssertionSuppressorTests.cs new file mode 100644 index 0000000000..28e196efff --- /dev/null +++ b/TUnit.Assertions.Analyzers.Tests/IsNotNullAssertionSuppressorTests.cs @@ -0,0 +1,425 @@ +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.Testing; +using TUnit.Core; + +namespace TUnit.Assertions.Analyzers.Tests; + +/// +/// Tests for the IsNotNullAssertionSuppressor which suppresses nullability warnings +/// (CS8600, CS8602, CS8604, CS8618) for variables after Assert.That(x).IsNotNull(). +/// +/// Note: These tests verify that the suppressor correctly identifies and suppresses +/// nullability warnings. The suppressor does not change null-state flow analysis, +/// only suppresses the resulting warnings. +/// +public class IsNotNullAssertionSuppressorTests +{ + private static readonly DiagnosticResult CS8602 = new("CS8602", DiagnosticSeverity.Warning); + private static readonly DiagnosticResult CS8604 = new("CS8604", DiagnosticSeverity.Warning); + + [Test] + public async Task Suppresses_CS8602_After_IsNotNull_Assertion() + { + const string code = """ + #nullable enable + using System.Threading.Tasks; + using TUnit.Assertions; + using TUnit.Assertions.Extensions; + + public class MyTests + { + public async Task TestMethod() + { + string? nullableString = GetNullableString(); + + await Assert.That(nullableString).IsNotNull(); + + // This would normally produce CS8602: Dereference of a possibly null reference + // But the suppressor should suppress it after IsNotNull assertion + var length = {|#0:nullableString|}.Length; + } + + private string? GetNullableString() => "test"; + } + """; + + await AnalyzerTestHelpers + .CreateSuppressorTest(code) + .IgnoringDiagnostics("CS1591") + .WithSpecificDiagnostics(CS8602) + .WithExpectedDiagnosticsResults(CS8602.WithLocation(0).WithIsSuppressed(true)) + .WithCompilerDiagnostics(CompilerDiagnostics.Warnings) + .RunAsync(); + } + + [Test] + public async Task Suppresses_CS8604_After_IsNotNull_Assertion() + { + const string code = """ + #nullable enable + using System.Threading.Tasks; + using TUnit.Assertions; + using TUnit.Assertions.Extensions; + + public class MyTests + { + public async Task TestMethod() + { + string? nullableString = GetNullableString(); + + await Assert.That(nullableString).IsNotNull(); + + // This would normally produce CS8604: Possible null reference argument + // But the suppressor should suppress it after IsNotNull assertion + AcceptsNonNull({|#0:nullableString|}); + } + + private void AcceptsNonNull(string nonNull) { } + private string? GetNullableString() => "test"; + } + """; + + await AnalyzerTestHelpers + .CreateSuppressorTest(code) + .IgnoringDiagnostics("CS1591") + .WithSpecificDiagnostics(CS8604) + .WithExpectedDiagnosticsResults(CS8604.WithLocation(0).WithIsSuppressed(true)) + .WithCompilerDiagnostics(CompilerDiagnostics.Warnings) + .RunAsync(); + } + + [Test] + public async Task Does_Not_Suppress_Without_IsNotNull_Assertion() + { + const string code = """ + #nullable enable + using System.Threading.Tasks; + using TUnit.Assertions; + using TUnit.Assertions.Extensions; + + public class MyTests + { + public void TestMethod() + { + string? nullableString = GetNullableString(); + + // No IsNotNull assertion here + + // This should still produce CS8602 warning + var length = {|#0:nullableString|}.Length; + } + + private string? GetNullableString() => "test"; + } + """; + + await AnalyzerTestHelpers + .CreateSuppressorTest(code) + .IgnoringDiagnostics("CS1591") + .WithSpecificDiagnostics(CS8602) + .WithExpectedDiagnosticsResults(CS8602.WithLocation(0).WithIsSuppressed(false)) + .WithCompilerDiagnostics(CompilerDiagnostics.Warnings) + .RunAsync(); + } + + [Test] + public async Task Suppresses_Multiple_Uses_After_IsNotNull() + { + const string code = """ + #nullable enable + using System.Threading.Tasks; + using TUnit.Assertions; + using TUnit.Assertions.Extensions; + + public class MyTests + { + public async Task TestMethod() + { + string? nullableString = GetNullableString(); + + await Assert.That(nullableString).IsNotNull(); + + // Multiple uses should all be suppressed + var length = {|#0:nullableString|}.Length; + var upper = {|#1:nullableString|}.ToUpper(); + AcceptsNonNull({|#2:nullableString|}); + } + + private void AcceptsNonNull(string nonNull) { } + private string? GetNullableString() => "test"; + } + """; + + await AnalyzerTestHelpers + .CreateSuppressorTest(code) + .IgnoringDiagnostics("CS1591") + .WithSpecificDiagnostics(CS8602) + .WithExpectedDiagnosticsResults( + // Only the first usage generates a warning; subsequent uses benefit from flow analysis + CS8602.WithLocation(0).WithIsSuppressed(true) + ) + .WithCompilerDiagnostics(CompilerDiagnostics.Warnings) + .RunAsync(); + } + + [Test] + public async Task Suppresses_Only_Asserted_Variable() + { + const string code = """ + #nullable enable + using System.Threading.Tasks; + using TUnit.Assertions; + using TUnit.Assertions.Extensions; + + public class MyTests + { + public async Task TestMethod() + { + string? nullableString1 = GetNullableString(); + string? nullableString2 = GetNullableString(); + + await Assert.That(nullableString1).IsNotNull(); + + // nullableString1 should be suppressed + var length1 = {|#0:nullableString1|}.Length; + + // nullableString2 should NOT be suppressed (not asserted) + var length2 = {|#1:nullableString2|}.Length; + } + + private string? GetNullableString() => "test"; + } + """; + + await AnalyzerTestHelpers + .CreateSuppressorTest(code) + .IgnoringDiagnostics("CS1591") + .WithSpecificDiagnostics(CS8602) + .WithExpectedDiagnosticsResults( + CS8602.WithLocation(0).WithIsSuppressed(true), + CS8602.WithLocation(1).WithIsSuppressed(false) + ) + .WithCompilerDiagnostics(CompilerDiagnostics.Warnings) + .RunAsync(); + } + + [Test] + public async Task Suppresses_Property_Access_Chain() + { + const string code = """ + #nullable enable + using System.Threading.Tasks; + using TUnit.Assertions; + using TUnit.Assertions.Extensions; + + public class MyClass + { + public string? Property { get; set; } + } + + public class MyTests + { + public async Task TestMethod() + { + MyClass? obj = GetNullableObject(); + + await Assert.That(obj).IsNotNull(); + + // This should be suppressed + var prop = {|#0:obj|}.Property; + } + + private MyClass? GetNullableObject() => new MyClass(); + } + """; + + await AnalyzerTestHelpers + .CreateSuppressorTest(code) + .IgnoringDiagnostics("CS1591") + .WithSpecificDiagnostics(CS8602) + .WithExpectedDiagnosticsResults(CS8602.WithLocation(0).WithIsSuppressed(true)) + .WithCompilerDiagnostics(CompilerDiagnostics.Warnings) + .RunAsync(); + } + + [Test] + public async Task Suppresses_After_IsNotNull_At_Start_Of_Assertion_Chain() + { + const string code = """ + #nullable enable + using System.Threading.Tasks; + using TUnit.Assertions; + using TUnit.Assertions.Extensions; + + public class MyTests + { + public async Task TestMethod() + { + string? nullableString = GetNullableString(); + + // IsNotNull at the START of the chain + await Assert.That(nullableString).IsNotNull().And.Contains("test"); + + // After the assertion chain, should be suppressed + var length = {|#0:nullableString|}.Length; + } + + private string? GetNullableString() => "test"; + } + """; + + await AnalyzerTestHelpers + .CreateSuppressorTest(code) + .IgnoringDiagnostics("CS1591") + .WithSpecificDiagnostics(CS8602) + .WithExpectedDiagnosticsResults(CS8602.WithLocation(0).WithIsSuppressed(true)) + .WithCompilerDiagnostics(CompilerDiagnostics.Warnings) + .RunAsync(); + } + + [Test] + public async Task Suppresses_After_IsNotNull_At_End_Of_Assertion_Chain() + { + const string code = """ + #nullable enable + using System.Threading.Tasks; + using TUnit.Assertions; + using TUnit.Assertions.Extensions; + + public class MyTests + { + public async Task TestMethod() + { + string? nullableString = GetNullableString(); + + // IsNotNull at the END of the chain + await Assert.That(nullableString).Contains("test").And.IsNotNull(); + + // After the assertion chain, should be suppressed + var length = {|#0:nullableString|}.Length; + } + + private string? GetNullableString() => "test"; + } + """; + + await AnalyzerTestHelpers + .CreateSuppressorTest(code) + .IgnoringDiagnostics("CS1591") + .WithSpecificDiagnostics(CS8602) + .WithExpectedDiagnosticsResults(CS8602.WithLocation(0).WithIsSuppressed(true)) + .WithCompilerDiagnostics(CompilerDiagnostics.Warnings) + .RunAsync(); + } + + [Test] + public async Task Suppresses_After_IsNotNull_In_Middle_Of_Assertion_Chain() + { + const string code = """ + #nullable enable + using System.Threading.Tasks; + using TUnit.Assertions; + using TUnit.Assertions.Extensions; + + public class MyTests + { + public async Task TestMethod() + { + string? nullableString = GetNullableString(); + + // IsNotNull in the MIDDLE of the chain + await Assert.That(nullableString).Contains("t").And.IsNotNull().And.Contains("test"); + + // After the assertion chain, should be suppressed + var length = {|#0:nullableString|}.Length; + } + + private string? GetNullableString() => "test"; + } + """; + + await AnalyzerTestHelpers + .CreateSuppressorTest(code) + .IgnoringDiagnostics("CS1591") + .WithSpecificDiagnostics(CS8602) + .WithExpectedDiagnosticsResults(CS8602.WithLocation(0).WithIsSuppressed(true)) + .WithCompilerDiagnostics(CompilerDiagnostics.Warnings) + .RunAsync(); + } + + [Test] + public async Task Suppresses_After_IsNotNull_With_Or_Chain() + { + const string code = """ + #nullable enable + using System.Threading.Tasks; + using TUnit.Assertions; + using TUnit.Assertions.Extensions; + + public class MyTests + { + public async Task TestMethod() + { + string? nullableString = GetNullableString(); + + // IsNotNull with Or chain + await Assert.That(nullableString).IsNotNull().Or.IsEqualTo("fallback"); + + // After the assertion, should be suppressed + var length = {|#0:nullableString|}.Length; + } + + private string? GetNullableString() => "test"; + } + """; + + await AnalyzerTestHelpers + .CreateSuppressorTest(code) + .IgnoringDiagnostics("CS1591") + .WithSpecificDiagnostics(CS8602) + .WithExpectedDiagnosticsResults(CS8602.WithLocation(0).WithIsSuppressed(true)) + .WithCompilerDiagnostics(CompilerDiagnostics.Warnings) + .RunAsync(); + } + + [Test] + public async Task Suppresses_Multiple_Variables_With_Chained_Assertions() + { + const string code = """ + #nullable enable + using System.Threading.Tasks; + using TUnit.Assertions; + using TUnit.Assertions.Extensions; + + public class MyTests + { + public async Task TestMethod() + { + string? str1 = GetNullableString(); + string? str2 = GetNullableString(); + + // Both variables asserted + await Assert.That(str1).IsNotNull().And.Contains("test"); + await Assert.That(str2).IsNotNull(); + + // Both should be suppressed + var length1 = {|#0:str1|}.Length; + var length2 = {|#1:str2|}.Length; + } + + private string? GetNullableString() => "test"; + } + """; + + await AnalyzerTestHelpers + .CreateSuppressorTest(code) + .IgnoringDiagnostics("CS1591") + .WithSpecificDiagnostics(CS8602) + .WithExpectedDiagnosticsResults( + CS8602.WithLocation(0).WithIsSuppressed(true), + CS8602.WithLocation(1).WithIsSuppressed(true) + ) + .WithCompilerDiagnostics(CompilerDiagnostics.Warnings) + .RunAsync(); + } +} diff --git a/TUnit.Assertions.Analyzers/IsNotNullAssertionSuppressor.cs b/TUnit.Assertions.Analyzers/IsNotNullAssertionSuppressor.cs new file mode 100644 index 0000000000..495afe2af2 --- /dev/null +++ b/TUnit.Assertions.Analyzers/IsNotNullAssertionSuppressor.cs @@ -0,0 +1,242 @@ +using System.Collections.Immutable; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Diagnostics; + +namespace TUnit.Assertions.Analyzers; + +/// +/// Suppresses nullability warnings (CS8600, CS8602, CS8604, CS8618) for variables +/// after they have been asserted as non-null using Assert.That(x).IsNotNull(). +/// +/// Note: This suppressor only hides the warnings; it does not change the compiler's +/// null-state flow analysis. Variables will still appear as nullable in IntelliSense. +/// +[DiagnosticAnalyzer(LanguageNames.CSharp)] +public class IsNotNullAssertionSuppressor : DiagnosticSuppressor +{ + public override void ReportSuppressions(SuppressionAnalysisContext context) + { + foreach (var diagnostic in context.ReportedDiagnostics) + { + // Only process nullability warnings + if (!IsNullabilityWarning(diagnostic.Id)) + { + continue; + } + + // Get the syntax tree and semantic model + if (diagnostic.Location.SourceTree is not { } sourceTree) + { + continue; + } + + var root = sourceTree.GetRoot(); + var diagnosticSpan = diagnostic.Location.SourceSpan; + var node = root.FindNode(diagnosticSpan); + + if (node is null) + { + continue; + } + + var semanticModel = context.GetSemanticModel(sourceTree); + + // Find the variable being referenced that caused the warning + var identifierName = GetIdentifierFromNode(node); + if (identifierName is null) + { + continue; + } + + // Check if this variable was previously asserted as non-null + if (WasAssertedNotNull(identifierName, semanticModel, context.CancellationToken)) + { + Suppress(context, diagnostic); + } + } + } + + private bool IsNullabilityWarning(string diagnosticId) + { + return diagnosticId is "CS8600" // Converting null literal or possible null value to non-nullable type + or "CS8602" // Dereference of a possibly null reference + or "CS8604" // Possible null reference argument + or "CS8618"; // Non-nullable field/property uninitialized + } + + private IdentifierNameSyntax? GetIdentifierFromNode(SyntaxNode node) + { + // The warning might be on the identifier itself or a parent node + return node switch + { + IdentifierNameSyntax identifier => identifier, + MemberAccessExpressionSyntax { Expression: IdentifierNameSyntax identifier } => identifier, + ArgumentSyntax { Expression: IdentifierNameSyntax identifier } => identifier, + _ => node.DescendantNodesAndSelf().OfType().FirstOrDefault() + }; + } + + private bool WasAssertedNotNull( + IdentifierNameSyntax identifierName, + SemanticModel semanticModel, + CancellationToken cancellationToken) + { + var symbol = semanticModel.GetSymbolInfo(identifierName, cancellationToken).Symbol; + if (symbol is null) + { + return false; + } + + // Find the containing method/block + var containingMethod = identifierName.FirstAncestorOrSelf(); + if (containingMethod is null) + { + return false; + } + + // Look for Assert.That(variable).IsNotNull() patterns before this usage + var allStatements = containingMethod.DescendantNodes().OfType().ToList(); + var identifierStatement = identifierName.FirstAncestorOrSelf(); + + if (identifierStatement is null) + { + return false; + } + + var identifierStatementIndex = allStatements.IndexOf(identifierStatement); + if (identifierStatementIndex < 0) + { + return false; + } + + // Check all statements before the current one + for (int i = 0; i < identifierStatementIndex; i++) + { + var statement = allStatements[i]; + + // Look for await Assert.That(x).IsNotNull() pattern + if (IsNotNullAssertion(statement, symbol, semanticModel, cancellationToken)) + { + return true; + } + } + + return false; + } + + private bool IsNotNullAssertion( + StatementSyntax statement, + ISymbol targetSymbol, + SemanticModel semanticModel, + CancellationToken cancellationToken) + { + // Pattern: await Assert.That(variable).IsNotNull() + // or: await Assert.That(variable).Contains("test").And.IsNotNull() + // or: Assert.That(variable).IsNotNull().GetAwaiter().GetResult() + + var invocations = statement.DescendantNodes().OfType(); + + foreach (var invocation in invocations) + { + // Check if this is a call to IsNotNull() + if (invocation.Expression is not MemberAccessExpressionSyntax { Name.Identifier.Text: "IsNotNull" }) + { + continue; + } + + // Walk up the expression chain to find Assert.That() call + var assertThatCall = FindAssertThatInChain(invocation); + if (assertThatCall is null) + { + continue; + } + + // Get the argument to Assert.That() + if (assertThatCall.ArgumentList.Arguments.Count != 1) + { + continue; + } + + var argument = assertThatCall.ArgumentList.Arguments[0].Expression; + + // Get the symbol of the argument + var argumentSymbol = semanticModel.GetSymbolInfo(argument, cancellationToken).Symbol; + + // Check if it's the same symbol we're looking for + if (SymbolEqualityComparer.Default.Equals(argumentSymbol, targetSymbol)) + { + return true; + } + } + + return false; + } + + private InvocationExpressionSyntax? FindAssertThatInChain(InvocationExpressionSyntax invocation) + { + // Walk up the expression chain looking for Assert.That() + var current = invocation.Expression; + + while (current is not null) + { + if (current is InvocationExpressionSyntax invocationExpr) + { + // Check if this is Assert.That() + if (invocationExpr.Expression is MemberAccessExpressionSyntax + { + Name.Identifier.Text: "That", + Expression: IdentifierNameSyntax { Identifier.Text: "Assert" } + }) + { + return invocationExpr; + } + + // Continue walking up from this invocation + current = invocationExpr.Expression; + } + else if (current is MemberAccessExpressionSyntax memberAccess) + { + // Move to the expression being accessed + current = memberAccess.Expression; + } + else + { + break; + } + } + + return null; + } + + private void Suppress(SuppressionAnalysisContext context, Diagnostic diagnostic) + { + var suppression = SupportedSuppressions.FirstOrDefault(s => s.SuppressedDiagnosticId == diagnostic.Id); + + if (suppression is not null) + { + context.ReportSuppression( + Suppression.Create( + suppression, + diagnostic + ) + ); + } + } + + public override ImmutableArray SupportedSuppressions { get; } = + ImmutableArray.Create( + CreateDescriptor("CS8600"), + CreateDescriptor("CS8602"), + CreateDescriptor("CS8604"), + CreateDescriptor("CS8618") + ); + + private static SuppressionDescriptor CreateDescriptor(string id) + => new( + id: $"{id}Suppression", + suppressedDiagnosticId: id, + justification: $"Suppress {id} for variables asserted as non-null via Assert.That(x).IsNotNull()." + ); +} diff --git a/TUnit.Assertions.Tests/AssertNotNullTests.cs b/TUnit.Assertions.Tests/AssertNotNullTests.cs new file mode 100644 index 0000000000..537ab1f2aa --- /dev/null +++ b/TUnit.Assertions.Tests/AssertNotNullTests.cs @@ -0,0 +1,149 @@ +#nullable enable +using TUnit.Assertions.Exceptions; +using TUnit.Core; + +namespace TUnit.Assertions.Tests; + +/// +/// Tests for Assert.NotNull() and Assert.Null() methods. +/// These methods properly update null-state flow analysis via [NotNull] attribute. +/// +public class AssertNotNullTests +{ + [Test] + public async Task NotNull_WithNonNullReferenceType_DoesNotThrow() + { + string? value = "test"; + + Assert.NotNull(value); + + // After NotNull, the compiler should know value is non-null + // This should compile without warnings + var length = value.Length; + + await Assert.That(length).IsEqualTo(4); + } + + [Test] + public async Task NotNull_WithNullReferenceType_Throws() + { + string? value = null; + + var exception = Assert.Throws(() => Assert.NotNull(value)); + + await Assert.That(exception.Message).Contains("to not be null"); + } + + [Test] + public async Task NotNull_WithNonNullableValueType_DoesNotThrow() + { + int? value = 42; + + Assert.NotNull(value); + + // After NotNull, the compiler should know value has a value + var intValue = value.Value; + + await Assert.That(intValue).IsEqualTo(42); + } + + [Test] + public async Task NotNull_WithNullableValueType_Throws() + { + int? value = null; + + var exception = Assert.Throws(() => Assert.NotNull(value)); + + await Assert.That(exception.Message).Contains("to not be null"); + } + + [Test] + public void Null_WithNullReferenceType_DoesNotThrow() + { + string? value = null; + + Assert.Null(value); + + // Test passes if no exception thrown + } + + [Test] + public async Task Null_WithNonNullReferenceType_Throws() + { + string? value = "test"; + + var exception = Assert.Throws(() => Assert.Null(value)); + + await Assert.That(exception.Message).Contains("to be null"); + } + + [Test] + public void Null_WithNullValueType_DoesNotThrow() + { + int? value = null; + + Assert.Null(value); + + // Test passes if no exception thrown + } + + [Test] + public async Task Null_WithNonNullValueType_Throws() + { + int? value = 42; + + var exception = Assert.Throws(() => Assert.Null(value)); + + await Assert.That(exception.Message).Contains("to be null"); + } + + [Test] + public async Task NotNull_CapturesExpressionInMessage() + { + string? myVariable = null; + + var exception = Assert.Throws(() => Assert.NotNull(myVariable)); + + await Assert.That(exception.Message).Contains("myVariable"); + } + + [Test] + public async Task Null_CapturesExpressionInMessage() + { + string myVariable = "not null"; + + var exception = Assert.Throws(() => Assert.Null(myVariable)); + + await Assert.That(exception.Message).Contains("myVariable"); + } + + [Test] + public async Task NotNull_AllowsChainingWithOtherAssertions() + { + string? value = "test"; + + Assert.NotNull(value); + + // Can use the value directly without null-forgiving operator + await Assert.That(value.ToUpper()).IsEqualTo("TEST"); + await Assert.That(value.Length).IsEqualTo(4); + } + + [Test] + public async Task NotNull_WithComplexObject_UpdatesNullState() + { + var obj = new TestClass { Name = "test" }; + TestClass? nullableObj = obj; + + Assert.NotNull(nullableObj); + + // Should be able to access properties without warnings + var name = nullableObj.Name; + await Assert.That(name).IsEqualTo("test"); + } + + private class TestClass + { + public string? Name { get; set; } + } +} diff --git a/TUnit.Assertions/Extensions/Assert.cs b/TUnit.Assertions/Extensions/Assert.cs index 6b9d5119fb..1cd31aee78 100644 --- a/TUnit.Assertions/Extensions/Assert.cs +++ b/TUnit.Assertions/Extensions/Assert.cs @@ -192,7 +192,7 @@ public static TException Throws(Action action) action(); throw new AssertionException($"Expected {typeof(TException).Name} but no exception was thrown"); } - catch (TException ex) when (ex is not AssertionException) + catch (TException ex) when (typeof(AssertionException).IsAssignableFrom(typeof(TException)) || ex is not AssertionException) { return ex; } @@ -219,7 +219,7 @@ public static Exception Throws(Type exceptionType, Action action) action(); throw new AssertionException($"Expected {exceptionType.Name} but no exception was thrown"); } - catch (Exception ex) when (exceptionType.IsInstanceOfType(ex) && ex is not AssertionException) + catch (Exception ex) when (exceptionType.IsInstanceOfType(ex) && (typeof(AssertionException).IsAssignableFrom(exceptionType) || ex is not AssertionException)) { return ex; } @@ -358,4 +358,79 @@ public static ExceptionParameterNameAssertion ThrowsExactlyAsync(parameterName); } + + /// + /// Asserts that a value is not null (for reference types). + /// This method properly updates null-state flow analysis, allowing the compiler to treat the value as non-null after this assertion. + /// Unlike Assert.That(x).IsNotNull() (fluent API), this method changes the compiler's null-state tracking. + /// Example: Assert.NotNull(myString); // After this, myString is treated as non-null + /// + /// The value to check for null + /// The expression being asserted (captured automatically) + /// Thrown if the value is null + public static void NotNull( + [NotNull] T? value, + [CallerArgumentExpression(nameof(value))] string? expression = null) + where T : class + { + if (value is null) + { + throw new AssertionException($"Expected {expression ?? "value"} to not be null, but it was null"); + } + } + + /// + /// Asserts that a nullable value type is not null. + /// This method properly updates null-state flow analysis, allowing the compiler to treat the value as non-null after this assertion. + /// Example: Assert.NotNull(myNullableInt); // After this, myNullableInt is treated as having a value + /// + /// The nullable value to check + /// The expression being asserted (captured automatically) + /// Thrown if the value is null + public static void NotNull( + [NotNull] T? value, + [CallerArgumentExpression(nameof(value))] string? expression = null) + where T : struct + { + if (!value.HasValue) + { + throw new AssertionException($"Expected {expression ?? "value"} to not be null, but it was null"); + } + } + + /// + /// Asserts that a value is null (for reference types). + /// Example: Assert.Null(myString); + /// + /// The value to check for null + /// The expression being asserted (captured automatically) + /// Thrown if the value is not null + public static void Null( + T? value, + [CallerArgumentExpression(nameof(value))] string? expression = null) + where T : class + { + if (value is not null) + { + throw new AssertionException($"Expected {expression ?? "value"} to be null, but it was {value}"); + } + } + + /// + /// Asserts that a nullable value type is null. + /// Example: Assert.Null(myNullableInt); + /// + /// The nullable value to check + /// The expression being asserted (captured automatically) + /// Thrown if the value is not null + public static void Null( + T? value, + [CallerArgumentExpression(nameof(value))] string? expression = null) + where T : struct + { + if (value.HasValue) + { + throw new AssertionException($"Expected {expression ?? "value"} to be null, but it was {value.Value}"); + } + } } diff --git a/TUnit.PublicAPI/Tests.Assertions_Library_Has_No_API_Changes.DotNet10_0.verified.txt b/TUnit.PublicAPI/Tests.Assertions_Library_Has_No_API_Changes.DotNet10_0.verified.txt index 8b838e794c..cbef291244 100644 --- a/TUnit.PublicAPI/Tests.Assertions_Library_Has_No_API_Changes.DotNet10_0.verified.txt +++ b/TUnit.PublicAPI/Tests.Assertions_Library_Has_No_API_Changes.DotNet10_0.verified.txt @@ -5,6 +5,14 @@ namespace { public static void Fail(string reason) { } public static Multiple() { } + public static void NotNull([.] T? value, [.("value")] string? expression = null) + where T : class { } + public static void NotNull([.] T? value, [.("value")] string? expression = null) + where T : struct { } + public static void Null(T? value, [.("value")] string? expression = null) + where T : class { } + public static void Null(T? value, [.("value")] string? expression = null) + where T : struct { } public static . That( action, [.("action")] string? expression = null) { } public static . That(.IEnumerable value, [.("value")] string? expression = null) { } public static . That(<.> action, [.("action")] string? expression = null) { } diff --git a/TUnit.PublicAPI/Tests.Assertions_Library_Has_No_API_Changes.DotNet8_0.verified.txt b/TUnit.PublicAPI/Tests.Assertions_Library_Has_No_API_Changes.DotNet8_0.verified.txt index 05c906c479..13d23a7731 100644 --- a/TUnit.PublicAPI/Tests.Assertions_Library_Has_No_API_Changes.DotNet8_0.verified.txt +++ b/TUnit.PublicAPI/Tests.Assertions_Library_Has_No_API_Changes.DotNet8_0.verified.txt @@ -5,6 +5,14 @@ namespace { public static void Fail(string reason) { } public static Multiple() { } + public static void NotNull([.] T? value, [.("value")] string? expression = null) + where T : class { } + public static void NotNull([.] T? value, [.("value")] string? expression = null) + where T : struct { } + public static void Null(T? value, [.("value")] string? expression = null) + where T : class { } + public static void Null(T? value, [.("value")] string? expression = null) + where T : struct { } public static . That( action, [.("action")] string? expression = null) { } public static . That(.IEnumerable value, [.("value")] string? expression = null) { } public static . That(<.> action, [.("action")] string? expression = null) { } diff --git a/TUnit.PublicAPI/Tests.Assertions_Library_Has_No_API_Changes.DotNet9_0.verified.txt b/TUnit.PublicAPI/Tests.Assertions_Library_Has_No_API_Changes.DotNet9_0.verified.txt index 941d60bc07..9adf2916fc 100644 --- a/TUnit.PublicAPI/Tests.Assertions_Library_Has_No_API_Changes.DotNet9_0.verified.txt +++ b/TUnit.PublicAPI/Tests.Assertions_Library_Has_No_API_Changes.DotNet9_0.verified.txt @@ -5,6 +5,14 @@ namespace { public static void Fail(string reason) { } public static Multiple() { } + public static void NotNull([.] T? value, [.("value")] string? expression = null) + where T : class { } + public static void NotNull([.] T? value, [.("value")] string? expression = null) + where T : struct { } + public static void Null(T? value, [.("value")] string? expression = null) + where T : class { } + public static void Null(T? value, [.("value")] string? expression = null) + where T : struct { } public static . That( action, [.("action")] string? expression = null) { } public static . That(.IEnumerable value, [.("value")] string? expression = null) { } public static . That(<.> action, [.("action")] string? expression = null) { } diff --git a/TUnit.PublicAPI/Tests.Assertions_Library_Has_No_API_Changes.Net4_7.verified.txt b/TUnit.PublicAPI/Tests.Assertions_Library_Has_No_API_Changes.Net4_7.verified.txt index 67f8cfe8c3..a2f6c9c0ea 100644 --- a/TUnit.PublicAPI/Tests.Assertions_Library_Has_No_API_Changes.Net4_7.verified.txt +++ b/TUnit.PublicAPI/Tests.Assertions_Library_Has_No_API_Changes.Net4_7.verified.txt @@ -5,6 +5,14 @@ namespace { public static void Fail(string reason) { } public static Multiple() { } + public static void NotNull([.] T? value, [.("value")] string? expression = null) + where T : class { } + public static void NotNull([.] T? value, [.("value")] string? expression = null) + where T : struct { } + public static void Null(T? value, [.("value")] string? expression = null) + where T : class { } + public static void Null(T? value, [.("value")] string? expression = null) + where T : struct { } public static . That( action, [.("action")] string? expression = null) { } public static . That(.IEnumerable value, [.("value")] string? expression = null) { } public static . That(<.> action, [.("action")] string? expression = null) { }