Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Add support for weighted RRF
  • Loading branch information
amanrao23 committed May 5, 2025
commit 25ab6c14656f0a0ea49d1264a4c2f8c7e6578be5
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,11 @@ export class HybridQueryExecutionContext implements ExecutionContext {
}

// Initialize an array to hold ranks for each document
const sortedHybridSearchResult = this.sortHybridSearchResultByRRFScore(this.hybridSearchResult);
const componentWeights = this.extractComponentWeights();
const sortedHybridSearchResult = this.sortHybridSearchResultByRRFScore(
this.hybridSearchResult,
componentWeights,
);
// store the result to buffer
// add only data from the sortedHybridSearchResult in the buffer
sortedHybridSearchResult.forEach((item) => this.buffer.push(item.data));
Expand Down Expand Up @@ -318,6 +322,7 @@ export class HybridQueryExecutionContext implements ExecutionContext {

private sortHybridSearchResultByRRFScore(
hybridSearchResult: HybridSearchQueryResult[],
componentWeights: ComponentWeight[],
): HybridSearchQueryResult[] {
if (hybridSearchResult.length === 0) {
return [];
Expand All @@ -329,7 +334,9 @@ export class HybridQueryExecutionContext implements ExecutionContext {
// Compute ranks for each component score
for (let i = 0; i < hybridSearchResult[0].componentScores.length; i++) {
// Sort based on the i-th component score
hybridSearchResult.sort((a, b) => b.componentScores[i] - a.componentScores[i]);
hybridSearchResult.sort((a, b) =>
componentWeights[i].comparator(a.componentScores[i], b.componentScores[i]),
);

// Assign ranks
let rank = 1;
Expand All @@ -338,7 +345,7 @@ export class HybridQueryExecutionContext implements ExecutionContext {
j > 0 &&
hybridSearchResult[j].componentScores[i] !== hybridSearchResult[j - 1].componentScores[i]
) {
rank = j + 1;
++rank;
}
const rankIndex = ranksArray.findIndex(
(rankItem) => rankItem.rid === hybridSearchResult[j].rid,
Expand All @@ -347,20 +354,14 @@ export class HybridQueryExecutionContext implements ExecutionContext {
}
}

// Function to compute RRF score
const computeRRFScore = (ranks: number[], k: number): number => {
return ranks.reduce((acc, rank) => acc + 1 / (k + rank), 0);
};

// Compute RRF scores and sort based on them
const rrfScores = ranksArray.map((item) => ({
rid: item.rid,
rrfScore: computeRRFScore(item.ranks, this.RRF_CONSTANT),
rrfScore: this.computeRRFScore(item.ranks, this.RRF_CONSTANT, componentWeights),
}));

// Sort based on RRF scores
rrfScores.sort((a, b) => b.rrfScore - a.rrfScore);

// Map sorted RRF scores back to hybridSearchResult
const sortedHybridSearchResult = rrfScores.map((scoreItem) =>
hybridSearchResult.find((item) => item.rid === scoreItem.rid),
Expand Down Expand Up @@ -455,8 +456,14 @@ export class HybridQueryExecutionContext implements ExecutionContext {
globalStats: GlobalStatistics,
): QueryInfo[] {
return componentQueryInfos.map((queryInfo) => {
if (!queryInfo.hasNonStreamingOrderBy) {
throw new Error("The component query must have a non-streaming order by clause.");
let rewrittenOrderByExpressions = queryInfo.orderByExpressions;
if (queryInfo.orderBy && queryInfo.orderBy.length > 0) {
if (!queryInfo.hasNonStreamingOrderBy) {
throw new Error("The component query must have a non-streaming order by clause.");
}
rewrittenOrderByExpressions = queryInfo.orderByExpressions.map((expr) =>
this.replacePlaceholdersWorkaroud(expr, globalStats, componentQueryInfos.length),
);
}
return {
...queryInfo,
Expand All @@ -465,9 +472,7 @@ export class HybridQueryExecutionContext implements ExecutionContext {
globalStats,
componentQueryInfos.length,
),
orderByExpressions: queryInfo.orderByExpressions.map((expr) =>
this.replacePlaceholdersWorkaroud(expr, globalStats, componentQueryInfos.length),
),
orderByExpressions: rewrittenOrderByExpressions,
};
});
}
Expand Down Expand Up @@ -531,4 +536,62 @@ export class HybridQueryExecutionContext implements ExecutionContext {
}
return query;
}

private computeRRFScore = (
ranks: number[],
k: number,
componentWeights: ComponentWeight[],
): number => {
if (ranks.length !== componentWeights.length) {
throw new Error("Ranks and component weights length mismatch");
}
let rrfScore = 0;
for (let i = 0; i < ranks.length; i++) {
const rank = ranks[i];
const weight = componentWeights[i].weight;
rrfScore += weight * (1 / (k + rank));
}
return rrfScore;
};

private extractComponentWeights(): ComponentWeight[] {
const hybridSearchQueryInfo = this.partitionedQueryExecutionInfo.hybridSearchQueryInfo;
const useDefaultComponentWeight =
!hybridSearchQueryInfo.componentWeights ||
hybridSearchQueryInfo.componentWeights.length === 0;

const result: {
weight: number;
comparator: (x: number, y: number) => number;
}[] = [];

for (let index = 0; index < hybridSearchQueryInfo.componentQueryInfos.length; ++index) {
const queryInfo = hybridSearchQueryInfo.componentQueryInfos[index];

if (queryInfo.orderBy && queryInfo.orderBy.length > 0) {
if (!queryInfo.hasNonStreamingOrderBy) {
throw new Error("The component query should have a non streaming order by");
}

if (!queryInfo.orderByExpressions || queryInfo.orderByExpressions.length !== 1) {
throw new Error("The component query should have exactly one order by expression");
}
}
const componentWeight = useDefaultComponentWeight
? 1
: hybridSearchQueryInfo.componentWeights[index];
const hasOrderBy = queryInfo.orderBy && queryInfo.orderBy.length > 0;
const sortOrder = hasOrderBy && queryInfo.orderBy[0].includes("Ascending") ? 1 : -1;
result.push({
weight: componentWeight,
comparator: (x: number, y: number) => sortOrder * (x - y),
});
}
return result;
}
}

export interface ComponentWeight {
weight: number;
comparator: (x: number, y: number) => number;
}