From 83e1e5385c8aae52ef14266a6a4c58f4ce9614a9 Mon Sep 17 00:00:00 2001 From: Patrick Woody Date: Thu, 29 Mar 2018 07:06:47 -0400 Subject: [PATCH 1/5] SPARK-23819: InMemoryTableScanExec prunes orderable complex types due to out of date ColumnStats --- .../sql/execution/columnar/ColumnStats.scala | 18 ++++- .../execution/columnar/ColumnStatsSuite.scala | 74 ++++++++++++++++--- .../columnar/ColumnarTestUtils.scala | 24 ++++-- .../columnar/InMemoryColumnarQuerySuite.scala | 7 +- 4 files changed, 104 insertions(+), 19 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala index bc7e73ae1ba8..4b7e5aaa3fac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.execution.columnar import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, RowOrdering} +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -323,18 +324,31 @@ private[columnar] final class DecimalColumnStats(precision: Int, scale: Int) ext } private[columnar] final class ObjectColumnStats(dataType: DataType) extends ColumnStats { + protected var upper: Any = null + protected var lower: Any = null + val columnType = ColumnType(dataType) + val ordering = if (RowOrdering.isOrderable(dataType)) { + Option(TypeUtils.getInterpretedOrdering(dataType)) + } else { + None + } override def gatherStats(row: InternalRow, ordinal: Int): Unit = { if (!row.isNullAt(ordinal)) { val size = columnType.actualSize(row, ordinal) sizeInBytes += size count += 1 + ordering.foreach { order => + val value = row.get(ordinal, dataType) + if (upper == null || order.gt(value, upper)) upper = value + if (lower == null || order.lt(value, lower)) lower = value + } } else { gatherNullStats } } override def collectedStatistics: Array[Any] = - Array[Any](null, null, nullCount, count, sizeInBytes) + Array[Any](lower, upper, nullCount, count, sizeInBytes) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala index d4e7e362c6c8..fae7a2ebccf8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala @@ -18,18 +18,35 @@ package org.apache.spark.sql.execution.columnar import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.RowOrdering +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ class ColumnStatsSuite extends SparkFunSuite { - testColumnStats(classOf[BooleanColumnStats], BOOLEAN, Array(true, false, 0)) - testColumnStats(classOf[ByteColumnStats], BYTE, Array(Byte.MaxValue, Byte.MinValue, 0)) - testColumnStats(classOf[ShortColumnStats], SHORT, Array(Short.MaxValue, Short.MinValue, 0)) - testColumnStats(classOf[IntColumnStats], INT, Array(Int.MaxValue, Int.MinValue, 0)) - testColumnStats(classOf[LongColumnStats], LONG, Array(Long.MaxValue, Long.MinValue, 0)) - testColumnStats(classOf[FloatColumnStats], FLOAT, Array(Float.MaxValue, Float.MinValue, 0)) - testColumnStats(classOf[DoubleColumnStats], DOUBLE, Array(Double.MaxValue, Double.MinValue, 0)) - testColumnStats(classOf[StringColumnStats], STRING, Array(null, null, 0)) - testDecimalColumnStats(Array(null, null, 0)) + testColumnStats(classOf[BooleanColumnStats], BOOLEAN, Array(true, false, 0, 0, 0)) + testColumnStats(classOf[ByteColumnStats], BYTE, Array(Byte.MaxValue, Byte.MinValue, 0, 0, 0)) + testColumnStats(classOf[ShortColumnStats], SHORT, Array(Short.MaxValue, Short.MinValue, 0, 0, 0)) + testColumnStats(classOf[IntColumnStats], INT, Array(Int.MaxValue, Int.MinValue, 0, 0, 0)) + testColumnStats(classOf[LongColumnStats], LONG, Array(Long.MaxValue, Long.MinValue, 0, 0, 0)) + testColumnStats(classOf[FloatColumnStats], FLOAT, Array(Float.MaxValue, Float.MinValue, 0, 0, 0)) + testColumnStats( + classOf[DoubleColumnStats], DOUBLE, + Array(Double.MaxValue, Double.MinValue, 0, 0, 0) + ) + testColumnStats(classOf[StringColumnStats], STRING, Array(null, null, 0, 0, 0)) + testDecimalColumnStats(Array(null, null, 0, 0, 0)) + testObjectColumnStats(ArrayType(IntegerType), orderable = true, Array(null, null, 0, 0, 0)) + testObjectColumnStats( + StructType(Array(StructField("test", DataTypes.StringType))), + orderable = true, + Array(null, null, 0, 0, 0) + ) + testObjectColumnStats( + MapType(IntegerType, StringType), + orderable = false, + Array(null, null, 0, 0, 0) + ) + def testColumnStats[T <: AtomicType, U <: ColumnStats]( columnStatsClass: Class[U], @@ -103,4 +120,43 @@ class ColumnStatsSuite extends SparkFunSuite { } } } + + def testObjectColumnStats( + dataType: DataType, orderable: Boolean, initialStatistics: Array[Any]): Unit = { + assert(!(orderable ^ RowOrdering.isOrderable(dataType))) + val columnType = ColumnType(dataType) + + test(s"${dataType.typeName}: empty") { + val objectStats = new ObjectColumnStats(dataType) + objectStats.collectedStatistics.zip(initialStatistics).foreach { + case (actual, expected) => assert(actual === expected) + } + } + + test(s"${dataType.typeName}: non-empty") { + import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ + val objectStats = new ObjectColumnStats(dataType) + val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) + rows.foreach(objectStats.gatherStats(_, 0)) + + val stats = objectStats.collectedStatistics + if (orderable) { + val values = rows.take(10).map(_.get(0, columnType.dataType)) + val ordering = TypeUtils.getInterpretedOrdering(dataType) + + assertResult(values.min(ordering), "Wrong lower bound")(stats(0)) + assertResult(values.max(ordering), "Wrong upper bound")(stats(1)) + } else { + assertResult(null, "Wrong lower bound")(stats(0)) + assertResult(null, "Wrong upper bound")(stats(1)) + } + assertResult(10, "Wrong null count")(stats(2)) + assertResult(20, "Wrong row count")(stats(3)) + assertResult(stats(4), "Wrong size in bytes") { + rows.map { row => + if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) + }.sum + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala index 686c8fa6f5fa..2647331c3f72 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala @@ -21,9 +21,9 @@ import scala.collection.immutable.HashSet import scala.util.Random import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.GenericInternalRow -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} -import org.apache.spark.sql.types.{AtomicType, Decimal} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData, UnsafeMapData, UnsafeProjection} +import org.apache.spark.sql.catalyst.util.ArrayBasedMapData +import org.apache.spark.sql.types.{AtomicType, DataType, Decimal, IntegerType, MapType, StringType, StructField, StructType} import org.apache.spark.unsafe.types.UTF8String object ColumnarTestUtils { @@ -54,12 +54,22 @@ object ColumnarTestUtils { case COMPACT_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale) case LARGE_DECIMAL(precision, scale) => Decimal(Random.nextLong(), precision, scale) case STRUCT(_) => - new GenericInternalRow(Array[Any](UTF8String.fromString(Random.nextString(10)))) + val schema = StructType(Array(StructField("test", StringType))) + val converter = UnsafeProjection.create(schema) + converter(InternalRow(Array(UTF8String.fromString(Random.nextString(10))): _*)) case ARRAY(_) => - new GenericArrayData(Array[Any](Random.nextInt(), Random.nextInt())) + UnsafeArrayData.fromPrimitiveArray(Array(Random.nextInt(), Random.nextInt())) case MAP(_) => - ArrayBasedMapData( - Map(Random.nextInt() -> UTF8String.fromString(Random.nextString(Random.nextInt(32))))) + val unsafeConverter = + UnsafeProjection.create(Array[DataType](MapType(IntegerType, StringType))) + val row = new GenericInternalRow(1) + def toUnsafeMap(map: ArrayBasedMapData): UnsafeMapData = { + row.update(0, map) + val unsafeRow = unsafeConverter.apply(row) + unsafeRow.getMap(0).copy + } + toUnsafeMap(ArrayBasedMapData( + Map(Random.nextInt() -> UTF8String.fromString(Random.nextString(Random.nextInt(32)))))) case _ => throw new IllegalArgumentException(s"Unknown column type $columnType") }).asInstanceOf[JvmType] } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 26b63e8e8490..bc5ab3dcd7ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -30,7 +30,6 @@ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel._ -import org.apache.spark.util.Utils class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -527,4 +526,10 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-23819: Complex type pruning should utilize proper statistics") { + val df = Seq((Array(1), (1, 1))).toDF("arr", "struct").cache() + assert(df.where("arr <=> array(1)").count() === 1) + assert(df.where("struct <=> named_struct('_1', 1, '_2', 1)").count() === 1) + } } From 5c95cef81257185205c7ecebe3353060780b2c52 Mon Sep 17 00:00:00 2001 From: Patrick Woody Date: Thu, 29 Mar 2018 11:15:18 -0400 Subject: [PATCH 2/5] Fix null types --- .../apache/spark/sql/execution/columnar/ColumnStats.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala index 4b7e5aaa3fac..731505d1b50c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala @@ -328,10 +328,10 @@ private[columnar] final class ObjectColumnStats(dataType: DataType) extends Colu protected var lower: Any = null val columnType = ColumnType(dataType) - val ordering = if (RowOrdering.isOrderable(dataType)) { - Option(TypeUtils.getInterpretedOrdering(dataType)) - } else { - None + val ordering = dataType match { + case x if RowOrdering.isOrderable(dataType) && x != NullType => + Option(TypeUtils.getInterpretedOrdering(x)) + case _ => None } override def gatherStats(row: InternalRow, ordinal: Int): Unit = { From 426374b6131f65f67f3f023640af37d9a3130585 Mon Sep 17 00:00:00 2001 From: Patrick Woody Date: Sat, 31 Mar 2018 17:44:48 -0400 Subject: [PATCH 3/5] Add copying for unsafe data structures --- .../execution/columnar/ColumnBuilder.scala | 8 +- .../sql/execution/columnar/ColumnStats.scala | 91 ++++++++++++++----- .../execution/columnar/ColumnStatsSuite.scala | 83 +++++++++++++++-- 3 files changed, 145 insertions(+), 37 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala index d30655e0c4a2..48434d8d86a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala @@ -85,7 +85,7 @@ private[columnar] class BasicColumnBuilder[JvmType]( } private[columnar] class NullColumnBuilder - extends BasicColumnBuilder[Any](new ObjectColumnStats(NullType), NULL) + extends BasicColumnBuilder[Any](new NullColumnStats, NULL) with NullableColumnBuilder private[columnar] abstract class ComplexColumnBuilder[JvmType]( @@ -132,13 +132,13 @@ private[columnar] class DecimalColumnBuilder(dataType: DecimalType) extends ComplexColumnBuilder(new DecimalColumnStats(dataType), LARGE_DECIMAL(dataType)) private[columnar] class StructColumnBuilder(dataType: StructType) - extends ComplexColumnBuilder(new ObjectColumnStats(dataType), STRUCT(dataType)) + extends ComplexColumnBuilder(new StructColumnStats(dataType), STRUCT(dataType)) private[columnar] class ArrayColumnBuilder(dataType: ArrayType) - extends ComplexColumnBuilder(new ObjectColumnStats(dataType), ARRAY(dataType)) + extends ComplexColumnBuilder(new ArrayColumnStats(dataType), ARRAY(dataType)) private[columnar] class MapColumnBuilder(dataType: MapType) - extends ComplexColumnBuilder(new ObjectColumnStats(dataType), MAP(dataType)) + extends ComplexColumnBuilder(new MapColumnStats(dataType), MAP(dataType)) private[columnar] object ColumnBuilder { val DEFAULT_INITIAL_BUFFER_SIZE = 128 * 1024 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala index 731505d1b50c..b56df52dafb3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.columnar import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, RowOrdering} -import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.catalyst.util.{ArrayData, TypeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -81,7 +81,7 @@ private[columnar] final class NoopColumnStats extends ColumnStats { if (!row.isNullAt(ordinal)) { count += 1 } else { - gatherNullStats + gatherNullStats() } } @@ -97,7 +97,7 @@ private[columnar] final class BooleanColumnStats extends ColumnStats { val value = row.getBoolean(ordinal) gatherValueStats(value) } else { - gatherNullStats + gatherNullStats() } } @@ -121,7 +121,7 @@ private[columnar] final class ByteColumnStats extends ColumnStats { val value = row.getByte(ordinal) gatherValueStats(value) } else { - gatherNullStats + gatherNullStats() } } @@ -145,7 +145,7 @@ private[columnar] final class ShortColumnStats extends ColumnStats { val value = row.getShort(ordinal) gatherValueStats(value) } else { - gatherNullStats + gatherNullStats() } } @@ -169,7 +169,7 @@ private[columnar] final class IntColumnStats extends ColumnStats { val value = row.getInt(ordinal) gatherValueStats(value) } else { - gatherNullStats + gatherNullStats() } } @@ -193,7 +193,7 @@ private[columnar] final class LongColumnStats extends ColumnStats { val value = row.getLong(ordinal) gatherValueStats(value) } else { - gatherNullStats + gatherNullStats() } } @@ -217,7 +217,7 @@ private[columnar] final class FloatColumnStats extends ColumnStats { val value = row.getFloat(ordinal) gatherValueStats(value) } else { - gatherNullStats + gatherNullStats() } } @@ -241,7 +241,7 @@ private[columnar] final class DoubleColumnStats extends ColumnStats { val value = row.getDouble(ordinal) gatherValueStats(value) } else { - gatherNullStats + gatherNullStats() } } @@ -266,7 +266,7 @@ private[columnar] final class StringColumnStats extends ColumnStats { val size = STRING.actualSize(row, ordinal) gatherValueStats(value, size) } else { - gatherNullStats + gatherNullStats() } } @@ -288,7 +288,7 @@ private[columnar] final class BinaryColumnStats extends ColumnStats { sizeInBytes += size count += 1 } else { - gatherNullStats + gatherNullStats() } } @@ -308,7 +308,7 @@ private[columnar] final class DecimalColumnStats(precision: Int, scale: Int) ext // TODO: this is not right for DecimalType with precision > 18 gatherValueStats(value) } else { - gatherNullStats + gatherNullStats() } } @@ -323,32 +323,75 @@ private[columnar] final class DecimalColumnStats(precision: Int, scale: Int) ext Array[Any](lower, upper, nullCount, count, sizeInBytes) } -private[columnar] final class ObjectColumnStats(dataType: DataType) extends ColumnStats { - protected var upper: Any = null - protected var lower: Any = null +private abstract class OrderableSafeColumnStats[T](dataType: DataType) extends ColumnStats { + protected var upper: T = _ + protected var lower: T = _ - val columnType = ColumnType(dataType) - val ordering = dataType match { - case x if RowOrdering.isOrderable(dataType) && x != NullType => + private val columnType = ColumnType(dataType) + private val ordering = dataType match { + case x if RowOrdering.isOrderable(dataType) => Option(TypeUtils.getInterpretedOrdering(x)) case _ => None } override def gatherStats(row: InternalRow, ordinal: Int): Unit = { if (!row.isNullAt(ordinal)) { - val size = columnType.actualSize(row, ordinal) - sizeInBytes += size + sizeInBytes += columnType.actualSize(row, ordinal) count += 1 ordering.foreach { order => - val value = row.get(ordinal, dataType) - if (upper == null || order.gt(value, upper)) upper = value - if (lower == null || order.lt(value, lower)) lower = value + val value = getValue(row, ordinal) + if (upper == null || order.gt(value, upper)) upper = copy(value) + if (lower == null || order.lt(value, lower)) lower = copy(value) } } else { - gatherNullStats + gatherNullStats() } } + def getValue(row: InternalRow, ordinal: Int): T + + def copy(value: T): T + override def collectedStatistics: Array[Any] = Array[Any](lower, upper, nullCount, count, sizeInBytes) } + +private[columnar] final class ArrayColumnStats(dataType: DataType) + extends OrderableSafeColumnStats[ArrayData](dataType) { + override def getValue(row: InternalRow, ordinal: Int): ArrayData = row.getArray(ordinal) + + override def copy(value: ArrayData): ArrayData = value.copy() +} + +private[columnar] final class StructColumnStats(dataType: DataType) + extends OrderableSafeColumnStats[InternalRow](dataType) { + private val numFields = dataType.asInstanceOf[StructType].fields.length + + override def getValue(row: InternalRow, ordinal: Int): InternalRow = + row.getStruct(ordinal, numFields) + + override def copy(value: InternalRow): InternalRow = value.copy() +} + +private[columnar] final class MapColumnStats(dataType: DataType) extends ColumnStats { + private val columnType = ColumnType(dataType) + + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { + if (!row.isNullAt(ordinal)) { + sizeInBytes += columnType.actualSize(row, ordinal) + count += 1 + } else { + gatherNullStats() + } + } + + override def collectedStatistics: Array[Any] = + Array[Any](null, null, nullCount, count, sizeInBytes) +} + +private[columnar] final class NullColumnStats extends ColumnStats { + override def gatherStats(row: InternalRow, ordinal: Int): Unit = gatherNullStats() + + override def collectedStatistics: Array[Any] = + Array[Any](null, null, nullCount, count, sizeInBytes) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala index fae7a2ebccf8..89aa19fdfef3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala @@ -35,15 +35,14 @@ class ColumnStatsSuite extends SparkFunSuite { ) testColumnStats(classOf[StringColumnStats], STRING, Array(null, null, 0, 0, 0)) testDecimalColumnStats(Array(null, null, 0, 0, 0)) - testObjectColumnStats(ArrayType(IntegerType), orderable = true, Array(null, null, 0, 0, 0)) - testObjectColumnStats( + testArrayColumnStats(ArrayType(IntegerType), orderable = true, Array(null, null, 0, 0, 0)) + testStructColumnStats( StructType(Array(StructField("test", DataTypes.StringType))), orderable = true, Array(null, null, 0, 0, 0) ) - testObjectColumnStats( + testMapColumnStats( MapType(IntegerType, StringType), - orderable = false, Array(null, null, 0, 0, 0) ) @@ -121,13 +120,12 @@ class ColumnStatsSuite extends SparkFunSuite { } } - def testObjectColumnStats( - dataType: DataType, orderable: Boolean, initialStatistics: Array[Any]): Unit = { - assert(!(orderable ^ RowOrdering.isOrderable(dataType))) + def testArrayColumnStats( + dataType: DataType, orderable: Boolean, initialStatistics: Array[Any]): Unit = { val columnType = ColumnType(dataType) test(s"${dataType.typeName}: empty") { - val objectStats = new ObjectColumnStats(dataType) + val objectStats = new ArrayColumnStats(dataType) objectStats.collectedStatistics.zip(initialStatistics).foreach { case (actual, expected) => assert(actual === expected) } @@ -135,7 +133,7 @@ class ColumnStatsSuite extends SparkFunSuite { test(s"${dataType.typeName}: non-empty") { import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ - val objectStats = new ObjectColumnStats(dataType) + val objectStats = new ArrayColumnStats(dataType) val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) rows.foreach(objectStats.gatherStats(_, 0)) @@ -159,4 +157,71 @@ class ColumnStatsSuite extends SparkFunSuite { } } } + + def testStructColumnStats( + dataType: DataType, orderable: Boolean, initialStatistics: Array[Any]): Unit = { + val columnType = ColumnType(dataType) + + test(s"${dataType.typeName}: empty") { + val objectStats = new StructColumnStats(dataType) + objectStats.collectedStatistics.zip(initialStatistics).foreach { + case (actual, expected) => assert(actual === expected) + } + } + + test(s"${dataType.typeName}: non-empty") { + import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ + val objectStats = new StructColumnStats(dataType) + val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) + rows.foreach(objectStats.gatherStats(_, 0)) + + val stats = objectStats.collectedStatistics + if (orderable) { + val values = rows.take(10).map(_.get(0, columnType.dataType)) + val ordering = TypeUtils.getInterpretedOrdering(dataType) + + assertResult(values.min(ordering), "Wrong lower bound")(stats(0)) + assertResult(values.max(ordering), "Wrong upper bound")(stats(1)) + } else { + assertResult(null, "Wrong lower bound")(stats(0)) + assertResult(null, "Wrong upper bound")(stats(1)) + } + assertResult(10, "Wrong null count")(stats(2)) + assertResult(20, "Wrong row count")(stats(3)) + assertResult(stats(4), "Wrong size in bytes") { + rows.map { row => + if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) + }.sum + } + } + } + + def testMapColumnStats(dataType: DataType, initialStatistics: Array[Any]): Unit = { + val columnType = ColumnType(dataType) + + test(s"${dataType.typeName}: empty") { + val objectStats = new MapColumnStats(dataType) + objectStats.collectedStatistics.zip(initialStatistics).foreach { + case (actual, expected) => assert(actual === expected) + } + } + + test(s"${dataType.typeName}: non-empty") { + import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ + val objectStats = new MapColumnStats(dataType) + val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) + rows.foreach(objectStats.gatherStats(_, 0)) + + val stats = objectStats.collectedStatistics + assertResult(null, "Wrong lower bound")(stats(0)) + assertResult(null, "Wrong upper bound")(stats(1)) + assertResult(10, "Wrong null count")(stats(2)) + assertResult(20, "Wrong row count")(stats(3)) + assertResult(stats(4), "Wrong size in bytes") { + rows.map { row => + if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) + }.sum + } + } + } } From 1479bded7c6e220f6dad2eb5bff4feb64cbaf991 Mon Sep 17 00:00:00 2001 From: Patrick Woody Date: Sat, 31 Mar 2018 20:35:07 -0400 Subject: [PATCH 4/5] pr feedback --- .../sql/execution/columnar/ColumnStats.scala | 18 ++-- .../execution/columnar/ColumnStatsSuite.scala | 100 +++++++++--------- 2 files changed, 61 insertions(+), 57 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala index b56df52dafb3..f6cb15160dc9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala @@ -19,7 +19,8 @@ package org.apache.spark.sql.execution.columnar import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, RowOrdering} -import org.apache.spark.sql.catalyst.util.{ArrayData, TypeUtils} +import org.apache.spark.sql.catalyst.expressions.{UnsafeArrayData, UnsafeRow} +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -357,20 +358,21 @@ private abstract class OrderableSafeColumnStats[T](dataType: DataType) extends C } private[columnar] final class ArrayColumnStats(dataType: DataType) - extends OrderableSafeColumnStats[ArrayData](dataType) { - override def getValue(row: InternalRow, ordinal: Int): ArrayData = row.getArray(ordinal) + extends OrderableSafeColumnStats[UnsafeArrayData](dataType) { + override def getValue(row: InternalRow, ordinal: Int): UnsafeArrayData = + row.getArray(ordinal).asInstanceOf[UnsafeArrayData] - override def copy(value: ArrayData): ArrayData = value.copy() + override def copy(value: UnsafeArrayData): UnsafeArrayData = value.copy() } private[columnar] final class StructColumnStats(dataType: DataType) - extends OrderableSafeColumnStats[InternalRow](dataType) { + extends OrderableSafeColumnStats[UnsafeRow](dataType) { private val numFields = dataType.asInstanceOf[StructType].fields.length - override def getValue(row: InternalRow, ordinal: Int): InternalRow = - row.getStruct(ordinal, numFields) + override def getValue(row: InternalRow, ordinal: Int): UnsafeRow = + row.getStruct(ordinal, numFields).asInstanceOf[UnsafeRow] - override def copy(value: InternalRow): InternalRow = value.copy() + override def copy(value: UnsafeRow): UnsafeRow = value.copy() } private[columnar] final class MapColumnStats(dataType: DataType) extends ColumnStats { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala index 89aa19fdfef3..3e987d00ab65 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.columnar import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.RowOrdering +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -35,9 +35,30 @@ class ColumnStatsSuite extends SparkFunSuite { ) testColumnStats(classOf[StringColumnStats], STRING, Array(null, null, 0, 0, 0)) testDecimalColumnStats(Array(null, null, 0, 0, 0)) - testArrayColumnStats(ArrayType(IntegerType), orderable = true, Array(null, null, 0, 0, 0)) - testStructColumnStats( - StructType(Array(StructField("test", DataTypes.StringType))), + + private val orderableArrayDataType = ArrayType(IntegerType) + testOrderableColumnStats( + orderableArrayDataType, + () => new ArrayColumnStats(orderableArrayDataType), + ARRAY(orderableArrayDataType), + orderable = true, + Array(null, null, 0, 0, 0) + ) + + private val unorderableArrayDataType = ArrayType(MapType(IntegerType, StringType)) + testOrderableColumnStats( + unorderableArrayDataType, + () => new ArrayColumnStats(unorderableArrayDataType), + ARRAY(unorderableArrayDataType), + orderable = false, + Array(null, null, 0, 0, 0) + ) + + private val structDataType = StructType(Array(StructField("test", DataTypes.StringType))) + testOrderableColumnStats( + structDataType, + () => new StructColumnStats(structDataType), + STRUCT(structDataType), orderable = true, Array(null, null, 0, 0, 0) ) @@ -120,58 +141,23 @@ class ColumnStatsSuite extends SparkFunSuite { } } - def testArrayColumnStats( - dataType: DataType, orderable: Boolean, initialStatistics: Array[Any]): Unit = { - val columnType = ColumnType(dataType) - - test(s"${dataType.typeName}: empty") { - val objectStats = new ArrayColumnStats(dataType) - objectStats.collectedStatistics.zip(initialStatistics).foreach { - case (actual, expected) => assert(actual === expected) - } - } - - test(s"${dataType.typeName}: non-empty") { - import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ - val objectStats = new ArrayColumnStats(dataType) - val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) - rows.foreach(objectStats.gatherStats(_, 0)) - - val stats = objectStats.collectedStatistics - if (orderable) { - val values = rows.take(10).map(_.get(0, columnType.dataType)) - val ordering = TypeUtils.getInterpretedOrdering(dataType) - - assertResult(values.min(ordering), "Wrong lower bound")(stats(0)) - assertResult(values.max(ordering), "Wrong upper bound")(stats(1)) - } else { - assertResult(null, "Wrong lower bound")(stats(0)) - assertResult(null, "Wrong upper bound")(stats(1)) - } - assertResult(10, "Wrong null count")(stats(2)) - assertResult(20, "Wrong row count")(stats(3)) - assertResult(stats(4), "Wrong size in bytes") { - rows.map { row => - if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) - }.sum - } - } - } - - def testStructColumnStats( - dataType: DataType, orderable: Boolean, initialStatistics: Array[Any]): Unit = { - val columnType = ColumnType(dataType) + def testOrderableColumnStats[T]( + dataType: DataType, + statsSupplier: () => OrderableSafeColumnStats[T], + columnType: ColumnType[T], + orderable: Boolean, + initialStatistics: Array[Any]): Unit = { - test(s"${dataType.typeName}: empty") { - val objectStats = new StructColumnStats(dataType) + test(s"${dataType.typeName}, $orderable: empty") { + val objectStats = statsSupplier() objectStats.collectedStatistics.zip(initialStatistics).foreach { case (actual, expected) => assert(actual === expected) } } - test(s"${dataType.typeName}: non-empty") { + test(s"${dataType.typeName}, $orderable: non-empty") { import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ - val objectStats = new StructColumnStats(dataType) + val objectStats = statsSupplier() val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) rows.foreach(objectStats.gatherStats(_, 0)) @@ -224,4 +210,20 @@ class ColumnStatsSuite extends SparkFunSuite { } } } + + test("Reuse UnsafeArrayData for stats") { + val stats = new ArrayColumnStats(ArrayType(IntegerType)) + val unsafeData = UnsafeArrayData.fromPrimitiveArray(Array(1)) + (1 to 10).foreach { value => + val row = new GenericInternalRow(Array[Any](unsafeData)) + unsafeData.setInt(0, value) + stats.gatherStats(row, 0) + } + val collected = stats.collectedStatistics + assertResult(UnsafeArrayData.fromPrimitiveArray(Array(1)))(collected(0)) + assertResult(UnsafeArrayData.fromPrimitiveArray(Array(10)))(collected(1)) + assertResult(0)(collected(2)) + assertResult(10)(collected(3)) + assertResult(10 * (4 + unsafeData.getSizeInBytes))(collected(4)) + } } From 6ea0919ec0a9dfc6b121c88790fac79aa072bc60 Mon Sep 17 00:00:00 2001 From: Patrick Woody Date: Sun, 1 Apr 2018 11:18:32 -0400 Subject: [PATCH 5/5] extra test make Map orderable safe --- .../sql/execution/columnar/ColumnStats.scala | 24 +++++++------------ .../execution/columnar/ColumnStatsSuite.scala | 23 ++++++++++++++++-- 2 files changed, 29 insertions(+), 18 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala index f6cb15160dc9..f054d21860f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.columnar import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, RowOrdering} -import org.apache.spark.sql.catalyst.expressions.{UnsafeArrayData, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{UnsafeArrayData, UnsafeMapData, UnsafeRow} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -357,7 +357,7 @@ private abstract class OrderableSafeColumnStats[T](dataType: DataType) extends C Array[Any](lower, upper, nullCount, count, sizeInBytes) } -private[columnar] final class ArrayColumnStats(dataType: DataType) +private[columnar] final class ArrayColumnStats(dataType: ArrayType) extends OrderableSafeColumnStats[UnsafeArrayData](dataType) { override def getValue(row: InternalRow, ordinal: Int): UnsafeArrayData = row.getArray(ordinal).asInstanceOf[UnsafeArrayData] @@ -365,7 +365,7 @@ private[columnar] final class ArrayColumnStats(dataType: DataType) override def copy(value: UnsafeArrayData): UnsafeArrayData = value.copy() } -private[columnar] final class StructColumnStats(dataType: DataType) +private[columnar] final class StructColumnStats(dataType: StructType) extends OrderableSafeColumnStats[UnsafeRow](dataType) { private val numFields = dataType.asInstanceOf[StructType].fields.length @@ -375,20 +375,12 @@ private[columnar] final class StructColumnStats(dataType: DataType) override def copy(value: UnsafeRow): UnsafeRow = value.copy() } -private[columnar] final class MapColumnStats(dataType: DataType) extends ColumnStats { - private val columnType = ColumnType(dataType) - - override def gatherStats(row: InternalRow, ordinal: Int): Unit = { - if (!row.isNullAt(ordinal)) { - sizeInBytes += columnType.actualSize(row, ordinal) - count += 1 - } else { - gatherNullStats() - } - } +private[columnar] final class MapColumnStats(dataType: MapType) + extends OrderableSafeColumnStats[UnsafeMapData](dataType) { + override def getValue(row: InternalRow, ordinal: Int): UnsafeMapData = + row.getMap(ordinal).asInstanceOf[UnsafeMapData] - override def collectedStatistics: Array[Any] = - Array[Any](null, null, nullCount, count, sizeInBytes) + override def copy(value: UnsafeMapData): UnsafeMapData = value.copy() } private[columnar] final class NullColumnStats extends ColumnStats { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala index 3e987d00ab65..398f00b9395b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.execution.columnar import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData, UnsafeProjection} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -182,7 +183,7 @@ class ColumnStatsSuite extends SparkFunSuite { } } - def testMapColumnStats(dataType: DataType, initialStatistics: Array[Any]): Unit = { + def testMapColumnStats(dataType: MapType, initialStatistics: Array[Any]): Unit = { val columnType = ColumnType(dataType) test(s"${dataType.typeName}: empty") { @@ -226,4 +227,22 @@ class ColumnStatsSuite extends SparkFunSuite { assertResult(10)(collected(3)) assertResult(10 * (4 + unsafeData.getSizeInBytes))(collected(4)) } + + test("Reuse UnsafeRow for stats") { + val structType = StructType(Array(StructField("int", IntegerType))) + val stats = new StructColumnStats(structType) + val converter = UnsafeProjection.create(structType) + val unsafeData = converter(InternalRow(1)) + (1 to 10).foreach { value => + val row = new GenericInternalRow(Array[Any](unsafeData)) + unsafeData.setInt(0, value) + stats.gatherStats(row, 0) + } + val collected = stats.collectedStatistics + assertResult(converter(InternalRow(1)))(collected(0)) + assertResult(converter(InternalRow(10)))(collected(1)) + assertResult(0)(collected(2)) + assertResult(10)(collected(3)) + assertResult(10 * (4 + unsafeData.getSizeInBytes))(collected(4)) + } }