Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
0269f52
Adding changes for vectorIndex and vectorEmbeddingPolicy
aayush3011 Mar 22, 2024
29ad391
Adding some necessary comments
aayush3011 Mar 27, 2024
cd4d8cf
Adding test case
aayush3011 Mar 28, 2024
a2f6a83
updating enum values
aayush3011 Mar 28, 2024
c4bc283
Updating test case
aayush3011 Mar 28, 2024
af99d7b
Updating test case
aayush3011 Mar 29, 2024
6d8fc9b
Updating test case
aayush3011 Mar 29, 2024
2f7112d
updating changelog
aayush3011 Mar 29, 2024
158880f
Merge branch 'main' into users/akataria/vectorindexing
aayush3011 Mar 29, 2024
bb85dd3
Updating test case
aayush3011 Mar 29, 2024
a7185d7
Merge branch 'users/akataria/vectorindexing' of https://github.com/aa…
aayush3011 Mar 29, 2024
f4c4012
Merge branch 'Azure:main' into users/akataria/vectorindexing
aayush3011 Apr 2, 2024
72a4bcd
Resolving comments
aayush3011 Apr 2, 2024
dfb3575
Resolving comments
aayush3011 Apr 2, 2024
67f51cb
Fixing test case
aayush3011 Apr 2, 2024
730f8c2
Resolving comments
aayush3011 Apr 23, 2024
ad3ac89
Resolving Comments
aayush3011 Apr 27, 2024
940c6af
Merge branch 'main' into users/akataria/vectorindexing
aayush3011 Apr 27, 2024
3eb77ea
Fixing build issues
aayush3011 Apr 27, 2024
44f4e07
Merge branch 'main' into users/akataria/vectorindexing
aayush3011 Apr 29, 2024
460f681
Resolving comments
aayush3011 Apr 30, 2024
5579dd1
Resolving Comments
aayush3011 May 1, 2024
54a2ce3
Merge branch 'users/akataria/vectorindexing' of https://github.com/aa…
aayush3011 May 1, 2024
528a0eb
[Cosmos][VectorIndex]Adding changes for vectorIndex and vectorEmbeddi…
aayush3011 May 2, 2024
148cba5
[Cosmos][VectorSearch] Non Streaming Order By Query (#40085)
aayush3011 May 8, 2024
df7e838
[Cosmos][VectorSearch] Non Streaming Order By Query (#40096)
aayush3011 May 9, 2024
c8de52f
[Cosmos][VectorSearch] Non Streaming Order By Query (#40098)
aayush3011 May 9, 2024
425b78f
Merge branch 'Azure:main' into users/akataria/vectorindexing
aayush3011 May 10, 2024
55efb81
Resolving comments
aayush3011 May 10, 2024
9a3b003
Resolving comments
aayush3011 May 10, 2024
52917f9
[Cosmos][VectorSearch] Non Streaming Order By Query (#40115)
aayush3011 May 10, 2024
72a7145
Merge branch 'feature/vector_search' into users/akataria/vectorindexing
aayush3011 May 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Resolving comments
  • Loading branch information
aayush3011 committed Apr 2, 2024
commit dfb3575042601b07bddda41fe7821fee2813aa78
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import com.azure.cosmos.CosmosException;
import com.azure.cosmos.DirectConnectionConfig;
import com.azure.cosmos.implementation.TestConfigurations;
import com.azure.cosmos.implementation.Utils;
import com.azure.cosmos.implementation.guava25.collect.ImmutableList;
import com.azure.cosmos.models.CosmosContainerProperties;
import com.azure.cosmos.models.CosmosVectorDataType;
Expand All @@ -25,6 +26,8 @@
import com.azure.cosmos.models.VectorEmbeddingPolicy;
import com.azure.cosmos.models.VectorIndexSpec;
import com.azure.cosmos.models.VectorIndexType;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.testng.annotations.AfterClass;
Expand All @@ -48,6 +51,7 @@ public class VectorIndexTest extends TestSuiteBase {
protected static final int SHUTDOWN_TIMEOUT = 20000;

protected static Logger logger = LoggerFactory.getLogger(VectorIndexTest.class.getSimpleName());
private final ObjectMapper simpleObjectMapper = Utils.getSimpleObjectMapper();
private final String databaseId = CosmosDatabaseForTest.generateId();
private CosmosAsyncClient client;
private CosmosAsyncDatabase database;
Expand Down Expand Up @@ -161,9 +165,9 @@ public void shouldFailOnWrongVectorEmbeddingPolicy() {

CosmosVectorEmbedding embedding = new CosmosVectorEmbedding(
"/vector1",
CosmosVectorDistanceFunction.COSINE.toString(),
"String",
3L,
"String");
CosmosVectorDistanceFunction.COSINE.toString());

try {
VectorEmbeddingPolicy vectorEmbeddingPolicy = new VectorEmbeddingPolicy(ImmutableList.of(embedding));
Expand All @@ -173,7 +177,7 @@ public void shouldFailOnWrongVectorEmbeddingPolicy() {
assertThat(ex.getMessage()).isEqualTo("Invalid vector data type for the vector embedding policy.");
}

embedding.setVectorDataType("");
embedding.setCosmosVectorDataType("");
try {
VectorEmbeddingPolicy vectorEmbeddingPolicy = new VectorEmbeddingPolicy(ImmutableList.of(embedding));
collectionDefinition.setVectorEmbeddingPolicy(vectorEmbeddingPolicy);
Expand All @@ -182,8 +186,8 @@ public void shouldFailOnWrongVectorEmbeddingPolicy() {
assertThat(ex.getMessage()).isEqualTo("Vector data type cannot be empty for the vector embedding policy.");
}

embedding.setVectorDataType(CosmosVectorDataType.FLOAT32.toString());
embedding.setDistanceFunction("COS");
embedding.setCosmosVectorDataType(CosmosVectorDataType.FLOAT32.toString());
embedding.setCosmosVectorDistanceFunction("COS");
try {
VectorEmbeddingPolicy vectorEmbeddingPolicy = new VectorEmbeddingPolicy(ImmutableList.of(embedding));
collectionDefinition.setVectorEmbeddingPolicy(vectorEmbeddingPolicy);
Expand All @@ -192,7 +196,7 @@ public void shouldFailOnWrongVectorEmbeddingPolicy() {
assertThat(ex.getMessage()).isEqualTo("Invalid distance function for the vector embedding policy.");
}

embedding.setDistanceFunction("");
embedding.setCosmosVectorDistanceFunction("");
try {
VectorEmbeddingPolicy vectorEmbeddingPolicy = new VectorEmbeddingPolicy(ImmutableList.of(embedding));
collectionDefinition.setVectorEmbeddingPolicy(vectorEmbeddingPolicy);
Expand All @@ -201,7 +205,7 @@ public void shouldFailOnWrongVectorEmbeddingPolicy() {
assertThat(ex.getMessage()).isEqualTo("Distance function cannot be empty for the vector embedding policy.");
}

embedding.setDistanceFunction(CosmosVectorDistanceFunction.COSINE.toString());
embedding.setCosmosVectorDistanceFunction(CosmosVectorDistanceFunction.COSINE.toString());
embedding.setDimensions(-1L);
try {
VectorEmbeddingPolicy vectorEmbeddingPolicy = new VectorEmbeddingPolicy(ImmutableList.of(embedding));
Expand Down Expand Up @@ -239,7 +243,7 @@ public void shouldFailOnWrongVectorIndex() {
"/vector1",
CosmosVectorDistanceFunction.COSINE.toString(),
3L,
CosmosVectorDataType.INT8.toString());
CosmosVectorDistanceFunction.COSINE.toString());
VectorEmbeddingPolicy vectorEmbeddingPolicy = new VectorEmbeddingPolicy(ImmutableList.of(embedding));
collectionDefinition.setVectorEmbeddingPolicy(vectorEmbeddingPolicy);

Expand Down Expand Up @@ -287,22 +291,47 @@ public void shouldCreateVectorIndexSimilarPathDifferentVectorType() {
validateCollectionProperties(collectionDefinition, collectionProperties);
}

@Test(groups = {"emulator"}, timeOut = TIMEOUT)
public void shouldValidateVectorEmbeddingPolicySerializationAndDeserialization() throws JsonProcessingException {
IndexingPolicy indexingPolicy = new IndexingPolicy();
indexingPolicy.setVectorIndexes(populateVectorIndexes());

VectorEmbeddingPolicy vectorEmbeddingPolicy = new VectorEmbeddingPolicy(populateEmbeddings());
String vectorEmbeddingPolicyJson = getVectorEmbeddingPolicyAsString();
String expectedVectorEmbeddingPolicyJson = simpleObjectMapper.writeValueAsString(vectorEmbeddingPolicy);
assertThat(vectorEmbeddingPolicyJson).isEqualTo(expectedVectorEmbeddingPolicyJson);

VectorEmbeddingPolicy expectedVectorEmbeddingPolicy = simpleObjectMapper.readValue(expectedVectorEmbeddingPolicyJson, VectorEmbeddingPolicy.class);
validateVectorEmbeddingPolicy(vectorEmbeddingPolicy, expectedVectorEmbeddingPolicy);
}

private void validateCollectionProperties(CosmosContainerProperties collectionDefinition, CosmosContainerProperties collectionProperties) {
assertThat(collectionProperties.getVectorEmbeddingPolicy()).isNotNull();
assertThat(collectionProperties.getVectorEmbeddingPolicy().getEmbeddings()).isNotNull();
List<CosmosVectorEmbedding> embeddings = collectionProperties.getVectorEmbeddingPolicy().getEmbeddings();
assertThat(embeddings).hasSameSizeAs(collectionDefinition.getVectorEmbeddingPolicy().getEmbeddings());
for (int i = 0; i < embeddings.size(); i++) {
assertThat(embeddings.get(0).getPath()).isEqualTo(
collectionDefinition.getVectorEmbeddingPolicy().getEmbeddings().get(0).getPath());
}
assertThat(collectionProperties.getVectorEmbeddingPolicy().getCosmosVectorEmbeddings()).isNotNull();
validateVectorEmbeddingPolicy(collectionProperties.getVectorEmbeddingPolicy(),
collectionDefinition.getVectorEmbeddingPolicy());

assertThat(collectionProperties.getIndexingPolicy().getVectorIndexes()).isNotNull();
List<VectorIndexSpec> vectorIndexes = collectionProperties.getIndexingPolicy().getVectorIndexes();
assertThat(vectorIndexes).hasSameSizeAs(collectionDefinition.getIndexingPolicy().getVectorIndexes());
for (int i = 0; i < vectorIndexes.size(); i++) {
assertThat(vectorIndexes.get(0).getPath()).isEqualTo(
collectionDefinition.getIndexingPolicy().getVectorIndexes().get(0).getPath());
validateVectorIndexes(collectionDefinition.getIndexingPolicy().getVectorIndexes(), collectionProperties.getIndexingPolicy().getVectorIndexes());
}

private void validateVectorEmbeddingPolicy(VectorEmbeddingPolicy actual, VectorEmbeddingPolicy expected) {
List<CosmosVectorEmbedding> actualEmbeddings = actual.getCosmosVectorEmbeddings();
List<CosmosVectorEmbedding> expectedEmbeddings = expected.getCosmosVectorEmbeddings();
assertThat(expectedEmbeddings).hasSameSizeAs(actualEmbeddings);
for (int i = 0; i < expectedEmbeddings.size(); i++) {
assertThat(expectedEmbeddings.get(i).getPath()).isEqualTo(actualEmbeddings.get(i).getPath());
assertThat(expectedEmbeddings.get(i).getCosmosVectorDataType()).isEqualTo(actualEmbeddings.get(i).getCosmosVectorDataType());
assertThat(expectedEmbeddings.get(i).getDimensions()).isEqualTo(actualEmbeddings.get(i).getDimensions());
assertThat(expectedEmbeddings.get(i).getCosmosVectorDistanceFunction()).isEqualTo(actualEmbeddings.get(i).getCosmosVectorDistanceFunction());
}
}

private void validateVectorIndexes(List<VectorIndexSpec> actual, List<VectorIndexSpec> expected) {
assertThat(expected).hasSameSizeAs(actual);
for (int i = 0; i < expected.size(); i++) {
assertThat(expected.get(i).getPath()).isEqualTo(actual.get(i).getPath());
assertThat(expected.get(i).getType()).isEqualTo(actual.get(i).getType());
}
}

Expand All @@ -322,21 +351,29 @@ private List<VectorIndexSpec> populateVectorIndexes() {
private List<CosmosVectorEmbedding> populateEmbeddings() {
CosmosVectorEmbedding embedding1 = new CosmosVectorEmbedding(
"/vector1",
CosmosVectorDistanceFunction.COSINE.toString(),
CosmosVectorDataType.FLOAT32.toString(),
3L,
CosmosVectorDataType.FLOAT32.toString());
CosmosVectorDistanceFunction.COSINE.toString());

CosmosVectorEmbedding embedding2 = new CosmosVectorEmbedding(
"/vector2",
CosmosVectorDistanceFunction.DOT_PRODUCT.toString(),
CosmosVectorDataType.INT8.toString(),
3L,
CosmosVectorDataType.INT8.toString());
CosmosVectorDistanceFunction.DOT_PRODUCT.toString());

CosmosVectorEmbedding embedding3 = new CosmosVectorEmbedding(
"/vector3",
CosmosVectorDistanceFunction.EUCLIDEAN.toString(),
CosmosVectorDataType.UINT8.toString(),
3L,
CosmosVectorDataType.UINT8.toString());
CosmosVectorDistanceFunction.EUCLIDEAN.toString());
return Arrays.asList(embedding1, embedding2, embedding3);
}

private String getVectorEmbeddingPolicyAsString() {
return "{\"vectorEmbeddings\":[" +
"{\"path\":\"/vector1\",\"dataType\":\"float32\",\"dimensions\":3,\"distanceFunction\":\"cosine\"}," +
"{\"path\":\"/vector2\",\"dataType\":\"int8\",\"dimensions\":3,\"distanceFunction\":\"dotproduct\"}," +
"{\"path\":\"/vector3\",\"dataType\":\"uint8\",\"dimensions\":3,\"distanceFunction\":\"euclidean\"}" +
"]}";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package com.azure.cosmos.models;

import com.azure.cosmos.implementation.Constants;
import com.azure.cosmos.implementation.JsonSerializable;
import com.fasterxml.jackson.annotation.JsonProperty;

/**
Expand All @@ -13,11 +14,12 @@ public final class CosmosVectorEmbedding {
@JsonProperty(Constants.Properties.PATH)
private String path;
@JsonProperty(Constants.Properties.VECTOR_DATA_TYPE)
private String vectorDataType;
private String cosmosVectorDataType;
@JsonProperty(Constants.Properties.VECTOR_DIMENSIONS)
private Long dimensions;
@JsonProperty(Constants.Properties.DISTANCE_FUNCTION)
private String distanceFunction;
private String cosmosVectorDistanceFunction;
private JsonSerializable jsonSerializable;

/**
* Constructor
Expand All @@ -29,9 +31,16 @@ public final class CosmosVectorEmbedding {
*/
public CosmosVectorEmbedding(String path, String vectorDataType, Long dimensions, String distanceFunction) {
this.path = path;
this.vectorDataType = vectorDataType;
this.cosmosVectorDataType = vectorDataType;
this.dimensions = dimensions;
this.distanceFunction = distanceFunction;
this.cosmosVectorDistanceFunction = distanceFunction;
}

/**
* Constructor
*/
public CosmosVectorEmbedding() {
this.jsonSerializable = new JsonSerializable();
}

/**
Expand All @@ -57,20 +66,20 @@ public CosmosVectorEmbedding setPath(String path) {
/**
* Gets the data type for the cosmosVectorEmbedding.
*
* @return vectorDataType
* @return cosmosVectorDataType
*/
public String getVectorDataType() {
return vectorDataType;
public String getCosmosVectorDataType() {
return cosmosVectorDataType;
}

/**
* Sets the data type for the cosmosVectorEmbedding.
*
* @param vectorDataType the data type for the cosmosVectorEmbedding
* @param cosmosVectorDataType the data type for the cosmosVectorEmbedding
* @return CosmosVectorEmbedding
*/
public CosmosVectorEmbedding setVectorDataType(String vectorDataType) {
this.vectorDataType = vectorDataType;
public CosmosVectorEmbedding setCosmosVectorDataType(String cosmosVectorDataType) {
this.cosmosVectorDataType = cosmosVectorDataType;
return this;
}

Expand All @@ -97,20 +106,20 @@ public CosmosVectorEmbedding setDimensions(Long dimensions) {
/**
* Gets the distanceFunction for the cosmosVectorEmbedding.
*
* @return distanceFunction
* @return cosmosVectorDistanceFunction
*/
public String getDistanceFunction() {
return distanceFunction;
public String getCosmosVectorDistanceFunction() {
return cosmosVectorDistanceFunction;
}

/**
* Sets the distanceFunction for the cosmosVectorEmbedding.
*
* @param distanceFunction the distanceFunction for the cosmosVectorEmbedding
* @param cosmosVectorDistanceFunction the distanceFunction for the cosmosVectorEmbedding
* @return CosmosVectorEmbedding
*/
public CosmosVectorEmbedding setDistanceFunction(String distanceFunction) {
this.distanceFunction = distanceFunction;
public CosmosVectorEmbedding setCosmosVectorDistanceFunction(String cosmosVectorDistanceFunction) {
this.cosmosVectorDistanceFunction = cosmosVectorDistanceFunction;
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,16 @@ public final class VectorEmbeddingPolicy {
* Paths for embeddings along with path-specific settings for the item.
*/
@JsonProperty(Constants.Properties.VECTOR_EMBEDDINGS)
private List<CosmosVectorEmbedding> embeddings;
private List<CosmosVectorEmbedding> cosmosVectorEmbeddings;

/**
* Constructor
*
* @param embeddings list of path for embeddings along with path-specific settings for the item.
* @param cosmosVectorEmbeddings list of path for embeddings along with path-specific settings for the item.
*/
public VectorEmbeddingPolicy(List<CosmosVectorEmbedding> embeddings) {
validateEmbeddings(embeddings);
this.embeddings = embeddings;
public VectorEmbeddingPolicy(List<CosmosVectorEmbedding> cosmosVectorEmbeddings) {
validateEmbeddings(cosmosVectorEmbeddings);
this.cosmosVectorEmbeddings = cosmosVectorEmbeddings;
}

/**
Expand All @@ -41,15 +41,15 @@ public VectorEmbeddingPolicy() {
this.jsonSerializable = new JsonSerializable();
}

private static void validateEmbeddings(List<CosmosVectorEmbedding> embeddings) {
embeddings.forEach(embedding -> {
private static void validateEmbeddings(List<CosmosVectorEmbedding> cosmosVectorEmbeddings) {
cosmosVectorEmbeddings.forEach(embedding -> {
if (embedding == null) {
throw new IllegalArgumentException("Embedding cannot be null.");
}
validateEmbeddingPath(embedding.getPath());
validateEmbeddingDimensions(embedding.getDimensions());
validateEmbeddingVectorDataType(embedding.getVectorDataType());
validateEmbeddingDistanceFunction(embedding.getDistanceFunction());
validateEmbeddingVectorDataType(embedding.getCosmosVectorDataType());
validateEmbeddingDistanceFunction(embedding.getCosmosVectorDistanceFunction());
});
}

Expand Down Expand Up @@ -93,7 +93,7 @@ private static void validateEmbeddingDistanceFunction(String value) {
*
* @return the paths for embeddings along with path-specific settings for the item.
*/
public List<CosmosVectorEmbedding> getEmbeddings() {
return this.embeddings;
public List<CosmosVectorEmbedding> getCosmosVectorEmbeddings() {
return this.cosmosVectorEmbeddings;
}
}