Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
685fd07
use UTF8String instead of String for StringType
Mar 31, 2015
21f67c6
cleanup
Mar 31, 2015
4699c3a
use Array[Byte] in UTF8String
Mar 31, 2015
d32abd1
fix utf8 for python api
Mar 31, 2015
a85fb27
refactor
Mar 31, 2015
6b499ac
fix style
Apr 1, 2015
5f9e120
fix sql tests
Apr 1, 2015
38c303e
fix python sql tests
Apr 1, 2015
c7dd4d2
fix some catalyst tests
Apr 1, 2015
bb52e44
fix scala style
Apr 1, 2015
8b45864
fix codegen with UTF8String
Apr 1, 2015
23a766c
refactor
Apr 1, 2015
9dc32d1
fix some hive tests
Apr 2, 2015
73e4363
Merge branch 'master' of github.com:apache/spark into string
Apr 2, 2015
956b0a4
fix hive tests
Apr 2, 2015
9f4c194
convert data type for data source
Apr 2, 2015
537631c
some comment about Date
Apr 2, 2015
28d6f32
refactor
Apr 2, 2015
28f3d81
Merge branch 'master' of github.com:apache/spark into string
Apr 3, 2015
e5fa5b8
remove clone in UTF8String
Apr 3, 2015
8d17f21
fix hive compatibility tests
Apr 3, 2015
fd11364
optimize UTF8String
Apr 3, 2015
ac18ae6
address comment
Apr 3, 2015
2089d24
add hashcode check back
Apr 3, 2015
13d9d42
Merge branch 'master' of github.com:apache/spark into string
Apr 3, 2015
867bf50
fix String filter push down
Apr 4, 2015
1314a37
address comments from Yin
Apr 8, 2015
5116b43
rollback unrelated changes
Apr 8, 2015
08d897b
Merge branch 'master' of github.com:apache/spark into string
Apr 9, 2015
b04a19c
add comment for getString/setString
Apr 10, 2015
744788f
Merge branch 'master' of github.com:apache/spark into string
Apr 13, 2015
341ec2c
turn off scala style check in UTF8StringSuite
Apr 13, 2015
59025c8
address comments from @marmbrus
Apr 15, 2015
6d776a9
Merge branch 'master' of github.com:apache/spark into string
Apr 15, 2015
2772f0d
fix new test failure
Apr 15, 2015
3b7bfa8
fix schema of AddJar
Apr 15, 2015
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
refactor
  • Loading branch information
Davies Liu committed Apr 2, 2015
commit 28d6f32eda151ed51f35117eb5beb1ec6b6882d1
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,28 @@ trait ScalaReflection {
case (d: BigDecimal, _) => Decimal(d)
case (d: java.math.BigDecimal, _) => Decimal(d)
case (d: java.sql.Date, _) => DateUtils.fromJavaDate(d)
case (s: String, st: StringType) => UTF8String(s)
case (s: String, _) => UTF8String(s)
case (other, _) => other
}

/**
* Converts Scala objects to catalyst rows / types.
* Note: This should be called before do evaluation on Row
* (It does not support UDT)
*/
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will be good to make it clear when we need this function.

def convertToCatalyst(a: Any): Any = a match {
case s: String => UTF8String(s)
case d: java.sql.Date => DateUtils.fromJavaDate(d)
case d: BigDecimal => Decimal(d)
case d: java.math.BigDecimal => Decimal(d)
case seq: Seq[Any] => seq.map(convertToCatalyst)
case r: Row => Row(r.toSeq.map(convertToCatalyst): _*)
case arr: Array[Any] => arr.toSeq.map(convertToCatalyst).toArray
case m: Map[Any, Any] =>
m.map { case (k, v) => (convertToCatalyst(k), convertToCatalyst(v)) }.toMap
case other => other
}

/** Converts Catalyst types used internally in rows to standard Scala types */
def convertToScala(a: Any, dataType: DataType): Any = (a, dataType) match {
// Check UDT first since UDTs can override other types
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
}"""
case other =>
q"""
override def ${accessorForType(dataType)}(i: Int):${termForType(dataType)} = {
override def ${accessorForType(dataType)}(i: Int): ${termForType(dataType)} = {
..$ifStatements;
$accessorFailure
}"""
Expand All @@ -148,13 +148,13 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
dataType match {
case StringType =>
q"""
override def setString(i: Int, value: String): Unit = {
override def setString(i: Int, value: String) {
..$ifStatements;
$accessorFailure
}"""
case other =>
q"""
override def ${mutatorForType(dataType)}(i: Int, value: ${termForType(dataType)}):Unit = {
override def ${mutatorForType(dataType)}(i: Int, value: ${termForType(dataType)}) {
..$ifStatements;
$accessorFailure
}"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import java.sql.{Date, Timestamp}

import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.types._

object Literal {
Expand All @@ -42,20 +43,9 @@ object Literal {
throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v)
}

/**
* convert String in `v` as UTF8String
*/
def convertToUTF8String(v: Any): Any = v match {
case s: String => UTF8String(s)
case seq: Seq[Any] => seq.map(convertToUTF8String)
case r: Row => Row(r.toSeq.map(convertToUTF8String): _*)
case arr: Array[Any] => arr.toSeq.map(convertToUTF8String).toArray
case m: Map[Any, Any] =>
m.map { case (k, v) => (convertToUTF8String(k), convertToUTF8String(v)) }.toMap
case other => other
def create(v: Any, dataType: DataType): Literal = {
Literal(ScalaReflection.convertToCatalyst(v), dataType)
}

def create(v: Any, dataType: DataType): Literal = Literal(convertToUTF8String(v), dataType)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,7 @@ class GenericMutableRow(v: Array[Any]) extends GenericRow(v) with MutableRow {
override def setFloat(ordinal: Int, value: Float): Unit = { values(ordinal) = value }
override def setInt(ordinal: Int, value: Int): Unit = { values(ordinal) = value }
override def setLong(ordinal: Int, value: Long): Unit = { values(ordinal) = value }
override def setString(ordinal: Int, value: String): Unit = {
// TODO(davies): need this?
values(ordinal) = UTF8String(value)
}
override def setString(ordinal: Int, value: String) { values(ordinal) = UTF8String(value)}
override def setNullAt(i: Int): Unit = { values(i) = null }

override def setShort(ordinal: Int, value: Short): Unit = { values(ordinal) = value }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@ package org.apache.spark.sql.catalyst.expressions

import java.util.regex.Pattern

import scala.collection.IndexedSeqOptimized


import org.apache.spark.sql.catalyst.analysis.UnresolvedException
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -226,8 +223,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends
override def children: Seq[Expression] = str :: pos :: len :: Nil

@inline
def slice[T, C <: Any](str: C, startPos: Int, sliceLen: Int)
(implicit ev: (C=>IndexedSeqOptimized[T,_])): Any = {
def slicePos(startPos: Int, sliceLen: Int, length: () => Int): (Int, Int) = {
// Hive and SQL use one-based indexing for SUBSTR arguments but also accept zero and
// negative indices for start positions. If a start index i is greater than 0, it
// refers to element i-1 in the sequence. If a start index i is less than 0, it refers
Expand All @@ -236,29 +232,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends

val start = startPos match {
case pos if pos > 0 => pos - 1
case neg if neg < 0 => str.length + neg
case _ => 0
}

val end = sliceLen match {
case max if max == Integer.MAX_VALUE => max
case x => start + x
}

str.slice(start, end)
}

@inline
def slice(str: UTF8String, startPos: Int, sliceLen: Int): Any = {
// Hive and SQL use one-based indexing for SUBSTR arguments but also accept zero and
// negative indices for start positions. If a start index i is greater than 0, it
// refers to element i-1 in the sequence. If a start index i is less than 0, it refers
// to the -ith element before the end of the sequence. If a start index i is 0, it
// refers to the first element.

val start = startPos match {
case pos if pos > 0 => pos - 1
case neg if neg < 0 => str.length + neg
case neg if neg < 0 => length() + neg
case _ => 0
}

Expand All @@ -267,24 +241,26 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends
case x => start + x
}

str.slice(start, end)
(start, end)
}

override def eval(input: Row): Any = {
val string = str.eval(input)

val po = pos.eval(input)
val ln = len.eval(input)

if ((string == null) || (po == null) || (ln == null)) {
null
} else {
val start = po.asInstanceOf[Int]
val length = ln.asInstanceOf[Int]

val length = ln.asInstanceOf[Int]
string match {
case ba: Array[Byte] => slice(ba, start, length)
case s: UTF8String => slice(s, start, length)
case ba: Array[Byte] =>
val (st, end) = slicePos(start, length, () => ba.length)
ba.slice(st, end)
case s: UTF8String =>
val (st, end) = slicePos(start, length, () => s.length)
s.slice(st, end)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.scalactic.TripleEqualsSupport.Spread
import org.scalatest.FunSuite
import org.scalatest.Matchers._

import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.analysis.UnresolvedGetField
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -60,7 +61,7 @@ class ExpressionEvaluationBaseSuite extends FunSuite {
class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {

def create_row(values: Array[Any]): Row = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we just use Row.apply here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Row.apply will call ScalaRefaction.convertToCatalyst(), we still need a wrapper to do that.

new GenericRow(values.toSeq.map(Literal.convertToUTF8String).toArray)
new GenericRow(values.toSeq.map(ScalaReflection.convertToCatalyst).toArray)
}

test("literals") {
Expand Down
6 changes: 3 additions & 3 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -396,11 +396,11 @@ class SQLContext(@transient val sparkContext: SparkContext)
// schema differs from the existing schema on any field data type.
def needsConversion(dt: DataType): Boolean = dt match {
case StringType => true
case DateType => true
case DecimalType() => true
case dt: ArrayType => needsConversion(dt.elementType)
case dt: MapType => needsConversion(dt.keyType) || needsConversion(dt.valueType)
case dt: StructType =>
!dt.fields.forall(f => !needsConversion(f.dataType))
// TODO(davies): check other types and values
case dt: StructType => !dt.fields.forall(f => !needsConversion(f.dataType))
case other => false
}
val convertedRdd = if (needsConversion(schema)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,6 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:

def execute(): RDD[Row] = {
// TODO: Clean up after ourselves?
// TODO(davies): convert internal type to Scala Type
val childResults = child.execute().map(_.copy()).cache()

val parent = childResults.mapPartitions { iter =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ private[sql] object JsonRDD extends Logging {
private[sql] def rowToJSON(rowSchema: StructType, gen: JsonGenerator)(row: Row) = {
def valWriter: (DataType, Any) => Unit = {
case (_, null) | (NullType, _) => gen.writeNull()
case (StringType, v: String) => gen.writeString(v.toString)
case (StringType, v: String) => gen.writeString(v)
case (TimestampType, v: java.sql.Timestamp) => gen.writeString(v.toString)
case (IntegerType, v: Int) => gen.writeNumber(v)
case (ShortType, v: Short) => gen.writeNumber(v)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import scala.collection.mutable.ArrayBuffer

import org.apache.hadoop.fs.Path

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.parquet.ParquetRelation2._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.types._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,22 @@

package org.apache.spark.sql.hive

import org.apache.spark.sql.catalyst.expressions.Row

import scala.collection.JavaConversions._

import org.apache.spark.annotation.Experimental
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.{Row, _}
import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate
import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.sources.DescribeCommand
import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand, _}
import org.apache.spark.sql.hive.execution._
import org.apache.spark.sql.parquet.ParquetRelation
import org.apache.spark.sql.sources.{CreateTableUsingAsSelect, CreateTableUsing}
import org.apache.spark.sql.types.{UTF8String, StringType}
import org.apache.spark.sql.sources.{CreateTableUsing, CreateTableUsingAsSelect, DescribeCommand}
import org.apache.spark.sql.types.StringType


private[hive] trait HiveStrategies {
Expand Down Expand Up @@ -131,10 +128,7 @@ private[hive] trait HiveStrategies {
val partitionValues = part.getValues
var i = 0
while (i < partitionValues.size()) {
inputData(i) = partitionValues(i) match {
case s: String => UTF8String(s)
case other => other
}
inputData(i) = ScalaReflection.convertToCatalyst(partitionValues(i))
i += 1
}
pruningCondition(inputData)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@ import java.util.{Properties, ArrayList => JArrayList}
import scala.collection.JavaConversions._
import scala.language.implicitConversions

import com.esotericsoftware.kryo.Kryo
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.hive.common.StatsSetupConst
import org.apache.hadoop.hive.common.`type`.HiveDecimal
import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.ql.Context
import org.apache.hadoop.hive.ql.exec.{UDF, Utilities}
import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table}
import org.apache.hadoop.hive.ql.plan.{CreateTableDesc, FileSinkDesc, TableDesc}
import org.apache.hadoop.hive.ql.processors.CommandProcessorFactory
Expand All @@ -45,7 +47,6 @@ import org.apache.hadoop.{io => hadoopIo}
import org.apache.spark.Logging
import org.apache.spark.sql.types.{Decimal, DecimalType, UTF8String}


/**
* This class provides the UDF creation and also the UDF instance serialization and
* de-serialization cross process boundary.
Expand All @@ -60,19 +61,14 @@ private[hive] case class HiveFunctionWrapper(var functionClassName: String)
// for Serialization
def this() = this(null)

import java.io.{InputStream, OutputStream}

import com.esotericsoftware.kryo.Kryo
import org.apache.hadoop.hive.ql.exec.{UDF, Utilities}

import org.apache.spark.util.Utils._
import org.apache.spark.util.Utils._

@transient
private val methodDeSerialize = {
val method = classOf[Utilities].getDeclaredMethod(
"deserializeObjectByKryo",
classOf[Kryo],
classOf[InputStream],
classOf[java.io.InputStream],
classOf[Class[_]])
method.setAccessible(true)

Expand All @@ -85,7 +81,7 @@ import org.apache.spark.util.Utils._
"serializeObjectByKryo",
classOf[Kryo],
classOf[Object],
classOf[OutputStream])
classOf[java.io.OutputStream])
method.setAccessible(true)

method
Expand Down