diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 07d33fa7d52ae..41fe0c3b60d9e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -387,6 +387,14 @@ object SQLConf { .booleanConf .createWithDefault(true) + val PARQUET_FILTER_PUSHDOWN_DECIMAL_ENABLED = + buildConf("spark.sql.parquet.filterPushdown.decimal") + .doc("If true, enables Parquet filter push-down optimization for Decimal. " + + "This configuration only has an effect when 'spark.sql.parquet.filterPushdown' is enabled.") + .internal() + .booleanConf + .createWithDefault(true) + val PARQUET_FILTER_PUSHDOWN_STRING_STARTSWITH_ENABLED = buildConf("spark.sql.parquet.filterPushdown.string.startsWith") .doc("If true, enables Parquet filter push-down optimization for string startsWith function. " + @@ -1505,6 +1513,8 @@ class SQLConf extends Serializable with Logging { def parquetFilterPushDownTimestamp: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_TIMESTAMP_ENABLED) + def parquetFilterPushDownDecimal: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_DECIMAL_ENABLED) + def parquetFilterPushDownStringStartWith: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_STRING_STARTSWITH_ENABLED) diff --git a/sql/core/benchmarks/FilterPushdownBenchmark-results.txt b/sql/core/benchmarks/FilterPushdownBenchmark-results.txt index 4f38cc4cee96d..2215ed91e2018 100644 --- a/sql/core/benchmarks/FilterPushdownBenchmark-results.txt +++ b/sql/core/benchmarks/FilterPushdownBenchmark-results.txt @@ -292,120 +292,120 @@ Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz Select 1 decimal(9, 2) row (value = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 3785 / 3867 4.2 240.6 1.0X -Parquet Vectorized (Pushdown) 3820 / 3928 4.1 242.9 1.0X -Native ORC Vectorized 3981 / 4049 4.0 253.1 1.0X -Native ORC Vectorized (Pushdown) 702 / 735 22.4 44.6 5.4X +Parquet Vectorized 4546 / 4743 3.5 289.0 1.0X +Parquet Vectorized (Pushdown) 161 / 175 98.0 10.2 28.3X +Native ORC Vectorized 5721 / 5842 2.7 363.7 0.8X +Native ORC Vectorized (Pushdown) 1019 / 1070 15.4 64.8 4.5X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz Select 10% decimal(9, 2) rows (value < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 4694 / 4813 3.4 298.4 1.0X -Parquet Vectorized (Pushdown) 4839 / 4907 3.3 307.6 1.0X -Native ORC Vectorized 4943 / 5032 3.2 314.2 0.9X -Native ORC Vectorized (Pushdown) 2043 / 2085 7.7 129.9 2.3X +Parquet Vectorized 6340 / 7236 2.5 403.1 1.0X +Parquet Vectorized (Pushdown) 3052 / 3164 5.2 194.1 2.1X +Native ORC Vectorized 8370 / 9214 1.9 532.1 0.8X +Native ORC Vectorized (Pushdown) 4137 / 4242 3.8 263.0 1.5X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz Select 50% decimal(9, 2) rows (value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 8321 / 8472 1.9 529.0 1.0X -Parquet Vectorized (Pushdown) 8125 / 8471 1.9 516.6 1.0X -Native ORC Vectorized 8524 / 8616 1.8 541.9 1.0X -Native ORC Vectorized (Pushdown) 7961 / 8383 2.0 506.1 1.0X +Parquet Vectorized 12976 / 13249 1.2 825.0 1.0X +Parquet Vectorized (Pushdown) 12655 / 13570 1.2 804.6 1.0X +Native ORC Vectorized 15562 / 15950 1.0 989.4 0.8X +Native ORC Vectorized (Pushdown) 15042 / 15668 1.0 956.3 0.9X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz Select 90% decimal(9, 2) rows (value < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 9587 / 10112 1.6 609.5 1.0X -Parquet Vectorized (Pushdown) 9726 / 10370 1.6 618.3 1.0X -Native ORC Vectorized 10119 / 11147 1.6 643.4 0.9X -Native ORC Vectorized (Pushdown) 9366 / 9497 1.7 595.5 1.0X +Parquet Vectorized 14303 / 14616 1.1 909.3 1.0X +Parquet Vectorized (Pushdown) 14380 / 14649 1.1 914.3 1.0X +Native ORC Vectorized 16964 / 17358 0.9 1078.5 0.8X +Native ORC Vectorized (Pushdown) 17255 / 17874 0.9 1097.0 0.8X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz Select 1 decimal(18, 2) row (value = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 4060 / 4093 3.9 258.1 1.0X -Parquet Vectorized (Pushdown) 4037 / 4125 3.9 256.6 1.0X -Native ORC Vectorized 4756 / 4811 3.3 302.4 0.9X -Native ORC Vectorized (Pushdown) 824 / 889 19.1 52.4 4.9X +Parquet Vectorized 4701 / 6416 3.3 298.9 1.0X +Parquet Vectorized (Pushdown) 128 / 164 122.8 8.1 36.7X +Native ORC Vectorized 5698 / 7904 2.8 362.3 0.8X +Native ORC Vectorized (Pushdown) 913 / 942 17.2 58.0 5.2X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz Select 10% decimal(18, 2) rows (value < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 5157 / 5271 3.0 327.9 1.0X -Parquet Vectorized (Pushdown) 5051 / 5141 3.1 321.1 1.0X -Native ORC Vectorized 5723 / 6146 2.7 363.9 0.9X -Native ORC Vectorized (Pushdown) 2198 / 2317 7.2 139.8 2.3X +Parquet Vectorized 5376 / 5461 2.9 341.8 1.0X +Parquet Vectorized (Pushdown) 1479 / 1543 10.6 94.0 3.6X +Native ORC Vectorized 6640 / 6748 2.4 422.2 0.8X +Native ORC Vectorized (Pushdown) 2438 / 2479 6.5 155.0 2.2X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz Select 50% decimal(18, 2) rows (value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 8608 / 8647 1.8 547.3 1.0X -Parquet Vectorized (Pushdown) 8471 / 8584 1.9 538.6 1.0X -Native ORC Vectorized 9249 / 10048 1.7 588.0 0.9X -Native ORC Vectorized (Pushdown) 7645 / 8091 2.1 486.1 1.1X +Parquet Vectorized 9224 / 9356 1.7 586.5 1.0X +Parquet Vectorized (Pushdown) 7172 / 7415 2.2 456.0 1.3X +Native ORC Vectorized 11017 / 11408 1.4 700.4 0.8X +Native ORC Vectorized (Pushdown) 8771 / 10218 1.8 557.7 1.1X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz Select 90% decimal(18, 2) rows (value < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 11658 / 11888 1.3 741.2 1.0X -Parquet Vectorized (Pushdown) 11812 / 12098 1.3 751.0 1.0X -Native ORC Vectorized 12943 / 13312 1.2 822.9 0.9X -Native ORC Vectorized (Pushdown) 13139 / 13465 1.2 835.4 0.9X +Parquet Vectorized 13933 / 15990 1.1 885.8 1.0X +Parquet Vectorized (Pushdown) 12683 / 12942 1.2 806.4 1.1X +Native ORC Vectorized 16344 / 20196 1.0 1039.1 0.9X +Native ORC Vectorized (Pushdown) 15162 / 16627 1.0 964.0 0.9X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz Select 1 decimal(38, 2) row (value = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 5491 / 5716 2.9 349.1 1.0X -Parquet Vectorized (Pushdown) 5515 / 5615 2.9 350.6 1.0X -Native ORC Vectorized 4582 / 4654 3.4 291.3 1.2X -Native ORC Vectorized (Pushdown) 815 / 861 19.3 51.8 6.7X +Parquet Vectorized 7102 / 8282 2.2 451.5 1.0X +Parquet Vectorized (Pushdown) 124 / 150 126.4 7.9 57.1X +Native ORC Vectorized 5811 / 6883 2.7 369.5 1.2X +Native ORC Vectorized (Pushdown) 1121 / 1502 14.0 71.3 6.3X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz Select 10% decimal(38, 2) rows (value < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 6432 / 6527 2.4 409.0 1.0X -Parquet Vectorized (Pushdown) 6513 / 6607 2.4 414.1 1.0X -Native ORC Vectorized 5618 / 6085 2.8 357.2 1.1X -Native ORC Vectorized (Pushdown) 2403 / 2443 6.5 152.8 2.7X +Parquet Vectorized 6894 / 7562 2.3 438.3 1.0X +Parquet Vectorized (Pushdown) 1863 / 1980 8.4 118.4 3.7X +Native ORC Vectorized 6812 / 6848 2.3 433.1 1.0X +Native ORC Vectorized (Pushdown) 2511 / 2598 6.3 159.7 2.7X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz Select 50% decimal(38, 2) rows (value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 11041 / 11467 1.4 701.9 1.0X -Parquet Vectorized (Pushdown) 10909 / 11484 1.4 693.5 1.0X -Native ORC Vectorized 9860 / 10436 1.6 626.9 1.1X -Native ORC Vectorized (Pushdown) 7908 / 8069 2.0 502.8 1.4X +Parquet Vectorized 11732 / 12183 1.3 745.9 1.0X +Parquet Vectorized (Pushdown) 8912 / 9945 1.8 566.6 1.3X +Native ORC Vectorized 11499 / 12387 1.4 731.1 1.0X +Native ORC Vectorized (Pushdown) 9328 / 9382 1.7 593.1 1.3X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz Select 90% decimal(38, 2) rows (value < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 14816 / 16877 1.1 942.0 1.0X -Parquet Vectorized (Pushdown) 15383 / 15740 1.0 978.0 1.0X -Native ORC Vectorized 14408 / 14771 1.1 916.0 1.0X -Native ORC Vectorized (Pushdown) 13968 / 14805 1.1 888.1 1.1X +Parquet Vectorized 16272 / 16328 1.0 1034.6 1.0X +Parquet Vectorized (Pushdown) 15714 / 18100 1.0 999.1 1.0X +Native ORC Vectorized 16539 / 18897 1.0 1051.5 1.0X +Native ORC Vectorized (Pushdown) 16328 / 17306 1.0 1038.1 1.0X ================================================================================================ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 3ec33b2f4b540..295960b1c2d30 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -342,6 +342,7 @@ class ParquetFileFormat val returningBatch = supportBatch(sparkSession, resultSchema) val pushDownDate = sqlConf.parquetFilterPushDownDate val pushDownTimestamp = sqlConf.parquetFilterPushDownTimestamp + val pushDownDecimal = sqlConf.parquetFilterPushDownDecimal val pushDownStringStartWith = sqlConf.parquetFilterPushDownStringStartWith val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold @@ -367,7 +368,7 @@ class ParquetFileFormat val pushed = if (enableParquetFilterPushDown) { val parquetSchema = ParquetFileReader.readFooter(sharedConf, filePath, SKIP_ROW_GROUPS) .getFileMetaData.getSchema - val parquetFilters = new ParquetFilters(pushDownDate, pushDownTimestamp, + val parquetFilters = new ParquetFilters(pushDownDate, pushDownTimestamp, pushDownDecimal, pushDownStringStartWith, pushDownInFilterThreshold) filters // Collects all converted Parquet filter predicates. Notice that not all predicates can be diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index 0c146f2f6f915..58b4a769fcb62 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.execution.datasources.parquet -import java.lang.{Long => JLong} +import java.lang.{Boolean => JBoolean, Double => JDouble, Float => JFloat, Long => JLong} +import java.math.{BigDecimal => JBigDecimal} import java.sql.{Date, Timestamp} import scala.collection.JavaConverters.asScalaBufferConverter @@ -41,44 +42,65 @@ import org.apache.spark.unsafe.types.UTF8String private[parquet] class ParquetFilters( pushDownDate: Boolean, pushDownTimestamp: Boolean, + pushDownDecimal: Boolean, pushDownStartWith: Boolean, pushDownInFilterThreshold: Int) { private case class ParquetSchemaType( originalType: OriginalType, primitiveTypeName: PrimitiveTypeName, + length: Int, decimalMetadata: DecimalMetadata) - private val ParquetBooleanType = ParquetSchemaType(null, BOOLEAN, null) - private val ParquetByteType = ParquetSchemaType(INT_8, INT32, null) - private val ParquetShortType = ParquetSchemaType(INT_16, INT32, null) - private val ParquetIntegerType = ParquetSchemaType(null, INT32, null) - private val ParquetLongType = ParquetSchemaType(null, INT64, null) - private val ParquetFloatType = ParquetSchemaType(null, FLOAT, null) - private val ParquetDoubleType = ParquetSchemaType(null, DOUBLE, null) - private val ParquetStringType = ParquetSchemaType(UTF8, BINARY, null) - private val ParquetBinaryType = ParquetSchemaType(null, BINARY, null) - private val ParquetDateType = ParquetSchemaType(DATE, INT32, null) - private val ParquetTimestampMicrosType = ParquetSchemaType(TIMESTAMP_MICROS, INT64, null) - private val ParquetTimestampMillisType = ParquetSchemaType(TIMESTAMP_MILLIS, INT64, null) + private val ParquetBooleanType = ParquetSchemaType(null, BOOLEAN, 0, null) + private val ParquetByteType = ParquetSchemaType(INT_8, INT32, 0, null) + private val ParquetShortType = ParquetSchemaType(INT_16, INT32, 0, null) + private val ParquetIntegerType = ParquetSchemaType(null, INT32, 0, null) + private val ParquetLongType = ParquetSchemaType(null, INT64, 0, null) + private val ParquetFloatType = ParquetSchemaType(null, FLOAT, 0, null) + private val ParquetDoubleType = ParquetSchemaType(null, DOUBLE, 0, null) + private val ParquetStringType = ParquetSchemaType(UTF8, BINARY, 0, null) + private val ParquetBinaryType = ParquetSchemaType(null, BINARY, 0, null) + private val ParquetDateType = ParquetSchemaType(DATE, INT32, 0, null) + private val ParquetTimestampMicrosType = ParquetSchemaType(TIMESTAMP_MICROS, INT64, 0, null) + private val ParquetTimestampMillisType = ParquetSchemaType(TIMESTAMP_MILLIS, INT64, 0, null) private def dateToDays(date: Date): SQLDate = { DateTimeUtils.fromJavaDate(date) } + private def decimalToInt32(decimal: JBigDecimal): Integer = decimal.unscaledValue().intValue() + + private def decimalToInt64(decimal: JBigDecimal): JLong = decimal.unscaledValue().longValue() + + private def decimalToByteArray(decimal: JBigDecimal, numBytes: Int): Binary = { + val decimalBuffer = new Array[Byte](numBytes) + val bytes = decimal.unscaledValue().toByteArray + + val fixedLengthBytes = if (bytes.length == numBytes) { + bytes + } else { + val signByte = if (bytes.head < 0) -1: Byte else 0: Byte + java.util.Arrays.fill(decimalBuffer, 0, numBytes - bytes.length, signByte) + System.arraycopy(bytes, 0, decimalBuffer, numBytes - bytes.length, bytes.length) + decimalBuffer + } + Binary.fromConstantByteArray(fixedLengthBytes, 0, numBytes) + } + private val makeEq: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { case ParquetBooleanType => - (n: String, v: Any) => FilterApi.eq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean]) + (n: String, v: Any) => FilterApi.eq(booleanColumn(n), v.asInstanceOf[JBoolean]) case ParquetByteType | ParquetShortType | ParquetIntegerType => (n: String, v: Any) => FilterApi.eq( intColumn(n), Option(v).map(_.asInstanceOf[Number].intValue.asInstanceOf[Integer]).orNull) case ParquetLongType => - (n: String, v: Any) => FilterApi.eq(longColumn(n), v.asInstanceOf[java.lang.Long]) + (n: String, v: Any) => FilterApi.eq(longColumn(n), v.asInstanceOf[JLong]) case ParquetFloatType => - (n: String, v: Any) => FilterApi.eq(floatColumn(n), v.asInstanceOf[java.lang.Float]) + (n: String, v: Any) => FilterApi.eq(floatColumn(n), v.asInstanceOf[JFloat]) case ParquetDoubleType => - (n: String, v: Any) => FilterApi.eq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) + (n: String, v: Any) => FilterApi.eq(doubleColumn(n), v.asInstanceOf[JDouble]) // Binary.fromString and Binary.fromByteArray don't accept null values case ParquetStringType => @@ -102,21 +124,34 @@ private[parquet] class ParquetFilters( (n: String, v: Any) => FilterApi.eq( longColumn(n), Option(v).map(_.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]).orNull) + + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + (n: String, v: Any) => FilterApi.eq( + intColumn(n), + Option(v).map(d => decimalToInt32(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + (n: String, v: Any) => FilterApi.eq( + longColumn(n), + Option(v).map(d => decimalToInt64(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + (n: String, v: Any) => FilterApi.eq( + binaryColumn(n), + Option(v).map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)).orNull) } private val makeNotEq: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { case ParquetBooleanType => - (n: String, v: Any) => FilterApi.notEq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean]) + (n: String, v: Any) => FilterApi.notEq(booleanColumn(n), v.asInstanceOf[JBoolean]) case ParquetByteType | ParquetShortType | ParquetIntegerType => (n: String, v: Any) => FilterApi.notEq( intColumn(n), Option(v).map(_.asInstanceOf[Number].intValue.asInstanceOf[Integer]).orNull) case ParquetLongType => - (n: String, v: Any) => FilterApi.notEq(longColumn(n), v.asInstanceOf[java.lang.Long]) + (n: String, v: Any) => FilterApi.notEq(longColumn(n), v.asInstanceOf[JLong]) case ParquetFloatType => - (n: String, v: Any) => FilterApi.notEq(floatColumn(n), v.asInstanceOf[java.lang.Float]) + (n: String, v: Any) => FilterApi.notEq(floatColumn(n), v.asInstanceOf[JFloat]) case ParquetDoubleType => - (n: String, v: Any) => FilterApi.notEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) + (n: String, v: Any) => FilterApi.notEq(doubleColumn(n), v.asInstanceOf[JDouble]) case ParquetStringType => (n: String, v: Any) => FilterApi.notEq( @@ -139,6 +174,19 @@ private[parquet] class ParquetFilters( (n: String, v: Any) => FilterApi.notEq( longColumn(n), Option(v).map(_.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]).orNull) + + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + (n: String, v: Any) => FilterApi.notEq( + intColumn(n), + Option(v).map(d => decimalToInt32(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + (n: String, v: Any) => FilterApi.notEq( + longColumn(n), + Option(v).map(d => decimalToInt64(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + (n: String, v: Any) => FilterApi.notEq( + binaryColumn(n), + Option(v).map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)).orNull) } private val makeLt: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { @@ -146,11 +194,11 @@ private[parquet] class ParquetFilters( (n: String, v: Any) => FilterApi.lt(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) case ParquetLongType => - (n: String, v: Any) => FilterApi.lt(longColumn(n), v.asInstanceOf[java.lang.Long]) + (n: String, v: Any) => FilterApi.lt(longColumn(n), v.asInstanceOf[JLong]) case ParquetFloatType => - (n: String, v: Any) => FilterApi.lt(floatColumn(n), v.asInstanceOf[java.lang.Float]) + (n: String, v: Any) => FilterApi.lt(floatColumn(n), v.asInstanceOf[JFloat]) case ParquetDoubleType => - (n: String, v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[java.lang.Double]) + (n: String, v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[JDouble]) case ParquetStringType => (n: String, v: Any) => @@ -169,6 +217,16 @@ private[parquet] class ParquetFilters( (n: String, v: Any) => FilterApi.lt( longColumn(n), v.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]) + + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.lt(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.lt(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.lt(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) } private val makeLtEq: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { @@ -176,11 +234,11 @@ private[parquet] class ParquetFilters( (n: String, v: Any) => FilterApi.ltEq(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) case ParquetLongType => - (n: String, v: Any) => FilterApi.ltEq(longColumn(n), v.asInstanceOf[java.lang.Long]) + (n: String, v: Any) => FilterApi.ltEq(longColumn(n), v.asInstanceOf[JLong]) case ParquetFloatType => - (n: String, v: Any) => FilterApi.ltEq(floatColumn(n), v.asInstanceOf[java.lang.Float]) + (n: String, v: Any) => FilterApi.ltEq(floatColumn(n), v.asInstanceOf[JFloat]) case ParquetDoubleType => - (n: String, v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) + (n: String, v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[JDouble]) case ParquetStringType => (n: String, v: Any) => @@ -199,6 +257,16 @@ private[parquet] class ParquetFilters( (n: String, v: Any) => FilterApi.ltEq( longColumn(n), v.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]) + + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.ltEq(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.ltEq(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.ltEq(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) } private val makeGt: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { @@ -206,11 +274,11 @@ private[parquet] class ParquetFilters( (n: String, v: Any) => FilterApi.gt(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) case ParquetLongType => - (n: String, v: Any) => FilterApi.gt(longColumn(n), v.asInstanceOf[java.lang.Long]) + (n: String, v: Any) => FilterApi.gt(longColumn(n), v.asInstanceOf[JLong]) case ParquetFloatType => - (n: String, v: Any) => FilterApi.gt(floatColumn(n), v.asInstanceOf[java.lang.Float]) + (n: String, v: Any) => FilterApi.gt(floatColumn(n), v.asInstanceOf[JFloat]) case ParquetDoubleType => - (n: String, v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[java.lang.Double]) + (n: String, v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[JDouble]) case ParquetStringType => (n: String, v: Any) => @@ -229,6 +297,16 @@ private[parquet] class ParquetFilters( (n: String, v: Any) => FilterApi.gt( longColumn(n), v.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]) + + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.gt(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.gt(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.gt(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) } private val makeGtEq: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { @@ -236,11 +314,11 @@ private[parquet] class ParquetFilters( (n: String, v: Any) => FilterApi.gtEq(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) case ParquetLongType => - (n: String, v: Any) => FilterApi.gtEq(longColumn(n), v.asInstanceOf[java.lang.Long]) + (n: String, v: Any) => FilterApi.gtEq(longColumn(n), v.asInstanceOf[JLong]) case ParquetFloatType => - (n: String, v: Any) => FilterApi.gtEq(floatColumn(n), v.asInstanceOf[java.lang.Float]) + (n: String, v: Any) => FilterApi.gtEq(floatColumn(n), v.asInstanceOf[JFloat]) case ParquetDoubleType => - (n: String, v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) + (n: String, v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[JDouble]) case ParquetStringType => (n: String, v: Any) => @@ -259,6 +337,16 @@ private[parquet] class ParquetFilters( (n: String, v: Any) => FilterApi.gtEq( longColumn(n), v.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]) + + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.gtEq(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.gtEq(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.gtEq(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) } /** @@ -271,7 +359,7 @@ private[parquet] class ParquetFilters( // and it does not support to create filters for them. m.getFields.asScala.filter(_.isPrimitive).map(_.asPrimitiveType()).map { f => f.getName -> ParquetSchemaType( - f.getOriginalType, f.getPrimitiveTypeName, f.getDecimalMetadata) + f.getOriginalType, f.getPrimitiveTypeName, f.getTypeLength, f.getDecimalMetadata) }.toMap case _ => Map.empty[String, ParquetSchemaType] } @@ -282,21 +370,45 @@ private[parquet] class ParquetFilters( def createFilter(schema: MessageType, predicate: sources.Filter): Option[FilterPredicate] = { val nameToType = getFieldMap(schema) + // Decimal type must make sure that filter value's scale matched the file. + // If doesn't matched, which would cause data corruption. + def isDecimalMatched(value: Any, decimalMeta: DecimalMetadata): Boolean = value match { + case decimal: JBigDecimal => + decimal.scale == decimalMeta.getScale + case _ => false + } + + // Parquet's type in the given file should be matched to the value's type + // in the pushed filter in order to push down the filter to Parquet. + def valueCanMakeFilterOn(name: String, value: Any): Boolean = { + value == null || (nameToType(name) match { + case ParquetBooleanType => value.isInstanceOf[JBoolean] + case ParquetByteType | ParquetShortType | ParquetIntegerType => value.isInstanceOf[Number] + case ParquetLongType => value.isInstanceOf[JLong] + case ParquetFloatType => value.isInstanceOf[JFloat] + case ParquetDoubleType => value.isInstanceOf[JDouble] + case ParquetStringType => value.isInstanceOf[String] + case ParquetBinaryType => value.isInstanceOf[Array[Byte]] + case ParquetDateType => value.isInstanceOf[Date] + case ParquetTimestampMicrosType | ParquetTimestampMillisType => + value.isInstanceOf[Timestamp] + case ParquetSchemaType(DECIMAL, INT32, _, decimalMeta) => + isDecimalMatched(value, decimalMeta) + case ParquetSchemaType(DECIMAL, INT64, _, decimalMeta) => + isDecimalMatched(value, decimalMeta) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, _, decimalMeta) => + isDecimalMatched(value, decimalMeta) + case _ => false + }) + } + // Parquet does not allow dots in the column name because dots are used as a column path // delimiter. Since Parquet 1.8.2 (PARQUET-389), Parquet accepts the filter predicates // with missing columns. The incorrect results could be got from Parquet when we push down // filters for the column having dots in the names. Thus, we do not push down such filters. // See SPARK-20364. - def canMakeFilterOn(name: String): Boolean = nameToType.contains(name) && !name.contains(".") - - // All DataTypes that support `makeEq` can provide better performance. - def shouldConvertInPredicate(name: String): Boolean = nameToType(name) match { - case ParquetBooleanType | ParquetByteType | ParquetShortType | ParquetIntegerType - | ParquetLongType | ParquetFloatType | ParquetDoubleType | ParquetStringType - | ParquetBinaryType => true - case ParquetDateType if pushDownDate => true - case ParquetTimestampMicrosType | ParquetTimestampMillisType if pushDownTimestamp => true - case _ => false + def canMakeFilterOn(name: String, value: Any): Boolean = { + nameToType.contains(name) && !name.contains(".") && valueCanMakeFilterOn(name, value) } // NOTE: @@ -315,29 +427,29 @@ private[parquet] class ParquetFilters( // Probably I missed something and obviously this should be changed. predicate match { - case sources.IsNull(name) if canMakeFilterOn(name) => + case sources.IsNull(name) if canMakeFilterOn(name, null) => makeEq.lift(nameToType(name)).map(_(name, null)) - case sources.IsNotNull(name) if canMakeFilterOn(name) => + case sources.IsNotNull(name) if canMakeFilterOn(name, null) => makeNotEq.lift(nameToType(name)).map(_(name, null)) - case sources.EqualTo(name, value) if canMakeFilterOn(name) => + case sources.EqualTo(name, value) if canMakeFilterOn(name, value) => makeEq.lift(nameToType(name)).map(_(name, value)) - case sources.Not(sources.EqualTo(name, value)) if canMakeFilterOn(name) => + case sources.Not(sources.EqualTo(name, value)) if canMakeFilterOn(name, value) => makeNotEq.lift(nameToType(name)).map(_(name, value)) - case sources.EqualNullSafe(name, value) if canMakeFilterOn(name) => + case sources.EqualNullSafe(name, value) if canMakeFilterOn(name, value) => makeEq.lift(nameToType(name)).map(_(name, value)) - case sources.Not(sources.EqualNullSafe(name, value)) if canMakeFilterOn(name) => + case sources.Not(sources.EqualNullSafe(name, value)) if canMakeFilterOn(name, value) => makeNotEq.lift(nameToType(name)).map(_(name, value)) - case sources.LessThan(name, value) if canMakeFilterOn(name) => + case sources.LessThan(name, value) if canMakeFilterOn(name, value) => makeLt.lift(nameToType(name)).map(_(name, value)) - case sources.LessThanOrEqual(name, value) if canMakeFilterOn(name) => + case sources.LessThanOrEqual(name, value) if canMakeFilterOn(name, value) => makeLtEq.lift(nameToType(name)).map(_(name, value)) - case sources.GreaterThan(name, value) if canMakeFilterOn(name) => + case sources.GreaterThan(name, value) if canMakeFilterOn(name, value) => makeGt.lift(nameToType(name)).map(_(name, value)) - case sources.GreaterThanOrEqual(name, value) if canMakeFilterOn(name) => + case sources.GreaterThanOrEqual(name, value) if canMakeFilterOn(name, value) => makeGtEq.lift(nameToType(name)).map(_(name, value)) case sources.And(lhs, rhs) => @@ -362,13 +474,14 @@ private[parquet] class ParquetFilters( case sources.Not(pred) => createFilter(schema, pred).map(FilterApi.not) - case sources.In(name, values) if canMakeFilterOn(name) && shouldConvertInPredicate(name) + case sources.In(name, values) if canMakeFilterOn(name, values.head) && values.distinct.length <= pushDownInFilterThreshold => values.distinct.flatMap { v => makeEq.lift(nameToType(name)).map(_(name, v)) }.reduceLeftOption(FilterApi.or) - case sources.StringStartsWith(name, prefix) if pushDownStartWith && canMakeFilterOn(name) => + case sources.StringStartsWith(name, prefix) + if pushDownStartWith && canMakeFilterOn(name, prefix) => Option(prefix).map { v => FilterApi.userDefined(binaryColumn(name), new UserDefinedPredicate[Binary] with Serializable { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala index 567a8ebf9d102..bdb60b44750c7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala @@ -290,8 +290,12 @@ class FilterPushdownBenchmark extends SparkFunSuite with BenchmarkBeforeAndAfter s"decimal(${DecimalType.MAX_PRECISION}, 2)" ).foreach { dt => val columns = (1 to width).map(i => s"CAST(id AS string) c$i") - val df = spark.range(numRows).selectExpr(columns: _*) - .withColumn("value", monotonically_increasing_id().cast(dt)) + val valueCol = if (dt.equalsIgnoreCase(s"decimal(${Decimal.MAX_INT_DIGITS}, 2)")) { + monotonically_increasing_id() % 9999999 + } else { + monotonically_increasing_id() + } + val df = spark.range(numRows).selectExpr(columns: _*).withColumn("value", valueCol.cast(dt)) withTempTable("orcTable", "patquetTable") { saveAsTable(df, dir) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 924f136503656..be4f498c921ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.datasources.parquet +import java.math.{BigDecimal => JBigDecimal} import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} @@ -58,7 +59,8 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex private lazy val parquetFilters = new ParquetFilters(conf.parquetFilterPushDownDate, conf.parquetFilterPushDownTimestamp, - conf.parquetFilterPushDownStringStartWith, conf.parquetFilterPushDownInFilterThreshold) + conf.parquetFilterPushDownDecimal, conf.parquetFilterPushDownStringStartWith, + conf.parquetFilterPushDownInFilterThreshold) override def beforeEach(): Unit = { super.beforeEach() @@ -86,6 +88,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true", SQLConf.PARQUET_FILTER_PUSHDOWN_DATE_ENABLED.key -> "true", SQLConf.PARQUET_FILTER_PUSHDOWN_TIMESTAMP_ENABLED.key -> "true", + SQLConf.PARQUET_FILTER_PUSHDOWN_DECIMAL_ENABLED.key -> "true", SQLConf.PARQUET_FILTER_PUSHDOWN_STRING_STARTSWITH_ENABLED.key -> "true", SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { val query = df @@ -179,6 +182,13 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } + private def testDecimalPushDown(data: DataFrame)(f: DataFrame => Unit): Unit = { + withTempPath { file => + data.write.parquet(file.getCanonicalPath) + readParquetFile(file.toString)(f) + } + } + // This function tests that exactly go through the `canDrop` and `inverseCanDrop`. private def testStringStartsWith(dataFrame: DataFrame, filter: String): Unit = { withTempPath { dir => @@ -512,6 +522,84 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } + test("filter pushdown - decimal") { + Seq(true, false).foreach { legacyFormat => + withSQLConf(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> legacyFormat.toString) { + Seq( + s"a decimal(${Decimal.MAX_INT_DIGITS}, 2)", // 32BitDecimalType + s"a decimal(${Decimal.MAX_LONG_DIGITS}, 2)", // 64BitDecimalType + "a decimal(38, 18)" // ByteArrayDecimalType + ).foreach { schemaDDL => + val schema = StructType.fromDDL(schemaDDL) + val rdd = + spark.sparkContext.parallelize((1 to 4).map(i => Row(new java.math.BigDecimal(i)))) + val dataFrame = spark.createDataFrame(rdd, schema) + testDecimalPushDown(dataFrame) { implicit df => + assert(df.schema === schema) + checkFilterPredicate('a.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('a.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) + + checkFilterPredicate('a === 1, classOf[Eq[_]], 1) + checkFilterPredicate('a <=> 1, classOf[Eq[_]], 1) + checkFilterPredicate('a =!= 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + + checkFilterPredicate('a < 2, classOf[Lt[_]], 1) + checkFilterPredicate('a > 3, classOf[Gt[_]], 4) + checkFilterPredicate('a <= 1, classOf[LtEq[_]], 1) + checkFilterPredicate('a >= 4, classOf[GtEq[_]], 4) + + checkFilterPredicate(Literal(1) === 'a, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(1) <=> 'a, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2) > 'a, classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3) < 'a, classOf[Gt[_]], 4) + checkFilterPredicate(Literal(1) >= 'a, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= 'a, classOf[GtEq[_]], 4) + + checkFilterPredicate(!('a < 4), classOf[GtEq[_]], 4) + checkFilterPredicate('a < 2 || 'a > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) + } + } + } + } + } + + test("Ensure that filter value matched the parquet file schema") { + val scale = 2 + val schema = StructType(Seq( + StructField("cint", IntegerType), + StructField("cdecimal1", DecimalType(Decimal.MAX_INT_DIGITS, scale)), + StructField("cdecimal2", DecimalType(Decimal.MAX_LONG_DIGITS, scale)), + StructField("cdecimal3", DecimalType(DecimalType.MAX_PRECISION, scale)) + )) + + val parquetSchema = new SparkToParquetSchemaConverter(conf).convert(schema) + + val decimal = new JBigDecimal(10).setScale(scale) + val decimal1 = new JBigDecimal(10).setScale(scale + 1) + assert(decimal.scale() === scale) + assert(decimal1.scale() === scale + 1) + + assertResult(Some(lt(intColumn("cdecimal1"), 1000: Integer))) { + parquetFilters.createFilter(parquetSchema, sources.LessThan("cdecimal1", decimal)) + } + assertResult(None) { + parquetFilters.createFilter(parquetSchema, sources.LessThan("cdecimal1", decimal1)) + } + + assertResult(Some(lt(longColumn("cdecimal2"), 1000L: java.lang.Long))) { + parquetFilters.createFilter(parquetSchema, sources.LessThan("cdecimal2", decimal)) + } + assertResult(None) { + parquetFilters.createFilter(parquetSchema, sources.LessThan("cdecimal2", decimal1)) + } + + assert(parquetFilters.createFilter( + parquetSchema, sources.LessThan("cdecimal3", decimal)).isDefined) + assertResult(None) { + parquetFilters.createFilter(parquetSchema, sources.LessThan("cdecimal3", decimal1)) + } + } + test("SPARK-6554: don't push down predicates which reference partition columns") { import testImplicits._