diff --git a/apps/roam/src/utils/hyde.ts b/apps/roam/src/utils/hyde.ts new file mode 100644 index 000000000..1e24635e9 --- /dev/null +++ b/apps/roam/src/utils/hyde.ts @@ -0,0 +1,375 @@ +import { Result } from "./types"; +import { getNodeEnv } from "roamjs-components/util/env"; + +type ApiEmbeddingResponse = { + data: Array<{ + embedding: number[]; + }>; +}; + +type ApiSupabaseResultItem = { + roam_uid: string; + text_content: string; + similarity: number; +}; + +export type EmbeddingVectorType = number[]; + +export type CandidateNodeWithEmbedding = Result & { + type: string; + embedding: EmbeddingVectorType; +}; + +export type SuggestedNode = Result & { + type: string; +}; + +export type RelationDetails = { + relationLabel: string; + relatedNodeText: string; + relatedNodeFormat: string; +}; + +export type NodeSearchResult = { + object: { uid: string; text: string }; + score: number; +}; + +type HypotheticalNodeGenerator = (params: { + node: string; + relationType: RelationDetails; +}) => Promise; + +type EmbeddingFunc = (text: string) => Promise; + +type SearchFunc = (params: { + queryEmbedding: EmbeddingVectorType; + indexData: CandidateNodeWithEmbedding[]; +}) => Promise; + +const API_CONFIG = { + LLM: { + URL: "https://discoursegraphs.com/api/llm/openai/chat", + MODEL: "gpt-4.1", + TIMEOUT_MS: 30_000, + MAX_TOKENS: 104, + TEMPERATURE: 0.9, + }, + BASE_URL: { + DEV: "http://localhost:3000", + PROD: "https://discoursegraphs.com", + }, + EMBEDDINGS: { + PATH: "/api/embeddings/openai/small", + }, + SUPABASE: { + MATCH_EMBEDDINGS_PATH: + "/api/supabase/rpc/match-embeddings-for-subset-nodes", + }, +} as const; + +const getBaseUrl = (): string => { + const isDevelopment = getNodeEnv() === "development"; + return isDevelopment ? API_CONFIG.BASE_URL.DEV : API_CONFIG.BASE_URL.PROD; +}; + +const handleApiError = async ( + response: Response, + context: string, +): Promise => { + const errorText = await response.text(); + let errorData; + try { + errorData = JSON.parse(errorText); + } catch (e) { + errorData = { error: `Server responded with ${response.status}` }; + } + console.error( + `${context} failed with status ${response.status}. Error:`, + errorData, + ); + throw new Error( + errorData.error || + `${context} failed with status ${response.status}. Response: ${errorText}`, + ); +}; + +const generateHypotheticalNode: HypotheticalNodeGenerator = async ({ + node, + relationType, +}) => { + const { relationLabel, relatedNodeText, relatedNodeFormat } = relationType; + + const userPromptContent = `Given the source discourse node \`\`\`${node}\`\`\`, +and considering the relation \`\`\`${relationLabel}\`\`\` +which typically connects to a node of type \`\`\`${relatedNodeText}\`\`\` +(formatted like \`\`\`${relatedNodeFormat}\`\`\`), +generate a hypothetical related discourse node text that would plausibly fit this relationship. +Only return the text of the hypothetical node.`; + const requestBody = { + documents: [{ role: "user", content: userPromptContent }], + passphrase: "", + settings: { + model: API_CONFIG.LLM.MODEL, + maxTokens: API_CONFIG.LLM.MAX_TOKENS, + temperature: API_CONFIG.LLM.TEMPERATURE, + }, + }; + + let response: Response | null = null; + try { + const signal = AbortSignal.timeout(API_CONFIG.LLM.TIMEOUT_MS); + response = await fetch(API_CONFIG.LLM.URL, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify(requestBody), + signal, + }); + + if (!response.ok) { + await handleApiError(response, "Hypothetical node generation"); + } + + return await response.text(); + } catch (error: unknown) { + if ( + error instanceof Error && + (error.name === "AbortError" || error.name === "TimeoutError") + ) { + console.error("Hypothetical node generation timed out", error); + return `Error: Failed to generate hypothetical node. Request timed out.`; + } + console.error("Hypothetical node generation failed:", error); + return `Error: Failed to generate hypothetical node. ${ + error instanceof Error ? error.message : String(error) + }`; + } +}; + +const createEmbedding: EmbeddingFunc = async ( + text: string, +): Promise => { + if (!text.trim()) throw new Error("Input text for embedding is empty."); + + const apiUrl = `${getBaseUrl()}${API_CONFIG.EMBEDDINGS.PATH}`; + + try { + const response = await fetch(apiUrl, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ input: text }), + }); + + if (!response.ok) { + await handleApiError(response, "Embedding creation"); + } + + const data = (await response.json()) as ApiEmbeddingResponse; + if (!data?.data?.[0]?.embedding) { + throw new Error("Invalid API response format from embedding service."); + } + return data.data[0].embedding; + } catch (error: unknown) { + console.error("Error creating embedding:", error); + throw error; + } +}; + +const searchEmbeddings: SearchFunc = async ({ + queryEmbedding, + indexData, +}): Promise => { + if (!indexData?.length) { + return []; + } + + const subsetRoamUids = indexData.map((node) => node.uid); + const fullApiUrl = `${getBaseUrl()}${API_CONFIG.SUPABASE.MATCH_EMBEDDINGS_PATH}`; + + try { + const response = await fetch(fullApiUrl, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + queryEmbedding: queryEmbedding, + subsetRoamUids: subsetRoamUids, + }), + }); + + if (!response.ok) { + await handleApiError(response, "Embedding search"); + } + + const results = await response.json(); + + if (!Array.isArray(results)) { + console.error("Embedding search response was not an array:", results); + throw new Error("Invalid API response format: Expected an array."); + } + + const mappedResults = results.map((item: ApiSupabaseResultItem) => ({ + object: { uid: item.roam_uid, text: item.text_content }, + score: item.similarity, + })); + return mappedResults; + } catch (error: unknown) { + let errorMessage = `Embedding search failed for ${fullApiUrl}.`; + let errorStack; + if (error instanceof Error) { + errorMessage += ` Message: ${error.message}`; + errorStack = error.stack; + } else { + errorMessage += ` Error: ${String(error)}`; + } + console.error( + `Embedding search exception:`, + errorMessage, + errorStack && `Stack: ${errorStack}`, + ); + return []; + } +}; + +const searchAgainstCandidates = async ({ + hypotheticalTexts, + indexData, +}: { + hypotheticalTexts: string[]; + indexData: CandidateNodeWithEmbedding[]; +}): Promise => { + if (!hypotheticalTexts?.length || !indexData?.length) { + return []; + } + + const results = await Promise.all( + hypotheticalTexts.map(async (hypoText) => { + try { + const queryEmbedding = await createEmbedding(hypoText); + return await searchEmbeddings({ + queryEmbedding, + indexData, + }); + } catch (error: unknown) { + let errorMessage = `Search failed for hypothetical text "${hypoText}".`; + let errorStack; + if (error instanceof Error) { + errorMessage += ` Message: ${error.message}`; + errorStack = error.stack; + } else { + errorMessage += ` Error: ${String(error)}`; + } + console.error( + `Search exception:`, + errorMessage, + errorStack && `Stack: ${errorStack}`, + ); + return []; + } + }), + ); + return results; +}; + +const combineScores = ( + allSearchResults: NodeSearchResult[][], +): Map => { + const maxScores = new Map(); + allSearchResults.forEach((resultSet) => { + resultSet.forEach((result) => { + const currentMax = maxScores.get(result.object.uid) ?? -Infinity; + if (result.score > currentMax) { + maxScores.set(result.object.uid, result.score); + } + }); + }); + return maxScores; +}; + +const rankNodes = ({ + maxScores, + candidateNodes, +}: { + maxScores: Map; + candidateNodes: CandidateNodeWithEmbedding[]; +}): SuggestedNode[] => { + if (!candidateNodes?.length) { + return []; + } + + const nodeMap = new Map( + candidateNodes.map((node) => [node.uid, node]), + ); + + const combinedResults: { node: SuggestedNode; score: number }[] = []; + maxScores.forEach((score, uid) => { + const fullNode = nodeMap.get(uid); + if (fullNode) { + const { embedding, ...restOfNode } = fullNode; + const suggestedNodeObject: SuggestedNode = restOfNode; + combinedResults.push({ node: suggestedNodeObject, score }); + } + }); + + combinedResults.sort((a, b) => b.score - a.score); + return combinedResults.map((item) => item.node); +}; + +export const findSimilarNodesUsingHyde = async ({ + candidateNodes, + currentNodeText, + relationDetails, +}: { + candidateNodes: CandidateNodeWithEmbedding[]; + currentNodeText: string; + relationDetails: RelationDetails[]; +}): Promise => { + if ( + !candidateNodes?.length || + !currentNodeText?.trim() || + !relationDetails?.length + ) { + return []; + } + + try { + const hypotheticalTexts = ( + await Promise.all( + relationDetails.map((relationType) => + generateHypotheticalNode({ node: currentNodeText, relationType }), + ), + ) + ).filter((text) => !text.startsWith("Error:")); + + if (!hypotheticalTexts.length) { + console.warn("No valid hypothetical nodes were generated. Exiting."); + return []; + } + + const allSearchResults = await searchAgainstCandidates({ + hypotheticalTexts, + indexData: candidateNodes, + }); + + const maxScores = combineScores(allSearchResults); + return rankNodes({ maxScores, candidateNodes }); + } catch (error: unknown) { + let errorMessage = "Failed to find similar nodes."; + let errorStack; + if (error instanceof Error) { + errorMessage += ` Message: ${error.message}`; + errorStack = error.stack; + } else { + errorMessage += ` Error: ${String(error)}`; + } + console.error( + "Similar nodes search exception:", + errorMessage, + errorStack && `Stack: ${errorStack}`, + ); + return []; + } +}; diff --git a/apps/website/app/api/embeddings/openai/small/route.ts b/apps/website/app/api/embeddings/openai/small/route.ts index 455550770..988248acb 100644 --- a/apps/website/app/api/embeddings/openai/small/route.ts +++ b/apps/website/app/api/embeddings/openai/small/route.ts @@ -5,15 +5,18 @@ import cors from "~/utils/llm/cors"; const apiKey = process.env.OPENAI_API_KEY; if (!apiKey) { - throw new Error( + console.error( "Missing OPENAI_API_KEY environment variable. The embeddings API will not function.", ); } -const openai = new OpenAI({ apiKey }); +const openai = apiKey ? new OpenAI({ apiKey }) : null; type RequestBody = { input: string | string[]; + model?: string; + dimensions?: number; + encoding_format?: "float" | "base64"; }; const OPENAI_REQUEST_TIMEOUT_MS = 30000; @@ -21,9 +24,25 @@ const OPENAI_REQUEST_TIMEOUT_MS = 30000; export async function POST(req: NextRequest): Promise { let response: NextResponse; + if (!apiKey) { + response = NextResponse.json( + { + error: "Server configuration error.", + details: "Embeddings service is not configured.", + }, + { status: 500 }, + ); + return cors(req, response) as NextResponse; + } + try { const body: RequestBody = await req.json(); - const { input } = body; + const { + input, + model = "text-embedding-3-small", + dimensions, + encoding_format = "float", + } = body; if (!input || (Array.isArray(input) && input.length === 0)) { response = NextResponse.json( @@ -34,13 +53,16 @@ export async function POST(req: NextRequest): Promise { } const options: OpenAI.EmbeddingCreateParams = { - model: "text-embedding-3-small", + model: model, input: input, - encoding_format: "float", - dimensions: 1536, + encoding_format: encoding_format, }; - const embeddingsPromise = openai.embeddings.create(options); + if (dimensions && model.startsWith("text-embedding-3")) { + options.dimensions = dimensions; + } + + const embeddingsPromise = openai!.embeddings.create(options); const timeoutPromise = new Promise((_, reject) => setTimeout( () => reject(new Error("OpenAI API request timeout")),