Skip to content

Commit 85f3106

Browse files
committed
add CreateStruct
1 parent 4fc4d03 commit 85f3106

File tree

3 files changed

+40
-1
lines changed

3 files changed

+40
-1
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,12 @@ class Analyzer(catalog: Catalog,
212212
case o => o :: Nil
213213
}
214214
Alias(c.copy(children = expandedArgs), name)() :: Nil
215+
case Alias(c @ CreateStruct(args), name) if containsStar(args) =>
216+
val expandedArgs = args.flatMap {
217+
case s: Star => s.expand(child.output, resolver)
218+
case o => o :: Nil
219+
}
220+
Alias(c.copy(children = expandedArgs), name)() :: Nil
215221
case o => o :: Nil
216222
},
217223
child)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ case class ArrayGetField(child: Expression, field: StructField, ordinal: Int, co
120120
case class CreateArray(children: Seq[Expression]) extends Expression {
121121
override type EvaluatedType = Any
122122

123-
override def foldable: Boolean = !children.exists(!_.foldable)
123+
override def foldable: Boolean = children.forall(_.foldable)
124124

125125
lazy val childTypes = children.map(_.dataType).distinct
126126

@@ -142,3 +142,29 @@ case class CreateArray(children: Seq[Expression]) extends Expression {
142142

143143
override def toString: String = s"Array(${children.mkString(",")})"
144144
}
145+
146+
/**
147+
* Returns a Row containing the evaluation of all children expressions.
148+
*/
149+
case class CreateStruct(children: Seq[Expression]) extends Expression {
150+
override type EvaluatedType = Row
151+
152+
override def foldable: Boolean = children.forall(_.foldable)
153+
154+
override lazy val resolved: Boolean = childrenResolved
155+
156+
override def dataType: StructType = {
157+
assert(resolved, s"CreateStruct is called with unresolved children: $children.")
158+
val fields = children.map {
159+
case named: NamedExpression =>
160+
StructField(named.name, named.dataType, named.nullable, named.metadata)
161+
}
162+
StructType(fields)
163+
}
164+
165+
override def nullable: Boolean = false
166+
167+
override def eval(input: Row): EvaluatedType = {
168+
Row(children.map(_.eval(input)): _*)
169+
}
170+
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1080,4 +1080,11 @@ class ExpressionEvaluationSuite extends FunSuite {
10801080
checkEvaluation(c1 ^ c2, 3, row)
10811081
checkEvaluation(~c1, -2, row)
10821082
}
1083+
1084+
test("CreateStruct") {
1085+
val row = Row(1, 2, 3)
1086+
val c1 = 'a.int.at(0)
1087+
val c3 = 'a.int.at(2)
1088+
checkEvaluation(CreateStruct(Seq(c1, c3)), Row(1, 3), row)
1089+
}
10831090
}

0 commit comments

Comments
 (0)