From c4a4bf5318070aaf0b0c41fdfdf607aa372af9ee Mon Sep 17 00:00:00 2001 From: Charles Lien Date: Thu, 23 Oct 2025 11:22:52 -0700 Subject: [PATCH] add agent-runtime to sdk --- common/src/types/contracts/database.ts | 1 + sdk/src/impl/agent-runtime.ts | 9 +- sdk/src/impl/database.ts | 43 ++- sdk/src/impl/llm.ts | 22 +- sdk/src/run-state.ts | 10 +- sdk/src/run.ts | 380 +++++++++++++++---------- sdk/src/tools/glob.ts | 5 +- 7 files changed, 296 insertions(+), 174 deletions(-) diff --git a/common/src/types/contracts/database.ts b/common/src/types/contracts/database.ts index 65d448cc4..17e527df7 100644 --- a/common/src/types/contracts/database.ts +++ b/common/src/types/contracts/database.ts @@ -6,6 +6,7 @@ type User = { email: string discord_id: string | null } +export const userColumns = ['id', 'email', 'discord_id'] as const export type UserColumn = keyof User export type GetUserInfoFromApiKeyInput = { apiKey: string diff --git a/sdk/src/impl/agent-runtime.ts b/sdk/src/impl/agent-runtime.ts index cd1bbb127..797a4609a 100644 --- a/sdk/src/impl/agent-runtime.ts +++ b/sdk/src/impl/agent-runtime.ts @@ -1,6 +1,3 @@ -import { trackEvent } from '@codebuff/common/analytics' -import { success } from '@codebuff/common/util/error' - import { addAgentStep, fetchAgentFromDatabase, @@ -9,12 +6,14 @@ import { startAgentRun, } from './database' import { promptAiSdk, promptAiSdkStream, promptAiSdkStructured } from './llm' +import { trackEvent } from '../../../common/src/analytics' +import { success } from '../../../common/src/util/error' import type { AgentRuntimeDeps, AgentRuntimeScopedDeps, -} from '@codebuff/common/types/contracts/agent-runtime' -import type { Logger } from '@codebuff/common/types/contracts/logger' +} from '../../../common/src/types/contracts/agent-runtime' +import type { Logger } from '../../../common/src/types/contracts/logger' export function getAgentRuntimeImpl( params: { diff --git a/sdk/src/impl/database.ts b/sdk/src/impl/database.ts index 4e6eb9857..b22482fba 100644 --- a/sdk/src/impl/database.ts +++ b/sdk/src/impl/database.ts @@ -1,5 +1,5 @@ -import { getErrorObject } from '@codebuff/common/util/error' - +import { userColumns } from '../../../common/src/types/contracts/database' +import { getErrorObject } from '../../../common/src/util/error' import { WEBSITE_URL } from '../constants' import type { @@ -10,15 +10,35 @@ import type { GetUserInfoFromApiKeyOutput, StartAgentRunFn, UserColumn, -} from '@codebuff/common/types/contracts/database' -import type { ParamsOf } from '@codebuff/common/types/function-params' +} from '../../../common/src/types/contracts/database' +import type { ParamsOf } from '../../../common/src/types/function-params' + +const userInfoCache: Record< + string, + Awaited> +> = {} export async function getUserInfoFromApiKey( params: GetUserInfoFromApiKeyInput, ): GetUserInfoFromApiKeyOutput { const { apiKey, fields, logger } = params - const urlParams = new URLSearchParams({ apiKey, fields: fields.join(',') }) + if (apiKey in userInfoCache) { + const userInfo = userInfoCache[apiKey] + if (userInfo === null) { + return userInfo + } + return Object.fromEntries( + fields.map((field) => [field, userInfo[field]]), + ) as { + [K in (typeof fields)[number]]: (typeof userInfo)[K] + } + } + + const urlParams = new URLSearchParams({ + apiKey, + fields: userColumns.join(','), + }) const url = new URL(`/api/v1/me?${urlParams}`, WEBSITE_URL) try { @@ -36,7 +56,8 @@ export async function getUserInfoFromApiKey( ) return null } - return response.json() + + userInfoCache[apiKey] = await response.json() } catch (error) { logger.error( { error: getErrorObject(error), apiKey, fields }, @@ -44,6 +65,16 @@ export async function getUserInfoFromApiKey( ) return null } + + const userInfo = userInfoCache[apiKey] + if (userInfo === null) { + return userInfo + } + return Object.fromEntries( + fields.map((field) => [field, userInfo[field]]), + ) as { + [K in (typeof fields)[number]]: (typeof userInfo)[K] + } } export async function fetchAgentFromDatabase( diff --git a/sdk/src/impl/llm.ts b/sdk/src/impl/llm.ts index 02b0abe67..874558415 100644 --- a/sdk/src/impl/llm.ts +++ b/sdk/src/impl/llm.ts @@ -1,25 +1,25 @@ import { createOpenAICompatible } from '@ai-sdk/openai-compatible' +import { streamText, APICallError, generateText, generateObject } from 'ai' + +import { PROFIT_MARGIN } from '../../../common/src/old-constants' +import { buildArray } from '../../../common/src/util/array' +import { getErrorObject } from '../../../common/src/util/error' +import { convertCbToModelMessages } from '../../../common/src/util/messages' +import { StopSequenceHandler } from '../../../common/src/util/stop-sequence' import { checkLiveUserInput, getLiveUserInputIds, -} from '@codebuff/agent-runtime/live-user-inputs' -import { PROFIT_MARGIN } from '@codebuff/common/old-constants' -import { buildArray } from '@codebuff/common/util/array' -import { getErrorObject } from '@codebuff/common/util/error' -import { convertCbToModelMessages } from '@codebuff/common/util/messages' -import { StopSequenceHandler } from '@codebuff/common/util/stop-sequence' -import { streamText, APICallError, generateText, generateObject } from 'ai' - +} from '../../../packages/agent-runtime/src/live-user-inputs' import { WEBSITE_URL } from '../constants' -import type { LanguageModelV2 } from '@ai-sdk/provider' import type { PromptAiSdkFn, PromptAiSdkStreamFn, PromptAiSdkStructuredInput, PromptAiSdkStructuredOutput, -} from '@codebuff/common/types/contracts/llm' -import type { ParamsOf } from '@codebuff/common/types/function-params' +} from '../../../common/src/types/contracts/llm' +import type { ParamsOf } from '../../../common/src/types/function-params' +import type { LanguageModelV2 } from '@ai-sdk/provider' import type { OpenRouterProviderOptions, OpenRouterUsageAccounting, diff --git a/sdk/src/run-state.ts b/sdk/src/run-state.ts index ef0f58530..0ffb46927 100644 --- a/sdk/src/run-state.ts +++ b/sdk/src/run-state.ts @@ -1,7 +1,6 @@ import * as os from 'os' import path from 'path' -import { getFileTokenScores } from '@codebuff/code-map/parse' import { cloneDeep } from 'lodash' import { @@ -9,9 +8,11 @@ import { getAllFilePaths, } from '../../common/src/project-file-tree' import { getInitialSessionState } from '../../common/src/types/session-state' +import { getFileTokenScores } from '../../packages/code-map/src/parse' import type { CustomToolDefinition } from './custom-tool' import type { AgentDefinition } from '../../common/src/templates/initial-agents-dir/types/agent-definition' +import type { CodebuffFileSystem } from '../../common/src/types/filesystem' import type { Message } from '../../common/src/types/messages/codebuff-message' import type { AgentOutput, @@ -21,7 +22,6 @@ import type { CustomToolDefinitions, FileTreeNode, } from '../../common/src/util/file' -import type { CodebuffFileSystem } from '@codebuff/common/types/filesystem' export type RunState = { sessionState: SessionState @@ -36,7 +36,7 @@ export type InitialSessionStateOptions = { customToolDefinitions?: CustomToolDefinition[] maxAgentSteps?: number fs?: CodebuffFileSystem -}; +} /** * Processes agent definitions array and converts handleSteps functions to strings @@ -167,11 +167,11 @@ function deriveKnowledgeFiles( export function initialSessionState( options: InitialSessionStateOptions, -): Promise; +): Promise export function initialSessionState( cwd: string, options?: Omit, -): Promise; +): Promise export async function initialSessionState( arg1: string | InitialSessionStateOptions, arg2?: Omit, diff --git a/sdk/src/run.ts b/sdk/src/run.ts index 16af35ea7..5536d3d7e 100644 --- a/sdk/src/run.ts +++ b/sdk/src/run.ts @@ -2,6 +2,8 @@ import path from 'path' import { cloneDeep } from 'lodash' +import { getAgentRuntimeImpl } from './impl/agent-runtime' +import { getUserInfoFromApiKey } from './impl/database' import { initialSessionState, applyOverridesToSessionState } from './run-state' import { stripToolCallPayloads } from './tool-xml-buffer' import { @@ -14,15 +16,18 @@ import { glob } from './tools/glob' import { listDirectory } from './tools/list-directory' import { getFiles } from './tools/read-files' import { runTerminalCommand } from './tools/run-terminal-command' -import { WebSocketHandler } from './websocket-client' import { MAX_AGENT_STEPS_DEFAULT } from '../../common/src/constants/agents' +import { getMCPClient, listMCPTools } from '../../common/src/mcp/client' +import { toOptionalFile } from '../../common/src/old-constants' import { toolNames } from '../../common/src/tools/constants' import { clientToolCallSchema } from '../../common/src/tools/list' import { AgentOutputSchema } from '../../common/src/types/session-state' +import { callMainPrompt } from '../../packages/agent-runtime/src/main-prompt' import type { CustomToolDefinition } from './custom-tool' import type { RunState } from './run-state' import type { ToolXmlFilterState } from './tool-xml-filter' +import type { WebSocketHandler } from './websocket-client' import type { ServerAction } from '../../common/src/actions' import type { AgentDefinition } from '../../common/src/templates/initial-agents-dir/types/agent-definition' import type { @@ -35,6 +40,8 @@ import type { CodebuffToolOutput, PublishedClientToolName, } from '../../common/src/tools/list' +import type { Logger } from '../../common/src/types/contracts/logger' +import type { CodebuffFileSystem } from '../../common/src/types/filesystem' import type { ToolResultOutput, ToolResultPart, @@ -42,7 +49,6 @@ import type { import type { PrintModeEvent } from '../../common/src/types/print-mode' import type { SessionState } from '../../common/src/types/session-state' import type { Source } from '../../common/src/types/source' -import type { CodebuffFileSystem } from '@codebuff/common/types/filesystem' export type CodebuffClientOptions = { apiKey?: string @@ -71,6 +77,7 @@ export type CodebuffClientOptions = { customToolDefinitions?: CustomToolDefinition[] fsSource?: Source + logger?: Logger } export type RunOptions = { @@ -100,6 +107,7 @@ export async function run({ customToolDefinitions, fsSource = () => require('fs'), + logger, agent, prompt, @@ -328,26 +336,127 @@ export async function run({ } } - const websocketHandler = new WebSocketHandler({ + const onResponseChunk = async ( + action: ServerAction<'response-chunk'>, + ): Promise => { + checkAborted(signal) + const { chunk } = action + if (typeof chunk === 'string') { + ensureSectionStart(ROOT_AGENT_KEY) + const { text: sanitized } = filterToolXmlFromText( + streamFilterState, + chunk, + MAX_TOOL_XML_BUFFER, + ) + + if (sanitized) { + const nextFullText = accumulateText(ROOT_AGENT_KEY, sanitized) + await emitStreamDelta(ROOT_AGENT_KEY, nextFullText) + } + } else { + const chunkType = chunk.type as string + + if ( + chunkType !== 'finish' && + chunkType !== 'subagent_finish' && + chunkType !== 'subagent-finish' + ) { + await emitPendingSection(ROOT_AGENT_KEY) + const pendingAgentId = 'agentId' in chunk ? chunk.agentId : undefined + if (pendingAgentId && pendingAgentId !== ROOT_AGENT_KEY) { + await emitPendingSection(pendingAgentId, pendingAgentId) + } + } + + if (chunkType === 'finish') { + const { text: streamTail } = filterToolXmlFromText( + streamFilterState, + '', + MAX_TOOL_XML_BUFFER, + ) + let remainder = streamTail + + if ( + streamFilterState.buffer && + !streamFilterState.buffer.includes('<') + ) { + remainder += streamFilterState.buffer + } + streamFilterState.buffer = '' + streamFilterState.activeTag = null + + if (remainder) { + const nextFullText = accumulateText(ROOT_AGENT_KEY, remainder) + await emitStreamDelta(ROOT_AGENT_KEY, nextFullText) + } + + await flushTextState(ROOT_AGENT_KEY) + + const finishAgentKey = 'agentId' in chunk ? chunk.agentId : undefined + if (finishAgentKey && finishAgentKey !== ROOT_AGENT_KEY) { + await flushTextState(finishAgentKey, finishAgentKey) + await flushSubagentState( + finishAgentKey, + (chunk as { agentType?: string }).agentType, + ) + } + } else if ( + chunkType === 'subagent_finish' || + chunkType === 'subagent-finish' + ) { + const subagentId = 'agentId' in chunk ? chunk.agentId : undefined + if (subagentId) { + await flushTextState(subagentId, subagentId) + await flushSubagentState( + subagentId, + (chunk as { agentType?: string }).agentType, + ) + } + } + + await handleEvent?.(chunk) + } + } + const onSubagentResponseChunk = async ( + action: ServerAction<'subagent-response-chunk'>, + ) => { + checkAborted(signal) + const { agentId, agentType, chunk } = action + + const state = getSubagentFilterState(agentId) + const { text: sanitized } = filterToolXmlFromText( + state, + chunk, + MAX_TOOL_XML_BUFFER, + ) + + if (sanitized && handleEvent) { + await handleEvent({ + type: 'subagent-chunk', + agentId, + agentType, + chunk: sanitized, + } as any) + } + } + + const agentRuntimeImpl = getAgentRuntimeImpl({ + logger, apiKey, - onWebsocketError: (error) => { - onError({ message: error.message }) - }, - onWebsocketReconnect: () => {}, - onRequestReconnect: async () => {}, - onResponseError: async (error) => { - onError({ message: error.message }) + handleStepsLogChunk: () => { + // Does nothing for now }, - readFiles: ({ filePaths }) => - readFiles({ - filePaths, - override: overrideTools?.read_files, - cwd, - fs, - }), - handleToolCall: (action) => - handleToolCall({ - action, + requestToolCall: async ({ userInputId, toolName, input, mcpConfig }) => { + return handleToolCall({ + action: { + type: 'tool-call-request', + requestId: crypto.randomUUID(), + userInputId, + toolName, + input, + timeout: undefined, + mcpConfig, + }, overrides: overrideTools ?? {}, customToolDefinitions: customToolDefinitions ? Object.fromEntries( @@ -356,124 +465,91 @@ export async function run({ : {}, cwd, fs, - }), - onCostResponse: async () => {}, - - onResponseChunk: async (action) => { - checkAborted(signal) - const { chunk } = action - if (typeof chunk === 'string') { - ensureSectionStart(ROOT_AGENT_KEY) - const { text: sanitized } = filterToolXmlFromText( - streamFilterState, - chunk, - MAX_TOOL_XML_BUFFER, - ) - - if (sanitized) { - const nextFullText = accumulateText(ROOT_AGENT_KEY, sanitized) - await emitStreamDelta(ROOT_AGENT_KEY, nextFullText) - } - } else { - const chunkType = chunk.type as string - - if ( - chunkType !== 'finish' && - chunkType !== 'subagent_finish' && - chunkType !== 'subagent-finish' - ) { - await emitPendingSection(ROOT_AGENT_KEY) - const pendingAgentId = - 'agentId' in chunk ? chunk.agentId : undefined - if (pendingAgentId && pendingAgentId !== ROOT_AGENT_KEY) { - await emitPendingSection(pendingAgentId, pendingAgentId) - } + }) + }, + requestMcpToolData: async ({ mcpConfig, toolNames }) => { + const mcpClientId = await getMCPClient(mcpConfig) + const tools = (await listMCPTools(mcpClientId)).tools + const filteredTools: typeof tools = [] + for (const tool of tools) { + if (!toolNames) { + filteredTools.push(tool) + continue } - - if (chunkType === 'finish') { - const { text: streamTail } = filterToolXmlFromText( - streamFilterState, - '', - MAX_TOOL_XML_BUFFER, - ) - let remainder = streamTail - - if ( - streamFilterState.buffer && - !streamFilterState.buffer.includes('<') - ) { - remainder += streamFilterState.buffer - } - streamFilterState.buffer = '' - streamFilterState.activeTag = null - - if (remainder) { - const nextFullText = accumulateText(ROOT_AGENT_KEY, remainder) - await emitStreamDelta(ROOT_AGENT_KEY, nextFullText) - } - - await flushTextState(ROOT_AGENT_KEY) - - const finishAgentKey = 'agentId' in chunk ? chunk.agentId : undefined - if (finishAgentKey && finishAgentKey !== ROOT_AGENT_KEY) { - await flushTextState(finishAgentKey, finishAgentKey) - await flushSubagentState( - finishAgentKey, - (chunk as { agentType?: string }).agentType, - ) - } - } else if ( - chunkType === 'subagent_finish' || - chunkType === 'subagent-finish' - ) { - const subagentId = 'agentId' in chunk ? chunk.agentId : undefined - if (subagentId) { - await flushTextState(subagentId, subagentId) - await flushSubagentState( - subagentId, - (chunk as { agentType?: string }).agentType, - ) - } + if (tool.name in toolNames) { + filteredTools.push(tool) + continue } + } - await handleEvent?.(chunk) + return filteredTools + }, + requestFiles: ({ filePaths }) => + readFiles({ + filePaths, + override: overrideTools?.read_files, + cwd, + fs, + }), + requestOptionalFile: async ({ filePath }) => { + const files = await readFiles({ + filePaths: [filePath], + override: overrideTools?.read_files, + cwd, + fs, + }) + return toOptionalFile(files[filePath] ?? null) + }, + sendAction: ({ action }) => { + if (action.type === 'action-error') { + onError({ message: action.message }) + return + } + if (action.type === 'response-chunk') { + onResponseChunk(action) + return + } + if (action.type === 'subagent-response-chunk') { + onSubagentResponseChunk(action) + return + } + if (action.type === 'prompt-response') { + handlePromptResponse({ + action, + resolve, + onError, + initialSessionState: sessionState, + }) + return + } + if (action.type === 'prompt-error') { + handlePromptResponse({ + action, + resolve, + onError, + initialSessionState: sessionState, + }) + return } }, - onSubagentResponseChunk: async (action) => { - checkAborted(signal) - const { agentId, agentType, chunk } = action - - const state = getSubagentFilterState(agentId) - const { text: sanitized } = filterToolXmlFromText( - state, + sendSubagentChunk: ({ + userInputId, + agentId, + agentType, + chunk, + prompt, + forwardToPrompt = true, + }) => { + onSubagentResponseChunk({ + type: 'subagent-response-chunk', + userInputId, + agentId, + agentType, chunk, - MAX_TOOL_XML_BUFFER, - ) - - if (sanitized && handleEvent) { - await handleEvent({ - type: 'subagent-chunk', - agentId, - agentType, - chunk: sanitized, - } as any) - } + prompt, + forwardToPrompt, + }) }, - - onPromptResponse: (action) => - handlePromptResponse({ - action, - resolve, - onError, - initialSessionState: sessionState, - }), - onPromptError: (action) => - handlePromptResponse({ - action, - resolve, - onError, - initialSessionState: sessionState, - }), }) // Init session state @@ -515,24 +591,38 @@ export async function run({ // Send input checkAborted(signal) - await websocketHandler.connect() - websocketHandler.sendInput({ - promptId, - prompt, - promptParams: params, - fingerprintId: fingerprintId, - costMode: 'normal', - sessionState, - toolResults: extraToolResults ?? [], - agentId, + const userInfo = await getUserInfoFromApiKey({ + ...agentRuntimeImpl, + apiKey, + fields: ['id'], }) + if (!userInfo) { + throw new Error('No user found for key') + } + const userId = userInfo.id - const result = await promise - - websocketHandler.close() + callMainPrompt({ + ...agentRuntimeImpl, + promptId, + action: { + type: 'prompt', + promptId, + prompt, + promptParams: params, + fingerprintId: fingerprintId, + costMode: 'normal', + sessionState, + toolResults: extraToolResults ?? [], + agentId, + }, + repoUrl: undefined, + repoId: undefined, + clientSessionId: promptId, + userId, + }) - return result + return promise } function requireCwd(cwd: string | undefined, toolName: string): string { diff --git a/sdk/src/tools/glob.ts b/sdk/src/tools/glob.ts index d93273374..f69f313b8 100644 --- a/sdk/src/tools/glob.ts +++ b/sdk/src/tools/glob.ts @@ -1,8 +1,9 @@ +import micromatch from 'micromatch' + import { flattenTree, getProjectFileTree, -} from '@codebuff/common/project-file-tree' -import micromatch from 'micromatch' +} from '../../../common/src/project-file-tree' import type { CodebuffToolOutput } from '../../../common/src/tools/list' import type { CodebuffFileSystem } from '../../../common/src/types/filesystem'