Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
pr feedback
  • Loading branch information
Patrick Woody committed Apr 1, 2018
commit 1479bded7c6e220f6dad2eb5bff4feb64cbaf991
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ package org.apache.spark.sql.execution.columnar

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, RowOrdering}
import org.apache.spark.sql.catalyst.util.{ArrayData, TypeUtils}
import org.apache.spark.sql.catalyst.expressions.{UnsafeArrayData, UnsafeRow}
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -357,20 +358,21 @@ private abstract class OrderableSafeColumnStats[T](dataType: DataType) extends C
}

private[columnar] final class ArrayColumnStats(dataType: DataType)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dataType: DataType -> dataType: ArrayType?

extends OrderableSafeColumnStats[ArrayData](dataType) {
override def getValue(row: InternalRow, ordinal: Int): ArrayData = row.getArray(ordinal)
extends OrderableSafeColumnStats[UnsafeArrayData](dataType) {
override def getValue(row: InternalRow, ordinal: Int): UnsafeArrayData =
row.getArray(ordinal).asInstanceOf[UnsafeArrayData]

override def copy(value: ArrayData): ArrayData = value.copy()
override def copy(value: UnsafeArrayData): UnsafeArrayData = value.copy()
}

private[columnar] final class StructColumnStats(dataType: DataType)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dataType: DataType -> dataType: StructType?

extends OrderableSafeColumnStats[InternalRow](dataType) {
extends OrderableSafeColumnStats[UnsafeRow](dataType) {
private val numFields = dataType.asInstanceOf[StructType].fields.length

override def getValue(row: InternalRow, ordinal: Int): InternalRow =
row.getStruct(ordinal, numFields)
override def getValue(row: InternalRow, ordinal: Int): UnsafeRow =
row.getStruct(ordinal, numFields).asInstanceOf[UnsafeRow]

override def copy(value: InternalRow): InternalRow = value.copy()
override def copy(value: UnsafeRow): UnsafeRow = value.copy()
}

private[columnar] final class MapColumnStats(dataType: DataType) extends ColumnStats {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dataType: DataType -> dataType: MapType?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a TODO that we need to make this use OrderableSafeColumnStats when MapType is orderable.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that you mention it - we can just have it use it now since it will always fall through to the unorderable case. Everything will just work when we make it orderable w/o code change here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, sounds good to me.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.execution.columnar

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.RowOrdering
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData}
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._

Expand All @@ -35,9 +35,30 @@ class ColumnStatsSuite extends SparkFunSuite {
)
testColumnStats(classOf[StringColumnStats], STRING, Array(null, null, 0, 0, 0))
testDecimalColumnStats(Array(null, null, 0, 0, 0))
testArrayColumnStats(ArrayType(IntegerType), orderable = true, Array(null, null, 0, 0, 0))
testStructColumnStats(
StructType(Array(StructField("test", DataTypes.StringType))),

private val orderableArrayDataType = ArrayType(IntegerType)
testOrderableColumnStats(
orderableArrayDataType,
() => new ArrayColumnStats(orderableArrayDataType),
ARRAY(orderableArrayDataType),
orderable = true,
Array(null, null, 0, 0, 0)
)

private val unorderableArrayDataType = ArrayType(MapType(IntegerType, StringType))
testOrderableColumnStats(
unorderableArrayDataType,
() => new ArrayColumnStats(unorderableArrayDataType),
ARRAY(unorderableArrayDataType),
orderable = false,
Array(null, null, 0, 0, 0)
)

private val structDataType = StructType(Array(StructField("test", DataTypes.StringType)))
testOrderableColumnStats(
structDataType,
() => new StructColumnStats(structDataType),
STRUCT(structDataType),
orderable = true,
Array(null, null, 0, 0, 0)
)
Expand Down Expand Up @@ -120,58 +141,23 @@ class ColumnStatsSuite extends SparkFunSuite {
}
}

def testArrayColumnStats(
dataType: DataType, orderable: Boolean, initialStatistics: Array[Any]): Unit = {
val columnType = ColumnType(dataType)

test(s"${dataType.typeName}: empty") {
val objectStats = new ArrayColumnStats(dataType)
objectStats.collectedStatistics.zip(initialStatistics).foreach {
case (actual, expected) => assert(actual === expected)
}
}

test(s"${dataType.typeName}: non-empty") {
import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._
val objectStats = new ArrayColumnStats(dataType)
val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1))
rows.foreach(objectStats.gatherStats(_, 0))

val stats = objectStats.collectedStatistics
if (orderable) {
val values = rows.take(10).map(_.get(0, columnType.dataType))
val ordering = TypeUtils.getInterpretedOrdering(dataType)

assertResult(values.min(ordering), "Wrong lower bound")(stats(0))
assertResult(values.max(ordering), "Wrong upper bound")(stats(1))
} else {
assertResult(null, "Wrong lower bound")(stats(0))
assertResult(null, "Wrong upper bound")(stats(1))
}
assertResult(10, "Wrong null count")(stats(2))
assertResult(20, "Wrong row count")(stats(3))
assertResult(stats(4), "Wrong size in bytes") {
rows.map { row =>
if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0)
}.sum
}
}
}

def testStructColumnStats(
dataType: DataType, orderable: Boolean, initialStatistics: Array[Any]): Unit = {
val columnType = ColumnType(dataType)
def testOrderableColumnStats[T](
dataType: DataType,
statsSupplier: () => OrderableSafeColumnStats[T],
columnType: ColumnType[T],
orderable: Boolean,
initialStatistics: Array[Any]): Unit = {

test(s"${dataType.typeName}: empty") {
val objectStats = new StructColumnStats(dataType)
test(s"${dataType.typeName}, $orderable: empty") {
val objectStats = statsSupplier()
objectStats.collectedStatistics.zip(initialStatistics).foreach {
case (actual, expected) => assert(actual === expected)
}
}

test(s"${dataType.typeName}: non-empty") {
test(s"${dataType.typeName}, $orderable: non-empty") {
import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._
val objectStats = new StructColumnStats(dataType)
val objectStats = statsSupplier()
val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1))
rows.foreach(objectStats.gatherStats(_, 0))

Expand Down Expand Up @@ -224,4 +210,20 @@ class ColumnStatsSuite extends SparkFunSuite {
}
}
}

test("Reuse UnsafeArrayData for stats") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also test against UnsafeRow too.

val stats = new ArrayColumnStats(ArrayType(IntegerType))
val unsafeData = UnsafeArrayData.fromPrimitiveArray(Array(1))
(1 to 10).foreach { value =>
val row = new GenericInternalRow(Array[Any](unsafeData))
unsafeData.setInt(0, value)
stats.gatherStats(row, 0)
}
val collected = stats.collectedStatistics
assertResult(UnsafeArrayData.fromPrimitiveArray(Array(1)))(collected(0))
assertResult(UnsafeArrayData.fromPrimitiveArray(Array(10)))(collected(1))
assertResult(0)(collected(2))
assertResult(10)(collected(3))
assertResult(10 * (4 + unsafeData.getSizeInBytes))(collected(4))
}
}