Skip to content
Next Next commit
Creating converters for ScalaReflection stuff, and more
  • Loading branch information
vlyubin committed Apr 6, 2015
commit 41b2aa944520c45e0eea9ea2ee70ecb3c8d7055a
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,90 @@ trait ScalaReflection {
case (other, _) => other
}

/**
* Creates a converter function that will convert Scala objects to the specified catalyst type.
*/
private[sql] def createCatalystConverter(dataType: DataType): (Any) => Any = {
def extractOption(item: Any) = item match {
case o: Some[_] => o.get
case other => other
}

dataType match {
// Check UDT first since UDTs can override other types
case udt: UserDefinedType[_] => (item) => {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please break the line after the first =>.

if (item == None) null else udt.serialize(extractOption(item))
}

case arrayType: ArrayType => {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally, please don't add braces in a case clause, namely, prefer

foo match {
  case ... =>
    expr1
    expr2
}

over

foo match {
  case ... => {
    expr1
    expr2
  }
}

This rule applies to the whole PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok

val elementConverter = createCatalystConverter(arrayType.elementType)
(item: Any) => {
if (item == None) {
null
} else {
extractOption(item) match {
case a: Array[_] => a.toSeq.map(elementConverter)
case s: Seq[_] => s.map(elementConverter)
}
}
}
}

case mapType: MapType => {
val keyConverter = createCatalystConverter(mapType.keyType)
val valueConverter = createCatalystConverter(mapType.valueType)
(item: Any) => {
if (item == None) {
null
} else {
extractOption(item) match {
case m: Map[_, _] => m.map{ case (k, v) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Space before {.

keyConverter(k) -> valueConverter(v) }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put the } on a separate line.

case other => other
}
}
}
}

case structType: StructType => {
val converters = new Array[(Any) => Any](structType.length)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Any => Any should be OK.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is. BTW, both (Any) => Any and Any => Any are used in the codebase. I guess one can make a separate PR to replace (Any) => Any with Any => Any.

val iter = structType.fields.iterator
var idx = 0
while (iter.hasNext) {
converters(idx) = createCatalystConverter(iter.next().dataType)
idx += 1
}
(item: Any) => {
if (item == None) {
null
} else {
extractOption(item) match {
case p: Product => {
val ar = new Array[Any](structType.size)
val iter = p.productIterator
var idx = 0
while (idx < structType.size) {
ar(idx) = converters(idx)(iter.next())
idx += 1
}
new GenericRowWithSchema(ar, structType)
}
case other => other
}
}
}
}

case _ => (item: Any) => extractOption(item) match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Break the line after the first =>.

case None => null
case d: BigDecimal => Decimal(d)
case d: java.math.BigDecimal => Decimal(d)
case d: java.sql.Date => DateUtils.fromJavaDate(d)
case other => other
}
}
}

/** Converts Catalyst types used internally in rows to standard Scala types */
def convertToScala(a: Any, dataType: DataType): Any = (a, dataType) match {
// Check UDT first since UDTs can override other types
Expand All @@ -94,11 +178,92 @@ trait ScalaReflection {
case (other, _) => other
}

/**
* Creates a converter function that will convert Catalyst types to Scala type.
*/
private[sql] def createScalaConverter(dataType: DataType): (Any) => Any = dataType match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any => Any

// Check UDT first since UDTs can override other types
case udt: UserDefinedType[_] => (item: Any) => udt.deserialize(item)

case arrayType: ArrayType => {
val elementConverter = createScalaConverter(arrayType.elementType)
(item: Any) => item match {
case s: Seq[_] => s.map(elementConverter)
case other => other
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed? It may lead to other problems (the items inside other is not converted).

I think the only case you need here is null. then you could just use if and item.isinstanceOf[Seq[_]]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to be breaking java beans.

}
}

case mapType: MapType => {
val keyConverter = createScalaConverter(mapType.keyType)
val valueConverter = createScalaConverter(mapType.valueType)
(item: Any) => item match {
case m: Map[_, _] => m.map { case (k, v) =>
keyConverter(k) -> valueConverter(v)
}
case other => other
}
}

case s: StructType => {
val converters = createConvertersForStruct(s)
(item: Any) => item match {
case r: Row => convertRowToScalaWithConverters(r, s, converters)
case other => other
}
}

case _: DecimalType => (item: Any) => item match {
case d: Decimal => d.toJavaBigDecimal
case other => other
}

case DateType => (item: Any) => item match {
case i: Int => DateUtils.toJavaDate(i)
case other => other
}

case other => (item: Any) => item
}

def convertRowToScala(r: Row, schema: StructType): Row = {
// TODO: This is very slow!!!
new GenericRowWithSchema(
r.toSeq.zip(schema.fields.map(_.dataType))
.map(r_dt => convertToScala(r_dt._1, r_dt._2)).toArray, schema)
val ar = new Array[Any](r.size)
var idx = 0
while (idx < r.size) {
ar(idx) = convertToScala(r(idx), schema.fields(idx).dataType)
idx += 1
}
new GenericRowWithSchema(ar, schema)
}

/**
* Creates Catalyst->Scala converter functions for each field of the given StructType.
*/
private[sql] def createConvertersForStruct(s: StructType): Array[(Any) => Any] = {
val converters = new Array[(Any) => Any](s.length)
val iter = s.fields.iterator
var idx = 0
while (iter.hasNext) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you may could:
s.fields.toSeq.map(createScalaConverter(_.dataType)).toArray

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably reasonable since this is not per-tuple

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, will do.

converters(idx) = createScalaConverter(iter.next().dataType)
idx += 1
}
converters
}

/**
* Converts a row with Catalyst types to a row with Scala types using the provided set of
* converter functions.
*/
private[sql] def convertRowToScalaWithConverters(
row: Row,
schema: StructType,
converters: Array[(Any) => Any]): Row = {
val ar = new Array[Any](row.size)
var idx = 0
while (idx < row.size) {
ar(idx) = converters(idx)(row(idx))
idx += 1
}
new GenericRowWithSchema(ar, schema)
}

/** Returns a Sequence of attributes for the given case class type. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,20 @@ import java.text.SimpleDateFormat

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema


private[sql] object DataTypeConversions {

def productToRow(product: Product, schema: StructType): Row = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only used from LocalRelation AFAICT, but I think we should also add a comment that this method is slow, and that users should use CatalystTypeConverters.createToCatalystConverter for batch conversions -- we don't want future contributors to rely on such a convenient method without at least exposing the performance characteristics.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I replaced the usage inside LocalRelation with the new stuff, so now this method isn't used anywhere. Should we leave it around with a warning that it's slow, or ask people to use converters at all times?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see the change here yet -- but if no one is using this method I would be very happy to remove it. (It's not public outside of sql, so should be safe.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

val mutableRow = new GenericMutableRow(product.productArity)
val schemaFields = schema.fields.toArray

val ar = new Array[Any](schema.length)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe choose a better name, like "elementConverters" or something, idk

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those are already converted elements. I'll rename to converted

var i = 0
while (i < mutableRow.length) {
mutableRow(i) =
ScalaReflection.convertToCatalyst(product.productElement(i), schemaFields(i).dataType)
while (i < schema.length) {
ar(i) =
ScalaReflection.convertToCatalyst(product.productElement(i), schema.fields(i).dataType)
i += 1
}

mutableRow
new GenericRowWithSchema(ar, schema)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't quite related to this PR, but I don't think it was necessary to use GenericMutableRow here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. The original version uses a mutable row mostly because of the updates in the while loop I guess.

}

def stringToTime(s: String): java.util.Date = {
Expand Down
5 changes: 4 additions & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -961,7 +961,10 @@ class DataFrame private[sql](
lazy val rdd: RDD[Row] = {
// use a local variable to make sure the map closure doesn't capture the whole DataFrame
val schema = this.schema
queryExecution.executedPlan.execute().map(ScalaReflection.convertRowToScala(_, schema))
queryExecution.executedPlan.execute().mapPartitions(rows => {
val converters = ScalaReflection.createConvertersForStruct(schema)
rows.map(ScalaReflection.convertRowToScalaWithConverters(_, schema, converters))
})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the case clause in pattern matching, we usually use

.mapPartitions { rows =>
  ...
}

or

.mapPartitions(rows => ...)

but not

.mapPartitions(rows => {
  ...
})

}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,10 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow}
import org.apache.spark.sql.catalyst.expressions.{SpecificMutableRow, Attribute}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics}
import org.apache.spark.sql.types.StructType

import scala.collection.immutable

/**
* :: DeveloperApi ::
*/
Expand All @@ -39,13 +37,13 @@ object RDDConversions {
Iterator.empty
} else {
val bufferedIterator = iterator.buffered
val mutableRow = new GenericMutableRow(bufferedIterator.head.productArity)
val mutableRow = new SpecificMutableRow(schema.fields.map(_.dataType))
val schemaFields = schema.fields.toArray
val converters = schemaFields.map(f => ScalaReflection.createCatalystConverter(f.dataType))
bufferedIterator.map { r =>
var i = 0
while (i < mutableRow.length) {
mutableRow(i) =
ScalaReflection.convertToCatalyst(r.productElement(i), schemaFields(i).dataType)
mutableRow(i) = converters(i)(r.productElement(i))
i += 1
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,16 @@ case class LocalTableScan(output: Seq[Attribute], rows: Seq[Row]) extends LeafNo

override def execute(): RDD[Row] = rdd

override def executeCollect(): Array[Row] =
rows.map(ScalaReflection.convertRowToScala(_, schema)).toArray

override def executeTake(limit: Int): Array[Row] =
rows.map(ScalaReflection.convertRowToScala(_, schema)).take(limit).toArray
override def executeCollect(): Array[Row] = {
val converters = ScalaReflection.createConvertersForStruct(schema)
rows.map(ScalaReflection.convertRowToScalaWithConverters(_, schema, converters)).toArray
}


override def executeTake(limit: Int): Array[Row] = {
val converters = ScalaReflection.createConvertersForStruct(schema)
rows.map(ScalaReflection.convertRowToScalaWithConverters(_, schema, converters))
.take(limit).toArray
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,12 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
/**
* Runs this query returning the result as an array.
*/

def executeCollect(): Array[Row] = {
execute().map(ScalaReflection.convertRowToScala(_, schema)).collect()
execute().mapPartitions(iter => {
val converters = ScalaReflection.createConvertersForStruct(schema)
iter.map(ScalaReflection.convertRowToScalaWithConverters(_, schema, converters))
}).collect()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same style issue as above.

}

/**
Expand Down Expand Up @@ -125,7 +129,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
partsScanned += numPartsToTry
}

buf.toArray.map(ScalaReflection.convertRowToScala(_, this.schema))
val converters = ScalaReflection.createConvertersForStruct(schema)
buf.toArray.map(ScalaReflection.convertRowToScalaWithConverters(_, schema, converters))
}

protected def newProjection(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,11 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
private def collectData(): Array[Row] = child.execute().map(_.copy()).takeOrdered(limit)(ord)

// TODO: Is this copying for no reason?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Is this an outdated comment?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It wasn't relevant before the change as well (that copy call wasn't there anymore). I'll remove it.

override def executeCollect(): Array[Row] =
collectData().map(ScalaReflection.convertRowToScala(_, this.schema))

override def executeCollect(): Array[Row] = {
val converters = ScalaReflection.createConvertersForStruct(this.schema)
collectData().map(ScalaReflection.convertRowToScalaWithConverters(_, schema, converters))
}

// TODO: Terminal split should be implemented differently from non-terminal split.
// TODO: Pick num splits based on |limit|.
Expand Down