Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Initial work
  • Loading branch information
nikolamand-db committed Jun 20, 2024
commit eeb9ce04cfc866e1658b4576f1ad2b4de4adf6d3
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ private[columnar] class FloatColumnAccessor(buffer: ByteBuffer)
private[columnar] class DoubleColumnAccessor(buffer: ByteBuffer)
extends NativeColumnAccessor(buffer, DOUBLE)

private[columnar] class StringColumnAccessor(buffer: ByteBuffer)
extends NativeColumnAccessor(buffer, STRING)
private[columnar] class StringColumnAccessor(buffer: ByteBuffer, dataType: StringType)
extends NativeColumnAccessor(buffer, STRING(dataType))

private[columnar] class BinaryColumnAccessor(buffer: ByteBuffer)
extends BasicColumnAccessor[Array[Byte]](buffer, BINARY)
Expand Down Expand Up @@ -147,7 +147,7 @@ private[sql] object ColumnAccessor {
new LongColumnAccessor(buf)
case FloatType => new FloatColumnAccessor(buf)
case DoubleType => new DoubleColumnAccessor(buf)
case StringType => new StringColumnAccessor(buf)
case dt: StringType => new StringColumnAccessor(buf, dt)
case BinaryType => new BinaryColumnAccessor(buf)
case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS =>
new CompactDecimalColumnAccessor(buf, dt)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ private[columnar]
class DoubleColumnBuilder extends NativeColumnBuilder(new DoubleColumnStats, DOUBLE)

private[columnar]
class StringColumnBuilder extends NativeColumnBuilder(new StringColumnStats, STRING)
class StringColumnBuilder(dataType: StringType)
extends NativeColumnBuilder(new StringColumnStats(dataType), STRING(dataType))

private[columnar]
class BinaryColumnBuilder extends ComplexColumnBuilder(new BinaryColumnStats, BINARY)
Expand Down Expand Up @@ -185,7 +186,7 @@ private[columnar] object ColumnBuilder {
new LongColumnBuilder
case FloatType => new FloatColumnBuilder
case DoubleType => new DoubleColumnBuilder
case StringType => new StringColumnBuilder
case s: StringType => new StringColumnBuilder(s)
case BinaryType => new BinaryColumnBuilder
case CalendarIntervalType => new IntervalColumnBuilder
case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.columnar

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference}
import org.apache.spark.sql.catalyst.util.CollationFactory
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -255,23 +256,27 @@ private[columnar] final class DoubleColumnStats extends ColumnStats {
Array[Any](lower, upper, nullCount, count, sizeInBytes)
}

private[columnar] final class StringColumnStats extends ColumnStats {
private[columnar] final class StringColumnStats(collationId: Int) extends ColumnStats {
def this(dt: StringType) = this(dt.collationId)

protected var upper: UTF8String = null
protected var lower: UTF8String = null

override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
if (!row.isNullAt(ordinal)) {
val value = row.getUTF8String(ordinal)
val size = STRING.actualSize(row, ordinal)
val size = STRING(collationId).actualSize(row, ordinal)
gatherValueStats(value, size)
} else {
gatherNullStats()
}
}

def gatherValueStats(value: UTF8String, size: Int): Unit = {
if (upper == null || value.binaryCompare(upper) > 0) upper = value.clone()
if (lower == null || value.binaryCompare(lower) < 0) lower = value.clone()
def collatedCompare(l: UTF8String, r: UTF8String): Int =
CollationFactory.fetchCollation(collationId).comparator.compare(l, r)
if (upper == null || collatedCompare(value, upper) > 0) upper = value.clone()
if (lower == null || collatedCompare(value, lower) < 0) lower = value.clone()
sizeInBytes += size
count += 1
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -491,8 +491,8 @@ private[columnar] trait DirectCopyColumnType[JvmType] extends ColumnType[JvmType
}
}

private[columnar] object STRING
extends NativeColumnType(PhysicalStringType(StringType.collationId), 8)
private[columnar] case class STRING(collationId: Int)
extends NativeColumnType(PhysicalStringType(collationId), 8)
with DirectCopyColumnType[UTF8String] {

override def actualSize(row: InternalRow, ordinal: Int): Int = {
Expand Down Expand Up @@ -532,6 +532,12 @@ private[columnar] object STRING
override def clone(v: UTF8String): UTF8String = v.clone()
}

private[columnar] object STRING {
def apply(dt: StringType): STRING = {
STRING(dt.collationId)
}
}

private[columnar] case class COMPACT_DECIMAL(precision: Int, scale: Int)
extends NativeColumnType(PhysicalDecimalType(precision, scale), 8) {

Expand Down Expand Up @@ -821,7 +827,7 @@ private[columnar] object ColumnType {
case LongType | TimestampType | TimestampNTZType | _: DayTimeIntervalType => LONG
case FloatType => FLOAT
case DoubleType => DOUBLE
case StringType => STRING
case s: StringType => STRING(s)
case BinaryType => BINARY
case i: CalendarIntervalType => CALENDAR_INTERVAL
case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => COMPACT_DECIMAL(dt)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
classOf[LongColumnAccessor].getName
case FloatType => classOf[FloatColumnAccessor].getName
case DoubleType => classOf[DoubleColumnAccessor].getName
case StringType => classOf[StringColumnAccessor].getName
case _: StringType => classOf[StringColumnAccessor].getName
case BinaryType => classOf[BinaryColumnAccessor].getName
case CalendarIntervalType => classOf[IntervalColumnAccessor].getName
case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import scala.collection.mutable
import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.types.{PhysicalBooleanType, PhysicalByteType, PhysicalDataType, PhysicalDoubleType, PhysicalFloatType, PhysicalIntegerType, PhysicalLongType, PhysicalShortType, PhysicalStringType}
import org.apache.spark.sql.catalyst.util.CollationFactory
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.columnar._
import org.apache.spark.sql.execution.vectorized.WritableColumnVector
Expand Down Expand Up @@ -176,7 +177,8 @@ private[columnar] case object RunLengthEncoding extends CompressionScheme {
}

override def supports(columnType: ColumnType[_]): Boolean = columnType match {
case INT | LONG | SHORT | BYTE | STRING | BOOLEAN => true
case INT | LONG | SHORT | BYTE | BOOLEAN => true
case STRING(CollationFactory.UTF8_BINARY_COLLATION_ID) => true
case _ => false
}

Expand Down Expand Up @@ -373,7 +375,7 @@ private[columnar] case object DictionaryEncoding extends CompressionScheme {
}

override def supports(columnType: ColumnType[_]): Boolean = columnType match {
case INT | LONG | STRING => true
case INT | LONG | STRING(CollationFactory.UTF8_BINARY_COLLATION_ID) => true
case _ => false
}

Expand Down
25 changes: 25 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1431,4 +1431,29 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
})
}

test("cache table with collated columns") {
val collations = Seq(
"UTF8_BINARY",
"UTF8_LCASE",
"UNICODE"
)
collations.foreach(collation =>
withTable("t1", "t2") {
sql(s"CACHE LAZY TABLE t1 AS SELECT col FROM VALUES ('a' COLLATE $collation) AS (col)")

// checkAnswer(sql("SELECT COLLATION(col) FROM t1"), Row(collation))
// withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) {
// sql(s"CACHE TABLE t2 AS SELECT col FROM VALUES ('a') AS (col)")
// checkAnswer(sql("SELECT COLLATION(col) FROM t2"), Row(collation))
// }
}
)
}

test("cache table new test") {
withTable("t") {
sql(s"CACHE LAZY TABLE t AS SELECT col FROM VALUES ('a' COLLATE UTF8_LCASE) AS (col)")
checkAnswer(sql("SELECT * FROM t WHERE col = 'A'"), Row("a"))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.columnar

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.types.PhysicalDataType
import org.apache.spark.sql.types.StringType

class ColumnStatsSuite extends SparkFunSuite {
testColumnStats(classOf[BooleanColumnStats], BOOLEAN, Array(true, false, 0))
Expand All @@ -28,7 +29,7 @@ class ColumnStatsSuite extends SparkFunSuite {
testColumnStats(classOf[LongColumnStats], LONG, Array(Long.MaxValue, Long.MinValue, 0))
testColumnStats(classOf[FloatColumnStats], FLOAT, Array(Float.MaxValue, Float.MinValue, 0))
testColumnStats(classOf[DoubleColumnStats], DOUBLE, Array(Double.MaxValue, Double.MinValue, 0))
testColumnStats(classOf[StringColumnStats], STRING, Array(null, null, 0))
testColumnStats(classOf[StringColumnStats], STRING(StringType), Array(null, null, 0))
testDecimalColumnStats(Array(null, null, 0))
testIntervalColumnStats(Array(null, null, 0))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ class ColumnTypeSuite extends SparkFunSuite {
val checks = Map(
NULL -> 0, BOOLEAN -> 1, BYTE -> 1, SHORT -> 2, INT -> 4, LONG -> 8,
FLOAT -> 4, DOUBLE -> 8, COMPACT_DECIMAL(15, 10) -> 8, LARGE_DECIMAL(20, 10) -> 12,
STRING -> 8, BINARY -> 16, STRUCT_TYPE -> 20, ARRAY_TYPE -> 28, MAP_TYPE -> 68,
STRING(StringType) -> 8, STRING(StringType("UTF8_LCASE")) -> 8,
STRING(StringType("UNICODE")) -> 8, STRING(StringType("UNICODE_CO")) -> 8,
BINARY -> 16, STRUCT_TYPE -> 20, ARRAY_TYPE -> 28, MAP_TYPE -> 68,
CALENDAR_INTERVAL -> 16)

checks.foreach { case (columnType, expectedSize) =>
Expand Down Expand Up @@ -73,7 +75,12 @@ class ColumnTypeSuite extends SparkFunSuite {
checkActualSize(LONG, Long.MaxValue, 8)
checkActualSize(FLOAT, Float.MaxValue, 4)
checkActualSize(DOUBLE, Double.MaxValue, 8)
checkActualSize(STRING, "hello", 4 + "hello".getBytes(StandardCharsets.UTF_8).length)
Seq(
"UTF8_BINARY", "UTF8_LCASE", "UNICODE", "UNICODE_CI"
).foreach(collation => {
checkActualSize(STRING(StringType(collation)),
"hello", 4 + "hello".getBytes(StandardCharsets.UTF_8).length)
})
checkActualSize(BINARY, Array.fill[Byte](4)(0.toByte), 4 + 4)
checkActualSize(COMPACT_DECIMAL(15, 10), Decimal(0, 15, 10), 8)
checkActualSize(LARGE_DECIMAL(20, 10), Decimal(0, 20, 10), 5)
Expand All @@ -93,7 +100,10 @@ class ColumnTypeSuite extends SparkFunSuite {
testNativeColumnType(FLOAT)
testNativeColumnType(DOUBLE)
testNativeColumnType(COMPACT_DECIMAL(15, 10))
testNativeColumnType(STRING)
testNativeColumnType(STRING(StringType))
testNativeColumnType(STRING(StringType("UTF8_LCASE")))
testNativeColumnType(STRING(StringType("UNICODE")))
testNativeColumnType(STRING(StringType("UNICODE_CI")))

testColumnType(NULL)
testColumnType(BINARY)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ object ColumnarTestUtils {
case LONG => Random.nextLong()
case FLOAT => Random.nextFloat()
case DOUBLE => Random.nextDouble()
case STRING => UTF8String.fromString(Random.nextString(Random.nextInt(32)))
case _: STRING => UTF8String.fromString(Random.nextString(Random.nextInt(32)))
case BINARY => randomBytes(Random.nextInt(32))
case CALENDAR_INTERVAL =>
new CalendarInterval(Random.nextInt(), Random.nextInt(), Random.nextLong())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ class NullableColumnAccessorSuite extends SparkFunSuite {

Seq(
NULL, BOOLEAN, BYTE, SHORT, INT, LONG, FLOAT, DOUBLE,
STRING, BINARY, COMPACT_DECIMAL(15, 10), LARGE_DECIMAL(20, 10),
STRING(StringType), STRING(StringType("UTF8_LCASE")), STRING(StringType("UNICODE")),
STRING(StringType("UNICODE_CI")), BINARY, COMPACT_DECIMAL(15, 10), LARGE_DECIMAL(20, 10),
STRUCT(PhysicalStructType(Array(StructField("a", StringType)))),
ARRAY(PhysicalArrayType(IntegerType, true)),
MAP(PhysicalMapType(IntegerType, StringType, true)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ class NullableColumnBuilderSuite extends SparkFunSuite {

Seq(
BOOLEAN, BYTE, SHORT, INT, LONG, FLOAT, DOUBLE,
STRING, BINARY, COMPACT_DECIMAL(15, 10), LARGE_DECIMAL(20, 10),
STRING(StringType), STRING(StringType("UTF8_LCASE")), STRING(StringType("UNICODE")),
STRING(StringType("UNICODE_CI")), BINARY, COMPACT_DECIMAL(15, 10), LARGE_DECIMAL(20, 10),
STRUCT(PhysicalStructType(Array(StructField("a", StringType)))),
ARRAY(PhysicalArrayType(IntegerType, true)),
MAP(PhysicalMapType(IntegerType, StringType, true)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.benchmark.{Benchmark, BenchmarkBase}
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.catalyst.types.PhysicalDataType
import org.apache.spark.sql.execution.columnar.{BOOLEAN, INT, LONG, NativeColumnType, SHORT, STRING}
import org.apache.spark.sql.types.StringType
import org.apache.spark.util.Utils._

/**
Expand Down Expand Up @@ -231,8 +232,8 @@ object CompressionSchemeBenchmark extends BenchmarkBase with AllCompressionSchem
}
testData.rewind()

runEncodeBenchmark("STRING Encode", iters, count, STRING, testData)
runDecodeBenchmark("STRING Decode", iters, count, STRING, testData)
runEncodeBenchmark("STRING Encode", iters, count, STRING(StringType), testData)
runDecodeBenchmark("STRING Decode", iters, count, STRING(StringType), testData)
}

override def runBenchmarkSuite(mainArgs: Array[String]): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,18 @@ import org.apache.spark.sql.catalyst.types.PhysicalDataType
import org.apache.spark.sql.execution.columnar._
import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
import org.apache.spark.sql.types.StringType

class DictionaryEncodingSuite extends SparkFunSuite {
val nullValue = -1
testDictionaryEncoding(new IntColumnStats, INT)
testDictionaryEncoding(new LongColumnStats, LONG)
testDictionaryEncoding(new StringColumnStats, STRING, false)
Seq(
"UTF8_BINARY", "UTF8_LCASE", "UNICODE", "UNICODE_CI"
).foreach(collation => {
val dt = StringType(collation)
testDictionaryEncoding(new StringColumnStats(dt), STRING(dt), false)
})

def testDictionaryEncoding[T <: PhysicalDataType](
columnStats: ColumnStats,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.types.PhysicalDataType
import org.apache.spark.sql.execution.columnar._
import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
import org.apache.spark.sql.types.StringType

class RunLengthEncodingSuite extends SparkFunSuite {
val nullValue = -1
Expand All @@ -31,7 +32,12 @@ class RunLengthEncodingSuite extends SparkFunSuite {
testRunLengthEncoding(new ShortColumnStats, SHORT)
testRunLengthEncoding(new IntColumnStats, INT)
testRunLengthEncoding(new LongColumnStats, LONG)
testRunLengthEncoding(new StringColumnStats, STRING, false)
Seq(
"UTF8_BINARY", "UTF8_LCASE", "UNICODE", "UNICODE_CI"
).foreach(collation => {
val dt = StringType(collation)
testRunLengthEncoding(new StringColumnStats(dt), STRING(dt), false)
})

def testRunLengthEncoding[T <: PhysicalDataType](
columnStats: ColumnStats,
Expand Down