Skip to content

Commit 6dc2c9e

Browse files
committed
MemorySink statistics for joinintg
1 parent 95ec4e2 commit 6dc2c9e

File tree

2 files changed

+30
-1
lines changed

2 files changed

+30
-1
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import org.apache.spark.internal.Logging
2727
import org.apache.spark.sql._
2828
import org.apache.spark.sql.catalyst.encoders.encoderFor
2929
import org.apache.spark.sql.catalyst.expressions.Attribute
30-
import org.apache.spark.sql.catalyst.plans.logical.LeafNode
30+
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics}
3131
import org.apache.spark.sql.streaming.OutputMode
3232
import org.apache.spark.sql.types.StructType
3333
import org.apache.spark.util.Utils
@@ -212,4 +212,8 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi
212212
*/
213213
case class MemoryPlan(sink: MemorySink, output: Seq[Attribute]) extends LeafNode {
214214
def this(sink: MemorySink) = this(sink, sink.schema.toAttributes)
215+
216+
private val sizePerRow = sink.schema.toAttributes.map(_.dataType.defaultSize).sum
217+
218+
override def statistics: Statistics = Statistics(sizePerRow * sink.allData.size)
215219
}

sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,31 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter {
187187
query.stop()
188188
}
189189

190+
test("MemoryPlan statistics for joining") {
191+
val input = MemoryStream[Int]
192+
val query = input.toDF()
193+
.writeStream
194+
.format("memory")
195+
.queryName("memStream")
196+
.start()
197+
198+
val memStream = spark.table("memStream").as[Int]
199+
200+
input.addData(1)
201+
query.processAllAvailable()
202+
checkDatasetUnorderly(
203+
memStream.crossJoin(memStream.withColumnRenamed("value", "value2")).as[(Int, Int)],
204+
(1, 1))
205+
206+
input.addData(2)
207+
query.processAllAvailable()
208+
checkDatasetUnorderly(
209+
memStream.crossJoin(memStream.withColumnRenamed("value", "value2")).as[(Int, Int)],
210+
(1, 1), (1, 2), (2, 1), (2, 2))
211+
212+
query.stop()
213+
}
214+
190215
ignore("stress test") {
191216
// Ignore the stress test as it takes several minutes to run
192217
(0 until 1000).foreach { _ =>

0 commit comments

Comments
 (0)