Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 209 additions & 0 deletions apps/roam/src/utils/hyde.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
export type EmbeddingVector = number[];

export type CandidateNodeWithEmbedding = {
text: string;
uid: string;
type: string;
embedding: EmbeddingVector;
};
export type SuggestedNode = {
text: string;
uid: string;
type: string;
};

export type RelationTriplet = [string, string, string];

export type HypotheticalNodeGenerator = (
node: string,
relationType: RelationTriplet,
) => Promise<string>;
export type EmbeddingFunc = (text: string) => Promise<EmbeddingVector>;

export type SearchResultItem = {
object: SuggestedNode;
score: number;
};
export type SearchFunc = (
queryEmbedding: EmbeddingVector,
indexData: CandidateNodeWithEmbedding[],
options: { topK: number },
) => Promise<SearchResultItem[]>;

const ANTHROPIC_API_URL = "https://discoursegraphs.com/api/llm/anthropic/chat";
const ANTHROPIC_MODEL = "claude-3-7-sonnet-latest";
const ANTHROPIC_REQUEST_TIMEOUT_MS = 30000;

export const generateHypotheticalNode: HypotheticalNodeGenerator = async (
node: string,
relationType: RelationTriplet,
): Promise<string> => {
const [relationLabel, relatedNodeText, relatedNodeFormat] = relationType;

const userPromptContent = `Given the source discourse node "${node}", and considering the relation
"${relationLabel}" which typically connects to a node of type "${relatedNodeText}"
(formatted like "${relatedNodeFormat}"), generate a hypothetical related discourse
node text that would plausibly fit this relationship. Only return the text of the hypothetical node.`;

const requestBody = {
documents: [{ role: "user", content: userPromptContent }],
passphrase: "",
settings: {
model: ANTHROPIC_MODEL,
maxTokens: 104,
temperature: 0.9,
},
};

let response: Response | null = null;
try {
const signal = AbortSignal.timeout(ANTHROPIC_REQUEST_TIMEOUT_MS);
response = await fetch(ANTHROPIC_API_URL, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(requestBody),
signal,
});

if (!response.ok) {
const errorText = await response.text();
console.error(
`Claude API request failed with status ${response.status}. Response Text: ${errorText}`,
);
throw new Error(
`Claude API request failed with status ${response.status}: ${errorText.substring(0, 500)}`,
);
}

const generatedText = await response.text();

return generatedText;
} catch (error) {
if (error instanceof Error && error.name === "TimeoutError") {
console.error(
"Error during fetch for Claude API: Request timed out",
error,
);
return `Error: Failed to generate hypothetical node. Request timed out.`;
}
console.error("Error during fetch for Claude API:", error);
return `Error: Failed to generate hypothetical node. ${
error instanceof Error ? error.message : String(error)
}`;
}
};

async function searchAgainstCandidates(
hypotheticalTexts: string[],
indexData: CandidateNodeWithEmbedding[],
embeddingFunction: EmbeddingFunc,
searchFunction: SearchFunc,
): Promise<SearchResultItem[][]> {
const allSearchResults: SearchResultItem[][] = [];
for (const hypoText of hypotheticalTexts) {
try {
const queryEmbedding = await embeddingFunction(hypoText);
const results = await searchFunction(queryEmbedding, indexData, {
topK: indexData.length,
});
allSearchResults.push(results);
} catch (error) {
console.error(
`Error searching for hypothetical node "${hypoText}":`,
error,
);
}
}
return allSearchResults;
}

function combineScores(
allSearchResults: SearchResultItem[][],
): Map<string, number> {
const maxScores = new Map<string, number>();
for (const resultSet of allSearchResults) {
for (const result of resultSet) {
const currentMaxScore = maxScores.get(result.object.uid) ?? -Infinity;
if (result.score > currentMaxScore) {
maxScores.set(result.object.uid, result.score);
}
}
}
return maxScores;
}

function rankNodes(
maxScores: Map<string, number>,
candidateNodes: CandidateNodeWithEmbedding[],
): SuggestedNode[] {
const nodeMap = new Map<string, CandidateNodeWithEmbedding>(
candidateNodes.map((node) => [node.uid, node]),
);
const combinedResults = Array.from(maxScores.entries()).map(
([uid, score]) => ({
node: nodeMap.get(uid)!,
score: score,
}),
);
combinedResults.sort((a, b) => b.score - a.score);
return combinedResults.map((item) => ({
text: item.node.text,
uid: item.node.uid,
type: item.node.type,
}));
}

export const findSimilarNodesUsingHyde = async (
candidateNodes: CandidateNodeWithEmbedding[],
currentNodeText: string,
relationTriplets: RelationTriplet[],
options: {
hypotheticalNodeGenerator: HypotheticalNodeGenerator;
embeddingFunction: EmbeddingFunc;
searchFunction: SearchFunc;
},
): Promise<SuggestedNode[]> => {
const { hypotheticalNodeGenerator, embeddingFunction, searchFunction } =
options;

if (candidateNodes.length === 0) {
return [];
}

try {
const indexData = candidateNodes;

const hypotheticalNodePromises = [];
for (const relationType of relationTriplets) {
hypotheticalNodePromises.push(
hypotheticalNodeGenerator(currentNodeText, relationType),
);
}
const hypotheticalNodeTexts = (
await Promise.all(hypotheticalNodePromises)
).filter((text) => !text.startsWith("Error:"));

if (hypotheticalNodeTexts.length === 0) {
console.error("Failed to generate any valid hypothetical nodes.");
return [];
}

const allSearchResults = await searchAgainstCandidates(
hypotheticalNodeTexts,
indexData,
embeddingFunction,
searchFunction,
);

const maxScores = combineScores(allSearchResults);

const rankedNodes = rankNodes(maxScores, candidateNodes);

return rankedNodes;
} catch (error) {
console.error("Error in findSimilarNodesUsingHyde:", error);
return [];
}
};
101 changes: 101 additions & 0 deletions apps/website/app/api/embeddings/openai/small/route.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import { NextRequest, NextResponse } from "next/server";
import OpenAI from "openai";
import cors from "~/utils/llm/cors";

const apiKey = process.env.OPENAI_API_KEY;

if (!apiKey) {
console.error(
"Missing OPENAI_API_KEY environment variable. The embeddings API will not function.",
);
}

const openai = apiKey ? new OpenAI({ apiKey }) : null;

type RequestBody = {
input: string | string[];
model?: string;
dimensions?: number;
encoding_format?: "float" | "base64";
};

const OPENAI_REQUEST_TIMEOUT_MS = 30000;

export async function POST(req: NextRequest): Promise<NextResponse> {
let response: NextResponse;

if (!apiKey) {
response = NextResponse.json(
{
error: "Server configuration error.",
details: "Embeddings service is not configured.",
},
{ status: 500 },
);
return cors(req, response) as NextResponse;
}

try {
if (req.method === "OPTIONS") {
return cors(req, new NextResponse(null, { status: 204 })) as NextResponse;
}

const body: RequestBody = await req.json();
const {
input,
model = "text-embedding-3-small",
dimensions,
encoding_format = "float",
} = body;

if (!input || (Array.isArray(input) && input.length === 0)) {
response = NextResponse.json(
{ error: "Input text cannot be empty." },
{ status: 400 },
);
return cors(req, response) as NextResponse;
}

const options: OpenAI.EmbeddingCreateParams = {
model: model,
input: input,
encoding_format: encoding_format,
};

if (dimensions && model.startsWith("text-embedding-3")) {
options.dimensions = dimensions;
}

const embeddingsPromise = openai!.embeddings.create(options);
const timeoutPromise = new Promise((_, reject) =>
setTimeout(
() => reject(new Error("OpenAI API request timeout")),
OPENAI_REQUEST_TIMEOUT_MS,
),
);

const openAIResponse = await Promise.race([
embeddingsPromise,
timeoutPromise,
]);

response = NextResponse.json(openAIResponse, { status: 200 });
} catch (error: unknown) {
console.error("Error calling OpenAI Embeddings API:", error);
const errorMessage =
error instanceof Error ? error.message : "Unknown error";
response = NextResponse.json(
{
error: "Failed to generate embeddings.",
details: errorMessage,
},
{ status: 500 },
);
}

return cors(req, response) as NextResponse;
}

export async function OPTIONS(req: NextRequest): Promise<NextResponse> {
return cors(req, new NextResponse(null, { status: 204 })) as NextResponse;
}
1 change: 1 addition & 0 deletions apps/website/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"@sindresorhus/slugify": "^2.2.1",
"gray-matter": "^4.0.3",
"next": "^15.0.3",
"openai": "^4.97.0",
"react": "19.0.0-rc-66855b96-20241106",
"react-dom": "19.0.0-rc-66855b96-20241106",
"rehype-parse": "^9.0.1",
Expand Down
Loading