Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Added test more test coverage
  • Loading branch information
tvaron3 committed Dec 3, 2024
commit 11effce2558f08a9130602b0b6b56849e552cd0c
1 change: 1 addition & 0 deletions sdk/cosmos/azure-cosmos-spark_3-1_2-12/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
### 4.35.0-beta.1 (Unreleased)

#### Features Added
* Added the udfs `GetFeedRangesForBuckets` and `GetBucketForPartitionKey` to ease mapping of cosmos partition key to databricks table partition key. - See [PR 43092](https://github.com/Azure/azure-sdk-for-java/pull/43092)

#### Breaking Changes

Expand Down
1 change: 1 addition & 0 deletions sdk/cosmos/azure-cosmos-spark_3-2_2-12/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
### 4.35.0-beta.1 (Unreleased)

#### Features Added
* Added the udfs `GetFeedRangesForBuckets` and `GetBucketForPartitionKey` to ease mapping of cosmos partition key to databricks table partition key. - See [PR 43092](https://github.com/Azure/azure-sdk-for-java/pull/43092)

#### Breaking Changes

Expand Down
1 change: 1 addition & 0 deletions sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
### 4.35.0-beta.1 (Unreleased)

#### Features Added
* Added the udfs `GetFeedRangesForBuckets` and `GetBucketForPartitionKey` to ease mapping of cosmos partition key to databricks table partition key. - See [PR 43092](https://github.com/Azure/azure-sdk-for-java/pull/43092)

#### Breaking Changes

Expand Down
1 change: 1 addition & 0 deletions sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
### 4.35.0-beta.1 (Unreleased)

#### Features Added
* Added the udfs `GetFeedRangesForBuckets` and `GetBucketForPartitionKey` to ease mapping of cosmos partition key to databricks table partition key. - See [PR 43092](https://github.com/Azure/azure-sdk-for-java/pull/43092)

#### Breaking Changes

Expand Down
1 change: 1 addition & 0 deletions sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
### 4.35.0-beta.1 (Unreleased)

#### Features Added
* Added the udfs `GetFeedRangesForBuckets` and `GetBucketForPartitionKey` to ease mapping of cosmos partition key to databricks table partition key. - See [PR 43092](https://github.com/Azure/azure-sdk-for-java/pull/43092)

#### Breaking Changes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ private[cosmos] object SparkBridgeImplementationInternal extends BasicLoggingTra
partitionKeyDefinitionJson: String
): NormalizedRange = {
val pkDefinition = SparkModelBridgeInternal.createPartitionKeyDefinitionFromJson(partitionKeyDefinitionJson)
partitionKeyToNormalizedRange(new PartitionKey(partitionKeyValue), pkDefinition)
partitionKeyToNormalizedRange(getPartitionKeyValue(pkDefinition, partitionKeyValue), pkDefinition)
}

private[cosmos] def partitionKeyToNormalizedRange(
Expand All @@ -226,27 +226,13 @@ private[cosmos] object SparkBridgeImplementationInternal extends BasicLoggingTra
partitionKeyValueJsonArray: Object,
partitionKeyDefinitionJson: String
): NormalizedRange = {

val partitionKey = new PartitionKeyBuilder()
val objectMapper = new ObjectMapper()
val json = partitionKeyValueJsonArray.toString
try {
val partitionKeyValues = objectMapper.readValue(json, classOf[Array[String]])
for (value <- partitionKeyValues) {
partitionKey.add(value.trim)
}
partitionKey.build()
} catch {
case e: Exception =>
logInfo("Invalid partition key paths: " + json, e)
}

val feedRange = FeedRange
.forLogicalPartition(partitionKey.build())
val pkDefinition = SparkModelBridgeInternal.createPartitionKeyDefinitionFromJson(partitionKeyDefinitionJson)
val partitionKey = getPartitionKeyValue(pkDefinition, partitionKeyValueJsonArray)
val feedRange = FeedRange
.forLogicalPartition(partitionKey)
.asInstanceOf[FeedRangePartitionKeyImpl]

val pkDefinition = SparkModelBridgeInternal.createPartitionKeyDefinitionFromJson(partitionKeyDefinitionJson)
val effectiveRange = feedRange.getEffectiveRange(pkDefinition)
val effectiveRange = feedRange.getEffectiveRange(pkDefinition)
rangeToNormalizedRange(effectiveRange)
}

Expand All @@ -268,10 +254,23 @@ private[cosmos] object SparkBridgeImplementationInternal extends BasicLoggingTra
}

def findBucket(feedRanges: Array[String], pkValue: Object, pkDefinition: PartitionKeyDefinition):Int = {
val pk = getPartitionKeyValue(pkDefinition, pkValue)
val feedRangeFromPk = FeedRange.forLogicalPartition(pk).asInstanceOf[FeedRangePartitionKeyImpl]
val effectiveRangeFromPk = feedRangeFromPk.getEffectiveRange(pkDefinition)

for (i <- feedRanges.indices) {
val range = SparkBridgeImplementationInternal.toCosmosRange(feedRanges(i))
if (range.contains(effectiveRangeFromPk.getMin)) {
return i
}
}
throw new IllegalArgumentException("The partition key value does not belong to any of the feed ranges")
}

private def getPartitionKeyValue(pkDefinition: PartitionKeyDefinition, pkValue: Object): PartitionKey = {
val partitionKey = new PartitionKeyBuilder()
var pk: PartitionKey = null
// refactor this
if (pkDefinition.getKind == PartitionKind.MULTI_HASH) {
if (pkDefinition.getKind.equals(PartitionKind.MULTI_HASH)) {
val objectMapper = new ObjectMapper()
val json = pkValue.toString
try {
Expand All @@ -284,21 +283,10 @@ private[cosmos] object SparkBridgeImplementationInternal extends BasicLoggingTra
case e: Exception =>
logInfo("Invalid partition key paths: " + json, e)
}
} else if (pkDefinition.getKind == PartitionKind.HASH) {
} else if (pkDefinition.getKind.equals(PartitionKind.HASH)) {
pk = new PartitionKey(pkValue)
}
val feedRangeFromPk = FeedRange.forLogicalPartition(pk).asInstanceOf[FeedRangePartitionKeyImpl]
val effectiveRangeFromPk = feedRangeFromPk.getEffectiveRange(pkDefinition)

for (i <- feedRanges.indices) {
val range = SparkBridgeImplementationInternal.toCosmosRange(feedRanges(i))
if (range.contains(effectiveRangeFromPk.getMin)) {
return i
}

}
-1

pk
}

def setIoThreadCountPerCoreFactor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,22 @@
// Licensed under the MIT License.
package com.azure.cosmos.spark

import com.azure.cosmos.implementation.{TestConfigurations, Utils}
import com.azure.cosmos.CosmosAsyncContainer
import com.azure.cosmos.implementation.{SparkBridgeImplementationInternal, Utils}
import com.azure.cosmos.models.CosmosQueryRequestOptions
import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait
import com.azure.cosmos.spark.udf.{GetBucketForPartitionKey, GetFeedRangeForHierarchicalPartitionKeyValues, GetFeedRangesForBuckets}
import org.apache.spark.sql.types._
import com.azure.cosmos.spark.udf.{GetBucketForPartitionKey, GetFeedRangesForBuckets}
import com.fasterxml.jackson.databind.node.ObjectNode
import org.apache.spark.sql.types.{ArrayType, IntegerType, StringType}

import java.util.UUID
import scala.collection.mutable

class FeedRangesForBucketsITest
extends IntegrationSpec
with SparkWithDropwizardAndSlf4jMetrics
with CosmosClient
with AutoCleanableCosmosContainerWithSubpartitions
with CosmosContainer
with BasicLoggingTrait
with MetricAssertions {

Expand All @@ -24,26 +28,102 @@ class FeedRangesForBucketsITest
this.reinitializeContainer()
}

"feed ranges" can "can be split into different buckets" in {
"feed ranges" can "be split into different buckets" in {
spark.udf.register("GetFeedRangesForBuckets", new GetFeedRangesForBuckets(), ArrayType(StringType))
val pkDefinition = "{\"paths\":[\"/id\"],\"kind\":\"Hash\"}"
var dummyDf = spark.sql(s"SELECT GetFeedRangesForBuckets('$pkDefinition', 5)")
var pkDefinition = "{\"paths\":[\"/id\"],\"kind\":\"Hash\"}"
val dummyDf = spark.sql(s"SELECT GetFeedRangesForBuckets('$pkDefinition', 5)")
val expectedFeedRanges = Array("-05C1C9CD673398", "05C1C9CD673398-05C1D9CD673398",
"05C1D9CD673398-05C1E399CD6732", "05C1E399CD6732-05C1E9CD673398", "05C1E9CD673398-FF")
val feedRange = dummyDf
.collect()(0)
.getList[String](0)
.toArray

logInfo(s"FeedRange from UDF: $feedRange")
assert(feedRange.sameElements(expectedFeedRanges), "Feed ranges do not match the expected values")
val lastId = "45170a78-eac0-4d3a-be5e-9b00bb5f4649"

// spark.udf.register("GetBucketForPartitionKey", new GetBucketForPartitionKey(), IntegerType)
// dummyDf = spark.sql(s"SELECT GetBucketForPartitionKey('$pkDefinition', 4979ea4a-6ba6-42ee-b9e6-1f5bf996a01f, '$feedRange')")
// val bucket = dummyDf.collect()(0).getInt(0)
// assert(bucket == 0, "Bucket does not match the expected value")
var bucket = new GetBucketForPartitionKey().call(pkDefinition, lastId, expectedFeedRanges)
assert(bucket == 0, "Bucket does not match the expected value")

// test with hpk partition key definition
pkDefinition = "{\"paths\":[\"/tenantId\",\"/userId\",\"/sessionId\"],\"kind\":\"MultiHash\"}"
val pkValues = "[\"" + lastId + "\"]"

bucket = new GetBucketForPartitionKey().call(pkDefinition, pkValues, expectedFeedRanges)
assert(bucket == 4, "Bucket does not match the expected value")

}

"feed ranges" can "be converted into buckets for new partition key" in {
feedRangesForBuckets(false)
}

"feed ranges" can "be converted into buckets for new hierarchical partition key" in {
feedRangesForBuckets(true)
}

def feedRangesForBuckets(hpk: Boolean): Unit = {
val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(cosmosContainer)
val docs = createItems(container, 50, hpk)

spark.udf.register("GetFeedRangesForBuckets", new GetFeedRangesForBuckets(), ArrayType(StringType))
val pkDefinition = if (hpk) {"{\"paths\":[\"/tenantId\",\"/userId\",\"/sessionId\"],\"kind\":\"MultiHash\"}"}
else {"{\"paths\":[\"/id\"],\"kind\":\"Hash\"}"}

val dummyDf = spark.sql(s"SELECT GetFeedRangesForBuckets('$pkDefinition', 5)")
val feedRanges = dummyDf
.collect()(0)
.getList[String](0)
.toArray(new Array[String](0))

spark.udf.register("GetBucketForPartitionKey", new GetBucketForPartitionKey(), IntegerType)
val bucketToDocsMap = mutable.Map[Int, List[ObjectNode]]().withDefaultValue(List())

for (doc <- docs) {
val lastId = if (!hpk) doc.get("id").asText() else "[\"" + doc.get("tenantId").asText() + "\"]"
val bucket = new GetBucketForPartitionKey().call(pkDefinition, lastId, feedRanges)
// Add the document to the corresponding bucket in the map
bucketToDocsMap(bucket) = doc :: bucketToDocsMap(bucket)
}

for (i <- feedRanges.indices) {
val range = SparkBridgeImplementationInternal.toCosmosRange(feedRanges(i))
val feedRange = SparkBridgeImplementationInternal.toFeedRange(SparkBridgeImplementationInternal.rangeToNormalizedRange(range))
val requestOptions = new CosmosQueryRequestOptions().setFeedRange(feedRange)
container.queryItems("SELECT * FROM c", requestOptions, classOf[ObjectNode]).byPage().collectList().block().forEach { rsp =>
val results = rsp.getResults
var numDocs = 0
val expectedResults = bucketToDocsMap(i)
results.forEach(doc => {
assert(expectedResults.collect({
case expectedDoc if expectedDoc.get("id").asText() == doc.get("id").asText() => expectedDoc
}).size >= 0, "Document not found in the expected bucket")
numDocs += 1
})
assert(numDocs == results.size(), "Number of documents in the bucket does not match the number of docs for that feed range")
}
}
}

def createItems(container: CosmosAsyncContainer, numOfItems: Int, hpk: Boolean): Array[ObjectNode] = {
val docs = new Array[ObjectNode](numOfItems)
for (sequenceNumber <- 1 to numOfItems) {
val lastId = UUID.randomUUID().toString
val objectNode = Utils.getSimpleObjectMapper.createObjectNode()
objectNode.put("name", "Shrodigner's cat")
objectNode.put("type", "cat")
objectNode.put("age", 20)
objectNode.put("sequenceNumber", sequenceNumber)
objectNode.put("id", lastId)
if (hpk) {
objectNode.put("tenantId", lastId)
objectNode.put("userId", "userId1")
objectNode.put("sessionId", "sessionId1")
}
docs(sequenceNumber - 1) = objectNode
container.createItem(objectNode).block()
}
docs
}

//scalastyle:on magic.number
Expand Down
Loading