diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala index d1a9cafdf61fa..878313c28f05c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala @@ -22,11 +22,13 @@ import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, UnresolvedWith import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending, Expression, ExpressionDescription, ImplicitCastInputTypes, SortOrder} import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.catalyst.types.PhysicalDataType -import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.catalyst.util.{CollationFactory, GenericArrayData} import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, BooleanType, DataType} +import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, BooleanType, DataType, StringType} +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.collection.OpenHashMap + case class Mode( child: Expression, mutableAggBufferOffset: Int = 0, @@ -35,6 +37,8 @@ case class Mode( extends TypedAggregateWithHashMapAsBuffer with ImplicitCastInputTypes with SupportsOrderingWithinGroup with UnaryLike[Expression] { + final lazy val collationId: Int = child.dataType.asInstanceOf[StringType].collationId + def this(child: Expression) = this(child, 0, 0) def this(child: Expression, reverse: Boolean) = { @@ -74,6 +78,19 @@ case class Mode( if (buffer.isEmpty) { return null } + val collationAwareBuffer = + if (!CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { + val modeMap = buffer.toSeq.groupMapReduce { + case (key: String, _) => + CollationFactory.getCollationKey(UTF8String.fromString(key), collationId) + case (key: UTF8String, _) => + CollationFactory.getCollationKey(key, collationId) + case (key, _) => key + }(x => x)((x, y) => (x._1, x._2 + y._2)).values + modeMap + } else { + buffer + } reverseOpt.map { reverse => val defaultKeyOrdering = if (reverse) { @@ -82,8 +99,8 @@ case class Mode( PhysicalDataType.ordering(child.dataType).asInstanceOf[Ordering[AnyRef]] } val ordering = Ordering.Tuple2(Ordering.Long, defaultKeyOrdering) - buffer.maxBy { case (key, count) => (count, key) }(ordering) - }.getOrElse(buffer.maxBy(_._2))._1 + collationAwareBuffer.maxBy { case (key, count) => (count, key) }(ordering) + }.getOrElse(collationAwareBuffer.maxBy(_._2))._1 } override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): Mode = @@ -249,7 +266,7 @@ case class PandasMode( val (key, count) = iter.next() if (maxCount < count) { modes.clear() - modes.append(key) + modes. append(key) maxCount = count } else if (maxCount == count) { modes.append(key) diff --git a/sql/core/benchmarks/CollationBenchmark-jdk21-results.txt b/sql/core/benchmarks/CollationBenchmark-jdk21-results.txt index 500e39a965f37..5b71f71d6d859 100644 --- a/sql/core/benchmarks/CollationBenchmark-jdk21-results.txt +++ b/sql/core/benchmarks/CollationBenchmark-jdk21-results.txt @@ -2,53 +2,61 @@ OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1018-azure AMD EPYC 7763 64-Core Processor collation unit benchmarks - equalsFunction: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------------------------------------- -UTF8_BINARY_LCASE 2948 2958 13 0.0 29483.6 1.0X -UNICODE 2040 2042 3 0.0 20396.6 1.4X -UTF8_BINARY 2043 2043 0 0.0 20426.3 1.4X -UNICODE_CI 16318 16338 28 0.0 163178.4 0.2X +UTF8_BINARY_LCASE 2889 2923 48 0.0 28892.1 1.0X +UNICODE 2748 2748 1 0.0 27476.5 1.1X +UTF8_BINARY 2744 2745 1 0.0 27439.5 1.1X +UNICODE_CI 16815 16817 2 0.0 168154.3 0.2X OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1018-azure AMD EPYC 7763 64-Core Processor collation unit benchmarks - compareFunction: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative --------------------------------------------------------------------------------------------------------------------------- -UTF8_BINARY_LCASE 3227 3228 1 0.0 32272.1 1.0X -UNICODE 16637 16643 9 0.0 166367.7 0.2X -UTF8_BINARY 3132 3137 7 0.0 31319.2 1.0X -UNICODE_CI 17816 17829 18 0.0 178162.4 0.2X +UTF8_BINARY_LCASE 4782 4784 3 0.0 47819.3 1.0X +UNICODE 18986 18995 13 0.0 189855.8 0.3X +UTF8_BINARY 5026 5048 31 0.0 50258.2 1.0X +UNICODE_CI 19735 19771 50 0.0 197351.1 0.2X OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1018-azure AMD EPYC 7763 64-Core Processor collation unit benchmarks - hashFunction: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -UTF8_BINARY_LCASE 4824 4824 0 0.0 48243.7 1.0X -UNICODE 69416 69475 84 0.0 694158.3 0.1X -UTF8_BINARY 3806 3808 2 0.0 38062.8 1.3X -UNICODE_CI 60943 60975 45 0.0 609426.2 0.1X +UTF8_BINARY_LCASE 4933 4933 1 0.0 49330.9 1.0X +UNICODE 68091 68119 40 0.0 680908.8 0.1X +UTF8_BINARY 3878 3879 2 0.0 38782.4 1.3X +UNICODE_CI 55501 55526 35 0.0 555014.2 0.1X OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1018-azure AMD EPYC 7763 64-Core Processor collation unit benchmarks - contains: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -UTF8_BINARY_LCASE 11979 11980 1 0.0 119790.4 1.0X -UNICODE 6469 6474 7 0.0 64694.8 1.9X -UTF8_BINARY 7253 7253 1 0.0 72528.3 1.7X -UNICODE_CI 319124 319881 1070 0.0 3191244.0 0.0X +UTF8_BINARY_LCASE 10441 10444 4 0.0 104412.3 1.0X +UNICODE 5811 5812 1 0.0 58106.6 1.8X +UTF8_BINARY 6397 6411 19 0.0 63971.7 1.6X +UNICODE_CI 323853 324618 1082 0.0 3238530.0 0.0X OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1018-azure AMD EPYC 7763 64-Core Processor collation unit benchmarks - startsWith: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -UTF8_BINARY_LCASE 11584 11595 15 0.0 115841.4 1.0X -UNICODE 6155 6156 2 0.0 61548.7 1.9X -UTF8_BINARY 6979 6982 5 0.0 69785.6 1.7X -UNICODE_CI 318228 318726 705 0.0 3182275.2 0.0X +UTF8_BINARY_LCASE 10123 10154 44 0.0 101227.4 1.0X +UNICODE 5682 5686 7 0.0 56815.0 1.8X +UTF8_BINARY 6296 6300 5 0.0 62961.9 1.6X +UNICODE_CI 318720 318957 336 0.0 3187199.4 0.0X OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1018-azure AMD EPYC 7763 64-Core Processor collation unit benchmarks - endsWith: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -UTF8_BINARY_LCASE 11655 11664 12 0.0 116552.8 1.0X -UNICODE 6235 6239 5 0.0 62350.8 1.9X -UTF8_BINARY 7066 7069 5 0.0 70658.1 1.6X -UNICODE_CI 313515 313999 685 0.0 3135149.1 0.0X +UTF8_BINARY_LCASE 10195 10198 5 0.0 101948.5 1.0X +UNICODE 5731 5732 1 0.0 57314.8 1.8X +UTF8_BINARY 6344 6366 31 0.0 63443.6 1.6X +UNICODE_CI 324196 325450 1772 0.0 3241964.4 0.0X +OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1018-azure +AMD EPYC 7763 64-Core Processor +collation unit benchmarks - mode - 30105 elements: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +--------------------------------------------------------------------------------------------------------------------------------- +UTF8_BINARY_LCASE - mode - 30105 elements 6 6 0 5.1 195.6 1.0X +UNICODE - mode - 30105 elements 3 3 0 11.6 86.0 2.3X +UTF8_BINARY - mode - 30105 elements 3 3 0 11.6 85.9 2.3X +UNICODE_CI - mode - 30105 elements 12 12 1 2.6 382.9 0.5X diff --git a/sql/core/benchmarks/CollationBenchmark-results.txt b/sql/core/benchmarks/CollationBenchmark-results.txt index 1e0515b182862..d889fff8f2b3d 100644 --- a/sql/core/benchmarks/CollationBenchmark-results.txt +++ b/sql/core/benchmarks/CollationBenchmark-results.txt @@ -2,53 +2,61 @@ OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1018-azure AMD EPYC 7763 64-Core Processor collation unit benchmarks - equalsFunction: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------------------------------------- -UTF8_BINARY_LCASE 3571 3576 7 0.0 35708.8 1.0X -UNICODE 2235 2240 7 0.0 22349.2 1.6X -UTF8_BINARY 2237 2242 6 0.0 22371.7 1.6X -UNICODE_CI 18733 18817 118 0.0 187333.8 0.2X +UTF8_BINARY_LCASE 3260 3264 7 0.0 32595.0 1.0X +UNICODE 2783 2784 1 0.0 27834.5 1.2X +UTF8_BINARY 2789 2789 0 0.0 27889.1 1.2X +UNICODE_CI 17545 17548 5 0.0 175445.8 0.2X OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1018-azure AMD EPYC 7763 64-Core Processor collation unit benchmarks - compareFunction: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative --------------------------------------------------------------------------------------------------------------------------- -UTF8_BINARY_LCASE 4260 4290 41 0.0 42602.6 1.0X -UNICODE 19536 19624 124 0.0 195360.2 0.2X -UTF8_BINARY 3582 3612 43 0.0 35818.5 1.2X -UNICODE_CI 20381 20454 103 0.0 203814.1 0.2X +UTF8_BINARY_LCASE 3716 3736 27 0.0 37164.3 1.0X +UNICODE 18425 18429 6 0.0 184247.4 0.2X +UTF8_BINARY 3192 3198 9 0.0 31922.3 1.2X +UNICODE_CI 19072 19079 10 0.0 190718.3 0.2X OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1018-azure AMD EPYC 7763 64-Core Processor collation unit benchmarks - hashFunction: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -UTF8_BINARY_LCASE 7347 7349 3 0.0 73467.1 1.0X -UNICODE 73462 73608 206 0.0 734623.2 0.1X -UTF8_BINARY 5775 5815 57 0.0 57746.0 1.3X -UNICODE_CI 57543 57619 108 0.0 575425.2 0.1X +UTF8_BINARY_LCASE 7051 7053 4 0.0 70505.6 1.0X +UNICODE 64901 64941 58 0.0 649006.0 0.1X +UTF8_BINARY 5461 5501 57 0.0 54612.8 1.3X +UNICODE_CI 59907 59972 91 0.0 599073.6 0.1X OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1018-azure AMD EPYC 7763 64-Core Processor collation unit benchmarks - contains: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -UTF8_BINARY_LCASE 15415 15424 13 0.0 154147.1 1.0X -UNICODE 8091 8108 25 0.0 80907.9 1.9X -UTF8_BINARY 8964 8979 21 0.0 89643.5 1.7X -UNICODE_CI 469123 474822 8060 0.0 4691227.7 0.0X +UTF8_BINARY_LCASE 13702 13706 5 0.0 137020.5 1.0X +UNICODE 7306 7309 5 0.0 73056.6 1.9X +UTF8_BINARY 8077 8079 4 0.0 80765.4 1.7X +UNICODE_CI 311083 311372 409 0.0 3110831.3 0.0X OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1018-azure AMD EPYC 7763 64-Core Processor collation unit benchmarks - startsWith: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -UTF8_BINARY_LCASE 13064 13080 23 0.0 130635.2 1.0X -UNICODE 6836 6851 22 0.0 68360.1 1.9X -UTF8_BINARY 7693 7719 36 0.0 76933.9 1.7X -UNICODE_CI 488919 495530 9349 0.0 4889190.5 0.0X +UTF8_BINARY_LCASE 11665 11676 16 0.0 116650.6 1.0X +UNICODE 5983 5987 5 0.0 59832.3 1.9X +UTF8_BINARY 6680 6701 29 0.0 66803.2 1.7X +UNICODE_CI 307047 307098 72 0.0 3070474.8 0.0X OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1018-azure AMD EPYC 7763 64-Core Processor collation unit benchmarks - endsWith: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -UTF8_BINARY_LCASE 13097 13112 21 0.0 130970.4 1.0X -UNICODE 6960 6985 34 0.0 69603.9 1.9X -UTF8_BINARY 7766 7768 3 0.0 77663.5 1.7X -UNICODE_CI 456956 470733 19485 0.0 4569556.7 0.0X +UTF8_BINARY_LCASE 11707 11712 7 0.0 117068.1 1.0X +UNICODE 6062 6064 2 0.0 60618.8 1.9X +UTF8_BINARY 6730 6730 1 0.0 67300.3 1.7X +UNICODE_CI 314370 314565 276 0.0 3143696.2 0.0X +OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1018-azure +AMD EPYC 7763 64-Core Processor +collation unit benchmarks - mode - 30105 elements: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +--------------------------------------------------------------------------------------------------------------------------------- +UTF8_BINARY_LCASE - mode - 30105 elements 7 7 0 4.4 224.8 1.0X +UNICODE - mode - 30105 elements 3 3 0 9.4 105.8 2.1X +UTF8_BINARY - mode - 30105 elements 3 3 0 9.7 103.0 2.2X +UNICODE_CI - mode - 30105 elements 11 12 1 2.7 371.3 0.6X diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala index f8b3548b956ce..42cef2561d29a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala @@ -21,10 +21,14 @@ import java.text.SimpleDateFormat import scala.collection.immutable.Seq +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.aggregate.Mode import org.apache.spark.{SparkException, SparkIllegalArgumentException, SparkRuntimeException} import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.collection.OpenHashMap // scalastyle:off nonascii class CollationSQLExpressionsSuite @@ -1404,6 +1408,82 @@ class CollationSQLExpressionsSuite }) } + test("Support mode for string expression with collation - Basic Test") { + Seq("utf8_binary", "utf8_binary_lcase", "unicode_ci", "unicode").foreach { collationId => + val query = s"SELECT mode(collate('abc', '${collationId}'))" + checkAnswer(sql(query), Row("abc")) + assert(sql(query).schema.fields.head.dataType.sameType(StringType(collationId))) + } + } + + test("Support mode for string expression with collation - Advanced Test") { + case class ModeTestCase[R](collationId: String, bufferValues: Map[String, Long], result: R) + val testCases = Seq( + ModeTestCase("utf8_binary", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"), + ModeTestCase("utf8_binary_lcase", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"), + ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"), + ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a") + ) + testCases.foreach(t => { + val valuesToAdd = t.bufferValues.map { case (elt, numRepeats) => + (0L to numRepeats).map(_ => s"('$elt')").mkString(",") + }.mkString(",") + + withTable("t") { + sql("CREATE TABLE t(i STRING) USING parquet") + sql("INSERT INTO t VALUES " + valuesToAdd) + val query = s"SELECT mode(collate(i, '${t.collationId}')) FROM t" + checkAnswer(sql(query), Row(t.result)) + assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.collationId))) + + } + }) + } + + test("Support Mode.eval(buffer)") { + case class ModeTestCase[R]( + collationId: String, + bufferValues: Map[String, Long], + result: R) + case class UTF8StringModeTestCase[R]( + collationId: String, + bufferValues: Map[UTF8String, Long], + result: R) + + val bufferValues = Map("a" -> 5L, "b" -> 4L, "B" -> 3L, "d" -> 2L, "e" -> 1L) + val testCasesStrings = Seq(ModeTestCase("utf8_binary", bufferValues, "a"), + ModeTestCase("utf8_binary_lcase", bufferValues, "b"), + ModeTestCase("unicode_ci", bufferValues, "b"), + ModeTestCase("unicode", bufferValues, "a")) + + val bufferValuesUTF8String = Map( + UTF8String.fromString("a") -> 5L, + UTF8String.fromString("b") -> 4L, + UTF8String.fromString("B") -> 3L, + UTF8String.fromString("d") -> 2L, + UTF8String.fromString("e") -> 1L) + + val testCasesUTF8String = Seq( + UTF8StringModeTestCase("utf8_binary", bufferValuesUTF8String, "a"), + UTF8StringModeTestCase("utf8_binary_lcase", bufferValuesUTF8String, "b"), + UTF8StringModeTestCase("unicode_ci", bufferValuesUTF8String, "b"), + UTF8StringModeTestCase("unicode", bufferValuesUTF8String, "a")) + + testCasesStrings.foreach(t => { + val buffer = new OpenHashMap[AnyRef, Long](5) + val myMode = Mode(child = Literal.create("some_column_name", StringType(t.collationId))) + t.bufferValues.foreach { case (k, v) => buffer.update(k, v) } + assert(myMode.eval(buffer).toString.toLowerCase() == t.result.toLowerCase()) + }) + + testCasesUTF8String.foreach(t => { + val buffer = new OpenHashMap[AnyRef, Long](5) + val myMode = Mode(child = Literal.create("some_column_name", StringType(t.collationId))) + t.bufferValues.foreach { case (k, v) => buffer.update(k, v) } + assert(myMode.eval(buffer).toString.toLowerCase() == t.result.toLowerCase()) + }) + } + // TODO: Add more tests for other SQL expressions } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/CollationBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/CollationBenchmark.scala index 03b638b786bfd..949d9b30b3c19 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/CollationBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/CollationBenchmark.scala @@ -19,8 +19,12 @@ package org.apache.spark.sql.execution.benchmark import scala.concurrent.duration._ import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.aggregate.Mode import org.apache.spark.sql.catalyst.util.{CollationFactory, CollationSupport} +import org.apache.spark.sql.types.StringType import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.collection.OpenHashMap abstract class CollationBenchmarkBase extends BenchmarkBase { protected val collationTypes: Seq[String] = @@ -185,6 +189,30 @@ abstract class CollationBenchmarkBase extends BenchmarkBase { } benchmark.run() } + + def benchmarkMode( + collationTypes: Seq[String], + value: Seq[UTF8String]): Unit = { + val benchmark = new Benchmark( + s"collation unit benchmarks - mode - ${value.size} elements", + value.size, + warmupTime = 10.seconds, + output = output) + collationTypes.foreach { collationType => { + benchmark.addCase(s"$collationType - mode - ${value.size} elements") { _ => + val modeDefaultCollation = Mode(child = + Literal.create("some_column_name", StringType(collationType))) + val buffer = new OpenHashMap[AnyRef, Long](value.size) + value.foreach(v => { + buffer.update(v.toString, (v.hashCode() % 1000).toLong) + }) + modeDefaultCollation.eval(buffer) + } + } + } + + benchmark.run() + } } /** @@ -201,16 +229,36 @@ abstract class CollationBenchmarkBase extends BenchmarkBase { */ object CollationBenchmark extends CollationBenchmarkBase { - override def generateSeqInput(n: Long): Seq[UTF8String] = { - val input = Seq("ABC", "ABC", "aBC", "aBC", "abc", "abc", "DEF", "DEF", "def", "def", - "GHI", "ghi", "JKL", "jkl", "MNO", "mno", "PQR", "pqr", "STU", "stu", "VWX", "vwx", - "ABC", "ABC", "aBC", "aBC", "abc", "abc", "DEF", "DEF", "def", "def", "GHI", "ghi", - "JKL", "jkl", "MNO", "mno", "PQR", "pqr", "STU", "stu", "VWX", "vwx", "YZ") - .map(UTF8String.fromString) - val inputLong: Seq[UTF8String] = (0L until n).map(i => input(i.toInt % input.size)) - inputLong + private val baseInputStrings = Seq("ABC", "ABC", "aBC", "aBC", "abc", + "abc", "DEF", "DEF", "def", "def", + "GHI", "ghi", "JKL", "jkl", "MNO", "mno", "PQR", "pqr", "STU", "stu", "VWX", "vwx", + "ABC", "ABC", "aBC", "aBC", "abc", "abc", "DEF", "DEF", "def", "def", "GHI", "ghi", + "JKL", "jkl", "MNO", "mno", "PQR", "pqr", "STU", "stu", "VWX", "vwx", "YZ") + + + /* + * Generate input strings for the benchmark. The input strings are a sequence of base strings + * repeated n / input.size times. + */ + private def generateBaseInputStrings(n: Long): Seq[UTF8String] = { + val input = baseInputStrings.map(UTF8String.fromString) + (0L until n).map(i => input(i.toInt % input.size)) } + /* + Lowercase and some repeated strings to test the performance of the collation functions. + */ + def generateBaseInputStringswithUniqueGroupNumber(n: Long): Seq[UTF8String] = { + (0 to n.toInt / baseInputStrings.size).flatMap(k => baseInputStrings.map( + x => UTF8String.fromString(x + "_" + k))) + .flatMap( + x => Seq(x, x.repeat(4), x.repeat(8))) // Variable Lengths... + .sortBy(f => f.reverse().hashCode()) // Shuffle the input + } + + override def generateSeqInput(n: Long): Seq[UTF8String] = + generateBaseInputStrings(n) + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { val inputs = generateSeqInput(10000L) benchmarkUTFStringEquals(collationTypes, inputs) @@ -219,6 +267,7 @@ object CollationBenchmark extends CollationBenchmarkBase { benchmarkContains(collationTypes, inputs) benchmarkStartsWith(collationTypes, inputs) benchmarkEndsWith(collationTypes, inputs) + benchmarkMode(collationTypes, generateBaseInputStringswithUniqueGroupNumber(10000L)) } } @@ -248,5 +297,6 @@ object CollationNonASCIIBenchmark extends CollationBenchmarkBase { benchmarkContains(collationTypes, inputs) benchmarkStartsWith(collationTypes, inputs) benchmarkEndsWith(collationTypes, inputs) + benchmarkMode(collationTypes, inputs) } }