Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ private[columnar] class FloatColumnAccessor(buffer: ByteBuffer)
private[columnar] class DoubleColumnAccessor(buffer: ByteBuffer)
extends NativeColumnAccessor(buffer, DOUBLE)

private[columnar] class StringColumnAccessor(buffer: ByteBuffer)
extends NativeColumnAccessor(buffer, STRING)
private[columnar] class StringColumnAccessor(buffer: ByteBuffer, dataType: StringType)
extends NativeColumnAccessor(buffer, STRING(dataType))

private[columnar] class BinaryColumnAccessor(buffer: ByteBuffer)
extends BasicColumnAccessor[Array[Byte]](buffer, BINARY)
Expand Down Expand Up @@ -147,7 +147,7 @@ private[sql] object ColumnAccessor {
new LongColumnAccessor(buf)
case FloatType => new FloatColumnAccessor(buf)
case DoubleType => new DoubleColumnAccessor(buf)
case StringType => new StringColumnAccessor(buf)
case s: StringType => new StringColumnAccessor(buf, s)
case BinaryType => new BinaryColumnAccessor(buf)
case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS =>
new CompactDecimalColumnAccessor(buf, dt)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ private[columnar]
class DoubleColumnBuilder extends NativeColumnBuilder(new DoubleColumnStats, DOUBLE)

private[columnar]
class StringColumnBuilder extends NativeColumnBuilder(new StringColumnStats, STRING)
class StringColumnBuilder(dataType: StringType)
extends NativeColumnBuilder(new StringColumnStats(dataType), STRING(dataType))

private[columnar]
class BinaryColumnBuilder extends ComplexColumnBuilder(new BinaryColumnStats, BINARY)
Expand Down Expand Up @@ -185,7 +186,7 @@ private[columnar] object ColumnBuilder {
new LongColumnBuilder
case FloatType => new FloatColumnBuilder
case DoubleType => new DoubleColumnBuilder
case StringType => new StringColumnBuilder
case s: StringType => new StringColumnBuilder(s)
case BinaryType => new BinaryColumnBuilder
case CalendarIntervalType => new IntervalColumnBuilder
case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,23 +255,25 @@ private[columnar] final class DoubleColumnStats extends ColumnStats {
Array[Any](lower, upper, nullCount, count, sizeInBytes)
}

private[columnar] final class StringColumnStats extends ColumnStats {
private[columnar] final class StringColumnStats(collationId: Int) extends ColumnStats {
def this(dt: StringType) = this(dt.collationId)

protected var upper: UTF8String = null
protected var lower: UTF8String = null

override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
if (!row.isNullAt(ordinal)) {
val value = row.getUTF8String(ordinal)
val size = STRING.actualSize(row, ordinal)
val size = STRING(collationId).actualSize(row, ordinal)
gatherValueStats(value, size)
} else {
gatherNullStats()
}
}

def gatherValueStats(value: UTF8String, size: Int): Unit = {
if (upper == null || value.binaryCompare(upper) > 0) upper = value.clone()
if (lower == null || value.binaryCompare(lower) < 0) lower = value.clone()
if (upper == null || value.semanticCompare(upper, collationId) > 0) upper = value.clone()
if (lower == null || value.semanticCompare(lower, collationId) < 0) lower = value.clone()
sizeInBytes += size
count += 1
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -491,8 +491,8 @@ private[columnar] trait DirectCopyColumnType[JvmType] extends ColumnType[JvmType
}
}

private[columnar] object STRING
extends NativeColumnType(PhysicalStringType(StringType.collationId), 8)
private[columnar] case class STRING(collationId: Int)
extends NativeColumnType(PhysicalStringType(collationId), 8)
with DirectCopyColumnType[UTF8String] {

override def actualSize(row: InternalRow, ordinal: Int): Int = {
Expand Down Expand Up @@ -532,6 +532,12 @@ private[columnar] object STRING
override def clone(v: UTF8String): UTF8String = v.clone()
}

private[columnar] object STRING {
def apply(dt: StringType): STRING = {
STRING(dt.collationId)
}
}

private[columnar] case class COMPACT_DECIMAL(precision: Int, scale: Int)
extends NativeColumnType(PhysicalDecimalType(precision, scale), 8) {

Expand Down Expand Up @@ -821,7 +827,7 @@ private[columnar] object ColumnType {
case LongType | TimestampType | TimestampNTZType | _: DayTimeIntervalType => LONG
case FloatType => FLOAT
case DoubleType => DOUBLE
case StringType => STRING
case s: StringType => STRING(s)
case BinaryType => BINARY
case i: CalendarIntervalType => CALENDAR_INTERVAL
case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => COMPACT_DECIMAL(dt)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
classOf[LongColumnAccessor].getName
case FloatType => classOf[FloatColumnAccessor].getName
case DoubleType => classOf[DoubleColumnAccessor].getName
case StringType => classOf[StringColumnAccessor].getName
case _: StringType => classOf[StringColumnAccessor].getName
case BinaryType => classOf[BinaryColumnAccessor].getName
case CalendarIntervalType => classOf[IntervalColumnAccessor].getName
case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS =>
Expand All @@ -101,7 +101,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
val createCode = dt match {
case t if CodeGenerator.isPrimitiveType(dt) =>
s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));"
case NullType | StringType | BinaryType | CalendarIntervalType =>
case NullType | BinaryType | CalendarIntervalType =>
s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));"
case other =>
s"""$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ private[columnar] case object RunLengthEncoding extends CompressionScheme {
}

override def supports(columnType: ColumnType[_]): Boolean = columnType match {
case INT | LONG | SHORT | BYTE | STRING | BOOLEAN => true
case INT | LONG | SHORT | BYTE | _: STRING | BOOLEAN => true
case _ => false
}

Expand Down Expand Up @@ -373,7 +373,7 @@ private[columnar] case object DictionaryEncoding extends CompressionScheme {
}

override def supports(columnType: ColumnType[_]): Boolean = columnType match {
case INT | LONG | STRING => true
case INT | LONG | _: STRING => true
case _ => false
}

Expand Down
34 changes: 34 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.spark.sql.connector.catalog.CatalogV2Util.withDefaultOwnership
import org.apache.spark.sql.errors.DataTypeErrors.toSQLType
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec}
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.internal.{SqlApiConf, SQLConf}
import org.apache.spark.sql.types.{MapType, StringType, StructField, StructType}
Expand Down Expand Up @@ -1431,4 +1432,37 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
})
}

test("cache table with collated columns") {
val collations = Seq("UTF8_BINARY", "UTF8_LCASE", "UNICODE", "UNICODE_CI")
val lazyOptions = Seq(false, true)

for (
collation <- collations;
lazyTable <- lazyOptions
) {
val lazyStr = if (lazyTable) "LAZY" else ""

def checkCacheTable(values: String): Unit = {
sql(s"CACHE $lazyStr TABLE tbl AS SELECT col FROM VALUES ($values) AS (col)")
// Checks in-memory fetching code path.
val all = sql("SELECT col FROM tbl")
assert(all.queryExecution.executedPlan.collectFirst {
case _: InMemoryTableScanExec => true
}.nonEmpty)
checkAnswer(all, Row("a"))
// Checks column stats code path.
checkAnswer(sql("SELECT col FROM tbl WHERE col = 'a'"), Row("a"))
checkAnswer(sql("SELECT col FROM tbl WHERE col = 'b'"), Seq.empty)
}

withTable("tbl") {
checkCacheTable(s"'a' COLLATE $collation")
}
withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) {
withTable("tbl") {
checkCacheTable("'a'")
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.columnar

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.types.PhysicalDataType
import org.apache.spark.sql.types.StringType

class ColumnStatsSuite extends SparkFunSuite {
testColumnStats(classOf[BooleanColumnStats], BOOLEAN, Array(true, false, 0))
Expand All @@ -28,9 +29,9 @@ class ColumnStatsSuite extends SparkFunSuite {
testColumnStats(classOf[LongColumnStats], LONG, Array(Long.MaxValue, Long.MinValue, 0))
testColumnStats(classOf[FloatColumnStats], FLOAT, Array(Float.MaxValue, Float.MinValue, 0))
testColumnStats(classOf[DoubleColumnStats], DOUBLE, Array(Double.MaxValue, Double.MinValue, 0))
testColumnStats(classOf[StringColumnStats], STRING, Array(null, null, 0))
testDecimalColumnStats(Array(null, null, 0))
testIntervalColumnStats(Array(null, null, 0))
testStringColumnStats(Array(null, null, 0))

def testColumnStats[T <: PhysicalDataType, U <: ColumnStats](
columnStatsClass: Class[U],
Expand Down Expand Up @@ -141,4 +142,60 @@ class ColumnStatsSuite extends SparkFunSuite {
}
}
}

def testStringColumnStats[T <: PhysicalDataType, U <: ColumnStats](
initialStatistics: Array[Any]): Unit = {

Seq("UTF8_BINARY", "UTF8_LCASE", "UNICODE", "UNICODE_CI").foreach(collation => {
val columnType = STRING(StringType(collation))

test(s"STRING($collation): empty") {
val columnStats = new StringColumnStats(StringType(collation).collationId)
columnStats.collectedStatistics.zip(initialStatistics).foreach {
case (actual, expected) => assert(actual === expected)
}
}

test(s"STRING($collation): non-empty") {
import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._

val columnStats = new StringColumnStats(StringType(collation).collationId)
val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1))
rows.foreach(columnStats.gatherStats(_, 0))

val values = rows.take(10).map(_.get(0,
ColumnarDataTypeUtils.toLogicalDataType(columnType.dataType)))
val ordering = PhysicalDataType.ordering(
ColumnarDataTypeUtils.toLogicalDataType(columnType.dataType))
val stats = columnStats.collectedStatistics

assertResult(values.min(ordering), "Wrong lower bound")(stats(0))
assertResult(values.max(ordering), "Wrong upper bound")(stats(1))
assertResult(10, "Wrong null count")(stats(2))
assertResult(20, "Wrong row count")(stats(3))
assertResult(stats(4), "Wrong size in bytes") {
rows.map { row =>
if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0)
}.sum
}
}
})

test("STRING(UTF8_LCASE): collation-defined ordering") {
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.unsafe.types.UTF8String

val columnStats = new StringColumnStats(StringType("UTF8_LCASE").collationId)
val rows = Seq("b", "a", "C", "A").map(str => {
val row = new GenericInternalRow(1)
row(0) = UTF8String.fromString(str)
row
})
rows.foreach(columnStats.gatherStats(_, 0))

val stats = columnStats.collectedStatistics
assertResult(UTF8String.fromString("a"), "Wrong lower bound")(stats(0))
assertResult(UTF8String.fromString("C"), "Wrong upper bound")(stats(1))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection}
import org.apache.spark.sql.catalyst.types.{PhysicalArrayType, PhysicalDataType, PhysicalMapType, PhysicalStructType}
import org.apache.spark.sql.catalyst.util.CollationFactory
import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
Expand All @@ -40,7 +41,9 @@ class ColumnTypeSuite extends SparkFunSuite {
val checks = Map(
NULL -> 0, BOOLEAN -> 1, BYTE -> 1, SHORT -> 2, INT -> 4, LONG -> 8,
FLOAT -> 4, DOUBLE -> 8, COMPACT_DECIMAL(15, 10) -> 8, LARGE_DECIMAL(20, 10) -> 12,
STRING -> 8, BINARY -> 16, STRUCT_TYPE -> 20, ARRAY_TYPE -> 28, MAP_TYPE -> 68,
STRING(StringType) -> 8, STRING(StringType("UTF8_LCASE")) -> 8,
STRING(StringType("UNICODE")) -> 8, STRING(StringType("UNICODE_CI")) -> 8,
BINARY -> 16, STRUCT_TYPE -> 20, ARRAY_TYPE -> 28, MAP_TYPE -> 68,
CALENDAR_INTERVAL -> 16)

checks.foreach { case (columnType, expectedSize) =>
Expand Down Expand Up @@ -73,7 +76,12 @@ class ColumnTypeSuite extends SparkFunSuite {
checkActualSize(LONG, Long.MaxValue, 8)
checkActualSize(FLOAT, Float.MaxValue, 4)
checkActualSize(DOUBLE, Double.MaxValue, 8)
checkActualSize(STRING, "hello", 4 + "hello".getBytes(StandardCharsets.UTF_8).length)
Seq(
"UTF8_BINARY", "UTF8_LCASE", "UNICODE", "UNICODE_CI"
).foreach(collation => {
checkActualSize(STRING(StringType(collation)),
"hello", 4 + "hello".getBytes(StandardCharsets.UTF_8).length)
})
checkActualSize(BINARY, Array.fill[Byte](4)(0.toByte), 4 + 4)
checkActualSize(COMPACT_DECIMAL(15, 10), Decimal(0, 15, 10), 8)
checkActualSize(LARGE_DECIMAL(20, 10), Decimal(0, 20, 10), 5)
Expand All @@ -93,7 +101,10 @@ class ColumnTypeSuite extends SparkFunSuite {
testNativeColumnType(FLOAT)
testNativeColumnType(DOUBLE)
testNativeColumnType(COMPACT_DECIMAL(15, 10))
testNativeColumnType(STRING)
testNativeColumnType(STRING(StringType)) // UTF8_BINARY
testNativeColumnType(STRING(StringType("UTF8_LCASE")))
testNativeColumnType(STRING(StringType("UNICODE")))
testNativeColumnType(STRING(StringType("UNICODE_CI")))

testColumnType(NULL)
testColumnType(BINARY)
Expand All @@ -104,20 +115,28 @@ class ColumnTypeSuite extends SparkFunSuite {
testColumnType(CALENDAR_INTERVAL)

def testNativeColumnType[T <: PhysicalDataType](columnType: NativeColumnType[T]): Unit = {
testColumnType[T#InternalType](columnType)
val typeName = columnType match {
case s: STRING =>
val collation = CollationFactory.fetchCollation(s.collationId).collationName
Some(if (collation == "UTF8_BINARY") "STRING" else s"STRING($collation)")
case _ => None
}
testColumnType[T#InternalType](columnType, typeName)
}

def testColumnType[JvmType](columnType: ColumnType[JvmType]): Unit = {

def testColumnType[JvmType](
columnType: ColumnType[JvmType],
typeName: Option[String] = None): Unit = {
val proj = UnsafeProjection.create(
Array[DataType](ColumnarDataTypeUtils.toLogicalDataType(columnType.dataType)))
val converter = CatalystTypeConverters.createToScalaConverter(
ColumnarDataTypeUtils.toLogicalDataType(columnType.dataType))
val seq = (0 until 4).map(_ => proj(makeRandomRow(columnType)).copy())
val totalSize = seq.map(_.getSizeInBytes).sum
val bufferSize = Math.max(DEFAULT_BUFFER_SIZE, totalSize)
val testName = typeName.getOrElse(columnType.toString)

test(s"$columnType append/extract") {
test(s"$testName append/extract") {
val buffer = ByteBuffer.allocate(bufferSize).order(ByteOrder.nativeOrder())
seq.foreach(r => columnType.append(columnType.getField(r, 0), buffer))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ object ColumnarTestUtils {
case LONG => Random.nextLong()
case FLOAT => Random.nextFloat()
case DOUBLE => Random.nextDouble()
case STRING => UTF8String.fromString(Random.nextString(Random.nextInt(32)))
case _: STRING => UTF8String.fromString(Random.nextString(Random.nextInt(32)))
case BINARY => randomBytes(Random.nextInt(32))
case CALENDAR_INTERVAL =>
new CalendarInterval(Random.nextInt(), Random.nextInt(), Random.nextLong())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection}
import org.apache.spark.sql.catalyst.types.{PhysicalArrayType, PhysicalMapType, PhysicalStructType}
import org.apache.spark.sql.catalyst.util.CollationFactory
import org.apache.spark.sql.types._

class TestNullableColumnAccessor[JvmType](
Expand All @@ -41,21 +42,33 @@ object TestNullableColumnAccessor {
class NullableColumnAccessorSuite extends SparkFunSuite {
import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._

Seq(
val stringTypes = Seq(
STRING(StringType), // UTF8_BINARY
STRING(StringType("UTF8_LCASE")),
STRING(StringType("UNICODE")),
STRING(StringType("UNICODE_CI")))
val otherTypes = Seq(
NULL, BOOLEAN, BYTE, SHORT, INT, LONG, FLOAT, DOUBLE,
STRING, BINARY, COMPACT_DECIMAL(15, 10), LARGE_DECIMAL(20, 10),
BINARY, COMPACT_DECIMAL(15, 10), LARGE_DECIMAL(20, 10),
STRUCT(PhysicalStructType(Array(StructField("a", StringType)))),
ARRAY(PhysicalArrayType(IntegerType, true)),
MAP(PhysicalMapType(IntegerType, StringType, true)),
CALENDAR_INTERVAL)
.foreach {

stringTypes.foreach(s => {
val collation = CollationFactory.fetchCollation(s.collationId).collationName
val typeName = if (collation == "UTF8_BINARY") "STRING" else s"STRING($collation)"
testNullableColumnAccessor(s, Some(typeName))
})
otherTypes.foreach {
testNullableColumnAccessor(_)
}

def testNullableColumnAccessor[JvmType](
columnType: ColumnType[JvmType]): Unit = {
columnType: ColumnType[JvmType],
testTypeName: Option[String] = None): Unit = {

val typeName = columnType.getClass.getSimpleName.stripSuffix("$")
val typeName = testTypeName.getOrElse(columnType.getClass.getSimpleName.stripSuffix("$"))
val nullRow = makeNullRow(1)

test(s"Nullable $typeName column accessor: empty column") {
Expand Down
Loading