Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,16 @@
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Stream;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springaicommunity.mcp.annotation.McpTool;
import org.springaicommunity.mcp.method.tool.ReactiveUtils;
import org.springaicommunity.mcp.method.tool.ReturnMode;
import org.springaicommunity.mcp.method.tool.SyncMcpToolMethodCallback;
import org.springaicommunity.mcp.method.tool.utils.ClassUtils;
Expand All @@ -47,15 +50,58 @@ public class SyncMcpToolProvider {

private static final Logger logger = LoggerFactory.getLogger(SyncMcpToolProvider.class);

private final List<Object> toolObjects;
protected final List<Object> toolObjects;

// optional set of classes defining groups of annotated McpTools methods
protected final Class<?>[] toolGroups;

/**
* Create a new SyncMcpToolProvider.
* @param toolObjects the objects containing methods annotated with {@link McpTool}
* @param toolGroups optional array of classes defining the tool groups that all
* toolObjects are required to implement
* @exception IllegalArgumentException thrown if toolObjects is null, or any of the
* specified toolGroups are not implemented by all of the toolObjects
*/
public SyncMcpToolProvider(List<Object> toolObjects) {
public SyncMcpToolProvider(List<Object> toolObjects, Class<?>... toolGroups) {
Assert.notNull(toolObjects, "toolObjects cannot be null");
this.toolObjects = toolObjects;
Assert.notNull(toolGroups, "toolGroups cannot be null");
this.toolGroups = toolGroups;
// verify that every toolObject is instance of all toolGroups
this.toolObjects.forEach(toolObject -> {
Arrays.asList(this.toolGroups)
.forEach(clazz -> Assert.isTrue(clazz.isInstance(toolObject),
String.format("toolObject=%s is not an instance of %s", toolObject, clazz.getName())));
});
}

public SyncMcpToolProvider(Object toolObject, Class<?>... toolGroups) {
this(List.of(toolObject), toolGroups);
}

protected Method[] doGetMethods(Class<?> toolGroup) {
// For interfaces, getMethods() gets super interface methods
if (toolGroup.isInterface()) {
return toolGroup.getMethods();
}
else {
return toolGroup.getDeclaredMethods();
}
}

protected String doGetFullyQualifiedToolName(String annotationToolName, Class<?> toolGroup) {
return (this.toolGroups.length == 0) ? annotationToolName
: new StringBuffer(toolGroup.getName()).append(".").append(annotationToolName).toString();
}

protected Class<?>[] doGetClasses(Object toolObject) {
return (this.toolGroups.length == 0) ? new Class[] { toolObject.getClass() } : this.toolGroups;
}

protected <T> Predicate<T> distinctByName(Function<? super T, Object> nameExtractor) {
Map<Object, Boolean> map = new ConcurrentHashMap<>();
return t -> map.putIfAbsent(nameExtractor.apply(t), Boolean.TRUE) == null;
}

/**
Expand All @@ -65,82 +111,87 @@ public SyncMcpToolProvider(List<Object> toolObjects) {
* methods are found
*/
public List<SyncToolSpecification> getToolSpecifications() {

List<SyncToolSpecification> toolSpecs = this.toolObjects.stream()
.map(toolObject -> Stream.of(doGetClassMethods(toolObject))
.filter(method -> method.isAnnotationPresent(McpTool.class))
.filter(method -> !Mono.class.isAssignableFrom(method.getReturnType()))
.map(mcpToolMethod -> {

McpTool toolAnnotation = doGetMcpToolAnnotation(mcpToolMethod);

String toolName = Utils.hasText(toolAnnotation.name()) ? toolAnnotation.name()
: mcpToolMethod.getName();

String toolDescription = toolAnnotation.description();

// Check if method has CallToolRequest parameter
boolean hasCallToolRequestParam = Arrays.stream(mcpToolMethod.getParameterTypes())
.anyMatch(type -> CallToolRequest.class.isAssignableFrom(type));

String inputSchema;
if (hasCallToolRequestParam) {
// For methods with CallToolRequest, generate minimal schema or
// use the one from the request
// The schema generation will handle this appropriately
inputSchema = JsonSchemaGenerator.generateForMethodInput(mcpToolMethod);
logger.debug("Tool method '{}' uses CallToolRequest parameter, using minimal schema", toolName);
}
else {
inputSchema = JsonSchemaGenerator.generateForMethodInput(mcpToolMethod);
}

var toolBuilder = McpSchema.Tool.builder()
.name(toolName)
.description(toolDescription)
.inputSchema(inputSchema);

// Tool annotations
if (toolAnnotation.annotations() != null) {
var toolAnnotations = toolAnnotation.annotations();
toolBuilder.annotations(new McpSchema.ToolAnnotations(toolAnnotations.title(),
toolAnnotations.readOnlyHint(), toolAnnotations.destructiveHint(),
toolAnnotations.idempotentHint(), toolAnnotations.openWorldHint(), null));
}

// ReactiveUtils.isReactiveReturnTypeOfCallToolResult(mcpToolMethod);

// Generate Output Schema from the method return type.
// Output schema is not generated for primitive types, void,
// CallToolResult, simple value types (String, etc.)
// or if generateOutputSchema attribute is set to false.
Class<?> methodReturnType = mcpToolMethod.getReturnType();
if (toolAnnotation.generateOutputSchema() && methodReturnType != null
&& methodReturnType != CallToolResult.class && methodReturnType != Void.class
&& methodReturnType != void.class && !ClassUtils.isPrimitiveOrWrapper(methodReturnType)
&& !ClassUtils.isSimpleValueType(methodReturnType)) {

toolBuilder.outputSchema(JsonSchemaGenerator.generateFromClass(methodReturnType));
}

var tool = toolBuilder.build();

boolean useStructuredOtput = tool.outputSchema() != null;

ReturnMode returnMode = useStructuredOtput ? ReturnMode.STRUCTURED
: (methodReturnType == Void.TYPE || methodReturnType == void.class ? ReturnMode.VOID
: ReturnMode.TEXT);

BiFunction<McpSyncServerExchange, CallToolRequest, CallToolResult> methodCallback = new SyncMcpToolMethodCallback(
returnMode, mcpToolMethod, toolObject);

var toolSpec = SyncToolSpecification.builder().tool(tool).callHandler(methodCallback).build();

return toolSpec;
})
.toList())
.flatMap(List::stream)
.toList();
List<SyncToolSpecification> toolSpecs = this.toolObjects.stream().map(toolObject -> {
return Stream.of(doGetClasses(toolObject)).map(toolGroup -> {
return Stream.of(doGetMethods(toolGroup))
.filter(method -> method.isAnnotationPresent(McpTool.class))
.filter(method -> !Mono.class.isAssignableFrom(method.getReturnType()))
.map(mcpToolMethod -> {

McpTool toolAnnotation = doGetMcpToolAnnotation(mcpToolMethod);

String annotationToolName = Utils.hasText(toolAnnotation.name()) ? toolAnnotation.name()
: mcpToolMethod.getName();

String toolName = doGetFullyQualifiedToolName(annotationToolName, toolGroup);

String toolDescription = toolAnnotation.description();

// Check if method has CallToolRequest parameter
boolean hasCallToolRequestParam = Arrays.stream(mcpToolMethod.getParameterTypes())
.anyMatch(type -> CallToolRequest.class.isAssignableFrom(type));

String inputSchema;
if (hasCallToolRequestParam) {
// For methods with CallToolRequest, generate minimal schema
// or
// use the one from the request
// The schema generation will handle this appropriately
inputSchema = JsonSchemaGenerator.generateForMethodInput(mcpToolMethod);
logger.debug("Tool method '{}' uses CallToolRequest parameter, using minimal schema",
toolName);
}
else {
inputSchema = JsonSchemaGenerator.generateForMethodInput(mcpToolMethod);
}

var toolBuilder = McpSchema.Tool.builder()
.name(toolName)
.description(toolDescription)
.inputSchema(inputSchema);

// Tool annotations
if (toolAnnotation.annotations() != null) {
var toolAnnotations = toolAnnotation.annotations();
toolBuilder.annotations(new McpSchema.ToolAnnotations(toolAnnotations.title(),
toolAnnotations.readOnlyHint(), toolAnnotations.destructiveHint(),
toolAnnotations.idempotentHint(), toolAnnotations.openWorldHint(), null));
}

// ReactiveUtils.isReactiveReturnTypeOfCallToolResult(mcpToolMethod);

// Generate Output Schema from the method return type.
// Output schema is not generated for primitive types, void,
// CallToolResult, simple value types (String, etc.)
// or if generateOutputSchema attribute is set to false.
Class<?> methodReturnType = mcpToolMethod.getReturnType();
if (toolAnnotation.generateOutputSchema() && methodReturnType != null
&& methodReturnType != CallToolResult.class && methodReturnType != Void.class
&& methodReturnType != void.class && !ClassUtils.isPrimitiveOrWrapper(methodReturnType)
&& !ClassUtils.isSimpleValueType(methodReturnType)) {

toolBuilder.outputSchema(JsonSchemaGenerator.generateFromClass(methodReturnType));
}

var tool = toolBuilder.build();

boolean useStructuredOtput = tool.outputSchema() != null;

ReturnMode returnMode = useStructuredOtput ? ReturnMode.STRUCTURED
: (methodReturnType == Void.TYPE || methodReturnType == void.class ? ReturnMode.VOID
: ReturnMode.TEXT);

BiFunction<McpSyncServerExchange, CallToolRequest, CallToolResult> methodCallback = new SyncMcpToolMethodCallback(
returnMode, mcpToolMethod, toolObject);

var toolSpec = SyncToolSpecification.builder().tool(tool).callHandler(methodCallback).build();

return toolSpec;

})
.toList();
}).flatMap(List::stream).toList();
}).flatMap(List::stream).filter(distinctByName(s -> s.tool().name())).toList();

if (toolSpecs.isEmpty()) {
logger.warn("No tool methods found in the provided tool objects: {}", this.toolObjects);
Expand All @@ -149,10 +200,6 @@ public List<SyncToolSpecification> getToolSpecifications() {
return toolSpecs;
}

protected Method[] doGetClassMethods(Object bean) {
return bean.getClass().getDeclaredMethods();
}

protected McpTool doGetMcpToolAnnotation(Method method) {
return method.getAnnotation(McpTool.class);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ void testConstructorWithNullToolObjects() {
.hasMessageContaining("toolObjects cannot be null");
}

@Test
void testConstructorWithInvalidToolGroups() {
assertThatThrownBy(() -> new SyncMcpToolProvider(List.of(new Object()), List.class))
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining(" is not an instance of java.util.List");
}

@Test
void testGetToolSpecificationsWithSingleValidTool() {
// Create a class with only one valid tool method
Expand Down Expand Up @@ -200,6 +207,102 @@ public String secondTool(String input) {
assertThat(toolSpecs.get(0).tool().name()).isNotEqualTo(toolSpecs.get(1).tool().name());
}

@Test
void testGetToolSpecificationsWithMultipleToolMethodsToolGroup() {
class MultipleToolMethods {

@McpTool(name = "tool1", description = "First tool")
public String firstTool(String input) {
return "First: " + input;
}

@McpTool(name = "tool2", description = "Second tool")
public String secondTool(String input) {
return "Second: " + input;
}

}

MultipleToolMethods toolObject = new MultipleToolMethods();
SyncMcpToolProvider provider = new SyncMcpToolProvider(List.of(toolObject), MultipleToolMethods.class);

List<SyncToolSpecification> toolSpecs = provider.getToolSpecifications();

assertThat(toolSpecs).hasSize(2);
assertThat(toolSpecs.get(0).tool().name()).isIn(MultipleToolMethods.class.getName() + ".tool1",
MultipleToolMethods.class.getName() + ".tool2");
assertThat(toolSpecs.get(1).tool().name()).isIn(MultipleToolMethods.class.getName() + ".tool1",
MultipleToolMethods.class.getName() + ".tool2");
assertThat(toolSpecs.get(0).tool().name()).isNotEqualTo(toolSpecs.get(1).tool().name());
}

interface MultipleToolMethodsIntf {

@McpTool(name = "tool1", description = "First tool")
public String firstTool(String input);

@McpTool(name = "tool2", description = "Second tool")
public String secondTool(String input);

}

@Test
void testGetToolSpecificationsWithMultipleToolMethodsInterfaceToolGroup() {
class MultipleToolMethods implements MultipleToolMethodsIntf {

public String firstTool(String input) {
return "First: " + input;
}

public String secondTool(String input) {
return "Second: " + input;
}

}

MultipleToolMethods toolObject = new MultipleToolMethods();
SyncMcpToolProvider provider = new SyncMcpToolProvider(List.of(toolObject), MultipleToolMethodsIntf.class);

List<SyncToolSpecification> toolSpecs = provider.getToolSpecifications();

assertThat(toolSpecs).hasSize(2);
assertThat(toolSpecs.get(0).tool().name()).isIn(MultipleToolMethodsIntf.class.getName() + ".tool1",
MultipleToolMethodsIntf.class.getName() + ".tool2");
assertThat(toolSpecs.get(1).tool().name()).isIn(MultipleToolMethodsIntf.class.getName() + ".tool1",
MultipleToolMethodsIntf.class.getName() + ".tool2");
assertThat(toolSpecs.get(0).tool().name()).isNotEqualTo(toolSpecs.get(1).tool().name());
}

@Test
void testGetToolSpecificationsWithMultipleToolMethodsInterfaceToolGroupMultipleObjects() {
class MultipleToolMethods implements MultipleToolMethodsIntf {

public String firstTool(String input) {
return "First: " + input;
}

public String secondTool(String input) {
return "Second: " + input;
}

}

MultipleToolMethods toolObject1 = new MultipleToolMethods();
MultipleToolMethods toolObject2 = new MultipleToolMethods();

SyncMcpToolProvider provider = new SyncMcpToolProvider(List.of(toolObject1, toolObject2),
MultipleToolMethodsIntf.class);

List<SyncToolSpecification> toolSpecs = provider.getToolSpecifications();

assertThat(toolSpecs).hasSize(2);
assertThat(toolSpecs.get(0).tool().name()).isIn(MultipleToolMethodsIntf.class.getName() + ".tool1",
MultipleToolMethodsIntf.class.getName() + ".tool2");
assertThat(toolSpecs.get(1).tool().name()).isIn(MultipleToolMethodsIntf.class.getName() + ".tool1",
MultipleToolMethodsIntf.class.getName() + ".tool2");
assertThat(toolSpecs.get(0).tool().name()).isNotEqualTo(toolSpecs.get(1).tool().name());
}

@Test
void testGetToolSpecificationsWithMultipleToolObjects() {
class FirstToolObject {
Expand Down