diff --git a/Directory.Packages.props b/Directory.Packages.props
index 1fa7f407..42145995 100644
--- a/Directory.Packages.props
+++ b/Directory.Packages.props
@@ -19,7 +19,8 @@
-
+
+
diff --git a/src/Common/ITypeSymbolExtensions.cs b/src/Common/ITypeSymbolExtensions.cs
index 38cb625f..de0d1f10 100644
--- a/src/Common/ITypeSymbolExtensions.cs
+++ b/src/Common/ITypeSymbolExtensions.cs
@@ -96,7 +96,7 @@ typeSymbol is INamedTypeSymbol
},
};
- public static bool IsCancellationToken(this ITypeSymbol typeSymbol) =>
+ public static bool IsCancellationToken(this ITypeSymbol? typeSymbol) =>
typeSymbol is INamedTypeSymbol
{
Name: "CancellationToken",
diff --git a/src/Common/SyntaxExtensions.cs b/src/Common/SyntaxExtensions.cs
new file mode 100644
index 00000000..86486284
--- /dev/null
+++ b/src/Common/SyntaxExtensions.cs
@@ -0,0 +1,12 @@
+using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.CSharp.Syntax;
+
+namespace Immediate.Handlers;
+
+internal static class SyntaxExtensions
+{
+ public static bool IsCancellationToken(this SemanticModel model, TypeSyntax? typeSyntax, CancellationToken token) =>
+ typeSyntax is { } syntax
+ && model.GetSymbolInfo(syntax, token).Symbol is INamedTypeSymbol namedType
+ && namedType.IsCancellationToken();
+}
diff --git a/src/Immediate.Handlers.CodeFixes/RefactoringExtensions.cs b/src/Immediate.Handlers.CodeFixes/RefactoringExtensions.cs
new file mode 100644
index 00000000..8afaff73
--- /dev/null
+++ b/src/Immediate.Handlers.CodeFixes/RefactoringExtensions.cs
@@ -0,0 +1,35 @@
+using System.Diagnostics.CodeAnalysis;
+using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.CodeRefactorings;
+using Microsoft.CodeAnalysis.Text;
+
+namespace Immediate.Handlers.CodeFixes;
+
+[ExcludeFromCodeCoverage]
+internal static class RefactoringExtensions
+{
+ internal static void Deconstruct(this CodeRefactoringContext context, out Document document, out TextSpan span, out CancellationToken cancellationToken)
+ {
+ document = context.Document;
+ span = context.Span;
+ cancellationToken = context.CancellationToken;
+ }
+
+ public static async ValueTask GetRequiredSyntaxRootAsync(this Document document, CancellationToken cancellationToken)
+ {
+ if (document.TryGetSyntaxRoot(out var root))
+ return root;
+
+ return await document.GetSyntaxRootAsync(cancellationToken)
+ ?? throw new InvalidOperationException($"Failed to retrieve the syntax root for document '{document.Name ?? document.FilePath ?? "unknown"}'.");
+ }
+
+ public static async ValueTask GetRequiredSemanticModelAsync(this Document document, CancellationToken cancellationToken)
+ {
+ if (document.TryGetSemanticModel(out var semanticModel))
+ return semanticModel;
+
+ return await document.GetSemanticModelAsync(cancellationToken)
+ ?? throw new InvalidOperationException("Could not retrieve semantic model for the document.");
+ }
+}
diff --git a/src/Immediate.Handlers.CodeFixes/StaticToSealedHandlerRefactoringProvider.cs b/src/Immediate.Handlers.CodeFixes/StaticToSealedHandlerRefactoringProvider.cs
new file mode 100644
index 00000000..d201f661
--- /dev/null
+++ b/src/Immediate.Handlers.CodeFixes/StaticToSealedHandlerRefactoringProvider.cs
@@ -0,0 +1,178 @@
+using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.CodeActions;
+using Microsoft.CodeAnalysis.CodeRefactorings;
+using Microsoft.CodeAnalysis.CSharp;
+using Microsoft.CodeAnalysis.CSharp.Syntax;
+using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
+
+namespace Immediate.Handlers.CodeFixes;
+
+[ExportCodeRefactoringProvider(LanguageNames.CSharp, Name = "Convert to instance handler")]
+public sealed class StaticToSealedHandlerRefactoringProvider : CodeRefactoringProvider
+{
+ public override async Task ComputeRefactoringsAsync(CodeRefactoringContext context)
+ {
+ var (document, span, token) = context;
+ token.ThrowIfCancellationRequested();
+
+ if (await document.GetRequiredSyntaxRootAsync(token) is not CompilationUnitSyntax root)
+ return;
+
+ var model = await document.GetRequiredSemanticModelAsync(token);
+
+ switch (root.FindNode(span))
+ {
+ case ClassDeclarationSyntax cds:
+ {
+ if (model.GetDeclaredSymbol(cds, token) is not INamedTypeSymbol { IsStatic: true } container)
+ return;
+
+ if (!container.GetAttributes().Any(a => a.AttributeClass.IsHandlerAttribute()))
+ return;
+
+ var method = container.GetMembers()
+ .OfType()
+ .FirstOrDefault(m => m is { IsStatic: true, Name: "Handle" or "HandleAsync" });
+
+ if (method is null)
+ return;
+
+ var mds = (MethodDeclarationSyntax)await method
+ .DeclaringSyntaxReferences[0]
+ .GetSyntaxAsync(token);
+
+ var service = new RefactoringService(
+ document,
+ model,
+ root,
+ cds,
+ mds
+ );
+
+ context.RegisterRefactoring(
+ CodeAction.Create(
+ title: "Convert to instance handler",
+ createChangedDocument: service.ConvertToInstanceHandler,
+ equivalenceKey: nameof(StaticToSealedHandlerRefactoringProvider)
+ )
+ );
+
+ break;
+ }
+
+ case MethodDeclarationSyntax mds:
+ {
+ if (model.GetDeclaredSymbol(mds, token) is not IMethodSymbol
+ {
+ IsStatic: true,
+ Name: "Handle" or "HandleAsync",
+ ContainingType: INamedTypeSymbol { IsStatic: true } container,
+ } method)
+ {
+ return;
+ }
+
+ if (!container.GetAttributes().Any(a => a.AttributeClass.IsHandlerAttribute()))
+ return;
+
+ var service = new RefactoringService(
+ document,
+ model,
+ root,
+ (ClassDeclarationSyntax)mds.Parent!,
+ mds
+ );
+
+ context.RegisterRefactoring(
+ CodeAction.Create(
+ title: "Convert to instance handler",
+ createChangedDocument: service.ConvertToInstanceHandler,
+ equivalenceKey: nameof(StaticToSealedHandlerRefactoringProvider)
+ )
+ );
+
+ break;
+ }
+
+ default:
+ break;
+ }
+ }
+
+}
+
+file sealed class RefactoringService(
+ Document document,
+ SemanticModel model,
+ CompilationUnitSyntax documentRoot,
+ ClassDeclarationSyntax classDeclarationSyntax,
+ MethodDeclarationSyntax methodDeclarationSyntax
+)
+{
+ public Task ConvertToInstanceHandler(
+ CancellationToken token
+ )
+ {
+ var methodParameters = methodDeclarationSyntax.ParameterList.Parameters;
+
+ var isLastParamCancellationToken = model.IsCancellationToken(methodParameters[^1].Type, token);
+
+ var classParameters = methodParameters
+ .Skip(1)
+ .Take(methodParameters.Count - (isLastParamCancellationToken ? 2 : 1))
+ .Select(p => p.WithTrailingTrivia(ElasticSpace))
+ .ToList();
+
+ var newMethodParameters = methodParameters.RemoveParametersUntilCount(isLastParamCancellationToken ? 2 : 1);
+
+ var newMethodDeclarationSyntax = methodDeclarationSyntax
+ .WithParameterList(
+ methodDeclarationSyntax.ParameterList
+ .WithParameters(newMethodParameters)
+ )
+ .WithModifiers(
+ methodDeclarationSyntax.Modifiers
+ .RemoveStaticModifier()
+ );
+
+ var newClassDeclarationSyntax = classDeclarationSyntax
+ .ReplaceNode(methodDeclarationSyntax, newMethodDeclarationSyntax)
+ .WithModifiers(
+ classDeclarationSyntax.Modifiers
+ .RemoveStaticModifier()
+ .Insert(
+ // valid case will have `partial` as final element; insert `sealed` before `partial`
+ classDeclarationSyntax.Modifiers.Count - 2,
+ Token(SyntaxKind.SealedKeyword).WithTrailingTrivia(ElasticSpace)
+ )
+ );
+
+ if (classParameters.Count > 0)
+ {
+ newClassDeclarationSyntax = newClassDeclarationSyntax
+ .WithParameterList(
+ ParameterList(SeparatedList(classParameters))
+ )
+ .WithIdentifier(classDeclarationSyntax.Identifier.WithoutTrivia());
+ }
+
+ return Task.FromResult(document.WithSyntaxRoot(documentRoot.ReplaceNode(classDeclarationSyntax, newClassDeclarationSyntax)));
+ }
+}
+
+file static class SyntaxExtensions
+{
+ public static SeparatedSyntaxList RemoveParametersUntilCount(
+ this SeparatedSyntaxList nodes,
+ int count
+ )
+ {
+ while (nodes.Count > count)
+ nodes = nodes.RemoveAt(1);
+ return nodes;
+ }
+
+ public static SyntaxTokenList RemoveStaticModifier(
+ this SyntaxTokenList list
+ ) => new(list.Where(static token => !token.IsKind(SyntaxKind.StaticKeyword)));
+}
diff --git a/tests/Immediate.Handlers.Tests/CodeFixTests/CodeRefactoringTestHelper.cs b/tests/Immediate.Handlers.Tests/CodeFixTests/CodeRefactoringTestHelper.cs
new file mode 100644
index 00000000..b311cd9b
--- /dev/null
+++ b/tests/Immediate.Handlers.Tests/CodeFixTests/CodeRefactoringTestHelper.cs
@@ -0,0 +1,52 @@
+using System.Diagnostics.CodeAnalysis;
+using Immediate.Handlers.Tests.Helpers;
+using Microsoft.CodeAnalysis.CodeRefactorings;
+using Microsoft.CodeAnalysis.CSharp.Testing;
+using Microsoft.CodeAnalysis.Testing;
+
+namespace Immediate.Handlers.Tests.CodeFixTests;
+
+public static class CodeRefactoringTestHelper
+{
+ private const string EditorConfig =
+ """
+ root = true
+
+ [*.cs]
+ charset = utf-8
+ indent_style = tab
+ insert_final_newline = true
+ indent_size = 4
+ """;
+
+ public static CSharpCodeRefactoringTest CreateCodeRefactoringTest(
+ [StringSyntax("c#-test")] string inputSource,
+ [StringSyntax("c#-test")] string fixedSource,
+ int codeActionIndex = 0
+ )
+ where TRefactoring : CodeRefactoringProvider, new()
+ {
+ var csTest = new CSharpCodeRefactoringTest
+ {
+ CodeActionIndex = codeActionIndex,
+ TestState =
+ {
+ Sources = { inputSource },
+ AnalyzerConfigFiles = { { ("/.editorconfig", EditorConfig) } },
+ ReferenceAssemblies = new ReferenceAssemblies(
+ "net8.0",
+ new PackageIdentity(
+ "Microsoft.NETCore.App.Ref",
+ "8.0.0"),
+ Path.Combine("ref", "net8.0")
+ ),
+ },
+ FixedState = { MarkupHandling = MarkupMode.IgnoreFixable, Sources = { fixedSource } },
+ };
+
+ csTest.TestState.AdditionalReferences
+ .AddRange(DriverReferenceAssemblies.Msdi.GetAdditionalReferences());
+
+ return csTest;
+ }
+}
diff --git a/tests/Immediate.Handlers.Tests/CodeFixTests/Tests.HandleMethodDoesNotExist.cs b/tests/Immediate.Handlers.Tests/CodeFixTests/HandlerMethodMustExistCodeFixProviderTests.cs
similarity index 94%
rename from tests/Immediate.Handlers.Tests/CodeFixTests/Tests.HandleMethodDoesNotExist.cs
rename to tests/Immediate.Handlers.Tests/CodeFixTests/HandlerMethodMustExistCodeFixProviderTests.cs
index 1b9a3116..814be64d 100644
--- a/tests/Immediate.Handlers.Tests/CodeFixTests/Tests.HandleMethodDoesNotExist.cs
+++ b/tests/Immediate.Handlers.Tests/CodeFixTests/HandlerMethodMustExistCodeFixProviderTests.cs
@@ -5,7 +5,7 @@
namespace Immediate.Handlers.Tests.CodeFixTests;
[System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "CA1724:Type names should not match namespaces", Justification = "Not being consumed by other code")]
-public sealed partial class Tests
+public sealed partial class HandlerMethodMustExistCodeFixProviderTests
{
[Test]
public async Task HandleMethodDoesNotExist() =>
diff --git a/tests/Immediate.Handlers.Tests/CodeFixTests/StaticToSealedHandlerRefactoringProviderTests.cs b/tests/Immediate.Handlers.Tests/CodeFixTests/StaticToSealedHandlerRefactoringProviderTests.cs
new file mode 100644
index 00000000..89e07231
--- /dev/null
+++ b/tests/Immediate.Handlers.Tests/CodeFixTests/StaticToSealedHandlerRefactoringProviderTests.cs
@@ -0,0 +1,183 @@
+using Immediate.Handlers.CodeFixes;
+
+namespace Immediate.Handlers.Tests.CodeFixTests;
+
+public sealed class StaticToSealedHandlerRefactoringProviderTests
+{
+ [Test]
+ public async Task RefactorOnHandlerClass() =>
+ await CodeRefactoringTestHelper.CreateCodeRefactoringTest(
+ """
+ using System.Threading;
+ using System.Threading.Tasks;
+
+ namespace Immediate.Handlers.Shared;
+
+ public sealed class Dependency1;
+ public sealed class Dependency2;
+
+ [Handler]
+ public static partial class {|Refactoring:DoSomething|}
+ {
+ public sealed record Query;
+ public sealed record Response;
+
+ private static ValueTask HandleAsync(
+ Query query,
+ Dependency1 dependency1,
+ Dependency2 dependency2,
+ CancellationToken token
+ )
+ {
+ Method(dependency2, 2);
+ return new();
+ }
+
+ private static void Method(Dependency2 dependency2, int value)
+ {
+ }
+ }
+ """,
+ """
+ using System.Threading;
+ using System.Threading.Tasks;
+
+ namespace Immediate.Handlers.Shared;
+
+ public sealed class Dependency1;
+ public sealed class Dependency2;
+
+ [Handler]
+ public sealed partial class DoSomething(Dependency1 dependency1, Dependency2 dependency2)
+ {
+ public sealed record Query;
+ public sealed record Response;
+
+ private ValueTask HandleAsync(
+ Query query,
+ CancellationToken token
+ )
+ {
+ Method(dependency2, 2);
+ return new();
+ }
+
+ private static void Method(Dependency2 dependency2, int value)
+ {
+ }
+ }
+ """
+ ).RunAsync();
+
+ [Test]
+ public async Task RefactorOnHandlerMethod() =>
+ await CodeRefactoringTestHelper.CreateCodeRefactoringTest(
+ """
+ using System.Threading;
+ using System.Threading.Tasks;
+
+ namespace Immediate.Handlers.Shared;
+
+ public sealed class Dependency1;
+ public sealed class Dependency2;
+
+ [Handler]
+ public static partial class DoSomething
+ {
+ public sealed record Query;
+ public sealed record Response;
+
+ private static ValueTask {|Refactoring:HandleAsync|}(
+ Query query,
+ Dependency1 dependency1,
+ Dependency2 dependency2,
+ CancellationToken token
+ )
+ {
+ Method(dependency2, 2);
+ return new();
+ }
+
+ private static void Method(Dependency2 dependency2, int value)
+ {
+ }
+ }
+ """,
+ """
+ using System.Threading;
+ using System.Threading.Tasks;
+
+ namespace Immediate.Handlers.Shared;
+
+ public sealed class Dependency1;
+ public sealed class Dependency2;
+
+ [Handler]
+ public sealed partial class DoSomething(Dependency1 dependency1, Dependency2 dependency2)
+ {
+ public sealed record Query;
+ public sealed record Response;
+
+ private ValueTask HandleAsync(
+ Query query,
+ CancellationToken token
+ )
+ {
+ Method(dependency2, 2);
+ return new();
+ }
+
+ private static void Method(Dependency2 dependency2, int value)
+ {
+ }
+ }
+ """
+ ).RunAsync();
+
+ [Test]
+ public async Task RefactorWithNoDependencyParameters() =>
+ await CodeRefactoringTestHelper.CreateCodeRefactoringTest(
+ """
+ using System.Threading;
+ using System.Threading.Tasks;
+
+ namespace Immediate.Handlers.Shared;
+
+ [Handler]
+ public static partial class {|Refactoring:DoSomething|}
+ {
+ public sealed record Query;
+ public sealed record Response;
+
+ private static ValueTask HandleAsync(
+ Query query,
+ CancellationToken token
+ )
+ {
+ return new();
+ }
+ }
+ """,
+ """
+ using System.Threading;
+ using System.Threading.Tasks;
+
+ namespace Immediate.Handlers.Shared;
+
+ [Handler]
+ public sealed partial class DoSomething
+ {
+ public sealed record Query;
+ public sealed record Response;
+
+ private ValueTask HandleAsync(
+ Query query,
+ CancellationToken token
+ )
+ {
+ return new();
+ }
+ }
+ """
+ ).RunAsync();
+}
diff --git a/tests/Immediate.Handlers.Tests/Immediate.Handlers.Tests.csproj b/tests/Immediate.Handlers.Tests/Immediate.Handlers.Tests.csproj
index 96974f00..aa6bdba7 100644
--- a/tests/Immediate.Handlers.Tests/Immediate.Handlers.Tests.csproj
+++ b/tests/Immediate.Handlers.Tests/Immediate.Handlers.Tests.csproj
@@ -19,6 +19,7 @@
+