Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
address nits
  • Loading branch information
BryanCutler committed Jun 22, 2018
commit c593650aa527241da1ddd7e433c626db0fa26bb5
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ case class AggregateInPandasExec(
val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536)
val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true)
val sessionLocalTimeZone = conf.sessionLocalTimeZone
val runnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)

val (pyFuncs, inputs) = udfExpressions.map(collectFunctions).unzip

Expand Down Expand Up @@ -143,7 +143,7 @@ case class AggregateInPandasExec(
argOffsets,
aggInputSchema,
sessionLocalTimeZone,
runnerConf).compute(projectedRowIter, context.partitionId(), context)
pythonRunnerConf).compute(projectedRowIter, context.partitionId(), context)

val joinedAttributes =
groupingExpressions.map(_.toAttribute) ++ udfExpressions.map(_.resultAttribute)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi

private val batchSize = conf.arrowMaxRecordsPerBatch
private val sessionLocalTimeZone = conf.sessionLocalTimeZone
private val runnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)

protected override def evaluate(
funcs: Seq[ChainedPythonFunctions],
Expand All @@ -88,7 +88,7 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi
argOffsets,
schema,
sessionLocalTimeZone,
runnerConf).compute(batchIter, context.partitionId(), context)
pythonRunnerConf).compute(batchIter, context.partitionId(), context)

new Iterator[InternalRow] {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,14 @@ class ArrowPythonRunner(
new WriterThread(env, worker, inputIterator, partitionIndex, context) {

protected override def writeCommand(dataOut: DataOutputStream): Unit = {

// Write config for the worker as a number of key -> value pairs of strings
dataOut.writeInt(conf.size)
Copy link
Contributor

@icexelloss icexelloss Jun 19, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe put this in a writeConf method to be more specific?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine, but I will add some comments

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, SGTM.

for ((k, v) <- conf) {
PythonRDD.writeUTF(k, dataOut)
PythonRDD.writeUTF(v, dataOut)
}

PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ case class FlatMapGroupsInPandasExec(
val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true)
val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction)))
val sessionLocalTimeZone = conf.sessionLocalTimeZone
val runnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)

// Deduplicate the grouping attributes.
// If a grouping attribute also appears in data attributes, then we don't need to send the
Expand Down Expand Up @@ -147,7 +147,7 @@ case class FlatMapGroupsInPandasExec(
argOffsets,
dedupSchema,
sessionLocalTimeZone,
runnerConf).compute(grouped, context.partitionId(), context)
pythonRunnerConf).compute(grouped, context.partitionId(), context)

columnarBatchIter.flatMap(_.rowIterator.asScala).map(UnsafeProjection.create(output, output))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ case class WindowInPandasExec(
val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536)
val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true)
val sessionLocalTimeZone = conf.sessionLocalTimeZone
val runnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)

// Extract window expressions and window functions
val expressions = windowExpression.flatMap(_.collect { case e: WindowExpression => e })
Expand Down Expand Up @@ -162,7 +162,7 @@ case class WindowInPandasExec(
argOffsets,
windowInputSchema,
sessionLocalTimeZone,
runnerConf).compute(pythonInput, context.partitionId(), context)
pythonRunnerConf).compute(pythonInput, context.partitionId(), context)

val joined = new JoinedRow
val resultProj = createResultProjection(expressions)
Expand Down