From 551bf526a1d2cf2888660e73e953f5d9cc5e232a Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Fri, 25 Jun 2021 17:28:30 +0900 Subject: [PATCH 01/11] [SPARK-34893][SS] Support session window natively --- python/pyspark/sql/functions.py | 35 ++ python/pyspark/sql/functions.pyi | 1 + .../sql/catalyst/analysis/Analyzer.scala | 86 +++- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../catalyst/expressions/SessionWindow.scala | 103 ++++ .../sql/errors/QueryCompilationErrors.scala | 5 +- .../apache/spark/sql/internal/SQLConf.scala | 23 + .../spark/sql/execution/SparkStrategies.scala | 33 +- .../sql/execution/aggregate/AggUtils.scala | 187 ++++++- .../aggregate/UpdatingSessionsIterator.scala | 3 +- .../python/AggregateInPandasExec.scala | 51 +- .../streaming/IncrementalExecution.scala | 20 + .../streaming/statefulOperators.scala | 292 +++++++++++ .../org/apache/spark/sql/functions.scala | 29 ++ .../sql-functions/sql-expression-schema.md | 7 +- .../sql/DataFrameSessionWindowingSuite.scala | 291 +++++++++++ .../sql/DataFrameTimeWindowingSuite.scala | 2 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 3 +- .../sql/expressions/ExpressionInfoSuite.scala | 1 + .../StreamingSessionWindowSuite.scala | 456 ++++++++++++++++++ 20 files changed, 1598 insertions(+), 31 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 9240ae6a8c51..295372fa4cc1 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2333,6 +2333,41 @@ def check_string_field(field, fieldName): return Column(res) +def session_window(timeColumn, gapDuration): + """ + Generates session window given a timestamp specifying column. + Session window is the one of dynamic windows, which means the length of window is vary + according to the given inputs. The length of session window is defined as "the timestamp + of latest input of the session + gap duration", so when the new inputs are bound to the + current session window, the end time of session window can be expanded according to the new + inputs. + Windows can support microsecond precision. Windows in the order of months are not supported. + For a streaming query, you may use the function `current_timestamp` to generate windows on + processing time. + gapDuration is provided as strings, e.g. '1 second', '1 day 12 hours', '2 minutes'. Valid + interval strings are 'week', 'day', 'hour', 'minute', 'second', 'millisecond', 'microsecond'. + The output column will be a struct called 'session_window' by default with the nested columns + 'start' and 'end', where 'start' and 'end' will be of :class:`pyspark.sql.types.TimestampType`. + .. versionadded:: 3.2.0 + Examples + -------- + >>> df = spark.createDataFrame([("2016-03-11 09:00:07", 1)]).toDF("date", "val") + >>> w = df.groupBy(session_window("date", "5 seconds")).agg(sum("val").alias("sum")) + >>> w.select(w.session_window.start.cast("string").alias("start"), + ... w.session_window.end.cast("string").alias("end"), "sum").collect() + [Row(start='2016-03-11 09:00:07', end='2016-03-11 09:00:12', sum=1)] + """ + def check_string_field(field, fieldName): + if not field or type(field) is not str: + raise TypeError("%s should be provided as a string" % fieldName) + + sc = SparkContext._active_spark_context + time_col = _to_java_column(timeColumn) + check_string_field(gapDuration, "gapDuration") + res = sc._jvm.functions.session_window(time_col, gapDuration) + return Column(res) + + # ---------------------------- misc functions ---------------------------------- def crc32(col): diff --git a/python/pyspark/sql/functions.pyi b/python/pyspark/sql/functions.pyi index 0a4aabf1bf71..051a6f1dbc53 100644 --- a/python/pyspark/sql/functions.pyi +++ b/python/pyspark/sql/functions.pyi @@ -135,6 +135,7 @@ def window( slideDuration: Optional[str] = ..., startTime: Optional[str] = ..., ) -> Column: ... +def session_window(timeColumn: ColumnOrName, gapDuration: str) -> Column: ... def crc32(col: ColumnOrName) -> Column: ... def md5(col: ColumnOrName) -> Column: ... def sha1(col: ColumnOrName) -> Column: ... diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 49390d35bd1d..1c5f1338e0fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -296,6 +296,7 @@ class Analyzer(override val catalogManager: CatalogManager) GlobalAggregates :: ResolveAggregateFunctions :: TimeWindowing :: + SessionWindowing :: ResolveInlineTables :: ResolveHigherOrderFunctions(catalogManager) :: ResolveLambdaVariables :: @@ -3856,9 +3857,13 @@ object TimeWindowing extends Rule[LogicalPlan] { val windowExpressions = p.expressions.flatMap(_.collect { case t: TimeWindow => t }).toSet - val numWindowExpr = windowExpressions.size + val numWindowExpr = p.expressions.flatMap(_.collect { + case s: SessionWindow => s + case t: TimeWindow => t + }).toSet.size + // Only support a single window expression for now - if (numWindowExpr == 1 && + if (numWindowExpr == 1 && windowExpressions.nonEmpty && windowExpressions.head.timeColumn.resolved && windowExpressions.head.checkInputDataTypes().isSuccess) { @@ -3933,6 +3938,83 @@ object TimeWindowing extends Rule[LogicalPlan] { } } +/** Maps a time column to a session window. */ +object SessionWindowing extends Rule[LogicalPlan] { + import org.apache.spark.sql.catalyst.dsl.expressions._ + + private final val SESSION_COL_NAME = "session_window" + private final val SESSION_START = "start" + private final val SESSION_END = "end" + + /** + * Generates the logical plan for generating session window on a timestamp column. + * Each session window is initially defined as [timestamp, timestamp + gap). + * + * This also adds a marker to the session column so that downstream can easily find the column + * on session window. + */ + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { + case p: LogicalPlan if p.children.size == 1 => + val child = p.children.head + val sessionExpressions = + p.expressions.flatMap(_.collect { case s: SessionWindow => s }).toSet + + val numWindowExpr = p.expressions.flatMap(_.collect { + case s: SessionWindow => s + case t: TimeWindow => t + }).toSet.size + + // Only support a single session expression for now + if (numWindowExpr == 1 && sessionExpressions.nonEmpty && + sessionExpressions.head.timeColumn.resolved && + sessionExpressions.head.checkInputDataTypes().isSuccess) { + + val session = sessionExpressions.head + + val metadata = session.timeColumn match { + case a: Attribute => a.metadata + case _ => Metadata.empty + } + + val newMetadata = new MetadataBuilder() + .withMetadata(metadata) + .putBoolean(SessionWindow.marker, true) + .build() + + val sessionAttr = AttributeReference( + SESSION_COL_NAME, session.dataType, metadata = newMetadata)() + + val sessionStart = PreciseTimestampConversion(session.timeColumn, TimestampType, LongType) + val sessionEnd = sessionStart + session.gapDuration + + val literalSessionStruct = CreateNamedStruct( + Literal(SESSION_START) :: + PreciseTimestampConversion(sessionStart, LongType, TimestampType) :: + Literal(SESSION_END) :: + PreciseTimestampConversion(sessionEnd, LongType, TimestampType) :: + Nil) + + val sessionStruct = Alias(literalSessionStruct, SESSION_COL_NAME)( + exprId = sessionAttr.exprId, explicitMetadata = Some(newMetadata)) + + val replacedPlan = p transformExpressions { + case s: SessionWindow => sessionAttr + } + + // For backwards compatibility we add a filter to filter out nulls + val filterExpr = IsNotNull(session.timeColumn) + + replacedPlan.withNewChildren( + Filter(filterExpr, + Project(sessionStruct +: child.output, child)) :: Nil) + } else if (numWindowExpr > 1) { + throw QueryCompilationErrors.multiTimeWindowExpressionsNotSupportedError(p) + } else { + p // Return unchanged. Analyzer will throw exception later + } + } +} + /** * Resolve expressions if they contains [[NamePlaceholder]]s. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 60ca1e96a5e3..234da76e06a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -552,6 +552,7 @@ object FunctionRegistry { expression[WeekOfYear]("weekofyear"), expression[Year]("year"), expression[TimeWindow]("window"), + expression[SessionWindow]("session_window"), expression[MakeDate]("make_date"), expression[MakeTimestamp]("make_timestamp"), expression[MakeTimestampNTZ]("make_timestamp_ntz", true), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala new file mode 100644 index 000000000000..ccc451235100 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure +import org.apache.spark.sql.catalyst.util.{DateTimeConstants, IntervalUtils} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +case class SessionWindow(timeColumn: Expression, gapDuration: Long) extends UnaryExpression + with ImplicitCastInputTypes + with Unevaluable + with NonSQLExpression { + + ////////////////////////// + // SQL Constructors + ////////////////////////// + + def this(timeColumn: Expression, gapDuration: Expression) = { + this(timeColumn, SessionWindow.parseExpression(gapDuration)) + } + + override def child: Expression = timeColumn + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) + override def dataType: DataType = new StructType() + .add(StructField("start", TimestampType)) + .add(StructField("end", TimestampType)) + + // This expression is replaced in the analyzer. + override lazy val resolved = false + + /** Validate the inputs for the gap duration in addition to the input data type. */ + override def checkInputDataTypes(): TypeCheckResult = { + val dataTypeCheck = super.checkInputDataTypes() + if (dataTypeCheck.isSuccess) { + if (gapDuration <= 0) { + return TypeCheckFailure(s"The window duration ($gapDuration) must be greater than 0.") + } + } + dataTypeCheck + } + + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(timeColumn = newChild) +} + +object SessionWindow { + val marker = "spark.sessionWindow" + + /** + * Parses the interval string for a valid time duration. CalendarInterval expects interval + * strings to start with the string `interval`. For usability, we prepend `interval` to the string + * if the user omitted it. + * + * @param interval The interval string + * @return The interval duration in microseconds. SparkSQL casts TimestampType has microsecond + * precision. + */ + private def getIntervalInMicroSeconds(interval: String): Long = { + val cal = IntervalUtils.stringToInterval(UTF8String.fromString(interval)) + if (cal.months != 0) { + throw new IllegalArgumentException( + s"Intervals greater than a month is not supported ($interval).") + } + cal.days * DateTimeConstants.MICROS_PER_DAY + cal.microseconds + } + + /** + * Parses the duration expression to generate the long value for the original constructor so + * that we can use `window` in SQL. + */ + private def parseExpression(expr: Expression): Long = expr match { + case NonNullLiteral(s, StringType) => getIntervalInMicroSeconds(s.toString) + case IntegerLiteral(i) => i.toLong + case NonNullLiteral(l, LongType) => l.toString.toLong + case _ => throw new AnalysisException("The duration and time inputs to window must be " + + "an integer, long or string literal.") + } + + def apply( + timeColumn: Expression, + gapDuration: String): SessionWindow = { + SessionWindow(timeColumn, + getIntervalInMicroSeconds(gapDuration)) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 2cee614eb614..7a33d52a6573 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -366,8 +366,9 @@ private[spark] object QueryCompilationErrors { } def multiTimeWindowExpressionsNotSupportedError(t: TreeNode[_]): Throwable = { - new AnalysisException("Multiple time window expressions would result in a cartesian product " + - "of rows, therefore they are currently not supported.", t.origin.line, t.origin.startPosition) + new AnalysisException("Multiple time/session window expressions would result in a cartesian " + + "product of rows, therefore they are currently not supported.", t.origin.line, + t.origin.startPosition) } def viewOutputNumberMismatchQueryColumnNamesError( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index f1bfb1465b2f..fda99d6285d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1610,6 +1610,26 @@ object SQLConf { .checkValue(v => Set(1, 2).contains(v), "Valid versions are 1 and 2") .createWithDefault(2) + val STREAMING_SESSION_WINDOW_MERGE_SESSIONS_IN_LOCAL_PARTITION = + buildConf("spark.sql.streaming.sessionWindow.merge.sessions.in.local.partition") + .internal() + .doc("When true, streaming session window sorts and merge sessions in local partition " + + "prior to shuffle. This is to reduce the rows to shuffle, but only beneficial when " + + "there're lots of rows in a batch being assigned to same sessions.") + .booleanConf + .createWithDefault(false) + + val STREAMING_SESSION_WINDOW_STATE_FORMAT_VERSION = + buildConf("spark.sql.streaming.sessionWindow.stateFormatVersion") + .internal() + .doc("State format version used by streaming session window in a streaming query. " + + "State between versions are tend to be incompatible, so state format version shouldn't " + + "be modified after running.") + .version("3.2.0") + .intConf + .checkValue(v => Set(1).contains(v), "Valid version is 1") + .createWithDefault(1) + val UNSUPPORTED_OPERATION_CHECK_ENABLED = buildConf("spark.sql.streaming.unsupportedOperationCheck") .internal() @@ -3676,6 +3696,9 @@ class SQLConf extends Serializable with Logging { def fastHashAggregateRowMaxCapacityBit: Int = getConf(FAST_HASH_AGGREGATE_MAX_ROWS_CAPACITY_BIT) + def streamingSessionWindowMergeSessionInLocalPartition: Boolean = + getConf(STREAMING_SESSION_WINDOW_MERGE_SESSIONS_IN_LOCAL_PARTITION) + def datetimeJava8ApiEnabled: Boolean = getConf(DATETIME_JAVA8API_ENABLED) def uiExplainMode: String = getConf(UI_EXPLAIN_MODE) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 65a592302c6a..b4af14df5be6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -326,6 +326,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { val stateVersion = conf.getConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION) + val sessionWindowOption = namedGroupingExpressions.find { p => + p.metadata.contains(SessionWindow.marker) + } + // Ideally this should be done in `NormalizeFloatingNumbers`, but we do it here because // `groupingExpressions` is not extracted during logical phase. val normalizedGroupingExpressions = namedGroupingExpressions.map { e => @@ -335,12 +339,29 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } - AggUtils.planStreamingAggregation( - normalizedGroupingExpressions, - aggregateExpressions.map(expr => expr.asInstanceOf[AggregateExpression]), - rewrittenResultExpressions, - stateVersion, - planLater(child)) + sessionWindowOption match { + case Some(sessionWindow) => + val stateVersion = conf.getConf(SQLConf.STREAMING_SESSION_WINDOW_STATE_FORMAT_VERSION) + + AggUtils.planStreamingAggregationForSession( + normalizedGroupingExpressions, + sessionWindow, + aggregateExpressions.map(expr => expr.asInstanceOf[AggregateExpression]), + rewrittenResultExpressions, + stateVersion, + conf.streamingSessionWindowMergeSessionInLocalPartition, + planLater(child)) + + case None => + val stateVersion = conf.getConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION) + + AggUtils.planStreamingAggregation( + normalizedGroupingExpressions, + aggregateExpressions.map(expr => expr.asInstanceOf[AggregateExpression]), + rewrittenResultExpressions, + stateVersion, + planLater(child)) + } case _ => Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 58d341107347..975a8edb049c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.execution.aggregate +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.streaming.{StateStoreRestoreExec, StateStoreSaveExec} +import org.apache.spark.sql.execution.streaming._ /** * Utility functions used by the query planner to convert our plan to new aggregation code path. @@ -113,6 +114,9 @@ object AggUtils { resultExpressions = partialResultExpressions, child = child) + val interExec: SparkPlan = mayAppendMergingSessionExec(groupingExpressions, + aggregateExpressions, partialAggregate) + // 2. Create an Aggregate Operator for final aggregations. val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final)) // The attributes of the final aggregation buffer, which is presented as input to the result @@ -126,7 +130,7 @@ object AggUtils { aggregateAttributes = finalAggregateAttributes, initialInputBufferOffset = groupingExpressions.length, resultExpressions = resultExpressions, - child = partialAggregate) + child = interExec) finalAggregate :: Nil } @@ -140,6 +144,8 @@ object AggUtils { resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { + val maySessionChild = mayAppendUpdatingSessionExec(groupingExpressions, child) + val distinctAttributes = normalizedNamedDistinctExpressions.map(_.toAttribute) val groupingAttributes = groupingExpressions.map(_.toAttribute) @@ -156,7 +162,7 @@ object AggUtils { aggregateAttributes = aggregateAttributes, resultExpressions = groupingAttributes ++ distinctAttributes ++ aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), - child = child) + child = maySessionChild) } // 2. Create an Aggregate Operator for partial merge aggregations. @@ -345,4 +351,179 @@ object AggUtils { finalAndCompleteAggregate :: Nil } + + /** + * Plans a streaming session aggregation using the following progression: + * + * - Partial Aggregation + * - all tuples will have aggregated columns with initial value + * - (If "spark.sql.streaming.sessionWindow.merge.sessions.in.local.partition" is enabled) + * - Sort within partition (sort: all keys) + * - MergingSessionExec + * - calculate session among tuples, and aggregate tuples in session with partial merge + * - Shuffle & Sort (distribution: keys "without" session, sort: all keys) + * - SessionWindowStateStoreRestore (group: keys "without" session) + * - merge input tuples with stored tuples (sessions) respecting sort order + * - MergingSessionExec + * - calculate session among tuples, and aggregate tuples in session with partial merge + * - NOTE: it leverages the fact that the output of SessionWindowStateStoreRestore is sorted + * - now there is at most 1 tuple per group, key with session + * - SessionWindowStateStoreSave (group: keys "without" session) + * - saves tuple(s) for the next batch (multiple sessions could co-exist at the same time) + * - Complete (output the current result of the aggregation) + */ + def planStreamingAggregationForSession( + groupingExpressions: Seq[NamedExpression], + sessionExpression: NamedExpression, + functionsWithoutDistinct: Seq[AggregateExpression], + resultExpressions: Seq[NamedExpression], + stateFormatVersion: Int, + mergeSessionsInLocalPartition: Boolean, + child: SparkPlan): Seq[SparkPlan] = { + + val groupWithoutSessionExpression = groupingExpressions.filterNot { p => + p.semanticEquals(sessionExpression) + } + + if (groupWithoutSessionExpression.isEmpty) { + throw new AnalysisException("Global aggregation with session window in streaming query" + + " is not supported.") + } + + val groupingWithoutSessionAttributes = groupWithoutSessionExpression.map(_.toAttribute) + + val groupingAttributes = groupingExpressions.map(_.toAttribute) + + // we don't do partial aggregate here, because it requires additional shuffle + // and there will be less rows which have same session start + // here doing partial merge is to have aggregated columns with default value for each row + val partialAggregate: SparkPlan = { + val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) + val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) + createAggregate( + groupingExpressions = groupingExpressions, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + resultExpressions = groupingAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + child = child) + } + + val partialMerged1: SparkPlan = if (mergeSessionsInLocalPartition) { + val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) + val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) + + // sort happens here to merge sessions on each partition + // this is to reduce amount of rows to shuffle + MergingSessionsExec( + requiredChildDistributionExpressions = None, + requiredChildDistributionOption = None, + groupingExpressions = groupingAttributes, + sessionExpression = sessionExpression, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = groupingAttributes.length, + resultExpressions = groupingAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + child = partialAggregate + ) + } else { + partialAggregate + } + + // shuffle & sort happens here: most of details are also handled in this physical plan + val restored = SessionWindowStateStoreRestoreExec(groupingWithoutSessionAttributes, + sessionExpression.toAttribute, stateInfo = None, eventTimeWatermark = None, + stateFormatVersion, partialMerged1) + + val mergedSessions = { + val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) + val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) + MergingSessionsExec( + requiredChildDistributionExpressions = None, + requiredChildDistributionOption = Some(restored.requiredChildDistribution), + groupingExpressions = groupingAttributes, + sessionExpression = sessionExpression, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = groupingAttributes.length, + resultExpressions = groupingAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + child = restored + ) + } + + // Note: stateId and returnAllStates are filled in later with preparation rules + // in IncrementalExecution. + val saved = SessionWindowStateStoreSaveExec( + groupingWithoutSessionAttributes, + sessionExpression.toAttribute, + stateInfo = None, + outputMode = None, + eventTimeWatermark = None, + stateFormatVersion, mergedSessions) + + val finalAndCompleteAggregate: SparkPlan = { + val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final)) + // The attributes of the final aggregation buffer, which is presented as input to the result + // projection: + val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute) + + createAggregate( + requiredChildDistributionExpressions = Some(groupingAttributes), + groupingExpressions = groupingAttributes, + aggregateExpressions = finalAggregateExpressions, + aggregateAttributes = finalAggregateAttributes, + initialInputBufferOffset = groupingAttributes.length, + resultExpressions = resultExpressions, + child = saved) + } + + finalAndCompleteAggregate :: Nil + } + + private def mayAppendUpdatingSessionExec( + groupingExpressions: Seq[NamedExpression], + maybeChildPlan: SparkPlan): SparkPlan = { + groupingExpressions.find(_.metadata.contains(SessionWindow.marker)) match { + case Some(sessionExpression) => + UpdatingSessionsExec( + groupingExpressions.map(_.toAttribute), + sessionExpression.toAttribute, + maybeChildPlan) + + case None => maybeChildPlan + } + } + + private def mayAppendMergingSessionExec( + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression], + partialAggregate: SparkPlan): SparkPlan = { + groupingExpressions.find(_.metadata.contains(SessionWindow.marker)) match { + case Some(sessionExpression) => + val aggExpressions = aggregateExpressions.map(_.copy(mode = PartialMerge)) + val aggAttributes = aggregateExpressions.map(_.resultAttribute) + + val groupingAttributes = groupingExpressions.map(_.toAttribute) + val groupingWithoutSessionExpressions = groupingExpressions.diff(Seq(sessionExpression)) + val groupingWithoutSessionsAttributes = groupingWithoutSessionExpressions + .map(_.toAttribute) + + MergingSessionsExec( + requiredChildDistributionExpressions = Some(groupingWithoutSessionsAttributes), + requiredChildDistributionOption = None, + groupingExpressions = groupingAttributes, + sessionExpression = sessionExpression, + aggregateExpressions = aggExpressions, + aggregateAttributes = aggAttributes, + initialInputBufferOffset = groupingAttributes.length, + resultExpressions = groupingAttributes ++ + aggExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + child = partialAggregate + ) + + case None => partialAggregate + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionsIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionsIterator.scala index bb474a19222c..2c611e3d0d22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionsIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionsIterator.scala @@ -181,7 +181,8 @@ class UpdatingSessionsIterator( private val valueProj = GenerateUnsafeProjection.generate(valuesExpressions, inputSchema) private val restoreProj = GenerateUnsafeProjection.generate(inputSchema, - groupingExpressions.map(_.toAttribute) ++ valuesExpressions.map(_.toAttribute)) + groupingWithoutSession.map(_.toAttribute) ++ Seq(sessionExpression.toAttribute) ++ + valuesExpressions.map(_.toAttribute)) private def generateGroupingKey(): InternalRow = { val newRow = new SpecificInternalRow(Seq(sessionExpression.toAttribute).toStructType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala index 5019008ec5e3..7fd39146c95f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -26,8 +26,9 @@ import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} +import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.aggregate.UpdatingSessionsIterator import org.apache.spark.sql.types.{DataType, StructField, StructType} import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.util.Utils @@ -53,12 +54,23 @@ case class AggregateInPandasExec( override def producedAttributes: AttributeSet = AttributeSet(output) - override def requiredChildDistribution: Seq[Distribution] = { - if (groupingExpressions.isEmpty) { - AllTuples :: Nil - } else { - ClusteredDistribution(groupingExpressions) :: Nil - } + val sessionWindowOption = groupingExpressions.find { p => + p.metadata.contains(SessionWindow.marker) + } + + val groupingWithoutSessionExpressions = sessionWindowOption match { + case Some(sessionExpression) => + groupingExpressions.filterNot { p => p.semanticEquals(sessionExpression) } + + case None => groupingExpressions + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = sessionWindowOption match { + case Some(sessionExpression) => + Seq((groupingWithoutSessionExpressions ++ Seq(sessionExpression)) + .map(SortOrder(_, Ascending))) + + case None => Seq(groupingExpressions.map(SortOrder(_, Ascending))) } private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = { @@ -73,9 +85,6 @@ case class AggregateInPandasExec( } } - override def requiredChildOrdering: Seq[Seq[SortOrder]] = - Seq(groupingExpressions.map(SortOrder(_, Ascending))) - override protected def doExecute(): RDD[InternalRow] = { val inputRDD = child.execute() @@ -107,13 +116,14 @@ case class AggregateInPandasExec( // Map grouped rows to ArrowPythonRunner results, Only execute if partition is not empty inputRDD.mapPartitionsInternal { iter => if (iter.isEmpty) iter else { + val newIter: Iterator[InternalRow] = mayAppendUpdatingSessionIterator(iter) val prunedProj = UnsafeProjection.create(allInputs.toSeq, child.output) val grouped = if (groupingExpressions.isEmpty) { // Use an empty unsafe row as a place holder for the grouping key - Iterator((new UnsafeRow(), iter)) + Iterator((new UnsafeRow(), newIter)) } else { - GroupedIterator(iter, groupingExpressions, child.output) + GroupedIterator(newIter, groupingExpressions, child.output) }.map { case (key, rows) => (key, rows.map(prunedProj)) } @@ -157,4 +167,21 @@ case class AggregateInPandasExec( override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = copy(child = newChild) + + + private def mayAppendUpdatingSessionIterator( + iter: Iterator[InternalRow]): Iterator[InternalRow] = { + val newIter = sessionWindowOption match { + case Some(sessionExpression) => + val inMemoryThreshold = conf.windowExecBufferInMemoryThreshold + val spillThreshold = conf.windowExecBufferSpillThreshold + + new UpdatingSessionsIterator(iter, groupingWithoutSessionExpressions, sessionExpression, + child.output, inMemoryThreshold, spillThreshold) + + case None => iter + } + + newIter + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index e98996b8e37d..3e772e104648 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -149,6 +149,26 @@ class IncrementalExecution( stateFormatVersion, child) :: Nil)) + case SessionWindowStateStoreSaveExec(keys, session, None, None, None, stateFormatVersion, + UnaryExecNode(agg, + SessionWindowStateStoreRestoreExec(_, _, None, None, _, child))) => + val aggStateInfo = nextStatefulOperationStateInfo + SessionWindowStateStoreSaveExec( + keys, + session, + Some(aggStateInfo), + Some(outputMode), + Some(offsetSeqMetadata.batchWatermarkMs), + stateFormatVersion, + agg.withNewChildren( + SessionWindowStateStoreRestoreExec( + keys, + session, + Some(aggStateInfo), + Some(offsetSeqMetadata.batchWatermarkMs), + stateFormatVersion, + child) :: Nil)) + case StreamingDeduplicateExec(keys, child, None, None) => StreamingDeduplicateExec( keys, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 3f6a7ba1a0da..74743c5ec058 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -20,7 +20,9 @@ package org.apache.spark.sql.execution.streaming import java.util.UUID import java.util.concurrent.TimeUnit._ +import scala.annotation.tailrec import scala.collection.JavaConverters._ +import scala.collection.mutable import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD @@ -511,6 +513,296 @@ case class StateStoreSaveExec( copy(child = newChild) } +/** + * This class sorts input rows and existing sessions in state and provides output rows as + * sorted by "group keys + start time of session window". + * + * Refer [[MergingSortWithSessionWindowStateIterator]] for more details. + */ +case class SessionWindowStateStoreRestoreExec( + keyWithoutSessionExpressions: Seq[Attribute], + sessionExpression: Attribute, + stateInfo: Option[StatefulOperatorStateInfo], + eventTimeWatermark: Option[Long], + stateFormatVersion: Int, + child: SparkPlan) + extends UnaryExecNode with StateStoreReader with WatermarkSupport { + + override def keyExpressions: Seq[Attribute] = keyWithoutSessionExpressions + + private val stateManager = StreamingSessionWindowStateManager.createStateManager( + keyWithoutSessionExpressions, sessionExpression, child.output, stateFormatVersion) + + override protected def doExecute(): RDD[InternalRow] = { + val numOutputRows = longMetric("numOutputRows") + assert(keyExpressions.nonEmpty, "Grouping key must be specified when using sessionWindow") + + child.execute().mapPartitionsWithReadStateStore( + getStateInfo, + stateManager.getStateKeySchema, + stateManager.getStateValueSchema, + numColsPrefixKey = stateManager.getNumColsForPrefixKey, + session.sessionState, + Some(session.streams.stateStoreCoordinator)) { case (store, iter) => + + // We need to filter out outdated inputs + val filteredIterator = watermarkPredicateForData match { + case Some(predicate) => iter.filter((row: InternalRow) => !predicate.eval(row)) + case None => iter + } + + new MergingSortWithSessionWindowStateIterator( + filteredIterator, + stateManager, + store, + keyWithoutSessionExpressions, + sessionExpression, + child.output).map { row => + numOutputRows += 1 + row + } + } + } + + override def output: Seq[Attribute] = child.output + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def outputOrdering: Seq[SortOrder] = { + (keyWithoutSessionExpressions ++ Seq(sessionExpression)).map(SortOrder(_, Ascending)) + } + + override def requiredChildDistribution: Seq[Distribution] = { + if (keyWithoutSessionExpressions.isEmpty) { + AllTuples :: Nil + } else { + ClusteredDistribution(keyWithoutSessionExpressions, stateInfo.map(_.numPartitions)) :: Nil + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = { + Seq((keyWithoutSessionExpressions ++ Seq(sessionExpression)).map(SortOrder(_, Ascending))) + } + + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + copy(child = newChild) +} + +/** + * For each input tuple, the key is calculated and the tuple is `put` into the [[StateStore]]. + */ +case class SessionWindowStateStoreSaveExec( + keyExpressions: Seq[Attribute], + sessionExpression: Attribute, + stateInfo: Option[StatefulOperatorStateInfo] = None, + outputMode: Option[OutputMode] = None, + eventTimeWatermark: Option[Long] = None, + stateFormatVersion: Int, + child: SparkPlan) + extends UnaryExecNode with StateStoreWriter with WatermarkSupport { + + private val keyWithoutSessionExpressions = keyExpressions.filterNot { p => + p.semanticEquals(sessionExpression) + } + + private val stateManager = StreamingSessionWindowStateManager.createStateManager( + keyWithoutSessionExpressions, sessionExpression, child.output, stateFormatVersion) + + override protected def doExecute(): RDD[InternalRow] = { + metrics // force lazy init at driver + assert(outputMode.nonEmpty, + "Incorrect planning in IncrementalExecution, outputMode has not been set") + assert(keyExpressions.nonEmpty, + "Grouping key must be specified when using sessionWindow") + + child.execute().mapPartitionsWithStateStore( + getStateInfo, + stateManager.getStateKeySchema, + stateManager.getStateValueSchema, + numColsPrefixKey = stateManager.getNumColsForPrefixKey, + session.sessionState, + Some(session.streams.stateStoreCoordinator)) { case (store, iter) => + + val numOutputRows = longMetric("numOutputRows") + val allUpdatesTimeMs = longMetric("allUpdatesTimeMs") + val allRemovalsTimeMs = longMetric("allRemovalsTimeMs") + val commitTimeMs = longMetric("commitTimeMs") + + outputMode match { + // Update and output all rows in the StateStore. + case Some(Complete) => + allUpdatesTimeMs += timeTakenMs { + putToStore(iter, store, false) + } + allRemovalsTimeMs += 0 + commitTimeMs += timeTakenMs { + stateManager.commit(store) + } + setStoreMetrics(store) + stateManager.iterator(store).map { row => + numOutputRows += 1 + row + } + + // Update and output only rows being evicted from the StateStore + // Assumption: watermark predicates must be non-empty if append mode is allowed + case Some(Append) => + allUpdatesTimeMs += timeTakenMs { + putToStore(iter, store, true) + } + + val removalStartTimeNs = System.nanoTime + new NextIterator[InternalRow] { + private val removedIter = stateManager.removeByValueCondition( + store, watermarkPredicateForData.get.eval) + + override protected def getNext(): InternalRow = { + if (!removedIter.hasNext) { + finished = true + null + } else { + numOutputRows += 1 + removedIter.next() + } + } + + override protected def close(): Unit = { + allRemovalsTimeMs += NANOSECONDS.toMillis(System.nanoTime - removalStartTimeNs) + commitTimeMs += timeTakenMs { store.commit() } + setStoreMetrics(store) + } + } + + case Some(Update) => + val iterPutToStore = iteratorPutToStore(iter, store, true, true) + new NextIterator[InternalRow] { + private val updatesStartTimeNs = System.nanoTime + + override protected def getNext(): InternalRow = { + if (iterPutToStore.hasNext) { + iterPutToStore.next() + } else { + finished = true + null + } + } + + override protected def close(): Unit = { + allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs) + + allRemovalsTimeMs += timeTakenMs { + if (watermarkPredicateForData.nonEmpty) { + val removedIter = stateManager.removeByValueCondition( + store, watermarkPredicateForData.get.eval) + while (removedIter.hasNext) { + removedIter.next() + } + } + } + commitTimeMs += timeTakenMs { store.commit() } + setStoreMetrics(store) + } + } + + case _ => throw new UnsupportedOperationException(s"Invalid output mode: $outputMode") + } + } + } + + override def output: Seq[Attribute] = child.output + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def requiredChildDistribution: Seq[Distribution] = { + if (keyExpressions.isEmpty) { + AllTuples :: Nil + } else { + ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil + } + } + + override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { + (outputMode.contains(Append) || outputMode.contains(Update)) && + eventTimeWatermark.isDefined && + newMetadata.batchWatermarkMs > eventTimeWatermark.get + } + + private def iteratorPutToStore( + baseIter: Iterator[InternalRow], + store: StateStore, + needFilter: Boolean, + returnOnlyUpdatedRows: Boolean): Iterator[InternalRow] = { + val numUpdatedStateRows = longMetric("numUpdatedStateRows") + val iter = if (needFilter) { + baseIter.filter(row => !watermarkPredicateForData.get.eval(row)) + } else { + baseIter + } + + new NextIterator[InternalRow] { + var curKey: UnsafeRow = null + val curValuesOnKey = new mutable.ArrayBuffer[UnsafeRow]() + + private def applyChangesOnKey(): Unit = { + if (curValuesOnKey.nonEmpty) { + val updatedRows = stateManager.updateSessions(store, curKey, curValuesOnKey) + numUpdatedStateRows += updatedRows + curValuesOnKey.clear + } + } + + @tailrec + override protected def getNext(): InternalRow = { + if (!iter.hasNext) { + applyChangesOnKey() + finished = true + return null + } + + val row = iter.next().asInstanceOf[UnsafeRow] + val key = stateManager.extractKeyWithoutSession(row) + + if (curKey == null || curKey != key) { + // new group appears + applyChangesOnKey() + curKey = key.copy() + } + + // must copy the row, for this row is a reference in iterator and + // will change when iter.next + curValuesOnKey += row.copy + + if (!returnOnlyUpdatedRows) { + row + } else { + if (stateManager.newOrModified(store, row)) { + row + } else { + // current row isn't the "updated" row, continue to the next row + getNext() + } + } + } + + override protected def close(): Unit = {} + } + } + + private def putToStore( + baseIter: Iterator[InternalRow], + store: StateStore, + needFilter: Boolean) { + val iterPutToStore = iteratorPutToStore(baseIter, store, needFilter, false) + while (iterPutToStore.hasNext) { + iterPutToStore.next() + } + } + + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + copy(child = newChild) +} + + /** Physical operator for executing streaming Deduplicate. */ case class StreamingDeduplicateExec( keyExpressions: Seq[Attribute], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 3b39d9790dff..688842d0e2a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3630,6 +3630,35 @@ object functions { window(timeColumn, windowDuration, windowDuration, "0 second") } + /** + * Generates session window given a timestamp specifying column. + * + * Session window is the one of dynamic windows, which means the length of window is vary + * according to the given inputs. The length of session window is defined as "the timestamp + * of latest input of the session + gap duration", so when the new inputs are bound to the + * current session window, the end time of session window can be expanded according to the new + * inputs. + * + * Windows can support microsecond precision. Windows in the order of months are not supported. + * + * For a streaming query, you may use the function `current_timestamp` to generate windows on + * processing time. + * + * @param timeColumn The column or the expression to use as the timestamp for windowing by time. + * The time column must be of TimestampType. + * @param gapDuration A string specifying the timeout of the session, e.g. `10 minutes`, + * `1 second`. Check `org.apache.spark.unsafe.types.CalendarInterval` for + * valid duration identifiers. + * + * @group datetime_funcs + * @since 3.2.0 + */ + def session_window(timeColumn: Column, gapDuration: String): Column = { + withExpr { + SessionWindow(timeColumn.expr, gapDuration) + }.as("session_window") + } + /** * Creates timestamp from the number of seconds since UTC epoch. * @group datetime_funcs diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index c13a1d4e93d4..75be26a26bd0 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -1,8 +1,8 @@ ## Summary - Number of queries: 360 - - Number of expressions that missing example: 13 - - Expressions missing examples: bigint,binary,boolean,date,decimal,double,float,int,smallint,string,timestamp,tinyint,window + - Number of expressions that missing example: 14 + - Expressions missing examples: bigint,binary,boolean,date,decimal,double,float,int,smallint,string,timestamp,tinyint,session_window,window ## Schema of Built-in Functions | Class name | Function name or alias | Query example | Output schema | | ---------- | ---------------------- | ------------- | ------------- | @@ -244,6 +244,7 @@ | org.apache.spark.sql.catalyst.expressions.SecondsToTimestamp | timestamp_seconds | SELECT timestamp_seconds(1230219000) | struct | | org.apache.spark.sql.catalyst.expressions.Sentences | sentences | SELECT sentences('Hi there! Good morning.') | struct>> | | org.apache.spark.sql.catalyst.expressions.Sequence | sequence | SELECT sequence(1, 5) | struct> | +| org.apache.spark.sql.catalyst.expressions.SessionWindow | session_window | N/A | N/A | | org.apache.spark.sql.catalyst.expressions.Sha1 | sha | SELECT sha('Spark') | struct | | org.apache.spark.sql.catalyst.expressions.Sha1 | sha1 | SELECT sha1('Spark') | struct | | org.apache.spark.sql.catalyst.expressions.Sha2 | sha2 | SELECT sha2('Spark', 256) | struct | @@ -365,4 +366,4 @@ | org.apache.spark.sql.catalyst.expressions.xml.XPathList | xpath | SELECT xpath('b1b2b3c1c2','a/b/text()') | structb1b2b3c1c2, a/b/text()):array> | | org.apache.spark.sql.catalyst.expressions.xml.XPathLong | xpath_long | SELECT xpath_long('12', 'sum(a/b)') | struct12, sum(a/b)):bigint> | | org.apache.spark.sql.catalyst.expressions.xml.XPathShort | xpath_short | SELECT xpath_short('12', 'sum(a/b)') | struct12, sum(a/b)):smallint> | -| org.apache.spark.sql.catalyst.expressions.xml.XPathString | xpath_string | SELECT xpath_string('bcc','a/c') | structbcc, a/c):string> | \ No newline at end of file +| org.apache.spark.sql.catalyst.expressions.xml.XPathString | xpath_string | SELECT xpath_string('bcc','a/c') | structbcc, a/c):string> | diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala new file mode 100644 index 000000000000..6dc1860bfa7a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala @@ -0,0 +1,291 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.sql.catalyst.plans.logical.Expand +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.StringType + +class DataFrameSessionWindowingSuite extends QueryTest with SharedSparkSession + with BeforeAndAfterEach { + + import testImplicits._ + + test("simple session window with record at window start") { + val df = Seq( + ("2016-03-27 19:39:30", 1, "a")).toDF("time", "value", "id") + + checkAnswer( + df.groupBy(session_window($"time", "10 seconds")) + .agg(count("*").as("counts")) + .orderBy($"session_window.start".asc) + .select($"session_window.start".cast("string"), $"session_window.end".cast("string"), + $"counts"), + Seq( + Row("2016-03-27 19:39:30", "2016-03-27 19:39:40", 1) + ) + ) + } + + test("session window groupBy statement") { + val df = Seq( + ("2016-03-27 19:39:34", 1, "a"), + ("2016-03-27 19:39:56", 2, "a"), + ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + + // session window handles sort while applying group by + // whereas time window doesn't + + checkAnswer( + df.groupBy(session_window($"time", "10 seconds")) + .agg(count("*").as("counts")) + .orderBy($"session_window.start".asc) + .select("counts"), + Seq(Row(2), Row(1)) + ) + } + + test("session window groupBy with multiple keys statement") { + val df = Seq( + ("2016-03-27 19:39:34", 1, "a"), + ("2016-03-27 19:39:39", 1, "a"), + ("2016-03-27 19:39:56", 2, "a"), + ("2016-03-27 19:40:04", 2, "a"), + ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + + // session window handles sort while applying group by + // whereas time window doesn't + + // expected sessions + // key "a" => (19:39:34 ~ 19:39:49) (19:39:56 ~ 19:40:14) + // key "b" => (19:39:27 ~ 19:39:37) + + checkAnswer( + df.groupBy(session_window($"time", "10 seconds"), 'id) + .agg(count("*").as("counts"), sum("value").as("sum")) + .orderBy($"session_window.start".asc) + .selectExpr("CAST(session_window.start AS STRING)", "CAST(session_window.end AS STRING)", + "id", "counts", "sum"), + + Seq( + Row("2016-03-27 19:39:27", "2016-03-27 19:39:37", "b", 1, 4), + Row("2016-03-27 19:39:34", "2016-03-27 19:39:49", "a", 2, 2), + Row("2016-03-27 19:39:56", "2016-03-27 19:40:14", "a", 2, 4) + ) + ) + } + + // FIXME: fix the failing test - check if it still fails or not + test("session window groupBy with multiple keys statement - one distinct") { + val df = Seq( + ("2016-03-27 19:39:34", 1, "a"), + ("2016-03-27 19:39:39", 1, "a"), + ("2016-03-27 19:39:56", 2, "a"), + ("2016-03-27 19:40:04", 2, "a"), + ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + + // session window handles sort while applying group by + // whereas time window doesn't + + // expected sessions + // key "a" => (19:39:34 ~ 19:39:49) (19:39:56 ~ 19:40:14) + // key "b" => (19:39:27 ~ 19:39:37) + + checkAnswer( + df.groupBy(session_window($"time", "10 seconds"), 'id) + .agg(count("*").as("counts"), sum_distinct(col("value")).as("sum")) + .orderBy($"session_window.start".asc) + .selectExpr("CAST(session_window.start AS STRING)", "CAST(session_window.end AS STRING)", + "id", "counts", "sum"), + Seq( + Row("2016-03-27 19:39:27", "2016-03-27 19:39:37", "b", 1, 4), + Row("2016-03-27 19:39:34", "2016-03-27 19:39:49", "a", 2, 1), + Row("2016-03-27 19:39:56", "2016-03-27 19:40:14", "a", 2, 2) + ) + ) + } + + test("session window groupBy with multiple keys statement - two distinct") { + val df = Seq( + ("2016-03-27 19:39:34", 1, 2, "a"), + ("2016-03-27 19:39:39", 1, 2, "a"), + ("2016-03-27 19:39:56", 2, 4, "a"), + ("2016-03-27 19:40:04", 2, 4, "a"), + ("2016-03-27 19:39:27", 4, 8, "b")).toDF("time", "value", "value2", "id") + + // session window handles sort while applying group by + // whereas time window doesn't + + // expected sessions + // key "a" => (19:39:34 ~ 19:39:49) (19:39:56 ~ 19:40:14) + // key "b" => (19:39:27 ~ 19:39:37) + + checkAnswer( + df.groupBy(session_window($"time", "10 seconds"), 'id) + .agg(sum_distinct(col("value")).as("sum"), sum_distinct(col("value2")).as("sum2")) + .orderBy($"session_window.start".asc) + .selectExpr("CAST(session_window.start AS STRING)", "CAST(session_window.end AS STRING)", + "id", "sum", "sum2"), + Seq( + Row("2016-03-27 19:39:27", "2016-03-27 19:39:37", "b", 4, 8), + Row("2016-03-27 19:39:34", "2016-03-27 19:39:49", "a", 1, 2), + Row("2016-03-27 19:39:56", "2016-03-27 19:40:14", "a", 2, 4) + ) + ) + } + + test("session window groupBy with multiple keys statement - keys overlapped with sessions") { + val df = Seq( + ("2016-03-27 19:39:34", 1, "a"), + ("2016-03-27 19:39:39", 1, "b"), + ("2016-03-27 19:39:40", 2, "a"), + ("2016-03-27 19:39:45", 2, "b"), + ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + + // session window handles sort while applying group by + // whereas time window doesn't + + // expected sessions + // a => (19:39:34 ~ 19:39:50) + // b => (19:39:27 ~ 19:39:37), (19:39:39 ~ 19:39:55) + + checkAnswer( + df.groupBy(session_window($"time", "10 seconds"), 'id) + .agg(count("*").as("counts"), sum("value").as("sum")) + .orderBy($"session_window.start".asc) + .selectExpr("CAST(session_window.start AS STRING)", "CAST(session_window.end AS STRING)", + "id", "counts", "sum"), + + Seq( + Row("2016-03-27 19:39:27", "2016-03-27 19:39:37", "b", 1, 4), + Row("2016-03-27 19:39:34", "2016-03-27 19:39:50", "a", 2, 3), + Row("2016-03-27 19:39:39", "2016-03-27 19:39:55", "b", 2, 3) + ) + ) + } + + test("session window with multi-column projection") { + val df = Seq( + ("2016-03-27 19:39:34", 1, "a"), + ("2016-03-27 19:39:56", 2, "a"), + ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + .select(session_window($"time", "10 seconds"), $"value") + .orderBy($"session_window.start".asc) + .select($"session_window.start".cast("string"), $"session_window.end".cast("string"), + $"value") + + val expands = df.queryExecution.optimizedPlan.find(_.isInstanceOf[Expand]) + assert(expands.isEmpty, "Session windows shouldn't require expand") + + checkAnswer( + df, + Seq( + Row("2016-03-27 19:39:27", "2016-03-27 19:39:37", 4), + Row("2016-03-27 19:39:34", "2016-03-27 19:39:44", 1), + Row("2016-03-27 19:39:56", "2016-03-27 19:40:06", 2) + ) + ) + } + + test("session window combined with explode expression") { + val df = Seq( + ("2016-03-27 19:39:34", 1, Seq("a", "b")), + ("2016-03-27 19:39:56", 2, Seq("a", "c", "d"))).toDF("time", "value", "ids") + + checkAnswer( + df.select(session_window($"time", "10 seconds"), $"value", explode($"ids")) + .orderBy($"session_window.start".asc).select("value"), + // first window exploded to two rows for "a", and "b", second window exploded to 3 rows + Seq(Row(1), Row(1), Row(2), Row(2), Row(2)) + ) + } + + test("null timestamps") { + val df = Seq( + ("2016-03-27 09:00:05", 1), + ("2016-03-27 09:00:32", 2), + (null, 3), + (null, 4)).toDF("time", "value") + + checkDataset( + df.select(session_window($"time", "10 seconds"), $"value") + .orderBy($"session_window.start".asc) + .select("value") + .as[Int], + 1, 2) // null columns are dropped + } + + // NOTE: unlike time window, joining session windows without grouping + // doesn't arrange session, so two rows will be joined only if session range is exactly same + + test("multiple session windows in a single operator throws nice exception") { + val df = Seq( + ("2016-03-27 09:00:02", 3), + ("2016-03-27 09:00:35", 6)).toDF("time", "value") + val e = intercept[AnalysisException] { + df.select(session_window($"time", "10 second"), session_window($"time", "15 second")) + .collect() + } + assert(e.getMessage.contains( + "Multiple time/session window expressions would result in a cartesian product")) + } + + test("aliased session windows") { + val df = Seq( + ("2016-03-27 19:39:34", 1, Seq("a", "b")), + ("2016-03-27 19:39:56", 2, Seq("a", "c", "d"))).toDF("time", "value", "ids") + + checkAnswer( + df.select(session_window($"time", "10 seconds").as("session_window"), $"value") + .orderBy($"session_window.start".asc) + .select("value"), + Seq(Row(1), Row(2)) + ) + } + + private def withTempTable(f: String => Unit): Unit = { + val tableName = "temp" + Seq( + ("2016-03-27 19:39:34", 1), + ("2016-03-27 19:39:56", 2), + ("2016-03-27 19:39:27", 4)).toDF("time", "value").createOrReplaceTempView(tableName) + try { + f(tableName) + } finally { + spark.catalog.dropTempView(tableName) + } + } + + test("time window in SQL with single string expression") { + withTempTable { table => + checkAnswer( + spark.sql(s"""select session_window(time, "10 seconds"), value from $table""") + .select($"session_window.start".cast(StringType), $"session_window.end".cast(StringType), + $"value"), + Seq( + Row("2016-03-27 19:39:27", "2016-03-27 19:39:37", 4), + Row("2016-03-27 19:39:34", "2016-03-27 19:39:44", 1), + Row("2016-03-27 19:39:56", "2016-03-27 19:40:06", 2) + ) + ) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala index 4fdaeb57ad50..2ef43dcf562c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala @@ -239,7 +239,7 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSparkSession { df.select(window($"time", "10 second"), window($"time", "15 second")).collect() } assert(e.getMessage.contains( - "Multiple time window expressions would result in a cartesian product")) + "Multiple time/session window expressions would result in a cartesian product")) } test("aliased windows") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index b0d5c8932bcb..1e23c115ff24 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -136,7 +136,8 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark test("SPARK-14415: All functions should have own descriptions") { for (f <- spark.sessionState.functionRegistry.listFunction()) { - if (!Seq("cube", "grouping", "grouping_id", "rollup", "window").contains(f.unquotedString)) { + if (!Seq("cube", "grouping", "grouping_id", "rollup", "window", + "session_window").contains(f.unquotedString)) { checkKeywordsNotExist(sql(s"describe function $f"), "N/A.") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala index 30ee97a89b4e..08e21d537122 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala @@ -133,6 +133,7 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession { val ignoreSet = Set( // Explicitly inherits NonSQLExpression, and has no ExpressionDescription "org.apache.spark.sql.catalyst.expressions.TimeWindow", + "org.apache.spark.sql.catalyst.expressions.SessionWindow", // Cast aliases do not need examples "org.apache.spark.sql.catalyst.expressions.Cast") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala new file mode 100644 index 000000000000..d1d21e35ef78 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala @@ -0,0 +1,456 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import java.util.Locale + +import org.scalatest.BeforeAndAfter +import org.scalatest.matchers.must.Matchers + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider +import org.apache.spark.sql.functions.{count, session_window, sum} +import org.apache.spark.sql.internal.SQLConf + +class StreamingSessionWindowSuite extends StreamTest + with BeforeAndAfter with Matchers with Logging { + + import testImplicits._ + + after { + sqlContext.streams.active.foreach(_.stop()) + } + + def testWithAllOptions(name: String, confPairs: (String, String)*) + (func: => Any): Unit = { + val mergingSessionOptions = Seq(true, false).map { value => + (SQLConf.STREAMING_SESSION_WINDOW_MERGE_SESSIONS_IN_LOCAL_PARTITION.key, value) + } + val providerOptions = Seq( + classOf[HDFSBackedStateStoreProvider].getCanonicalName).map { value => + (SQLConf.STATE_STORE_PROVIDER_CLASS.key, value.stripSuffix("$")) + } + + val availableOptions = for ( + opt1 <- mergingSessionOptions; + opt2 <- providerOptions + ) yield (opt1, opt2) + + for (option <- availableOptions) { + test(s"$name - merging sessions in local partition: ${option._1._2} / " + + s"provider: ${option._2._2}") { + withSQLConf(confPairs ++ + Seq( + option._1._1 -> option._1._2.toString, + option._2._1 -> option._2._2): _*) { + func + } + } + } + } + + testWithAllOptions("complete mode - session window") { + // Implements StructuredSessionization.scala leveraging "session" function + // as a test, to verify the sessionization works with simple example + + // note that complete mode doesn't honor watermark: even it is specified, watermark will be + // always Unix timestamp 0 + + val inputData = MemoryStream[(String, Long)] + + // Split the lines into words, treat words as sessionId of events + val events = inputData.toDF() + .select($"_1".as("value"), $"_2".as("timestamp")) + .withColumn("eventTime", $"timestamp".cast("timestamp")) + .selectExpr("explode(split(value, ' ')) AS sessionId", "eventTime") + + val sessionUpdates = events + .groupBy(session_window($"eventTime", "10 seconds") as 'session, 'sessionId) + .agg(count("*").as("numEvents")) + .selectExpr("sessionId", "CAST(session.start AS LONG)", "CAST(session.end AS LONG)", + "CAST(session.end AS LONG) - CAST(session.start AS LONG) AS durationMs", + "numEvents") + + testStream(sessionUpdates, OutputMode.Complete())( + AddData(inputData, + ("hello world spark streaming", 40L), + ("world hello structured streaming", 41L) + ), + CheckNewAnswer( + ("hello", 40, 51, 11, 2), + ("world", 40, 51, 11, 2), + ("streaming", 40, 51, 11, 2), + ("spark", 40, 50, 10, 1), + ("structured", 41, 51, 10, 1) + ), + + // placing new sessions "before" previous sessions + AddData(inputData, ("spark streaming", 25L)), + CheckNewAnswer( + ("spark", 25, 35, 10, 1), + ("streaming", 25, 35, 10, 1), + ("hello", 40, 51, 11, 2), + ("world", 40, 51, 11, 2), + ("streaming", 40, 51, 11, 2), + ("spark", 40, 50, 10, 1), + ("structured", 41, 51, 10, 1) + ), + + // concatenating multiple previous sessions into one + AddData(inputData, ("spark streaming", 30L)), + CheckNewAnswer( + ("spark", 25, 50, 25, 3), + ("streaming", 25, 51, 26, 4), + ("hello", 40, 51, 11, 2), + ("world", 40, 51, 11, 2), + ("structured", 41, 51, 10, 1) + ), + + // placing new sessions after previous sessions + AddData(inputData, ("hello apache spark", 60L)), + CheckNewAnswer( + ("spark", 25, 50, 25, 3), + ("streaming", 25, 51, 26, 4), + ("hello", 40, 51, 11, 2), + ("world", 40, 51, 11, 2), + ("structured", 41, 51, 10, 1), + ("hello", 60, 70, 10, 1), + ("apache", 60, 70, 10, 1), + ("spark", 60, 70, 10, 1) + ), + + AddData(inputData, ("structured streaming", 90L)), + CheckNewAnswer( + ("spark", 25, 50, 25, 3), + ("streaming", 25, 51, 26, 4), + ("hello", 40, 51, 11, 2), + ("world", 40, 51, 11, 2), + ("structured", 41, 51, 10, 1), + ("hello", 60, 70, 10, 1), + ("apache", 60, 70, 10, 1), + ("spark", 60, 70, 10, 1), + ("structured", 90, 100, 10, 1), + ("streaming", 90, 100, 10, 1) + ) + ) + } + + testWithAllOptions("complete mode - session window - no key") { + // complete mode doesn't honor watermark: even it is specified, watermark will be + // always Unix timestamp 0 + + val inputData = MemoryStream[Int] + + val windowedAggregation = inputData.toDF() + .selectExpr("*") + .withColumn("eventTime", $"value".cast("timestamp")) + .groupBy(session_window($"eventTime", "5 seconds") as 'session) + .agg(count("*") as 'count, sum("value") as 'sum) + .select($"session".getField("start").cast("long").as[Long], + $"session".getField("end").cast("long").as[Long], $"count".as[Long], $"sum".as[Long]) + + val e = intercept[StreamingQueryException] { + testStream(windowedAggregation, OutputMode.Complete())( + AddData(inputData, 40), + CheckAnswer() // this is just to trigger the exception + ) + } + Seq("Global aggregation with session window", "not supported").foreach { m => + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(m.toLowerCase(Locale.ROOT))) + } + } + + testWithAllOptions("append mode - session window") { + // Implements StructuredSessionization.scala leveraging "session" function + // as a test, to verify the sessionization works with simple example + + val inputData = MemoryStream[(String, Long)] + + // Split the lines into words, treat words as sessionId of events + val events = inputData.toDF() + .select($"_1".as("value"), $"_2".as("timestamp")) + .withColumn("eventTime", $"timestamp".cast("timestamp")) + .selectExpr("explode(split(value, ' ')) AS sessionId", "eventTime") + .withWatermark("eventTime", "30 seconds") + + val sessionUpdates = events + .groupBy(session_window($"eventTime", "10 seconds") as 'session, 'sessionId) + .agg(count("*").as("numEvents")) + .selectExpr("sessionId", "CAST(session.start AS LONG)", "CAST(session.end AS LONG)", + "CAST(session.end AS LONG) - CAST(session.start AS LONG) AS durationMs", + "numEvents") + + testStream(sessionUpdates, OutputMode.Append())( + AddData(inputData, + ("hello world spark streaming", 40L), + ("world hello structured streaming", 41L) + ), + + // watermark: 11 + // current sessions + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("streaming", 40, 51, 11, 2), + // ("spark", 40, 50, 10, 1), + // ("structured", 41, 51, 10, 1) + CheckNewAnswer( + ), + + // placing new sessions "before" previous sessions + AddData(inputData, ("spark streaming", 25L)), + // watermark: 11 + // current sessions + // ("spark", 25, 35, 10, 1), + // ("streaming", 25, 35, 10, 1), + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("streaming", 40, 51, 11, 2), + // ("spark", 40, 50, 10, 1), + // ("structured", 41, 51, 10, 1) + CheckNewAnswer( + ), + + // late event which session's end 10 would be later than watermark 11: should be dropped + AddData(inputData, ("spark streaming", 0L)), + // watermark: 11 + // current sessions + // ("spark", 25, 35, 10, 1), + // ("streaming", 25, 35, 10, 1), + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("streaming", 40, 51, 11, 2), + // ("spark", 40, 50, 10, 1), + // ("structured", 41, 51, 10, 1) + CheckNewAnswer( + ), + + // concatenating multiple previous sessions into one + AddData(inputData, ("spark streaming", 30L)), + // watermark: 11 + // current sessions + // ("spark", 25, 50, 25, 3), + // ("streaming", 25, 51, 26, 4), + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("structured", 41, 51, 10, 1) + CheckNewAnswer( + ), + + // placing new sessions after previous sessions + AddData(inputData, ("hello apache spark", 60L)), + // watermark: 30 + // current sessions + // ("spark", 25, 50, 25, 3), + // ("streaming", 25, 51, 26, 4), + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("structured", 41, 51, 10, 1), + // ("hello", 60, 70, 10, 1), + // ("apache", 60, 70, 10, 1), + // ("spark", 60, 70, 10, 1) + CheckNewAnswer( + ), + + AddData(inputData, ("structured streaming", 90L)), + // watermark: 60 + // current sessions + // ("hello", 60, 70, 10, 1), + // ("apache", 60, 70, 10, 1), + // ("spark", 60, 70, 10, 1), + // ("structured", 90, 100, 10, 1), + // ("streaming", 90, 100, 10, 1) + CheckNewAnswer( + ("spark", 25, 50, 25, 3), + ("streaming", 25, 51, 26, 4), + ("hello", 40, 51, 11, 2), + ("world", 40, 51, 11, 2), + ("structured", 41, 51, 10, 1) + ) + ) + } + + testWithAllOptions("append mode - session window - no key") { + val inputData = MemoryStream[Int] + + val windowedAggregation = inputData.toDF() + .selectExpr("*") + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(session_window($"eventTime", "5 seconds") as 'session) + .agg(count("*") as 'count, sum("value") as 'sum) + .select($"session".getField("start").cast("long").as[Long], + $"session".getField("end").cast("long").as[Long], $"count".as[Long], $"sum".as[Long]) + + val e = intercept[StreamingQueryException] { + testStream(windowedAggregation)( + AddData(inputData, 40), + CheckAnswer() // this is just to trigger the exception + ) + } + Seq("Global aggregation with session window", "not supported").foreach { m => + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(m.toLowerCase(Locale.ROOT))) + } + } + + testWithAllOptions("update mode - session window") { + // Implements StructuredSessionization.scala leveraging "session" function + // as a test, to verify the sessionization works with simple example + + val inputData = MemoryStream[(String, Long)] + + // Split the lines into words, treat words as sessionId of events + val events = inputData.toDF() + .select($"_1".as("value"), $"_2".as("timestamp")) + .withColumn("eventTime", $"timestamp".cast("timestamp")) + .selectExpr("explode(split(value, ' ')) AS sessionId", "eventTime") + .withWatermark("eventTime", "10 seconds") + + val sessionUpdates = events + .groupBy(session_window($"eventTime", "10 seconds") as 'session, 'sessionId) + .agg(count("*").as("numEvents")) + .selectExpr("sessionId", "CAST(session.start AS LONG)", "CAST(session.end AS LONG)", + "CAST(session.end AS LONG) - CAST(session.start AS LONG) AS durationMs", + "numEvents") + + testStream(sessionUpdates, OutputMode.Update())( + AddData(inputData, + ("hello world spark streaming", 40L), + ("world hello structured streaming", 41L) + ), + // watermark: 11 + // current sessions + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("streaming", 40, 51, 11, 2), + // ("spark", 40, 50, 10, 1), + // ("structured", 41, 51, 10, 1) + CheckNewAnswer( + ("hello", 40, 51, 11, 2), + ("world", 40, 51, 11, 2), + ("streaming", 40, 51, 11, 2), + ("spark", 40, 50, 10, 1), + ("structured", 41, 51, 10, 1) + ), + + // placing new sessions "before" previous sessions + AddData(inputData, ("spark streaming", 25L)), + // watermark: 11 + // current sessions + // ("spark", 25, 35, 10, 1), + // ("streaming", 25, 35, 10, 1), + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("streaming", 40, 51, 11, 2), + // ("spark", 40, 50, 10, 1), + // ("structured", 41, 51, 10, 1) + CheckNewAnswer( + ("spark", 25, 35, 10, 1), + ("streaming", 25, 35, 10, 1) + ), + + // late event which session's end 10 would be later than watermark 11: should be dropped + AddData(inputData, ("spark streaming", 0L)), + // watermark: 11 + // current sessions + // ("spark", 25, 35, 10, 1), + // ("streaming", 25, 35, 10, 1), + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("streaming", 40, 51, 11, 2), + // ("spark", 40, 50, 10, 1), + // ("structured", 41, 51, 10, 1) + CheckNewAnswer( + ), + + // concatenating multiple previous sessions into one + AddData(inputData, ("spark streaming", 30L)), + // watermark: 11 + // current sessions + // ("spark", 25, 50, 25, 3), + // ("streaming", 25, 51, 26, 4), + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("structured", 41, 51, 10, 1) + CheckNewAnswer( + ("spark", 25, 50, 25, 3), + ("streaming", 25, 51, 26, 4) + ), + + // placing new sessions after previous sessions + AddData(inputData, ("hello apache spark", 60L)), + // watermark: 30 + // current sessions + // ("spark", 25, 50, 25, 3), + // ("streaming", 25, 51, 26, 4), + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("structured", 41, 51, 10, 1), + // ("hello", 60, 70, 10, 1), + // ("apache", 60, 70, 10, 1), + // ("spark", 60, 70, 10, 1) + CheckNewAnswer( + ("hello", 60, 70, 10, 1), + ("apache", 60, 70, 10, 1), + ("spark", 60, 70, 10, 1) + ), + + AddData(inputData, ("structured streaming", 90L)), + // watermark: 60 + // current sessions + // ("hello", 60, 70, 10, 1), + // ("apache", 60, 70, 10, 1), + // ("spark", 60, 70, 10, 1), + // ("structured", 90, 100, 10, 1), + // ("streaming", 90, 100, 10, 1) + // evicted + // ("spark", 25, 50, 25, 3), + // ("streaming", 25, 51, 26, 4), + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("structured", 41, 51, 10, 1) + CheckNewAnswer( + ("structured", 90, 100, 10, 1), + ("streaming", 90, 100, 10, 1) + ) + ) + } + + testWithAllOptions("update mode - session window - no key") { + val inputData = MemoryStream[Int] + + val windowedAggregation = inputData.toDF() + .selectExpr("*") + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(session_window($"eventTime", "5 seconds") as 'session) + .agg(count("*") as 'count, sum("value") as 'sum) + .select($"session".getField("start").cast("long").as[Long], + $"session".getField("end").cast("long").as[Long], $"count".as[Long], $"sum".as[Long]) + + val e = intercept[StreamingQueryException] { + testStream(windowedAggregation, OutputMode.Update())( + AddData(inputData, 40), + CheckAnswer() // this is just to trigger the exception + ) + } + Seq("Global aggregation with session window", "not supported").foreach { m => + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(m.toLowerCase(Locale.ROOT))) + } + } +} From 78c6f2d07656dd134886dab2fcf4d450ef91cf84 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Wed, 14 Jul 2021 13:35:28 +0900 Subject: [PATCH 02/11] Address metrics stuff --- .../StreamingSessionWindowStateManager.scala | 17 ++++++++++++++--- .../execution/streaming/statefulOperators.scala | 6 ++++-- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingSessionWindowStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingSessionWindowStateManager.scala index 6561286448b4..5130933f52ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingSessionWindowStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingSessionWindowStateManager.scala @@ -68,8 +68,11 @@ sealed trait StreamingSessionWindowStateManager extends Serializable { * {@code extractKeyWithoutSession}. * @param sessions The all sessions including existing sessions if it's active. * Existing sessions which aren't included in this parameter will be removed. + * @return A tuple having two elements + * 1. number of added/updated rows + * 2. number of deleted rows */ - def updateSessions(store: StateStore, key: UnsafeRow, sessions: Seq[UnsafeRow]): Unit + def updateSessions(store: StateStore, key: UnsafeRow, sessions: Seq[UnsafeRow]): (Long, Long) /** * Removes using a predicate on values, with returning removed values via iterator. @@ -168,7 +171,7 @@ class StreamingSessionWindowStateManagerImplV1( override def updateSessions( store: StateStore, key: UnsafeRow, - sessions: Seq[UnsafeRow]): Unit = { + sessions: Seq[UnsafeRow]): (Long, Long) = { // Below two will be used multiple times - need to make sure this is not a stream or iterator. val newValues = sessions.toList val savedStates = getSessionsWithKeys(store, key) @@ -225,7 +228,7 @@ class StreamingSessionWindowStateManagerImplV1( store: StateStore, key: UnsafeRow, oldValues: List[(UnsafeRow, UnsafeRow)], - values: List[UnsafeRow]): Unit = { + values: List[UnsafeRow]): (Long, Long) = { // Here the key doesn't represent the state key - we need to construct the key for state val keyAndValues = values.map { row => val sessionStart = helper.extractTimePair(row)._1 @@ -236,16 +239,24 @@ class StreamingSessionWindowStateManagerImplV1( val keysForValues = keyAndValues.map(_._1) val keysForOldValues = oldValues.map(_._1) + var upsertedRows = 0L + var deletedRows = 0L + // We should "replace" the value instead of "delete" and "put" if the start time // equals to. This will remove unnecessary tombstone being written to the delta, which is // implementation details on state store implementations. + keysForOldValues.filterNot(keysForValues.contains).foreach { oldKey => store.remove(oldKey) + deletedRows += 1 } keyAndValues.foreach { case (key, value) => store.put(key, value) + upsertedRows += 1 } + + (upsertedRows, deletedRows) } override def abortIfNeeded(store: StateStore): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 74743c5ec058..3c7a004c3184 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -733,6 +733,7 @@ case class SessionWindowStateStoreSaveExec( needFilter: Boolean, returnOnlyUpdatedRows: Boolean): Iterator[InternalRow] = { val numUpdatedStateRows = longMetric("numUpdatedStateRows") + val numRemovedStateRows = longMetric("numRemovedStateRows") val iter = if (needFilter) { baseIter.filter(row => !watermarkPredicateForData.get.eval(row)) } else { @@ -745,8 +746,9 @@ case class SessionWindowStateStoreSaveExec( private def applyChangesOnKey(): Unit = { if (curValuesOnKey.nonEmpty) { - val updatedRows = stateManager.updateSessions(store, curKey, curValuesOnKey) - numUpdatedStateRows += updatedRows + val (upserted, deleted) = stateManager.updateSessions(store, curKey, curValuesOnKey) + numUpdatedStateRows += upserted + numRemovedStateRows += deleted curValuesOnKey.clear } } From 79f21977c77b640075c219e066b9655db343584c Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Wed, 14 Jul 2021 19:39:27 +0900 Subject: [PATCH 03/11] Fix a bug on UpdatingSessionsIterator which made some UT failing --- .../execution/aggregate/UpdatingSessionsIterator.scala | 10 ++++++---- .../spark/sql/DataFrameSessionWindowingSuite.scala | 1 - .../streaming/UpdatingSessionsIteratorSuite.scala | 8 ++++---- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionsIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionsIterator.scala index 2c611e3d0d22..0a60ddc0a98e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionsIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionsIterator.scala @@ -191,19 +191,21 @@ class UpdatingSessionsIterator( } private def closeCurrentSession(keyChanged: Boolean): Unit = { - assert(returnRowsIter == null || !returnRowsIter.hasNext) - returnRows = rowsForCurrentSession rowsForCurrentSession = null - val groupingKey = generateGroupingKey() + val groupingKey = generateGroupingKey().copy() val currentRowsIter = returnRows.generateIterator().map { internalRow => val valueRow = valueProj(internalRow) restoreProj(join2(groupingKey, valueRow)).copy() } - returnRowsIter = currentRowsIter + if (returnRowsIter != null && returnRowsIter.hasNext) { + returnRowsIter = returnRowsIter ++ currentRowsIter + } else { + returnRowsIter = currentRowsIter + } if (keyChanged) processedKeys.add(currentKeys) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala index 6dc1860bfa7a..b70b2c670732 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala @@ -93,7 +93,6 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSparkSession ) } - // FIXME: fix the failing test - check if it still fails or not test("session window groupBy with multiple keys statement - one distinct") { val df = Seq( ("2016-03-27 19:39:34", 1, "a"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionsIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionsIteratorSuite.scala index 2a4245d3b363..045901bc20ca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionsIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionsIteratorSuite.scala @@ -199,9 +199,9 @@ class UpdatingSessionsIteratorSuite extends SharedSparkSession { val row6 = createRow("a", 2, 115, 125, 20, 1.2) val rows3 = List(row5, row6) + // This is to test the edge case that the last input row creates a new session. val row7 = createRow("a", 2, 127, 137, 30, 1.3) - val row8 = createRow("a", 2, 135, 145, 40, 1.4) - val rows4 = List(row7, row8) + val rows4 = List(row7) val rowsAll = rows1 ++ rows2 ++ rows3 ++ rows4 @@ -244,8 +244,8 @@ class UpdatingSessionsIteratorSuite extends SharedSparkSession { } retRows4.zip(rows4).foreach { case (retRow, expectedRow) => - // session being expanded to (127 ~ 145) - assertRowsEqualsWithNewSession(expectedRow, retRow, 127, 145) + // session being expanded to (127 ~ 137) + assertRowsEqualsWithNewSession(expectedRow, retRow, 127, 137) } assert(iterator.hasNext === false) From 1359123d6b835089fcf9a6eb29d8aa8ca517b16e Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Tue, 23 Mar 2021 08:01:19 +0900 Subject: [PATCH 04/11] Giving co-authorship to Yuanjian Li From 0f8ad99ac937b6f4b17fe8a6eac8aee4d01e9b9d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 23 Mar 2021 08:02:42 +0900 Subject: [PATCH 05/11] Giving co-authorship to Liang-Chi Hsieh From 37c98dee2491d6d1727878d7e7ce17f167e5e969 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Wed, 14 Jul 2021 23:24:29 +0900 Subject: [PATCH 06/11] fix --- .../src/test/resources/sql-functions/sql-expression-schema.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 75be26a26bd0..41692d20ed56 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -1,6 +1,6 @@ ## Summary - - Number of queries: 360 + - Number of queries: 361 - Number of expressions that missing example: 14 - Expressions missing examples: bigint,binary,boolean,date,decimal,double,float,int,smallint,string,timestamp,tinyint,session_window,window ## Schema of Built-in Functions From 441d6e5f8fa8dfee91b703cb23f4cd72db0851e9 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Thu, 15 Jul 2021 09:56:52 +0900 Subject: [PATCH 07/11] Fix for Scala 2.13 --- .../spark/sql/execution/streaming/statefulOperators.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 3c7a004c3184..7090a0cb258e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -746,7 +746,7 @@ case class SessionWindowStateStoreSaveExec( private def applyChangesOnKey(): Unit = { if (curValuesOnKey.nonEmpty) { - val (upserted, deleted) = stateManager.updateSessions(store, curKey, curValuesOnKey) + val (upserted, deleted) = stateManager.updateSessions(store, curKey, curValuesOnKey.toSeq) numUpdatedStateRows += upserted numRemovedStateRows += deleted curValuesOnKey.clear From b9c0357dece38ebe2eb1a3724dac818780b5fbff Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Thu, 15 Jul 2021 11:54:19 +0900 Subject: [PATCH 08/11] Another fix for Scala 2.13 --- .../spark/sql/execution/streaming/statefulOperators.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 7090a0cb258e..bb566ba925bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -793,7 +793,7 @@ case class SessionWindowStateStoreSaveExec( private def putToStore( baseIter: Iterator[InternalRow], store: StateStore, - needFilter: Boolean) { + needFilter: Boolean): Unit = { val iterPutToStore = iteratorPutToStore(baseIter, store, needFilter, false) while (iterPutToStore.hasNext) { iterPutToStore.next() From a4fa37ba69f3d2bbe5b443dcfbc6ee361a0b58e1 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Thu, 15 Jul 2021 16:53:43 +0900 Subject: [PATCH 09/11] Reflect review comments --- python/pyspark/sql/functions.py | 2 +- .../sql/catalyst/analysis/Analyzer.scala | 6 +-- .../catalyst/expressions/SessionWindow.scala | 44 ++++--------------- .../sql/catalyst/expressions/TimeWindow.scala | 4 +- .../apache/spark/sql/internal/SQLConf.scala | 1 + .../spark/sql/execution/SparkStrategies.scala | 2 - .../sql/execution/aggregate/AggUtils.scala | 9 ++-- .../python/AggregateInPandasExec.scala | 14 +++++- .../org/apache/spark/sql/functions.scala | 5 ++- .../StreamingSessionWindowSuite.scala | 2 + 10 files changed, 40 insertions(+), 49 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 295372fa4cc1..6e7b1f7cfafd 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2336,7 +2336,7 @@ def check_string_field(field, fieldName): def session_window(timeColumn, gapDuration): """ Generates session window given a timestamp specifying column. - Session window is the one of dynamic windows, which means the length of window is vary + Session window is one of dynamic windows, which means the length of window is varying according to the given inputs. The length of session window is defined as "the timestamp of latest input of the session + gap duration", so when the new inputs are bound to the current session window, the end time of session window can be expanded according to the new diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 1c5f1338e0fd..565624b6fd6a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -4001,12 +4001,12 @@ object SessionWindowing extends Rule[LogicalPlan] { case s: SessionWindow => sessionAttr } - // For backwards compatibility we add a filter to filter out nulls + // As same as tumbling window, we add a filter to filter out nulls. val filterExpr = IsNotNull(session.timeColumn) replacedPlan.withNewChildren( - Filter(filterExpr, - Project(sessionStruct +: child.output, child)) :: Nil) + Project(sessionStruct +: child.output, + Filter(filterExpr, child)) :: Nil) } else if (numWindowExpr > 1) { throw QueryCompilationErrors.multiTimeWindowExpressionsNotSupportedError(p) } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala index ccc451235100..60b07444f9fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala @@ -17,13 +17,17 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure -import org.apache.spark.sql.catalyst.util.{DateTimeConstants, IntervalUtils} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String +/** + * Represent the session window. + * + * @param timeColumn the start time of session window + * @param gapDuration the duration of session gap, meaning the session will close if there is + * no new element appeared within "the last element in session + gap". + */ case class SessionWindow(timeColumn: Expression, gapDuration: Long) extends UnaryExpression with ImplicitCastInputTypes with Unevaluable @@ -34,7 +38,7 @@ case class SessionWindow(timeColumn: Expression, gapDuration: Long) extends Unar ////////////////////////// def this(timeColumn: Expression, gapDuration: Expression) = { - this(timeColumn, SessionWindow.parseExpression(gapDuration)) + this(timeColumn, TimeWindow.parseExpression(gapDuration)) } override def child: Expression = timeColumn @@ -64,40 +68,10 @@ case class SessionWindow(timeColumn: Expression, gapDuration: Long) extends Unar object SessionWindow { val marker = "spark.sessionWindow" - /** - * Parses the interval string for a valid time duration. CalendarInterval expects interval - * strings to start with the string `interval`. For usability, we prepend `interval` to the string - * if the user omitted it. - * - * @param interval The interval string - * @return The interval duration in microseconds. SparkSQL casts TimestampType has microsecond - * precision. - */ - private def getIntervalInMicroSeconds(interval: String): Long = { - val cal = IntervalUtils.stringToInterval(UTF8String.fromString(interval)) - if (cal.months != 0) { - throw new IllegalArgumentException( - s"Intervals greater than a month is not supported ($interval).") - } - cal.days * DateTimeConstants.MICROS_PER_DAY + cal.microseconds - } - - /** - * Parses the duration expression to generate the long value for the original constructor so - * that we can use `window` in SQL. - */ - private def parseExpression(expr: Expression): Long = expr match { - case NonNullLiteral(s, StringType) => getIntervalInMicroSeconds(s.toString) - case IntegerLiteral(i) => i.toLong - case NonNullLiteral(l, LongType) => l.toString.toLong - case _ => throw new AnalysisException("The duration and time inputs to window must be " + - "an integer, long or string literal.") - } - def apply( timeColumn: Expression, gapDuration: String): SessionWindow = { SessionWindow(timeColumn, - getIntervalInMicroSeconds(gapDuration)) + TimeWindow.getIntervalInMicroSeconds(gapDuration)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala index 5b13872e566a..e79e8d767c95 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala @@ -109,7 +109,7 @@ object TimeWindow { * @return The interval duration in microseconds. SparkSQL casts TimestampType has microsecond * precision. */ - private def getIntervalInMicroSeconds(interval: String): Long = { + def getIntervalInMicroSeconds(interval: String): Long = { val cal = IntervalUtils.stringToInterval(UTF8String.fromString(interval)) if (cal.months != 0) { throw new IllegalArgumentException( @@ -122,7 +122,7 @@ object TimeWindow { * Parses the duration expression to generate the long value for the original constructor so * that we can use `window` in SQL. */ - private def parseExpression(expr: Expression): Long = expr match { + def parseExpression(expr: Expression): Long = expr match { case NonNullLiteral(s, StringType) => getIntervalInMicroSeconds(s.toString) case IntegerLiteral(i) => i.toLong case NonNullLiteral(l, LongType) => l.toString.toLong diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index fda99d6285d5..a1c344a25d7c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1616,6 +1616,7 @@ object SQLConf { .doc("When true, streaming session window sorts and merge sessions in local partition " + "prior to shuffle. This is to reduce the rows to shuffle, but only beneficial when " + "there're lots of rows in a batch being assigned to same sessions.") + .version("3.2.0") .booleanConf .createWithDefault(false) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index b4af14df5be6..6d10fa83f432 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -324,8 +324,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { throw QueryCompilationErrors.groupAggPandasUDFUnsupportedByStreamingAggError() } - val stateVersion = conf.getConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION) - val sessionWindowOption = namedGroupingExpressions.find { p => p.metadata.contains(SessionWindow.marker) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 975a8edb049c..0f239b457fd1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -114,6 +114,8 @@ object AggUtils { resultExpressions = partialResultExpressions, child = child) + // If we have session window expression in aggregation, we add MergingSessionExec to + // merge sessions with calculating aggregation values. val interExec: SparkPlan = mayAppendMergingSessionExec(groupingExpressions, aggregateExpressions, partialAggregate) @@ -144,6 +146,9 @@ object AggUtils { resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { + // If we have session window expression in aggregation, we add UpdatingSessionsExec to + // calculate sessions for input rows and update rows' session column, so that further + // aggregations can aggregate input rows for the same session. val maySessionChild = mayAppendUpdatingSessionExec(groupingExpressions, child) val distinctAttributes = normalizedNamedDistinctExpressions.map(_.toAttribute) @@ -394,9 +399,7 @@ object AggUtils { val groupingAttributes = groupingExpressions.map(_.toAttribute) - // we don't do partial aggregate here, because it requires additional shuffle - // and there will be less rows which have same session start - // here doing partial merge is to have aggregated columns with default value for each row + // Here doing partial merge is to have aggregated columns with default value for each row. val partialAggregate: SparkPlan = { val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala index 7fd39146c95f..69802b143c11 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -26,7 +26,7 @@ import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.aggregate.UpdatingSessionsIterator import org.apache.spark.sql.types.{DataType, StructField, StructType} @@ -65,6 +65,14 @@ case class AggregateInPandasExec( case None => groupingExpressions } + override def requiredChildDistribution: Seq[Distribution] = { + if (groupingExpressions.isEmpty) { + AllTuples :: Nil + } else { + ClusteredDistribution(groupingExpressions) :: Nil + } + } + override def requiredChildOrdering: Seq[Seq[SortOrder]] = sessionWindowOption match { case Some(sessionExpression) => Seq((groupingWithoutSessionExpressions ++ Seq(sessionExpression)) @@ -116,6 +124,10 @@ case class AggregateInPandasExec( // Map grouped rows to ArrowPythonRunner results, Only execute if partition is not empty inputRDD.mapPartitionsInternal { iter => if (iter.isEmpty) iter else { + // If we have session window expression in aggregation, we wrap iterator with + // UpdatingSessionIterator to calculate sessions for input rows and update + // rows' session column, so that further aggregations can aggregate input rows + // for the same session. val newIter: Iterator[InternalRow] = mayAppendUpdatingSessionIterator(iter) val prunedProj = UnsafeProjection.create(allInputs.toSeq, child.output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 688842d0e2a5..7db8e8a90a59 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3633,13 +3633,14 @@ object functions { /** * Generates session window given a timestamp specifying column. * - * Session window is the one of dynamic windows, which means the length of window is vary + * Session window is one of dynamic windows, which means the length of window is varying * according to the given inputs. The length of session window is defined as "the timestamp * of latest input of the session + gap duration", so when the new inputs are bound to the * current session window, the end time of session window can be expanded according to the new * inputs. * - * Windows can support microsecond precision. Windows in the order of months are not supported. + * Windows can support microsecond precision. gapDuration in the order of months are not + * supported. * * For a streaming query, you may use the function `current_timestamp` to generate windows on * processing time. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala index d1d21e35ef78..eb9fb76154c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala @@ -87,6 +87,8 @@ class StreamingSessionWindowSuite extends StreamTest "CAST(session.end AS LONG) - CAST(session.start AS LONG) AS durationMs", "numEvents") + sessionUpdates.explain() + testStream(sessionUpdates, OutputMode.Complete())( AddData(inputData, ("hello world spark streaming", 40L), From e7a2a37227034c92daf404f556d1790dd6c29c3a Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Fri, 16 Jul 2021 10:31:52 +0900 Subject: [PATCH 10/11] Reflect review comments & fix metrics --- .../streaming/statefulOperators.scala | 65 +++++++++---------- 1 file changed, 30 insertions(+), 35 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index bb566ba925bf..2dd91decfa99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -530,12 +530,13 @@ case class SessionWindowStateStoreRestoreExec( override def keyExpressions: Seq[Attribute] = keyWithoutSessionExpressions + assert(keyExpressions.nonEmpty, "Grouping key must be specified when using sessionWindow") + private val stateManager = StreamingSessionWindowStateManager.createStateManager( keyWithoutSessionExpressions, sessionExpression, child.output, stateFormatVersion) override protected def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") - assert(keyExpressions.nonEmpty, "Grouping key must be specified when using sessionWindow") child.execute().mapPartitionsWithReadStateStore( getStateInfo, @@ -558,8 +559,8 @@ case class SessionWindowStateStoreRestoreExec( keyWithoutSessionExpressions, sessionExpression, child.output).map { row => - numOutputRows += 1 - row + numOutputRows += 1 + row } } } @@ -573,11 +574,7 @@ case class SessionWindowStateStoreRestoreExec( } override def requiredChildDistribution: Seq[Distribution] = { - if (keyWithoutSessionExpressions.isEmpty) { - AllTuples :: Nil - } else { - ClusteredDistribution(keyWithoutSessionExpressions, stateInfo.map(_.numPartitions)) :: Nil - } + ClusteredDistribution(keyWithoutSessionExpressions, stateInfo.map(_.numPartitions)) :: Nil } override def requiredChildOrdering: Seq[Seq[SortOrder]] = { @@ -592,7 +589,7 @@ case class SessionWindowStateStoreRestoreExec( * For each input tuple, the key is calculated and the tuple is `put` into the [[StateStore]]. */ case class SessionWindowStateStoreSaveExec( - keyExpressions: Seq[Attribute], + keyWithoutSessionExpressions: Seq[Attribute], sessionExpression: Attribute, stateInfo: Option[StatefulOperatorStateInfo] = None, outputMode: Option[OutputMode] = None, @@ -601,9 +598,7 @@ case class SessionWindowStateStoreSaveExec( child: SparkPlan) extends UnaryExecNode with StateStoreWriter with WatermarkSupport { - private val keyWithoutSessionExpressions = keyExpressions.filterNot { p => - p.semanticEquals(sessionExpression) - } + override def keyExpressions: Seq[Attribute] = keyWithoutSessionExpressions private val stateManager = StreamingSessionWindowStateManager.createStateManager( keyWithoutSessionExpressions, sessionExpression, child.output, stateFormatVersion) @@ -624,6 +619,7 @@ case class SessionWindowStateStoreSaveExec( Some(session.streams.stateStoreCoordinator)) { case (store, iter) => val numOutputRows = longMetric("numOutputRows") + val numRemovedStateRows = longMetric("numRemovedStateRows") val allUpdatesTimeMs = longMetric("allUpdatesTimeMs") val allRemovalsTimeMs = longMetric("allRemovalsTimeMs") val commitTimeMs = longMetric("commitTimeMs") @@ -632,9 +628,8 @@ case class SessionWindowStateStoreSaveExec( // Update and output all rows in the StateStore. case Some(Complete) => allUpdatesTimeMs += timeTakenMs { - putToStore(iter, store, false) + putToStore(iter, store) } - allRemovalsTimeMs += 0 commitTimeMs += timeTakenMs { stateManager.commit(store) } @@ -648,7 +643,9 @@ case class SessionWindowStateStoreSaveExec( // Assumption: watermark predicates must be non-empty if append mode is allowed case Some(Append) => allUpdatesTimeMs += timeTakenMs { - putToStore(iter, store, true) + val filteredIter = applyRemovingRowsOlderThanWatermark(iter, + watermarkPredicateForData.get) + putToStore(filteredIter, store) } val removalStartTimeNs = System.nanoTime @@ -661,6 +658,7 @@ case class SessionWindowStateStoreSaveExec( finished = true null } else { + numRemovedStateRows += 1 numOutputRows += 1 removedIter.next() } @@ -670,17 +668,25 @@ case class SessionWindowStateStoreSaveExec( allRemovalsTimeMs += NANOSECONDS.toMillis(System.nanoTime - removalStartTimeNs) commitTimeMs += timeTakenMs { store.commit() } setStoreMetrics(store) + setOperatorMetrics() } } case Some(Update) => - val iterPutToStore = iteratorPutToStore(iter, store, true, true) + val baseIterator = watermarkPredicateForData match { + case Some(predicate) => applyRemovingRowsOlderThanWatermark(iter, predicate) + case None => iter + } + val iterPutToStore = iteratorPutToStore(baseIterator, store, + returnOnlyUpdatedRows = true) new NextIterator[InternalRow] { private val updatesStartTimeNs = System.nanoTime override protected def getNext(): InternalRow = { if (iterPutToStore.hasNext) { - iterPutToStore.next() + val row = iterPutToStore.next() + numOutputRows += 1 + row } else { finished = true null @@ -695,16 +701,18 @@ case class SessionWindowStateStoreSaveExec( val removedIter = stateManager.removeByValueCondition( store, watermarkPredicateForData.get.eval) while (removedIter.hasNext) { + numRemovedStateRows += 1 removedIter.next() } } } commitTimeMs += timeTakenMs { store.commit() } setStoreMetrics(store) + setOperatorMetrics() } } - case _ => throw new UnsupportedOperationException(s"Invalid output mode: $outputMode") + case _ => throw QueryExecutionErrors.invalidStreamingOutputModeError(outputMode) } } } @@ -714,11 +722,7 @@ case class SessionWindowStateStoreSaveExec( override def outputPartitioning: Partitioning = child.outputPartitioning override def requiredChildDistribution: Seq[Distribution] = { - if (keyExpressions.isEmpty) { - AllTuples :: Nil - } else { - ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil - } + ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil } override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { @@ -728,17 +732,11 @@ case class SessionWindowStateStoreSaveExec( } private def iteratorPutToStore( - baseIter: Iterator[InternalRow], + iter: Iterator[InternalRow], store: StateStore, - needFilter: Boolean, returnOnlyUpdatedRows: Boolean): Iterator[InternalRow] = { val numUpdatedStateRows = longMetric("numUpdatedStateRows") val numRemovedStateRows = longMetric("numRemovedStateRows") - val iter = if (needFilter) { - baseIter.filter(row => !watermarkPredicateForData.get.eval(row)) - } else { - baseIter - } new NextIterator[InternalRow] { var curKey: UnsafeRow = null @@ -790,11 +788,8 @@ case class SessionWindowStateStoreSaveExec( } } - private def putToStore( - baseIter: Iterator[InternalRow], - store: StateStore, - needFilter: Boolean): Unit = { - val iterPutToStore = iteratorPutToStore(baseIter, store, needFilter, false) + private def putToStore(baseIter: Iterator[InternalRow], store: StateStore): Unit = { + val iterPutToStore = iteratorPutToStore(baseIter, store, returnOnlyUpdatedRows = false) while (iterPutToStore.hasNext) { iterPutToStore.next() } From bbade3501f16e9437ba9af4feca3e2029785d273 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Fri, 16 Jul 2021 18:49:58 +0900 Subject: [PATCH 11/11] Reflect review comment --- .../spark/sql/streaming/StreamingSessionWindowSuite.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala index eb9fb76154c3..a381d069df2d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala @@ -24,7 +24,7 @@ import org.scalatest.matchers.must.Matchers import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.streaming.MemoryStream -import org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider +import org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, RocksDBStateStoreProvider} import org.apache.spark.sql.functions.{count, session_window, sum} import org.apache.spark.sql.internal.SQLConf @@ -43,7 +43,9 @@ class StreamingSessionWindowSuite extends StreamTest (SQLConf.STREAMING_SESSION_WINDOW_MERGE_SESSIONS_IN_LOCAL_PARTITION.key, value) } val providerOptions = Seq( - classOf[HDFSBackedStateStoreProvider].getCanonicalName).map { value => + classOf[HDFSBackedStateStoreProvider].getCanonicalName, + classOf[RocksDBStateStoreProvider].getCanonicalName + ).map { value => (SQLConf.STATE_STORE_PROVIDER_CLASS.key, value.stripSuffix("$")) }