Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@
package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, UnresolvedWithinGroup}
import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, TypeCheckResult, UnresolvedWithinGroup}
import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending, Expression, ExpressionDescription, ImplicitCastInputTypes, SortOrder}
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.catalyst.types.PhysicalDataType
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.catalyst.util.{CollationFactory, GenericArrayData, UnsafeRowUtils}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, BooleanType, DataType}
import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, BooleanType, DataType, StringType}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.collection.OpenHashMap

case class Mode(
Expand All @@ -48,6 +49,21 @@ case class Mode(

override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)

override def checkInputDataTypes(): TypeCheckResult = {
if (UnsafeRowUtils.isBinaryStable(child.dataType) || child.dataType.isInstanceOf[StringType]) {
/*
* The Mode class uses collation awareness logic to handle string data.
* Complex types with collated fields are not yet supported.
*/
// TODO: SPARK-48700: Mode expression for complex types (all collations)
super.checkInputDataTypes()
} else {
TypeCheckResult.TypeCheckFailure("The input to the function 'mode' was" +
" a type of binary-unstable type that is " +
s"not currently supported by ${prettyName}.")
}
}

override def prettyName: String = "mode"

override def update(
Expand All @@ -74,16 +90,38 @@ case class Mode(
if (buffer.isEmpty) {
return null
}

/*
* The Mode class uses special collation awareness logic
* to handle string data types with various collations.
*
* For string types that don't support binary equality,
* we create a new map where the keys are the collation keys of the original strings.
*
* Keys from the original map are aggregated based on the corresponding collation keys.
* The groupMapReduce method groups the entries by collation key and maps each group
* to a single value (the sum of the counts), and finally reduces the groups to a single map.
*
* The new map is then used in the rest of the Mode evaluation logic.
*/
val collationAwareBuffer = child.dataType match {
case c: StringType if
!CollationFactory.fetchCollation(c.collationId).supportsBinaryEquality =>
val collationId = c.collationId
val modeMap = buffer.toSeq.groupMapReduce {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not expert in this part of code but I wander if we could do better than this.
I see that most of the logic is in OpenHashMap and OpenHashSet. In OpenHashSet hash calc is usually done like this hashcode(hasher.hash(k)). If we just could get hash to respect collation problem might be solved.

On collation level we do have Collation.hashFunction. Can we somehow pass this to the OpenHashSet?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dbatomic What you are proposing would make sense. The complexity is increased. But i can whip up a draft PR and we can see whether it makes sense to proceed.

Copy link
Contributor Author

@GideonPotok GideonPotok Jun 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dbatomic @uros-db here is a mockup/proof of concept of this proposal: #46917.

The relevant unit test has passed, which indicates that this approach is viable! Now, we need to consider whether to advance and determine how to integrate the relevant information about key datatype into the OpenHashMap. What are your thoughts on the feasibility of moving forward?

I'm primarily concerned about the risks involved: Integrating collation with specialized types and complex hash functions might lead to subtle bugs. Considering the crucial nature of this data structure, we should approach any changes with a detailed plan for validation and with caution. It may be wise to consider less invasive modifications , such as the one proposed in this PR (#46597).

Despite these concerns, this approach is functioning, and it touches on a particularly intriguing part of the codebase that I am eager to work on. If you think it's a promising route, I'm ready to complete the implementation and perform further benchmarks. However, I would appreciate some design suggestions as mentioned below.

To effectively implement this, I see two possible directions:

  1. Is there a benefit to using AnyRef (as in OpenHashMap[AnyRef, ...]) by TypedAggregateWithHashMapAsBuffer? This was introduced here: https://github.com/apache/spark/pull/37216/files without a clear explanation of why AnyRef was preferred over generics. Should TypedAggregateWithHashMapAsBuffer remain unchanged, or should it evolve to rely on (Pseudo Code) OpenHashMap[childExpression.dataType.getClass, ...] for more specific typing? @beliefer, although it’s been some time since you worked on this, could you advise on whether this component should be modified?
  2. Assuming TypedAggregateWithHashMapAsBuffer remains unchanged, I'm seeking a more effective method to inject the custom hashing logic (and a custom keyExistsAtPos method) from Mode into the OpenHashMap, depending on the childExpr.dataType. I would greatly value ideas on how to best integrate this. At the moment, the proof of concept is assuming any object passed into OpenHashSet that is not Long,Int,Double, or Float is a UTF8String with UTF8_BINARY_LCASE collation.

Lastly, while I am eager to complete the implementation, I hope to ensure that this is something you would definitively want to pursue, barring any significant performance setbacks revealed by benchmarking. I've developed this proof of concept and it's operational, but a full implementation should ideally be something you are confident is the right direction.

Copy link
Contributor Author

@GideonPotok GideonPotok Jun 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@uros-db @dbatomic Update: #46917 is, at this point, pretty close to what a finished version could look like. Please let me know which you prefer this approach or that one, which features the use of a collation aware hash function within OpenHashSet. Thank you.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left some comments in #46917 - overall, I'm liking the new approach, but I agree with you that group mapReduce was far less invasive... Personally, I'm fine with either approach and don't have a hunch on why one might obviously prevail over the other at this point, but there might be good value in drilling deeper and running benchmarks on #46917 - then we would be able to compare, discuss, and opt for our winner

After all, we shouldn't spend much more time here, so I suggest the following - if you're feeling up for it @GideonPotok please continue investigating further with the new approach (OHM) until we reach benchmark results for a final decision. If not, I would be happy with cleaning everything up with the old approach (GMR). After all, we can always re-visit this in the future

How does this sound? @GideonPotok @dbatomic

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can create another PR for change the method signature.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@beliefer sounds good.. Once it is ready I will tag you in that one, assuming we proceed with that approach.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dbatomic see comment here #46917 (comment)

And let us know if on board to continue with this approach.

It's always an option to not have mode support collations. But if supporting collations, this is the best way to go. We have tried a bunch of approaches and this way is simple, tested, decoupled and at the correct layer of abstraction, and easy to modify.

I can also add support for complex types and for PandasMode in this pr once I know you are on board with this approach.

case (k, _) => CollationFactory.getCollationKey(k.asInstanceOf[UTF8String], collationId)
}(x => x)((x, y) => (x._1, x._2 + y._2)).values
modeMap
case _ => buffer
}
reverseOpt.map { reverse =>
val defaultKeyOrdering = if (reverse) {
PhysicalDataType.ordering(child.dataType).asInstanceOf[Ordering[AnyRef]].reverse
} else {
PhysicalDataType.ordering(child.dataType).asInstanceOf[Ordering[AnyRef]]
}
val ordering = Ordering.Tuple2(Ordering.Long, defaultKeyOrdering)
buffer.maxBy { case (key, count) => (count, key) }(ordering)
}.getOrElse(buffer.maxBy(_._2))._1
collationAwareBuffer.maxBy { case (key, count) => (count, key) }(ordering)
}.getOrElse(collationAwareBuffer.maxBy(_._2))._1
}

override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): Mode =
Expand Down Expand Up @@ -128,6 +166,7 @@ case class Mode(
copy(child = newChild)
}

// TODO: SPARK-48701: PandasMode (all collations)
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = """
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,13 @@ import scala.collection.immutable.Seq

import org.apache.spark.{SparkConf, SparkException, SparkIllegalArgumentException, SparkRuntimeException}
import org.apache.spark.sql.catalyst.ExtendedAnalysisException
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.expressions.aggregate.Mode
import org.apache.spark.sql.internal.{SqlApiConf, SQLConf}
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.collection.OpenHashMap

// scalastyle:off nonascii
class CollationSQLExpressionsSuite
Expand Down Expand Up @@ -1646,6 +1650,216 @@ class CollationSQLExpressionsSuite
}
}

test("Support mode for string expression with collation - Basic Test") {
Seq("utf8_binary", "UTF8_LCASE", "unicode_ci", "unicode").foreach { collationId =>
val query = s"SELECT mode(collate('abc', '${collationId}'))"
checkAnswer(sql(query), Row("abc"))
assert(sql(query).schema.fields.head.dataType.sameType(StringType(collationId)))
}
}

test("Support mode for string expression with collation - Advanced Test") {
case class ModeTestCase[R](collationId: String, bufferValues: Map[String, Long], result: R)
val testCases = Seq(
ModeTestCase("utf8_binary", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"),
ModeTestCase("UTF8_LCASE", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"),
ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"),
ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a")
)
testCases.foreach(t => {
val valuesToAdd = t.bufferValues.map { case (elt, numRepeats) =>
(0L to numRepeats).map(_ => s"('$elt')").mkString(",")
}.mkString(",")

val tableName = s"t_${t.collationId}_mode"
withTable(s"${tableName}") {
sql(s"CREATE TABLE ${tableName}(i STRING) USING parquet")
sql(s"INSERT INTO ${tableName} VALUES " + valuesToAdd)
val query = s"SELECT mode(collate(i, '${t.collationId}')) FROM ${tableName}"
checkAnswer(sql(query), Row(t.result))
assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.collationId)))

}
})
}

test("Support Mode.eval(buffer)") {
case class UTF8StringModeTestCase[R](
collationId: String,
bufferValues: Map[UTF8String, Long],
result: R)

val bufferValuesUTF8String = Map(
UTF8String.fromString("a") -> 5L,
UTF8String.fromString("b") -> 4L,
UTF8String.fromString("B") -> 3L,
UTF8String.fromString("d") -> 2L,
UTF8String.fromString("e") -> 1L)

val testCasesUTF8String = Seq(
UTF8StringModeTestCase("utf8_binary", bufferValuesUTF8String, "a"),
UTF8StringModeTestCase("UTF8_LCASE", bufferValuesUTF8String, "b"),
UTF8StringModeTestCase("unicode_ci", bufferValuesUTF8String, "b"),
UTF8StringModeTestCase("unicode", bufferValuesUTF8String, "a"))

testCasesUTF8String.foreach(t => {
val buffer = new OpenHashMap[AnyRef, Long](5)
val myMode = Mode(child = Literal.create("some_column_name", StringType(t.collationId)))
t.bufferValues.foreach { case (k, v) => buffer.update(k, v) }
assert(myMode.eval(buffer).toString.toLowerCase() == t.result.toLowerCase())
})
}

test("Support mode for string expression with collated strings in struct") {
case class ModeTestCase[R](collationId: String, bufferValues: Map[String, Long], result: R)
val testCases = Seq(
ModeTestCase("utf8_binary", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"),
ModeTestCase("UTF8_LCASE", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"),
ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"),
ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b")
)
testCases.foreach(t => {
val valuesToAdd = t.bufferValues.map { case (elt, numRepeats) =>
(0L to numRepeats).map(_ => s"named_struct('f1'," +
s" collate('$elt', '${t.collationId}'), 'f2', 1)").mkString(",")
}.mkString(",")

val tableName = s"t_${t.collationId}_mode_struct"
withTable(tableName) {
sql(s"CREATE TABLE ${tableName}(i STRUCT<f1: STRING COLLATE " +
t.collationId + ", f2: INT>) USING parquet")
sql(s"INSERT INTO ${tableName} VALUES " + valuesToAdd)
val query = s"SELECT lower(mode(i).f1) FROM ${tableName}"
if(t.collationId == "UTF8_LCASE" ||
t.collationId == "unicode_ci" ||
t.collationId == "unicode") {
// Cannot resolve "mode(i)" due to data type mismatch:
// Input to function mode was a complex type with strings collated on non-binary
// collations, which is not yet supported.. SQLSTATE: 42K09; line 1 pos 13;
val params = Seq(("sqlExpr", "\"mode(i)\""),
("msg", "The input to the function 'mode'" +
" was a type of binary-unstable type that is not currently supported by mode."),
("hint", "")).toMap
checkError(
exception = intercept[AnalysisException] {
sql(query)
},
errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT",
parameters = params,
queryContext = Array(
ExpectedContext(objectType = "",
objectName = "",
startIndex = 13,
stopIndex = 19,
fragment = "mode(i)")
)
)
} else {
checkAnswer(sql(query), Row(t.result))
}
}
})
}

test("Support mode for string expression with collated strings in recursively nested struct") {
case class ModeTestCase[R](collationId: String, bufferValues: Map[String, Long], result: R)
val testCases = Seq(
ModeTestCase("utf8_binary", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"),
ModeTestCase("UTF8_LCASE", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"),
ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"),
ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b")
)
testCases.foreach(t => {
val valuesToAdd = t.bufferValues.map { case (elt, numRepeats) =>
(0L to numRepeats).map(_ => s"named_struct('f1', " +
s"named_struct('f2', collate('$elt', '${t.collationId}')), 'f3', 1)").mkString(",")
}.mkString(",")

val tableName = s"t_${t.collationId}_mode_nested_struct"
withTable(tableName) {
sql(s"CREATE TABLE ${tableName}(i STRUCT<f1: STRUCT<f2: STRING COLLATE " +
t.collationId + ">, f3: INT>) USING parquet")
sql(s"INSERT INTO ${tableName} VALUES " + valuesToAdd)
val query = s"SELECT lower(mode(i).f1.f2) FROM ${tableName}"
if(t.collationId == "UTF8_LCASE" ||
t.collationId == "unicode_ci" ||
t.collationId == "unicode") {
// Cannot resolve "mode(i)" due to data type mismatch:
// Input to function mode was a complex type with strings collated on non-binary
// collations, which is not yet supported.. SQLSTATE: 42K09; line 1 pos 13;
val params = Seq(("sqlExpr", "\"mode(i)\""),
("msg", "The input to the function 'mode' " +
"was a type of binary-unstable type that is not currently supported by mode."),
("hint", "")).toMap
checkError(
exception = intercept[AnalysisException] {
sql(query)
},
errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT",
parameters = params,
queryContext = Array(
ExpectedContext(objectType = "",
objectName = "",
startIndex = 13,
stopIndex = 19,
fragment = "mode(i)")
)
)
} else {
checkAnswer(sql(query), Row(t.result))
}
}
})
}

test("Support mode for string expression with collated strings in array complex type") {
case class ModeTestCase[R](collationId: String, bufferValues: Map[String, Long], result: R)
val testCases = Seq(
ModeTestCase("utf8_binary", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"),
ModeTestCase("UTF8_LCASE", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"),
ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"),
ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b")
)
testCases.foreach(t => {
val valuesToAdd = t.bufferValues.map { case (elt, numRepeats) =>
(0L to numRepeats).map(_ => s"array(named_struct('s1', named_struct('a2', " +
s"array(collate('$elt', '${t.collationId}'))), 'f3', 1))").mkString(",")
}.mkString(",")

val tableName = s"t_${t.collationId}_mode_nested_struct"
withTable(tableName) {
sql(s"CREATE TABLE ${tableName}(" +
s"i ARRAY<STRUCT<s1: STRUCT<a2: ARRAY<STRING COLLATE ${t.collationId}>>, f3: INT>>)" +
s" USING parquet")
sql(s"INSERT INTO ${tableName} VALUES " + valuesToAdd)
val query = s"SELECT lower(element_at(element_at(mode(i), 1).s1.a2, 1)) FROM ${tableName}"
if(t.collationId == "UTF8_LCASE" ||
t.collationId == "unicode_ci" || t.collationId == "unicode") {
val params = Seq(("sqlExpr", "\"mode(i)\""),
("msg", "The input to the function 'mode' was a type" +
" of binary-unstable type that is not currently supported by mode."),
("hint", "")).toMap
checkError(
exception = intercept[AnalysisException] {
sql(query)
},
errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT",
parameters = params,
queryContext = Array(
ExpectedContext(objectType = "",
objectName = "",
startIndex = 35,
stopIndex = 41,
fragment = "mode(i)")
)
)
} else {
checkAnswer(sql(query), Row(t.result))
}
}
})
}

test("SPARK-48430: Map value extraction with collations") {
for {
collateKey <- Seq(true, false)
Expand Down