Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
some cleanup
  • Loading branch information
marmbrus committed Nov 12, 2015
commit ab2dbb9fcd1146e5fdb22a595a2d1722805291a4
10 changes: 5 additions & 5 deletions sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ object Encoders {
private def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = {
assert(encoders.length > 1)
// make sure all encoders are resolved, i.e. `Attribute` has been resolved to `BoundReference`.
assert(encoders.forall(_.constructExpression.find(_.isInstanceOf[Attribute]).isEmpty))
assert(encoders.forall(_.fromRowExpression.find(_.isInstanceOf[Attribute]).isEmpty))

val schema = StructType(encoders.zipWithIndex.map {
case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema)
Expand All @@ -93,8 +93,8 @@ object Encoders {
val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")

val extractExpressions = encoders.map {
case e if e.flat => e.extractExpressions.head
case other => CreateStruct(other.extractExpressions)
case e if e.flat => e.toRowExpressions.head
case other => CreateStruct(other.toRowExpressions)
}.zipWithIndex.map { case (expr, index) =>
expr.transformUp {
case BoundReference(0, t: ObjectType, _) =>
Expand All @@ -107,11 +107,11 @@ object Encoders {

val constructExpressions = encoders.zipWithIndex.map { case (enc, index) =>
if (enc.flat) {
enc.constructExpression.transform {
enc.fromRowExpression.transform {
case b: BoundReference => b.copy(ordinal = index)
}
} else {
enc.constructExpression.transformUp {
enc.fromRowExpression.transformUp {
case BoundReference(ordinal, dt, _) =>
GetInternalRowField(BoundReference(index, enc.schema, nullable = true), ordinal, dt)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,6 @@ object HiveTypeCoercion {

case e: ImplicitCastInputTypes if e.inputTypes.nonEmpty =>
val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) =>
println(s"$in -> $expected")
// If we cannot do the implicit cast, just use the original input.
implicitCast(in, expected).getOrElse(in)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,43 +61,35 @@ object ExpressionEncoder {

/**
* Given a set of N encoders, constructs a new encoder that produce objects as items in an
* N-tuple. Note that these encoders should first be bound correctly to the combined input
* schema.
* N-tuple. Note that these encoders should be unresolved so that information about
* name/positional binding is preserved.
*/
def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = {
encoders.foreach(_.assertUnresolved())

val schema =
StructType(
encoders.zipWithIndex.map { case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema)})
encoders.zipWithIndex.map {
case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema)
})
val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")

// Rebind the encoders to the nested schema that will be produced by the aggregation.
// Rebind the encoders to the nested schema.
val newConstructExpressions = encoders.zipWithIndex.map {
case (e, i) if !e.flat =>
println(s"=== $i - nested ===")
println(e.constructExpression.treeString)
println()
println(e.nested(i).constructExpression.treeString)

e.nested(i).constructExpression
case (e, i) =>
println(s"=== $i - flat ===")
println(e.constructExpression.treeString)
println()
println(e.shift(i).constructExpression.treeString)

e.shift(i).constructExpression
case (e, i) if !e.flat => e.nested(i).fromRowExpression
case (e, i) => e.shift(i).fromRowExpression
}

val constructExpression =
NewInstance(cls, newConstructExpressions, false, ObjectType(cls))

val input = BoundReference(0, ObjectType(cls), false)
val extractExpressions = encoders.zipWithIndex.map {
case (e, i) if !e.flat => CreateStruct(e.extractExpressions.map(_ transformUp {
case (e, i) if !e.flat => CreateStruct(e.toRowExpressions.map(_ transformUp {
case b: BoundReference =>
Invoke(input, s"_${i + 1}", b.dataType, Nil)
}))
case (e, i) => e.extractExpressions.head transformUp {
case (e, i) => e.toRowExpressions.head transformUp {
case b: BoundReference =>
Invoke(input, s"_${i + 1}", b.dataType, Nil)
}
Expand All @@ -122,26 +114,27 @@ object ExpressionEncoder {
* A generic encoder for JVM objects.
*
* @param schema The schema after converting `T` to a Spark SQL row.
* @param extractExpressions A set of expressions, one for each top-level field that can be used to
* extract the values from a raw object.
* @param toRowExpressions A set of expressions, one for each top-level field that can be used to
* extract the values from a raw object into an [[InternalRow]].
* @param fromRowExpression An expression that will construct an object given an [[InternalRow]].
* @param clsTag A classtag for `T`.
*/
case class ExpressionEncoder[T](
schema: StructType,
flat: Boolean,
extractExpressions: Seq[Expression],
constructExpression: Expression,
toRowExpressions: Seq[Expression],
fromRowExpression: Expression,
clsTag: ClassTag[T])
extends Encoder[T] {

if (flat) require(extractExpressions.size == 1)
if (flat) require(toRowExpressions.size == 1)

@transient
private lazy val extractProjection = GenerateUnsafeProjection.generate(extractExpressions)
private lazy val extractProjection = GenerateUnsafeProjection.generate(toRowExpressions)
private val inputRow = new GenericMutableRow(1)

@transient
private lazy val constructProjection = GenerateSafeProjection.generate(constructExpression :: Nil)
private lazy val constructProjection = GenerateSafeProjection.generate(fromRowExpression :: Nil)

/**
* Returns an encoded version of `t` as a Spark SQL row. Note that multiple calls to
Expand All @@ -153,7 +146,8 @@ case class ExpressionEncoder[T](
extractProjection(inputRow)
} catch {
case e: Exception =>
throw new RuntimeException(s"Error while encoding: $e\n${extractExpressions.map(_.treeString).mkString("\n")}", e)
throw new RuntimeException(
s"Error while encoding: $e\n${toRowExpressions.map(_.treeString).mkString("\n")}", e)
}

/**
Expand All @@ -165,17 +159,35 @@ case class ExpressionEncoder[T](
constructProjection(row).get(0, ObjectType(clsTag.runtimeClass)).asInstanceOf[T]
} catch {
case e: Exception =>
throw new RuntimeException(s"Error while decoding: $e\n${constructExpression.treeString}", e)
throw new RuntimeException(s"Error while decoding: $e\n${fromRowExpression.treeString}", e)
}

/**
* The process of resolution to a given schema throws away information about where a given field
* is being bound by ordinal instead of by name. This method checks to make sure this process
* has not been done already in places where we plan to do later composition of encoders.
*/
def assertUnresolved(): Unit = {
(fromRowExpression +: toRowExpressions).foreach(_.foreach {
case a: AttributeReference =>
sys.error(s"Unresolved encoder expected, but $a was found.")
case _ =>
})
}

/**
* Returns a new copy of this encoder, where the expressions used by `fromRow` are resolved to the
* given schema.
*/
def resolve(schema: Seq[Attribute]): ExpressionEncoder[T] = {
val plan = Project(Alias(constructExpression, "")() :: Nil, LocalRelation(schema))
val positionToAttribute = AttributeMap.toIndex(schema)
val unbound = fromRowExpression transform {
case b: BoundReference => positionToAttribute(b.ordinal)
}

val plan = Project(Alias(unbound, "")() :: Nil, LocalRelation(schema))
val analyzedPlan = SimpleAnalyzer.execute(plan)
copy(constructExpression = analyzedPlan.expressions.head.children.head)
copy(fromRowExpression = analyzedPlan.expressions.head.children.head)
}

/**
Expand All @@ -184,39 +196,14 @@ case class ExpressionEncoder[T](
* resolve before bind.
*/
def bind(schema: Seq[Attribute]): ExpressionEncoder[T] = {
copy(constructExpression = BindReferences.bindReference(constructExpression, schema))
copy(fromRowExpression = BindReferences.bindReference(fromRowExpression, schema))
}

/**
* Replaces any bound references in the schema with the attributes at the corresponding ordinal
* in the provided schema. This can be used to "relocate" a given encoder to pull values from
* a different schema than it was initially bound to. It can also be used to assign attributes
* to ordinal based extraction (i.e. because the input data was a tuple).
* Returns a new encoder with input columns shifted by `delta` ordinals
*/
def unbind(schema: Seq[Attribute]): ExpressionEncoder[T] = {
val positionToAttribute = AttributeMap.toIndex(schema)
copy(constructExpression = constructExpression transform {
case b: BoundReference => positionToAttribute(b.ordinal)
})
}

/**
* Given an encoder that has already been bound to a given schema, returns a new encoder
* where the positions are mapped from `oldSchema` to `newSchema`. This can be used, for example,
* when you are trying to use an encoder on grouping keys that were originally part of a larger
* row, but now you have projected out only the key expressions.
*/
def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): ExpressionEncoder[T] = {
val positionToAttribute = AttributeMap.toIndex(oldSchema)
val attributeToNewPosition = AttributeMap.byIndex(newSchema)
copy(constructExpression = constructExpression transform {
case r: BoundReference =>
r.copy(ordinal = attributeToNewPosition(positionToAttribute(r.ordinal)))
})
}

def shift(delta: Int): ExpressionEncoder[T] = {
copy(constructExpression = constructExpression transform {
copy(fromRowExpression = fromRowExpression transform {
case r: BoundReference => r.copy(ordinal = r.ordinal + delta)
})
}
Expand All @@ -226,9 +213,11 @@ case class ExpressionEncoder[T](
* input row have been modified to pull the object out from a nested struct, instead of the
* top level fields.
*/
def nested(i: Int): ExpressionEncoder[T] = {
val input = BoundReference(i, NullType, true)
copy(constructExpression = constructExpression transformUp {
private def nested(i: Int): ExpressionEncoder[T] = {
// We don't always know our input type at this point since it might be unresolved.
// We fill in null and it will get unbound to the actual attribute at this position.
val input = BoundReference(i, NullType, nullable = true)
copy(fromRowExpression = fromRowExpression transformUp {
case u: Attribute =>
UnresolvedExtractValue(input, Literal(u.name))
case b: BoundReference =>
Expand All @@ -239,24 +228,7 @@ case class ExpressionEncoder[T](
})
}

/**
* Returns a copy of this encoder where the expressions used to create an object given an
* input row have been modified to pull the object out from a nested struct, instead of the
* top level fields.
*/
def nested(input: Expression = BoundReference(0, schema, true)): ExpressionEncoder[T] = {
copy(constructExpression = constructExpression transform {
case u: Attribute if u != input =>
UnresolvedExtractValue(input, Literal(u.name))
case b: BoundReference if b != input =>
GetStructField(
input,
StructField(s"i[${b.ordinal}]", b.dataType),
b.ordinal)
})
}

protected val attrs = extractExpressions.flatMap(_.collect {
protected val attrs = toRowExpressions.flatMap(_.collect {
case _: UnresolvedAttribute => ""
case a: Attribute => s"#${a.exprId}"
case b: BoundReference => s"[${b.ordinal}]"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,19 @@
package org.apache.spark.sql.catalyst

import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.expressions.AttributeReference

package object encoders {
/**
* Returns an internal encoder object that can be used to serialize / deserialize JVM objects
* into Spark SQL rows. The implicit encoder should always be unresolved (i.e. have no attribute
* references from a specific schema.) This requirement allows us to preserve whether a given
* object type is being bound by name or by ordinal when doing resolution.
*/
private[sql] def encoderFor[A : Encoder]: ExpressionEncoder[A] = implicitly[Encoder[A]] match {
case e: ExpressionEncoder[A] => e
case e: ExpressionEncoder[A] =>
e.assertUnresolved()
e
case _ => sys.error(s"Only expression encoders are supported today")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ object ExtractValue {
* Returns the value of fields in the Struct `child`.
*
* No need to do type checking since it is handled by [[ExtractValue]].
* TODO: Unify with [[GetInternalRowField]], remove the need to specify a [[StructField]].
*/
case class GetStructField(child: Expression, field: StructField, ordinal: Int)
extends UnaryExpression {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -469,9 +469,12 @@ case class MapPartitions[T, U](

/** Factory for constructing new `AppendColumn` nodes. */
object AppendColumn {
def apply[T : Encoder, U : Encoder](func: T => U, child: LogicalPlan): AppendColumn[T, U] = {
def apply[T, U : Encoder](
func: T => U,
tEncoder: ExpressionEncoder[T],
child: LogicalPlan): AppendColumn[T, U] = {
val attrs = encoderFor[U].schema.toAttributes
new AppendColumn[T, U](func, encoderFor[T], encoderFor[U], attrs, child)
new AppendColumn[T, U](func, tEncoder, encoderFor[U], attrs, child)
}
}

Expand All @@ -492,14 +495,16 @@ case class AppendColumn[T, U](

/** Factory for constructing new `MapGroups` nodes. */
object MapGroups {
def apply[K : Encoder, T : Encoder, U : Encoder](
def apply[K, T, U : Encoder](
func: (K, Iterator[T]) => TraversableOnce[U],
kEncoder: ExpressionEncoder[K],
tEncoder: ExpressionEncoder[T],
groupingAttributes: Seq[Attribute],
child: LogicalPlan): MapGroups[K, T, U] = {
new MapGroups(
func,
encoderFor[K],
encoderFor[T],
kEncoder,
tEncoder,
encoderFor[U],
groupingAttributes,
encoderFor[U].schema.toAttributes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ class ExpressionEncoderSuite extends SparkFunSuite {
}
val convertedData = encoder.toRow(inputData)
val schema = encoder.schema.toAttributes
val boundEncoder = encoder.resolve(schema).bind(schema)
val boundEncoder = encoder.resolve(schema)
val convertedBack = try boundEncoder.fromRow(convertedData) catch {
case e: Exception =>
fail(
Expand Down
19 changes: 19 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/Column.scala
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,25 @@ class Column(protected[sql] val expr: Expression) extends Logging {
/** Creates a column based on the given expression. */
private def withExpr(newExpr: Expression): Column = new Column(newExpr)

/**
* Returns the expression for this column either with an existing or auto assigned name.
*/
private[sql] def named: NamedExpression = expr match {
// Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we
// will remove intermediate Alias for ExtractValue chain, and we need to alias it again to
// make it a NamedExpression.
case u: UnresolvedAttribute => UnresolvedAlias(u)

case expr: NamedExpression => expr

// Leave an unaliased generator with an empty list of names since the analyzer will generate
// the correct defaults after the nested expression's type has been resolved.
case explode: Explode => MultiAlias(explode, Nil)
case jt: JsonTuple => MultiAlias(jt, Nil)

case expr: Expression => Alias(expr, expr.prettyString)()
}

override def toString: String = expr.prettyString

override def equals(that: Any): Boolean = that match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -735,7 +735,7 @@ class DataFrame private[sql](
*/
@scala.annotation.varargs
def select(cols: Column*): DataFrame = withPlan {
Project(nameColumns(cols), logicalPlan)
Project(cols.map(_.named), logicalPlan)
}

/**
Expand Down
Loading