Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@ case class FilterNode(condition: Expression, child: LocalNode) extends UnaryLoca

override def next(): Boolean = {
var found = false
while (child.next() && !found) {
found = predicate.apply(child.get())
while (!found && child.next()) {
found = predicate.apply(child.fetch())
}
found
}

override def get(): InternalRow = child.get()
override def fetch(): InternalRow = child.fetch()

override def close(): Unit = child.close()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.local

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute


case class LimitNode(limit: Int, child: LocalNode) extends UnaryLocalNode {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a need to distinguish Unary operators from others?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe more generally, if we are never going to do transformations of these iterator trees, do they need to inherit from TreeNode?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we still need filter, or map for these iterator trees. @rxin is there anything I misunderstand for the LocalNode design?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may or may not. We can remove them later though.


private[this] var count = 0

override def output: Seq[Attribute] = child.output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do iterators need to know their output?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I guess this is useful for collect which is nice for debugging.


override def open(): Unit = child.open()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this also reset the count?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LocalNode cannot be reused, just like Iterator. So it's not necessary to reset it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually it'd be great to have a reset method in addition to open so we can revisit result of the iterator. We can do that later.


override def close(): Unit = child.close()

override def fetch(): InternalRow = child.fetch()

override def next(): Boolean = {
if (count < limit) {
count += 1
child.next()
} else {
false
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ abstract class LocalNode extends TreeNode[LocalNode] {
/**
* Returns the current tuple.
*/
def get(): InternalRow
def fetch(): InternalRow

/**
* Closes the iterator and releases all resources.
* Closes the iterator and releases all resources. It should be idempotent.
*
* Implementations of this must also call the `close()` function of its children.
*/
Expand All @@ -64,10 +64,13 @@ abstract class LocalNode extends TreeNode[LocalNode] {
val converter = CatalystTypeConverters.createToScalaConverter(StructType.fromAttributes(output))
val result = new scala.collection.mutable.ArrayBuffer[Row]
open()
while (next()) {
result += converter.apply(get()).asInstanceOf[Row]
try {
while (next()) {
result += converter.apply(fetch()).asInstanceOf[Row]
}
} finally {
close()
}
close()
result
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ case class ProjectNode(projectList: Seq[NamedExpression], child: LocalNode) exte

override def next(): Boolean = child.next()

override def get(): InternalRow = {
project.apply(child.get())
override def fetch(): InternalRow = {
project.apply(child.fetch())
}

override def close(): Unit = child.close()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ case class SeqScanNode(output: Seq[Attribute], data: Seq[InternalRow]) extends L
}
}

override def get(): InternalRow = currentRow
override def fetch(): InternalRow = currentRow

override def close(): Unit = {
// Do nothing
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.local

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute

case class UnionNode(children: Seq[LocalNode]) extends LocalNode {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we add an assert to make sure all children have same output?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Different children may have different AttributeReferences in output although the types are same. I checked org.apache.spark.sql.execution.Union and it also doesn't assert it. It should be safe since here this is a physical operator.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider making this an Array[LocalNode]. In general, we should probably only be using Array as this level of execution.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 on making this Array

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It overrides TreeNode.children which uses Seq.


override def output: Seq[Attribute] = children.head.output

private[this] var currentChild: LocalNode = _

private[this] var nextChildIndex: Int = _

override def open(): Unit = {
currentChild = children.head
currentChild.open()
nextChildIndex = 1
}

private def advanceToNextChild(): Boolean = {
var found = false
var exit = false
while (!exit && !found) {
if (currentChild != null) {
currentChild.close()
}
if (nextChildIndex >= children.size) {
found = false
exit = true
} else {
currentChild = children(nextChildIndex)
nextChildIndex += 1
currentChild.open()
found = currentChild.next()
}
}
found
}

override def close(): Unit = {
if (currentChild != null) {
currentChild.close()
}
}

override def fetch(): InternalRow = currentChild.fetch()

override def next(): Boolean = {
if (currentChild.next()) {
true
} else {
advanceToNextChild()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import scala.util.control.NonFatal
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.{DataFrame, DataFrameHolder, Row, SQLContext}
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.test.SQLTestUtils

/**
* Base class for writing tests for individual physical operators. For an example of how this
Expand Down Expand Up @@ -184,7 +184,7 @@ object SparkPlanTest {
return Some(errorMessage)
}

compareAnswers(actualAnswer, expectedAnswer, sortAnswers).map { errorMessage =>
SQLTestUtils.compareAnswers(actualAnswer, expectedAnswer, sortAnswers).map { errorMessage =>
s"""
| Results do not match.
| Actual result Spark plan:
Expand Down Expand Up @@ -229,7 +229,7 @@ object SparkPlanTest {
return Some(errorMessage)
}

compareAnswers(sparkAnswer, expectedAnswer, sortAnswers).map { errorMessage =>
SQLTestUtils.compareAnswers(sparkAnswer, expectedAnswer, sortAnswers).map { errorMessage =>
s"""
| Results do not match for Spark plan:
| $outputPlan
Expand All @@ -238,46 +238,6 @@ object SparkPlanTest {
}
}

private def compareAnswers(
sparkAnswer: Seq[Row],
expectedAnswer: Seq[Row],
sort: Boolean): Option[String] = {
def prepareAnswer(answer: Seq[Row]): Seq[Row] = {
// Converts data to types that we can do equality comparison using Scala collections.
// For BigDecimal type, the Scala type has a better definition of equality test (similar to
// Java's java.math.BigDecimal.compareTo).
// For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for
// equality test.
// This function is copied from Catalyst's QueryTest
val converted: Seq[Row] = answer.map { s =>
Row.fromSeq(s.toSeq.map {
case d: java.math.BigDecimal => BigDecimal(d)
case b: Array[Byte] => b.toSeq
case o => o
})
}
if (sort) {
converted.sortBy(_.toString())
} else {
converted
}
}
if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) {
val errorMessage =
s"""
| == Results ==
| ${sideBySide(
s"== Expected Answer - ${expectedAnswer.size} ==" +:
prepareAnswer(expectedAnswer).map(_.toString()),
s"== Actual Answer - ${sparkAnswer.size} ==" +:
prepareAnswer(sparkAnswer).map(_.toString())).mkString("\n")}
""".stripMargin
Some(errorMessage)
} else {
None
}
}

private def executePlan(outputPlan: SparkPlan, _sqlContext: SQLContext): Seq[Row] = {
// A very simple resolver to make writing tests easier. In contrast to the real resolver
// this is always case sensitive and does not try to handle scoping or complex type resolution.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.local

import org.apache.spark.sql.test.SharedSQLContext

class FilterNodeSuite extends LocalNodeTest with SharedSQLContext {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so I was thinking it would be great if we can get rid of the SQLContext in the test cases for these local stuff.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just want to reuse SQLTestData. It would be easy if we could use DataFrame to test the local stuff. I feel if we make sure not using SQLContext in the main codes of LocalNodes, we don't need to get rid of SQLContext in the test cases.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that SQLTestData is going away though.... #7406

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that SQLTestData is going away though.... #7406

Didn't notice that. I will add test data for each test case manually.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just realized we need to call DataFrame.resolve to create a Column. Looks it's hard to get rid of SQLContext only in tests. I think it's better to do it in a separate PR to add Analyzer for LocalNode.


test("basic") {
val condition = (testData.col("key") % 2) === 0
checkAnswer(
testData,
node => FilterNode(condition.expr, node),
testData.filter(condition).collect()
)
}

test("empty") {
val condition = (emptyTestData.col("key") % 2) === 0
checkAnswer(
emptyTestData,
node => FilterNode(condition.expr, node),
emptyTestData.filter(condition).collect()
)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.local

import org.apache.spark.sql.test.SharedSQLContext

class LimitNodeSuite extends LocalNodeTest with SharedSQLContext {

test("basic") {
checkAnswer(
testData,
node => LimitNode(10, node),
testData.limit(10).collect()
)
}

test("empty") {
checkAnswer(
emptyTestData,
node => LimitNode(10, node),
emptyTestData.limit(10).collect()
)
}
}
Loading