Skip to content
Merged
4 changes: 3 additions & 1 deletion tools/azsdk-cli/.editorconfig
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
[*.cs]
# disallow single if statements with no curly braces
dotnet_diagnostic.SA1503.severity = error
dotnet_diagnostic.IDE0011.severity = error
dotnet_style_namespace_match_folder = false
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
using System;
using System.Collections.Immutable;
using System.Linq;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Diagnostics;

namespace Azure.Sdk.Tools.Cli.Analyzer
{
[DiagnosticAnalyzer(LanguageNames.CSharp)]
public class EnforceToolsReturnTypesAnalyzer : DiagnosticAnalyzer
{
public const string Id = "MCP003";
public static readonly DiagnosticDescriptor Rule = new DiagnosticDescriptor(
Id,
"Tool methods must return Response types, built-in value types, or string",
"Method '{0}' in Tools namespace must return a class implementing Response, a built-in value type, or string. Current return type: '{1}'.",
"Design",
DiagnosticSeverity.Error,
isEnabledByDefault: true);

public override ImmutableArray<DiagnosticDescriptor> SupportedDiagnostics => ImmutableArray.Create(Rule);

public override void Initialize(AnalysisContext context)
{
context.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.None);
context.EnableConcurrentExecution();
context.RegisterSyntaxNodeAction(AnalyzeMethod, SyntaxKind.MethodDeclaration);
}

private static void AnalyzeMethod(SyntaxNodeAnalysisContext context)
{
var methodDeclaration = (MethodDeclarationSyntax)context.Node;
var semanticModel = context.SemanticModel;

// Get method symbol
if (!(semanticModel.GetDeclaredSymbol(methodDeclaration) is IMethodSymbol methodSymbol))
{
return;
}

// Only analyze public non-static methods
if (!methodSymbol.DeclaredAccessibility.HasFlag(Accessibility.Public) || methodSymbol.IsStatic)
{
return;
}

// Only analyze methods in classes within Azure.Sdk.Tools.Cli.Tools namespace
var containingType = methodSymbol.ContainingType;
if (containingType?.ContainingNamespace == null)
{
return;
}

var namespaceName = containingType.ContainingNamespace.ToDisplayString();
if (!namespaceName.StartsWith("Azure.Sdk.Tools.Cli.Tools"))
{
return;
}

// Exclude abstract methods and virtual methods that are likely from base classes
if (methodSymbol.IsAbstract || methodSymbol.IsVirtual || methodSymbol.IsOverride)
{
return;
}

// Exclude specific framework methods by name
if (methodSymbol.Name == "GetCommand" || methodSymbol.Name == "HandleCommand")
{
return;
}

var returnType = methodSymbol.ReturnType;

// Handle Task<T> - get the inner type
if (IsTaskType(returnType, out var innerType))
{
returnType = innerType;
}

// Check if return type is valid
if (!IsValidReturnType(returnType, context.Compilation))
{
var returnTypeDisplayName = returnType.ToDisplayString();
var diagnostic = Diagnostic.Create(Rule,
methodDeclaration.Identifier.GetLocation(),
methodSymbol.Name,
returnTypeDisplayName);
context.ReportDiagnostic(diagnostic);
}
}

private static bool IsTaskType(ITypeSymbol type, out ITypeSymbol innerType)
{
innerType = type;

if (type is INamedTypeSymbol namedType)
{
// Check for Task<T>
if (namedType.IsGenericType &&
(namedType.ConstructedFrom?.ToDisplayString() == "System.Threading.Tasks.Task<T>" ||
namedType.Name == "Task" && namedType.ContainingNamespace?.ToDisplayString() == "System.Threading.Tasks"))
Comment on lines +102 to +103
Copy link

Copilot AI Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The Task detection logic checks both the constructed form string and separate name/namespace conditions. This is redundant and could be simplified to use either the ToDisplayString() approach or the name/namespace approach consistently.

Suggested change
(namedType.ConstructedFrom?.ToDisplayString() == "System.Threading.Tasks.Task<T>" ||
namedType.Name == "Task" && namedType.ContainingNamespace?.ToDisplayString() == "System.Threading.Tasks"))
namedType.ConstructedFrom?.ToDisplayString() == "System.Threading.Tasks.Task<T>")

Copilot uses AI. Check for mistakes.
{
if (namedType.TypeArguments.Length > 0)
{
innerType = namedType.TypeArguments[0];
return true;
}
}
}

return false;
}

private static bool IsPrimitiveOrString(ITypeSymbol returnType)
{
switch (returnType.SpecialType)
{
case SpecialType.System_String:
case SpecialType.System_Boolean:
case SpecialType.System_Byte:
case SpecialType.System_Char: // NOTE: this seems to be matching against 'string' for some reason
Copy link

Copilot AI Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment suggests System_Char is matching against 'string', which indicates a potential logic error. System_Char should match char types, not string. If this is unexpected behavior, the logic needs investigation and the comment should be clarified or the case removed if incorrect.

Suggested change
case SpecialType.System_Char: // NOTE: this seems to be matching against 'string' for some reason

Copilot uses AI. Check for mistakes.
case SpecialType.System_Double:
case SpecialType.System_Int16:
case SpecialType.System_Int32:
case SpecialType.System_Int64:
case SpecialType.System_SByte:
case SpecialType.System_Single:
case SpecialType.System_UInt16:
case SpecialType.System_UInt32:
case SpecialType.System_UInt64:
return true;
}

return false;
}

private static bool IsValidReturnType(ITypeSymbol returnType, Compilation compilation)
{
if (IsPrimitiveOrString(returnType))
{
return true;
}

// void is allowed (for non-async methods)
if (returnType.SpecialType == SpecialType.System_Void)
{
return true;
}

// Task (without generic parameter) is allowed for void async methods
if (returnType.Name == "Task"
&& returnType.ContainingNamespace?.ToDisplayString() == "System.Threading.Tasks"
&& returnType is INamedTypeSymbol namedTaskType
&& !namedTaskType.IsGenericType)
{
return true;
}

// Check if it implements Response (look in Azure.Sdk.Tools.Cli.Models namespace)
var responseType = compilation.GetTypeByMetadataName("Azure.Sdk.Tools.Cli.Models.Response");
if (responseType != null && InheritsFromOrImplements(returnType, responseType))
{
return true;
}

// Check if it's an enumerable of allowed types
if (IsEnumerableOfAllowedType(returnType, compilation))
{
return true;
}

return false;
}

private static bool InheritsFromOrImplements(ITypeSymbol type, ITypeSymbol baseType)
{
// Check inheritance chain
for (var current = type; current != null; current = current.BaseType)
{
if (SymbolEqualityComparer.Default.Equals(current, baseType))
{
return true;
}
}

// Check interfaces
foreach (var interfaceType in type.AllInterfaces)
{
if (SymbolEqualityComparer.Default.Equals(interfaceType, baseType))
{
return true;
}
}

return false;
}

private static bool IsEnumerableOfAllowedType(ITypeSymbol returnType, Compilation compilation)
{
var ienumerableInterface = returnType.AllInterfaces.FirstOrDefault(
i => i.IsGenericType &&
i.ConstructedFrom?.ToDisplayString() == "System.Collections.Generic.IEnumerable<T>");

if (ienumerableInterface != null && ienumerableInterface.TypeArguments.Length > 0)
{
var elementType = ienumerableInterface.TypeArguments[0];
Comment on lines +202 to +208
Copy link

Copilot AI Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using FirstOrDefault() with a complex predicate can be inefficient. Consider using a more specific lookup or caching the constructed generic type definition for IEnumerable to improve performance when analyzing large codebases.

Suggested change
var ienumerableInterface = returnType.AllInterfaces.FirstOrDefault(
i => i.IsGenericType &&
i.ConstructedFrom?.ToDisplayString() == "System.Collections.Generic.IEnumerable<T>");
if (ienumerableInterface != null && ienumerableInterface.TypeArguments.Length > 0)
{
var elementType = ienumerableInterface.TypeArguments[0];
var ienumerableTypeSymbol = compilation.GetTypeByMetadataName("System.Collections.Generic.IEnumerable`1");
if (ienumerableTypeSymbol == null)
{
return false;
}
var ienumerableInterface = returnType.AllInterfaces
.FirstOrDefault(i => i is INamedTypeSymbol named &&
named.IsGenericType &&
SymbolEqualityComparer.Default.Equals(named.OriginalDefinition, ienumerableTypeSymbol));
if (ienumerableInterface != null && ((INamedTypeSymbol)ienumerableInterface).TypeArguments.Length > 0)
{
var elementType = ((INamedTypeSymbol)ienumerableInterface).TypeArguments[0];

Copilot uses AI. Check for mistakes.

if (IsPrimitiveOrString(elementType))
{
return true;
}

var responseType = compilation.GetTypeByMetadataName("Azure.Sdk.Tools.Cli.Models.Response");
if (responseType != null && InheritsFromOrImplements(elementType, responseType))
{
return true;
}
}

return false;
}
}
}
153 changes: 153 additions & 0 deletions tools/azsdk-cli/Azure.Sdk.Tools.Cli.Analyzer/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Tool Exception Handling Analyzer (MCP001)

## Overview

The `EnforceToolsExceptionHandlingAnalyzer` ensures that all methods decorated with the `McpServerTool` attribute properly wrap their entire body in try/catch blocks with proper exception handling.

Unhandled exceptions in MCP mode prevent the client from logging tool responses correctly as the error information is not transmitted through the protocol. Currently there is no middleware support in the MCP C# SDK that could intercept unhandled exceptions. See https://github.com/modelcontextprotocol/csharp-sdk/issues/267

## Rule: MCP001

**Title**: McpServerTool methods must wrap body in try/catch

Methods decorated with `[McpServerTool]` must have their entire body wrapped in a try/catch block that catches `System.Exception`. This ensures consistent error handling across all MCP tools.

## Requirements

1. **Try/catch wrapping**: The entire method body must be within a try statement
2. **Exception type**: Must catch `System.Exception` (not specific exception types)
3. **Variable declarations**: Local variable declarations are allowed outside the try block
4. **SetFailure() call**: The catch block should call `SetFailure()` to mark the tool as failed

## Migration Guide

When you encounter MCP001 violations:

```csharp
// ❌ Incorrect - no try/catch
[McpServerTool]
public async Task<Response> ProcessData(string myArg)
{
var parsedArg = myArg.Trim(" ");
var result = await DoSomething(parsedArg);
return new Response { Data = result };
}

// ✅ Correct - proper try/catch structure
[McpServerTool]
public async Task<Response> ProcessData(string myArg)
{
// Variables are allowed to be defined outside the try/catch
// if they need to be referenced in the catch block
var parsedArg = myArg.Trim(" ");

try
{
var result = await DoSomething(parsedArg);
return new Response { Data = result };
}
catch (Exception ex)
{
SetFailure();
logger.LogError(ex, "Error processing data");
return new Response { ResponseError = $"Error processing data for {parsedArg}: {ex.Message}" };
}
}
```

# Tool Service Registration Analyzer (MCP002)

## Overview

The `EnforceToolsListAnalyzer` ensures that every class inheriting from `MCPTool` is properly registered in the `SharedOptions.ToolsList` static list, otherwise they will not be loaded at startup.

## Rule: MCP002

**Title**: Every MCPTool must be listed in SharedOptions.ToolsList

All non-abstract classes that inherit from `Azure.Sdk.Tools.Cli.Contract.MCPTool` must be included as `typeof(ClassName)` entries in the `SharedOptions.ToolsList` static field (`Azure.Sdk.Tools.Cli/Commands/SharedOptions.cs`).

## Requirements

1. **Registration**: Every `MCPTool` implementation must appear in `SharedOptions.ToolsList`
2. **Typeof syntax**: Tools must be registered using `typeof(YourToolClass)`
3. **Non-abstract only**: Only concrete (non-abstract) classes are validated
4. **Compile-time check**: This validation happens at compilation end

## Migration Guide

When you encounter MCP002 violations:

```csharp
// 1. Tool implementation
public class MyCustomTool : MCPTool
{
public override Command GetCommand() { /* implementation */ }
public override Task HandleCommand(InvocationContext ctx, CancellationToken ct) { /* implementation */ }
}

// 2. Add it to SharedOptions.ToolsList in Azure.Sdk.Tools.Cli/Commands/SharedOptions.cs
public static readonly List<Type> ToolsList = [
typeof(ExistingTool1),
typeof(ExistingTool2),
typeof(MyCustomTool), // ← Add your new tool here
// ... other tools
];
```

# Tool Return Type Analyzer (MCP003)

## Overview

The `EnforceToolsReturnTypesAnalyzer` ensures that all public non-static methods in classes within
the `Azure.Sdk.Tools.Cli.Tools` namespace return only approved types at compile time.

## Rule: MCP003

**Title**: Tool methods must return Response types, built-in value types, or string

This excludes inherited methods `GetCommand`, `HandleCommand`, etc.

## Allowed Return Types

1. Classes implementing `Azure.Sdk.Tools.Cli.Models.Response`
1. `string`
1. Primitive types (`int`, `bool`, etc.)
1. `IEnumerable<T>` of any of the above
1. `Task` (for `Task<T>` T must be any of the above)
1. `void`

## Migration Guide

When you encounter MCP003 violations:

**For custom objects**: Make them inherit from Response or wrap in Response
```csharp
// Instead of:
public async Task<CustomData> GetData() { }

// Option 1: Make CustomData inherit Response
public class CustomData : Response { }

// Option 2: Wrap in Response type
public async Task<CustomDataResponse> GetData() { }
```

Exceptions are not handled from top-level tool methods.
To bubble errors up in a way that can be formatted for supported callers (CLI, MCP, etc.),
return the custom type and set the inherited `ResponseError` or `ResponseErrors` property:

```csharp
try
{
// Tool business logic here
}
catch (Exception ex)
{
SetFailure();
logger.LogError(ex, "Error running tool");
return new CustomData {
ResponseError: $"Error running tool: {ex.Message}";
}
}
```
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
using Azure.Sdk.Tools.Cli.Models;
using Azure.Sdk.Tools.Cli.Services;
using Azure.Sdk.Tools.Cli.Tests.TestHelpers;
using Azure.Sdk.Tools.Cli.Tools.HelloWorldTool;
using Azure.Sdk.Tools.Cli.Tools;

namespace Azure.Sdk.Tools.Cli.Tests;

Expand Down
Loading