Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions python/pyspark/sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ def __init__(self, sparkContext, sqlContext=None):
>>> df.registerTempTable("allTypes")
>>> sqlContext.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a '
... 'from allTypes where b and i > 0').collect()
[Row(c0=2, c1=2.0, c2=False, c3=2, c4=0, time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)]
[Row(_c0=2, _c1=2.0, _c2=False, _c3=2, _c4=0, \
time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)]
>>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, x.row.a, x.list)).collect()
[(1, u'string', 1.0, 1, True, datetime.datetime(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])]
"""
Expand Down Expand Up @@ -176,17 +177,17 @@ def registerFunction(self, name, f, returnType=StringType()):

>>> sqlContext.registerFunction("stringLengthString", lambda x: len(x))
>>> sqlContext.sql("SELECT stringLengthString('test')").collect()
[Row(c0=u'4')]
[Row(_c0=u'4')]

>>> from pyspark.sql.types import IntegerType
>>> sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType())
>>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
[Row(c0=4)]
[Row(_c0=4)]

>>> from pyspark.sql.types import IntegerType
>>> sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
>>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
[Row(c0=4)]
[Row(_c0=4)]
"""
func = lambda _, it: map(lambda x: f(*x), it)
ser = AutoBatchedSerializer(PickleSerializer())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,6 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
protected val WHERE = Keyword("WHERE")
protected val WITH = Keyword("WITH")

protected def assignAliases(exprs: Seq[Expression]): Seq[NamedExpression] = {
exprs.zipWithIndex.map {
case (ne: NamedExpression, _) => ne
case (e, i) => Alias(e, s"c$i")()
}
}

protected lazy val start: Parser[LogicalPlan] =
start1 | insert | cte

Expand All @@ -130,8 +123,8 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
val base = r.getOrElse(OneRowRelation)
val withFilter = f.map(Filter(_, base)).getOrElse(base)
val withProjection = g
.map(Aggregate(_, assignAliases(p), withFilter))
.getOrElse(Project(assignAliases(p), withFilter))
.map(Aggregate(_, p.map(UnresolvedAlias(_)), withFilter))
.getOrElse(Project(p.map(UnresolvedAlias(_)), withFilter))
val withDistinct = d.map(_ => Distinct(withProjection)).getOrElse(withProjection)
val withHaving = h.map(Filter(_, withDistinct)).getOrElse(withDistinct)
val withOrder = o.map(_(withHaving)).getOrElse(withHaving)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package org.apache.spark.sql.catalyst.analysis

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf}
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -74,10 +72,10 @@ class Analyzer(
ResolveSortReferences ::
ResolveGenerate ::
ResolveFunctions ::
ResolveAliases ::
ExtractWindowExpressions ::
GlobalAggregates ::
UnresolvedHavingClauseAttributes ::
TrimGroupingAliases ::
typeCoercionRules ++
extendedResolutionRules : _*)
)
Expand Down Expand Up @@ -132,12 +130,38 @@ class Analyzer(
}

/**
* Removes no-op Alias expressions from the plan.
* Replaces [[UnresolvedAlias]]s with concrete aliases.
*/
object TrimGroupingAliases extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case Aggregate(groups, aggs, child) =>
Aggregate(groups.map(_.transform { case Alias(c, _) => c }), aggs, child)
object ResolveAliases extends Rule[LogicalPlan] {
private def assignAliases(exprs: Seq[NamedExpression]) = {
// The `UnresolvedAlias`s will appear only at root of a expression tree, we don't need
// to transform down the whole tree.
exprs.zipWithIndex.map {
case (u @ UnresolvedAlias(child), i) =>
child match {
case _: UnresolvedAttribute => u
case ne: NamedExpression => ne
case ev: ExtractValueWithStruct => Alias(ev, ev.field.name)()
case g: Generator if g.resolved && g.elementTypes.size > 1 => MultiAlias(g, Nil)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolve the names for Generator

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But we will do it in ResolveGenerate right?

case e if !e.resolved => u
case other => Alias(other, s"_c$i")()
}
case (other, _) => other
}
}

def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case Aggregate(groups, aggs, child)
if child.resolved && aggs.exists(_.isInstanceOf[UnresolvedAlias]) =>
Aggregate(groups, assignAliases(aggs), child)

case g: GroupingAnalytics
if g.child.resolved && g.aggregations.exists(_.isInstanceOf[UnresolvedAlias]) =>
g.withNewAggs(assignAliases(g.aggregations))

case Project(projectList, child)
if child.resolved && projectList.exists(_.isInstanceOf[UnresolvedAlias]) =>
Project(assignAliases(projectList), child)
}
}

Expand Down Expand Up @@ -228,7 +252,7 @@ class Analyzer(
}

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case i@InsertIntoTable(u: UnresolvedRelation, _, _, _, _) =>
case i @ InsertIntoTable(u: UnresolvedRelation, _, _, _, _) =>
i.copy(table = EliminateSubQueries(getTable(u)))
case u: UnresolvedRelation =>
getTable(u)
Expand All @@ -248,24 +272,24 @@ class Analyzer(
Project(
projectList.flatMap {
case s: Star => s.expand(child.output, resolver)
case Alias(f @ UnresolvedFunction(_, args), name) if containsStar(args) =>
case UnresolvedAlias(f @ UnresolvedFunction(_, args)) if containsStar(args) =>
val expandedArgs = args.flatMap {
case s: Star => s.expand(child.output, resolver)
case o => o :: Nil
}
Alias(child = f.copy(children = expandedArgs), name)() :: Nil
case Alias(c @ CreateArray(args), name) if containsStar(args) =>
UnresolvedAlias(child = f.copy(children = expandedArgs)) :: Nil
case UnresolvedAlias(c @ CreateArray(args)) if containsStar(args) =>
val expandedArgs = args.flatMap {
case s: Star => s.expand(child.output, resolver)
case o => o :: Nil
}
Alias(c.copy(children = expandedArgs), name)() :: Nil
case Alias(c @ CreateStruct(args), name) if containsStar(args) =>
UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil
case UnresolvedAlias(c @ CreateStruct(args)) if containsStar(args) =>
val expandedArgs = args.flatMap {
case s: Star => s.expand(child.output, resolver)
case o => o :: Nil
}
Alias(c.copy(children = expandedArgs), name)() :: Nil
UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil
case o => o :: Nil
},
child)
Expand Down Expand Up @@ -353,7 +377,9 @@ class Analyzer(
case u @ UnresolvedAttribute(nameParts) =>
// Leave unchanged if resolution fails. Hopefully will be resolved next round.
val result =
withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) }
withPosition(u) {
q.resolveChildren(nameParts, resolver).map(trimUnresolvedAlias).getOrElse(u)
}
logDebug(s"Resolving $u to $result")
result
case UnresolvedExtractValue(child, fieldExpr) if child.resolved =>
Expand All @@ -379,6 +405,11 @@ class Analyzer(
exprs.exists(_.collect { case _: Star => true }.nonEmpty)
}

private def trimUnresolvedAlias(ne: NamedExpression) = ne match {
case UnresolvedAlias(child) => child
case other => other
}

private def resolveSortOrders(ordering: Seq[SortOrder], plan: LogicalPlan, throws: Boolean) = {
ordering.map { order =>
// Resolve SortOrder in one round.
Expand All @@ -388,7 +419,7 @@ class Analyzer(
try {
val newOrder = order transformUp {
case u @ UnresolvedAttribute(nameParts) =>
plan.resolve(nameParts, resolver).getOrElse(u)
plan.resolve(nameParts, resolver).map(trimUnresolvedAlias).getOrElse(u)
case UnresolvedExtractValue(child, fieldName) if child.resolved =>
ExtractValue(child, fieldName, resolver)
}
Expand Down Expand Up @@ -586,18 +617,6 @@ class Analyzer(
/** Extracts a [[Generator]] expression and any names assigned by aliases to their output. */
private object AliasedGenerator {
def unapply(e: Expression): Option[(Generator, Seq[String])] = e match {
case Alias(g: Generator, name)
if g.resolved &&
g.elementTypes.size > 1 &&
java.util.regex.Pattern.matches("_c[0-9]+", name) => {
// Assume the default name given by parser is "_c[0-9]+",
// TODO in long term, move the naming logic from Parser to Analyzer.
// In projection, Parser gave default name for TGF as does for normal UDF,
// but the TGF probably have multiple output columns/names.
// e.g. SELECT explode(map(key, value)) FROM src;
// Let's simply ignore the default given name for this case.
Some((g, Nil))
}
case Alias(g: Generator, name) if g.resolved && g.elementTypes.size > 1 =>
// If not given the default names, and the TGF with multiple output columns
failAnalysis(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,7 @@ trait CheckAnalysis {
case e => e.children.foreach(checkValidAggregateExpression)
}

val cleaned = aggregateExprs.map(_.transform {
// Should trim aliases around `GetField`s. These aliases are introduced while
// resolving struct field accesses, because `GetField` is not a `NamedExpression`.
// (Should we just turn `GetField` into a `NamedExpression`?)
case Alias(g, _) => g
})

cleaned.foreach(checkValidAggregateExpression)
aggregateExprs.foreach(checkValidAggregateExpression)

case _ => // Fallbacks to the following checks
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.{errors, trees}
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -206,3 +205,22 @@ case class UnresolvedExtractValue(child: Expression, extraction: Expression)

override def toString: String = s"$child[$extraction]"
}

/**
* Holds the expression that has yet to be aliased.
*/
case class UnresolvedAlias(child: Expression) extends NamedExpression
with trees.UnaryNode[Expression] {

override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute")
override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers")
override def exprId: ExprId = throw new UnresolvedException(this, "exprId")
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
override def dataType: DataType = throw new UnresolvedException(this, "dataType")
override def name: String = throw new UnresolvedException(this, "name")

override lazy val resolved = false

override def eval(input: InternalRow = null): Any =
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import scala.collection.Map

import org.apache.spark.sql.{catalyst, AnalysisException}
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.types._

Expand All @@ -41,16 +41,22 @@ object ExtractValue {
resolver: Resolver): ExtractValue = {

(child.dataType, extraction) match {
case (StructType(fields), Literal(fieldName, StringType)) =>
val ordinal = findField(fields, fieldName.toString, resolver)
GetStructField(child, fields(ordinal), ordinal)
case (ArrayType(StructType(fields), containsNull), Literal(fieldName, StringType)) =>
val ordinal = findField(fields, fieldName.toString, resolver)
GetArrayStructFields(child, fields(ordinal), ordinal, containsNull)
case (StructType(fields), NonNullLiteral(v, StringType)) =>
val fieldName = v.toString
val ordinal = findField(fields, fieldName, resolver)
GetStructField(child, fields(ordinal).copy(name = fieldName), ordinal)

case (ArrayType(StructType(fields), containsNull), NonNullLiteral(v, StringType)) =>
val fieldName = v.toString
val ordinal = findField(fields, fieldName, resolver)
GetArrayStructFields(child, fields(ordinal).copy(name = fieldName), ordinal, containsNull)

case (_: ArrayType, _) if extraction.dataType.isInstanceOf[IntegralType] =>
GetArrayItem(child, extraction)

case (_: MapType, _) =>
GetMapValue(child, extraction)

case (otherType, _) =>
val errorMsg = otherType match {
case StructType(_) | ArrayType(StructType(_), _) =>
Expand Down Expand Up @@ -94,16 +100,21 @@ trait ExtractValue extends UnaryExpression {
self: Product =>
}

abstract class ExtractValueWithStruct extends ExtractValue {
self: Product =>

def field: StructField
override def toString: String = s"$child.${field.name}"
}

/**
* Returns the value of fields in the Struct `child`.
*/
case class GetStructField(child: Expression, field: StructField, ordinal: Int)
extends ExtractValue {
extends ExtractValueWithStruct {

override def dataType: DataType = field.dataType
override def nullable: Boolean = child.nullable || field.nullable
override def foldable: Boolean = child.foldable
override def toString: String = s"$child.${field.name}"

override def eval(input: InternalRow): Any = {
val baseValue = child.eval(input).asInstanceOf[InternalRow]
Expand All @@ -118,12 +129,9 @@ case class GetArrayStructFields(
child: Expression,
field: StructField,
ordinal: Int,
containsNull: Boolean) extends ExtractValue {
containsNull: Boolean) extends ExtractValueWithStruct {

override def dataType: DataType = ArrayType(field.dataType, containsNull)
override def nullable: Boolean = child.nullable
override def foldable: Boolean = child.foldable
override def toString: String = s"$child.${field.name}"

override def eval(input: InternalRow): Any = {
val baseValue = child.eval(input).asInstanceOf[Seq[InternalRow]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,8 @@ object PartialAggregation {
partialEvaluations(new TreeNodeRef(e)).finalEvaluation

case e: Expression =>
// Should trim aliases around `GetField`s. These aliases are introduced while
// resolving struct field accesses, because `GetField` is not a `NamedExpression`.
// (Should we just turn `GetField` into a `NamedExpression`?)
val trimmed = e.transform { case Alias(g: ExtractValue, _) => g }
namedGroupingExpressions.collectFirst {
case (expr, ne) if expr semanticEquals trimmed => ne.toAttribute
case (expr, ne) if expr semanticEquals e => ne.toAttribute
}.getOrElse(e)
}).asInstanceOf[Seq[NamedExpression]]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.Logging
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, EliminateSubQueries, Resolver}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.trees.TreeNode
Expand Down Expand Up @@ -252,14 +252,13 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
// One match, but we also need to extract the requested nested field.
case Seq((a, nestedFields)) =>
// The foldLeft adds ExtractValues for every remaining parts of the identifier,
// and aliases it with the last part of the identifier.
// and wrap it with UnresolvedAlias which will be removed later.
// For example, consider "a.b.c", where "a" is resolved to an existing attribute.
// Then this will add ExtractValue("c", ExtractValue("b", a)), and alias
// the final expression as "c".
// Then this will add ExtractValue("c", ExtractValue("b", a)), and wrap it as
// UnresolvedAlias(ExtractValue("c", ExtractValue("b", a))).
val fieldExprs = nestedFields.foldLeft(a: Expression)((expr, fieldName) =>
ExtractValue(expr, Literal(fieldName), resolver))
val aliasName = nestedFields.last
Some(Alias(fieldExprs, aliasName)())
Some(UnresolvedAlias(fieldExprs))

// No matches.
case Seq() =>
Expand Down
Loading