diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/tool/SyncMcpToolProvider.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/tool/SyncMcpToolProvider.java index fd492bd..a11465d 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/tool/SyncMcpToolProvider.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/tool/SyncMcpToolProvider.java @@ -19,13 +19,16 @@ import java.lang.reflect.Method; import java.util.Arrays; import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.function.Predicate; import java.util.stream.Stream; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springaicommunity.mcp.annotation.McpTool; -import org.springaicommunity.mcp.method.tool.ReactiveUtils; import org.springaicommunity.mcp.method.tool.ReturnMode; import org.springaicommunity.mcp.method.tool.SyncMcpToolMethodCallback; import org.springaicommunity.mcp.method.tool.utils.ClassUtils; @@ -47,15 +50,58 @@ public class SyncMcpToolProvider { private static final Logger logger = LoggerFactory.getLogger(SyncMcpToolProvider.class); - private final List toolObjects; + protected final List toolObjects; + + // optional set of classes defining groups of annotated McpTools methods + protected final Class[] toolGroups; /** * Create a new SyncMcpToolProvider. * @param toolObjects the objects containing methods annotated with {@link McpTool} + * @param toolGroups optional array of classes defining the tool groups that all + * toolObjects are required to implement + * @exception IllegalArgumentException thrown if toolObjects is null, or any of the + * specified toolGroups are not implemented by all of the toolObjects */ - public SyncMcpToolProvider(List toolObjects) { + public SyncMcpToolProvider(List toolObjects, Class... toolGroups) { Assert.notNull(toolObjects, "toolObjects cannot be null"); this.toolObjects = toolObjects; + Assert.notNull(toolGroups, "toolGroups cannot be null"); + this.toolGroups = toolGroups; + // verify that every toolObject is instance of all toolGroups + this.toolObjects.forEach(toolObject -> { + Arrays.asList(this.toolGroups) + .forEach(clazz -> Assert.isTrue(clazz.isInstance(toolObject), + String.format("toolObject=%s is not an instance of %s", toolObject, clazz.getName()))); + }); + } + + public SyncMcpToolProvider(Object toolObject, Class... toolGroups) { + this(List.of(toolObject), toolGroups); + } + + protected Method[] doGetMethods(Class toolGroup) { + // For interfaces, getMethods() gets super interface methods + if (toolGroup.isInterface()) { + return toolGroup.getMethods(); + } + else { + return toolGroup.getDeclaredMethods(); + } + } + + protected String doGetFullyQualifiedToolName(String annotationToolName, Class toolGroup) { + return (this.toolGroups.length == 0) ? annotationToolName + : new StringBuffer(toolGroup.getName()).append(".").append(annotationToolName).toString(); + } + + protected Class[] doGetClasses(Object toolObject) { + return (this.toolGroups.length == 0) ? new Class[] { toolObject.getClass() } : this.toolGroups; + } + + protected Predicate distinctByName(Function nameExtractor) { + Map map = new ConcurrentHashMap<>(); + return t -> map.putIfAbsent(nameExtractor.apply(t), Boolean.TRUE) == null; } /** @@ -65,82 +111,87 @@ public SyncMcpToolProvider(List toolObjects) { * methods are found */ public List getToolSpecifications() { - - List toolSpecs = this.toolObjects.stream() - .map(toolObject -> Stream.of(doGetClassMethods(toolObject)) - .filter(method -> method.isAnnotationPresent(McpTool.class)) - .filter(method -> !Mono.class.isAssignableFrom(method.getReturnType())) - .map(mcpToolMethod -> { - - McpTool toolAnnotation = doGetMcpToolAnnotation(mcpToolMethod); - - String toolName = Utils.hasText(toolAnnotation.name()) ? toolAnnotation.name() - : mcpToolMethod.getName(); - - String toolDescription = toolAnnotation.description(); - - // Check if method has CallToolRequest parameter - boolean hasCallToolRequestParam = Arrays.stream(mcpToolMethod.getParameterTypes()) - .anyMatch(type -> CallToolRequest.class.isAssignableFrom(type)); - - String inputSchema; - if (hasCallToolRequestParam) { - // For methods with CallToolRequest, generate minimal schema or - // use the one from the request - // The schema generation will handle this appropriately - inputSchema = JsonSchemaGenerator.generateForMethodInput(mcpToolMethod); - logger.debug("Tool method '{}' uses CallToolRequest parameter, using minimal schema", toolName); - } - else { - inputSchema = JsonSchemaGenerator.generateForMethodInput(mcpToolMethod); - } - - var toolBuilder = McpSchema.Tool.builder() - .name(toolName) - .description(toolDescription) - .inputSchema(inputSchema); - - // Tool annotations - if (toolAnnotation.annotations() != null) { - var toolAnnotations = toolAnnotation.annotations(); - toolBuilder.annotations(new McpSchema.ToolAnnotations(toolAnnotations.title(), - toolAnnotations.readOnlyHint(), toolAnnotations.destructiveHint(), - toolAnnotations.idempotentHint(), toolAnnotations.openWorldHint(), null)); - } - - // ReactiveUtils.isReactiveReturnTypeOfCallToolResult(mcpToolMethod); - - // Generate Output Schema from the method return type. - // Output schema is not generated for primitive types, void, - // CallToolResult, simple value types (String, etc.) - // or if generateOutputSchema attribute is set to false. - Class methodReturnType = mcpToolMethod.getReturnType(); - if (toolAnnotation.generateOutputSchema() && methodReturnType != null - && methodReturnType != CallToolResult.class && methodReturnType != Void.class - && methodReturnType != void.class && !ClassUtils.isPrimitiveOrWrapper(methodReturnType) - && !ClassUtils.isSimpleValueType(methodReturnType)) { - - toolBuilder.outputSchema(JsonSchemaGenerator.generateFromClass(methodReturnType)); - } - - var tool = toolBuilder.build(); - - boolean useStructuredOtput = tool.outputSchema() != null; - - ReturnMode returnMode = useStructuredOtput ? ReturnMode.STRUCTURED - : (methodReturnType == Void.TYPE || methodReturnType == void.class ? ReturnMode.VOID - : ReturnMode.TEXT); - - BiFunction methodCallback = new SyncMcpToolMethodCallback( - returnMode, mcpToolMethod, toolObject); - - var toolSpec = SyncToolSpecification.builder().tool(tool).callHandler(methodCallback).build(); - - return toolSpec; - }) - .toList()) - .flatMap(List::stream) - .toList(); + List toolSpecs = this.toolObjects.stream().map(toolObject -> { + return Stream.of(doGetClasses(toolObject)).map(toolGroup -> { + return Stream.of(doGetMethods(toolGroup)) + .filter(method -> method.isAnnotationPresent(McpTool.class)) + .filter(method -> !Mono.class.isAssignableFrom(method.getReturnType())) + .map(mcpToolMethod -> { + + McpTool toolAnnotation = doGetMcpToolAnnotation(mcpToolMethod); + + String annotationToolName = Utils.hasText(toolAnnotation.name()) ? toolAnnotation.name() + : mcpToolMethod.getName(); + + String toolName = doGetFullyQualifiedToolName(annotationToolName, toolGroup); + + String toolDescription = toolAnnotation.description(); + + // Check if method has CallToolRequest parameter + boolean hasCallToolRequestParam = Arrays.stream(mcpToolMethod.getParameterTypes()) + .anyMatch(type -> CallToolRequest.class.isAssignableFrom(type)); + + String inputSchema; + if (hasCallToolRequestParam) { + // For methods with CallToolRequest, generate minimal schema + // or + // use the one from the request + // The schema generation will handle this appropriately + inputSchema = JsonSchemaGenerator.generateForMethodInput(mcpToolMethod); + logger.debug("Tool method '{}' uses CallToolRequest parameter, using minimal schema", + toolName); + } + else { + inputSchema = JsonSchemaGenerator.generateForMethodInput(mcpToolMethod); + } + + var toolBuilder = McpSchema.Tool.builder() + .name(toolName) + .description(toolDescription) + .inputSchema(inputSchema); + + // Tool annotations + if (toolAnnotation.annotations() != null) { + var toolAnnotations = toolAnnotation.annotations(); + toolBuilder.annotations(new McpSchema.ToolAnnotations(toolAnnotations.title(), + toolAnnotations.readOnlyHint(), toolAnnotations.destructiveHint(), + toolAnnotations.idempotentHint(), toolAnnotations.openWorldHint(), null)); + } + + // ReactiveUtils.isReactiveReturnTypeOfCallToolResult(mcpToolMethod); + + // Generate Output Schema from the method return type. + // Output schema is not generated for primitive types, void, + // CallToolResult, simple value types (String, etc.) + // or if generateOutputSchema attribute is set to false. + Class methodReturnType = mcpToolMethod.getReturnType(); + if (toolAnnotation.generateOutputSchema() && methodReturnType != null + && methodReturnType != CallToolResult.class && methodReturnType != Void.class + && methodReturnType != void.class && !ClassUtils.isPrimitiveOrWrapper(methodReturnType) + && !ClassUtils.isSimpleValueType(methodReturnType)) { + + toolBuilder.outputSchema(JsonSchemaGenerator.generateFromClass(methodReturnType)); + } + + var tool = toolBuilder.build(); + + boolean useStructuredOtput = tool.outputSchema() != null; + + ReturnMode returnMode = useStructuredOtput ? ReturnMode.STRUCTURED + : (methodReturnType == Void.TYPE || methodReturnType == void.class ? ReturnMode.VOID + : ReturnMode.TEXT); + + BiFunction methodCallback = new SyncMcpToolMethodCallback( + returnMode, mcpToolMethod, toolObject); + + var toolSpec = SyncToolSpecification.builder().tool(tool).callHandler(methodCallback).build(); + + return toolSpec; + + }) + .toList(); + }).flatMap(List::stream).toList(); + }).flatMap(List::stream).filter(distinctByName(s -> s.tool().name())).toList(); if (toolSpecs.isEmpty()) { logger.warn("No tool methods found in the provided tool objects: {}", this.toolObjects); @@ -149,10 +200,6 @@ public List getToolSpecifications() { return toolSpecs; } - protected Method[] doGetClassMethods(Object bean) { - return bean.getClass().getDeclaredMethods(); - } - protected McpTool doGetMcpToolAnnotation(Method method) { return method.getAnnotation(McpTool.class); } diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/provider/tool/SyncMcpToolProviderTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/provider/tool/SyncMcpToolProviderTests.java index 83e784e..340188b 100644 --- a/mcp-annotations/src/test/java/org/springaicommunity/mcp/provider/tool/SyncMcpToolProviderTests.java +++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/provider/tool/SyncMcpToolProviderTests.java @@ -46,6 +46,13 @@ void testConstructorWithNullToolObjects() { .hasMessageContaining("toolObjects cannot be null"); } + @Test + void testConstructorWithInvalidToolGroups() { + assertThatThrownBy(() -> new SyncMcpToolProvider(List.of(new Object()), List.class)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining(" is not an instance of java.util.List"); + } + @Test void testGetToolSpecificationsWithSingleValidTool() { // Create a class with only one valid tool method @@ -200,6 +207,102 @@ public String secondTool(String input) { assertThat(toolSpecs.get(0).tool().name()).isNotEqualTo(toolSpecs.get(1).tool().name()); } + @Test + void testGetToolSpecificationsWithMultipleToolMethodsToolGroup() { + class MultipleToolMethods { + + @McpTool(name = "tool1", description = "First tool") + public String firstTool(String input) { + return "First: " + input; + } + + @McpTool(name = "tool2", description = "Second tool") + public String secondTool(String input) { + return "Second: " + input; + } + + } + + MultipleToolMethods toolObject = new MultipleToolMethods(); + SyncMcpToolProvider provider = new SyncMcpToolProvider(List.of(toolObject), MultipleToolMethods.class); + + List toolSpecs = provider.getToolSpecifications(); + + assertThat(toolSpecs).hasSize(2); + assertThat(toolSpecs.get(0).tool().name()).isIn(MultipleToolMethods.class.getName() + ".tool1", + MultipleToolMethods.class.getName() + ".tool2"); + assertThat(toolSpecs.get(1).tool().name()).isIn(MultipleToolMethods.class.getName() + ".tool1", + MultipleToolMethods.class.getName() + ".tool2"); + assertThat(toolSpecs.get(0).tool().name()).isNotEqualTo(toolSpecs.get(1).tool().name()); + } + + interface MultipleToolMethodsIntf { + + @McpTool(name = "tool1", description = "First tool") + public String firstTool(String input); + + @McpTool(name = "tool2", description = "Second tool") + public String secondTool(String input); + + } + + @Test + void testGetToolSpecificationsWithMultipleToolMethodsInterfaceToolGroup() { + class MultipleToolMethods implements MultipleToolMethodsIntf { + + public String firstTool(String input) { + return "First: " + input; + } + + public String secondTool(String input) { + return "Second: " + input; + } + + } + + MultipleToolMethods toolObject = new MultipleToolMethods(); + SyncMcpToolProvider provider = new SyncMcpToolProvider(List.of(toolObject), MultipleToolMethodsIntf.class); + + List toolSpecs = provider.getToolSpecifications(); + + assertThat(toolSpecs).hasSize(2); + assertThat(toolSpecs.get(0).tool().name()).isIn(MultipleToolMethodsIntf.class.getName() + ".tool1", + MultipleToolMethodsIntf.class.getName() + ".tool2"); + assertThat(toolSpecs.get(1).tool().name()).isIn(MultipleToolMethodsIntf.class.getName() + ".tool1", + MultipleToolMethodsIntf.class.getName() + ".tool2"); + assertThat(toolSpecs.get(0).tool().name()).isNotEqualTo(toolSpecs.get(1).tool().name()); + } + + @Test + void testGetToolSpecificationsWithMultipleToolMethodsInterfaceToolGroupMultipleObjects() { + class MultipleToolMethods implements MultipleToolMethodsIntf { + + public String firstTool(String input) { + return "First: " + input; + } + + public String secondTool(String input) { + return "Second: " + input; + } + + } + + MultipleToolMethods toolObject1 = new MultipleToolMethods(); + MultipleToolMethods toolObject2 = new MultipleToolMethods(); + + SyncMcpToolProvider provider = new SyncMcpToolProvider(List.of(toolObject1, toolObject2), + MultipleToolMethodsIntf.class); + + List toolSpecs = provider.getToolSpecifications(); + + assertThat(toolSpecs).hasSize(2); + assertThat(toolSpecs.get(0).tool().name()).isIn(MultipleToolMethodsIntf.class.getName() + ".tool1", + MultipleToolMethodsIntf.class.getName() + ".tool2"); + assertThat(toolSpecs.get(1).tool().name()).isIn(MultipleToolMethodsIntf.class.getName() + ".tool1", + MultipleToolMethodsIntf.class.getName() + ".tool2"); + assertThat(toolSpecs.get(0).tool().name()).isNotEqualTo(toolSpecs.get(1).tool().name()); + } + @Test void testGetToolSpecificationsWithMultipleToolObjects() { class FirstToolObject {