Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,9 @@ final class Decimal extends Ordered[Decimal] with Serializable {
*/
def set(decimal: BigDecimal, precision: Int, scale: Int): Decimal = {
this.decimalVal = decimal.setScale(scale, ROUNDING_MODE)
require(decimalVal.precision <= precision, "Overflowed precision")
require(
decimalVal.precision <= precision,
s"Decimal precision ${decimalVal.precision} exceeds max precision $precision")
this.longVal = 0L
this._precision = precision
this._scale = scale
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ private[parquet] class CatalystReadSupport extends ReadSupport[InternalRow] with
""".stripMargin
}

new CatalystRecordMaterializer(parquetRequestedSchema, catalystRequestedSchema)
new CatalystRecordMaterializer(
parquetRequestedSchema,
CatalystReadSupport.expandUDT(catalystRequestedSchema))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Expands UDTs early so that CatalystRowConverter always receive a Catalyst schema without UDTs.

}
}

Expand All @@ -110,7 +112,10 @@ private[parquet] object CatalystReadSupport {
*/
def clipParquetSchema(parquetSchema: MessageType, catalystSchema: StructType): MessageType = {
val clippedParquetFields = clipParquetGroupFields(parquetSchema.asGroupType(), catalystSchema)
Types.buildMessage().addFields(clippedParquetFields: _*).named("root")
Types
.buildMessage()
.addFields(clippedParquetFields: _*)
.named(CatalystSchemaConverter.SPARK_PARQUET_SCHEMA_NAME)
}

private def clipParquetType(parquetType: Type, catalystType: DataType): Type = {
Expand Down Expand Up @@ -271,4 +276,30 @@ private[parquet] object CatalystReadSupport {
.getOrElse(toParquet.convertField(f))
}
}

def expandUDT(schema: StructType): StructType = {
def expand(dataType: DataType): DataType = {
dataType match {
case t: ArrayType =>
t.copy(elementType = expand(t.elementType))

case t: MapType =>
t.copy(
keyType = expand(t.keyType),
valueType = expand(t.valueType))

case t: StructType =>
val expandedFields = t.fields.map(f => f.copy(dataType = expand(f.dataType)))
t.copy(fields = expandedFields)

case t: UserDefinedType[_] =>
t.sqlType

case t =>
t
}
}

expand(schema).asInstanceOf[StructType]
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure whether this method is useful enough to be added as methods of all complex data types.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe not.

}
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ private[parquet] class CatalystPrimitiveConverter(val updater: ParentContainerUp
* any "parent" container.
*
* @param parquetType Parquet schema of Parquet records
* @param catalystType Spark SQL schema that corresponds to the Parquet record type
* @param catalystType Spark SQL schema that corresponds to the Parquet record type. User-defined
* types should have been expanded.
* @param updater An updater which propagates converted field values to the parent container
*/
private[parquet] class CatalystRowConverter(
Expand All @@ -133,6 +134,12 @@ private[parquet] class CatalystRowConverter(
|${catalystType.prettyJson}
""".stripMargin)

assert(
!catalystType.existsRecursively(_.isInstanceOf[UserDefinedType[_]]),
s"""User-defined types in Catalyst schema should have already been expanded:
|${catalystType.prettyJson}
""".stripMargin)

logDebug(
s"""Building row converter for the following schema:
|
Expand Down Expand Up @@ -268,13 +275,6 @@ private[parquet] class CatalystRowConverter(
override def set(value: Any): Unit = updater.set(value.asInstanceOf[InternalRow].copy())
})

case t: UserDefinedType[_] =>
val catalystTypeForUDT = t.sqlType
val nullable = parquetType.isRepetition(Repetition.OPTIONAL)
val field = StructField("udt", catalystTypeForUDT, nullable)
val parquetTypeForUDT = new CatalystSchemaConverter().convertField(field)
newConverter(parquetTypeForUDT, catalystTypeForUDT, updater)

case _ =>
throw new RuntimeException(
s"Unable to create Parquet converter for data type ${catalystType.json}")
Expand Down Expand Up @@ -340,30 +340,36 @@ private[parquet] class CatalystRowConverter(
val scale = decimalType.scale

if (precision <= CatalystSchemaConverter.MAX_PRECISION_FOR_INT64) {
// Constructs a `Decimal` with an unscaled `Long` value if possible. The underlying
// `ByteBuffer` implementation is guaranteed to be `HeapByteBuffer`, so here we are using
// `Binary.toByteBuffer.array()` to steal the underlying byte array without copying it.
val buffer = value.toByteBuffer
val bytes = buffer.array()
val start = buffer.position()
val end = buffer.limit()

var unscaled = 0L
var i = start

while (i < end) {
unscaled = (unscaled << 8) | (bytes(i) & 0xff)
i += 1
}

val bits = 8 * (end - start)
unscaled = (unscaled << (64 - bits)) >> (64 - bits)
// Constructs a `Decimal` with an unscaled `Long` value if possible.
val unscaled = binaryToUnscaledLong(value)
Decimal(unscaled, precision, scale)
} else {
// Otherwise, resorts to an unscaled `BigInteger` instead.
Decimal(new BigDecimal(new BigInteger(value.getBytes), scale), precision, scale)
}
}

private def binaryToUnscaledLong(binary: Binary): Long = {
// The underlying `ByteBuffer` implementation is guaranteed to be `HeapByteBuffer`, so here
// we are using `Binary.toByteBuffer.array()` to steal the underlying byte array without
// copying it.
val buffer = binary.toByteBuffer
val bytes = buffer.array()
val start = buffer.position()
val end = buffer.limit()

var unscaled = 0L
var i = start

while (i < end) {
unscaled = (unscaled << 8) | (bytes(i) & 0xff)
i += 1
}

val bits = 8 * (end - start)
unscaled = (unscaled << (64 - bits)) >> (64 - bits)
unscaled
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ import org.apache.spark.sql.{AnalysisException, SQLConf}
* [[StructType]]. Note that Spark SQL [[TimestampType]] is similar to Hive timestamp, which
* has optional nanosecond precision, but different from `TIME_MILLS` and `TIMESTAMP_MILLIS`
* described in Parquet format spec. This argument only affects Parquet read path.
* @param writeLegacyParquetFormat Whether to use legacy Parquet format compatible with Spark 1.4
* @param writeLegacyParquetFormat Whether to use legacy Parquet format compatible with Spark 1.5
* and prior versions when converting a Catalyst [[StructType]] to a Parquet [[MessageType]].
* When set to false, use standard format defined in parquet-format spec. This argument only
* affects Parquet write path.
Expand Down Expand Up @@ -121,7 +121,7 @@ private[parquet] class CatalystSchemaConverter(
val precision = field.getDecimalMetadata.getPrecision
val scale = field.getDecimalMetadata.getScale

CatalystSchemaConverter.analysisRequire(
CatalystSchemaConverter.checkConversionRequirement(
maxPrecision == -1 || 1 <= precision && precision <= maxPrecision,
s"Invalid decimal precision: $typeName cannot store $precision digits (max $maxPrecision)")

Expand Down Expand Up @@ -155,7 +155,7 @@ private[parquet] class CatalystSchemaConverter(
}

case INT96 =>
CatalystSchemaConverter.analysisRequire(
CatalystSchemaConverter.checkConversionRequirement(
assumeInt96IsTimestamp,
"INT96 is not supported unless it's interpreted as timestamp. " +
s"Please try to set ${SQLConf.PARQUET_INT96_AS_TIMESTAMP.key} to true.")
Expand Down Expand Up @@ -197,11 +197,11 @@ private[parquet] class CatalystSchemaConverter(
//
// See: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#lists
case LIST =>
CatalystSchemaConverter.analysisRequire(
CatalystSchemaConverter.checkConversionRequirement(
field.getFieldCount == 1, s"Invalid list type $field")

val repeatedType = field.getType(0)
CatalystSchemaConverter.analysisRequire(
CatalystSchemaConverter.checkConversionRequirement(
repeatedType.isRepetition(REPEATED), s"Invalid list type $field")

if (isElementType(repeatedType, field.getName)) {
Expand All @@ -217,17 +217,17 @@ private[parquet] class CatalystSchemaConverter(
// See: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#backward-compatibility-rules-1
// scalastyle:on
case MAP | MAP_KEY_VALUE =>
CatalystSchemaConverter.analysisRequire(
CatalystSchemaConverter.checkConversionRequirement(
field.getFieldCount == 1 && !field.getType(0).isPrimitive,
s"Invalid map type: $field")

val keyValueType = field.getType(0).asGroupType()
CatalystSchemaConverter.analysisRequire(
CatalystSchemaConverter.checkConversionRequirement(
keyValueType.isRepetition(REPEATED) && keyValueType.getFieldCount == 2,
s"Invalid map type: $field")

val keyType = keyValueType.getType(0)
CatalystSchemaConverter.analysisRequire(
CatalystSchemaConverter.checkConversionRequirement(
keyType.isPrimitive,
s"Map key type is expected to be a primitive type, but found: $keyType")

Expand Down Expand Up @@ -299,7 +299,10 @@ private[parquet] class CatalystSchemaConverter(
* Converts a Spark SQL [[StructType]] to a Parquet [[MessageType]].
*/
def convert(catalystSchema: StructType): MessageType = {
Types.buildMessage().addFields(catalystSchema.map(convertField): _*).named("root")
Types
.buildMessage()
.addFields(catalystSchema.map(convertField): _*)
.named(CatalystSchemaConverter.SPARK_PARQUET_SCHEMA_NAME)
}

/**
Expand Down Expand Up @@ -347,10 +350,10 @@ private[parquet] class CatalystSchemaConverter(
// NOTE: Spark SQL TimestampType is NOT a well defined type in Parquet format spec.
//
// As stated in PARQUET-323, Parquet `INT96` was originally introduced to represent nanosecond
// timestamp in Impala for some historical reasons, it's not recommended to be used for any
// other types and will probably be deprecated in future Parquet format spec. That's the
// reason why Parquet format spec only defines `TIMESTAMP_MILLIS` and `TIMESTAMP_MICROS` which
// are both logical types annotating `INT64`.
// timestamp in Impala for some historical reasons. It's not recommended to be used for any
// other types and will probably be deprecated in some future version of parquet-format spec.
// That's the reason why parquet-format spec only defines `TIMESTAMP_MILLIS` and
// `TIMESTAMP_MICROS` which are both logical types annotating `INT64`.
//
// Originally, Spark SQL uses the same nanosecond timestamp type as Impala and Hive. Starting
// from Spark 1.5.0, we resort to a timestamp type with 100 ns precision so that we can store
Expand All @@ -361,7 +364,7 @@ private[parquet] class CatalystSchemaConverter(
// currently not implemented yet because parquet-mr 1.7.0 (the version we're currently using)
// hasn't implemented `TIMESTAMP_MICROS` yet.
//
// TODO Implements `TIMESTAMP_MICROS` once parquet-mr has that.
// TODO Converts `TIMESTAMP_MICROS` once parquet-mr implements that.
case TimestampType =>
Types.primitive(INT96, repetition).named(field.name)

Expand All @@ -372,7 +375,7 @@ private[parquet] class CatalystSchemaConverter(
// Decimals (legacy mode)
// ======================

// Spark 1.4.x and prior versions only support decimals with a maximum precision of 18 and
// Spark 1.5.x and prior versions only support decimals with a maximum precision of 18 and
// always store decimals in fixed-length byte arrays. To keep compatibility with these older
// versions, here we convert decimals with all precisions to `FIXED_LEN_BYTE_ARRAY` annotated
// by `DECIMAL`.
Expand Down Expand Up @@ -423,7 +426,7 @@ private[parquet] class CatalystSchemaConverter(
// ArrayType and MapType (legacy mode)
// ===================================

// Spark 1.4.x and prior versions convert `ArrayType` with nullable elements into a 3-level
// Spark 1.5.x and prior versions convert `ArrayType` with nullable elements into a 3-level
// `LIST` structure. This behavior is somewhat a hybrid of parquet-hive and parquet-avro
// (1.6.0rc3): the 3-level structure is similar to parquet-hive while the 3rd level element
// field name "array" is borrowed from parquet-avro.
Expand All @@ -442,7 +445,7 @@ private[parquet] class CatalystSchemaConverter(
.addField(convertField(StructField("array", elementType, nullable)))
.named("bag"))

// Spark 1.4.x and prior versions convert ArrayType with non-nullable elements into a 2-level
// Spark 1.5.x and prior versions convert ArrayType with non-nullable elements into a 2-level
// LIST structure. This behavior mimics parquet-avro (1.6.0rc3). Note that this case is
// covered by the backwards-compatibility rules implemented in `isElementType()`.
case ArrayType(elementType, nullable @ false) if writeLegacyParquetFormat =>
Expand All @@ -455,7 +458,7 @@ private[parquet] class CatalystSchemaConverter(
// "array" is the name chosen by parquet-avro (1.7.0 and prior version)
convertField(StructField("array", elementType, nullable), REPEATED))

// Spark 1.4.x and prior versions convert MapType into a 3-level group annotated by
// Spark 1.5.x and prior versions convert MapType into a 3-level group annotated by
// MAP_KEY_VALUE. This is covered by `convertGroupField(field: GroupType): DataType`.
case MapType(keyType, valueType, valueContainsNull) if writeLegacyParquetFormat =>
// <map-repetition> group <name> (MAP) {
Expand Down Expand Up @@ -523,11 +526,12 @@ private[parquet] class CatalystSchemaConverter(
}
}


private[parquet] object CatalystSchemaConverter {
val SPARK_PARQUET_SCHEMA_NAME = "spark_schema"

def checkFieldName(name: String): Unit = {
// ,;{}()\n\t= and space are special characters in Parquet schema
analysisRequire(
checkConversionRequirement(
!name.matches(".*[ ,;{}()\n\t=].*"),
s"""Attribute name "$name" contains invalid character(s) among " ,;{}()\\n\\t=".
|Please use alias to rename it.
Expand All @@ -539,7 +543,7 @@ private[parquet] object CatalystSchemaConverter {
schema
}

def analysisRequire(f: => Boolean, message: String): Unit = {
def checkConversionRequirement(f: => Boolean, message: String): Unit = {
if (!f) {
throw new AnalysisException(message)
}
Expand All @@ -553,16 +557,8 @@ private[parquet] object CatalystSchemaConverter {
numBytes
}

private val MIN_BYTES_FOR_PRECISION = Array.tabulate[Int](39)(computeMinBytesForPrecision)

// Returns the minimum number of bytes needed to store a decimal with a given `precision`.
def minBytesForPrecision(precision : Int) : Int = {
if (precision < MIN_BYTES_FOR_PRECISION.length) {
MIN_BYTES_FOR_PRECISION(precision)
} else {
computeMinBytesForPrecision(precision)
}
}
val minBytesForPrecision = Array.tabulate[Int](39)(computeMinBytesForPrecision)

val MAX_PRECISION_FOR_INT32 = maxPrecisionForBytes(4) /* 9 */

Expand Down
Loading