Skip to content

Commit b8a8b9a

Browse files
More fixes to short and byte conversion
1 parent 63d1b57 commit b8a8b9a

File tree

6 files changed

+100
-41
lines changed

6 files changed

+100
-41
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
154154
case logical.WriteToFile(path, child) =>
155155
val relation =
156156
ParquetRelation.create(path, child, sparkContext.hadoopConfiguration)
157-
InsertIntoParquetTable(relation, planLater(child), overwrite=true)(sparkContext) :: Nil
157+
// Note: overwrite=false because otherwise the metadata we just created will be deleted
158+
InsertIntoParquetTable(relation, planLater(child), overwrite=false)(sparkContext) :: Nil
158159
case logical.InsertIntoTable(table: ParquetRelation, partition, child, overwrite) =>
159160
InsertIntoParquetTable(table, planLater(child), overwrite)(sparkContext) :: Nil
160161
case PhysicalOperation(projectList, filters: Seq[Expression], relation: ParquetRelation) =>

sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala

Lines changed: 46 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -93,14 +93,30 @@ private[sql] object CatalystConverter {
9393
fieldIndex,
9494
parent)
9595
}
96-
case ctype: NativeType => {
97-
// note: for some reason matching for StringType fails so use this ugly if instead
98-
if (ctype == StringType) {
99-
new CatalystPrimitiveStringConverter(parent, fieldIndex)
100-
} else {
101-
new CatalystPrimitiveConverter(parent, fieldIndex)
96+
// Strings, Shorts and Bytes do not have a corresponding type in Parquet
97+
// so we need to treat them separately
98+
case StringType => {
99+
new CatalystPrimitiveConverter(parent, fieldIndex) {
100+
override def addBinary(value: Binary): Unit =
101+
parent.updateString(fieldIndex, value)
102102
}
103103
}
104+
case ShortType => {
105+
new CatalystPrimitiveConverter(parent, fieldIndex) {
106+
override def addInt(value: Int): Unit =
107+
parent.updateShort(fieldIndex, value.asInstanceOf[ShortType.JvmType])
108+
}
109+
}
110+
case ByteType => {
111+
new CatalystPrimitiveConverter(parent, fieldIndex) {
112+
override def addInt(value: Int): Unit =
113+
parent.updateByte(fieldIndex, value.asInstanceOf[ByteType.JvmType])
114+
}
115+
}
116+
// All other primitive types use the default converter
117+
case ctype: NativeType => { // note: need the type tag here!
118+
new CatalystPrimitiveConverter(parent, fieldIndex)
119+
}
104120
case _ => throw new RuntimeException(
105121
s"unable to convert datatype ${field.dataType.toString} in CatalystConverter")
106122
}
@@ -153,6 +169,12 @@ private[parquet] trait CatalystConverter {
153169
protected[parquet] def updateLong(fieldIndex: Int, value: Long): Unit =
154170
updateField(fieldIndex, value)
155171

172+
protected[parquet] def updateShort(fieldIndex: Int, value: Short): Unit =
173+
updateField(fieldIndex, value)
174+
175+
protected[parquet] def updateByte(fieldIndex: Int, value: Byte): Unit =
176+
updateField(fieldIndex, value)
177+
156178
protected[parquet] def updateDouble(fieldIndex: Int, value: Double): Unit =
157179
updateField(fieldIndex, value)
158180

@@ -309,6 +331,12 @@ private[parquet] class CatalystPrimitiveRowConverter(
309331
override protected[parquet] def updateLong(fieldIndex: Int, value: Long): Unit =
310332
current.setLong(fieldIndex, value)
311333

334+
override protected[parquet] def updateShort(fieldIndex: Int, value: Short): Unit =
335+
current.setShort(fieldIndex, value)
336+
337+
override protected[parquet] def updateByte(fieldIndex: Int, value: Byte): Unit =
338+
current.setByte(fieldIndex, value)
339+
312340
override protected[parquet] def updateDouble(fieldIndex: Int, value: Double): Unit =
313341
current.setDouble(fieldIndex, value)
314342

@@ -350,21 +378,6 @@ private[parquet] class CatalystPrimitiveConverter(
350378
parent.updateLong(fieldIndex, value)
351379
}
352380

353-
/**
354-
* A `parquet.io.api.PrimitiveConverter` that converts Parquet strings (fixed-length byte arrays)
355-
* into Catalyst Strings.
356-
*
357-
* @param parent The parent group converter.
358-
* @param fieldIndex The index inside the record.
359-
*/
360-
private[parquet] class CatalystPrimitiveStringConverter(
361-
parent: CatalystConverter,
362-
fieldIndex: Int)
363-
extends CatalystPrimitiveConverter(parent, fieldIndex) {
364-
override def addBinary(value: Binary): Unit =
365-
parent.updateString(fieldIndex, value)
366-
}
367-
368381
object CatalystArrayConverter {
369382
val INITIAL_ARRAY_SIZE = 20
370383
}
@@ -486,6 +499,18 @@ private[parquet] class CatalystNativeArrayConverter(
486499
elements += 1
487500
}
488501

502+
override protected[parquet] def updateShort(fieldIndex: Int, value: Short): Unit = {
503+
checkGrowBuffer()
504+
buffer(elements) = value.asInstanceOf[NativeType]
505+
elements += 1
506+
}
507+
508+
override protected[parquet] def updateByte(fieldIndex: Int, value: Byte): Unit = {
509+
checkGrowBuffer()
510+
buffer(elements) = value.asInstanceOf[NativeType]
511+
elements += 1
512+
}
513+
489514
override protected[parquet] def updateLong(fieldIndex: Int, value: Long): Unit = {
490515
checkGrowBuffer()
491516
buffer(elements) = value.asInstanceOf[NativeType]

sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,7 @@ private[sql] case class ParquetRelation(val path: String)
5555
.getSchema
5656

5757
/** Attributes */
58-
// TODO: THIS POTENTIALLY LOOSES TYPE INFORMATION!!!!
59-
// e.g. short <-> INT32 and byte <-> INT32
60-
override val output =
61-
ParquetTypesConverter
62-
.convertToAttributes(parquetSchema)
58+
override val output = ParquetTypesConverter.readSchemaFromFile(new Path(path))
6359

6460
override def newInstance = ParquetRelation(path).asInstanceOf[this.type]
6561

sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,13 @@ case class ParquetTableScan(
6565
NewFileInputFormat.addInputPath(job, path)
6666
}
6767

68-
// Store Parquet schema in `Configuration`
68+
// Store both requested and original schema in `Configuration`
6969
conf.set(
7070
RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA,
7171
ParquetTypesConverter.convertToString(output))
72+
conf.set(
73+
RowWriteSupport.SPARK_ROW_SCHEMA,
74+
ParquetTypesConverter.convertToString(relation.output))
7275

7376
// Store record filtering predicate in `Configuration`
7477
// Note 1: the input format ignores all predicates that cannot be expressed

sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,19 @@ import org.apache.hadoop.mapreduce.Job
2626
import parquet.hadoop.{ParquetFileReader, Footer, ParquetFileWriter}
2727
import parquet.hadoop.metadata.{ParquetMetadata, FileMetaData}
2828
import parquet.hadoop.util.ContextUtil
29-
import parquet.schema.{Type => ParquetType, PrimitiveType => ParquetPrimitiveType, MessageType, MessageTypeParser}
29+
import parquet.schema.{Type => ParquetType, PrimitiveType => ParquetPrimitiveType, MessageType}
3030
import parquet.schema.{GroupType => ParquetGroupType, OriginalType => ParquetOriginalType, ConversionPatterns}
3131
import parquet.schema.PrimitiveType.{PrimitiveTypeName => ParquetPrimitiveTypeName}
3232
import parquet.schema.Type.Repetition
3333

34+
import org.apache.spark.Logging
3435
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute}
3536
import org.apache.spark.sql.catalyst.types._
36-
import com.google.common.io.BaseEncoding
37-
import org.apache.spark.sql.execution.SparkSqlSerializer
3837

3938
// Implicits
4039
import scala.collection.JavaConversions._
4140

42-
private[parquet] object ParquetTypesConverter {
41+
private[parquet] object ParquetTypesConverter extends Logging {
4342
def isPrimitiveType(ctype: DataType): Boolean =
4443
classOf[PrimitiveType] isAssignableFrom ctype.getClass
4544

@@ -62,7 +61,7 @@ private[parquet] object ParquetTypesConverter {
6261
* Converts a given Parquet `Type` into the corresponding
6362
* [[org.apache.spark.sql.catalyst.types.DataType]].
6463
*
65-
* Note that we apply the following conversion rules:
64+
* We apply the following conversion rules:
6665
* <ul>
6766
* <li> Primitive types are converter to the corresponding primitive type.</li>
6867
* <li> Group types that have a single field that is itself a group, which has repetition
@@ -97,6 +96,7 @@ private[parquet] object ParquetTypesConverter {
9796
keyValueGroup.getFields.apply(1).getName == CatalystConverter.MAP_VALUE_SCHEMA_NAME
9897
}
9998
}
99+
100100
def correspondsToArray(groupType: ParquetGroupType): Boolean = {
101101
groupType.getFieldCount == 1 &&
102102
groupType.getFieldName(0) == CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME &&
@@ -188,7 +188,7 @@ private[parquet] object ParquetTypesConverter {
188188
* <li> Primitive types are converted into Parquet's primitive types.</li>
189189
* <li> [[org.apache.spark.sql.catalyst.types.StructType]]s are converted
190190
* into Parquet's `GroupType` with the corresponding field types.</li>
191-
* <li> [[org.apache.spark.sql.catalyst.types.ArrayType]]s are converterd
191+
* <li> [[org.apache.spark.sql.catalyst.types.ArrayType]]s are converted
192192
* into a 2-level nested group, where the outer group has the inner
193193
* group as sole field. The inner group has name `values` and
194194
* repetition level `REPEATED` and has the element type of
@@ -269,9 +269,6 @@ private[parquet] object ParquetTypesConverter {
269269
}
270270
}
271271

272-
def getSchema(schemaString: String) : MessageType =
273-
MessageTypeParser.parseMessageType(schemaString)
274-
275272
def convertToAttributes(parquetSchema: ParquetType): Seq[Attribute] = {
276273
parquetSchema
277274
.asGroupType()
@@ -302,7 +299,7 @@ private[parquet] object ParquetTypesConverter {
302299
StructType.fromAttributes(schema).toString
303300
}
304301

305-
def writeMetaData(attributes: Seq[Attribute], origPath: Path, conf: Configuration) {
302+
def writeMetaData(attributes: Seq[Attribute], origPath: Path, conf: Configuration): Unit = {
306303
if (origPath == null) {
307304
throw new IllegalArgumentException("Unable to write Parquet metadata: path is null")
308305
}
@@ -385,4 +382,28 @@ private[parquet] object ParquetTypesConverter {
385382
footers(0).getParquetMetadata
386383
}
387384
}
385+
386+
/**
387+
* Reads in Parquet Metadata from the given path and tries to extract the schema
388+
* (Catalyst attributes) from the application-specific key-value map. If this
389+
* is empty it falls back to converting from the Parquet file schema which
390+
* may lead to an upcast of types (e.g., {byte, short} to int).
391+
*
392+
* @param origPath The path at which we expect one (or more) Parquet files.
393+
* @return A list of attributes that make up the schema.
394+
*/
395+
def readSchemaFromFile(origPath: Path): Seq[Attribute] = {
396+
val keyValueMetadata: java.util.Map[String, String] =
397+
readMetaData(origPath)
398+
.getFileMetaData
399+
.getKeyValueMetaData
400+
if (keyValueMetadata.get(RowReadSupport.SPARK_METADATA_KEY) != null) {
401+
convertFromString(keyValueMetadata.get(RowReadSupport.SPARK_METADATA_KEY))
402+
} else {
403+
val attributes = convertToAttributes(
404+
readMetaData(origPath).getFileMetaData.getSchema)
405+
log.warn(s"Falling back to schema conversion from Parquet types; result: $attributes")
406+
attributes
407+
}
408+
}
388409
}

sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,23 @@ class ParquetQuerySuite extends QueryTest with FunSuite with BeforeAndAfterAll {
111111
}
112112

113113
test("Read/Write All Types") {
114-
val data = AllDataTypes("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true)
115114
val tempDir = getTempFilePath("parquetTest").getCanonicalPath
116-
sparkContext.parallelize(data :: Nil).saveAsParquetFile(tempDir)
117-
assert(parquetFile(tempDir).collect().head === data)
115+
val range = (0 to 255)
116+
TestSQLContext.sparkContext.parallelize(range)
117+
.map(x => AllDataTypes(s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0))
118+
.saveAsParquetFile(tempDir)
119+
val result = parquetFile(tempDir).collect()
120+
range.foreach {
121+
i =>
122+
assert(result(i).getString(0) == s"$i", s"row $i String field did not match, got ${result(i).getString(0)}")
123+
assert(result(i).getInt(1) === i)
124+
assert(result(i).getLong(2) === i.toLong)
125+
assert(result(i).getFloat(3) === i.toFloat)
126+
assert(result(i).getDouble(4) === i.toDouble)
127+
assert(result(i).getShort(5) === i.toShort)
128+
assert(result(i).getByte(6) === i.toByte)
129+
assert(result(i).getBoolean(7) === (i % 2 == 0))
130+
}
118131
}
119132

120133
test("self-join parquet files") {

0 commit comments

Comments
 (0)