Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
217 changes: 217 additions & 0 deletions apps/roam/src/utils/hyde.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
export type EmbeddingVector = number[];

export type CandidateNodeWithEmbedding = {
text: string;
uid: string;
type: string;
embedding: EmbeddingVector;
};

export type SuggestedNode = {
text: string;
uid: string;
type: string;
};

export type RelationTriplet = [string, string, string];

export type HypotheticalNodeGenerator = (
node: string,
relationType: RelationTriplet,
) => Promise<string>;

export type EmbeddingFunc = (text: string) => Promise<EmbeddingVector>;

export type SearchResultItem = {
object: SuggestedNode;
score: number;
};
export type SearchFunc = (
queryEmbedding: EmbeddingVector,
indexData: CandidateNodeWithEmbedding[],
options: { topK: number },
) => Promise<SearchResultItem[]>;

export const ANTHROPIC_API_URL =
"https://discoursegraphs.com/api/llm/anthropic/chat";
export const ANTHROPIC_MODEL = "claude-3-sonnet-20240229";
export const ANTHROPIC_REQUEST_TIMEOUT_MS = 30_000;

export const generateHypotheticalNode: HypotheticalNodeGenerator = async (
node: string,
relationType: RelationTriplet,
): Promise<string> => {
const [relationLabel, relatedNodeText, relatedNodeFormat] = relationType;

const userPromptContent = `Given the source discourse node \\\`\\\`\\\`${node}\\\`\\\`\\\`, \nand considering the relation \\\`\\\`\\\`${relationLabel}\\\`\\\`\\\` \nwhich typically connects to a node of type \\\`\\\`\\\`${relatedNodeText}\\\`\\\`\\\` \n(formatted like \\\`\\\`\\\`${relatedNodeFormat}\\\`\\\`\\\`), \ngenerate a hypothetical related discourse node text that would plausibly fit this relationship. \nOnly return the text of the hypothetical node.`;
const requestBody = {
documents: [{ role: "user", content: userPromptContent }],
passphrase: "",
settings: {
model: ANTHROPIC_MODEL,
maxTokens: 104,
temperature: 0.9,
},
};

let response: Response | null = null;
try {
const signal = AbortSignal.timeout(ANTHROPIC_REQUEST_TIMEOUT_MS);
response = await fetch(ANTHROPIC_API_URL, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(requestBody),
signal,
});

if (!response.ok) {
const errorText = await response.text();
console.error(
`Claude API request failed with status ${response.status}. Response Text: ${errorText}`,
);
throw new Error(
`Claude API request failed with status ${response.status}: ${errorText.substring(0, 500)}`,
);
}

const body = await response.json().catch(() => null);
if (!body || typeof body.completion !== "string") {
console.error("Claude API returned unexpected payload:", body);
throw new Error("Claude API returned unexpected payload");
}

return body.completion.trim();
} catch (error) {
if (
error instanceof Error &&
(error.name === "AbortError" || error.name === "TimeoutError")
) {
console.error(
"Error during fetch for Claude API: Request timed out",
error,
);
return `Error: Failed to generate hypothetical node. Request timed out.`;
}
console.error("Error during fetch for Claude API:", error);
return `Error: Failed to generate hypothetical node. ${
error instanceof Error ? error.message : String(error)
}`;
}
};

async function searchAgainstCandidates(
hypotheticalTexts: string[],
indexData: CandidateNodeWithEmbedding[],
embeddingFunction: EmbeddingFunc,
searchFunction: SearchFunc,
): Promise<SearchResultItem[][]> {
const allSearchResults = await Promise.all(
hypotheticalTexts.map(async (hypoText) => {
try {
const queryEmbedding = await embeddingFunction(hypoText);
return await searchFunction(queryEmbedding, indexData, {
topK: indexData.length,
});
} catch (error) {
console.error(
`Error searching for hypothetical node "${hypoText}":`,
error,
);
return [];
}
}),
);
return allSearchResults;
}

function combineScores(
allSearchResults: SearchResultItem[][],
): Map<string, number> {
const maxScores = new Map<string, number>();
for (const resultSet of allSearchResults) {
for (const result of resultSet) {
const currentMaxScore = maxScores.get(result.object.uid) ?? -Infinity;
if (result.score > currentMaxScore) {
maxScores.set(result.object.uid, result.score);
}
}
}
return maxScores;
}

function rankNodes(
maxScores: Map<string, number>,
candidateNodes: CandidateNodeWithEmbedding[],
): SuggestedNode[] {
const nodeMap = new Map<string, CandidateNodeWithEmbedding>(
candidateNodes.map((node) => [node.uid, node]),
);
const combinedResults = Array.from(maxScores.entries())
.map(([uid, score]) => {
const node = nodeMap.get(uid);
return node ? { node, score } : undefined;
})
.filter(Boolean) as { node: CandidateNodeWithEmbedding; score: number }[];

combinedResults.sort((a, b) => b.score - a.score);
return combinedResults.map((item) => ({
text: item.node.text,
uid: item.node.uid,
type: item.node.type,
}));
}

export const findSimilarNodesUsingHyde = async (
candidateNodes: CandidateNodeWithEmbedding[],
currentNodeText: string,
relationTriplets: RelationTriplet[],
options: {
hypotheticalNodeGenerator: HypotheticalNodeGenerator;
embeddingFunction: EmbeddingFunc;
searchFunction: SearchFunc;
},
): Promise<SuggestedNode[]> => {
const { hypotheticalNodeGenerator, embeddingFunction, searchFunction } =
options;

if (candidateNodes.length === 0) {
return [];
}

try {
const indexData = candidateNodes;

const hypotheticalNodePromises = [];
for (const relationType of relationTriplets) {
hypotheticalNodePromises.push(
hypotheticalNodeGenerator(currentNodeText, relationType),
);
}
const hypotheticalNodeTexts = (
await Promise.all(hypotheticalNodePromises)
).filter((text) => !text.startsWith("Error:"));

if (hypotheticalNodeTexts.length === 0) {
console.error("Failed to generate any valid hypothetical nodes.");
return [];
}

const allSearchResults = await searchAgainstCandidates(
hypotheticalNodeTexts,
indexData,
embeddingFunction,
searchFunction,
);

const maxScores = combineScores(allSearchResults);

const rankedNodes = rankNodes(maxScores, candidateNodes);

return rankedNodes;
} catch (error) {
console.error("Error in findSimilarNodesUsingHyde:", error);
return [];
}
};
97 changes: 97 additions & 0 deletions apps/website/app/api/embeddings/openai/small/route.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import { NextRequest, NextResponse } from "next/server";
import OpenAI from "openai";
import cors from "~/utils/llm/cors";

const apiKey = process.env.OPENAI_API_KEY;

if (!apiKey) {
console.error(
"Missing OPENAI_API_KEY environment variable. The embeddings API will not function.",
);
}

const openai = apiKey ? new OpenAI({ apiKey }) : null;

type RequestBody = {
input: string | string[];
model?: string;
dimensions?: number;
encoding_format?: "float" | "base64";
};

const OPENAI_REQUEST_TIMEOUT_MS = 30000;

export async function POST(req: NextRequest): Promise<NextResponse> {
let response: NextResponse;

if (!apiKey) {
response = NextResponse.json(
{
error: "Server configuration error.",
details: "Embeddings service is not configured.",
},
{ status: 500 },
);
return cors(req, response) as NextResponse;
}

try {
const body: RequestBody = await req.json();
const {
input,
model = "text-embedding-3-small",
dimensions,
encoding_format = "float",
} = body;

if (!input || (Array.isArray(input) && input.length === 0)) {
response = NextResponse.json(
{ error: "Input text cannot be empty." },
{ status: 400 },
);
return cors(req, response) as NextResponse;
}

const options: OpenAI.EmbeddingCreateParams = {
model: model,
input: input,
encoding_format: encoding_format,
};

if (dimensions && model.startsWith("text-embedding-3")) {
options.dimensions = dimensions;
}

const embeddingsPromise = openai!.embeddings.create(options);
const timeoutPromise = new Promise((_, reject) =>
setTimeout(
() => reject(new Error("OpenAI API request timeout")),
OPENAI_REQUEST_TIMEOUT_MS,
),
);

const openAIResponse = await Promise.race([
embeddingsPromise,
timeoutPromise,
]);

response = NextResponse.json(openAIResponse, { status: 200 });
} catch (error: unknown) {
console.error("Error calling OpenAI Embeddings API:", error);
const errorMessage =
error instanceof Error ? error.message : "Unknown error";
response = NextResponse.json(
{
error: "Failed to generate embeddings.",
details: errorMessage,
},
{ status: 500 },
);
}

return cors(req, response) as NextResponse;
}

export async function OPTIONS(req: NextRequest): Promise<NextResponse> {
return cors(req, new NextResponse(null, { status: 204 })) as NextResponse;
}
1 change: 1 addition & 0 deletions apps/website/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"@sindresorhus/slugify": "^2.2.1",
"gray-matter": "^4.0.3",
"next": "^15.0.3",
"openai": "^4.97.0",
"react": "19.0.0-rc-66855b96-20241106",
"react-dom": "19.0.0-rc-66855b96-20241106",
"rehype-parse": "^9.0.1",
Expand Down
Loading