Skip to content

Commit e4926e5

Browse files
tzolovilayaperumalg
authored andcommitted
refactor: improve ToolCallbackProvider injection using ObjectProvider
Replace direct List<ToolCallbackProvider> injection with ObjectProvider pattern for better flexibility and backward compatibility in tool callback autoconfiguration. - Use ObjectProvider<List<ToolCallbackProvider>> and ObjectProvider<ToolCallbackProvider> for flexible bean resolution - Add merging and de-duplication logic for ToolCallbackProvider instances from multiple sources - Update StatelessToolCallbackConverterAutoConfiguration with new provider handling in syncTools() and asyncTools() methods - Update ToolCallbackConverterAutoConfiguration with similar provider handling improvements - Enhance ToolCallingAutoConfiguration toolCallbackResolver() to support both legacy List injection and new ObjectProvider pattern - Add test coverage Co-authored-by: Yanming Zhou <[email protected]> Signed-off-by: Christian Tzolov <[email protected]>
1 parent 9cec4d7 commit e4926e5

File tree

4 files changed

+350
-27
lines changed

4 files changed

+350
-27
lines changed

auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-common/src/main/java/org/springframework/ai/mcp/server/common/autoconfigure/StatelessToolCallbackConverterAutoConfiguration.java

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,11 @@ public class StatelessToolCallbackConverterAutoConfiguration {
5151
matchIfMissing = true)
5252
public List<McpStatelessServerFeatures.SyncToolSpecification> syncTools(
5353
ObjectProvider<List<ToolCallback>> toolCalls, List<ToolCallback> toolCallbackList,
54-
List<ToolCallbackProvider> toolCallbackProvider, McpServerProperties serverProperties) {
54+
ObjectProvider<List<ToolCallbackProvider>> tcbProviderList,
55+
ObjectProvider<ToolCallbackProvider> tcbProviders, McpServerProperties serverProperties) {
5556

56-
List<ToolCallback> tools = this.aggregateToolCallbacks(toolCalls, toolCallbackList, toolCallbackProvider);
57+
List<ToolCallback> tools = this.aggregateToolCallbacks(toolCalls, toolCallbackList, tcbProviderList,
58+
tcbProviders);
5759

5860
return this.toSyncToolSpecifications(tools, serverProperties);
5961
}
@@ -81,9 +83,11 @@ private List<McpStatelessServerFeatures.SyncToolSpecification> toSyncToolSpecifi
8183
@ConditionalOnProperty(prefix = McpServerProperties.CONFIG_PREFIX, name = "type", havingValue = "ASYNC")
8284
public List<McpStatelessServerFeatures.AsyncToolSpecification> asyncTools(
8385
ObjectProvider<List<ToolCallback>> toolCalls, List<ToolCallback> toolCallbackList,
84-
List<ToolCallbackProvider> toolCallbackProvider, McpServerProperties serverProperties) {
86+
ObjectProvider<List<ToolCallbackProvider>> tcbProviderList,
87+
ObjectProvider<ToolCallbackProvider> tcbProviders, McpServerProperties serverProperties) {
8588

86-
List<ToolCallback> tools = this.aggregateToolCallbacks(toolCalls, toolCallbackList, toolCallbackProvider);
89+
List<ToolCallback> tools = this.aggregateToolCallbacks(toolCalls, toolCallbackList, tcbProviderList,
90+
tcbProviders);
8791

8892
return this.toAsyncToolSpecification(tools, serverProperties);
8993
}
@@ -107,15 +111,24 @@ private List<McpStatelessServerFeatures.AsyncToolSpecification> toAsyncToolSpeci
107111
}
108112

109113
private List<ToolCallback> aggregateToolCallbacks(ObjectProvider<List<ToolCallback>> toolCalls,
110-
List<ToolCallback> toolCallbacksList, List<ToolCallbackProvider> toolCallbackProvider) {
114+
List<ToolCallback> toolCallbacksList, ObjectProvider<List<ToolCallbackProvider>> tcbProviderList,
115+
ObjectProvider<ToolCallbackProvider> tcbProviders) {
116+
117+
// Merge ToolCallbackProviders from both ObjectProviders.
118+
List<ToolCallbackProvider> totalToolCallbackProviders = new ArrayList<>(
119+
tcbProviderList.stream().flatMap(List::stream).toList());
120+
totalToolCallbackProviders.addAll(tcbProviders.stream().toList());
121+
122+
// De-duplicate ToolCallbackProviders
123+
totalToolCallbackProviders = totalToolCallbackProviders.stream().distinct().toList();
111124

112125
List<ToolCallback> tools = new ArrayList<>(toolCalls.stream().flatMap(List::stream).toList());
113126

114127
if (!CollectionUtils.isEmpty(toolCallbacksList)) {
115128
tools.addAll(toolCallbacksList);
116129
}
117130

118-
List<ToolCallback> providerToolCallbacks = toolCallbackProvider.stream()
131+
List<ToolCallback> providerToolCallbacks = totalToolCallbackProviders.stream()
119132
.map(pr -> List.of(pr.getToolCallbacks()))
120133
.flatMap(List::stream)
121134
.filter(fc -> fc instanceof ToolCallback)

auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-common/src/main/java/org/springframework/ai/mcp/server/common/autoconfigure/ToolCallbackConverterAutoConfiguration.java

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,19 @@ public class ToolCallbackConverterAutoConfiguration {
4949
@ConditionalOnProperty(prefix = McpServerProperties.CONFIG_PREFIX, name = "type", havingValue = "SYNC",
5050
matchIfMissing = true)
5151
public List<McpServerFeatures.SyncToolSpecification> syncTools(ObjectProvider<List<ToolCallback>> toolCalls,
52-
List<ToolCallback> toolCallbacksList, List<ToolCallbackProvider> toolCallbackProvider,
53-
McpServerProperties serverProperties) {
52+
List<ToolCallback> toolCallbacksList, ObjectProvider<List<ToolCallbackProvider>> tcbProviderList,
53+
ObjectProvider<ToolCallbackProvider> tcbProviders, McpServerProperties serverProperties) {
54+
55+
// Merge ToolCallbackProviders from both ObjectProviders.
56+
List<ToolCallbackProvider> totalToolCallbackProviders = new ArrayList<>(
57+
tcbProviderList.stream().flatMap(List::stream).toList());
58+
totalToolCallbackProviders.addAll(tcbProviders.stream().toList());
59+
60+
// De-duplicate ToolCallbackProviders
61+
totalToolCallbackProviders = totalToolCallbackProviders.stream().distinct().toList();
5462

55-
List<ToolCallback> tools = this.aggregateToolCallbacks(toolCalls, toolCallbacksList, toolCallbackProvider);
63+
List<ToolCallback> tools = this.aggregateToolCallbacks(toolCalls, toolCallbacksList, tcbProviderList,
64+
tcbProviders);
5665

5766
return this.toSyncToolSpecifications(tools, serverProperties);
5867
}
@@ -63,10 +72,7 @@ private List<McpServerFeatures.SyncToolSpecification> toSyncToolSpecifications(L
6372
// De-duplicate tools by their name, keeping the first occurrence of each tool
6473
// name
6574
return tools.stream() // Key: tool name
66-
.collect(Collectors.toMap(tool -> tool.getToolDefinition().name(), tool -> tool, // Value:
67-
// the
68-
// tool
69-
// itself
75+
.collect(Collectors.toMap(tool -> tool.getToolDefinition().name(), tool -> tool,
7076
(existing, replacement) -> existing)) // On duplicate key, keep the
7177
// existing tool
7278
.values()
@@ -83,10 +89,11 @@ private List<McpServerFeatures.SyncToolSpecification> toSyncToolSpecifications(L
8389
@Bean
8490
@ConditionalOnProperty(prefix = McpServerProperties.CONFIG_PREFIX, name = "type", havingValue = "ASYNC")
8591
public List<McpServerFeatures.AsyncToolSpecification> asyncTools(ObjectProvider<List<ToolCallback>> toolCalls,
86-
List<ToolCallback> toolCallbacksList, List<ToolCallbackProvider> toolCallbackProvider,
87-
McpServerProperties serverProperties) {
92+
List<ToolCallback> toolCallbacksList, ObjectProvider<List<ToolCallbackProvider>> tcbProviderList,
93+
ObjectProvider<ToolCallbackProvider> tcbProviders, McpServerProperties serverProperties) {
8894

89-
List<ToolCallback> tools = this.aggregateToolCallbacks(toolCalls, toolCallbacksList, toolCallbackProvider);
95+
List<ToolCallback> tools = this.aggregateToolCallbacks(toolCalls, toolCallbacksList, tcbProviderList,
96+
tcbProviders);
9097

9198
return this.toAsyncToolSpecification(tools, serverProperties);
9299
}
@@ -114,15 +121,24 @@ private List<McpServerFeatures.AsyncToolSpecification> toAsyncToolSpecification(
114121
}
115122

116123
private List<ToolCallback> aggregateToolCallbacks(ObjectProvider<List<ToolCallback>> toolCalls,
117-
List<ToolCallback> toolCallbacksList, List<ToolCallbackProvider> toolCallbackProvider) {
124+
List<ToolCallback> toolCallbacksList, ObjectProvider<List<ToolCallbackProvider>> tcbProviderList,
125+
ObjectProvider<ToolCallbackProvider> tcbProviders) {
126+
127+
// Merge ToolCallbackProviders from both ObjectProviders.
128+
List<ToolCallbackProvider> totalToolCallbackProviders = new ArrayList<>(
129+
tcbProviderList.stream().flatMap(List::stream).toList());
130+
totalToolCallbackProviders.addAll(tcbProviders.stream().toList());
131+
132+
// De-duplicate ToolCallbackProviders
133+
totalToolCallbackProviders = totalToolCallbackProviders.stream().distinct().toList();
118134

119135
List<ToolCallback> tools = new ArrayList<>(toolCalls.stream().flatMap(List::stream).toList());
120136

121137
if (!CollectionUtils.isEmpty(toolCallbacksList)) {
122138
tools.addAll(toolCallbacksList);
123139
}
124140

125-
List<ToolCallback> providerToolCallbacks = toolCallbackProvider.stream()
141+
List<ToolCallback> providerToolCallbacks = totalToolCallbackProviders.stream()
126142
.map(pr -> List.of(pr.getToolCallbacks()))
127143
.flatMap(List::stream)
128144
.filter(fc -> fc instanceof ToolCallback)

auto-configurations/models/tool/spring-ai-autoconfigure-model-tool/src/main/java/org/springframework/ai/model/tool/autoconfigure/ToolCallingAutoConfiguration.java

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,24 @@ public class ToolCallingAutoConfiguration {
6969
*/
7070
@Bean
7171
@ConditionalOnMissingBean
72-
ToolCallbackResolver toolCallbackResolver(GenericApplicationContext applicationContext,
73-
List<ToolCallback> toolCallbacks, List<ToolCallbackProvider> tcbProviders) {
72+
ToolCallbackResolver toolCallbackResolver(
73+
GenericApplicationContext applicationContext, // @formatter:off
74+
List<ToolCallback> toolCallbacks,
75+
// Deprecated in favor of the tcbProviders. Kept for backward compatibility.
76+
ObjectProvider<List<ToolCallbackProvider>> tcbProviderList,
77+
ObjectProvider<ToolCallbackProvider> tcbProviders) { // @formatter:on
78+
7479
List<ToolCallback> allFunctionAndToolCallbacks = new ArrayList<>(toolCallbacks);
75-
tcbProviders.stream()
80+
81+
// Merge ToolCallbackProviders from both ObjectProviders.
82+
List<ToolCallbackProvider> totalToolCallbackProviders = new ArrayList<>(
83+
tcbProviderList.stream().flatMap(List::stream).toList());
84+
totalToolCallbackProviders.addAll(tcbProviders.stream().toList());
85+
86+
// De-duplicate ToolCallbackProviders
87+
totalToolCallbackProviders = totalToolCallbackProviders.stream().distinct().toList();
88+
89+
totalToolCallbackProviders.stream()
7690
.filter(pr -> !isMcpToolCallbackProvider(ResolvableType.forInstance(pr)))
7791
.map(pr -> List.of(pr.getToolCallbacks()))
7892
.forEach(allFunctionAndToolCallbacks::addAll);

0 commit comments

Comments
 (0)