-
Notifications
You must be signed in to change notification settings - Fork 2
DGAI: Create HyDe utility function - ENG-292 #147
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 9 commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
0f9ad1c
add api endpoint
sid597 8d0cf3e
address coderabbit review
sid597 349a0cb
hyde utility
sid597 7c4685c
address review
sid597 7c0e97d
Merge branch 'add-embedding-endpoint' into hyde-utility
sid597 eaa9c8f
redundant code
sid597 8e7a928
Merge branch 'add-embedding-endpoint' into hyde-utility
sid597 0033fb1
address coderabbit review
sid597 2cc7e23
fx
sid597 7054186
use arrow functions
sid597 1be7a55
build off the existing type
sid597 7060255
named parameters
sid597 d127db4
hyde functionality from bigger pr
sid597 a5d7f79
undo style guide changes
sid597 33e6cbe
simpler fence
sid597 e5290c0
Merge branch 'eng-233-suggestive-mode-internal' into hyde-utility
sid597 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,217 @@ | ||
| export type EmbeddingVector = number[]; | ||
|
|
||
| export type CandidateNodeWithEmbedding = { | ||
sid597 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| text: string; | ||
| uid: string; | ||
| type: string; | ||
| embedding: EmbeddingVector; | ||
| }; | ||
|
|
||
| export type SuggestedNode = { | ||
sid597 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| text: string; | ||
| uid: string; | ||
| type: string; | ||
| }; | ||
|
|
||
| export type RelationTriplet = [string, string, string]; | ||
sid597 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| export type HypotheticalNodeGenerator = ( | ||
| node: string, | ||
| relationType: RelationTriplet, | ||
| ) => Promise<string>; | ||
|
|
||
| export type EmbeddingFunc = (text: string) => Promise<EmbeddingVector>; | ||
|
|
||
| export type SearchResultItem = { | ||
| object: SuggestedNode; | ||
| score: number; | ||
| }; | ||
| export type SearchFunc = ( | ||
| queryEmbedding: EmbeddingVector, | ||
| indexData: CandidateNodeWithEmbedding[], | ||
| options: { topK: number }, | ||
| ) => Promise<SearchResultItem[]>; | ||
|
|
||
| export const ANTHROPIC_API_URL = | ||
| "https://discoursegraphs.com/api/llm/anthropic/chat"; | ||
| export const ANTHROPIC_MODEL = "claude-3-sonnet-20240229"; | ||
| export const ANTHROPIC_REQUEST_TIMEOUT_MS = 30_000; | ||
|
|
||
| export const generateHypotheticalNode: HypotheticalNodeGenerator = async ( | ||
| node: string, | ||
| relationType: RelationTriplet, | ||
| ): Promise<string> => { | ||
| const [relationLabel, relatedNodeText, relatedNodeFormat] = relationType; | ||
|
|
||
| const userPromptContent = `Given the source discourse node \\\`\\\`\\\`${node}\\\`\\\`\\\`, \nand considering the relation \\\`\\\`\\\`${relationLabel}\\\`\\\`\\\` \nwhich typically connects to a node of type \\\`\\\`\\\`${relatedNodeText}\\\`\\\`\\\` \n(formatted like \\\`\\\`\\\`${relatedNodeFormat}\\\`\\\`\\\`), \ngenerate a hypothetical related discourse node text that would plausibly fit this relationship. \nOnly return the text of the hypothetical node.`; | ||
sid597 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| const requestBody = { | ||
| documents: [{ role: "user", content: userPromptContent }], | ||
| passphrase: "", | ||
| settings: { | ||
| model: ANTHROPIC_MODEL, | ||
| maxTokens: 104, | ||
| temperature: 0.9, | ||
| }, | ||
| }; | ||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| let response: Response | null = null; | ||
| try { | ||
| const signal = AbortSignal.timeout(ANTHROPIC_REQUEST_TIMEOUT_MS); | ||
| response = await fetch(ANTHROPIC_API_URL, { | ||
| method: "POST", | ||
| headers: { | ||
| "Content-Type": "application/json", | ||
| }, | ||
| body: JSON.stringify(requestBody), | ||
| signal, | ||
| }); | ||
coderabbitai[bot] marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| if (!response.ok) { | ||
| const errorText = await response.text(); | ||
| console.error( | ||
| `Claude API request failed with status ${response.status}. Response Text: ${errorText}`, | ||
| ); | ||
| throw new Error( | ||
| `Claude API request failed with status ${response.status}: ${errorText.substring(0, 500)}`, | ||
| ); | ||
| } | ||
|
|
||
| const body = await response.json().catch(() => null); | ||
| if (!body || typeof body.completion !== "string") { | ||
| console.error("Claude API returned unexpected payload:", body); | ||
| throw new Error("Claude API returned unexpected payload"); | ||
| } | ||
|
|
||
| return body.completion.trim(); | ||
| } catch (error) { | ||
| if ( | ||
| error instanceof Error && | ||
| (error.name === "AbortError" || error.name === "TimeoutError") | ||
| ) { | ||
| console.error( | ||
| "Error during fetch for Claude API: Request timed out", | ||
| error, | ||
| ); | ||
| return `Error: Failed to generate hypothetical node. Request timed out.`; | ||
| } | ||
| console.error("Error during fetch for Claude API:", error); | ||
| return `Error: Failed to generate hypothetical node. ${ | ||
| error instanceof Error ? error.message : String(error) | ||
| }`; | ||
| } | ||
| }; | ||
|
|
||
| async function searchAgainstCandidates( | ||
sid597 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| hypotheticalTexts: string[], | ||
sid597 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| indexData: CandidateNodeWithEmbedding[], | ||
| embeddingFunction: EmbeddingFunc, | ||
| searchFunction: SearchFunc, | ||
| ): Promise<SearchResultItem[][]> { | ||
| const allSearchResults = await Promise.all( | ||
| hypotheticalTexts.map(async (hypoText) => { | ||
| try { | ||
| const queryEmbedding = await embeddingFunction(hypoText); | ||
| return await searchFunction(queryEmbedding, indexData, { | ||
| topK: indexData.length, | ||
| }); | ||
| } catch (error) { | ||
| console.error( | ||
| `Error searching for hypothetical node "${hypoText}":`, | ||
| error, | ||
| ); | ||
| return []; | ||
| } | ||
| }), | ||
| ); | ||
| return allSearchResults; | ||
| } | ||
|
|
||
| function combineScores( | ||
sid597 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| allSearchResults: SearchResultItem[][], | ||
| ): Map<string, number> { | ||
| const maxScores = new Map<string, number>(); | ||
| for (const resultSet of allSearchResults) { | ||
| for (const result of resultSet) { | ||
| const currentMaxScore = maxScores.get(result.object.uid) ?? -Infinity; | ||
| if (result.score > currentMaxScore) { | ||
| maxScores.set(result.object.uid, result.score); | ||
| } | ||
| } | ||
| } | ||
| return maxScores; | ||
| } | ||
|
|
||
| function rankNodes( | ||
| maxScores: Map<string, number>, | ||
| candidateNodes: CandidateNodeWithEmbedding[], | ||
| ): SuggestedNode[] { | ||
| const nodeMap = new Map<string, CandidateNodeWithEmbedding>( | ||
| candidateNodes.map((node) => [node.uid, node]), | ||
| ); | ||
| const combinedResults = Array.from(maxScores.entries()) | ||
| .map(([uid, score]) => { | ||
| const node = nodeMap.get(uid); | ||
| return node ? { node, score } : undefined; | ||
| }) | ||
| .filter(Boolean) as { node: CandidateNodeWithEmbedding; score: number }[]; | ||
|
|
||
| combinedResults.sort((a, b) => b.score - a.score); | ||
| return combinedResults.map((item) => ({ | ||
| text: item.node.text, | ||
| uid: item.node.uid, | ||
| type: item.node.type, | ||
| })); | ||
| } | ||
|
|
||
| export const findSimilarNodesUsingHyde = async ( | ||
| candidateNodes: CandidateNodeWithEmbedding[], | ||
sid597 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| currentNodeText: string, | ||
| relationTriplets: RelationTriplet[], | ||
| options: { | ||
sid597 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| hypotheticalNodeGenerator: HypotheticalNodeGenerator; | ||
| embeddingFunction: EmbeddingFunc; | ||
| searchFunction: SearchFunc; | ||
| }, | ||
| ): Promise<SuggestedNode[]> => { | ||
| const { hypotheticalNodeGenerator, embeddingFunction, searchFunction } = | ||
| options; | ||
|
|
||
| if (candidateNodes.length === 0) { | ||
| return []; | ||
| } | ||
|
|
||
| try { | ||
| const indexData = candidateNodes; | ||
|
|
||
| const hypotheticalNodePromises = []; | ||
| for (const relationType of relationTriplets) { | ||
| hypotheticalNodePromises.push( | ||
| hypotheticalNodeGenerator(currentNodeText, relationType), | ||
| ); | ||
| } | ||
| const hypotheticalNodeTexts = ( | ||
| await Promise.all(hypotheticalNodePromises) | ||
| ).filter((text) => !text.startsWith("Error:")); | ||
|
|
||
| if (hypotheticalNodeTexts.length === 0) { | ||
| console.error("Failed to generate any valid hypothetical nodes."); | ||
| return []; | ||
| } | ||
|
|
||
| const allSearchResults = await searchAgainstCandidates( | ||
| hypotheticalNodeTexts, | ||
| indexData, | ||
| embeddingFunction, | ||
| searchFunction, | ||
| ); | ||
|
|
||
| const maxScores = combineScores(allSearchResults); | ||
|
|
||
| const rankedNodes = rankNodes(maxScores, candidateNodes); | ||
|
|
||
| return rankedNodes; | ||
| } catch (error) { | ||
| console.error("Error in findSimilarNodesUsingHyde:", error); | ||
| return []; | ||
| } | ||
| }; | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,97 @@ | ||
| import { NextRequest, NextResponse } from "next/server"; | ||
| import OpenAI from "openai"; | ||
| import cors from "~/utils/llm/cors"; | ||
|
|
||
| const apiKey = process.env.OPENAI_API_KEY; | ||
|
|
||
| if (!apiKey) { | ||
| console.error( | ||
| "Missing OPENAI_API_KEY environment variable. The embeddings API will not function.", | ||
| ); | ||
| } | ||
|
|
||
| const openai = apiKey ? new OpenAI({ apiKey }) : null; | ||
|
|
||
| type RequestBody = { | ||
| input: string | string[]; | ||
| model?: string; | ||
| dimensions?: number; | ||
| encoding_format?: "float" | "base64"; | ||
| }; | ||
|
|
||
| const OPENAI_REQUEST_TIMEOUT_MS = 30000; | ||
|
|
||
| export async function POST(req: NextRequest): Promise<NextResponse> { | ||
| let response: NextResponse; | ||
|
|
||
| if (!apiKey) { | ||
| response = NextResponse.json( | ||
| { | ||
| error: "Server configuration error.", | ||
| details: "Embeddings service is not configured.", | ||
| }, | ||
| { status: 500 }, | ||
| ); | ||
| return cors(req, response) as NextResponse; | ||
| } | ||
|
|
||
| try { | ||
| const body: RequestBody = await req.json(); | ||
| const { | ||
| input, | ||
| model = "text-embedding-3-small", | ||
| dimensions, | ||
| encoding_format = "float", | ||
| } = body; | ||
|
|
||
| if (!input || (Array.isArray(input) && input.length === 0)) { | ||
| response = NextResponse.json( | ||
| { error: "Input text cannot be empty." }, | ||
| { status: 400 }, | ||
| ); | ||
| return cors(req, response) as NextResponse; | ||
| } | ||
|
|
||
| const options: OpenAI.EmbeddingCreateParams = { | ||
| model: model, | ||
| input: input, | ||
| encoding_format: encoding_format, | ||
| }; | ||
|
|
||
| if (dimensions && model.startsWith("text-embedding-3")) { | ||
| options.dimensions = dimensions; | ||
| } | ||
|
|
||
| const embeddingsPromise = openai!.embeddings.create(options); | ||
| const timeoutPromise = new Promise((_, reject) => | ||
| setTimeout( | ||
| () => reject(new Error("OpenAI API request timeout")), | ||
| OPENAI_REQUEST_TIMEOUT_MS, | ||
| ), | ||
| ); | ||
|
|
||
| const openAIResponse = await Promise.race([ | ||
| embeddingsPromise, | ||
| timeoutPromise, | ||
| ]); | ||
|
|
||
| response = NextResponse.json(openAIResponse, { status: 200 }); | ||
| } catch (error: unknown) { | ||
| console.error("Error calling OpenAI Embeddings API:", error); | ||
| const errorMessage = | ||
| error instanceof Error ? error.message : "Unknown error"; | ||
| response = NextResponse.json( | ||
| { | ||
| error: "Failed to generate embeddings.", | ||
| details: errorMessage, | ||
| }, | ||
| { status: 500 }, | ||
| ); | ||
| } | ||
|
|
||
| return cors(req, response) as NextResponse; | ||
| } | ||
|
|
||
| export async function OPTIONS(req: NextRequest): Promise<NextResponse> { | ||
| return cors(req, new NextResponse(null, { status: 204 })) as NextResponse; | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.