Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
extra test make Map orderable safe
  • Loading branch information
Patrick Woody committed Apr 1, 2018
commit 6ea0919ec0a9dfc6b121c88790fac79aa072bc60
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.columnar

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, RowOrdering}
import org.apache.spark.sql.catalyst.expressions.{UnsafeArrayData, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.{UnsafeArrayData, UnsafeMapData, UnsafeRow}
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
Expand Down Expand Up @@ -357,15 +357,15 @@ private abstract class OrderableSafeColumnStats[T](dataType: DataType) extends C
Array[Any](lower, upper, nullCount, count, sizeInBytes)
}

private[columnar] final class ArrayColumnStats(dataType: DataType)
private[columnar] final class ArrayColumnStats(dataType: ArrayType)
extends OrderableSafeColumnStats[UnsafeArrayData](dataType) {
override def getValue(row: InternalRow, ordinal: Int): UnsafeArrayData =
row.getArray(ordinal).asInstanceOf[UnsafeArrayData]

override def copy(value: UnsafeArrayData): UnsafeArrayData = value.copy()
}

private[columnar] final class StructColumnStats(dataType: DataType)
private[columnar] final class StructColumnStats(dataType: StructType)
extends OrderableSafeColumnStats[UnsafeRow](dataType) {
private val numFields = dataType.asInstanceOf[StructType].fields.length

Expand All @@ -375,20 +375,12 @@ private[columnar] final class StructColumnStats(dataType: DataType)
override def copy(value: UnsafeRow): UnsafeRow = value.copy()
}

private[columnar] final class MapColumnStats(dataType: DataType) extends ColumnStats {
private val columnType = ColumnType(dataType)

override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
if (!row.isNullAt(ordinal)) {
sizeInBytes += columnType.actualSize(row, ordinal)
count += 1
} else {
gatherNullStats()
}
}
private[columnar] final class MapColumnStats(dataType: MapType)
extends OrderableSafeColumnStats[UnsafeMapData](dataType) {
override def getValue(row: InternalRow, ordinal: Int): UnsafeMapData =
row.getMap(ordinal).asInstanceOf[UnsafeMapData]

override def collectedStatistics: Array[Any] =
Array[Any](null, null, nullCount, count, sizeInBytes)
override def copy(value: UnsafeMapData): UnsafeMapData = value.copy()
}

private[columnar] final class NullColumnStats extends ColumnStats {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
package org.apache.spark.sql.execution.columnar

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData, UnsafeProjection}
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -182,7 +183,7 @@ class ColumnStatsSuite extends SparkFunSuite {
}
}

def testMapColumnStats(dataType: DataType, initialStatistics: Array[Any]): Unit = {
def testMapColumnStats(dataType: MapType, initialStatistics: Array[Any]): Unit = {
val columnType = ColumnType(dataType)

test(s"${dataType.typeName}: empty") {
Expand Down Expand Up @@ -226,4 +227,22 @@ class ColumnStatsSuite extends SparkFunSuite {
assertResult(10)(collected(3))
assertResult(10 * (4 + unsafeData.getSizeInBytes))(collected(4))
}

test("Reuse UnsafeRow for stats") {
val structType = StructType(Array(StructField("int", IntegerType)))
val stats = new StructColumnStats(structType)
val converter = UnsafeProjection.create(structType)
val unsafeData = converter(InternalRow(1))
(1 to 10).foreach { value =>
val row = new GenericInternalRow(Array[Any](unsafeData))
unsafeData.setInt(0, value)
stats.gatherStats(row, 0)
}
val collected = stats.collectedStatistics
assertResult(converter(InternalRow(1)))(collected(0))
assertResult(converter(InternalRow(10)))(collected(1))
assertResult(0)(collected(2))
assertResult(10)(collected(3))
assertResult(10 * (4 + unsafeData.getSizeInBytes))(collected(4))
}
}