-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-23819][SQL] Fix InMemoryTableScanExec complex type pruning #20935
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
83e1e53
5c95cef
a63eb59
426374b
1479bde
6ea0919
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,7 +19,8 @@ package org.apache.spark.sql.execution.columnar | |
|
|
||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, RowOrdering} | ||
| import org.apache.spark.sql.catalyst.util.{ArrayData, TypeUtils} | ||
| import org.apache.spark.sql.catalyst.expressions.{UnsafeArrayData, UnsafeRow} | ||
| import org.apache.spark.sql.catalyst.util.TypeUtils | ||
| import org.apache.spark.sql.types._ | ||
| import org.apache.spark.unsafe.types.UTF8String | ||
|
|
||
|
|
@@ -357,20 +358,21 @@ private abstract class OrderableSafeColumnStats[T](dataType: DataType) extends C | |
| } | ||
|
|
||
| private[columnar] final class ArrayColumnStats(dataType: DataType) | ||
| extends OrderableSafeColumnStats[ArrayData](dataType) { | ||
| override def getValue(row: InternalRow, ordinal: Int): ArrayData = row.getArray(ordinal) | ||
| extends OrderableSafeColumnStats[UnsafeArrayData](dataType) { | ||
| override def getValue(row: InternalRow, ordinal: Int): UnsafeArrayData = | ||
| row.getArray(ordinal).asInstanceOf[UnsafeArrayData] | ||
|
|
||
| override def copy(value: ArrayData): ArrayData = value.copy() | ||
| override def copy(value: UnsafeArrayData): UnsafeArrayData = value.copy() | ||
| } | ||
|
|
||
| private[columnar] final class StructColumnStats(dataType: DataType) | ||
|
||
| extends OrderableSafeColumnStats[InternalRow](dataType) { | ||
| extends OrderableSafeColumnStats[UnsafeRow](dataType) { | ||
| private val numFields = dataType.asInstanceOf[StructType].fields.length | ||
|
|
||
| override def getValue(row: InternalRow, ordinal: Int): InternalRow = | ||
| row.getStruct(ordinal, numFields) | ||
| override def getValue(row: InternalRow, ordinal: Int): UnsafeRow = | ||
| row.getStruct(ordinal, numFields).asInstanceOf[UnsafeRow] | ||
|
|
||
| override def copy(value: InternalRow): InternalRow = value.copy() | ||
| override def copy(value: UnsafeRow): UnsafeRow = value.copy() | ||
| } | ||
|
|
||
| private[columnar] final class MapColumnStats(dataType: DataType) extends ColumnStats { | ||
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,7 +18,7 @@ | |
| package org.apache.spark.sql.execution.columnar | ||
|
|
||
| import org.apache.spark.SparkFunSuite | ||
| import org.apache.spark.sql.catalyst.expressions.RowOrdering | ||
| import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData} | ||
| import org.apache.spark.sql.catalyst.util.TypeUtils | ||
| import org.apache.spark.sql.types._ | ||
|
|
||
|
|
@@ -35,9 +35,30 @@ class ColumnStatsSuite extends SparkFunSuite { | |
| ) | ||
| testColumnStats(classOf[StringColumnStats], STRING, Array(null, null, 0, 0, 0)) | ||
| testDecimalColumnStats(Array(null, null, 0, 0, 0)) | ||
| testArrayColumnStats(ArrayType(IntegerType), orderable = true, Array(null, null, 0, 0, 0)) | ||
| testStructColumnStats( | ||
| StructType(Array(StructField("test", DataTypes.StringType))), | ||
|
|
||
| private val orderableArrayDataType = ArrayType(IntegerType) | ||
| testOrderableColumnStats( | ||
| orderableArrayDataType, | ||
| () => new ArrayColumnStats(orderableArrayDataType), | ||
| ARRAY(orderableArrayDataType), | ||
| orderable = true, | ||
| Array(null, null, 0, 0, 0) | ||
| ) | ||
|
|
||
| private val unorderableArrayDataType = ArrayType(MapType(IntegerType, StringType)) | ||
| testOrderableColumnStats( | ||
| unorderableArrayDataType, | ||
| () => new ArrayColumnStats(unorderableArrayDataType), | ||
| ARRAY(unorderableArrayDataType), | ||
| orderable = false, | ||
| Array(null, null, 0, 0, 0) | ||
| ) | ||
|
|
||
| private val structDataType = StructType(Array(StructField("test", DataTypes.StringType))) | ||
| testOrderableColumnStats( | ||
| structDataType, | ||
| () => new StructColumnStats(structDataType), | ||
| STRUCT(structDataType), | ||
| orderable = true, | ||
| Array(null, null, 0, 0, 0) | ||
| ) | ||
|
|
@@ -120,58 +141,23 @@ class ColumnStatsSuite extends SparkFunSuite { | |
| } | ||
| } | ||
|
|
||
| def testArrayColumnStats( | ||
| dataType: DataType, orderable: Boolean, initialStatistics: Array[Any]): Unit = { | ||
| val columnType = ColumnType(dataType) | ||
|
|
||
| test(s"${dataType.typeName}: empty") { | ||
| val objectStats = new ArrayColumnStats(dataType) | ||
| objectStats.collectedStatistics.zip(initialStatistics).foreach { | ||
| case (actual, expected) => assert(actual === expected) | ||
| } | ||
| } | ||
|
|
||
| test(s"${dataType.typeName}: non-empty") { | ||
| import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ | ||
| val objectStats = new ArrayColumnStats(dataType) | ||
| val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) | ||
| rows.foreach(objectStats.gatherStats(_, 0)) | ||
|
|
||
| val stats = objectStats.collectedStatistics | ||
| if (orderable) { | ||
| val values = rows.take(10).map(_.get(0, columnType.dataType)) | ||
| val ordering = TypeUtils.getInterpretedOrdering(dataType) | ||
|
|
||
| assertResult(values.min(ordering), "Wrong lower bound")(stats(0)) | ||
| assertResult(values.max(ordering), "Wrong upper bound")(stats(1)) | ||
| } else { | ||
| assertResult(null, "Wrong lower bound")(stats(0)) | ||
| assertResult(null, "Wrong upper bound")(stats(1)) | ||
| } | ||
| assertResult(10, "Wrong null count")(stats(2)) | ||
| assertResult(20, "Wrong row count")(stats(3)) | ||
| assertResult(stats(4), "Wrong size in bytes") { | ||
| rows.map { row => | ||
| if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) | ||
| }.sum | ||
| } | ||
| } | ||
| } | ||
|
|
||
| def testStructColumnStats( | ||
| dataType: DataType, orderable: Boolean, initialStatistics: Array[Any]): Unit = { | ||
| val columnType = ColumnType(dataType) | ||
| def testOrderableColumnStats[T]( | ||
| dataType: DataType, | ||
| statsSupplier: () => OrderableSafeColumnStats[T], | ||
| columnType: ColumnType[T], | ||
| orderable: Boolean, | ||
| initialStatistics: Array[Any]): Unit = { | ||
|
|
||
| test(s"${dataType.typeName}: empty") { | ||
| val objectStats = new StructColumnStats(dataType) | ||
| test(s"${dataType.typeName}, $orderable: empty") { | ||
| val objectStats = statsSupplier() | ||
| objectStats.collectedStatistics.zip(initialStatistics).foreach { | ||
| case (actual, expected) => assert(actual === expected) | ||
| } | ||
| } | ||
|
|
||
| test(s"${dataType.typeName}: non-empty") { | ||
| test(s"${dataType.typeName}, $orderable: non-empty") { | ||
| import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ | ||
| val objectStats = new StructColumnStats(dataType) | ||
| val objectStats = statsSupplier() | ||
| val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) | ||
| rows.foreach(objectStats.gatherStats(_, 0)) | ||
|
|
||
|
|
@@ -224,4 +210,20 @@ class ColumnStatsSuite extends SparkFunSuite { | |
| } | ||
| } | ||
| } | ||
|
|
||
| test("Reuse UnsafeArrayData for stats") { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should also test against UnsafeRow too. |
||
| val stats = new ArrayColumnStats(ArrayType(IntegerType)) | ||
| val unsafeData = UnsafeArrayData.fromPrimitiveArray(Array(1)) | ||
| (1 to 10).foreach { value => | ||
| val row = new GenericInternalRow(Array[Any](unsafeData)) | ||
| unsafeData.setInt(0, value) | ||
| stats.gatherStats(row, 0) | ||
| } | ||
| val collected = stats.collectedStatistics | ||
| assertResult(UnsafeArrayData.fromPrimitiveArray(Array(1)))(collected(0)) | ||
| assertResult(UnsafeArrayData.fromPrimitiveArray(Array(10)))(collected(1)) | ||
| assertResult(0)(collected(2)) | ||
| assertResult(10)(collected(3)) | ||
| assertResult(10 * (4 + unsafeData.getSizeInBytes))(collected(4)) | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dataType: DataType -> dataType: ArrayType?