Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Reacting to comments
  • Loading branch information
tvaron3 committed Dec 18, 2024
commit e280624f77ce8c2bb029780601bd8c414bc6d6a2
Original file line number Diff line number Diff line change
Expand Up @@ -238,30 +238,52 @@ private[cosmos] object SparkBridgeImplementationInternal extends BasicLoggingTra

private[cosmos] def trySplitFeedRanges
(
partitionKeyDefinitionJson: String,
feedRange: FeedRangeEpkImpl,
bucketCount: Int
): Array[String] = {
cosmosClient: CosmosAsyncClient,
containerName: String,
databaseName: String,
targetedCount: Int
): List[String] = {
val container = cosmosClient
.getDatabase(databaseName)
.getContainer(containerName)

val pkDefinition = SparkModelBridgeInternal.createPartitionKeyDefinitionFromJson(partitionKeyDefinitionJson)
val feedRanges = FeedRangeInternal.trySplitCore(pkDefinition, feedRange.getRange, bucketCount)
val normalizedRanges = new Array[String](feedRanges.size())
val epkRange = rangeToNormalizedRange(FeedRange.forFullRange().asInstanceOf[FeedRangeEpkImpl].getRange)
val feedRanges = SparkBridgeInternal.trySplitFeedRange(container, epkRange, targetedCount)
var normalizedRanges: List[String] = List()
for (i <- feedRanges.indices) {
val normalizedRange = rangeToNormalizedRange(feedRanges(i).getRange)
normalizedRanges(i) = s"${normalizedRange.min}-${normalizedRange.max}"
normalizedRanges = normalizedRanges :+ s"${feedRanges(i).min}-${feedRanges(i).max}"
}
normalizedRanges
}

def findBucket(feedRanges: Array[String], pkValue: Object, pkDefinition: PartitionKeyDefinition):Int = {
private[cosmos] def getFeedRangesForContainer
(
cosmosClient: CosmosAsyncClient,
containerName: String,
databaseName: String
): List[String] = {
val container = cosmosClient
.getDatabase(databaseName)
.getContainer(containerName)

val feedRanges: List[String] = List()
container.getFeedRanges().block.map(feedRange => {
val effectiveRangeFromPk = feedRange.asInstanceOf[FeedRangeEpkImpl].getEffectiveRange(null, null, null).block
val normalizedRange = rangeToNormalizedRange(effectiveRangeFromPk)
s"${normalizedRange.min}-${normalizedRange.max}" :: feedRanges
})
feedRanges
}

def getOverlappingRange(feedRanges: Array[String], pkValue: Object, pkDefinition: PartitionKeyDefinition): String = {
val pk = getPartitionKeyValue(pkDefinition, pkValue)
val feedRangeFromPk = FeedRange.forLogicalPartition(pk).asInstanceOf[FeedRangePartitionKeyImpl]
val effectiveRangeFromPk = feedRangeFromPk.getEffectiveRange(pkDefinition)

for (i <- feedRanges.indices) {
val range = SparkBridgeImplementationInternal.toCosmosRange(feedRanges(i))
if (range.contains(effectiveRangeFromPk.getMin)) {
return i
return feedRanges(i)
}
}
throw new IllegalArgumentException("The partition key value does not belong to any of the feed ranges")
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package com.azure.cosmos.spark.udf

import com.azure.cosmos.implementation.SparkBridgeImplementationInternal
import com.azure.cosmos.spark.{CosmosClientCache, CosmosClientCacheItem, CosmosClientConfiguration, CosmosConfig, CosmosContainerConfig, CosmosReadConfig, Loan}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.api.java.UDF2

@SerialVersionUID(1L)
class GetFeedRangesForContainer extends UDF2[Map[String, String], Option[Int], Array[String]] {
override def call
(
userProvidedConfig: Map[String, String],
targetedCount: Option[Int]
): Array[String] = {

val effectiveUserConfig = CosmosConfig.getEffectiveConfig(None, None, userProvidedConfig)
var feedRanges = List[String]()
val cosmosContainerConfig: CosmosContainerConfig =
CosmosContainerConfig.parseCosmosContainerConfig(effectiveUserConfig, None, None)
val readConfig = CosmosReadConfig.parseCosmosReadConfig(effectiveUserConfig)
val cosmosClientConfig = CosmosClientConfiguration(
effectiveUserConfig,
useEventualConsistency = readConfig.forceEventualConsistency,
CosmosClientConfiguration.getSparkEnvironmentInfo(SparkSession.getActiveSession))
Loan(
List[Option[CosmosClientCacheItem]](
Some(CosmosClientCache(
cosmosClientConfig,
None,
s"UDF GetFeedRangesForContainer"
))
))
.to(cosmosClientCacheItems => {

if (targetedCount.isEmpty) {
feedRanges = SparkBridgeImplementationInternal.getFeedRangesForContainer(
cosmosClientCacheItems.head.get.cosmosClient,
cosmosContainerConfig.container,
cosmosContainerConfig.database
)
} else {
feedRanges = SparkBridgeImplementationInternal.trySplitFeedRanges(
cosmosClientCacheItems.head.get.cosmosClient,
cosmosContainerConfig.container,
cosmosContainerConfig.database,
targetedCount.get)
}
})
feedRanges.toArray

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@ import com.azure.cosmos.spark.CosmosPredicates.requireNotNullOrEmpty
import org.apache.spark.sql.api.java.UDF3

@SerialVersionUID(1L)
class GetBucketForPartitionKey extends UDF3[String, Object, Array[String], Int] {
class GetOverlappingFeedRange extends UDF3[String, Object, Array[String], String] {
override def call
(
partitionKeyDefinitionJson: String,
partitionKeyValue: Object,
feedRangesForBuckets: Array[String]
): Int = {
): String = {
requireNotNullOrEmpty(partitionKeyDefinitionJson, "partitionKeyDefinitionJson")

val pkDefinition = SparkModelBridgeInternal.createPartitionKeyDefinitionFromJson(partitionKeyDefinitionJson)
SparkBridgeImplementationInternal.findBucket(feedRangesForBuckets, partitionKeyValue, pkDefinition)
SparkBridgeImplementationInternal.getOverlappingRange(feedRangesForBuckets, partitionKeyValue, pkDefinition)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,16 @@
package com.azure.cosmos.spark

import com.azure.cosmos.CosmosAsyncContainer
import com.azure.cosmos.implementation.{SparkBridgeImplementationInternal, Utils}
import com.azure.cosmos.implementation.{SparkBridgeImplementationInternal, TestConfigurations, Utils}
import com.azure.cosmos.models.CosmosQueryRequestOptions
import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait
import com.azure.cosmos.spark.udf.{GetBucketForPartitionKey, GetFeedRangesForBuckets}
import com.azure.cosmos.spark.udf.{GetFeedRangesForContainer, GetOverlappingFeedRange}
import com.fasterxml.jackson.databind.node.ObjectNode
import org.apache.spark.sql.types.{ArrayType, IntegerType, StringType}

import java.util.UUID
import scala.collection.mutable

class FeedRangesForBucketsITest
class FeedRangesForContainerITest
extends IntegrationSpec
with SparkWithDropwizardAndSlf4jMetrics
with CosmosClient
Expand All @@ -28,62 +27,73 @@ class FeedRangesForBucketsITest
this.reinitializeContainer()
}

"feed ranges" can "be split into different buckets" in {
spark.udf.register("GetFeedRangesForBuckets", new GetFeedRangesForBuckets(), ArrayType(StringType))
"feed ranges" can "be split into different sub feed ranges" in {

val cosmosEndpoint = TestConfigurations.HOST
val cosmosMasterKey = TestConfigurations.MASTER_KEY

val cfg = Map(
"spark.cosmos.accountEndpoint" -> cosmosEndpoint,
"spark.cosmos.accountKey" -> cosmosMasterKey,
"spark.cosmos.database" -> cosmosDatabase,
"spark.cosmos.container" -> cosmosContainer,
)
var pkDefinition = "{\"paths\":[\"/id\"],\"kind\":\"Hash\"}"
val dummyDf = spark.sql(s"SELECT GetFeedRangesForBuckets('$pkDefinition', 5)")
val feedRanges = new GetFeedRangesForContainer().call(cfg, Option(5))
val expectedFeedRanges = Array("-05C1C9CD673398", "05C1C9CD673398-05C1D9CD673398",
"05C1D9CD673398-05C1E399CD6732", "05C1E399CD6732-05C1E9CD673398", "05C1E9CD673398-FF")
val feedRange = dummyDf
.collect()(0)
.getList[String](0)
.toArray

assert(feedRange.sameElements(expectedFeedRanges), "Feed ranges do not match the expected values")


assert(feedRanges.sameElements(expectedFeedRanges), "Feed ranges do not match the expected values")
val lastId = "45170a78-eac0-4d3a-be5e-9b00bb5f4649"

var bucket = new GetBucketForPartitionKey().call(pkDefinition, lastId, expectedFeedRanges)
assert(bucket == 0, "Bucket does not match the expected value")
var feedRangeResult = new GetOverlappingFeedRange().call(pkDefinition, lastId, expectedFeedRanges)
assert(feedRangeResult == "-05C1C9CD673398", "feed range does not match the expected value")

// test with hpk partition key definition
pkDefinition = "{\"paths\":[\"/tenantId\",\"/userId\",\"/sessionId\"],\"kind\":\"MultiHash\"}"
val pkValues = "[\"" + lastId + "\"]"

bucket = new GetBucketForPartitionKey().call(pkDefinition, pkValues, expectedFeedRanges)
assert(bucket == 4, "Bucket does not match the expected value")
feedRangeResult = new GetOverlappingFeedRange().call(pkDefinition, pkValues, expectedFeedRanges)
assert(feedRangeResult == "05C1E9CD673398-FF", "feed range does not match the expected value")

}

"feed ranges" can "be converted into buckets for new partition key" in {
feedRangesForBuckets(false)
"feed ranges" can "be mapped to new partition key" in {
feedRangesForPK(false)
}

"feed ranges" can "be converted into buckets for new hierarchical partition key" in {
feedRangesForBuckets(true)
"feed ranges" can "be mapped to new hierarchical partition key" in {
feedRangesForPK(true)
}

def feedRangesForBuckets(hpk: Boolean): Unit = {
def feedRangesForPK(hpk: Boolean): Unit = {

val cosmosEndpoint = TestConfigurations.HOST
val cosmosMasterKey = TestConfigurations.MASTER_KEY

val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(cosmosContainer)
val docs = createItems(container, 50, hpk)
val cfg = Map(
"spark.cosmos.accountEndpoint" -> cosmosEndpoint,
"spark.cosmos.accountKey" -> cosmosMasterKey,
"spark.cosmos.database" -> cosmosDatabase,
"spark.cosmos.container" -> cosmosContainer,
)

spark.udf.register("GetFeedRangesForBuckets", new GetFeedRangesForBuckets(), ArrayType(StringType))
val pkDefinition = if (hpk) {"{\"paths\":[\"/tenantId\",\"/userId\",\"/sessionId\"],\"kind\":\"MultiHash\"}"}
else {"{\"paths\":[\"/id\"],\"kind\":\"Hash\"}"}

val dummyDf = spark.sql(s"SELECT GetFeedRangesForBuckets('$pkDefinition', 5)")
val feedRanges = dummyDf
.collect()(0)
.getList[String](0)
.toArray(new Array[String](0))
val feedRanges = new GetFeedRangesForContainer().call(cfg, Option(5))

spark.udf.register("GetBucketForPartitionKey", new GetBucketForPartitionKey(), IntegerType)
val bucketToDocsMap = mutable.Map[Int, List[ObjectNode]]().withDefaultValue(List())
val feedRangeToDocsMap = mutable.Map[String, List[ObjectNode]]().withDefaultValue(List())

for (doc <- docs) {
val lastId = if (!hpk) doc.get("id").asText() else "[\"" + doc.get("tenantId").asText() + "\"]"
val bucket = new GetBucketForPartitionKey().call(pkDefinition, lastId, feedRanges)
// Add the document to the corresponding bucket in the map
bucketToDocsMap(bucket) = doc :: bucketToDocsMap(bucket)
val feedRange = new GetOverlappingFeedRange().call(pkDefinition, lastId, feedRanges)
// Add the document to the corresponding feed range in the map
feedRangeToDocsMap(feedRange) = doc :: feedRangeToDocsMap(feedRange)
}

for (i <- feedRanges.indices) {
Expand All @@ -93,14 +103,14 @@ class FeedRangesForBucketsITest
container.queryItems("SELECT * FROM c", requestOptions, classOf[ObjectNode]).byPage().collectList().block().forEach { rsp =>
val results = rsp.getResults
var numDocs = 0
val expectedResults = bucketToDocsMap(i)
val expectedResults = feedRangeToDocsMap(feedRanges(i))
results.forEach(doc => {
assert(expectedResults.collect({
case expectedDoc if expectedDoc.get("id").asText() == doc.get("id").asText() => expectedDoc
}).size >= 0, "Document not found in the expected bucket")
}).size >= 0, "Document not found in the expected feed range")
numDocs += 1
})
assert(numDocs == results.size(), "Number of documents in the bucket does not match the number of docs for that feed range")
assert(numDocs == results.size(), "Number of documents in the target feed range does not match the number of docs for that feed range")
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,27 +224,23 @@ public Mono<List<FeedRangeEpkImpl>> trySplit(
return Collections.singletonList(new FeedRangeEpkImpl(effectiveRange));
}

return trySplitCore(pkDefinition, effectiveRange, targetedSplitCount);
PartitionKeyDefinitionVersion effectivePKVersion =
pkDefinition.getVersion() != null
? pkDefinition.getVersion()
: PartitionKeyDefinitionVersion.V1;
switch (effectivePKVersion) {
case V1:
return trySplitWithHashV1(effectiveRange, targetedSplitCount);

case V2:
return trySplitWithHashV2(effectiveRange, targetedSplitCount);

default:
return Collections.singletonList(new FeedRangeEpkImpl(effectiveRange));
}
});
}

public static List<FeedRangeEpkImpl> trySplitCore(PartitionKeyDefinition pkDefinition, Range<String> effectiveRange, int targetedSplitCount) {
PartitionKeyDefinitionVersion effectivePKVersion =
pkDefinition.getVersion() != null
? pkDefinition.getVersion()
: PartitionKeyDefinitionVersion.V1;
switch (effectivePKVersion) {
case V1:
return trySplitWithHashV1(effectiveRange, targetedSplitCount);

case V2:
return trySplitWithHashV2(effectiveRange, targetedSplitCount);

default:
return Collections.singletonList(new FeedRangeEpkImpl(effectiveRange));
}

}

static List<FeedRangeEpkImpl> trySplitWithHashV1(
Range<String> effectiveRange,
Expand Down
Loading