Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
address review comments
  • Loading branch information
kiszk committed Apr 17, 2017
commit f695e50e38bd329db3b75951dd7af52fea3b3dde
Original file line number Diff line number Diff line change
Expand Up @@ -104,20 +104,12 @@ object CombineTypedFilters extends Rule[LogicalPlan] {
* 1. Mapobject(e) where e is lambdavariable(), which means types for input output
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is obscure. Can we improve it a bit?

For example, MapObject(e) is confusing. Shall we clearly say the lambdaFunction of MapObject?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, done

* are primitive types
* 2. no custom collection class specified
* representation of data item. For example back to back map operations.
* representation of data item. For example back to back map operations.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we rephrase this comment too? It looks weird.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, deleted

*/
object EliminateMapObjects extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we call plan.transformAllExpressions?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, it works. done

case _ @ DeserializeToObject(Invoke(
MapObjects(_, _, _, Cast(LambdaVariable(_, _, dataType, _), castDataType, _),
inputData, None),
funcName, returnType: ObjectType, arguments, propagateNull, returnNullable),
outputObjAttr, child) if dataType == castDataType =>
DeserializeToObject(Invoke(
inputData, funcName, returnType, arguments, propagateNull, returnNullable),
outputObjAttr, child)
case _ @ DeserializeToObject(Invoke(
MapObjects(_, _, _, LambdaVariable(_, _, dataType, _), inputData, None),
case DeserializeToObject(Invoke(
MapObjects(_, _, _, _ : LambdaVariable, inputData, None),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we just replace MapObjects with its child? Seems the only reason you match the whole DeserializeToObject is to make sure the returnType is object type, but that's guaranteed if the collectionClass is None.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To replace MapObjects with its child is a type LogicalPlan => Expression while this method requires LogicalPlan => LogicalPlan.
Is it fine to replace Invoke(MapObject(..., inputData, None)...) with Invoke(inputData, ...)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan, I misunderstood. Both Expression and Invoke are not LogicalPlan.
I think that we have to replace some of arguments in DeserializeToObject.
What do you think?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan Unfortunately, this change caused two test failures.
"checkAnswer should compare map correctly" and "SPARK-18717: code generation works for both scala.collection.Map and scala.collection.imutable.Map" in DatasetSuite
I will check what's happen very soon.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we run the following code, the following exception occurs. This is because UnsafeArrayData.copy(), which is unsupported, is called.

    val ds = Seq((1, Map(2 -> 3))).toDS.map(t => t)
    ds.collect.toSeq

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The solution is to use the following matching:

  def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
     case MapObjects(_, _, _, LambdaVariable(_, _, _, false), inputData, None) => inputData
  }

Previous one avoided the case such as Object[] -> Integer[]. Thus, the generated code incorrectly called UnsafeArrayData.array().

funcName, returnType: ObjectType, arguments, propagateNull, returnNullable),
outputObjAttr, child) =>
DeserializeToObject(Invoke(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,12 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types._

class EliminateMapObjectsSuite extends PlanTest {
class Optimize(addSimplifyCast: Boolean) extends RuleExecutor[LogicalPlan] {
val batches = if (addSimplifyCast) {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches = {
Batch("EliminateMapObjects", FixedPoint(50),
NullPropagation(conf),
SimplifyCasts,
EliminateMapObjects) :: Nil
} else {
Batch("EliminateMapObjects", FixedPoint(50),
NullPropagation(conf),
EliminateMapObjects) :: Nil
}
}

Expand All @@ -48,23 +44,19 @@ class EliminateMapObjectsSuite extends PlanTest {
val intObjType = ObjectType(classOf[Array[Int]])
val intInput = LocalRelation('a.array(ArrayType(IntegerType, false)))
val intQuery = intInput.deserialize[Array[Int]].analyze
Seq(true, false).foreach { addSimplifyCast =>
val intOptimized = new Optimize(addSimplifyCast).execute(intQuery)
val intExpected = DeserializeToObject(
Invoke(intInput.output(0), "toIntArray", intObjType, Nil, true, false),
AttributeReference("obj", intObjType, true)(), intInput)
comparePlans(intOptimized, intExpected)
}
val intOptimized = Optimize.execute(intQuery)
val intExpected = DeserializeToObject(
Invoke(intInput.output(0), "toIntArray", intObjType, Nil, true, false),
AttributeReference("obj", intObjType, true)(), intInput)
comparePlans(intOptimized, intExpected)

val doubleObjType = ObjectType(classOf[Array[Double]])
val doubleInput = LocalRelation('a.array(ArrayType(DoubleType, false)))
val doubleQuery = doubleInput.deserialize[Array[Double]].analyze
Seq(true, false).foreach { addSimplifyCast =>
val doubleOptimized = new Optimize(addSimplifyCast).execute(doubleQuery)
val doubleExpected = DeserializeToObject(
Invoke(doubleInput.output(0), "toDoubleArray", doubleObjType, Nil, true, false),
AttributeReference("obj", doubleObjType, true)(), doubleInput)
comparePlans(doubleOptimized, doubleExpected)
}
val doubleOptimized = Optimize.execute(doubleQuery)
val doubleExpected = DeserializeToObject(
Invoke(doubleInput.output(0), "toDoubleArray", doubleObjType, Nil, true, false),
AttributeReference("obj", doubleObjType, true)(), doubleInput)
comparePlans(doubleOptimized, doubleExpected)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ package org.apache.spark.sql
import scala.collection.immutable.Queue
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.execution.DeserializeToObjectExec
import org.apache.spark.sql.test.SharedSQLContext

case class IntClass(value: Int)
Expand Down