diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/VectorIndexTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/VectorIndexTest.java new file mode 100644 index 000000000000..74f13be9d4f4 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/VectorIndexTest.java @@ -0,0 +1,345 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.cosmos.rx; + +import com.azure.cosmos.ConsistencyLevel; +import com.azure.cosmos.CosmosAsyncClient; +import com.azure.cosmos.CosmosAsyncContainer; +import com.azure.cosmos.CosmosAsyncDatabase; +import com.azure.cosmos.CosmosClientBuilder; +import com.azure.cosmos.CosmosDatabaseForTest; +import com.azure.cosmos.CosmosException; +import com.azure.cosmos.DirectConnectionConfig; +import com.azure.cosmos.implementation.TestConfigurations; +import com.azure.cosmos.implementation.Utils; +import com.azure.cosmos.implementation.guava25.collect.ImmutableList; +import com.azure.cosmos.models.CosmosContainerProperties; +import com.azure.cosmos.models.CosmosVectorDataType; +import com.azure.cosmos.models.CosmosVectorDistanceFunction; +import com.azure.cosmos.models.CosmosVectorEmbedding; +import com.azure.cosmos.models.CosmosVectorEmbeddingPolicy; +import com.azure.cosmos.models.ExcludedPath; +import com.azure.cosmos.models.IncludedPath; +import com.azure.cosmos.models.IndexingMode; +import com.azure.cosmos.models.IndexingPolicy; +import com.azure.cosmos.models.PartitionKeyDefinition; +import com.azure.cosmos.models.CosmosVectorIndexSpec; +import com.azure.cosmos.models.CosmosVectorIndexType; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Ignore; +import org.testng.annotations.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.UUID; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; + +@Ignore("TODO: Ignore these test cases until the public emulator with vector indexes is released.") +public class VectorIndexTest extends TestSuiteBase { + protected static final int TIMEOUT = 30000; + protected static final int SETUP_TIMEOUT = 20000; + protected static final int SHUTDOWN_TIMEOUT = 20000; + + protected static Logger logger = LoggerFactory.getLogger(VectorIndexTest.class.getSimpleName()); + private final ObjectMapper simpleObjectMapper = Utils.getSimpleObjectMapper(); + private final String databaseId = CosmosDatabaseForTest.generateId(); + private CosmosAsyncClient client; + private CosmosAsyncDatabase database; + + @BeforeClass(groups = {"emulator"}, timeOut = SETUP_TIMEOUT) + public void before_VectorIndexTest() { + // set up the client + client = new CosmosClientBuilder() + .endpoint(TestConfigurations.HOST) + .key(TestConfigurations.MASTER_KEY) + .directMode(DirectConnectionConfig.getDefaultConfig()) + .consistencyLevel(ConsistencyLevel.SESSION) + .contentResponseOnWriteEnabled(true) + .buildAsyncClient(); + + database = createDatabase(client, databaseId); + } + + @AfterClass(groups = {"emulator"}, timeOut = SHUTDOWN_TIMEOUT, alwaysRun = true) + public void afterClass() { + safeDeleteDatabase(database); + safeClose(client); + } + + @Test(groups = {"emulator"}, timeOut = TIMEOUT*10000) + public void shouldCreateVectorEmbeddingPolicy() { + PartitionKeyDefinition partitionKeyDef = new PartitionKeyDefinition(); + ArrayList paths = new ArrayList(); + paths.add("/mypk"); + partitionKeyDef.setPaths(paths); + + CosmosContainerProperties collectionDefinition = new CosmosContainerProperties(UUID.randomUUID().toString(), partitionKeyDef); + + IndexingPolicy indexingPolicy = new IndexingPolicy(); + indexingPolicy.setIndexingMode(IndexingMode.CONSISTENT); + ExcludedPath excludedPath = new ExcludedPath("/*"); + indexingPolicy.setExcludedPaths(Collections.singletonList(excludedPath)); + + IncludedPath includedPath1 = new IncludedPath("/name/?"); + IncludedPath includedPath2 = new IncludedPath("/description/?"); + indexingPolicy.setIncludedPaths(ImmutableList.of(includedPath1, includedPath2)); + + indexingPolicy.setVectorIndexes(populateVectorIndexes()); + + CosmosVectorEmbeddingPolicy cosmosVectorEmbeddingPolicy = new CosmosVectorEmbeddingPolicy(); + cosmosVectorEmbeddingPolicy.setCosmosVectorEmbeddings(populateEmbeddings()); + + collectionDefinition.setIndexingPolicy(indexingPolicy); + collectionDefinition.setVectorEmbeddingPolicy(cosmosVectorEmbeddingPolicy); + + database.createContainer(collectionDefinition).block(); + CosmosAsyncContainer createdCollection = database.getContainer(collectionDefinition.getId()); + CosmosContainerProperties collectionProperties = createdCollection.read().block().getProperties(); + validateCollectionProperties(collectionDefinition, collectionProperties); + } + + @Test(groups = {"emulator"}, timeOut = TIMEOUT) + public void shouldFailOnEmptyVectorEmbeddingPolicy() { + PartitionKeyDefinition partitionKeyDef = new PartitionKeyDefinition(); + ArrayList paths = new ArrayList(); + paths.add("/mypk"); + partitionKeyDef.setPaths(paths); + + CosmosContainerProperties collectionDefinition = new CosmosContainerProperties(UUID.randomUUID().toString(), partitionKeyDef); + + IndexingPolicy indexingPolicy = new IndexingPolicy(); + indexingPolicy.setIndexingMode(IndexingMode.CONSISTENT); + ExcludedPath excludedPath = new ExcludedPath("/*"); + indexingPolicy.setExcludedPaths(Collections.singletonList(excludedPath)); + + IncludedPath includedPath1 = new IncludedPath("/name/?"); + IncludedPath includedPath2 = new IncludedPath("/description/?"); + indexingPolicy.setIncludedPaths(ImmutableList.of(includedPath1, includedPath2)); + + CosmosVectorIndexSpec cosmosVectorIndexSpec = new CosmosVectorIndexSpec(); + cosmosVectorIndexSpec.setPath("/vector1"); + cosmosVectorIndexSpec.setType(CosmosVectorIndexType.FLAT.toString()); + indexingPolicy.setVectorIndexes(ImmutableList.of(cosmosVectorIndexSpec)); + + collectionDefinition.setIndexingPolicy(indexingPolicy); + + try { + database.createContainer(collectionDefinition).block(); + fail("Container creation will fail as no vector embedding policy is being passed"); + } catch (CosmosException ex) { + assertThat(ex.getStatusCode()).isEqualTo(400); + assertThat(ex.getMessage()).contains("vector1 not matching in Embedding's path"); + } + } + + @Test(groups = {"emulator"}, timeOut = TIMEOUT) + public void shouldFailOnWrongVectorIndex() { + PartitionKeyDefinition partitionKeyDef = new PartitionKeyDefinition(); + ArrayList paths = new ArrayList(); + paths.add("/mypk"); + partitionKeyDef.setPaths(paths); + + CosmosContainerProperties collectionDefinition = new CosmosContainerProperties(UUID.randomUUID().toString(), partitionKeyDef); + + IndexingPolicy indexingPolicy = new IndexingPolicy(); + indexingPolicy.setIndexingMode(IndexingMode.CONSISTENT); + ExcludedPath excludedPath = new ExcludedPath("/*"); + indexingPolicy.setExcludedPaths(Collections.singletonList(excludedPath)); + + IncludedPath includedPath1 = new IncludedPath("/name/?"); + IncludedPath includedPath2 = new IncludedPath("/description/?"); + indexingPolicy.setIncludedPaths(ImmutableList.of(includedPath1, includedPath2)); + + CosmosVectorIndexSpec cosmosVectorIndexSpec = new CosmosVectorIndexSpec(); + cosmosVectorIndexSpec.setPath("/vector1"); + cosmosVectorIndexSpec.setType("NonFlat"); + indexingPolicy.setVectorIndexes(ImmutableList.of(cosmosVectorIndexSpec)); + collectionDefinition.setIndexingPolicy(indexingPolicy); + + CosmosVectorEmbedding embedding = new CosmosVectorEmbedding(); + embedding.setPath("/vector1"); + embedding.setDataType(CosmosVectorDataType.FLOAT32); + embedding.setDimensions(3L); + embedding.setDistanceFunction(CosmosVectorDistanceFunction.COSINE); + CosmosVectorEmbeddingPolicy cosmosVectorEmbeddingPolicy = new CosmosVectorEmbeddingPolicy(); + cosmosVectorEmbeddingPolicy.setCosmosVectorEmbeddings(ImmutableList.of(embedding)); + collectionDefinition.setVectorEmbeddingPolicy(cosmosVectorEmbeddingPolicy); + + try { + database.createContainer(collectionDefinition).block(); + fail("Container creation will fail as wrong vector index type is being passed"); + } catch (CosmosException ex) { + assertThat(ex.getStatusCode()).isEqualTo(400); + assertThat(ex.getMessage()).contains("NonFlat is invalid, Valid types are 'flat' or 'quantizedFlat'"); + } + } + + @Test(groups = {"emulator"}, timeOut = TIMEOUT) + public void shouldCreateVectorIndexSimilarPathDifferentVectorType() { + PartitionKeyDefinition partitionKeyDef = new PartitionKeyDefinition(); + ArrayList paths = new ArrayList(); + paths.add("/mypk"); + partitionKeyDef.setPaths(paths); + + CosmosContainerProperties collectionDefinition = new CosmosContainerProperties(UUID.randomUUID().toString(), partitionKeyDef); + + IndexingPolicy indexingPolicy = new IndexingPolicy(); + indexingPolicy.setIndexingMode(IndexingMode.CONSISTENT); + ExcludedPath excludedPath = new ExcludedPath("/*"); + indexingPolicy.setExcludedPaths(Collections.singletonList(excludedPath)); + + IncludedPath includedPath1 = new IncludedPath("/name/?"); + IncludedPath includedPath2 = new IncludedPath("/description/?"); + indexingPolicy.setIncludedPaths(ImmutableList.of(includedPath1, includedPath2)); + + List vectorIndexes = populateVectorIndexes(); + vectorIndexes.get(2).setPath("/vector2"); + indexingPolicy.setVectorIndexes(vectorIndexes); + + List embeddings = populateEmbeddings(); + embeddings.get(2).setPath("/vector2"); + CosmosVectorEmbeddingPolicy cosmosVectorEmbeddingPolicy = new CosmosVectorEmbeddingPolicy(); + cosmosVectorEmbeddingPolicy.setCosmosVectorEmbeddings(embeddings); + + collectionDefinition.setIndexingPolicy(indexingPolicy); + collectionDefinition.setVectorEmbeddingPolicy(cosmosVectorEmbeddingPolicy); + + database.createContainer(collectionDefinition).block(); + CosmosAsyncContainer createdCollection = database.getContainer(collectionDefinition.getId()); + CosmosContainerProperties collectionProperties = createdCollection.read().block().getProperties(); + validateCollectionProperties(collectionDefinition, collectionProperties); + } + + @Test(groups = {"unit"}, timeOut = TIMEOUT) + public void shouldFailOnWrongVectorEmbeddingPolicy() { + CosmosVectorEmbedding embedding = new CosmosVectorEmbedding(); + try { + + embedding.setDataType(null); + fail("Embedding creation failed because cosmosVectorDataType argument is empty"); + } catch (NullPointerException ex) { + assertThat(ex.getMessage()).isEqualTo("cosmosVectorDataType cannot be empty"); + } + + try { + embedding.setDistanceFunction(null); + fail("Embedding creation failed because cosmosVectorDistanceFunction argument is empty"); + } catch (NullPointerException ex) { + assertThat(ex.getMessage()).isEqualTo("cosmosVectorDistanceFunction cannot be empty"); + } + + try { + embedding.setDimensions(null); + fail("Embedding creation failed because dimensions argument is empty"); + } catch (NullPointerException ex) { + assertThat(ex.getMessage()).isEqualTo("dimensions cannot be empty"); + } + + try { + embedding.setDimensions(-1L); + fail("Vector Embedding policy creation will fail for negative dimensions being passed"); + } catch (IllegalArgumentException ex) { + assertThat(ex.getMessage()).isEqualTo("Dimensions for the embedding has to be a long value greater than 1 for the vector embedding policy"); + } + } + + @Test(groups = {"unit"}, timeOut = TIMEOUT) + public void shouldValidateVectorEmbeddingPolicySerializationAndDeserialization() throws JsonProcessingException { + IndexingPolicy indexingPolicy = new IndexingPolicy(); + indexingPolicy.setVectorIndexes(populateVectorIndexes()); + + CosmosVectorEmbeddingPolicy cosmosVectorEmbeddingPolicy = new CosmosVectorEmbeddingPolicy(); + cosmosVectorEmbeddingPolicy.setCosmosVectorEmbeddings(populateEmbeddings()); + String vectorEmbeddingPolicyJson = getVectorEmbeddingPolicyAsString(); + String expectedVectorEmbeddingPolicyJson = simpleObjectMapper.writeValueAsString(cosmosVectorEmbeddingPolicy); + assertThat(vectorEmbeddingPolicyJson).isEqualTo(expectedVectorEmbeddingPolicyJson); + + CosmosVectorEmbeddingPolicy expectedCosmosVectorEmbeddingPolicy = simpleObjectMapper.readValue(expectedVectorEmbeddingPolicyJson, CosmosVectorEmbeddingPolicy.class); + validateVectorEmbeddingPolicy(cosmosVectorEmbeddingPolicy, expectedCosmosVectorEmbeddingPolicy); + } + + private void validateCollectionProperties(CosmosContainerProperties collectionDefinition, CosmosContainerProperties collectionProperties) { + assertThat(collectionProperties.getVectorEmbeddingPolicy()).isNotNull(); + assertThat(collectionProperties.getVectorEmbeddingPolicy().getVectorEmbeddings()).isNotNull(); + validateVectorEmbeddingPolicy(collectionProperties.getVectorEmbeddingPolicy(), + collectionDefinition.getVectorEmbeddingPolicy()); + + assertThat(collectionProperties.getIndexingPolicy().getVectorIndexes()).isNotNull(); + validateVectorIndexes(collectionDefinition.getIndexingPolicy().getVectorIndexes(), collectionProperties.getIndexingPolicy().getVectorIndexes()); + } + + private void validateVectorEmbeddingPolicy(CosmosVectorEmbeddingPolicy actual, CosmosVectorEmbeddingPolicy expected) { + List actualEmbeddings = actual.getVectorEmbeddings(); + List expectedEmbeddings = expected.getVectorEmbeddings(); + assertThat(expectedEmbeddings).hasSameSizeAs(actualEmbeddings); + for (int i = 0; i < expectedEmbeddings.size(); i++) { + assertThat(expectedEmbeddings.get(i).getPath()).isEqualTo(actualEmbeddings.get(i).getPath()); + assertThat(expectedEmbeddings.get(i).getDataType()).isEqualTo(actualEmbeddings.get(i).getDataType()); + assertThat(expectedEmbeddings.get(i).getDimensions()).isEqualTo(actualEmbeddings.get(i).getDimensions()); + assertThat(expectedEmbeddings.get(i).getDistanceFunction()).isEqualTo(actualEmbeddings.get(i).getDistanceFunction()); + } + } + + private void validateVectorIndexes(List actual, List expected) { + assertThat(expected).hasSameSizeAs(actual); + for (int i = 0; i < expected.size(); i++) { + assertThat(expected.get(i).getPath()).isEqualTo(actual.get(i).getPath()); + assertThat(expected.get(i).getType()).isEqualTo(actual.get(i).getType()); + } + } + + private List populateVectorIndexes() { + CosmosVectorIndexSpec cosmosVectorIndexSpec1 = new CosmosVectorIndexSpec(); + cosmosVectorIndexSpec1.setPath("/vector1"); + cosmosVectorIndexSpec1.setType(CosmosVectorIndexType.FLAT.toString()); + + CosmosVectorIndexSpec cosmosVectorIndexSpec2 = new CosmosVectorIndexSpec(); + cosmosVectorIndexSpec2.setPath("/vector2"); + cosmosVectorIndexSpec2.setType(CosmosVectorIndexType.QUANTIZED_FLAT.toString()); + + CosmosVectorIndexSpec cosmosVectorIndexSpec3 = new CosmosVectorIndexSpec(); + cosmosVectorIndexSpec3.setPath("/vector3"); + cosmosVectorIndexSpec3.setType(CosmosVectorIndexType.DISK_ANN.toString()); + + return Arrays.asList(cosmosVectorIndexSpec1, cosmosVectorIndexSpec2, cosmosVectorIndexSpec3); + } + + private List populateEmbeddings() { + CosmosVectorEmbedding embedding1 = new CosmosVectorEmbedding(); + embedding1.setPath("/vector1"); + embedding1.setDataType(CosmosVectorDataType.INT8); + embedding1.setDimensions(3L); + embedding1.setDistanceFunction(CosmosVectorDistanceFunction.COSINE); + + CosmosVectorEmbedding embedding2 = new CosmosVectorEmbedding(); + embedding2.setPath("/vector2"); + embedding2.setDataType(CosmosVectorDataType.FLOAT32); + embedding2.setDimensions(3L); + embedding2.setDistanceFunction(CosmosVectorDistanceFunction.DOT_PRODUCT); + + CosmosVectorEmbedding embedding3 = new CosmosVectorEmbedding(); + embedding3.setPath("/vector3"); + embedding3.setDataType(CosmosVectorDataType.UINT8); + embedding3.setDimensions(3L); + embedding3.setDistanceFunction(CosmosVectorDistanceFunction.EUCLIDEAN); + return Arrays.asList(embedding1, embedding2, embedding3); + } + + private String getVectorEmbeddingPolicyAsString() { + return "{\"vectorEmbeddings\":[" + + "{\"path\":\"/vector1\",\"dataType\":\"int8\",\"dimensions\":3,\"distanceFunction\":\"cosine\"}," + + "{\"path\":\"/vector2\",\"dataType\":\"float32\",\"dimensions\":3,\"distanceFunction\":\"dotproduct\"}," + + "{\"path\":\"/vector3\",\"dataType\":\"uint8\",\"dimensions\":3,\"distanceFunction\":\"euclidean\"}" + + "]}"; + } +} diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md index 7586e505b1e8..35353b94d100 100644 --- a/sdk/cosmos/azure-cosmos/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md @@ -3,8 +3,9 @@ ### 4.59.0 (2024-04-27) #### Features Added +* Added `cosmosVectorEmbeddingPolicy` in `cosmosContainerProperties` and `vectorIndexes` in `indexPolicy` to support vector search in CosmosDB - See[39379](https://github.com/Azure/azure-sdk-for-java/pull/39379) * Added public APIs `getCustomItemSerializer` and `setCustomItemSerializer` to allow customers to specify custom payload transformations or serialization settings. - See [PR 38997](https://github.com/Azure/azure-sdk-for-java/pull/38997) and [PR 39933](https://github.com/Azure/azure-sdk-for-java/pull/39933) - + #### Other Changes * Load Blackbird or Afterburner into the ObjectMapper depending upon Java version and presence of modules in classpath. Make Afterburner and Blackbird optional maven dependencies. See - [PR 39689](https://github.com/Azure/azure-sdk-for-java/pull/39689) diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/Constants.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/Constants.java index 8409d5b7ec23..f789963783ed 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/Constants.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/Constants.java @@ -120,6 +120,15 @@ public static final class Properties { public static final String SPATIAL_INDEXES = "spatialIndexes"; public static final String TYPES = "types"; + // Vector Embedding Policy + public static final String VECTOR_EMBEDDING_POLICY = "vectorEmbeddingPolicy"; + public static final String VECTOR_INDEXES = "vectorIndexes"; + public static final String VECTOR_EMBEDDINGS = "vectorEmbeddings"; + public static final String VECTOR_INDEX_TYPE = "type"; + public static final String VECTOR_DATA_TYPE = "dataType"; + public static final String VECTOR_DIMENSIONS = "dimensions"; + public static final String DISTANCE_FUNCTION = "distanceFunction"; + // Unique index. public static final String UNIQUE_KEY_POLICY = "uniqueKeyPolicy"; public static final String UNIQUE_KEYS = "uniqueKeys"; diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/DocumentCollection.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/DocumentCollection.java index 678bdb90378c..1930f2275a61 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/DocumentCollection.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/DocumentCollection.java @@ -6,10 +6,11 @@ import com.azure.cosmos.CosmosItemSerializer; import com.azure.cosmos.implementation.apachecommons.lang.StringUtils; import com.azure.cosmos.implementation.caches.SerializableWrapper; -import com.azure.cosmos.models.ClientEncryptionPolicy; import com.azure.cosmos.models.ChangeFeedPolicy; +import com.azure.cosmos.models.ClientEncryptionPolicy; import com.azure.cosmos.models.ComputedProperty; import com.azure.cosmos.models.ConflictResolutionPolicy; +import com.azure.cosmos.models.CosmosVectorEmbeddingPolicy; import com.azure.cosmos.models.IndexingPolicy; import com.azure.cosmos.models.ModelBridgeInternal; import com.azure.cosmos.models.PartitionKeyDefinition; @@ -24,6 +25,8 @@ import java.util.Collection; import java.util.Collections; +import static com.azure.cosmos.implementation.guava25.base.Preconditions.checkNotNull; + /** * Represents a document collection in the Azure Cosmos DB database service. A collection is a named logical container * for documents. @@ -40,6 +43,7 @@ public final class DocumentCollection extends Resource { private UniqueKeyPolicy uniqueKeyPolicy; private PartitionKeyDefinition partitionKeyDefinition; private ClientEncryptionPolicy clientEncryptionPolicyInternal; + private CosmosVectorEmbeddingPolicy cosmosVectorEmbeddingPolicy; /** * Constructor. @@ -410,6 +414,33 @@ public void setClientEncryptionPolicy(ClientEncryptionPolicy value) { this.set(Constants.Properties.CLIENT_ENCRYPTION_POLICY, value, CosmosItemSerializer.DEFAULT_SERIALIZER); } + /** + * Gets the Vector Embedding Policy containing paths for embeddings along with path-specific settings for the item + * used in performing vector search on the items in a collection in the Azure CosmosDB database service. + * + * @return the Vector Embedding Policy. + */ + public CosmosVectorEmbeddingPolicy getVectorEmbeddingPolicy() { + if (this.cosmosVectorEmbeddingPolicy == null) { + if (super.has(Constants.Properties.VECTOR_EMBEDDING_POLICY)) { + this.cosmosVectorEmbeddingPolicy = super.getObject(Constants.Properties.VECTOR_EMBEDDING_POLICY, + CosmosVectorEmbeddingPolicy.class); + } + } + return this.cosmosVectorEmbeddingPolicy; + } + + /** + * Sets the Vector Embedding Policy containing paths for embeddings along with path-specific settings for the item + * used in performing vector search on the items in a collection in the Azure CosmosDB database service. + * + * @param value the Vector Embedding Policy. + */ + public void setVectorEmbeddingPolicy(CosmosVectorEmbeddingPolicy value) { + checkNotNull(value, "cosmosVectorEmbeddingPolicy cannot be null"); + this.set(Constants.Properties.VECTOR_EMBEDDING_POLICY, value, CosmosItemSerializer.DEFAULT_SERIALIZER); + } + public void populatePropertyBag() { super.populatePropertyBag(); if (this.indexingPolicy == null) { diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CompositePath.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CompositePath.java index a278e0aa8f50..8cfd99b46e37 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CompositePath.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CompositePath.java @@ -93,7 +93,7 @@ public CompositePathSortOrder getOrder() { } /** - * Gets the sort order for the composite path. + * Sets the sort order for the composite path. *

* For example if you want to run the query "SELECT * FROM c ORDER BY c.age asc, c.height desc", * then you need to make the order for "/age" "ascending" and the order for "/height" "descending". diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosContainerProperties.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosContainerProperties.java index 4fae5a797a70..0d357da0cc37 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosContainerProperties.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosContainerProperties.java @@ -347,6 +347,28 @@ public CosmosContainerProperties setClientEncryptionPolicy(ClientEncryptionPolic return this; } + /** + * Gets the Vector Embedding Policy containing paths for embeddings along with path-specific settings for the item + * used in performing vector search on the items in a collection in the Azure CosmosDB database service. + * + * @return the Vector Embedding Policy. + */ + public CosmosVectorEmbeddingPolicy getVectorEmbeddingPolicy() { + return this.documentCollection.getVectorEmbeddingPolicy(); + } + + /** + * Sets the Vector Embedding Policy containing paths for embeddings along with path-specific settings for the item + * used in performing vector search on the items in a collection in the Azure CosmosDB database service. + * + * @param value the Vector Embedding Policy. + * @return the CosmosContainerProperties. + */ + public CosmosContainerProperties setVectorEmbeddingPolicy(CosmosVectorEmbeddingPolicy value) { + this.documentCollection.setVectorEmbeddingPolicy(value); + return this; + } + Resource getResource() { return this.documentCollection; } diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosVectorDataType.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosVectorDataType.java new file mode 100644 index 000000000000..1a0d42af17a4 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosVectorDataType.java @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.cosmos.models; + +import com.fasterxml.jackson.annotation.JsonValue; + +import java.util.Arrays; + +/** + * Data types for the embeddings in Cosmos DB database service. + */ +public enum CosmosVectorDataType { + /** + * Represents a int8 data type. + */ + INT8("int8"), + + /** + * Represents a uint8 data type. + */ + UINT8("uint8"), + + /** + * Represents a float16 data type. + */ + FLOAT16("float16"), + + /** + * Represents a float32 data type. + */ + FLOAT32("float32"); + + private final String overWireValue; + + CosmosVectorDataType(String overWireValue) { + this.overWireValue = overWireValue; + } + + @JsonValue + @Override + public String toString() { + return this.overWireValue; + } + + /** + * Method to retrieve the enum constant by its overWireValue. + * @param value the overWire value of the enum constant + * @return the matching CosmosVectorDataType + * @throws IllegalArgumentException if no matching enum constant is found + */ + public static CosmosVectorDataType fromString(String value) { + return Arrays.stream(CosmosVectorDataType.values()) + .filter(vectorDataType -> vectorDataType.toString().equalsIgnoreCase(value)) + .findFirst() + .orElseThrow(() -> new IllegalArgumentException("Invalid vector data type for the vector embedding policy.")); + } +} diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosVectorDistanceFunction.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosVectorDistanceFunction.java new file mode 100644 index 000000000000..60efd432ad7f --- /dev/null +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosVectorDistanceFunction.java @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.cosmos.models; + +import com.fasterxml.jackson.annotation.JsonValue; + +import java.util.Arrays; + +/** + * Distance Function for the embeddings in the Cosmos DB database service. + */ +public enum CosmosVectorDistanceFunction { + /** + * Represents the euclidean distance function. + */ + EUCLIDEAN("euclidean"), + + /** + * Represents the cosine distance function. + */ + COSINE("cosine"), + + /** + * Represents the dot product distance function. + */ + DOT_PRODUCT("dotproduct"); + + private final String overWireValue; + + CosmosVectorDistanceFunction(String overWireValue) { + this.overWireValue = overWireValue; + } + + @JsonValue + @Override + public String toString() { + return this.overWireValue; + } + + /** + * Method to retrieve the enum constant by its overWireValue. + * @param value the overWire value of the enum constant + * @return the matching CosmosVectorDataType + * @throws IllegalArgumentException if no matching enum constant is found + */ + public static CosmosVectorDistanceFunction fromString(String value) { + return Arrays.stream(CosmosVectorDistanceFunction.values()) + .filter(vectorDistanceFunction -> vectorDistanceFunction.toString().equalsIgnoreCase(value)) + .findFirst() + .orElseThrow(() -> new IllegalArgumentException("Invalid distance function for the vector embedding policy.")); + } +} diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosVectorEmbedding.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosVectorEmbedding.java new file mode 100644 index 000000000000..94a519ef9d38 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosVectorEmbedding.java @@ -0,0 +1,128 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.cosmos.models; + +import com.azure.cosmos.implementation.Constants; +import com.azure.cosmos.implementation.JsonSerializable; +import com.azure.cosmos.implementation.apachecommons.lang.StringUtils; +import com.fasterxml.jackson.annotation.JsonProperty; +import static com.azure.cosmos.implementation.guava25.base.Preconditions.checkNotNull; + +/** + * Embedding settings within {@link CosmosVectorEmbeddingPolicy} + */ +public final class CosmosVectorEmbedding { + @JsonProperty(Constants.Properties.PATH) + private String path; + @JsonProperty(Constants.Properties.VECTOR_DATA_TYPE) + private String dataType; + @JsonProperty(Constants.Properties.VECTOR_DIMENSIONS) + private Long dimensions; + @JsonProperty(Constants.Properties.DISTANCE_FUNCTION) + private String distanceFunction; + private JsonSerializable jsonSerializable; + + /** + * Constructor + */ + public CosmosVectorEmbedding() { + this.jsonSerializable = new JsonSerializable(); + } + + /** + * Gets the path for the cosmosVectorEmbedding. + * + * @return path + */ + public String getPath() { + return path; + } + + /** + * Sets the path for the cosmosVectorEmbedding. + * + * @param path the path for the cosmosVectorEmbedding + * @return CosmosVectorEmbedding + */ + public CosmosVectorEmbedding setPath(String path) { + if (StringUtils.isEmpty(path)) { + throw new NullPointerException("embedding path is empty"); + } + + if (path.charAt(0) != '/' || path.lastIndexOf('/') != 0) { + throw new IllegalArgumentException(""); + } + + this.path = path; + return this; + } + + /** + * Gets the data type for the cosmosVectorEmbedding. + * + * @return dataType + */ + public CosmosVectorDataType getDataType() { + return CosmosVectorDataType.fromString(dataType); + } + + /** + * Sets the data type for the cosmosVectorEmbedding. + * + * @param dataType the data type for the cosmosVectorEmbedding + * @return CosmosVectorEmbedding + */ + public CosmosVectorEmbedding setDataType(CosmosVectorDataType dataType) { + checkNotNull(dataType, "cosmosVectorDataType cannot be null"); + this.dataType = dataType.toString(); + return this; + } + + /** + * Gets the dimensions for the cosmosVectorEmbedding. + * + * @return dimensions + */ + public Long getDimensions() { + return dimensions; + } + + /** + * Sets the dimensions for the cosmosVectorEmbedding. + * + * @param dimensions the dimensions for the cosmosVectorEmbedding + * @return CosmosVectorEmbedding + */ + public CosmosVectorEmbedding setDimensions(Long dimensions) { + checkNotNull(dimensions, "dimensions cannot be null"); + if (dimensions < 1) { + throw new IllegalArgumentException("Dimensions for the embedding has to be a long value greater than 0 " + + "for the vector embedding policy"); + } + + this.dimensions = dimensions; + return this; + } + + /** + * Gets the distanceFunction for the cosmosVectorEmbedding. + * + * @return distanceFunction + */ + public CosmosVectorDistanceFunction getDistanceFunction() { + return CosmosVectorDistanceFunction.fromString(distanceFunction); + } + + /** + * Sets the distanceFunction for the cosmosVectorEmbedding. + * + * @param distanceFunction the distanceFunction for the cosmosVectorEmbedding + * @return CosmosVectorEmbedding + */ + public CosmosVectorEmbedding setDistanceFunction(CosmosVectorDistanceFunction distanceFunction) { + checkNotNull(distanceFunction, "cosmosVectorDistanceFunction cannot be empty"); + this.distanceFunction = distanceFunction.toString(); + return this; + } +} diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosVectorEmbeddingPolicy.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosVectorEmbeddingPolicy.java new file mode 100644 index 000000000000..c54c843ebd96 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosVectorEmbeddingPolicy.java @@ -0,0 +1,55 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.cosmos.models; + +import com.azure.cosmos.implementation.Constants; +import com.azure.cosmos.implementation.JsonSerializable; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.List; + +/** + * Vector Embedding Policy + */ +public final class CosmosVectorEmbeddingPolicy { + + private JsonSerializable jsonSerializable; + /** + * Paths for embeddings along with path-specific settings for the item. + */ + @JsonProperty(Constants.Properties.VECTOR_EMBEDDINGS) + private List cosmosVectorEmbeddings; + + /** + * Constructor + */ + public CosmosVectorEmbeddingPolicy() { + this.jsonSerializable = new JsonSerializable(); + } + + /** + * Gets the paths for embeddings along with path-specific settings for the item. + * + * @return the paths for embeddings along with path-specific settings for the item. + */ + public List getVectorEmbeddings() { + return this.cosmosVectorEmbeddings; + } + + /** + * Sets the paths for embeddings along with path-specific settings for the item. + * + * @param cosmosVectorEmbeddings paths for embeddings along with path-specific settings for the item. + */ + public void setCosmosVectorEmbeddings(List cosmosVectorEmbeddings) { + cosmosVectorEmbeddings.forEach(embedding -> { + if (embedding == null) { + throw new NullPointerException("Embedding cannot be null."); + } + }); + this.cosmosVectorEmbeddings = cosmosVectorEmbeddings; +// this.jsonSerializable.set(Constants.Properties.VECTOR_EMBEDDINGS, cosmosVectorEmbeddings); + } + +} diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosVectorIndexSpec.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosVectorIndexSpec.java new file mode 100644 index 000000000000..13b1559810ca --- /dev/null +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosVectorIndexSpec.java @@ -0,0 +1,77 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.cosmos.models; + +import com.azure.cosmos.CosmosItemSerializer; +import com.azure.cosmos.implementation.Constants; +import com.azure.cosmos.implementation.JsonSerializable; + +import static com.azure.cosmos.implementation.guava25.base.Preconditions.checkNotNull; + +/** + * Vector Indexes spec for Azure CosmosDB service. + */ +public final class CosmosVectorIndexSpec { + + private JsonSerializable jsonSerializable; + private String type; + + /** + * Constructor + */ + public CosmosVectorIndexSpec() { this.jsonSerializable = new JsonSerializable(); } + + /** + * Gets path. + * + * @return the path. + */ + public String getPath() { + return this.jsonSerializable.getString(Constants.Properties.PATH); + } + + /** + * Sets path. + * + * @param path the path. + * @return the SpatialSpec. + */ + public CosmosVectorIndexSpec setPath(String path) { + this.jsonSerializable.set(Constants.Properties.PATH, path, CosmosItemSerializer.DEFAULT_SERIALIZER); + return this; + } + + /** + * Gets the vector index type for the vector index + * + * @return the vector index type + */ + public String getType() { + if (this.type == null) { + this.type = this.jsonSerializable.getString(Constants.Properties.VECTOR_INDEX_TYPE); + } + return this.type; + } + + /** + * Sets the vector index type for the vector index + * + * @param type the vector index type + * @return the VectorIndexSpec + */ + public CosmosVectorIndexSpec setType(String type) { + checkNotNull(type, "cosmosVectorIndexType cannot be null"); + this.type = type; + this.jsonSerializable.set(Constants.Properties.VECTOR_INDEX_TYPE, this.type, CosmosItemSerializer.DEFAULT_SERIALIZER); + return this; + } + + void populatePropertyBag() { + this.jsonSerializable.populatePropertyBag(); + } + + JsonSerializable getJsonSerializable() { + return this.jsonSerializable; + } +} diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosVectorIndexType.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosVectorIndexType.java new file mode 100644 index 000000000000..679ea1f991c0 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosVectorIndexType.java @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.cosmos.models; + +/** + * Defines the index type of vector index specification in the Azure Cosmos DB service. + */ +public enum CosmosVectorIndexType { + /** + * Represents a flat vector index type. + */ + FLAT("flat"), + + /** + * Represents a quantized flat vector index type. + */ + QUANTIZED_FLAT("quantizedFlat"), + + /** + * Represents a disk ANN vector index type. + */ + DISK_ANN("diskANN"); + + + private final String overWireValue; + + CosmosVectorIndexType(String overWireValue) { + this.overWireValue = overWireValue; + } + + @Override + public String toString() { + return this.overWireValue; + } +} diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/IndexingPolicy.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/IndexingPolicy.java index 4ee5153877f8..678cea56fcc4 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/IndexingPolicy.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/IndexingPolicy.java @@ -20,13 +20,12 @@ */ public final class IndexingPolicy { private static final String DEFAULT_PATH = "/*"; - + private final JsonSerializable jsonSerializable; private List includedPaths; private List excludedPaths; private List> compositeIndexes; private List spatialIndexes; - - private JsonSerializable jsonSerializable; + private List vectorIndexes; /** * Constructor. @@ -54,7 +53,7 @@ public IndexingPolicy() { * * * @param defaultIndexOverrides comma separated set of indexes that serve as default index specifications for the - * root path. + * root path. * @throws IllegalArgumentException throws when defaultIndexOverrides is null */ IndexingPolicy(Index[] defaultIndexOverrides) { @@ -235,7 +234,7 @@ public IndexingPolicy setCompositeIndexes(List> compositeInd } /** - * Sets the spatial indexes for additional indexes. + * Gets the spatial indexes for additional indexes. * * @return the spatial indexes. */ @@ -266,11 +265,55 @@ public IndexingPolicy setSpatialIndexes(List spatialIndexes) { return this; } + /** + * Gets the vector indexes. + * + * @return the vector indexes + */ + public List getVectorIndexes() { + if (this.vectorIndexes == null) { + this.vectorIndexes = this.jsonSerializable.getList(Constants.Properties.VECTOR_INDEXES, CosmosVectorIndexSpec.class); + + if (this.vectorIndexes == null) { + this.vectorIndexes = new ArrayList(); + } + } + + return this.vectorIndexes; + } + + /** + * Sets the vector indexes. + * + * Example of the vectorIndexes: + * "vectorIndexes": [ + * { + * "path": "/vector1", + * "type": "diskANN" + * }, + * { + * "path": "/vector1", + * "type": "flat" + * }, + * { + * "path": "/vector2", + * "type": "quantizedFlat" + * }] + * + * @param vectorIndexes the vector indexes + * @return the Indexing Policy. + */ + public IndexingPolicy setVectorIndexes(List vectorIndexes) { + this.vectorIndexes = vectorIndexes; + this.jsonSerializable.set(Constants.Properties.VECTOR_INDEXES,this.vectorIndexes, CosmosItemSerializer.DEFAULT_SERIALIZER); + return this; + } + void populatePropertyBag() { this.jsonSerializable.populatePropertyBag(); // If indexing mode is not 'none' and not paths are set, set them to the defaults if (this.getIndexingMode() != IndexingMode.NONE && this.getIncludedPaths().size() == 0 - && this.getExcludedPaths().size() == 0) { + && this.getExcludedPaths().size() == 0) { IncludedPath includedPath = new IncludedPath(IndexingPolicy.DEFAULT_PATH); this.getIncludedPaths().add(includedPath); } @@ -296,5 +339,7 @@ void populatePropertyBag() { } } - JsonSerializable getJsonSerializable() { return this.jsonSerializable; } + JsonSerializable getJsonSerializable() { + return this.jsonSerializable; + } } diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/ModelBridgeInternal.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/ModelBridgeInternal.java index a8f344ce4e2d..4931c9b3b1b0 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/ModelBridgeInternal.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/ModelBridgeInternal.java @@ -435,6 +435,8 @@ public static void populatePropertyBag(T t) { ((PartitionKeyDefinition) t).populatePropertyBag(); } else if (t instanceof SpatialSpec) { ((SpatialSpec) t).populatePropertyBag(); + } else if (t instanceof CosmosVectorIndexSpec) { + ((CosmosVectorIndexSpec) t).populatePropertyBag(); } else if (t instanceof SqlParameter) { ((SqlParameter) t).populatePropertyBag(); } else if (t instanceof SqlQuerySpec) { @@ -468,6 +470,8 @@ public static JsonSerializable getJsonSerializable(T t) { return ((PartitionKeyDefinition) t).getJsonSerializable(); } else if (t instanceof SpatialSpec) { return ((SpatialSpec) t).getJsonSerializable(); + } else if (t instanceof CosmosVectorIndexSpec) { + return ((CosmosVectorIndexSpec) t).getJsonSerializable(); } else if (t instanceof SqlParameter) { return ((SqlParameter) t).getJsonSerializable(); } else if (t instanceof SqlQuerySpec) {