Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/pyspark/sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def __init__(self, sparkContext, sqlContext=None):
>>> df.registerTempTable("allTypes")
>>> sqlCtx.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a '
... 'from allTypes where b and i > 0').collect()
[Row(c0=2, c1=2.0, c2=False, c3=2, c4=0...8, 1, 14, 1, 5), a=1)]
[Row(c0=2, c1=2.0, c2=False, c3=2, c4=0...8, 1, 14, 1, 5), row.a=1)]
>>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time,
... x.row.a, x.list)).collect()
[(1, u'string', 1.0, 1, True, ...(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ class SqlParser extends AbstractSparkSQLParser {

protected lazy val dotExpressionHeader: Parser[Expression] =
(ident <~ ".") ~ ident ~ rep("." ~> ident) ^^ {
case i1 ~ i2 ~ rest => UnresolvedAttribute(i1 + "." + i2 + rest.mkString(".", ".", ""))
case i1 ~ i2 ~ rest => UnresolvedAttribute((Seq(i1, i2) ++ rest).mkString("."))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

val i1="a"
val i2="b"
val rest:Seq[String]=Nil
println(i1 + "." + i2 + rest.mkString(".", ".", ""))

outputs a.b., what we expect is a.b

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was my mistake...didn't test the mkString method on Nil

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about
UnresolvedAttribute(i1 + "." + i2 + rest.mkString("."))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

val i1="a"
val i2="b"
val rest:Seq[String]="c" :: Nil
println(i1 + "." + i2 + rest.mkString("."))

Outputs a.bc

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That make sense, thanks!

}

protected lazy val dataType: Parser[DataType] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,9 @@ class Analyzer(catalog: Catalog,
logDebug(s"Resolving $u to $result")
result
case UnresolvedGetField(child, fieldName) if child.resolved =>
resolveGetField(child, fieldName)
val result = q.resolveGetField(child, fieldName, resolver)
logDebug(s"Resolving $fieldName of $child to $result")
result
}
}

Expand All @@ -277,36 +279,6 @@ class Analyzer(catalog: Catalog,
*/
protected def containsStar(exprs: Seq[Expression]): Boolean =
exprs.exists(_.collect { case _: Star => true }.nonEmpty)

/**
* Returns the resolved `GetField`, and report error if no desired field or over one
* desired fields are found.
*/
protected def resolveGetField(expr: Expression, fieldName: String): Expression = {
def findField(fields: Array[StructField]): Int = {
val checkField = (f: StructField) => resolver(f.name, fieldName)
val ordinal = fields.indexWhere(checkField)
if (ordinal == -1) {
throw new AnalysisException(
s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}")
} else if (fields.indexWhere(checkField, ordinal + 1) != -1) {
throw new AnalysisException(
s"Ambiguous reference to fields ${fields.filter(checkField).mkString(", ")}")
} else {
ordinal
}
}
expr.dataType match {
case StructType(fields) =>
val ordinal = findField(fields)
StructGetField(expr, fields(ordinal), ordinal)
case ArrayType(StructType(fields), containsNull) =>
val ordinal = findField(fields)
ArrayGetField(expr, fields(ordinal), ordinal, containsNull)
case otherType =>
throw new AnalysisException(s"GetField is not valid on fields of type $otherType")
}
}
}

/**
Expand All @@ -320,8 +292,7 @@ class Analyzer(catalog: Catalog,
case s @ Sort(ordering, global, p @ Project(projectList, child))
if !s.resolved && p.resolved =>
val unresolved = ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name })
val resolved = unresolved.flatMap(child.resolve(_, resolver))
val requiredAttributes = AttributeSet(resolved.collect { case a: Attribute => a })
val requiredAttributes = AttributeSet(unresolved.flatMap(child.resolve(_, resolver)))

val missingInProject = requiredAttributes -- p.output
if (missingInProject.nonEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.Logging
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.{UnresolvedGetField, Resolver}
import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.types.{ArrayType, StructType, StructField}


abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
Expand Down Expand Up @@ -192,14 +193,17 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {

// One match, but we also need to extract the requested nested field.
case Seq((a, nestedFields)) =>
// The foldLeft adds UnresolvedGetField for every remaining parts of the name,
// and aliased it with the last part of the name.
// For example, consider name "a.b.c", where "a" is resolved to an existing attribute.
// Then this will add UnresolvedGetField("b") and UnresolvedGetField("c"), and alias
// the final expression as "c".
val fieldExprs = nestedFields.foldLeft(a: Expression)(UnresolvedGetField)
val aliasName = nestedFields.last
Some(Alias(fieldExprs, aliasName)())
// The foldLeft will resolve all of the nested data type, to get its attributes.
val fieldExprs = nestedFields.foldLeft(a: Expression) { case (e, fieldName) =>
resolveGetField(e, fieldName, resolver)
}

// TODO the alias name is quite tricky to me, set it to _col1, _col2.. ?
// Set it as original attribute name like "a.b.c" seems still confusing,
// and we may never reference this column by its name (with "."), except
// people write SQL like: SELECT a.b.c as newCol FROM nestedTable, which
// explicitly specifying the alias name for the output column
Some(Alias(fieldExprs, name)())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The python test failure is caused by replacing aliasName with name here. Is it okay? SELECT a.b.c FROM table would get attribute named a.b.c instead of c before.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @viirya I've updated the python code.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant should it be that? In Hive it should be c instead of a.b.c?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not so sure how Hive handle that, but it can not be c; otherwise it may cause reference arbitrary for its parent logical plan.

e.g.
Assume we have table tbl with schema Struct < a : Struct < b : Int, c: Int>, b: int>

SELECT b FROM (SELECT a.b, b FROM tbl)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we can change the default alias when extracting nested fields. I believe we match hive behaviors now, and this would break existing queries.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree we shouldn't break the existed logic, but I believe this is a bug of Hive.

hive>create table struct1 as select named_struct("a",key, "b", value) as a, key as b from src limit 1;
hive>select a.b, b from struct1; -- Works
hive>create table struct2 as select a.b, b from struct1;
FAILED: SemanticException [Error 10036]: Duplicate column name: b

I am wondering if we can break the naming rule of Hive for nested data type references, which always causes ambiguous.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is this a bug? These are pretty contrived examples. How often do you actually have nested structures where the outside name is the same as the inside name?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, I shouldn't say "always", but "possible", it maybe quite often while with join.


// No matches.
case Seq() =>
Expand All @@ -212,6 +216,36 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
s"Ambiguous references to $name: ${ambiguousReferences.mkString(",")}")
}
}

/**
* Returns the resolved `GetField`, and report error if no desired field or over one
* desired fields are found.
*/
def resolveGetField(expr: Expression, fieldName: String, resolver: Resolver): Expression = {
def findField(fields: Array[StructField]): Int = {
val checkField = (f: StructField) => resolver(f.name, fieldName)
val ordinal = fields.indexWhere(checkField)
if (ordinal == -1) {
throw new AnalysisException(
s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}")
} else if (fields.indexWhere(checkField, ordinal + 1) != -1) {
throw new AnalysisException(
s"Ambiguous reference to fields ${fields.filter(checkField).mkString(", ")}")
} else {
ordinal
}
}
expr.dataType match {
case StructType(fields) =>
val ordinal = findField(fields)
StructGetField(expr, fields(ordinal), ordinal)
case ArrayType(StructType(fields), containsNull) =>
val ordinal = findField(fields)
ArrayGetField(expr, fields(ordinal), ordinal, containsNull)
case otherType =>
throw new AnalysisException(s"GetField is not valid on fields of type $otherType")
}
}
}

/**
Expand Down
18 changes: 17 additions & 1 deletion sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,22 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
)
}

test("SPARK-6145 order by the nested data #1") {
sqlCtx.jsonRDD(sqlCtx.sparkContext.parallelize(
"""{"a": {"b": {"d": 1}}, "c": 1}""" :: Nil)).registerTempTable("nestedOrder")

checkAnswer(sqlCtx.sql("SELECT 1 FROM nestedOrder ORDER BY c"), Row(1))
checkAnswer(sqlCtx.sql("SELECT 1 FROM nestedOrder ORDER BY a.b.d"), Row(1))
checkAnswer(sqlCtx.sql("SELECT a.b.d FROM nestedOrder ORDER BY a.b.d"), Row(1))
}

test("SPARK-6145 order by the nested data #2") {
sqlCtx.jsonRDD(sqlCtx.sparkContext.parallelize(
"""{"a": {"a": {"a": 1}}, "c": 1}""" :: Nil)).registerTempTable("nestedOrder")

checkAnswer(sqlCtx.sql("SELECT a.a.a FROM nestedOrder ORDER BY a.a.a"), Row(1))
}

test("grouping on nested fields") {
jsonRDD(sparkContext.parallelize("""{"nested": {"attribute": 1}, "value": 2}""" :: Nil))
.registerTempTable("rows")
Expand All @@ -52,7 +68,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
"""
|select attribute, sum(cnt)
|from (
| select nested.attribute, count(*) as cnt
| select nested.attribute as attribute, count(*) as cnt
| from rows
| group by nested.attribute) a
|group by attribute
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class HiveResolutionSuite extends HiveComparisonTest {
.toDF().registerTempTable("caseSensitivityTest")

val query = sql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest")
assert(query.schema.fields.map(_.name) === Seq("a", "b", "A", "B", "a", "b", "A", "B"),
assert(query.schema.fields.map(_.name) === Seq("a", "b", "A", "B", "n.a", "n.b", "n.A", "n.B"),
"The output schema did not preserve the case of the query.")
query.collect()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,15 @@ class SQLQuerySuite extends QueryTest {
Row(1) :: Row(2) :: Row(3) :: Nil)
}

test("SPARK-6145 insert into table by selecting data from a nested table") {
jsonRDD(sparkContext.parallelize(
"""{"a": {"a": {"a": 1}}, "c": 1}""" :: Nil)).registerTempTable("nestedOrder")

sql("CREATE TABLE gen_tmp_6145 (key Int)")
sql("INSERT INTO table gen_tmp_6145 SELECT a.a.a from nestedOrder")
sql("DROP TABLE gen_tmp_6145")
}

test("SPARK-4512 Fix attribute reference resolution error when using SORT BY") {
checkAnswer(
sql("SELECT * FROM (SELECT key + key AS a FROM src SORT BY value) t ORDER BY t.a"),
Expand Down