Skip to content
Prev Previous commit
Next Next commit
Addresed review feedback
  • Loading branch information
vlyubin committed Apr 8, 2015
commit dec680290e78aaef946b74a73373639e0375c16d
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row, SQLContext}

@BeanInfo
case class TokenizerTestData(rawText: String, wantedTokens: Seq[String]) {
/** Constructor used in [[org.apache.spark.ml.feature.JavaTokenizerSuite]] */
def this(rawText: String, wantedTokens: Array[String]) = this(rawText, wantedTokens.toSeq)
}
case class TokenizerTestData(rawText: String, wantedTokens: Array[String])

class RegexTokenizerSuite extends FunSuite with MLlibTestSparkContext {
import org.apache.spark.ml.feature.RegexTokenizerSuite._
Expand All @@ -46,14 +43,14 @@ class RegexTokenizerSuite extends FunSuite with MLlibTestSparkContext {
.setOutputCol("tokens")

val dataset0 = sqlContext.createDataFrame(Seq(
TokenizerTestData("Test for tokenization.", Seq("Test", "for", "tokenization", ".")),
TokenizerTestData("Te,st. punct", Seq("Te", ",", "st", ".", "punct"))
TokenizerTestData("Test for tokenization.", Array("Test", "for", "tokenization", ".")),
TokenizerTestData("Te,st. punct", Array("Te", ",", "st", ".", "punct"))
))
testRegexTokenizer(tokenizer, dataset0)

val dataset1 = sqlContext.createDataFrame(Seq(
TokenizerTestData("Test for tokenization.", Seq("Test", "for", "tokenization")),
TokenizerTestData("Te,st. punct", Seq("punct"))
TokenizerTestData("Test for tokenization.", Array("Test", "for", "tokenization")),
TokenizerTestData("Te,st. punct", Array("punct"))
))

tokenizer.setMinTokenLength(3)
Expand All @@ -64,8 +61,8 @@ class RegexTokenizerSuite extends FunSuite with MLlibTestSparkContext {
.setGaps(true)
.setMinTokenLength(0)
val dataset2 = sqlContext.createDataFrame(Seq(
TokenizerTestData("Test for tokenization.", Seq("Test", "for", "tokenization.")),
TokenizerTestData("Te,st. punct", Seq("Te,st.", "", "punct"))
TokenizerTestData("Test for tokenization.", Array("Test", "for", "tokenization.")),
TokenizerTestData("Te,st. punct", Array("Te,st.", "", "punct"))
))
testRegexTokenizer(tokenizer, dataset2)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ object CatalystTypeConverters {
import scala.collection.Map

/**
* Converts Scala objects to catalyst rows / types.
* Converts Scala objects to catalyst rows / types. This method is slow, and for batch
* conversion you should be using converter produced by createToCatalystConverter.
* Note: This is always called after schemaFor has been called.
* This ordering is important for UDT registration.
*/
Expand Down Expand Up @@ -97,6 +98,8 @@ object CatalystTypeConverters {

/**
* Creates a converter function that will convert Scala objects to the specified catalyst type.
* Typical use case would be converting a collection of rows that have the same schema. You will
* call this function once to get a converter, and apply it to every row.
*/
private[sql] def createToCatalystConverter(dataType: DataType): Any => Any = {
def extractOption(item: Any): Any = item match {
Expand Down Expand Up @@ -181,7 +184,10 @@ object CatalystTypeConverters {
}
}

/** Converts Catalyst types used internally in rows to standard Scala types */
/** Converts Catalyst types used internally in rows to standard Scala types
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrong comment style.

* This method is slow, and for batch conversion you should be using converter
* produced by createToScalaConverter.
*/
def convertToScala(a: Any, dataType: DataType): Any = (a, dataType) match {
// Check UDT first since UDTs can override other types
case (d, udt: UserDefinedType[_]) =>
Expand Down Expand Up @@ -210,6 +216,8 @@ object CatalystTypeConverters {

/**
* Creates a converter function that will convert Catalyst types to Scala type.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also add a note in this guy and its counterpart the pattern that they're expected to be used in. I was just reading executeCollect() and realized it may not be obvious to someone coming from there.

Just something like "use this during batch conversion, such as within a mapPartitions, to generate a function which efficiently converts Catalyst types back to Scala types for a particular schema."

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.

* Typical use case would be converting a collection of rows that have the same schema. You will
* call this function once to get a converter, and apply it to every row.
*/
private[sql] def createToScalaConverter(dataType: DataType): Any => Any = dataType match {
// Check UDT first since UDTs can override other types
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, analysis}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.types.{DataTypeConversions, StructType, StructField}

Expand All @@ -31,7 +31,8 @@ object LocalRelation {

def fromProduct(output: Seq[Attribute], data: Seq[Product]): LocalRelation = {
val schema = StructType.fromAttributes(output)
LocalRelation(output, data.map(row => DataTypeConversions.productToRow(row, schema)))
val converter = CatalystTypeConverters.createToCatalystConverter(schema)
LocalRelation(output, data.map(converter(_).asInstanceOf[Row]))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,6 @@ import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema


private[sql] object DataTypeConversions {

def productToRow(product: Product, schema: StructType): Row = {
val converted = new Array[Any](schema.length)
var i = 0
while (i < schema.length) {
converted(i) = CatalystTypeConverters.convertToCatalyst(product.productElement(i),
schema.fields(i).dataType)
i += 1
}
new GenericRowWithSchema(converted, schema)
}

def stringToTime(s: String): java.util.Date = {
if (!s.contains('T')) {
// JDBC escape string
Expand Down
6 changes: 2 additions & 4 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -962,10 +962,8 @@ class DataFrame private[sql](
// use a local variable to make sure the map closure doesn't capture the whole DataFrame
val schema = this.schema
queryExecution.executedPlan.execute().mapPartitions { rows =>
val converters = schema.fields.map {
f => CatalystTypeConverters.createToScalaConverter(f.dataType)
}
rows.map(CatalystTypeConverters.convertRowWithConverters(_, schema, converters))
val converter = CatalystTypeConverters.createToScalaConverter(schema)
rows.map(converter(_).asInstanceOf[Row])
}
}

Expand Down
6 changes: 2 additions & 4 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -404,10 +404,8 @@ class SQLContext(@transient val sparkContext: SparkContext)
// TODO: use MutableProjection when rowRDD is another DataFrame and the applied
// schema differs from the existing schema on any field data type.
val catalystRows = if (needsConversion) {
val converters = schema.fields.map {
f => CatalystTypeConverters.createToCatalystConverter(f.dataType)
}
rowRDD.map(CatalystTypeConverters.convertRowWithConverters(_, schema, converters))
val converter = CatalystTypeConverters.createToCatalystConverter(schema)
rowRDD.map(converter(_).asInstanceOf[Row])
} else {
rowRDD
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,13 @@ case class LocalTableScan(output: Seq[Attribute], rows: Seq[Row]) extends LeafNo


override def executeCollect(): Array[Row] = {
val converters = schema.fields.map {
f => CatalystTypeConverters.createToScalaConverter(f.dataType)
}
rows.map(CatalystTypeConverters.convertRowWithConverters(_, schema, converters)).toArray
val converter = CatalystTypeConverters.createToScalaConverter(schema)
rows.map(converter(_).asInstanceOf[Row]).toArray
}


override def executeTake(limit: Int): Array[Row] = {
val converters = schema.fields.map {
f => CatalystTypeConverters.createToScalaConverter(f.dataType)
}
rows.map(CatalystTypeConverters.convertRowWithConverters(_, schema, converters))
.take(limit).toArray
val converter = CatalystTypeConverters.createToScalaConverter(schema)
rows.map(converter(_).asInstanceOf[Row]).take(limit).toArray
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ

def executeCollect(): Array[Row] = {
execute().mapPartitions { iter =>
val converters = schema.fields.map {
f => CatalystTypeConverters.createToScalaConverter(f.dataType)
}
iter.map(CatalystTypeConverters.convertRowWithConverters(_, schema, converters))
val converter = CatalystTypeConverters.createToScalaConverter(schema)
iter.map(converter(_).asInstanceOf[Row])
}.collect()
}

Expand Down Expand Up @@ -131,10 +129,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
partsScanned += numPartsToTry
}

val converters = schema.fields.map {
f => CatalystTypeConverters.createToScalaConverter(f.dataType)
}
buf.toArray.map(CatalystTypeConverters.convertRowWithConverters(_, schema, converters))
val converter = CatalystTypeConverters.createToScalaConverter(schema)
buf.toArray.map(converter(_).asInstanceOf[Row])
}

protected def newProjection(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,8 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
private def collectData(): Array[Row] = child.execute().map(_.copy()).takeOrdered(limit)(ord)

override def executeCollect(): Array[Row] = {
val converters = schema.fields.map {
f => CatalystTypeConverters.createToScalaConverter(f.dataType)
}
collectData().map(CatalystTypeConverters.convertRowWithConverters(_, schema, converters))
val converter = CatalystTypeConverters.createToScalaConverter(schema)
collectData().map(converter(_).asInstanceOf[Row])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any reason that createToScalaConverter doesn't return a Row => Row to avoid the need to cast everywhere else?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the answer here is "udfs".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

createToScalaConverter is used recursively, so it will be generating functions that return Rows only at the top level, while there might be calls that generate functions that return arrays or decimals. Previously I used convertRowWithConverters, which was explicitly returning Row, but required an extra line to generate converters. This is cleaner, but I need a cast here.

}

// TODO: Terminal split should be implemented differently from non-terminal split.
Expand Down