Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
passing conf map to runner, tests pass
  • Loading branch information
BryanCutler committed Jun 18, 2018
commit 5a7edb2bc30dc7fe93d19504e78fbf83c1f525d9
9 changes: 7 additions & 2 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,13 @@ def read_udfs(pickleSer, infile, eval_type):
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF):
timezone = utf8_deserializer.loads(infile)
ser = ArrowStreamPandasSerializer(timezone)
runner_conf = {}
num_conf = read_int(infile)
for i in range(num_conf):
k = utf8_deserializer.loads(infile)
v = utf8_deserializer.loads(infile)
runner_conf[k] = v
ser = ArrowStreamPandasSerializer(runner_conf.get("spark.sql.session.timeZone", None))
else:
ser = BatchedSerializer(PickleSerializer(), 100)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning}
import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, StructField, StructType}
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -134,11 +135,23 @@ case class AggregateInPandasExec(
rows
}

val timeZoneConf = if (pandasRespectSessionTimeZone) {
Seq(SQLConf.SESSION_LOCAL_TIMEZONE.key -> sessionLocalTimeZone)
} else {
Nil
}
val runnerConfEntries = Seq() ++ timeZoneConf
val runnerConf = Map(runnerConfEntries: _*)

val columnarBatchIter = new ArrowPythonRunner(
pyFuncs, bufferSize, reuseWorker,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, argOffsets, aggInputSchema,
sessionLocalTimeZone, pandasRespectSessionTimeZone)
.compute(projectedRowIter, context.partitionId(), context)
pyFuncs,
bufferSize,
reuseWorker,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
argOffsets,
aggInputSchema,
sessionLocalTimeZone,
runnerConf).compute(projectedRowIter, context.partitionId(), context)

val joinedAttributes =
groupingExpressions.map(_.toAttribute) ++ udfExpressions.map(_.resultAttribute)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType

/**
Expand Down Expand Up @@ -79,11 +80,23 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi
// DO NOT use iter.grouped(). See BatchIterator.
val batchIter = if (batchSize > 0) new BatchIterator(iter, batchSize) else Iterator(iter)

val timeZoneConf = if (pandasRespectSessionTimeZone) {
Seq(SQLConf.SESSION_LOCAL_TIMEZONE.key -> sessionLocalTimeZone)
} else {
Nil
}
val runnerConfEntries = Seq() ++ timeZoneConf
val runnerConf = Map(runnerConfEntries: _*)

val columnarBatchIter = new ArrowPythonRunner(
funcs, bufferSize, reuseWorker,
PythonEvalType.SQL_SCALAR_PANDAS_UDF, argOffsets, schema,
sessionLocalTimeZone, pandasRespectSessionTimeZone)
.compute(batchIter, context.partitionId(), context)
funcs,
bufferSize,
reuseWorker,
PythonEvalType.SQL_SCALAR_PANDAS_UDF,
argOffsets,
schema,
sessionLocalTimeZone,
runnerConf).compute(batchIter, context.partitionId(), context)

new Iterator[InternalRow] {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class ArrowPythonRunner(
argOffsets: Array[Array[Int]],
schema: StructType,
timeZoneId: String,
respectTimeZone: Boolean)
conf: Map[String, String])
extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch](
funcs, bufferSize, reuseWorker, evalType, argOffsets) {

Expand All @@ -59,17 +59,17 @@ class ArrowPythonRunner(

protected override def writeCommand(dataOut: DataOutputStream): Unit = {
PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets)
if (respectTimeZone) {
PythonRDD.writeUTF(timeZoneId, dataOut)
} else {
dataOut.writeInt(SpecialLengths.NULL)
dataOut.writeInt(conf.size)
Copy link
Contributor

@icexelloss icexelloss Jun 19, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe put this in a writeConf method to be more specific?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine, but I will add some comments

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, SGTM.

for ((k, v) <- conf) {
PythonRDD.writeUTF(k, dataOut)
PythonRDD.writeUTF(v, dataOut)
}
}

protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = {
val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
val allocator = ArrowUtils.rootAllocator.newChildAllocator(
s"stdout writer for $pythonExec", 0, Long.MaxValue)
val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change this back, accidental

val root = VectorSchemaRoot.create(arrowSchema, allocator)

Utils.tryWithSafeFinally {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning}
import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType

/**
Expand Down Expand Up @@ -137,12 +138,23 @@ case class FlatMapGroupsInPandasExec(
}

val context = TaskContext.get()
val timeZoneConf = if (pandasRespectSessionTimeZone) {
Seq(SQLConf.SESSION_LOCAL_TIMEZONE.key -> sessionLocalTimeZone)
} else {
Nil
}
val runnerConfEntries = Seq() ++ timeZoneConf
val runnerConf = Map(runnerConfEntries: _*)

val columnarBatchIter = new ArrowPythonRunner(
chainedFunc, bufferSize, reuseWorker,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, argOffsets, dedupSchema,
sessionLocalTimeZone, pandasRespectSessionTimeZone)
.compute(grouped, context.partitionId(), context)
chainedFunc,
bufferSize,
reuseWorker,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
argOffsets,
dedupSchema,
sessionLocalTimeZone,
runnerConf).compute(grouped, context.partitionId(), context)

columnarBatchIter.flatMap(_.rowIterator.asScala).map(UnsafeProjection.create(output, output))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, StructField, StructType}
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -153,12 +154,23 @@ case class WindowInPandasExec(
}
}

val timeZoneConf = if (pandasRespectSessionTimeZone) {
Seq(SQLConf.SESSION_LOCAL_TIMEZONE.key -> sessionLocalTimeZone)
} else {
Nil
}
val runnerConfEntries = Seq() ++ timeZoneConf
val runnerConf = Map(runnerConfEntries: _*)

val windowFunctionResult = new ArrowPythonRunner(
pyFuncs, bufferSize, reuseWorker,
pyFuncs,
bufferSize,
reuseWorker,
PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF,
argOffsets, windowInputSchema,
sessionLocalTimeZone, pandasRespectSessionTimeZone)
.compute(pythonInput, context.partitionId(), context)
argOffsets,
windowInputSchema,
sessionLocalTimeZone,
runnerConf).compute(pythonInput, context.partitionId(), context)

val joined = new JoinedRow
val resultProj = createResultProjection(expressions)
Expand Down