diff --git a/src/CommunityToolkit.Mvvm.CodeFixers/AsyncVoidReturningRelayCommandMethodCodeFixer.cs b/src/CommunityToolkit.Mvvm.CodeFixers/AsyncVoidReturningRelayCommandMethodCodeFixer.cs
new file mode 100644
index 000000000..78139433f
--- /dev/null
+++ b/src/CommunityToolkit.Mvvm.CodeFixers/AsyncVoidReturningRelayCommandMethodCodeFixer.cs
@@ -0,0 +1,85 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Collections.Immutable;
+using System.Composition;
+using System.Threading;
+using System.Threading.Tasks;
+using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.CodeActions;
+using Microsoft.CodeAnalysis.CodeFixes;
+using Microsoft.CodeAnalysis.CSharp.Syntax;
+using Microsoft.CodeAnalysis.Editing;
+using Microsoft.CodeAnalysis.Simplification;
+using Microsoft.CodeAnalysis.Text;
+using static CommunityToolkit.Mvvm.SourceGenerators.Diagnostics.DiagnosticDescriptors;
+
+namespace CommunityToolkit.Mvvm.CodeFixers;
+
+///
+/// A code fixer that automatically updates the return type of methods using [RelayCommand] to return a instead.
+///
+[ExportCodeFixProvider(LanguageNames.CSharp)]
+[Shared]
+public sealed class AsyncVoidReturningRelayCommandMethodCodeFixer : CodeFixProvider
+{
+ ///
+ public override ImmutableArray FixableDiagnosticIds { get; } = ImmutableArray.Create(AsyncVoidReturningRelayCommandMethodId);
+
+ ///
+ public override FixAllProvider? GetFixAllProvider()
+ {
+ return WellKnownFixAllProviders.BatchFixer;
+ }
+
+ ///
+ public override async Task RegisterCodeFixesAsync(CodeFixContext context)
+ {
+ Diagnostic diagnostic = context.Diagnostics[0];
+ TextSpan diagnosticSpan = context.Span;
+
+ SyntaxNode? root = await context.Document.GetSyntaxRootAsync(context.CancellationToken).ConfigureAwait(false);
+
+ // Get the method declaration from the target diagnostic
+ if (root!.FindNode(diagnosticSpan) is MethodDeclarationSyntax methodDeclaration)
+ {
+ // Register the code fix to update the return type to be Task instead
+ context.RegisterCodeFix(
+ CodeAction.Create(
+ title: "Change return type to Task",
+ createChangedDocument: token => ChangeReturnType(context.Document, root, methodDeclaration, token),
+ equivalenceKey: "Change return type to Task"),
+ diagnostic);
+ }
+ }
+
+ ///
+ /// Applies the code fix to a target method declaration and returns an updated document.
+ ///
+ /// The original document being fixed.
+ /// The original tree root belonging to the current document.
+ /// The to update.
+ /// The cancellation token for the operation.
+ /// An updated document with the applied code fix, and the return type of the method being .
+ private static async Task ChangeReturnType(Document document, SyntaxNode root, MethodDeclarationSyntax methodDeclaration, CancellationToken cancellationToken)
+ {
+ // Get the semantic model (bail if it's not available)
+ if (await document.GetSemanticModelAsync(cancellationToken) is not SemanticModel semanticModel)
+ {
+ return document;
+ }
+
+ // Also bail if we can't resolve the Task symbol (this should really never happen)
+ if (semanticModel.Compilation.GetTypeByMetadataName("System.Threading.Tasks.Task") is not INamedTypeSymbol taskSymbol)
+ {
+ return document;
+ }
+
+ // Create the new syntax node for the return, and configure it to automatically add "using System.Threading.Tasks;" if needed
+ SyntaxNode typeSyntax = SyntaxGenerator.GetGenerator(document).TypeExpression(taskSymbol).WithAdditionalAnnotations(Simplifier.AddImportsAnnotation);
+
+ // Replace the void return type with the newly created Task type expression
+ return document.WithSyntaxRoot(root.ReplaceNode(methodDeclaration.ReturnType, typeSyntax));
+ }
+}
diff --git a/src/CommunityToolkit.Mvvm.CodeFixers/ClassUsingAttributeInsteadOfInheritanceCodeFixer.cs b/src/CommunityToolkit.Mvvm.CodeFixers/ClassUsingAttributeInsteadOfInheritanceCodeFixer.cs
index 9b19fc954..817b3c546 100644
--- a/src/CommunityToolkit.Mvvm.CodeFixers/ClassUsingAttributeInsteadOfInheritanceCodeFixer.cs
+++ b/src/CommunityToolkit.Mvvm.CodeFixers/ClassUsingAttributeInsteadOfInheritanceCodeFixer.cs
@@ -40,9 +40,9 @@ public sealed class ClassUsingAttributeInsteadOfInheritanceCodeFixer : CodeFixPr
public override async Task RegisterCodeFixesAsync(CodeFixContext context)
{
Diagnostic diagnostic = context.Diagnostics[0];
- TextSpan diagnosticSpan = diagnostic.Location.SourceSpan;
+ TextSpan diagnosticSpan = context.Span;
- // Retrieve the property passed by the analyzer
+ // Retrieve the properties passed by the analyzer
if (diagnostic.Properties[ClassUsingAttributeInsteadOfInheritanceAnalyzer.TypeNameKey] is not string typeName ||
diagnostic.Properties[ClassUsingAttributeInsteadOfInheritanceAnalyzer.AttributeTypeNameKey] is not string attributeTypeName)
{
@@ -59,11 +59,9 @@ public override async Task RegisterCodeFixesAsync(CodeFixContext context)
context.RegisterCodeFix(
CodeAction.Create(
title: "Inherit from ObservableObject",
- createChangedDocument: token => UpdateReference(context.Document, root, classDeclaration, attributeTypeName),
+ createChangedDocument: token => RemoveAttribute(context.Document, root, classDeclaration, attributeTypeName),
equivalenceKey: "Inherit from ObservableObject"),
diagnostic);
-
- return;
}
}
@@ -75,7 +73,7 @@ public override async Task RegisterCodeFixesAsync(CodeFixContext context)
/// The to update.
/// The name of the attribute that should be removed.
/// An updated document with the applied code fix, and inheriting from ObservableObject.
- private static Task UpdateReference(Document document, SyntaxNode root, ClassDeclarationSyntax classDeclaration, string attributeTypeName)
+ private static Task RemoveAttribute(Document document, SyntaxNode root, ClassDeclarationSyntax classDeclaration, string attributeTypeName)
{
// Insert ObservableObject always in first position in the base list. The type might have
// some interfaces in the base list, so we just copy them back after ObservableObject.
diff --git a/src/CommunityToolkit.Mvvm.CodeFixers/FieldReferenceForObservablePropertyFieldCodeFixer.cs b/src/CommunityToolkit.Mvvm.CodeFixers/FieldReferenceForObservablePropertyFieldCodeFixer.cs
index 3788192b3..9639705b4 100644
--- a/src/CommunityToolkit.Mvvm.CodeFixers/FieldReferenceForObservablePropertyFieldCodeFixer.cs
+++ b/src/CommunityToolkit.Mvvm.CodeFixers/FieldReferenceForObservablePropertyFieldCodeFixer.cs
@@ -37,7 +37,7 @@ public sealed class FieldReferenceForObservablePropertyFieldCodeFixer : CodeFixP
public override async Task RegisterCodeFixesAsync(CodeFixContext context)
{
Diagnostic diagnostic = context.Diagnostics[0];
- TextSpan diagnosticSpan = diagnostic.Location.SourceSpan;
+ TextSpan diagnosticSpan = context.Span;
// Retrieve the properties passed by the analyzer
if (diagnostic.Properties[FieldReferenceForObservablePropertyFieldAnalyzer.FieldNameKey] is not string fieldName ||
diff --git a/src/CommunityToolkit.Mvvm.SourceGenerators/AnalyzerReleases.Shipped.md b/src/CommunityToolkit.Mvvm.SourceGenerators/AnalyzerReleases.Shipped.md
index 0f9819023..ef01360ba 100644
--- a/src/CommunityToolkit.Mvvm.SourceGenerators/AnalyzerReleases.Shipped.md
+++ b/src/CommunityToolkit.Mvvm.SourceGenerators/AnalyzerReleases.Shipped.md
@@ -66,3 +66,4 @@ Rule ID | Category | Severity | Notes
--------|----------|----------|-------
MVVMTK0037 | CommunityToolkit.Mvvm.SourceGenerators.ObservablePropertyGenerator | Error | See https://aka.ms/mvvmtoolkit/errors/mvvmtk0037
MVVMTK0038 | CommunityToolkit.Mvvm.SourceGenerators.RelayCommandGenerator | Error | See https://aka.ms/mvvmtoolkit/errors/mvvmtk0038
+MVVMTK0039 | CommunityToolkit.Mvvm.SourceGenerators.RelayCommandGenerator | Warning | See https://aka.ms/mvvmtoolkit/errors/mvvmtk0039
diff --git a/src/CommunityToolkit.Mvvm.SourceGenerators/CommunityToolkit.Mvvm.SourceGenerators.projitems b/src/CommunityToolkit.Mvvm.SourceGenerators/CommunityToolkit.Mvvm.SourceGenerators.projitems
index 401916e36..edcdffa58 100644
--- a/src/CommunityToolkit.Mvvm.SourceGenerators/CommunityToolkit.Mvvm.SourceGenerators.projitems
+++ b/src/CommunityToolkit.Mvvm.SourceGenerators/CommunityToolkit.Mvvm.SourceGenerators.projitems
@@ -39,6 +39,7 @@
+
diff --git a/src/CommunityToolkit.Mvvm.SourceGenerators/Diagnostics/Analyzers/AsyncVoidReturningRelayCommandMethodAnalyzer.cs b/src/CommunityToolkit.Mvvm.SourceGenerators/Diagnostics/Analyzers/AsyncVoidReturningRelayCommandMethodAnalyzer.cs
new file mode 100644
index 000000000..8b8688d1e
--- /dev/null
+++ b/src/CommunityToolkit.Mvvm.SourceGenerators/Diagnostics/Analyzers/AsyncVoidReturningRelayCommandMethodAnalyzer.cs
@@ -0,0 +1,59 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Collections.Immutable;
+using System.Linq;
+using CommunityToolkit.Mvvm.SourceGenerators.Extensions;
+using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.Diagnostics;
+using static CommunityToolkit.Mvvm.SourceGenerators.Diagnostics.DiagnosticDescriptors;
+
+namespace CommunityToolkit.Mvvm.SourceGenerators;
+
+///
+/// A diagnostic analyzer that generates a warning when using [RelayCommand] over an method.
+///
+[DiagnosticAnalyzer(LanguageNames.CSharp)]
+public sealed class AsyncVoidReturningRelayCommandMethodAnalyzer : DiagnosticAnalyzer
+{
+ ///
+ public override ImmutableArray SupportedDiagnostics { get; } = ImmutableArray.Create(AsyncVoidReturningRelayCommandMethod);
+
+ ///
+ public override void Initialize(AnalysisContext context)
+ {
+ context.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.Analyze | GeneratedCodeAnalysisFlags.ReportDiagnostics);
+ context.EnableConcurrentExecution();
+
+ context.RegisterCompilationStartAction(static context =>
+ {
+ // Get the symbol for [RelayCommand]
+ if (context.Compilation.GetTypeByMetadataName("CommunityToolkit.Mvvm.Input.RelayCommandAttribute") is not INamedTypeSymbol relayCommandSymbol)
+ {
+ return;
+ }
+
+ context.RegisterSymbolAction(context =>
+ {
+ // We're only looking for async void methods
+ if (context.Symbol is not IMethodSymbol { IsAsync: true, ReturnsVoid: true } methodSymbol)
+ {
+ return;
+ }
+
+ // We only care about methods annotated with [RelayCommand]
+ if (!methodSymbol.HasAttributeWithType(relayCommandSymbol))
+ {
+ return;
+ }
+
+ // Warn on async void methods using [RelayCommand] (they should return a Task instead)
+ context.ReportDiagnostic(Diagnostic.Create(
+ AsyncVoidReturningRelayCommandMethod,
+ context.Symbol.Locations.FirstOrDefault(),
+ context.Symbol));
+ }, SymbolKind.Method);
+ });
+ }
+}
diff --git a/src/CommunityToolkit.Mvvm.SourceGenerators/Diagnostics/DiagnosticDescriptors.cs b/src/CommunityToolkit.Mvvm.SourceGenerators/Diagnostics/DiagnosticDescriptors.cs
index 1305b2db6..1fef89e1d 100644
--- a/src/CommunityToolkit.Mvvm.SourceGenerators/Diagnostics/DiagnosticDescriptors.cs
+++ b/src/CommunityToolkit.Mvvm.SourceGenerators/Diagnostics/DiagnosticDescriptors.cs
@@ -29,6 +29,11 @@ internal static class DiagnosticDescriptors
///
public const string FieldReferenceForObservablePropertyFieldId = "MVVMTK0034";
+ ///
+ /// The diagnostic id for .
+ ///
+ public const string AsyncVoidReturningRelayCommandMethodId = "MVVMTK0039";
+
///
/// Gets a indicating when a duplicate declaration of would happen.
///
@@ -637,4 +642,20 @@ internal static class DiagnosticDescriptors
isEnabledByDefault: true,
description: "All attributes targeting the generated field or property for a method annotated with [RelayCommand] must be using valid expressions.",
helpLinkUri: "https://aka.ms/mvvmtoolkit/errors/mvvmtk0038");
+
+ ///
+ /// Gets a indicating when a method with [RelayCommand] is async void.
+ ///
+ /// Format: "The method {0} annotated with [RelayCommand] is async void (make sure to return a Task type instead)".
+ ///
+ ///
+ public static readonly DiagnosticDescriptor AsyncVoidReturningRelayCommandMethod = new DiagnosticDescriptor(
+ id: AsyncVoidReturningRelayCommandMethodId,
+ title: "Async void returning method annotated with RelayCommand",
+ messageFormat: "The method {0} annotated with [RelayCommand] is async void (make sure to return a Task type instead)",
+ category: typeof(RelayCommandGenerator).FullName,
+ defaultSeverity: DiagnosticSeverity.Warning,
+ isEnabledByDefault: true,
+ description: "All asynchronous methods annotated with [RelayCommand] should return a Task type, to benefit from the additional support provided by AsyncRelayCommand and AsyncRelayCommand.",
+ helpLinkUri: "https://aka.ms/mvvmtoolkit/errors/mvvmtk0039");
}
diff --git a/tests/CommunityToolkit.Mvvm.SourceGenerators.Roslyn401.UnitTests/Test_AsyncVoidReturningRelayCommandMethodCodeFixer.cs b/tests/CommunityToolkit.Mvvm.SourceGenerators.Roslyn401.UnitTests/Test_AsyncVoidReturningRelayCommandMethodCodeFixer.cs
new file mode 100644
index 000000000..e3c69458e
--- /dev/null
+++ b/tests/CommunityToolkit.Mvvm.SourceGenerators.Roslyn401.UnitTests/Test_AsyncVoidReturningRelayCommandMethodCodeFixer.cs
@@ -0,0 +1,113 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Threading.Tasks;
+using CommunityToolkit.Mvvm.Input;
+using Microsoft.CodeAnalysis.Testing;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using CSharpCodeFixTest = Microsoft.CodeAnalysis.CSharp.Testing.CSharpCodeFixTest<
+ CommunityToolkit.Mvvm.SourceGenerators.AsyncVoidReturningRelayCommandMethodAnalyzer,
+ CommunityToolkit.Mvvm.CodeFixers.AsyncVoidReturningRelayCommandMethodCodeFixer,
+ Microsoft.CodeAnalysis.Testing.Verifiers.MSTestVerifier>;
+using CSharpCodeFixVerifier = Microsoft.CodeAnalysis.CSharp.Testing.CSharpCodeFixVerifier<
+ CommunityToolkit.Mvvm.SourceGenerators.AsyncVoidReturningRelayCommandMethodAnalyzer,
+ CommunityToolkit.Mvvm.CodeFixers.AsyncVoidReturningRelayCommandMethodCodeFixer,
+ Microsoft.CodeAnalysis.Testing.Verifiers.MSTestVerifier>;
+
+namespace CommunityToolkit.Mvvm.SourceGenerators.Roslyn401.UnitTests;
+
+[TestClass]
+public class Test_AsyncVoidReturningRelayCommandMethodCodeFixer
+{
+ [TestMethod]
+ public async Task AsyncVoidMethod_FileContainsSystemThreadingTasksUsingDirective()
+ {
+ string original = """
+ using System.Threading.Tasks;
+ using CommunityToolkit.Mvvm.Input;
+
+ partial class C
+ {
+ [RelayCommand]
+ private async void Foo()
+ {
+ }
+ }
+ """;
+
+ string @fixed = """
+ using System.Threading.Tasks;
+ using CommunityToolkit.Mvvm.Input;
+
+ partial class C
+ {
+ [RelayCommand]
+ private async Task Foo()
+ {
+ }
+ }
+ """;
+
+ CSharpCodeFixTest test = new()
+ {
+ TestCode = original,
+ FixedCode = @fixed,
+ ReferenceAssemblies = ReferenceAssemblies.Net.Net60
+ };
+
+ test.TestState.AdditionalReferences.Add(typeof(RelayCommand).Assembly);
+ test.ExpectedDiagnostics.AddRange(new[]
+ {
+ // /0/Test0.cs(7,24): error MVVMTK0039: The method C.Foo() annotated with [RelayCommand] is async void (make sure to return a Task type instead)
+ CSharpCodeFixVerifier.Diagnostic().WithSpan(7, 24, 7, 27).WithArguments("C.Foo()")
+ });
+
+ await test.RunAsync();
+ }
+
+ [TestMethod]
+ public async Task AsyncVoidMethod_FileDoesNotContainSystemThreadingTasksUsingDirective()
+ {
+ string original = """
+ using CommunityToolkit.Mvvm.Input;
+
+ partial class C
+ {
+ [RelayCommand]
+ private async void Foo()
+ {
+ }
+ }
+ """;
+
+ string @fixed = """
+ using System.Threading.Tasks;
+ using CommunityToolkit.Mvvm.Input;
+
+ partial class C
+ {
+ [RelayCommand]
+ private async Task Foo()
+ {
+ }
+ }
+ """;
+
+ CSharpCodeFixTest test = new()
+ {
+ TestCode = original,
+ FixedCode = @fixed,
+ ReferenceAssemblies = ReferenceAssemblies.Net.Net60
+ };
+
+ test.TestState.AdditionalReferences.Add(typeof(RelayCommand).Assembly);
+ test.ExpectedDiagnostics.AddRange(new[]
+ {
+ // /0/Test0.cs(7,24): error MVVMTK0039: The method C.Foo() annotated with [RelayCommand] is async void (make sure to return a Task type instead)
+ CSharpCodeFixVerifier.Diagnostic().WithSpan(6, 24, 6, 27).WithArguments("C.Foo()")
+ });
+
+ await test.RunAsync();
+ }
+}
diff --git a/tests/CommunityToolkit.Mvvm.SourceGenerators.UnitTests/Test_SourceGeneratorsDiagnostics.cs b/tests/CommunityToolkit.Mvvm.SourceGenerators.UnitTests/Test_SourceGeneratorsDiagnostics.cs
index d4e6f04ee..26e29cdae 100644
--- a/tests/CommunityToolkit.Mvvm.SourceGenerators.UnitTests/Test_SourceGeneratorsDiagnostics.cs
+++ b/tests/CommunityToolkit.Mvvm.SourceGenerators.UnitTests/Test_SourceGeneratorsDiagnostics.cs
@@ -1801,6 +1801,25 @@ public class TestAttribute : Attribute
VerifyGeneratedDiagnostics(source, "MVVMTK0038");
}
+ [TestMethod]
+ public async Task AsyncVoidReturningRelayCommandMethodAnalyzer()
+ {
+ string source = """
+ using System;
+ using CommunityToolkit.Mvvm.Input;
+
+ public partial class MyViewModel
+ {
+ [RelayCommand]
+ private async void {|MVVMTK0039:TestAsync|}()
+ {
+ }
+ }
+ """;
+
+ await VerifyAnalyzerDiagnosticsAndSuccessfulGeneration(source, LanguageVersion.CSharp8);
+ }
+
///
/// Verifies the diagnostic errors for a given analyzer, and that all available source generators can run successfully with the input source (including subsequent compilation).
///