diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/SparkSQLTest1.scala b/examples/src/main/scala/org/apache/spark/examples/sql/SparkSQLTest1.scala new file mode 100644 index 000000000000..0bbd7b41c9a6 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/sql/SparkSQLTest1.scala @@ -0,0 +1,248 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.sql + +import org.apache.spark.sql.{Encoder, Encoders, Row, SparkSession} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.expressions.Aggregator +import org.apache.spark.sql.internal.SQLConf._ +import org.apache.spark.sql.internal.StaticSQLConf._ +import org.apache.spark.sql.types.{DataType, UserDefinedType} + +// scalastyle:off println +object SparkSQLTest1 { + + def show(rows: Array[Row]): Unit = { + var printHeader = false + for (row <- rows) { + if (!printHeader) { + val headerBuilder = new StringBuilder("|") + val lineBuilder = new StringBuilder("|") + for (field <- row.schema.fields) { + headerBuilder.append(f"${field.name}%10s").append("|") + lineBuilder.append("----------").append("|") + } + println(headerBuilder.toString()) + println(lineBuilder.toString()) + printHeader = true + } + val contextBuilder = new StringBuilder("|") + for (ele <- row.toSeq) { + contextBuilder.append(f"$ele%10s").append("|") + } + println(contextBuilder.toString) + } + } + + def example(spark: SparkSession, sqlCommand: String, + sortByShuffle: Boolean, showResult: Boolean = true): Array[Row] = { + spark.sqlContext.setConf(SORTED_SHUFFLE_ENABLED.key, sortByShuffle.toString) + val res = spark.sql(sqlCommand).collect() + if (showResult) { + show(res) + } + res + } + + def testSortByShuffle(spark: SparkSession, sqlCommand: String, + considerSort: Boolean = false, show: Boolean = true): Unit = { + var success = true + val expect = example(spark, sqlCommand, false, show) + val actual = example(spark, sqlCommand, true, show) + if (expect.length != actual.length) { + println(s"Got wrong result size," + + s"expected size is ${expect.size},actual size is ${actual.size}") + success = false + } else if (considerSort) { + for ((e, a) <- expect.zip(actual)) { + if (!e.equals(a)) { + println(s"Got wrong matched result," + + s"expected result is ${e},actual result is ${a}") + success = false + } + } + } else { + val expectSet = expect.toSet + for (a <- actual) { + if (!expectSet.contains(a)) { + println(s"Got actual result ${a} which is not expect.") + success = false + } + } + } + if (success) { + println(s"Executed ${sqlCommand}\nSUCCESS!!!") + } else { + println(s"Executed ${sqlCommand}\nFAIL!!!") + sys.exit() + } + } + + case class UDFDataType(value: Int) extends UserDefinedType[UDFDataType] { + override def sqlType: DataType = this + override def serialize(obj: UDFDataType): Int = value + def deserialize(datum: Any): UDFDataType = datum match {case v: Int => UDFDataType(v)} + override def userClass: Class[UDFDataType] = classOf[UDFDataType] + private[spark] override def asNullable: UDFDataType = this + } + + object SumUDFDataType extends Aggregator[Integer, Integer, UDFDataType] { + override def zero: Integer = 0 + override def reduce(b: Integer, a: Integer): Integer = a + b + override def merge(b1: Integer, b2: Integer): Integer = b1 + b2 + override def finish(reduction: Integer): UDFDataType = UDFDataType(reduction) + override def bufferEncoder: Encoder[Integer] = Encoders.INT + override def outputEncoder: Encoder[UDFDataType] = ExpressionEncoder() + } + + def main(args: Array[String]): Unit = { + val spark = SparkSession.builder().appName("SparkSQLTest1"). + config("spark.memory.fraction", "0.1"). + config(CODEGEN_COMMENTS.key, "true").master("local").getOrCreate() + spark.sqlContext.setConf(ADAPTIVE_EXECUTION_ENABLED.key, "false") + spark.sqlContext.setConf(SORTED_SHUFFLE_ENABLED.key, "true") + spark.sqlContext.setConf(CODEGEN_FALLBACK.key, "false") + // disable codegen +// spark.sqlContext.setConf(WHOLESTAGE_HUGE_METHOD_LIMIT.key, "0") +// spark.sqlContext.setConf(CODEGEN_FACTORY_MODE.key, "NO_CODEGEN") + // Used to simulate that memory is not enough +// spark.sqlContext.setConf("spark.sql.TungstenAggregate.testFallbackStartsAt", "10") +// spark.sqlContext.setConf(HASH_AGG_MAX_RECORD_IN_MEMORY.key, "10") +// // Only for ObjectHashAggregateExec and case "2.5 groupBy gender with udf" +// spark.sqlContext.setConf(OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key, "1") + + spark.read.json("examples/src/main/resources/student.json").createOrReplaceTempView("student") + var sqlCommand = "" + + // 1 Select with filter + // 1.1 Select with filter + sqlCommand = "select id, name, age from student where age > 16" + testSortByShuffle(spark, sqlCommand) +// show(spark.sql(sqlCommand).collect()) + + // 1.2 Select with filter and sort + sqlCommand = "select * from student where age > 16 order by age" + testSortByShuffle(spark, sqlCommand, true) +// show(spark.sql(sqlCommand).collect()) + + // 1.3 Select with filter and sort + sqlCommand = "select * from student where age > 16 order by gender, score desc" + testSortByShuffle(spark, sqlCommand, true) +// show(spark.sql(sqlCommand).collect()) + + // 1.4 select with filter and repartition + sqlCommand = "select /*+ REPARTITION(4) */ * from student where age > 16" + testSortByShuffle(spark, sqlCommand) +// show(spark.sql(sqlCommand).collect()) + + // 1.5 select with filter and sort and repartition + sqlCommand = "select /*+ REPARTITION(4) */ * from student" + + " where age > 16 order by gender, score desc" + testSortByShuffle(spark, sqlCommand) +// show(spark.sql(sqlCommand).collect()) + + // 2 GroupBy + // 2.1 GroupBy gender + sqlCommand = "select gender, avg(score) from student group by gender" + testSortByShuffle(spark, sqlCommand) +// show(spark.sql(sqlCommand).collect()) + + // 2.2 GroupBy gender and age + sqlCommand = "select gender, age, avg(score) from student group by gender, age" + testSortByShuffle(spark, sqlCommand) +// show(spark.sql(sqlCommand).collect()) + + // 2.3 GroupBy gender and age, then sort by gender and age + sqlCommand = "select gender, age, avg(score) from student " + + "group by gender, age order by gender, age" + testSortByShuffle(spark, sqlCommand, true) +// show(spark.sql(sqlCommand).collect()) + + // 2.4 GroupBy gender, and distinct by age, avg by score + sqlCommand = "select gender, avg(score), count(distinct age) " + + "from student group by gender" + testSortByShuffle(spark, sqlCommand) +// show(spark.sql(sqlCommand).collect()) + + // 2.5 groupBy gender with udf + // Here we use user defined type so that we can use ObjectHashAggregateExec + // but not HashAggregateExec + import org.apache.spark.sql.functions + sqlCommand = "select gender, mysum(score) as mysum_score from student group by gender" + spark.udf.register("mysum", functions.udaf(SumUDFDataType)) + testSortByShuffle(spark, sqlCommand) +// show(spark.sql(sqlCommand).collect()) + + // 2.6 groupBy id, will generate many keys + sqlCommand = "select id, avg(score) from student group by id" + testSortByShuffle(spark, sqlCommand, false) +// show(spark.sql(sqlCommand).collect()) + + // 2.7 groupBy id and name without any aggregation functions + sqlCommand = "select id, name from student group by id, name" + testSortByShuffle(spark, sqlCommand) +// show(spark.sql(sqlCommand).collect()) + + // 2.8 groupBy id with distinct and sum + // In this example, will generate two shuffle. + // First shuffle have two grouping expression: gender, score. + // And the map of first Shuffle have one aggregate expressions: partial_sum(score) + // And the reduce of first shuffle have one aggregate expressions: merge_sum(score). + // Second shuffle have one grouping expression: gender. + // And the map of second shuffle have two aggregate expressions: + // merge_sum(score), partial_count(distinct id) + // And the reduce of second shuffle have two aggregate expressions: + // sum(score), count(distinct id) + sqlCommand = "select sum(score), count(distinct id) from student group by gender" + testSortByShuffle(spark, sqlCommand) +// show(spark.sql(sqlCommand).collect()) + + // 3 window + // 3.1 rank partitioned by gender, order by core. + sqlCommand = "select id, name, gender, score, " + + "rank() over(partition by gender order by score desc) as rank from student" + testSortByShuffle(spark, sqlCommand, true) +// show(spark.sql(sqlCommand).collect()) + + // 4 Join + spark.read.json("examples/src/main/resources/student_info1.json") + .createOrReplaceTempView("student_info1") + spark.read.json("examples/src/main/resources/student_info2.json") + .createOrReplaceTempView("student_info2") +// // disable broadcast join and prefer sort merge join + spark.sqlContext.setConf(AUTO_BROADCASTJOIN_THRESHOLD.key, "-1") + spark.sqlContext.setConf(PREFER_SORTMERGEJOIN.key, "true") + // 4.1 two table join + sqlCommand = "select student.id,student.name,student.gender," + + "student_info1.address from student " + + "join student_info1 on student.name = student_info1.name " + + "and student.score = student_info1.score" + testSortByShuffle(spark, sqlCommand) +// show(spark.sql(sqlCommand).collect()) + + // 4.2 three table join + sqlCommand = "select student.id,student.name,student.gender," + + "student_info1.address, student_info2.hobby from student " + + "join student_info1 on student.name = student_info1.name " + + "join student_info2 on student.name = student_info2.name" + testSortByShuffle(spark, sqlCommand) +// show(spark.sql(sqlCommand).collect()) + + spark.stop() + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala index 37a3b3a34e49..39bbef4b1f96 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala @@ -121,9 +121,17 @@ object RowOrdering extends CodeGeneratorWithInterpretedFallback[Seq[SortOrder], * Creates a row ordering for the given schema, in natural ascending order. */ def createNaturalAscendingOrdering(dataTypes: Seq[DataType]): BaseOrdering = { - val order: Seq[SortOrder] = dataTypes.zipWithIndex.map { + val order: Seq[SortOrder] = createNaturalAscendingSortOrder(dataTypes) + create(order, Seq.empty) + } + + def createNaturalAscendingSortOrder(dataTypes: Seq[DataType]): Seq[SortOrder] = { + dataTypes.zipWithIndex.map { case (dt, index) => SortOrder(BoundReference(index, dt, nullable = true), Ascending) } - create(order, Seq.empty) + } + + def createNaturalInterpretedOrdering(sortOrder: Seq[SortOrder]): BaseOrdering = { + createInterpretedObject(bindReferences(sortOrder, Seq.empty)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 8767d00767aa..ec8a789cd1ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2786,7 +2786,9 @@ object SQLConf { .internal() .doc("In the case of ObjectHashAggregateExec, when the size of the in-memory hash map " + "grows too large, we will fall back to sort-based aggregation. This option sets a row " + - "count threshold for the size of the hash map.") + "count threshold for the size of the hash map." + + "There is a special case. If sorted shuffle is enabled and in Partial mode, will only " + + "aggregate partial records in memory, and perform complete aggregation in the Final stage.") .version("2.2.0") .intConf // We are trying to be conservative and use a relatively small default count threshold here @@ -2794,6 +2796,22 @@ object SQLConf { // percentile_approx). .createWithDefault(128) + val HASH_AGG_MAX_RECORD_IN_MEMORY = buildConf("spark.sql.hashAggregate.maxRecordInMemory") + .internal() + .doc("When sorted shuffle is enabled and in Partial mode, will only aggregate partial " + + "records in memory, and perform complete aggregation in the Final stage. This option " + + "sets max records in memory.") + .version("3.5.1") + .intConf + .createWithDefault(1024) + + val SORTED_SHUFFLE_ENABLED = buildConf("spark.sql.execution.sortedShuffle.enabled") + .internal() + .doc("Whether to enable sorted shuffle for spark sql.") + .version("3.5.1") + .booleanConf + .createWithDefault(false) + val USE_OBJECT_HASH_AGG = buildConf("spark.sql.execution.useObjectHashAggregateExec") .internal() .doc("Decides if we use ObjectHashAggregateExec") @@ -6191,6 +6209,10 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def objectAggSortBasedFallbackThreshold: Int = getConf(OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD) + def hashAggMaxRecordsInMemory: Int = getConf(HASH_AGG_MAX_RECORD_IN_MEMORY) + + def sortedShuffleEnabled: Boolean = getConf(SORTED_SHUFFLE_ENABLED) + def variableSubstituteEnabled: Boolean = getConf(VARIABLE_SUBSTITUTE_ENABLED) def warehousePath: String = new Path(getConf(StaticSQLConf.WAREHOUSE_PATH)).toString diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java index 3d0511b7ba83..f9e5dcc38fc9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java @@ -37,6 +37,7 @@ public abstract class BufferedRowIterator { private long startTimeNs = System.nanoTime(); protected int partitionIndex = -1; + protected boolean shouldBreak = false; public boolean hasNext() throws IOException { if (currentRows.isEmpty()) { @@ -82,6 +83,14 @@ public boolean shouldStop() { return !currentRows.isEmpty(); } + public boolean shouldBreak() { + if (shouldBreak) { + shouldBreak = false; + return true; + } + return false; + } + /** * Increase the peak execution memory for current task. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala index 367732dbb205..bc62da36b4a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala @@ -128,13 +128,13 @@ class CoalescedPartitioner(val parent: Partitioner, val partitionStartIndices: A * (i.e. the number of partitions of the map output). */ class ShuffledRowRDD( - var dependency: ShuffleDependency[Int, InternalRow, InternalRow], + var dependency: ShuffleDependency[SqlKey, InternalRow, InternalRow], metrics: Map[String, SQLMetric], partitionSpecs: Array[ShufflePartitionSpec]) extends RDD[InternalRow](dependency.rdd.context, Nil) { def this( - dependency: ShuffleDependency[Int, InternalRow, InternalRow], + dependency: ShuffleDependency[SqlKey, InternalRow, InternalRow], metrics: Map[String, SQLMetric]) = { this(dependency, metrics, Array.tabulate(dependency.partitioner.numPartitions)(i => CoalescedPartitionSpec(i, i + 1))) @@ -229,7 +229,7 @@ class ShuffledRowRDD( context, sqlMetricsReporter) } - reader.read().asInstanceOf[Iterator[Product2[Int, InternalRow]]].map(_._2) + reader.read().asInstanceOf[Iterator[Product2[SqlKey, InternalRow]]].map(_._2) } override def clearDependencies(): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SqlKey.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SqlKey.scala new file mode 100644 index 000000000000..c64007174694 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SqlKey.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Input, Output} + +import org.apache.spark.Partitioner +import org.apache.spark.sql.catalyst.expressions.{BaseOrdering, UnsafeRow} +import org.apache.spark.util.Utils + +trait SqlKey extends Serializable with KryoSerializable + +object RowKey { + def apply(row: UnsafeRow): RowKey = new RowKey(row) +} + +object IntKey { + def apply(value: Int): IntKey = new IntKey(value) +} + +object SqlKeyPartitioner { + def apply(numPartitions: Int, real: Option[Partitioner] = None): SqlKeyPartitioner = + new SqlKeyPartitioner(numPartitions, real) +} + +case class RowKey(var row: UnsafeRow) extends SqlKey { + def this() = this(null) + + override def write(kryo: Kryo, output: Output): Unit = { + row.write(kryo, output) + } + + override def read(kryo: Kryo, input: Input): Unit = { + if (row == null) { + row = new UnsafeRow(); + } + row.read(kryo, input) + } +} +case class IntKey(var value: Int) extends SqlKey { + def this() = this(-1) + + override def write(kryo: Kryo, output: Output): Unit = { + output.writeInt(value) + } + + override def read(kryo: Kryo, input: Input): Unit = { + this.value = input.readInt + } +} + +class SqlKeyPartitioner(partitions: Int, real: Option[Partitioner] = None) extends Partitioner { + + override def numPartitions: Int = real.map(par => par.numPartitions).getOrElse(partitions) + + override def getPartition(key: Any): Int = key match { + case RowKey(row) => + real.map(par => par.getPartition(row)). + getOrElse(Utils.nonNegativeMod(row.hashCode(), partitions)) + case IntKey(v) => Utils.nonNegativeMod(v, partitions) + } +} + +object SqlKeyOrdering { + def apply(ordering: BaseOrdering): SqlKeyOrdering = new SqlKeyOrdering(ordering) +} + +class SqlKeyOrdering(ordering: BaseOrdering) extends Ordering[SqlKey] { + + def this() = this(null) + + override def compare(x: SqlKey, y: SqlKey): Int = { + (x, y) match { + case (IntKey(a), IntKey(b)) => a - b + case (RowKey(a), RowKey(b)) => ordering.compare(a, b) + } + } +} \ No newline at end of file diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala index 42fcfa8d60fa..c263b1238018 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala @@ -42,16 +42,18 @@ import org.apache.spark.unsafe.Platform * @param numFields the number of fields in the row being serialized. */ class UnsafeRowSerializer( - numFields: Int, - dataSize: SQLMetric = null) extends Serializer with Serializable { + numFields: Int, + sortKeyFields: Option[Int] = None, + dataSize: SQLMetric = null) extends Serializer with Serializable { override def newInstance(): SerializerInstance = - new UnsafeRowSerializerInstance(numFields, dataSize) + new UnsafeRowSerializerInstance(numFields, sortKeyFields, dataSize) override def supportsRelocationOfSerializedObjects: Boolean = true } private class UnsafeRowSerializerInstance( - numFields: Int, - dataSize: SQLMetric) extends SerializerInstance { + numFields: Int, + sortKeyFields: Option[Int] = None, + dataSize: SQLMetric = null) extends SerializerInstance { /** * Serializes a stream of UnsafeRows. Within the stream, each record consists of a record * length (stored as a 4-byte integer, written high byte first), followed by the record's bytes. @@ -74,7 +76,15 @@ private class UnsafeRowSerializerInstance( override def writeKey[T: ClassTag](key: T): SerializationStream = { // The key is only needed on the map side when computing partition ids. It does not need to // be shuffled. - assert(null == key || key.isInstanceOf[Int]) + assert(null == key || key.isInstanceOf[SqlKey]) + sortKeyFields.map(_ => { + val row = key.asInstanceOf[RowKey].row + if (dataSize != null) { + dataSize.add(row.getSizeInBytes) + } + dOut.writeInt(row.getSizeInBytes) + row.writeToStream(dOut, writeBuffer) + }) this } @@ -99,87 +109,167 @@ private class UnsafeRowSerializerInstance( } override def deserializeStream(in: InputStream): DeserializationStream = { - new DeserializationStream { - private[this] val dIn: DataInputStream = new DataInputStream(new BufferedInputStream(in)) - // 1024 is a default buffer size; this buffer will grow to accommodate larger rows - private[this] var rowBuffer: Array[Byte] = new Array[Byte](1024) - private[this] var row: UnsafeRow = new UnsafeRow(numFields) - private[this] var rowTuple: (Int, UnsafeRow) = (0, row) - private[this] val EOF: Int = -1 - - override def asKeyValueIterator: Iterator[(Int, UnsafeRow)] = { - new Iterator[(Int, UnsafeRow)] { - - private[this] def readSize(): Int = try { - dIn.readInt() - } catch { - case e: EOFException => - dIn.close() - EOF - } - - private[this] var rowSize: Int = readSize() - override def hasNext: Boolean = rowSize != EOF - - override def next(): (Int, UnsafeRow) = { - if (rowBuffer.length < rowSize) { - rowBuffer = new Array[Byte](rowSize) - } - ByteStreams.readFully(dIn, rowBuffer, 0, rowSize) - row.pointTo(rowBuffer, Platform.BYTE_ARRAY_OFFSET, rowSize) - rowSize = readSize() - if (rowSize == EOF) { // We are returning the last row in this stream - dIn.close() - val _rowTuple = rowTuple - // Null these out so that the byte array can be garbage collected once the entire - // iterator has been consumed - row = null - rowBuffer = null - rowTuple = null - _rowTuple - } else { - rowTuple - } - } - } - } + sortKeyFields.map(keyFields => new KeyValueDeserializationStream(in, keyFields, numFields)). + getOrElse(new ValueOnlyDeserializationStream(in, numFields)) + } - override def asIterator: Iterator[Any] = { - // This method is never called by shuffle code. - throw SparkUnsupportedOperationException() - } + // These methods are never called by shuffle code. + override def serialize[T: ClassTag](t: T): ByteBuffer = throw SparkUnsupportedOperationException() + override def deserialize[T: ClassTag](bytes: ByteBuffer): T = + throw SparkUnsupportedOperationException() + override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = + throw SparkUnsupportedOperationException() +} + +class ValueOnlyDeserializationStream(in: InputStream, + numFields: Int) extends DeserializationStream { + private[this] val dIn: DataInputStream = new DataInputStream(new BufferedInputStream(in)) + // 1024 is a default buffer size; this buffer will grow to accommodate larger rows + private[this] var rowBuffer: Array[Byte] = new Array[Byte](1024) + private[this] var row: UnsafeRow = new UnsafeRow(numFields) + private[this] var rowTuple: (IntKey, UnsafeRow) = (IntKey(0), row) + private[this] val EOF: Int = -1 - override def readKey[T: ClassTag](): T = { - // We skipped serialization of the key in writeKey(), so just return a dummy value since - // this is going to be discarded anyways. - null.asInstanceOf[T] + override def asKeyValueIterator: Iterator[(IntKey, UnsafeRow)] = { + new Iterator[(IntKey, UnsafeRow)] { + + private[this] def readSize(): Int = try { + dIn.readInt() + } catch { + case e: EOFException => + dIn.close() + EOF } - override def readValue[T: ClassTag](): T = { - val rowSize = dIn.readInt() + private[this] var rowSize: Int = readSize() + override def hasNext: Boolean = rowSize != EOF + + override def next(): (IntKey, UnsafeRow) = { if (rowBuffer.length < rowSize) { rowBuffer = new Array[Byte](rowSize) } ByteStreams.readFully(dIn, rowBuffer, 0, rowSize) row.pointTo(rowBuffer, Platform.BYTE_ARRAY_OFFSET, rowSize) - row.asInstanceOf[T] + rowSize = readSize() + if (rowSize == EOF) { // We are returning the last row in this stream + dIn.close() + val _rowTuple = rowTuple + // Null these out so that the byte array can be garbage collected once the entire + // iterator has been consumed + row = null + rowBuffer = null + rowTuple = null + _rowTuple + } else { + rowTuple + } } + } + } + + override def asIterator: Iterator[Any] = { + // This method is never called by shuffle code. + throw SparkUnsupportedOperationException() + } + + override def readKey[T: ClassTag](): T = { + // We skipped serialization of the key in writeKey(), so just return a dummy value since + // this is going to be discarded anyways. + null.asInstanceOf[T] + } + + override def readValue[T: ClassTag](): T = { + val rowSize = dIn.readInt() + if (rowBuffer.length < rowSize) { + rowBuffer = new Array[Byte](rowSize) + } + ByteStreams.readFully(dIn, rowBuffer, 0, rowSize) + row.pointTo(rowBuffer, Platform.BYTE_ARRAY_OFFSET, rowSize) + row.asInstanceOf[T] + } - override def readObject[T: ClassTag](): T = { - // This method is never called by shuffle code. - throw SparkUnsupportedOperationException() + override def readObject[T: ClassTag](): T = { + // This method is never called by shuffle code. + throw SparkUnsupportedOperationException() + } + + override def close(): Unit = { + dIn.close() + } +} + +class KeyValueDeserializationStream(in: InputStream, + sortKeyFields: Int, + valNumFields: Int) extends DeserializationStream { + private[this] val dIn: DataInputStream = new DataInputStream(new BufferedInputStream(in)) + + private[this] val EOF: Int = -1 + + override def asKeyValueIterator: Iterator[(RowKey, UnsafeRow)] = { + new Iterator[(RowKey, UnsafeRow)] { + + private[this] def readSize(): Int = try { + dIn.readInt() + } catch { + case e: EOFException => + dIn.close() + EOF } - override def close(): Unit = { - dIn.close() + private[this] var keySize: Int = readSize() + override def hasNext: Boolean = keySize != EOF + + override def next(): (RowKey, UnsafeRow) = { + // read key + val keyBuffer = new Array[Byte](keySize) + ByteStreams.readFully(dIn, keyBuffer, 0, keySize) + val keyRow = new UnsafeRow(sortKeyFields) + keyRow.pointTo(keyBuffer, Platform.BYTE_ARRAY_OFFSET, keySize) + + // read value + val valSize: Int = readSize() + val valBuffer = new Array[Byte](valSize) + ByteStreams.readFully(dIn, valBuffer, 0, valSize) + val valRow: UnsafeRow = new UnsafeRow(valNumFields) + valRow.pointTo(valBuffer, Platform.BYTE_ARRAY_OFFSET, valSize) + keySize = readSize() + if (keySize == EOF) { // We are returning the last row in this stream + dIn.close() + } + (RowKey(keyRow), valRow) } } } - // These methods are never called by shuffle code. - override def serialize[T: ClassTag](t: T): ByteBuffer = throw SparkUnsupportedOperationException() - override def deserialize[T: ClassTag](bytes: ByteBuffer): T = + override def asIterator: Iterator[Any] = { + // This method is never called by shuffle code. throw SparkUnsupportedOperationException() - override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = + } + + override def readKey[T: ClassTag](): T = { + val keySize = dIn.readInt() + val keyBuffer = new Array[Byte](keySize) + ByteStreams.readFully(dIn, keyBuffer, 0, keySize) + val keyRow = new UnsafeRow(sortKeyFields) + keyRow.pointTo(keyBuffer, Platform.BYTE_ARRAY_OFFSET, keySize) + RowKey(keyRow).asInstanceOf[T] + } + + override def readValue[T: ClassTag](): T = { + val valSize: Int = dIn.readInt() + val valBuffer = new Array[Byte](valSize) + ByteStreams.readFully(dIn, valBuffer, 0, valSize) + val valRow: UnsafeRow = new UnsafeRow(valNumFields) + valRow.pointTo(valBuffer, Platform.BYTE_ARRAY_OFFSET, valSize) + valRow.asInstanceOf[T] + } + + override def readObject[T: ClassTag](): T = { + // This method is never called by shuffle code. throw SparkUnsupportedOperationException() + } + + override def close(): Unit = { + dIn.close() + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 920f61574770..2d6552e0b297 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -379,10 +379,16 @@ trait CodegenSupport extends SparkPlan { /** * Helper default should stop check code. */ - def shouldStopCheckCode: String = if (needStopCheck) { - "if (shouldStop()) return;" - } else { - "// shouldStop check is eliminated" + def shouldStopCheckCode: String = { + if (this.isInstanceOf[WholeStageCodegenExec] || this.parent == null) { + if (needStopCheck) { + "if (shouldStop()) return;" + } else { + "// shouldStop check is eliminated" + } + } else { + parent.shouldStopCheckCode + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 367d4cfafb48..41f1c2fd2448 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -74,13 +74,14 @@ object AggUtils { aggregateAttributes: Seq[Attribute] = Nil, initialInputBufferOffset: Int = 0, resultExpressions: Seq[NamedExpression] = Nil, - child: SparkPlan): SparkPlan = { + child: SparkPlan, + isFinalMode: Boolean = false): SparkPlan = { val useHash = Aggregate.supportsHashAggregate( aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes), groupingExpressions) val forceObjHashAggregate = forceApplyObjectHashAggregate(child.conf) - val forceSortAggregate = forceApplySortAggregate(child.conf) - + val forceSortAggregate = (isFinalMode && SQLConf.get.sortedShuffleEnabled) || + forceApplySortAggregate(child.conf) if (useHash && !forceSortAggregate && !forceObjHashAggregate) { HashAggregateExec( requiredChildDistributionExpressions = requiredChildDistributionExpressions, @@ -91,7 +92,8 @@ object AggUtils { aggregateAttributes = aggregateAttributes, initialInputBufferOffset = initialInputBufferOffset, resultExpressions = resultExpressions, - child = child) + child = child, + isFinalMode = isFinalMode) } else { val objectHashEnabled = child.conf.useObjectHashAggregation val useObjectHash = Aggregate.supportsObjectHashAggregate(aggregateExpressions) @@ -106,7 +108,8 @@ object AggUtils { aggregateAttributes = aggregateAttributes, initialInputBufferOffset = initialInputBufferOffset, resultExpressions = resultExpressions, - child = child) + child = child, + isFinalMode = isFinalMode) } else { SortAggregateExec( requiredChildDistributionExpressions = requiredChildDistributionExpressions, @@ -166,7 +169,8 @@ object AggUtils { aggregateAttributes = finalAggregateAttributes, initialInputBufferOffset = groupingExpressions.length, resultExpressions = resultExpressions, - child = interExec) + child = interExec, + isFinalMode = true) finalAggregate :: Nil } @@ -295,7 +299,8 @@ object AggUtils { aggregateAttributes = finalAggregateAttributes ++ distinctAggregateAttributes, initialInputBufferOffset = groupingAttributes.length, resultExpressions = resultExpressions, - child = partialDistinctAggregate) + child = partialDistinctAggregate, + isFinalMode = true) } finalAndCompleteAggregate :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 469f42dcc0af..5f603c53f258 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -58,7 +58,8 @@ case class HashAggregateExec( aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], - child: SparkPlan) + child: SparkPlan, + isFinalMode: Boolean) extends AggregateCodegenSupport { require(Aggregate.supportsHashAggregate(aggregateBufferAttributes, groupingExpressions)) @@ -89,6 +90,9 @@ case class HashAggregateExec( } } + val aggregationInMemory: Boolean = !isStreaming && !isFinalMode && + SQLConf.get.sortedShuffleEnabled + protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") val peakMemory = longMetric("peakMemory") @@ -118,12 +122,14 @@ case class HashAggregateExec( MutableProjection.create(expressions, inputSchema), inputAttributes, iter, - testFallbackStartsAt, + (if (aggregationInMemory) SQLConf.get.hashAggMaxRecordsInMemory + else testFallbackStartsAt.getOrElse((Int.MaxValue, Int.MaxValue))._2), numOutputRows, peakMemory, spillSize, avgHashProbe, - numTasksFallBacked) + numTasksFallBacked, + aggregationInMemory) if (!hasInput && groupingExpressions.isEmpty) { numOutputRows += 1 Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput()) @@ -406,7 +412,10 @@ case class HashAggregateExec( !conf.getConf(SQLConf.ENABLE_TWOLEVEL_AGG_MAP_PARTIAL_ONLY) } - isSupported && isNotByteArrayDecimalType && isEnabledForAggModes + // For now, if we do aggregation in memory, it will do partial aggregation according + // to hashAggMaxRecordsInMemory. If the fast map is used, it may frequently apply for + // memory. Therefore, the fast map is disabled when aggregation in memory. + isSupported && isNotByteArrayDecimalType && isEnabledForAggModes && !aggregationInMemory } private def enableTwoLevelHashMap(): Unit = { @@ -609,23 +618,39 @@ case class HashAggregateExec( val aggTime = metricTerm(ctx, "aggTime") val beforeAgg = ctx.freshName("beforeAgg") + val fetchNextRow = ctx.freshName("fetchNewRow") + + ctx.addNewFunction(fetchNextRow, + s""" + |private void $fetchNextRow(int times) throws java.io.IOException { + | if (!$initAgg) { + | $initAgg = true; + | $sorterTerm = null; + | $createFastHashMap + | $addHookToCloseFastHashMap + | $hashMapTerm = $thisPlan.createHashMap(); + | long $beforeAgg = System.nanoTime(); + | $doAggFuncName(); + | $aggTime.add((System.nanoTime() - $beforeAgg) / $NANOS_PER_MILLIS); + | } + | // output the result + | $outputFromFastHashMap + | $outputFromRegularHashMap + | if (times < 1) { + | $initAgg = false; + | $fetchNextRow(times + 1); + | } + |} + """.stripMargin) + s""" - |if (!$initAgg) { - | $initAgg = true; - | $createFastHashMap - | $addHookToCloseFastHashMap - | $hashMapTerm = $thisPlan.createHashMap(); - | long $beforeAgg = System.nanoTime(); - | $doAggFuncName(); - | $aggTime.add((System.nanoTime() - $beforeAgg) / $NANOS_PER_MILLIS); - |} - |// output the result - |$outputFromFastHashMap - |$outputFromRegularHashMap + | $fetchNextRow(0); """.stripMargin } protected override def doConsumeWithKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { + val aggInMem = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "aggInMem", + v => s"$v = $aggregationInMemory;") // create grouping key val unsafeRowKeyCode = GenerateUnsafeProjection.createCode( ctx, bindReferences[Expression](groupingExpressions, child.output)) @@ -648,12 +673,18 @@ case class HashAggregateExec( } } - val (checkFallbackForBytesToBytesMap, resetCounter, incCounter) = testFallbackStartsAt match { - case Some((_, regularMapCounter)) => - val countTerm = ctx.addMutableState(CodeGenerator.JAVA_INT, "fallbackCounter") - (s"$countTerm < $regularMapCounter", s"$countTerm = 0;", s"$countTerm += 1;") - case _ => ("true", "", "") - } + val countTerm = ctx.addMutableState(CodeGenerator.JAVA_INT, "fallbackCounter") + val (checkFallbackForBytesToBytesMap, resetCounter, incCounter) = + if (aggregationInMemory) { + val maxRecords = SQLConf.get.hashAggMaxRecordsInMemory + (s"$countTerm < $maxRecords", s"$countTerm = 0;", s"$countTerm += 1;") + } else { + testFallbackStartsAt match { + case Some((_, regularMapCounter)) => + (s"$countTerm < $regularMapCounter", s"$countTerm = 0;", s"$countTerm += 1;") + case _ => ("true", "", "") + } + } val oomeClassName = classOf[SparkOutOfMemoryError].getName @@ -662,28 +693,41 @@ case class HashAggregateExec( |// generate grouping key |${unsafeRowKeyCode.code} |int $unsafeRowKeyHash = ${unsafeRowKeyCode.value}.hashCode(); - |if ($checkFallbackForBytesToBytesMap) { - | // try to get the buffer from hash map + |if ($aggInMem) { | $unsafeRowBuffer = - | $hashMapTerm.getAggregationBufferFromUnsafeRow($unsafeRowKeys, $unsafeRowKeyHash); - |} - |// Can't allocate buffer from the hash map. Spill the map and fallback to sort-based - |// aggregation after processing all input rows. - |if ($unsafeRowBuffer == null) { - | if ($sorterTerm == null) { - | $sorterTerm = $hashMapTerm.destructAndCreateExternalSorter(); - | } else { - | $sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter()); - | } - | $resetCounter - | // the hash map had be spilled, it should have enough memory now, - | // try to allocate buffer again. - | $unsafeRowBuffer = $hashMapTerm.getAggregationBufferFromUnsafeRow( - | $unsafeRowKeys, $unsafeRowKeyHash); + | $hashMapTerm.getAggregationBufferFromUnsafeRow($unsafeRowKeys, $unsafeRowKeyHash); | if ($unsafeRowBuffer == null) { - | // failed to allocate the first page + | // failed to allocate page | throw new $oomeClassName("_LEGACY_ERROR_TEMP_3302", new java.util.HashMap()); | } + | if (!($checkFallbackForBytesToBytesMap)) { + | shouldBreak = true; + | $resetCounter + | } + |} else { + | if ($checkFallbackForBytesToBytesMap) { + | // try to get the buffer from hash map + | $unsafeRowBuffer = + | $hashMapTerm.getAggregationBufferFromUnsafeRow($unsafeRowKeys, $unsafeRowKeyHash); + | } + | // Can't allocate buffer from the hash map. Spill the map and fallback to sort-based + | // aggregation after processing all input rows. + | if ($unsafeRowBuffer == null) { + | if ($sorterTerm == null) { + | $sorterTerm = $hashMapTerm.destructAndCreateExternalSorter(); + | } else { + | $sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter()); + | } + | $resetCounter + | // the hash map had be spilled, it should have enough memory now, + | // try to allocate buffer again. + | $unsafeRowBuffer = $hashMapTerm.getAggregationBufferFromUnsafeRow( + | $unsafeRowKeys, $unsafeRowKeyHash); + | if ($unsafeRowBuffer == null) { + | // failed to allocate the first page + | throw new $oomeClassName("_LEGACY_ERROR_TEMP_3302", new java.util.HashMap()); + | } + | } |} """.stripMargin @@ -890,4 +934,18 @@ case class HashAggregateExec( override protected def withNewChildInternal(newChild: SparkPlan): HashAggregateExec = copy(child = newChild) + + override def needStopCheck: Boolean = aggregationInMemory + + override def shouldStopCheckCode: String = if (needStopCheck) { + "if (shouldBreak()) break;" + } else { + "// shouldStop check is eliminated" + } + + // If we do aggregation in memory, it means that we wil not consume all the inputs first. + // If the parent is BlockingOperatorWithCodegen, the parent may consume partial result. + // So we disable supportCodegen when aggregation in memory so that we can avoid such two + // plans in the same codegen stage. + override def supportCodegen: Boolean = !aggregationInMemory && super.supportCodegen } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala index a4a6dc8e4ab0..bfde1ca4924e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala @@ -46,7 +46,8 @@ class ObjectAggregationIterator( fallbackCountThreshold: Int, numOutputRows: SQLMetric, spillSize: SQLMetric, - numTasksFallBacked: SQLMetric) + numTasksFallBacked: SQLMetric, + aggregationInMemory: Boolean = false) extends AggregationIterator( partIndex, groupingExpressions, @@ -81,17 +82,15 @@ class ObjectAggregationIterator( newExpressions, newFunctions.toImmutableArraySeq, newInputAttributes.toImmutableArraySeq) } - /** - * Start processing input rows. - */ - processInputs() - TaskContext.get().addTaskCompletionListener[Unit](_ => { // At the end of the task, update the task's spill size. spillSize.set(TaskContext.get().taskMetrics().memoryBytesSpilled - spillSizeBefore) }) override final def hasNext: Boolean = { + if (aggBufferIterator == null || (!aggBufferIterator.hasNext && inputRows.hasNext)) { + processInputs() + } aggBufferIterator.hasNext } @@ -166,7 +165,8 @@ class ObjectAggregationIterator( processRow(buffer, inputRows.next()) } } else { - while (inputRows.hasNext && !sortBased) { + var stop = false + while (inputRows.hasNext && !stop) { val newInput = inputRows.next() val groupingKey = groupingProjection.apply(newInput) val buffer: InternalRow = getAggregationBufferByKey(hashMap, groupingKey) @@ -183,8 +183,11 @@ class ObjectAggregationIterator( log"${MDC(CONFIG, SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key)}" ) - // Falls back to sort-based aggregation - sortBased = true + stop = true + if (!aggregationInMemory) { + // Falls back to sort-based aggregation + sortBased = true + } numTasksFallBacked += 1 } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala index 7e8ce3e884a3..38dd13571816 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.internal.SQLConf /** * A hash-based aggregate operator that supports [[TypedImperativeAggregate]] functions that may @@ -66,7 +67,8 @@ case class ObjectHashAggregateExec( aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], - child: SparkPlan) + child: SparkPlan, + isFinalMode: Boolean) extends BaseAggregateExec { override def allAttributes: AttributeSeq = @@ -111,7 +113,8 @@ case class ObjectHashAggregateExec( fallbackCountThreshold, numOutputRows, spillSize, - numTasksFallBacked) + numTasksFallBacked, + !isStreaming && !isFinalMode && SQLConf.get.sortedShuffleEnabled) if (!hasInput && groupingExpressions.isEmpty) { numOutputRows += 1 Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 2f1cda9d0f9b..d7c58980933e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -92,12 +92,13 @@ class TungstenAggregationIterator( newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection, originalInputAttributes: Seq[Attribute], inputIter: Iterator[InternalRow], - testFallbackStartsAt: Option[(Int, Int)], + maxRecordInMemory: Int, numOutputRows: SQLMetric, peakMemory: SQLMetric, spillSize: SQLMetric, avgHashProbe: SQLMetric, - numTasksFallBacked: SQLMetric) + numTasksFallBacked: SQLMetric, + aggregationInMemory: Boolean = true) extends AggregationIterator( partIndex, groupingExpressions, @@ -166,22 +167,42 @@ class TungstenAggregationIterator( // This is the hash map used for hash-based aggregation. It is backed by an // UnsafeFixedWidthAggregationMap and it is used to store // all groups and their corresponding aggregation buffers for hash-based aggregation. - private[this] val hashMap = new UnsafeFixedWidthAggregationMap( - initialAggregationBuffer, - DataTypeUtils.fromAttributes( - aggregateFunctions.flatMap(_.aggBufferAttributes).toImmutableArraySeq), - DataTypeUtils.fromAttributes(groupingExpressions.map(_.toAttribute)), - TaskContext.get(), - 1024 * 16, // initial capacity - TaskContext.get().taskMemoryManager().pageSizeBytes - ) - + private[this] var hashMap: UnsafeFixedWidthAggregationMap = null + + private[this] def createHashMap: UnsafeFixedWidthAggregationMap = + new UnsafeFixedWidthAggregationMap( + initialAggregationBuffer, + DataTypeUtils.fromAttributes( + aggregateFunctions.flatMap(_.aggBufferAttributes).toImmutableArraySeq), + DataTypeUtils.fromAttributes(groupingExpressions.map(_.toAttribute)), + TaskContext.get(), + 1024 * 16, // initial capacity + TaskContext.get().taskMemoryManager().pageSizeBytes + ) // The function used to read and process input rows. When processing input rows, // it first uses hash-based aggregation by putting groups and their buffers in // hashMap. If there is not enough memory, it will multiple hash-maps, spilling // after each becomes full then using sort to merge these spills, finally do sort // based aggregation. - private def processInputs(fallbackStartsAt: (Int, Int)): Unit = { + private def processInputs(): Unit = { + hashMap = createHashMap + TaskContext.get().addTaskCompletionListener[Unit](_ => { + // At the end of the task, update the task's peak memory usage. Since we destroy + // the map to create the sorter, their memory usages should not overlap, so it is safe + // to just use the max of the two. + val mapMemory = hashMap.getPeakMemoryUsedBytes + val sorterMemory = Option(externalSorter).map(_.getPeakMemoryUsedBytes).getOrElse(0L) + val maxMemory = Math.max(mapMemory, sorterMemory) + peakMemory.set(Math.max(maxMemory, peakMemory.value)) + val metrics = TaskContext.get().taskMetrics() + spillSize.set(metrics.memoryBytesSpilled - spillSizeBefore) + metrics.incPeakExecutionMemory(maxMemory) + + // Updating average hashmap probe + // When aggregationInMemory is enabled, only update last hashMap + avgHashProbe.set(hashMap.getAvgHashProbesPerKey) + }) + if (groupingExpressions.isEmpty) { // If there is no grouping expressions, we can just reuse the same buffer over and over again. // Note that it would be better to eliminate the hash map entirely in the future. @@ -193,21 +214,12 @@ class TungstenAggregationIterator( } } else { var i = 0 - while (inputIter.hasNext) { + var stop = false + while (inputIter.hasNext && !stop) { val newInput = inputIter.next() val groupingKey = groupingProjection.apply(newInput) var buffer: UnsafeRow = null - if (i < fallbackStartsAt._2) { - buffer = hashMap.getAggregationBufferFromUnsafeRow(groupingKey) - } - if (buffer == null) { - val sorter = hashMap.destructAndCreateExternalSorter() - if (externalSorter == null) { - externalSorter = sorter - } else { - externalSorter.merge(sorter) - } - i = 0 + if (aggregationInMemory) { buffer = hashMap.getAggregationBufferFromUnsafeRow(groupingKey) if (buffer == null) { // failed to allocate the first page @@ -215,6 +227,27 @@ class TungstenAggregationIterator( throw new SparkOutOfMemoryError("_LEGACY_ERROR_TEMP_3302", new util.HashMap()) // scalastyle:on throwerror } + if (i >= maxRecordInMemory) { + stop = true + } + } else { + buffer = hashMap.getAggregationBufferFromUnsafeRow(groupingKey) + if (i >= maxRecordInMemory || buffer == null) { + val sorter = hashMap.destructAndCreateExternalSorter() + if (externalSorter == null) { + externalSorter = sorter + } else { + externalSorter.merge(sorter) + } + i = 0 + buffer = hashMap.getAggregationBufferFromUnsafeRow(groupingKey) + if (buffer == null) { + // failed to allocate the first page + // scalastyle:off throwerror + throw new SparkOutOfMemoryError("_LEGACY_ERROR_TEMP_3302", new util.HashMap()) + // scalastyle:on throwerror + } + } } processRow(buffer, newInput) i += 1 @@ -226,6 +259,13 @@ class TungstenAggregationIterator( hashMap.free() switchToSortBasedAggregation() + } else { + // If we did not switch to sort-based aggregation in processInputs, + // we pre-load the first key-value pair from the map (to make hasNext idempotent). + // First, set aggregationBufferMapIterator. + aggregationBufferMapIterator = hashMap.iterator() + // Pre-load the first key-value pair from the aggregationBufferMapIterator. + mapIteratorHasNext = aggregationBufferMapIterator.next() } } } @@ -359,50 +399,19 @@ class TungstenAggregationIterator( } /////////////////////////////////////////////////////////////////////////// - // Part 6: Loads input rows and setup aggregationBufferMapIterator if we - // have not switched to sort-based aggregation. - /////////////////////////////////////////////////////////////////////////// - - /** - * Start processing input rows. - */ - processInputs(testFallbackStartsAt.getOrElse((Int.MaxValue, Int.MaxValue))) - - // If we did not switch to sort-based aggregation in processInputs, - // we pre-load the first key-value pair from the map (to make hasNext idempotent). - if (!sortBased) { - // First, set aggregationBufferMapIterator. - aggregationBufferMapIterator = hashMap.iterator() - // Pre-load the first key-value pair from the aggregationBufferMapIterator. - mapIteratorHasNext = aggregationBufferMapIterator.next() - // If the map is empty, we just free it. - if (!mapIteratorHasNext) { - hashMap.free() - } - } - - TaskContext.get().addTaskCompletionListener[Unit](_ => { - // At the end of the task, update the task's peak memory usage. Since we destroy - // the map to create the sorter, their memory usages should not overlap, so it is safe - // to just use the max of the two. - val mapMemory = hashMap.getPeakMemoryUsedBytes - val sorterMemory = Option(externalSorter).map(_.getPeakMemoryUsedBytes).getOrElse(0L) - val maxMemory = Math.max(mapMemory, sorterMemory) - val metrics = TaskContext.get().taskMetrics() - peakMemory.set(maxMemory) - spillSize.set(metrics.memoryBytesSpilled - spillSizeBefore) - metrics.incPeakExecutionMemory(maxMemory) - - // Updating average hashmap probe - avgHashProbe.set(hashMap.getAvgHashProbesPerKey) - }) - - /////////////////////////////////////////////////////////////////////////// - // Part 7: Iterator's public methods. + // Part 6: Iterator's public methods. /////////////////////////////////////////////////////////////////////////// override final def hasNext: Boolean = { - (sortBased && sortedInputHasNewGroup) || (!sortBased && mapIteratorHasNext) + if (sortBased) { + sortedInputHasNewGroup + } else { + if (aggregationBufferMapIterator == null || (!mapIteratorHasNext && inputIter.hasNext)) { + processInputs() + return hasNext + } + mapIteratorHasNext + } } override final def next(): UnsafeRow = { @@ -448,7 +457,7 @@ class TungstenAggregationIterator( } /////////////////////////////////////////////////////////////////////////// - // Part 8: Utility functions + // Part 7: Utility functions /////////////////////////////////////////////////////////////////////////// /** @@ -460,7 +469,6 @@ class TungstenAggregationIterator( // We create an output row and copy it. So, we can free the map. val resultCopy = generateOutput(UnsafeRow.createFromByteArray(0, 0), sortBasedAggregationBuffer).copy() - hashMap.free() resultCopy } else { throw SparkException.internalError( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 31a3f53eb719..1a17bd8b998c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -30,8 +30,7 @@ import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{ShuffleWriteMetricsReporter, ShuffleWriteProcessor} import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, UnsafeProjection, UnsafeRow} -import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, BoundReference, RowOrdering, SortOrder, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical._ @@ -203,8 +202,16 @@ case class ShuffleExchangeExec( override def nodeName: String = "Exchange" + lazy val sortKeyFields: Option[Int] = outputPartitioning match { + case HashPartitioning(expressions, _) + if SQLConf.get.sortedShuffleEnabled && expressions.size > 0 => Some(expressions.size) + case RangePartitioning(ordering, _) + if SQLConf.get.sortedShuffleEnabled && ordering.size > 0 => Some(ordering.size) + case _ => None + } + private lazy val serializer: Serializer = - new UnsafeRowSerializer(child.output.size, longMetric("dataSize")) + new UnsafeRowSerializer(child.output.size, sortKeyFields, longMetric("dataSize")) @transient lazy val inputRDD: RDD[InternalRow] = child.execute() @@ -240,13 +247,14 @@ case class ShuffleExchangeExec( * the returned ShuffleDependency will be the input of shuffle. */ @transient - lazy val shuffleDependency : ShuffleDependency[Int, InternalRow, InternalRow] = { + lazy val shuffleDependency : ShuffleDependency[SqlKey, InternalRow, InternalRow] = { val dep = ShuffleExchangeExec.prepareShuffleDependency( inputRDD, child.output, outputPartitioning, serializer, - writeMetrics) + writeMetrics, + Some(this)) metrics("numPartitions").set(dep.partitioner.numPartitions) val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) SQLMetrics.postDriverMetricUpdates( @@ -262,6 +270,35 @@ case class ShuffleExchangeExec( override protected def withNewChildInternal(newChild: SparkPlan): ShuffleExchangeExec = copy(child = newChild) + + override def outputOrdering: Seq[SortOrder] = outputPartitioning match { + case HashPartitioning(expressions, _) if sortKeyFields.isDefined => + expressions.map(SortOrder(_, Ascending)) + case RangePartitioning(ordering, _) if sortKeyFields.isDefined => ordering + case _ => Nil + } + + def keyOrdering: Option[Ordering[SqlKey]] = { + val ordering = outputPartitioning match { + case HashPartitioning(expressions, _) if sortKeyFields.isDefined => + RowOrdering.createNaturalAscendingSortOrder(expressions.map(_.dataType)) + case RangePartitioning(ordering, _) if sortKeyFields.isDefined => + ordering.zipWithIndex.map { + case (order, index) => + SortOrder(BoundReference(index, order.child.dataType, nullable = true), order.direction) + } + case _ => Nil + } + ordering match { + case sortOrders: Seq[SortOrder] => + // The ordering as part of TaskDescription will be serialized and propagated to each + // Executor. However, if we use code generated class, other executors will not be able + // to find the code generated related classes. + val ordering = RowOrdering.createNaturalInterpretedOrdering(sortOrders) + Some(SqlKeyOrdering(ordering)) + case Nil => None + } + } } object ShuffleExchangeExec { @@ -336,14 +373,17 @@ object ShuffleExchangeExec { outputAttributes: Seq[Attribute], newPartitioning: Partitioning, serializer: Serializer, - writeMetrics: Map[String, SQLMetric]) - : ShuffleDependency[Int, InternalRow, InternalRow] = { + writeMetrics: Map[String, SQLMetric], + shuffleExchangeExec: Option[ShuffleExchangeExec] = None) + : ShuffleDependency[SqlKey, InternalRow, InternalRow] = { + val sortedShuffleEnabled = + shuffleExchangeExec.map(exec => exec.sortKeyFields.isDefined).getOrElse(false) val part: Partitioner = newPartitioning match { - case RoundRobinPartitioning(numPartitions) => new HashPartitioner(numPartitions) + case RoundRobinPartitioning(numPartitions) => SqlKeyPartitioner(numPartitions) case HashPartitioning(_, n) => // For HashPartitioning, the partitioning key is already a valid partition ID, as we use // `HashPartitioning.partitionIdExpression` to produce partitioning key. - new PartitionIdPassthrough(n) + new SqlKeyPartitioner(n) case RangePartitioning(sortingExpressions, numPartitions) => // Extract only fields used for sorting to avoid collecting large fields that does not // affect sorting result when deciding partition bounds in RangePartitioner @@ -360,12 +400,12 @@ object ShuffleExchangeExec { ord.copy(child = BoundReference(i, ord.dataType, ord.nullable)) } implicit val ordering = new LazilyGeneratedOrdering(orderingAttributes) - new RangePartitioner( + SqlKeyPartitioner(numPartitions, Some(new RangePartitioner( numPartitions, rddForSampling, ascending = true, - samplePointsPerPartitionHint = SQLConf.get.rangeExchangeSampleSizePerPartition) - case SinglePartition => new ConstantPartitioner + samplePointsPerPartitionHint = SQLConf.get.rangeExchangeSampleSizePerPartition))) + case SinglePartition => SqlKeyPartitioner(1) case k @ KeyGroupedPartitioning(expressions, n, _, _) => val valueMap = k.uniquePartitionValues.zipWithIndex.map { case (partition, index) => (partition.toSeq(expressions.map(_.dataType)), index) @@ -374,7 +414,7 @@ object ShuffleExchangeExec { case _ => throw SparkException.internalError(s"Exchange not implemented for $newPartitioning") // TODO: Handle BroadcastPartitioning. } - def getPartitionKeyExtractor(): InternalRow => Any = newPartitioning match { + def getPartitionKeyExtractor(): InternalRow => SqlKey = newPartitioning match { case RoundRobinPartitioning(numPartitions) => // Distributes elements evenly across output partitions, starting from a random partition. // nextInt(numPartitions) implementation has a special case when bound is a power of 2, @@ -388,24 +428,32 @@ object ShuffleExchangeExec { (row: InternalRow) => { // The HashPartitioner will handle the `mod` by the number of partitions position += 1 - position + IntKey(position) } case h: HashPartitioning => - val projection = UnsafeProjection.create(h.partitionIdExpression :: Nil, outputAttributes) - row => projection(row).getInt(0) + if (shuffleExchangeExec.map(exec => exec.sortKeyFields.isDefined).getOrElse(false)) { + val projection = UnsafeProjection.create(h.expressions, outputAttributes) + row => RowKey(projection(row)) + } else { + val projection = UnsafeProjection.create(h.partitionIdExpression :: Nil, outputAttributes) + row => IntKey(projection(row).getInt(0)) + } case RangePartitioning(sortingExpressions, _) => val projection = UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes) - row => projection(row) - case SinglePartition => identity + row => RowKey(projection(row)) + case SinglePartition => + _ => IntKey(0) case KeyGroupedPartitioning(expressions, _, _, _) => - row => bindReferences(expressions, outputAttributes).map(_.eval(row)) + // row => RowKey(bindReferences(expressions, outputAttributes).map(_.eval(row))) + // TODO: support KeyGroupedPartitioning + throw SparkException.internalError(s"Exchange not implemented for $newPartitioning") case _ => throw SparkException.internalError(s"Exchange not implemented for $newPartitioning") } val isRoundRobin = newPartitioning.isInstanceOf[RoundRobinPartitioning] && newPartitioning.numPartitions > 1 - val rddWithPartitionIds: RDD[Product2[Int, InternalRow]] = { + val rddWithPartitionIds: RDD[Product2[SqlKey, InternalRow]] = { // [SPARK-23207] Have to make sure the generated RoundRobinPartitioning is deterministic, // otherwise a retry task may output different rows and thus lead to data loss. // @@ -457,13 +505,13 @@ object ShuffleExchangeExec { if (needToCopyObjectsBeforeShuffle(part)) { newRdd.mapPartitionsWithIndexInternal((_, iter) => { val getPartitionKey = getPartitionKeyExtractor() - iter.map { row => (part.getPartition(getPartitionKey(row)), row.copy()) } + iter.map { row => (getPartitionKey(row), row.copy()) } }, isOrderSensitive = isOrderSensitive) } else { newRdd.mapPartitionsWithIndexInternal((_, iter) => { val getPartitionKey = getPartitionKeyExtractor() - val mutablePair = new MutablePair[Int, InternalRow]() - iter.map { row => mutablePair.update(part.getPartition(getPartitionKey(row)), row) } + val mutablePair = new MutablePair[SqlKey, InternalRow]() + iter.map { row => mutablePair.update(getPartitionKey(row), row) } }, isOrderSensitive = isOrderSensitive) } } @@ -472,10 +520,14 @@ object ShuffleExchangeExec { // are in the form of (partitionId, row) and every partitionId is in the expected range // [0, part.numPartitions - 1]. The partitioner of this is a PartitionIdPassthrough. val dependency = - new ShuffleDependency[Int, InternalRow, InternalRow]( + new ShuffleDependency[SqlKey, InternalRow, InternalRow]( rddWithPartitionIds, - new PartitionIdPassthrough(part.numPartitions), + part, serializer, + keyOrdering = shuffleExchangeExec match { + case Some(exec) if exec.sortKeyFields.isDefined => exec.keyOrdering + case _ => None + }, shuffleWriterProcessor = createShuffleWriteProcessor(writeMetrics)) dependency diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index 928d732f2a16..87c6751f7d84 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -141,9 +141,9 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkSession { val unsafeRow = toUnsafeRow(row, Array(StringType, IntegerType)) val rowsRDD = spark.sparkContext.parallelize( Seq((0, unsafeRow), (1, unsafeRow), (0, unsafeRow)) - ).asInstanceOf[RDD[Product2[Int, InternalRow]]] + ).asInstanceOf[RDD[Product2[SqlKey, InternalRow]]] val dependency = - new ShuffleDependency[Int, InternalRow, InternalRow]( + new ShuffleDependency[SqlKey, InternalRow, InternalRow]( rowsRDD, new PartitionIdPassthrough(2), new UnsafeRowSerializer(2)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index a3cfdc5a240a..929ee0692f6f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -856,7 +856,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession assert( executedPlan.exists { case WholeStageCodegenExec( - HashAggregateExec(_, _, _, _, _, _, _, _, _: LocalTableScanExec)) => true + HashAggregateExec(_, _, _, _, _, _, _, _, _: LocalTableScanExec, _)) => true case _ => false }, "LocalTableScanExec should be within a WholeStageCodegen domain.")