Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sdk/cosmosdb/cosmos/review/cosmos.api.md
Original file line number Diff line number Diff line change
Expand Up @@ -1191,6 +1191,7 @@ export interface FeedOptions extends SharedOptions {
continuation?: string;
continuationToken?: string;
continuationTokenLimitInKB?: number;
disableHybridSearchQueryPlanOptimization?: boolean;
disableNonStreamingOrderByQuery?: boolean;
enableQueryControl?: boolean;
enableScanInQuery?: boolean;
Expand Down Expand Up @@ -1321,6 +1322,7 @@ export enum HTTPMethod {
// @public
export interface HybridSearchQueryInfo {
componentQueryInfos: QueryInfo[];
componentWeights?: number[];
globalStatisticsQuery: string;
requiresGlobalStatistics: boolean;
skip: number;
Expand Down
4 changes: 1 addition & 3 deletions sdk/cosmosdb/cosmos/src/ClientContext.ts
Original file line number Diff line number Diff line change
Expand Up @@ -279,9 +279,7 @@ export class ClientContext {
request.headers[HttpHeaders.IsQueryPlan] = "True";
request.headers[HttpHeaders.QueryVersion] = "1.4";
request.headers[HttpHeaders.ContentType] = QueryJsonContentType;
request.headers[HttpHeaders.SupportedQueryFeatures] = supportedQueryFeaturesBuilder(
options.disableNonStreamingOrderByQuery,
);
request.headers[HttpHeaders.SupportedQueryFeatures] = supportedQueryFeaturesBuilder(options);

if (typeof query === "string") {
request.body = { query }; // Converts query text to query object.
Expand Down
2 changes: 2 additions & 0 deletions sdk/cosmosdb/cosmos/src/common/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,8 @@ export enum QueryFeature {
ListAndSetAggregate = "ListAndSetAggregate",
CountIf = "CountIf",
HybridSearch = "HybridSearch",
WeightedRankFusion = "WeightedRankFusion",
HybridSearchSkipOrderByRewrite = "HybridSearchSkipOrderByRewrite",
}

export enum SDKSupportedCapabilities {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,11 @@ export class HybridQueryExecutionContext implements ExecutionContext {
}

// Initialize an array to hold ranks for each document
const sortedHybridSearchResult = this.sortHybridSearchResultByRRFScore(this.hybridSearchResult);
const componentWeights = this.extractComponentWeights();
const sortedHybridSearchResult = this.sortHybridSearchResultByRRFScore(
this.hybridSearchResult,
componentWeights,
);
// store the result to buffer
// add only data from the sortedHybridSearchResult in the buffer
sortedHybridSearchResult.forEach((item) => this.buffer.push(item.data));
Expand Down Expand Up @@ -318,6 +322,7 @@ export class HybridQueryExecutionContext implements ExecutionContext {

private sortHybridSearchResultByRRFScore(
hybridSearchResult: HybridSearchQueryResult[],
componentWeights: ComponentWeight[],
): HybridSearchQueryResult[] {
if (hybridSearchResult.length === 0) {
return [];
Expand All @@ -329,7 +334,9 @@ export class HybridQueryExecutionContext implements ExecutionContext {
// Compute ranks for each component score
for (let i = 0; i < hybridSearchResult[0].componentScores.length; i++) {
// Sort based on the i-th component score
hybridSearchResult.sort((a, b) => b.componentScores[i] - a.componentScores[i]);
hybridSearchResult.sort((a, b) =>
componentWeights[i].comparator(a.componentScores[i], b.componentScores[i]),
);

// Assign ranks
let rank = 1;
Expand All @@ -338,7 +345,7 @@ export class HybridQueryExecutionContext implements ExecutionContext {
j > 0 &&
hybridSearchResult[j].componentScores[i] !== hybridSearchResult[j - 1].componentScores[i]
) {
rank = j + 1;
++rank;
}
const rankIndex = ranksArray.findIndex(
(rankItem) => rankItem.rid === hybridSearchResult[j].rid,
Expand All @@ -347,20 +354,14 @@ export class HybridQueryExecutionContext implements ExecutionContext {
}
}

// Function to compute RRF score
const computeRRFScore = (ranks: number[], k: number): number => {
return ranks.reduce((acc, rank) => acc + 1 / (k + rank), 0);
};

// Compute RRF scores and sort based on them
const rrfScores = ranksArray.map((item) => ({
rid: item.rid,
rrfScore: computeRRFScore(item.ranks, this.RRF_CONSTANT),
rrfScore: this.computeRRFScore(item.ranks, this.RRF_CONSTANT, componentWeights),
}));

// Sort based on RRF scores
rrfScores.sort((a, b) => b.rrfScore - a.rrfScore);

// Map sorted RRF scores back to hybridSearchResult
const sortedHybridSearchResult = rrfScores.map((scoreItem) =>
hybridSearchResult.find((item) => item.rid === scoreItem.rid),
Expand Down Expand Up @@ -455,8 +456,14 @@ export class HybridQueryExecutionContext implements ExecutionContext {
globalStats: GlobalStatistics,
): QueryInfo[] {
return componentQueryInfos.map((queryInfo) => {
if (!queryInfo.hasNonStreamingOrderBy) {
throw new Error("The component query must have a non-streaming order by clause.");
let rewrittenOrderByExpressions = queryInfo.orderByExpressions;
if (queryInfo.orderBy && queryInfo.orderBy.length > 0) {
if (!queryInfo.hasNonStreamingOrderBy) {
throw new Error("The component query must have a non-streaming order by clause.");
}
rewrittenOrderByExpressions = queryInfo.orderByExpressions.map((expr) =>
this.replacePlaceholdersWorkaroud(expr, globalStats, componentQueryInfos.length),
);
}
return {
...queryInfo,
Expand All @@ -465,9 +472,7 @@ export class HybridQueryExecutionContext implements ExecutionContext {
globalStats,
componentQueryInfos.length,
),
orderByExpressions: queryInfo.orderByExpressions.map((expr) =>
this.replacePlaceholdersWorkaroud(expr, globalStats, componentQueryInfos.length),
),
orderByExpressions: rewrittenOrderByExpressions,
};
});
}
Expand Down Expand Up @@ -531,4 +536,62 @@ export class HybridQueryExecutionContext implements ExecutionContext {
}
return query;
}

private computeRRFScore = (
ranks: number[],
k: number,
componentWeights: ComponentWeight[],
): number => {
if (ranks.length !== componentWeights.length) {
throw new Error("Ranks and component weights length mismatch");
}
let rrfScore = 0;
for (let i = 0; i < ranks.length; i++) {
const rank = ranks[i];
const weight = componentWeights[i].weight;
rrfScore += weight * (1 / (k + rank));
}
return rrfScore;
};

private extractComponentWeights(): ComponentWeight[] {
const hybridSearchQueryInfo = this.partitionedQueryExecutionInfo.hybridSearchQueryInfo;
const useDefaultComponentWeight =
!hybridSearchQueryInfo.componentWeights ||
hybridSearchQueryInfo.componentWeights.length === 0;

const result: {
weight: number;
comparator: (x: number, y: number) => number;
}[] = [];

for (let index = 0; index < hybridSearchQueryInfo.componentQueryInfos.length; ++index) {
const queryInfo = hybridSearchQueryInfo.componentQueryInfos[index];

if (queryInfo.orderBy && queryInfo.orderBy.length > 0) {
if (!queryInfo.hasNonStreamingOrderBy) {
throw new Error("The component query should have a non streaming order by");
}

if (!queryInfo.orderByExpressions || queryInfo.orderByExpressions.length !== 1) {
throw new Error("The component query should have exactly one order by expression");
}
}
const componentWeight = useDefaultComponentWeight
? 1
: hybridSearchQueryInfo.componentWeights[index];
const hasOrderBy = queryInfo.orderBy && queryInfo.orderBy.length > 0;
const sortOrder = hasOrderBy && queryInfo.orderBy[0].includes("Ascending") ? 1 : -1;
result.push({
weight: componentWeight,
comparator: (x: number, y: number) => sortOrder * (x - y),
});
}
return result;
}
}

export interface ComponentWeight {
weight: number;
comparator: (x: number, y: number) => number;
}
4 changes: 4 additions & 0 deletions sdk/cosmosdb/cosmos/src/request/ErrorResponse.ts
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ export interface HybridSearchQueryInfo {
* Whether the query requires global statistics
*/
requiresGlobalStatistics: boolean;
/**
* Represents the weights for each component in a hybrid search query.
*/
componentWeights?: number[];
}

export type GroupByExpressions = string[];
Expand Down
6 changes: 5 additions & 1 deletion sdk/cosmosdb/cosmos/src/request/FeedOptions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,17 @@ export interface FeedOptions extends SharedOptions {
* Default: false; When set to true, it allows queries to bypass the default behavior that blocks nonStreaming queries without top or limit clauses.
*/
allowUnboundedNonStreamingQueries?: boolean;

/**
* Controls query execution behavior.
* Default: false. If set to false, the query will retry until results are ready and `maxItemCount` is reached, which can take time for large partitions with relatively small data.
* If set to true, scans partitions up to `maxDegreeOfParallelism`, adds results to the buffer, and returns what is available. If results are not ready, it returns an empty response.
*/
enableQueryControl?: boolean;
/**
* Default: false. If set to true, it disables the hybrid search query plan optimization.
* This optimization is enabled by default and is used to improve the performance of hybrid search queries.
*/
disableHybridSearchQueryPlanOptimization?: boolean;
/**
* @internal
* rid of the container.
Expand Down
19 changes: 13 additions & 6 deletions sdk/cosmosdb/cosmos/src/request/hybridSearchQueryResult.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,27 @@ export class HybridSearchQueryResult {
}

const outerPayload = document[FieldNames.Payload];
let componentScores: number[];
let data: Record<string, unknown>;

if (!outerPayload || typeof outerPayload !== "object") {
throw new Error(`${FieldNames.Payload} must exist.`);
}

const innerPayload = outerPayload[FieldNames.Payload];
if (!innerPayload || typeof innerPayload !== "object") {
throw new Error(`${FieldNames.Payload} must exist nested within the outer payload field.`);
}

const componentScores = outerPayload[FieldNames.ComponentScores];
if (innerPayload && typeof innerPayload === "object") {
// older format without query plan optimization
componentScores = outerPayload[FieldNames.ComponentScores];
data = innerPayload;
} else {
// newer format with the optimization
componentScores = document[FieldNames.ComponentScores];
data = outerPayload;
}
if (!Array.isArray(componentScores)) {
throw new Error(`${FieldNames.ComponentScores} must exist.`);
}

return new HybridSearchQueryResult(rid, componentScores, innerPayload);
return new HybridSearchQueryResult(rid, componentScores, data);
}
}
18 changes: 11 additions & 7 deletions sdk/cosmosdb/cosmos/src/utils/supportedQueryFeaturesBuilder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@
// Licensed under the MIT License.

import { QueryFeature } from "../common/index.js";
import type { FeedOptions } from "../request/FeedOptions.js";

export function supportedQueryFeaturesBuilder(disableNonStreamingOrderByQuery?: boolean): string {
if (disableNonStreamingOrderByQuery) {
return Object.keys(QueryFeature)
.filter((k) => k !== QueryFeature.NonStreamingOrderBy)
.join(", ");
} else {
return Object.keys(QueryFeature).join(", ");
export function supportedQueryFeaturesBuilder(options: FeedOptions): string {
const allFeatures = Object.keys(QueryFeature) as QueryFeature[];
const exclude: QueryFeature[] = [];

if (options.disableNonStreamingOrderByQuery) {
exclude.push(QueryFeature.NonStreamingOrderBy);
}
if (options.disableHybridSearchQueryPlanOptimization) {
exclude.push(QueryFeature.HybridSearchSkipOrderByRewrite);
}
return allFeatures.filter((feature) => !exclude.includes(feature)).join(",");
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import {
DiagnosticNodeType,
} from "../../../src/index.js";
import type { ClientContext, FeedOptions, QueryInfo } from "../../../src/index.js";
import type { ComponentWeight } from "../../../src/queryExecutionContext/hybridQueryExecutionContext.js";
import {
HybridQueryExecutionContext,
HybridQueryExecutionContextBaseStates,
Expand Down Expand Up @@ -210,7 +211,12 @@ describe("hybridQueryExecutionContext", () => {
];
const expectedSortedRids = ["3", "2", "1"];

const result = context["sortHybridSearchResultByRRFScore"](input);
const comparator = (x: number, y: number): number => -1 * (x - y);
const componentWeights = [
{ weight: 1, comparator },
{ weight: 1, comparator },
];
const result = context["sortHybridSearchResultByRRFScore"](input, componentWeights);
const resultRids = result.map((res) => res.rid);
// Assert that the result rids are equal to the expected sorted rids
assert.deepStrictEqual(resultRids, expectedSortedRids);
Expand All @@ -222,9 +228,11 @@ describe("hybridQueryExecutionContext", () => {
{ rid: "2", componentScores: [20], data: {}, score: 0, ranks: [] },
{ rid: "3", componentScores: [25], data: {}, score: 0, ranks: [] },
];
const comparator = (x: number, y: number): number => -1 * (x - y);
const componentWeights = [{ weight: 1, comparator }];
const expectedSortedRids = ["1", "3", "2"];

const result = context["sortHybridSearchResultByRRFScore"](input);
const result = context["sortHybridSearchResultByRRFScore"](input, componentWeights);
const resultRids = result.map((res) => res.rid);
// Assert that the result rids are equal to the expected sorted rids
assert.deepStrictEqual(resultRids, expectedSortedRids);
Expand All @@ -236,15 +244,22 @@ describe("hybridQueryExecutionContext", () => {
];
const expectedSortedRids = ["1"];

const result = context["sortHybridSearchResultByRRFScore"](input);
const comparator = (x: number, y: number): number => -1 * (x - y);
const componentWeights = [
{ weight: 1, comparator },
{ weight: 1, comparator },
{ weight: 1, comparator },
];
const result = context["sortHybridSearchResultByRRFScore"](input, componentWeights);
const resultRids = result.map((res) => res.rid);
// Assert that the result rids are equal to the expected sorted rids
assert.deepStrictEqual(resultRids, expectedSortedRids);
});

it("sortHybridSearchResultByRRFScore method should handle empty HybridSearchQueryResult array", async () => {
const input: HybridSearchQueryResult[] = [];
const result = context["sortHybridSearchResultByRRFScore"](input);
const componentWeights: ComponentWeight[] = [];
const result = context["sortHybridSearchResultByRRFScore"](input, componentWeights);
assert.deepStrictEqual(input, result);
});
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,43 @@ import { describe, it, assert } from "vitest";
describe("validate supportedQueryFeaturesBuilder", () => {
it("should contain nonStreamingOrderBy feature", () => {
const feedOptions: FeedOptions = {};
const res = supportedQueryFeaturesBuilder(feedOptions.disableNonStreamingOrderByQuery);
const res = supportedQueryFeaturesBuilder(feedOptions);
assert.equal(res.includes("NonStreamingOrderBy"), true);
});

it("should contain nonStreamingOrderBy feature", () => {
const feedOptions: FeedOptions = { disableNonStreamingOrderByQuery: false };
const res = supportedQueryFeaturesBuilder(feedOptions.disableNonStreamingOrderByQuery);
const res = supportedQueryFeaturesBuilder(feedOptions);
assert.equal(res.includes("NonStreamingOrderBy"), true);
});

it("should contain nonStreamingOrderBy feature", () => {
it("should not contain nonStreamingOrderBy feature", () => {
const feedOptions: FeedOptions = { disableNonStreamingOrderByQuery: true };
const res = supportedQueryFeaturesBuilder(feedOptions.disableNonStreamingOrderByQuery);
const res = supportedQueryFeaturesBuilder(feedOptions);
assert.equal(res.includes("NonStreamingOrderBy"), false);
});
it("should contain hybridSearchSkipOrderByRewrite feature", () => {
const feedOptions: FeedOptions = {};
const res = supportedQueryFeaturesBuilder(feedOptions);
assert.equal(res.includes("HybridSearchSkipOrderByRewrite"), true);
});
it("should contain hybridSearchSkipOrderByRewrite feature", () => {
const feedOptions: FeedOptions = { disableHybridSearchQueryPlanOptimization: false };
const res = supportedQueryFeaturesBuilder(feedOptions);
assert.equal(res.includes("HybridSearchSkipOrderByRewrite"), true);
});
it("should not contain hybridSearchSkipOrderByRewrite feature", () => {
const feedOptions: FeedOptions = { disableHybridSearchQueryPlanOptimization: true };
const res = supportedQueryFeaturesBuilder(feedOptions);
assert.equal(res.includes("HybridSearchSkipOrderByRewrite"), false);
});
it("should not contain nonStreamingOrderBy and hybridSearchSkipOrderByRewrite features", () => {
const feedOptions: FeedOptions = {
disableNonStreamingOrderByQuery: true,
disableHybridSearchQueryPlanOptimization: true,
};
const res = supportedQueryFeaturesBuilder(feedOptions);
assert.equal(res.includes("NonStreamingOrderBy"), false);
assert.equal(res.includes("HybridSearchSkipOrderByRewrite"), false);
});
});
Loading