diff --git a/Directory.Packages.props b/Directory.Packages.props index 1fa7f407..42145995 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -19,7 +19,8 @@ - + + diff --git a/src/Common/ITypeSymbolExtensions.cs b/src/Common/ITypeSymbolExtensions.cs index 38cb625f..de0d1f10 100644 --- a/src/Common/ITypeSymbolExtensions.cs +++ b/src/Common/ITypeSymbolExtensions.cs @@ -96,7 +96,7 @@ typeSymbol is INamedTypeSymbol }, }; - public static bool IsCancellationToken(this ITypeSymbol typeSymbol) => + public static bool IsCancellationToken(this ITypeSymbol? typeSymbol) => typeSymbol is INamedTypeSymbol { Name: "CancellationToken", diff --git a/src/Common/SyntaxExtensions.cs b/src/Common/SyntaxExtensions.cs new file mode 100644 index 00000000..86486284 --- /dev/null +++ b/src/Common/SyntaxExtensions.cs @@ -0,0 +1,12 @@ +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; + +namespace Immediate.Handlers; + +internal static class SyntaxExtensions +{ + public static bool IsCancellationToken(this SemanticModel model, TypeSyntax? typeSyntax, CancellationToken token) => + typeSyntax is { } syntax + && model.GetSymbolInfo(syntax, token).Symbol is INamedTypeSymbol namedType + && namedType.IsCancellationToken(); +} diff --git a/src/Immediate.Handlers.CodeFixes/RefactoringExtensions.cs b/src/Immediate.Handlers.CodeFixes/RefactoringExtensions.cs new file mode 100644 index 00000000..8afaff73 --- /dev/null +++ b/src/Immediate.Handlers.CodeFixes/RefactoringExtensions.cs @@ -0,0 +1,35 @@ +using System.Diagnostics.CodeAnalysis; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CodeRefactorings; +using Microsoft.CodeAnalysis.Text; + +namespace Immediate.Handlers.CodeFixes; + +[ExcludeFromCodeCoverage] +internal static class RefactoringExtensions +{ + internal static void Deconstruct(this CodeRefactoringContext context, out Document document, out TextSpan span, out CancellationToken cancellationToken) + { + document = context.Document; + span = context.Span; + cancellationToken = context.CancellationToken; + } + + public static async ValueTask GetRequiredSyntaxRootAsync(this Document document, CancellationToken cancellationToken) + { + if (document.TryGetSyntaxRoot(out var root)) + return root; + + return await document.GetSyntaxRootAsync(cancellationToken) + ?? throw new InvalidOperationException($"Failed to retrieve the syntax root for document '{document.Name ?? document.FilePath ?? "unknown"}'."); + } + + public static async ValueTask GetRequiredSemanticModelAsync(this Document document, CancellationToken cancellationToken) + { + if (document.TryGetSemanticModel(out var semanticModel)) + return semanticModel; + + return await document.GetSemanticModelAsync(cancellationToken) + ?? throw new InvalidOperationException("Could not retrieve semantic model for the document."); + } +} diff --git a/src/Immediate.Handlers.CodeFixes/StaticToSealedHandlerRefactoringProvider.cs b/src/Immediate.Handlers.CodeFixes/StaticToSealedHandlerRefactoringProvider.cs new file mode 100644 index 00000000..d201f661 --- /dev/null +++ b/src/Immediate.Handlers.CodeFixes/StaticToSealedHandlerRefactoringProvider.cs @@ -0,0 +1,178 @@ +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CodeActions; +using Microsoft.CodeAnalysis.CodeRefactorings; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; + +namespace Immediate.Handlers.CodeFixes; + +[ExportCodeRefactoringProvider(LanguageNames.CSharp, Name = "Convert to instance handler")] +public sealed class StaticToSealedHandlerRefactoringProvider : CodeRefactoringProvider +{ + public override async Task ComputeRefactoringsAsync(CodeRefactoringContext context) + { + var (document, span, token) = context; + token.ThrowIfCancellationRequested(); + + if (await document.GetRequiredSyntaxRootAsync(token) is not CompilationUnitSyntax root) + return; + + var model = await document.GetRequiredSemanticModelAsync(token); + + switch (root.FindNode(span)) + { + case ClassDeclarationSyntax cds: + { + if (model.GetDeclaredSymbol(cds, token) is not INamedTypeSymbol { IsStatic: true } container) + return; + + if (!container.GetAttributes().Any(a => a.AttributeClass.IsHandlerAttribute())) + return; + + var method = container.GetMembers() + .OfType() + .FirstOrDefault(m => m is { IsStatic: true, Name: "Handle" or "HandleAsync" }); + + if (method is null) + return; + + var mds = (MethodDeclarationSyntax)await method + .DeclaringSyntaxReferences[0] + .GetSyntaxAsync(token); + + var service = new RefactoringService( + document, + model, + root, + cds, + mds + ); + + context.RegisterRefactoring( + CodeAction.Create( + title: "Convert to instance handler", + createChangedDocument: service.ConvertToInstanceHandler, + equivalenceKey: nameof(StaticToSealedHandlerRefactoringProvider) + ) + ); + + break; + } + + case MethodDeclarationSyntax mds: + { + if (model.GetDeclaredSymbol(mds, token) is not IMethodSymbol + { + IsStatic: true, + Name: "Handle" or "HandleAsync", + ContainingType: INamedTypeSymbol { IsStatic: true } container, + } method) + { + return; + } + + if (!container.GetAttributes().Any(a => a.AttributeClass.IsHandlerAttribute())) + return; + + var service = new RefactoringService( + document, + model, + root, + (ClassDeclarationSyntax)mds.Parent!, + mds + ); + + context.RegisterRefactoring( + CodeAction.Create( + title: "Convert to instance handler", + createChangedDocument: service.ConvertToInstanceHandler, + equivalenceKey: nameof(StaticToSealedHandlerRefactoringProvider) + ) + ); + + break; + } + + default: + break; + } + } + +} + +file sealed class RefactoringService( + Document document, + SemanticModel model, + CompilationUnitSyntax documentRoot, + ClassDeclarationSyntax classDeclarationSyntax, + MethodDeclarationSyntax methodDeclarationSyntax +) +{ + public Task ConvertToInstanceHandler( + CancellationToken token + ) + { + var methodParameters = methodDeclarationSyntax.ParameterList.Parameters; + + var isLastParamCancellationToken = model.IsCancellationToken(methodParameters[^1].Type, token); + + var classParameters = methodParameters + .Skip(1) + .Take(methodParameters.Count - (isLastParamCancellationToken ? 2 : 1)) + .Select(p => p.WithTrailingTrivia(ElasticSpace)) + .ToList(); + + var newMethodParameters = methodParameters.RemoveParametersUntilCount(isLastParamCancellationToken ? 2 : 1); + + var newMethodDeclarationSyntax = methodDeclarationSyntax + .WithParameterList( + methodDeclarationSyntax.ParameterList + .WithParameters(newMethodParameters) + ) + .WithModifiers( + methodDeclarationSyntax.Modifiers + .RemoveStaticModifier() + ); + + var newClassDeclarationSyntax = classDeclarationSyntax + .ReplaceNode(methodDeclarationSyntax, newMethodDeclarationSyntax) + .WithModifiers( + classDeclarationSyntax.Modifiers + .RemoveStaticModifier() + .Insert( + // valid case will have `partial` as final element; insert `sealed` before `partial` + classDeclarationSyntax.Modifiers.Count - 2, + Token(SyntaxKind.SealedKeyword).WithTrailingTrivia(ElasticSpace) + ) + ); + + if (classParameters.Count > 0) + { + newClassDeclarationSyntax = newClassDeclarationSyntax + .WithParameterList( + ParameterList(SeparatedList(classParameters)) + ) + .WithIdentifier(classDeclarationSyntax.Identifier.WithoutTrivia()); + } + + return Task.FromResult(document.WithSyntaxRoot(documentRoot.ReplaceNode(classDeclarationSyntax, newClassDeclarationSyntax))); + } +} + +file static class SyntaxExtensions +{ + public static SeparatedSyntaxList RemoveParametersUntilCount( + this SeparatedSyntaxList nodes, + int count + ) + { + while (nodes.Count > count) + nodes = nodes.RemoveAt(1); + return nodes; + } + + public static SyntaxTokenList RemoveStaticModifier( + this SyntaxTokenList list + ) => new(list.Where(static token => !token.IsKind(SyntaxKind.StaticKeyword))); +} diff --git a/tests/Immediate.Handlers.Tests/CodeFixTests/CodeRefactoringTestHelper.cs b/tests/Immediate.Handlers.Tests/CodeFixTests/CodeRefactoringTestHelper.cs new file mode 100644 index 00000000..b311cd9b --- /dev/null +++ b/tests/Immediate.Handlers.Tests/CodeFixTests/CodeRefactoringTestHelper.cs @@ -0,0 +1,52 @@ +using System.Diagnostics.CodeAnalysis; +using Immediate.Handlers.Tests.Helpers; +using Microsoft.CodeAnalysis.CodeRefactorings; +using Microsoft.CodeAnalysis.CSharp.Testing; +using Microsoft.CodeAnalysis.Testing; + +namespace Immediate.Handlers.Tests.CodeFixTests; + +public static class CodeRefactoringTestHelper +{ + private const string EditorConfig = + """ + root = true + + [*.cs] + charset = utf-8 + indent_style = tab + insert_final_newline = true + indent_size = 4 + """; + + public static CSharpCodeRefactoringTest CreateCodeRefactoringTest( + [StringSyntax("c#-test")] string inputSource, + [StringSyntax("c#-test")] string fixedSource, + int codeActionIndex = 0 + ) + where TRefactoring : CodeRefactoringProvider, new() + { + var csTest = new CSharpCodeRefactoringTest + { + CodeActionIndex = codeActionIndex, + TestState = + { + Sources = { inputSource }, + AnalyzerConfigFiles = { { ("/.editorconfig", EditorConfig) } }, + ReferenceAssemblies = new ReferenceAssemblies( + "net8.0", + new PackageIdentity( + "Microsoft.NETCore.App.Ref", + "8.0.0"), + Path.Combine("ref", "net8.0") + ), + }, + FixedState = { MarkupHandling = MarkupMode.IgnoreFixable, Sources = { fixedSource } }, + }; + + csTest.TestState.AdditionalReferences + .AddRange(DriverReferenceAssemblies.Msdi.GetAdditionalReferences()); + + return csTest; + } +} diff --git a/tests/Immediate.Handlers.Tests/CodeFixTests/Tests.HandleMethodDoesNotExist.cs b/tests/Immediate.Handlers.Tests/CodeFixTests/HandlerMethodMustExistCodeFixProviderTests.cs similarity index 94% rename from tests/Immediate.Handlers.Tests/CodeFixTests/Tests.HandleMethodDoesNotExist.cs rename to tests/Immediate.Handlers.Tests/CodeFixTests/HandlerMethodMustExistCodeFixProviderTests.cs index 1b9a3116..814be64d 100644 --- a/tests/Immediate.Handlers.Tests/CodeFixTests/Tests.HandleMethodDoesNotExist.cs +++ b/tests/Immediate.Handlers.Tests/CodeFixTests/HandlerMethodMustExistCodeFixProviderTests.cs @@ -5,7 +5,7 @@ namespace Immediate.Handlers.Tests.CodeFixTests; [System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "CA1724:Type names should not match namespaces", Justification = "Not being consumed by other code")] -public sealed partial class Tests +public sealed partial class HandlerMethodMustExistCodeFixProviderTests { [Test] public async Task HandleMethodDoesNotExist() => diff --git a/tests/Immediate.Handlers.Tests/CodeFixTests/StaticToSealedHandlerRefactoringProviderTests.cs b/tests/Immediate.Handlers.Tests/CodeFixTests/StaticToSealedHandlerRefactoringProviderTests.cs new file mode 100644 index 00000000..89e07231 --- /dev/null +++ b/tests/Immediate.Handlers.Tests/CodeFixTests/StaticToSealedHandlerRefactoringProviderTests.cs @@ -0,0 +1,183 @@ +using Immediate.Handlers.CodeFixes; + +namespace Immediate.Handlers.Tests.CodeFixTests; + +public sealed class StaticToSealedHandlerRefactoringProviderTests +{ + [Test] + public async Task RefactorOnHandlerClass() => + await CodeRefactoringTestHelper.CreateCodeRefactoringTest( + """ + using System.Threading; + using System.Threading.Tasks; + + namespace Immediate.Handlers.Shared; + + public sealed class Dependency1; + public sealed class Dependency2; + + [Handler] + public static partial class {|Refactoring:DoSomething|} + { + public sealed record Query; + public sealed record Response; + + private static ValueTask HandleAsync( + Query query, + Dependency1 dependency1, + Dependency2 dependency2, + CancellationToken token + ) + { + Method(dependency2, 2); + return new(); + } + + private static void Method(Dependency2 dependency2, int value) + { + } + } + """, + """ + using System.Threading; + using System.Threading.Tasks; + + namespace Immediate.Handlers.Shared; + + public sealed class Dependency1; + public sealed class Dependency2; + + [Handler] + public sealed partial class DoSomething(Dependency1 dependency1, Dependency2 dependency2) + { + public sealed record Query; + public sealed record Response; + + private ValueTask HandleAsync( + Query query, + CancellationToken token + ) + { + Method(dependency2, 2); + return new(); + } + + private static void Method(Dependency2 dependency2, int value) + { + } + } + """ + ).RunAsync(); + + [Test] + public async Task RefactorOnHandlerMethod() => + await CodeRefactoringTestHelper.CreateCodeRefactoringTest( + """ + using System.Threading; + using System.Threading.Tasks; + + namespace Immediate.Handlers.Shared; + + public sealed class Dependency1; + public sealed class Dependency2; + + [Handler] + public static partial class DoSomething + { + public sealed record Query; + public sealed record Response; + + private static ValueTask {|Refactoring:HandleAsync|}( + Query query, + Dependency1 dependency1, + Dependency2 dependency2, + CancellationToken token + ) + { + Method(dependency2, 2); + return new(); + } + + private static void Method(Dependency2 dependency2, int value) + { + } + } + """, + """ + using System.Threading; + using System.Threading.Tasks; + + namespace Immediate.Handlers.Shared; + + public sealed class Dependency1; + public sealed class Dependency2; + + [Handler] + public sealed partial class DoSomething(Dependency1 dependency1, Dependency2 dependency2) + { + public sealed record Query; + public sealed record Response; + + private ValueTask HandleAsync( + Query query, + CancellationToken token + ) + { + Method(dependency2, 2); + return new(); + } + + private static void Method(Dependency2 dependency2, int value) + { + } + } + """ + ).RunAsync(); + + [Test] + public async Task RefactorWithNoDependencyParameters() => + await CodeRefactoringTestHelper.CreateCodeRefactoringTest( + """ + using System.Threading; + using System.Threading.Tasks; + + namespace Immediate.Handlers.Shared; + + [Handler] + public static partial class {|Refactoring:DoSomething|} + { + public sealed record Query; + public sealed record Response; + + private static ValueTask HandleAsync( + Query query, + CancellationToken token + ) + { + return new(); + } + } + """, + """ + using System.Threading; + using System.Threading.Tasks; + + namespace Immediate.Handlers.Shared; + + [Handler] + public sealed partial class DoSomething + { + public sealed record Query; + public sealed record Response; + + private ValueTask HandleAsync( + Query query, + CancellationToken token + ) + { + return new(); + } + } + """ + ).RunAsync(); +} diff --git a/tests/Immediate.Handlers.Tests/Immediate.Handlers.Tests.csproj b/tests/Immediate.Handlers.Tests/Immediate.Handlers.Tests.csproj index 96974f00..aa6bdba7 100644 --- a/tests/Immediate.Handlers.Tests/Immediate.Handlers.Tests.csproj +++ b/tests/Immediate.Handlers.Tests/Immediate.Handlers.Tests.csproj @@ -19,6 +19,7 @@ +