diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala index 0c3d8fdab6bd..41756c0586a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.parquet +import java.math.BigDecimal +import java.math.BigInteger import java.nio.ByteOrder import scala.collection.JavaConversions._ @@ -263,17 +265,21 @@ private[parquet] class CatalystRowConverter( val scale = decimalType.scale val bytes = value.getBytes - var unscaled = 0L - var i = 0 + if (value.length <= 8) { + var unscaled = 0L + var i = 0 - while (i < bytes.length) { - unscaled = (unscaled << 8) | (bytes(i) & 0xff) - i += 1 - } + while (i < bytes.length) { + unscaled = (unscaled << 8) | (bytes(i) & 0xff) + i += 1 + } - val bits = 8 * bytes.length - unscaled = (unscaled << (64 - bits)) >> (64 - bits) - Decimal(unscaled, precision, scale) + val bits = 8 * bytes.length + unscaled = (unscaled << (64 - bits)) >> (64 - bits) + Decimal(unscaled, precision, scale) + } else { + Decimal(new BigDecimal(new BigInteger(bytes), scale), precision, scale) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala index 1ea6926af6d5..1317f7ce5cc5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala @@ -388,6 +388,8 @@ private[parquet] class CatalystSchemaConverter( // Spark 1.4.x and prior versions only support decimals with a maximum precision of 18 and // always store decimals in fixed-length byte arrays. + // Always storing FIXED_LEN_BYTE_ARRAY is thus compatible with spark <= 1.4.x, except for + // precisions > 18. case DecimalType.Fixed(precision, scale) if precision <= maxPrecisionForBytes(8) && !followParquetFormatSpec => Types @@ -411,7 +413,7 @@ private[parquet] class CatalystSchemaConverter( // Uses INT32 for 1 <= precision <= 9 case DecimalType.Fixed(precision, scale) - if precision <= maxPrecisionForBytes(4) && followParquetFormatSpec => + if followParquetFormatSpec && precision <= maxPrecisionForBytes(4) => Types .primitive(INT32, repetition) .as(DECIMAL) @@ -421,7 +423,7 @@ private[parquet] class CatalystSchemaConverter( // Uses INT64 for 1 <= precision <= 18 case DecimalType.Fixed(precision, scale) - if precision <= maxPrecisionForBytes(8) && followParquetFormatSpec => + if followParquetFormatSpec && precision <= maxPrecisionForBytes(8) => Types .primitive(INT64, repetition) .as(DECIMAL) @@ -553,14 +555,25 @@ private[parquet] class CatalystSchemaConverter( .asInstanceOf[Int] } - // Min byte counts needed to store decimals with various precisions - private val minBytesForPrecision: Array[Int] = Array.tabulate(38) { precision => + private def minBytesForPrecisionCompute(precision : Int) : Int = { var numBytes = 1 while (math.pow(2.0, 8 * numBytes - 1) < math.pow(10.0, precision)) { numBytes += 1 } numBytes } + + private val minBytesForPrecisionStatic: Array[Int] = Array.tabulate(39) { + minBytesForPrecisionCompute + } + + // Min byte counts needed to store decimals with various precisions + private def minBytesForPrecision(precision : Int) : Int = + if (precision < minBytesForPrecisionStatic.length) { + minBytesForPrecisionStatic(precision) + } else { + minBytesForPrecisionCompute(precision) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 9058b0937529..9dc37412cca3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -206,7 +206,13 @@ private[sql] case class ParquetTableScan( * @return Pruned TableScan. */ def pruneColumns(prunedAttributes: Seq[Attribute]): ParquetTableScan = { - val success = validateProjection(prunedAttributes) + val sc = sqlContext.sparkContext + val job = new Job(sc.hadoopConfiguration) + ParquetInputFormat.setReadSupportClass(job, classOf[RowReadSupport]) + + val conf: Configuration = ContextUtil.getConfiguration(job) + + val success = validateProjection(prunedAttributes, conf) if (success) { ParquetTableScan(prunedAttributes, relation, columnPruningPred) } else { @@ -221,9 +227,9 @@ private[sql] case class ParquetTableScan( * @param projection The candidate projection. * @return True if the projection is valid, false otherwise. */ - private def validateProjection(projection: Seq[Attribute]): Boolean = { + private def validateProjection(projection: Seq[Attribute], conf : Configuration): Boolean = { val original: MessageType = relation.parquetSchema - val candidate: MessageType = ParquetTypesConverter.convertFromAttributes(projection) + val candidate: MessageType = ParquetTypesConverter.convertFromAttributes(projection, conf) Try(original.checkContains(candidate)).isSuccess } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index e8851ddb6802..582fc959ebc6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -35,6 +35,8 @@ import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -184,6 +186,9 @@ private[parquet] object RowReadSupport { */ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Logging { + var followParquetFormatSpec: Boolean = + SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.defaultValue.get + private[parquet] var writer: RecordConsumer = null private[parquet] var attributes: Array[Attribute] = null @@ -198,7 +203,8 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo log.debug(s"write support initialized for requested schema $attributes") ParquetRelation.enableLogForwarding() - new WriteSupport.WriteContext(ParquetTypesConverter.convertFromAttributes(attributes), metadata) + new WriteSupport.WriteContext( + ParquetTypesConverter.convertFromAttributes(attributes, configuration), metadata) } override def prepareForWrite(recordConsumer: RecordConsumer): Unit = { @@ -261,8 +267,9 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo case BinaryType => writer.addBinary( Binary.fromByteArray(value.asInstanceOf[Array[Byte]])) case d: DecimalType => - if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) { - sys.error(s"Unsupported datatype $d, cannot write to consumer") + if (d.precisionInfo == None) { + throw new AnalysisException( + s"Unsupported datatype $d, decimal precision and scale must be specified") } writeDecimal(value.asInstanceOf[Decimal], d.precisionInfo.get.precision) case _ => sys.error(s"Do not know how to writer $schema to consumer") @@ -346,19 +353,54 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo } // Scratch array used to write decimals as fixed-length binary - private[this] val scratchBytes = new Array[Byte](8) + private[this] val scratchBytes = new Array[Byte](16) private[parquet] def writeDecimal(decimal: Decimal, precision: Int): Unit = { - val numBytes = ParquetTypesConverter.BYTES_FOR_PRECISION(precision) - val unscaledLong = decimal.toUnscaledLong - var i = 0 - var shift = 8 * (numBytes - 1) - while (i < numBytes) { - scratchBytes(i) = (unscaledLong >> shift).toByte - i += 1 - shift -= 8 + if (precision <= 18) { + if (followParquetFormatSpec) { + if (ParquetTypesConverter.bytesForPrecision(precision) <= 4) { + writer.addInteger(decimal.toUnscaledLong.toInt) + } else { + writer.addLong(decimal.toUnscaledLong) + } + } else { + val numBytes = ParquetTypesConverter.bytesForPrecision(precision) + val unscaledLong = decimal.toUnscaledLong + var i = 0 + var shift = 8 * (numBytes - 1) + while (i < numBytes) { + scratchBytes(i) = (unscaledLong >> shift).toByte + i += 1 + shift -= 8 + } + writer.addBinary(Binary.fromByteArray(scratchBytes, 0, numBytes)) + } + } else { + val numBytes = ParquetTypesConverter.bytesForPrecision(precision) + val bytes = decimal.toBigDecimal.underlying.unscaledValue.toByteArray() + val bin = + if (bytes.length == numBytes) { + Binary.fromByteArray(bytes) + } else { + if (numBytes <= scratchBytes.length) { + if (bytes(0) >= 0) { + java.util.Arrays.fill(scratchBytes, 0, numBytes - bytes.length, 0.toByte) + } else { + java.util.Arrays.fill(scratchBytes, 0, numBytes - bytes.length, -1.toByte) + } + System.arraycopy(bytes, 0, scratchBytes, numBytes - bytes.length, bytes.length) + Binary.fromByteArray(scratchBytes, 0, numBytes) + } else { + val buf = new Array[Byte](numBytes) + if (bytes(0) < 0) { + java.util.Arrays.fill(buf, 0, numBytes - bytes.length, -1.toByte) + } + System.arraycopy(bytes, 0, buf, numBytes - bytes.length, bytes.length) + Binary.fromByteArray(buf) + } + } + writer.addBinary(bin) } - writer.addBinary(Binary.fromByteArray(scratchBytes, 0, numBytes)) } // array used to write Timestamp as Int96 (fixed-length binary) @@ -415,7 +457,7 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport { case BinaryType => writer.addBinary( Binary.fromByteArray(record(index).asInstanceOf[Array[Byte]])) case d: DecimalType => - if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) { + if (d.precisionInfo == None) { sys.error(s"Unsupported datatype $d, cannot write to consumer") } writeDecimal(record(index).asInstanceOf[Decimal], d.precisionInfo.get.precision) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index e748bd7857bd..760440e31cf4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -42,28 +42,40 @@ private[parquet] object ParquetTypesConverter extends Logging { case _ => false } - /** - * Compute the FIXED_LEN_BYTE_ARRAY length needed to represent a given DECIMAL precision. - */ - private[parquet] val BYTES_FOR_PRECISION = Array.tabulate[Int](38) { precision => - var length = 1 + private[this] def bytesForPrecisionCompute(precision : Int) : Int = { + var length = (precision / math.log10(2) - 1).toInt / 8 while (math.pow(2.0, 8 * length - 1) < math.pow(10.0, precision)) { length += 1 } length } + private[this] val bytesForPrecisionStatic = (0 to 39).map(bytesForPrecisionCompute).toArray + + /** + * bytesForPrecision computes the number of bytes required to store a value of a certain decimal + * precision. + */ + private[parquet] def bytesForPrecision(precision : Int) : Int = + if (precision < bytesForPrecisionStatic.length) { + bytesForPrecisionStatic(precision) + } else { + bytesForPrecisionCompute(precision) + } + def convertToAttributes( parquetSchema: MessageType, isBinaryAsString: Boolean, - isInt96AsTimestamp: Boolean): Seq[Attribute] = { + isInt96AsTimestamp: Boolean, + followParquetFormatSpec : Boolean = false): Seq[Attribute] = { val converter = new CatalystSchemaConverter( - isBinaryAsString, isInt96AsTimestamp, followParquetFormatSpec = false) + isBinaryAsString, isInt96AsTimestamp, followParquetFormatSpec) converter.convert(parquetSchema).toAttributes } - def convertFromAttributes(attributes: Seq[Attribute]): MessageType = { - val converter = new CatalystSchemaConverter() + def convertFromAttributes(attributes: Seq[Attribute], + conf : Configuration): MessageType = { + val converter = new CatalystSchemaConverter(conf) converter.convert(StructType.fromAttributes(attributes)) } @@ -107,7 +119,7 @@ private[parquet] object ParquetTypesConverter extends Logging { ParquetTypesConverter.convertToString(attributes)) // TODO: add extra data, e.g., table name, date, etc.? - val parquetSchema: MessageType = ParquetTypesConverter.convertFromAttributes(attributes) + val parquetSchema: MessageType = ParquetTypesConverter.convertFromAttributes(attributes, conf) val metaData: FileMetaData = new FileMetaData( parquetSchema, extraMetadata, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index 7b16eba00d6f..a43f18b2f158 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -107,29 +107,48 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { // Parquet doesn't allow column names with spaces, have to add an alias here .select($"_1" cast decimal as "dec") - for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17))) { - withTempPath { dir => - val data = makeDecimalRDD(DecimalType(precision, scale)) - data.write.parquet(dir.getCanonicalPath) - checkAnswer(sqlContext.read.parquet(dir.getCanonicalPath), data.collect().toSeq) + withSQLConf(SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.key -> "false")({ + for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17))) { + withTempPath { dir => + val data = makeDecimalRDD(DecimalType(precision, scale)) + data.write.parquet(dir.getCanonicalPath) + checkAnswer(sqlContext.read.parquet(dir.getCanonicalPath), data.collect().toSeq) + } } - } - // Decimals with precision above 18 are not yet supported - intercept[Throwable] { - withTempPath { dir => - makeDecimalRDD(DecimalType(19, 10)).write.parquet(dir.getCanonicalPath) - sqlContext.read.parquet(dir.getCanonicalPath).collect() + // Decimals with precision above 18 are not supported in compatibility mode + intercept[Throwable] { + withTempPath { dir => + makeDecimalRDD(DecimalType(19, 10)).write.parquet(dir.getCanonicalPath) + sqlContext.read.parquet(dir.getCanonicalPath).collect() + } } - } - // Unlimited-length decimals are not yet supported - intercept[Throwable] { - withTempPath { dir => - makeDecimalRDD(DecimalType.Unlimited).write.parquet(dir.getCanonicalPath) - sqlContext.read.parquet(dir.getCanonicalPath).collect() + // Unlimited-length decimals are not yet supported + intercept[Throwable] { + withTempPath { dir => + makeDecimalRDD(DecimalType.Unlimited).write.parquet(dir.getCanonicalPath) + sqlContext.read.parquet(dir.getCanonicalPath).collect() + } } - } + }) + withSQLConf(SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.key -> "true")({ + for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (6, 3), (18, 17), (19, 0), (38, 37), + (90, 0))) { + withTempPath { dir => + val data = makeDecimalRDD(DecimalType(precision, scale)) + data.write.parquet(dir.getCanonicalPath) + checkAnswer(sqlContext.read.parquet(dir.getCanonicalPath), data.collect().toSeq) + } + } + + intercept[Throwable] { + withTempPath { dir => + makeDecimalRDD(DecimalType.Unlimited).write.parquet(dir.getCanonicalPath) + sqlContext.read.parquet(dir.getCanonicalPath).collect() + } + } + }) } test("date type") { @@ -302,7 +321,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { val metaData = ParquetTypesConverter.readMetaData(path, Some(configuration)) val actualSchema = metaData.getFileMetaData.getSchema - val expectedSchema = ParquetTypesConverter.convertFromAttributes(attributes) + val expectedSchema = ParquetTypesConverter.convertFromAttributes(attributes, configuration) actualSchema.checkContains(expectedSchema) expectedSchema.checkContains(actualSchema)