diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 9240ae6a8c51..6e7b1f7cfafd 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2333,6 +2333,41 @@ def check_string_field(field, fieldName): return Column(res) +def session_window(timeColumn, gapDuration): + """ + Generates session window given a timestamp specifying column. + Session window is one of dynamic windows, which means the length of window is varying + according to the given inputs. The length of session window is defined as "the timestamp + of latest input of the session + gap duration", so when the new inputs are bound to the + current session window, the end time of session window can be expanded according to the new + inputs. + Windows can support microsecond precision. Windows in the order of months are not supported. + For a streaming query, you may use the function `current_timestamp` to generate windows on + processing time. + gapDuration is provided as strings, e.g. '1 second', '1 day 12 hours', '2 minutes'. Valid + interval strings are 'week', 'day', 'hour', 'minute', 'second', 'millisecond', 'microsecond'. + The output column will be a struct called 'session_window' by default with the nested columns + 'start' and 'end', where 'start' and 'end' will be of :class:`pyspark.sql.types.TimestampType`. + .. versionadded:: 3.2.0 + Examples + -------- + >>> df = spark.createDataFrame([("2016-03-11 09:00:07", 1)]).toDF("date", "val") + >>> w = df.groupBy(session_window("date", "5 seconds")).agg(sum("val").alias("sum")) + >>> w.select(w.session_window.start.cast("string").alias("start"), + ... w.session_window.end.cast("string").alias("end"), "sum").collect() + [Row(start='2016-03-11 09:00:07', end='2016-03-11 09:00:12', sum=1)] + """ + def check_string_field(field, fieldName): + if not field or type(field) is not str: + raise TypeError("%s should be provided as a string" % fieldName) + + sc = SparkContext._active_spark_context + time_col = _to_java_column(timeColumn) + check_string_field(gapDuration, "gapDuration") + res = sc._jvm.functions.session_window(time_col, gapDuration) + return Column(res) + + # ---------------------------- misc functions ---------------------------------- def crc32(col): diff --git a/python/pyspark/sql/functions.pyi b/python/pyspark/sql/functions.pyi index 0a4aabf1bf71..051a6f1dbc53 100644 --- a/python/pyspark/sql/functions.pyi +++ b/python/pyspark/sql/functions.pyi @@ -135,6 +135,7 @@ def window( slideDuration: Optional[str] = ..., startTime: Optional[str] = ..., ) -> Column: ... +def session_window(timeColumn: ColumnOrName, gapDuration: str) -> Column: ... def crc32(col: ColumnOrName) -> Column: ... def md5(col: ColumnOrName) -> Column: ... def sha1(col: ColumnOrName) -> Column: ... diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 49390d35bd1d..565624b6fd6a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -296,6 +296,7 @@ class Analyzer(override val catalogManager: CatalogManager) GlobalAggregates :: ResolveAggregateFunctions :: TimeWindowing :: + SessionWindowing :: ResolveInlineTables :: ResolveHigherOrderFunctions(catalogManager) :: ResolveLambdaVariables :: @@ -3856,9 +3857,13 @@ object TimeWindowing extends Rule[LogicalPlan] { val windowExpressions = p.expressions.flatMap(_.collect { case t: TimeWindow => t }).toSet - val numWindowExpr = windowExpressions.size + val numWindowExpr = p.expressions.flatMap(_.collect { + case s: SessionWindow => s + case t: TimeWindow => t + }).toSet.size + // Only support a single window expression for now - if (numWindowExpr == 1 && + if (numWindowExpr == 1 && windowExpressions.nonEmpty && windowExpressions.head.timeColumn.resolved && windowExpressions.head.checkInputDataTypes().isSuccess) { @@ -3933,6 +3938,83 @@ object TimeWindowing extends Rule[LogicalPlan] { } } +/** Maps a time column to a session window. */ +object SessionWindowing extends Rule[LogicalPlan] { + import org.apache.spark.sql.catalyst.dsl.expressions._ + + private final val SESSION_COL_NAME = "session_window" + private final val SESSION_START = "start" + private final val SESSION_END = "end" + + /** + * Generates the logical plan for generating session window on a timestamp column. + * Each session window is initially defined as [timestamp, timestamp + gap). + * + * This also adds a marker to the session column so that downstream can easily find the column + * on session window. + */ + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { + case p: LogicalPlan if p.children.size == 1 => + val child = p.children.head + val sessionExpressions = + p.expressions.flatMap(_.collect { case s: SessionWindow => s }).toSet + + val numWindowExpr = p.expressions.flatMap(_.collect { + case s: SessionWindow => s + case t: TimeWindow => t + }).toSet.size + + // Only support a single session expression for now + if (numWindowExpr == 1 && sessionExpressions.nonEmpty && + sessionExpressions.head.timeColumn.resolved && + sessionExpressions.head.checkInputDataTypes().isSuccess) { + + val session = sessionExpressions.head + + val metadata = session.timeColumn match { + case a: Attribute => a.metadata + case _ => Metadata.empty + } + + val newMetadata = new MetadataBuilder() + .withMetadata(metadata) + .putBoolean(SessionWindow.marker, true) + .build() + + val sessionAttr = AttributeReference( + SESSION_COL_NAME, session.dataType, metadata = newMetadata)() + + val sessionStart = PreciseTimestampConversion(session.timeColumn, TimestampType, LongType) + val sessionEnd = sessionStart + session.gapDuration + + val literalSessionStruct = CreateNamedStruct( + Literal(SESSION_START) :: + PreciseTimestampConversion(sessionStart, LongType, TimestampType) :: + Literal(SESSION_END) :: + PreciseTimestampConversion(sessionEnd, LongType, TimestampType) :: + Nil) + + val sessionStruct = Alias(literalSessionStruct, SESSION_COL_NAME)( + exprId = sessionAttr.exprId, explicitMetadata = Some(newMetadata)) + + val replacedPlan = p transformExpressions { + case s: SessionWindow => sessionAttr + } + + // As same as tumbling window, we add a filter to filter out nulls. + val filterExpr = IsNotNull(session.timeColumn) + + replacedPlan.withNewChildren( + Project(sessionStruct +: child.output, + Filter(filterExpr, child)) :: Nil) + } else if (numWindowExpr > 1) { + throw QueryCompilationErrors.multiTimeWindowExpressionsNotSupportedError(p) + } else { + p // Return unchanged. Analyzer will throw exception later + } + } +} + /** * Resolve expressions if they contains [[NamePlaceholder]]s. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 60ca1e96a5e3..234da76e06a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -552,6 +552,7 @@ object FunctionRegistry { expression[WeekOfYear]("weekofyear"), expression[Year]("year"), expression[TimeWindow]("window"), + expression[SessionWindow]("session_window"), expression[MakeDate]("make_date"), expression[MakeTimestamp]("make_timestamp"), expression[MakeTimestampNTZ]("make_timestamp_ntz", true), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala new file mode 100644 index 000000000000..60b07444f9fd --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure +import org.apache.spark.sql.types._ + +/** + * Represent the session window. + * + * @param timeColumn the start time of session window + * @param gapDuration the duration of session gap, meaning the session will close if there is + * no new element appeared within "the last element in session + gap". + */ +case class SessionWindow(timeColumn: Expression, gapDuration: Long) extends UnaryExpression + with ImplicitCastInputTypes + with Unevaluable + with NonSQLExpression { + + ////////////////////////// + // SQL Constructors + ////////////////////////// + + def this(timeColumn: Expression, gapDuration: Expression) = { + this(timeColumn, TimeWindow.parseExpression(gapDuration)) + } + + override def child: Expression = timeColumn + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) + override def dataType: DataType = new StructType() + .add(StructField("start", TimestampType)) + .add(StructField("end", TimestampType)) + + // This expression is replaced in the analyzer. + override lazy val resolved = false + + /** Validate the inputs for the gap duration in addition to the input data type. */ + override def checkInputDataTypes(): TypeCheckResult = { + val dataTypeCheck = super.checkInputDataTypes() + if (dataTypeCheck.isSuccess) { + if (gapDuration <= 0) { + return TypeCheckFailure(s"The window duration ($gapDuration) must be greater than 0.") + } + } + dataTypeCheck + } + + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(timeColumn = newChild) +} + +object SessionWindow { + val marker = "spark.sessionWindow" + + def apply( + timeColumn: Expression, + gapDuration: String): SessionWindow = { + SessionWindow(timeColumn, + TimeWindow.getIntervalInMicroSeconds(gapDuration)) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala index 5b13872e566a..e79e8d767c95 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala @@ -109,7 +109,7 @@ object TimeWindow { * @return The interval duration in microseconds. SparkSQL casts TimestampType has microsecond * precision. */ - private def getIntervalInMicroSeconds(interval: String): Long = { + def getIntervalInMicroSeconds(interval: String): Long = { val cal = IntervalUtils.stringToInterval(UTF8String.fromString(interval)) if (cal.months != 0) { throw new IllegalArgumentException( @@ -122,7 +122,7 @@ object TimeWindow { * Parses the duration expression to generate the long value for the original constructor so * that we can use `window` in SQL. */ - private def parseExpression(expr: Expression): Long = expr match { + def parseExpression(expr: Expression): Long = expr match { case NonNullLiteral(s, StringType) => getIntervalInMicroSeconds(s.toString) case IntegerLiteral(i) => i.toLong case NonNullLiteral(l, LongType) => l.toString.toLong diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 2cee614eb614..7a33d52a6573 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -366,8 +366,9 @@ private[spark] object QueryCompilationErrors { } def multiTimeWindowExpressionsNotSupportedError(t: TreeNode[_]): Throwable = { - new AnalysisException("Multiple time window expressions would result in a cartesian product " + - "of rows, therefore they are currently not supported.", t.origin.line, t.origin.startPosition) + new AnalysisException("Multiple time/session window expressions would result in a cartesian " + + "product of rows, therefore they are currently not supported.", t.origin.line, + t.origin.startPosition) } def viewOutputNumberMismatchQueryColumnNamesError( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index f1bfb1465b2f..a1c344a25d7c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1610,6 +1610,27 @@ object SQLConf { .checkValue(v => Set(1, 2).contains(v), "Valid versions are 1 and 2") .createWithDefault(2) + val STREAMING_SESSION_WINDOW_MERGE_SESSIONS_IN_LOCAL_PARTITION = + buildConf("spark.sql.streaming.sessionWindow.merge.sessions.in.local.partition") + .internal() + .doc("When true, streaming session window sorts and merge sessions in local partition " + + "prior to shuffle. This is to reduce the rows to shuffle, but only beneficial when " + + "there're lots of rows in a batch being assigned to same sessions.") + .version("3.2.0") + .booleanConf + .createWithDefault(false) + + val STREAMING_SESSION_WINDOW_STATE_FORMAT_VERSION = + buildConf("spark.sql.streaming.sessionWindow.stateFormatVersion") + .internal() + .doc("State format version used by streaming session window in a streaming query. " + + "State between versions are tend to be incompatible, so state format version shouldn't " + + "be modified after running.") + .version("3.2.0") + .intConf + .checkValue(v => Set(1).contains(v), "Valid version is 1") + .createWithDefault(1) + val UNSUPPORTED_OPERATION_CHECK_ENABLED = buildConf("spark.sql.streaming.unsupportedOperationCheck") .internal() @@ -3676,6 +3697,9 @@ class SQLConf extends Serializable with Logging { def fastHashAggregateRowMaxCapacityBit: Int = getConf(FAST_HASH_AGGREGATE_MAX_ROWS_CAPACITY_BIT) + def streamingSessionWindowMergeSessionInLocalPartition: Boolean = + getConf(STREAMING_SESSION_WINDOW_MERGE_SESSIONS_IN_LOCAL_PARTITION) + def datetimeJava8ApiEnabled: Boolean = getConf(DATETIME_JAVA8API_ENABLED) def uiExplainMode: String = getConf(UI_EXPLAIN_MODE) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 65a592302c6a..6d10fa83f432 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -324,7 +324,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { throw QueryCompilationErrors.groupAggPandasUDFUnsupportedByStreamingAggError() } - val stateVersion = conf.getConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION) + val sessionWindowOption = namedGroupingExpressions.find { p => + p.metadata.contains(SessionWindow.marker) + } // Ideally this should be done in `NormalizeFloatingNumbers`, but we do it here because // `groupingExpressions` is not extracted during logical phase. @@ -335,12 +337,29 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } - AggUtils.planStreamingAggregation( - normalizedGroupingExpressions, - aggregateExpressions.map(expr => expr.asInstanceOf[AggregateExpression]), - rewrittenResultExpressions, - stateVersion, - planLater(child)) + sessionWindowOption match { + case Some(sessionWindow) => + val stateVersion = conf.getConf(SQLConf.STREAMING_SESSION_WINDOW_STATE_FORMAT_VERSION) + + AggUtils.planStreamingAggregationForSession( + normalizedGroupingExpressions, + sessionWindow, + aggregateExpressions.map(expr => expr.asInstanceOf[AggregateExpression]), + rewrittenResultExpressions, + stateVersion, + conf.streamingSessionWindowMergeSessionInLocalPartition, + planLater(child)) + + case None => + val stateVersion = conf.getConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION) + + AggUtils.planStreamingAggregation( + normalizedGroupingExpressions, + aggregateExpressions.map(expr => expr.asInstanceOf[AggregateExpression]), + rewrittenResultExpressions, + stateVersion, + planLater(child)) + } case _ => Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 58d341107347..0f239b457fd1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.execution.aggregate +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.streaming.{StateStoreRestoreExec, StateStoreSaveExec} +import org.apache.spark.sql.execution.streaming._ /** * Utility functions used by the query planner to convert our plan to new aggregation code path. @@ -113,6 +114,11 @@ object AggUtils { resultExpressions = partialResultExpressions, child = child) + // If we have session window expression in aggregation, we add MergingSessionExec to + // merge sessions with calculating aggregation values. + val interExec: SparkPlan = mayAppendMergingSessionExec(groupingExpressions, + aggregateExpressions, partialAggregate) + // 2. Create an Aggregate Operator for final aggregations. val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final)) // The attributes of the final aggregation buffer, which is presented as input to the result @@ -126,7 +132,7 @@ object AggUtils { aggregateAttributes = finalAggregateAttributes, initialInputBufferOffset = groupingExpressions.length, resultExpressions = resultExpressions, - child = partialAggregate) + child = interExec) finalAggregate :: Nil } @@ -140,6 +146,11 @@ object AggUtils { resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { + // If we have session window expression in aggregation, we add UpdatingSessionsExec to + // calculate sessions for input rows and update rows' session column, so that further + // aggregations can aggregate input rows for the same session. + val maySessionChild = mayAppendUpdatingSessionExec(groupingExpressions, child) + val distinctAttributes = normalizedNamedDistinctExpressions.map(_.toAttribute) val groupingAttributes = groupingExpressions.map(_.toAttribute) @@ -156,7 +167,7 @@ object AggUtils { aggregateAttributes = aggregateAttributes, resultExpressions = groupingAttributes ++ distinctAttributes ++ aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), - child = child) + child = maySessionChild) } // 2. Create an Aggregate Operator for partial merge aggregations. @@ -345,4 +356,177 @@ object AggUtils { finalAndCompleteAggregate :: Nil } + + /** + * Plans a streaming session aggregation using the following progression: + * + * - Partial Aggregation + * - all tuples will have aggregated columns with initial value + * - (If "spark.sql.streaming.sessionWindow.merge.sessions.in.local.partition" is enabled) + * - Sort within partition (sort: all keys) + * - MergingSessionExec + * - calculate session among tuples, and aggregate tuples in session with partial merge + * - Shuffle & Sort (distribution: keys "without" session, sort: all keys) + * - SessionWindowStateStoreRestore (group: keys "without" session) + * - merge input tuples with stored tuples (sessions) respecting sort order + * - MergingSessionExec + * - calculate session among tuples, and aggregate tuples in session with partial merge + * - NOTE: it leverages the fact that the output of SessionWindowStateStoreRestore is sorted + * - now there is at most 1 tuple per group, key with session + * - SessionWindowStateStoreSave (group: keys "without" session) + * - saves tuple(s) for the next batch (multiple sessions could co-exist at the same time) + * - Complete (output the current result of the aggregation) + */ + def planStreamingAggregationForSession( + groupingExpressions: Seq[NamedExpression], + sessionExpression: NamedExpression, + functionsWithoutDistinct: Seq[AggregateExpression], + resultExpressions: Seq[NamedExpression], + stateFormatVersion: Int, + mergeSessionsInLocalPartition: Boolean, + child: SparkPlan): Seq[SparkPlan] = { + + val groupWithoutSessionExpression = groupingExpressions.filterNot { p => + p.semanticEquals(sessionExpression) + } + + if (groupWithoutSessionExpression.isEmpty) { + throw new AnalysisException("Global aggregation with session window in streaming query" + + " is not supported.") + } + + val groupingWithoutSessionAttributes = groupWithoutSessionExpression.map(_.toAttribute) + + val groupingAttributes = groupingExpressions.map(_.toAttribute) + + // Here doing partial merge is to have aggregated columns with default value for each row. + val partialAggregate: SparkPlan = { + val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) + val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) + createAggregate( + groupingExpressions = groupingExpressions, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + resultExpressions = groupingAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + child = child) + } + + val partialMerged1: SparkPlan = if (mergeSessionsInLocalPartition) { + val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) + val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) + + // sort happens here to merge sessions on each partition + // this is to reduce amount of rows to shuffle + MergingSessionsExec( + requiredChildDistributionExpressions = None, + requiredChildDistributionOption = None, + groupingExpressions = groupingAttributes, + sessionExpression = sessionExpression, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = groupingAttributes.length, + resultExpressions = groupingAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + child = partialAggregate + ) + } else { + partialAggregate + } + + // shuffle & sort happens here: most of details are also handled in this physical plan + val restored = SessionWindowStateStoreRestoreExec(groupingWithoutSessionAttributes, + sessionExpression.toAttribute, stateInfo = None, eventTimeWatermark = None, + stateFormatVersion, partialMerged1) + + val mergedSessions = { + val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) + val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) + MergingSessionsExec( + requiredChildDistributionExpressions = None, + requiredChildDistributionOption = Some(restored.requiredChildDistribution), + groupingExpressions = groupingAttributes, + sessionExpression = sessionExpression, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = groupingAttributes.length, + resultExpressions = groupingAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + child = restored + ) + } + + // Note: stateId and returnAllStates are filled in later with preparation rules + // in IncrementalExecution. + val saved = SessionWindowStateStoreSaveExec( + groupingWithoutSessionAttributes, + sessionExpression.toAttribute, + stateInfo = None, + outputMode = None, + eventTimeWatermark = None, + stateFormatVersion, mergedSessions) + + val finalAndCompleteAggregate: SparkPlan = { + val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final)) + // The attributes of the final aggregation buffer, which is presented as input to the result + // projection: + val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute) + + createAggregate( + requiredChildDistributionExpressions = Some(groupingAttributes), + groupingExpressions = groupingAttributes, + aggregateExpressions = finalAggregateExpressions, + aggregateAttributes = finalAggregateAttributes, + initialInputBufferOffset = groupingAttributes.length, + resultExpressions = resultExpressions, + child = saved) + } + + finalAndCompleteAggregate :: Nil + } + + private def mayAppendUpdatingSessionExec( + groupingExpressions: Seq[NamedExpression], + maybeChildPlan: SparkPlan): SparkPlan = { + groupingExpressions.find(_.metadata.contains(SessionWindow.marker)) match { + case Some(sessionExpression) => + UpdatingSessionsExec( + groupingExpressions.map(_.toAttribute), + sessionExpression.toAttribute, + maybeChildPlan) + + case None => maybeChildPlan + } + } + + private def mayAppendMergingSessionExec( + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression], + partialAggregate: SparkPlan): SparkPlan = { + groupingExpressions.find(_.metadata.contains(SessionWindow.marker)) match { + case Some(sessionExpression) => + val aggExpressions = aggregateExpressions.map(_.copy(mode = PartialMerge)) + val aggAttributes = aggregateExpressions.map(_.resultAttribute) + + val groupingAttributes = groupingExpressions.map(_.toAttribute) + val groupingWithoutSessionExpressions = groupingExpressions.diff(Seq(sessionExpression)) + val groupingWithoutSessionsAttributes = groupingWithoutSessionExpressions + .map(_.toAttribute) + + MergingSessionsExec( + requiredChildDistributionExpressions = Some(groupingWithoutSessionsAttributes), + requiredChildDistributionOption = None, + groupingExpressions = groupingAttributes, + sessionExpression = sessionExpression, + aggregateExpressions = aggExpressions, + aggregateAttributes = aggAttributes, + initialInputBufferOffset = groupingAttributes.length, + resultExpressions = groupingAttributes ++ + aggExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + child = partialAggregate + ) + + case None => partialAggregate + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionsIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionsIterator.scala index bb474a19222c..0a60ddc0a98e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionsIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionsIterator.scala @@ -181,7 +181,8 @@ class UpdatingSessionsIterator( private val valueProj = GenerateUnsafeProjection.generate(valuesExpressions, inputSchema) private val restoreProj = GenerateUnsafeProjection.generate(inputSchema, - groupingExpressions.map(_.toAttribute) ++ valuesExpressions.map(_.toAttribute)) + groupingWithoutSession.map(_.toAttribute) ++ Seq(sessionExpression.toAttribute) ++ + valuesExpressions.map(_.toAttribute)) private def generateGroupingKey(): InternalRow = { val newRow = new SpecificInternalRow(Seq(sessionExpression.toAttribute).toStructType) @@ -190,19 +191,21 @@ class UpdatingSessionsIterator( } private def closeCurrentSession(keyChanged: Boolean): Unit = { - assert(returnRowsIter == null || !returnRowsIter.hasNext) - returnRows = rowsForCurrentSession rowsForCurrentSession = null - val groupingKey = generateGroupingKey() + val groupingKey = generateGroupingKey().copy() val currentRowsIter = returnRows.generateIterator().map { internalRow => val valueRow = valueProj(internalRow) restoreProj(join2(groupingKey, valueRow)).copy() } - returnRowsIter = currentRowsIter + if (returnRowsIter != null && returnRowsIter.hasNext) { + returnRowsIter = returnRowsIter ++ currentRowsIter + } else { + returnRowsIter = currentRowsIter + } if (keyChanged) processedKeys.add(currentKeys) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala index 5019008ec5e3..69802b143c11 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.aggregate.UpdatingSessionsIterator import org.apache.spark.sql.types.{DataType, StructField, StructType} import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.util.Utils @@ -53,6 +54,17 @@ case class AggregateInPandasExec( override def producedAttributes: AttributeSet = AttributeSet(output) + val sessionWindowOption = groupingExpressions.find { p => + p.metadata.contains(SessionWindow.marker) + } + + val groupingWithoutSessionExpressions = sessionWindowOption match { + case Some(sessionExpression) => + groupingExpressions.filterNot { p => p.semanticEquals(sessionExpression) } + + case None => groupingExpressions + } + override def requiredChildDistribution: Seq[Distribution] = { if (groupingExpressions.isEmpty) { AllTuples :: Nil @@ -61,6 +73,14 @@ case class AggregateInPandasExec( } } + override def requiredChildOrdering: Seq[Seq[SortOrder]] = sessionWindowOption match { + case Some(sessionExpression) => + Seq((groupingWithoutSessionExpressions ++ Seq(sessionExpression)) + .map(SortOrder(_, Ascending))) + + case None => Seq(groupingExpressions.map(SortOrder(_, Ascending))) + } + private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = { udf.children match { case Seq(u: PythonUDF) => @@ -73,9 +93,6 @@ case class AggregateInPandasExec( } } - override def requiredChildOrdering: Seq[Seq[SortOrder]] = - Seq(groupingExpressions.map(SortOrder(_, Ascending))) - override protected def doExecute(): RDD[InternalRow] = { val inputRDD = child.execute() @@ -107,13 +124,18 @@ case class AggregateInPandasExec( // Map grouped rows to ArrowPythonRunner results, Only execute if partition is not empty inputRDD.mapPartitionsInternal { iter => if (iter.isEmpty) iter else { + // If we have session window expression in aggregation, we wrap iterator with + // UpdatingSessionIterator to calculate sessions for input rows and update + // rows' session column, so that further aggregations can aggregate input rows + // for the same session. + val newIter: Iterator[InternalRow] = mayAppendUpdatingSessionIterator(iter) val prunedProj = UnsafeProjection.create(allInputs.toSeq, child.output) val grouped = if (groupingExpressions.isEmpty) { // Use an empty unsafe row as a place holder for the grouping key - Iterator((new UnsafeRow(), iter)) + Iterator((new UnsafeRow(), newIter)) } else { - GroupedIterator(iter, groupingExpressions, child.output) + GroupedIterator(newIter, groupingExpressions, child.output) }.map { case (key, rows) => (key, rows.map(prunedProj)) } @@ -157,4 +179,21 @@ case class AggregateInPandasExec( override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = copy(child = newChild) + + + private def mayAppendUpdatingSessionIterator( + iter: Iterator[InternalRow]): Iterator[InternalRow] = { + val newIter = sessionWindowOption match { + case Some(sessionExpression) => + val inMemoryThreshold = conf.windowExecBufferInMemoryThreshold + val spillThreshold = conf.windowExecBufferSpillThreshold + + new UpdatingSessionsIterator(iter, groupingWithoutSessionExpressions, sessionExpression, + child.output, inMemoryThreshold, spillThreshold) + + case None => iter + } + + newIter + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index e98996b8e37d..3e772e104648 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -149,6 +149,26 @@ class IncrementalExecution( stateFormatVersion, child) :: Nil)) + case SessionWindowStateStoreSaveExec(keys, session, None, None, None, stateFormatVersion, + UnaryExecNode(agg, + SessionWindowStateStoreRestoreExec(_, _, None, None, _, child))) => + val aggStateInfo = nextStatefulOperationStateInfo + SessionWindowStateStoreSaveExec( + keys, + session, + Some(aggStateInfo), + Some(outputMode), + Some(offsetSeqMetadata.batchWatermarkMs), + stateFormatVersion, + agg.withNewChildren( + SessionWindowStateStoreRestoreExec( + keys, + session, + Some(aggStateInfo), + Some(offsetSeqMetadata.batchWatermarkMs), + stateFormatVersion, + child) :: Nil)) + case StreamingDeduplicateExec(keys, child, None, None) => StreamingDeduplicateExec( keys, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingSessionWindowStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingSessionWindowStateManager.scala index 6561286448b4..5130933f52ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingSessionWindowStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingSessionWindowStateManager.scala @@ -68,8 +68,11 @@ sealed trait StreamingSessionWindowStateManager extends Serializable { * {@code extractKeyWithoutSession}. * @param sessions The all sessions including existing sessions if it's active. * Existing sessions which aren't included in this parameter will be removed. + * @return A tuple having two elements + * 1. number of added/updated rows + * 2. number of deleted rows */ - def updateSessions(store: StateStore, key: UnsafeRow, sessions: Seq[UnsafeRow]): Unit + def updateSessions(store: StateStore, key: UnsafeRow, sessions: Seq[UnsafeRow]): (Long, Long) /** * Removes using a predicate on values, with returning removed values via iterator. @@ -168,7 +171,7 @@ class StreamingSessionWindowStateManagerImplV1( override def updateSessions( store: StateStore, key: UnsafeRow, - sessions: Seq[UnsafeRow]): Unit = { + sessions: Seq[UnsafeRow]): (Long, Long) = { // Below two will be used multiple times - need to make sure this is not a stream or iterator. val newValues = sessions.toList val savedStates = getSessionsWithKeys(store, key) @@ -225,7 +228,7 @@ class StreamingSessionWindowStateManagerImplV1( store: StateStore, key: UnsafeRow, oldValues: List[(UnsafeRow, UnsafeRow)], - values: List[UnsafeRow]): Unit = { + values: List[UnsafeRow]): (Long, Long) = { // Here the key doesn't represent the state key - we need to construct the key for state val keyAndValues = values.map { row => val sessionStart = helper.extractTimePair(row)._1 @@ -236,16 +239,24 @@ class StreamingSessionWindowStateManagerImplV1( val keysForValues = keyAndValues.map(_._1) val keysForOldValues = oldValues.map(_._1) + var upsertedRows = 0L + var deletedRows = 0L + // We should "replace" the value instead of "delete" and "put" if the start time // equals to. This will remove unnecessary tombstone being written to the delta, which is // implementation details on state store implementations. + keysForOldValues.filterNot(keysForValues.contains).foreach { oldKey => store.remove(oldKey) + deletedRows += 1 } keyAndValues.foreach { case (key, value) => store.put(key, value) + upsertedRows += 1 } + + (upsertedRows, deletedRows) } override def abortIfNeeded(store: StateStore): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 3f6a7ba1a0da..2dd91decfa99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -20,7 +20,9 @@ package org.apache.spark.sql.execution.streaming import java.util.UUID import java.util.concurrent.TimeUnit._ +import scala.annotation.tailrec import scala.collection.JavaConverters._ +import scala.collection.mutable import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD @@ -511,6 +513,293 @@ case class StateStoreSaveExec( copy(child = newChild) } +/** + * This class sorts input rows and existing sessions in state and provides output rows as + * sorted by "group keys + start time of session window". + * + * Refer [[MergingSortWithSessionWindowStateIterator]] for more details. + */ +case class SessionWindowStateStoreRestoreExec( + keyWithoutSessionExpressions: Seq[Attribute], + sessionExpression: Attribute, + stateInfo: Option[StatefulOperatorStateInfo], + eventTimeWatermark: Option[Long], + stateFormatVersion: Int, + child: SparkPlan) + extends UnaryExecNode with StateStoreReader with WatermarkSupport { + + override def keyExpressions: Seq[Attribute] = keyWithoutSessionExpressions + + assert(keyExpressions.nonEmpty, "Grouping key must be specified when using sessionWindow") + + private val stateManager = StreamingSessionWindowStateManager.createStateManager( + keyWithoutSessionExpressions, sessionExpression, child.output, stateFormatVersion) + + override protected def doExecute(): RDD[InternalRow] = { + val numOutputRows = longMetric("numOutputRows") + + child.execute().mapPartitionsWithReadStateStore( + getStateInfo, + stateManager.getStateKeySchema, + stateManager.getStateValueSchema, + numColsPrefixKey = stateManager.getNumColsForPrefixKey, + session.sessionState, + Some(session.streams.stateStoreCoordinator)) { case (store, iter) => + + // We need to filter out outdated inputs + val filteredIterator = watermarkPredicateForData match { + case Some(predicate) => iter.filter((row: InternalRow) => !predicate.eval(row)) + case None => iter + } + + new MergingSortWithSessionWindowStateIterator( + filteredIterator, + stateManager, + store, + keyWithoutSessionExpressions, + sessionExpression, + child.output).map { row => + numOutputRows += 1 + row + } + } + } + + override def output: Seq[Attribute] = child.output + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def outputOrdering: Seq[SortOrder] = { + (keyWithoutSessionExpressions ++ Seq(sessionExpression)).map(SortOrder(_, Ascending)) + } + + override def requiredChildDistribution: Seq[Distribution] = { + ClusteredDistribution(keyWithoutSessionExpressions, stateInfo.map(_.numPartitions)) :: Nil + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = { + Seq((keyWithoutSessionExpressions ++ Seq(sessionExpression)).map(SortOrder(_, Ascending))) + } + + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + copy(child = newChild) +} + +/** + * For each input tuple, the key is calculated and the tuple is `put` into the [[StateStore]]. + */ +case class SessionWindowStateStoreSaveExec( + keyWithoutSessionExpressions: Seq[Attribute], + sessionExpression: Attribute, + stateInfo: Option[StatefulOperatorStateInfo] = None, + outputMode: Option[OutputMode] = None, + eventTimeWatermark: Option[Long] = None, + stateFormatVersion: Int, + child: SparkPlan) + extends UnaryExecNode with StateStoreWriter with WatermarkSupport { + + override def keyExpressions: Seq[Attribute] = keyWithoutSessionExpressions + + private val stateManager = StreamingSessionWindowStateManager.createStateManager( + keyWithoutSessionExpressions, sessionExpression, child.output, stateFormatVersion) + + override protected def doExecute(): RDD[InternalRow] = { + metrics // force lazy init at driver + assert(outputMode.nonEmpty, + "Incorrect planning in IncrementalExecution, outputMode has not been set") + assert(keyExpressions.nonEmpty, + "Grouping key must be specified when using sessionWindow") + + child.execute().mapPartitionsWithStateStore( + getStateInfo, + stateManager.getStateKeySchema, + stateManager.getStateValueSchema, + numColsPrefixKey = stateManager.getNumColsForPrefixKey, + session.sessionState, + Some(session.streams.stateStoreCoordinator)) { case (store, iter) => + + val numOutputRows = longMetric("numOutputRows") + val numRemovedStateRows = longMetric("numRemovedStateRows") + val allUpdatesTimeMs = longMetric("allUpdatesTimeMs") + val allRemovalsTimeMs = longMetric("allRemovalsTimeMs") + val commitTimeMs = longMetric("commitTimeMs") + + outputMode match { + // Update and output all rows in the StateStore. + case Some(Complete) => + allUpdatesTimeMs += timeTakenMs { + putToStore(iter, store) + } + commitTimeMs += timeTakenMs { + stateManager.commit(store) + } + setStoreMetrics(store) + stateManager.iterator(store).map { row => + numOutputRows += 1 + row + } + + // Update and output only rows being evicted from the StateStore + // Assumption: watermark predicates must be non-empty if append mode is allowed + case Some(Append) => + allUpdatesTimeMs += timeTakenMs { + val filteredIter = applyRemovingRowsOlderThanWatermark(iter, + watermarkPredicateForData.get) + putToStore(filteredIter, store) + } + + val removalStartTimeNs = System.nanoTime + new NextIterator[InternalRow] { + private val removedIter = stateManager.removeByValueCondition( + store, watermarkPredicateForData.get.eval) + + override protected def getNext(): InternalRow = { + if (!removedIter.hasNext) { + finished = true + null + } else { + numRemovedStateRows += 1 + numOutputRows += 1 + removedIter.next() + } + } + + override protected def close(): Unit = { + allRemovalsTimeMs += NANOSECONDS.toMillis(System.nanoTime - removalStartTimeNs) + commitTimeMs += timeTakenMs { store.commit() } + setStoreMetrics(store) + setOperatorMetrics() + } + } + + case Some(Update) => + val baseIterator = watermarkPredicateForData match { + case Some(predicate) => applyRemovingRowsOlderThanWatermark(iter, predicate) + case None => iter + } + val iterPutToStore = iteratorPutToStore(baseIterator, store, + returnOnlyUpdatedRows = true) + new NextIterator[InternalRow] { + private val updatesStartTimeNs = System.nanoTime + + override protected def getNext(): InternalRow = { + if (iterPutToStore.hasNext) { + val row = iterPutToStore.next() + numOutputRows += 1 + row + } else { + finished = true + null + } + } + + override protected def close(): Unit = { + allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs) + + allRemovalsTimeMs += timeTakenMs { + if (watermarkPredicateForData.nonEmpty) { + val removedIter = stateManager.removeByValueCondition( + store, watermarkPredicateForData.get.eval) + while (removedIter.hasNext) { + numRemovedStateRows += 1 + removedIter.next() + } + } + } + commitTimeMs += timeTakenMs { store.commit() } + setStoreMetrics(store) + setOperatorMetrics() + } + } + + case _ => throw QueryExecutionErrors.invalidStreamingOutputModeError(outputMode) + } + } + } + + override def output: Seq[Attribute] = child.output + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def requiredChildDistribution: Seq[Distribution] = { + ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil + } + + override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { + (outputMode.contains(Append) || outputMode.contains(Update)) && + eventTimeWatermark.isDefined && + newMetadata.batchWatermarkMs > eventTimeWatermark.get + } + + private def iteratorPutToStore( + iter: Iterator[InternalRow], + store: StateStore, + returnOnlyUpdatedRows: Boolean): Iterator[InternalRow] = { + val numUpdatedStateRows = longMetric("numUpdatedStateRows") + val numRemovedStateRows = longMetric("numRemovedStateRows") + + new NextIterator[InternalRow] { + var curKey: UnsafeRow = null + val curValuesOnKey = new mutable.ArrayBuffer[UnsafeRow]() + + private def applyChangesOnKey(): Unit = { + if (curValuesOnKey.nonEmpty) { + val (upserted, deleted) = stateManager.updateSessions(store, curKey, curValuesOnKey.toSeq) + numUpdatedStateRows += upserted + numRemovedStateRows += deleted + curValuesOnKey.clear + } + } + + @tailrec + override protected def getNext(): InternalRow = { + if (!iter.hasNext) { + applyChangesOnKey() + finished = true + return null + } + + val row = iter.next().asInstanceOf[UnsafeRow] + val key = stateManager.extractKeyWithoutSession(row) + + if (curKey == null || curKey != key) { + // new group appears + applyChangesOnKey() + curKey = key.copy() + } + + // must copy the row, for this row is a reference in iterator and + // will change when iter.next + curValuesOnKey += row.copy + + if (!returnOnlyUpdatedRows) { + row + } else { + if (stateManager.newOrModified(store, row)) { + row + } else { + // current row isn't the "updated" row, continue to the next row + getNext() + } + } + } + + override protected def close(): Unit = {} + } + } + + private def putToStore(baseIter: Iterator[InternalRow], store: StateStore): Unit = { + val iterPutToStore = iteratorPutToStore(baseIter, store, returnOnlyUpdatedRows = false) + while (iterPutToStore.hasNext) { + iterPutToStore.next() + } + } + + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + copy(child = newChild) +} + + /** Physical operator for executing streaming Deduplicate. */ case class StreamingDeduplicateExec( keyExpressions: Seq[Attribute], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 3b39d9790dff..7db8e8a90a59 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3630,6 +3630,36 @@ object functions { window(timeColumn, windowDuration, windowDuration, "0 second") } + /** + * Generates session window given a timestamp specifying column. + * + * Session window is one of dynamic windows, which means the length of window is varying + * according to the given inputs. The length of session window is defined as "the timestamp + * of latest input of the session + gap duration", so when the new inputs are bound to the + * current session window, the end time of session window can be expanded according to the new + * inputs. + * + * Windows can support microsecond precision. gapDuration in the order of months are not + * supported. + * + * For a streaming query, you may use the function `current_timestamp` to generate windows on + * processing time. + * + * @param timeColumn The column or the expression to use as the timestamp for windowing by time. + * The time column must be of TimestampType. + * @param gapDuration A string specifying the timeout of the session, e.g. `10 minutes`, + * `1 second`. Check `org.apache.spark.unsafe.types.CalendarInterval` for + * valid duration identifiers. + * + * @group datetime_funcs + * @since 3.2.0 + */ + def session_window(timeColumn: Column, gapDuration: String): Column = { + withExpr { + SessionWindow(timeColumn.expr, gapDuration) + }.as("session_window") + } + /** * Creates timestamp from the number of seconds since UTC epoch. * @group datetime_funcs diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index c13a1d4e93d4..41692d20ed56 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -1,8 +1,8 @@ ## Summary - - Number of queries: 360 - - Number of expressions that missing example: 13 - - Expressions missing examples: bigint,binary,boolean,date,decimal,double,float,int,smallint,string,timestamp,tinyint,window + - Number of queries: 361 + - Number of expressions that missing example: 14 + - Expressions missing examples: bigint,binary,boolean,date,decimal,double,float,int,smallint,string,timestamp,tinyint,session_window,window ## Schema of Built-in Functions | Class name | Function name or alias | Query example | Output schema | | ---------- | ---------------------- | ------------- | ------------- | @@ -244,6 +244,7 @@ | org.apache.spark.sql.catalyst.expressions.SecondsToTimestamp | timestamp_seconds | SELECT timestamp_seconds(1230219000) | struct | | org.apache.spark.sql.catalyst.expressions.Sentences | sentences | SELECT sentences('Hi there! Good morning.') | struct>> | | org.apache.spark.sql.catalyst.expressions.Sequence | sequence | SELECT sequence(1, 5) | struct> | +| org.apache.spark.sql.catalyst.expressions.SessionWindow | session_window | N/A | N/A | | org.apache.spark.sql.catalyst.expressions.Sha1 | sha | SELECT sha('Spark') | struct | | org.apache.spark.sql.catalyst.expressions.Sha1 | sha1 | SELECT sha1('Spark') | struct | | org.apache.spark.sql.catalyst.expressions.Sha2 | sha2 | SELECT sha2('Spark', 256) | struct | @@ -365,4 +366,4 @@ | org.apache.spark.sql.catalyst.expressions.xml.XPathList | xpath | SELECT xpath('b1b2b3c1c2','a/b/text()') | structb1b2b3c1c2, a/b/text()):array> | | org.apache.spark.sql.catalyst.expressions.xml.XPathLong | xpath_long | SELECT xpath_long('12', 'sum(a/b)') | struct12, sum(a/b)):bigint> | | org.apache.spark.sql.catalyst.expressions.xml.XPathShort | xpath_short | SELECT xpath_short('12', 'sum(a/b)') | struct12, sum(a/b)):smallint> | -| org.apache.spark.sql.catalyst.expressions.xml.XPathString | xpath_string | SELECT xpath_string('bcc','a/c') | structbcc, a/c):string> | \ No newline at end of file +| org.apache.spark.sql.catalyst.expressions.xml.XPathString | xpath_string | SELECT xpath_string('bcc','a/c') | structbcc, a/c):string> | diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala new file mode 100644 index 000000000000..b70b2c670732 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala @@ -0,0 +1,290 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.sql.catalyst.plans.logical.Expand +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.StringType + +class DataFrameSessionWindowingSuite extends QueryTest with SharedSparkSession + with BeforeAndAfterEach { + + import testImplicits._ + + test("simple session window with record at window start") { + val df = Seq( + ("2016-03-27 19:39:30", 1, "a")).toDF("time", "value", "id") + + checkAnswer( + df.groupBy(session_window($"time", "10 seconds")) + .agg(count("*").as("counts")) + .orderBy($"session_window.start".asc) + .select($"session_window.start".cast("string"), $"session_window.end".cast("string"), + $"counts"), + Seq( + Row("2016-03-27 19:39:30", "2016-03-27 19:39:40", 1) + ) + ) + } + + test("session window groupBy statement") { + val df = Seq( + ("2016-03-27 19:39:34", 1, "a"), + ("2016-03-27 19:39:56", 2, "a"), + ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + + // session window handles sort while applying group by + // whereas time window doesn't + + checkAnswer( + df.groupBy(session_window($"time", "10 seconds")) + .agg(count("*").as("counts")) + .orderBy($"session_window.start".asc) + .select("counts"), + Seq(Row(2), Row(1)) + ) + } + + test("session window groupBy with multiple keys statement") { + val df = Seq( + ("2016-03-27 19:39:34", 1, "a"), + ("2016-03-27 19:39:39", 1, "a"), + ("2016-03-27 19:39:56", 2, "a"), + ("2016-03-27 19:40:04", 2, "a"), + ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + + // session window handles sort while applying group by + // whereas time window doesn't + + // expected sessions + // key "a" => (19:39:34 ~ 19:39:49) (19:39:56 ~ 19:40:14) + // key "b" => (19:39:27 ~ 19:39:37) + + checkAnswer( + df.groupBy(session_window($"time", "10 seconds"), 'id) + .agg(count("*").as("counts"), sum("value").as("sum")) + .orderBy($"session_window.start".asc) + .selectExpr("CAST(session_window.start AS STRING)", "CAST(session_window.end AS STRING)", + "id", "counts", "sum"), + + Seq( + Row("2016-03-27 19:39:27", "2016-03-27 19:39:37", "b", 1, 4), + Row("2016-03-27 19:39:34", "2016-03-27 19:39:49", "a", 2, 2), + Row("2016-03-27 19:39:56", "2016-03-27 19:40:14", "a", 2, 4) + ) + ) + } + + test("session window groupBy with multiple keys statement - one distinct") { + val df = Seq( + ("2016-03-27 19:39:34", 1, "a"), + ("2016-03-27 19:39:39", 1, "a"), + ("2016-03-27 19:39:56", 2, "a"), + ("2016-03-27 19:40:04", 2, "a"), + ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + + // session window handles sort while applying group by + // whereas time window doesn't + + // expected sessions + // key "a" => (19:39:34 ~ 19:39:49) (19:39:56 ~ 19:40:14) + // key "b" => (19:39:27 ~ 19:39:37) + + checkAnswer( + df.groupBy(session_window($"time", "10 seconds"), 'id) + .agg(count("*").as("counts"), sum_distinct(col("value")).as("sum")) + .orderBy($"session_window.start".asc) + .selectExpr("CAST(session_window.start AS STRING)", "CAST(session_window.end AS STRING)", + "id", "counts", "sum"), + Seq( + Row("2016-03-27 19:39:27", "2016-03-27 19:39:37", "b", 1, 4), + Row("2016-03-27 19:39:34", "2016-03-27 19:39:49", "a", 2, 1), + Row("2016-03-27 19:39:56", "2016-03-27 19:40:14", "a", 2, 2) + ) + ) + } + + test("session window groupBy with multiple keys statement - two distinct") { + val df = Seq( + ("2016-03-27 19:39:34", 1, 2, "a"), + ("2016-03-27 19:39:39", 1, 2, "a"), + ("2016-03-27 19:39:56", 2, 4, "a"), + ("2016-03-27 19:40:04", 2, 4, "a"), + ("2016-03-27 19:39:27", 4, 8, "b")).toDF("time", "value", "value2", "id") + + // session window handles sort while applying group by + // whereas time window doesn't + + // expected sessions + // key "a" => (19:39:34 ~ 19:39:49) (19:39:56 ~ 19:40:14) + // key "b" => (19:39:27 ~ 19:39:37) + + checkAnswer( + df.groupBy(session_window($"time", "10 seconds"), 'id) + .agg(sum_distinct(col("value")).as("sum"), sum_distinct(col("value2")).as("sum2")) + .orderBy($"session_window.start".asc) + .selectExpr("CAST(session_window.start AS STRING)", "CAST(session_window.end AS STRING)", + "id", "sum", "sum2"), + Seq( + Row("2016-03-27 19:39:27", "2016-03-27 19:39:37", "b", 4, 8), + Row("2016-03-27 19:39:34", "2016-03-27 19:39:49", "a", 1, 2), + Row("2016-03-27 19:39:56", "2016-03-27 19:40:14", "a", 2, 4) + ) + ) + } + + test("session window groupBy with multiple keys statement - keys overlapped with sessions") { + val df = Seq( + ("2016-03-27 19:39:34", 1, "a"), + ("2016-03-27 19:39:39", 1, "b"), + ("2016-03-27 19:39:40", 2, "a"), + ("2016-03-27 19:39:45", 2, "b"), + ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + + // session window handles sort while applying group by + // whereas time window doesn't + + // expected sessions + // a => (19:39:34 ~ 19:39:50) + // b => (19:39:27 ~ 19:39:37), (19:39:39 ~ 19:39:55) + + checkAnswer( + df.groupBy(session_window($"time", "10 seconds"), 'id) + .agg(count("*").as("counts"), sum("value").as("sum")) + .orderBy($"session_window.start".asc) + .selectExpr("CAST(session_window.start AS STRING)", "CAST(session_window.end AS STRING)", + "id", "counts", "sum"), + + Seq( + Row("2016-03-27 19:39:27", "2016-03-27 19:39:37", "b", 1, 4), + Row("2016-03-27 19:39:34", "2016-03-27 19:39:50", "a", 2, 3), + Row("2016-03-27 19:39:39", "2016-03-27 19:39:55", "b", 2, 3) + ) + ) + } + + test("session window with multi-column projection") { + val df = Seq( + ("2016-03-27 19:39:34", 1, "a"), + ("2016-03-27 19:39:56", 2, "a"), + ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + .select(session_window($"time", "10 seconds"), $"value") + .orderBy($"session_window.start".asc) + .select($"session_window.start".cast("string"), $"session_window.end".cast("string"), + $"value") + + val expands = df.queryExecution.optimizedPlan.find(_.isInstanceOf[Expand]) + assert(expands.isEmpty, "Session windows shouldn't require expand") + + checkAnswer( + df, + Seq( + Row("2016-03-27 19:39:27", "2016-03-27 19:39:37", 4), + Row("2016-03-27 19:39:34", "2016-03-27 19:39:44", 1), + Row("2016-03-27 19:39:56", "2016-03-27 19:40:06", 2) + ) + ) + } + + test("session window combined with explode expression") { + val df = Seq( + ("2016-03-27 19:39:34", 1, Seq("a", "b")), + ("2016-03-27 19:39:56", 2, Seq("a", "c", "d"))).toDF("time", "value", "ids") + + checkAnswer( + df.select(session_window($"time", "10 seconds"), $"value", explode($"ids")) + .orderBy($"session_window.start".asc).select("value"), + // first window exploded to two rows for "a", and "b", second window exploded to 3 rows + Seq(Row(1), Row(1), Row(2), Row(2), Row(2)) + ) + } + + test("null timestamps") { + val df = Seq( + ("2016-03-27 09:00:05", 1), + ("2016-03-27 09:00:32", 2), + (null, 3), + (null, 4)).toDF("time", "value") + + checkDataset( + df.select(session_window($"time", "10 seconds"), $"value") + .orderBy($"session_window.start".asc) + .select("value") + .as[Int], + 1, 2) // null columns are dropped + } + + // NOTE: unlike time window, joining session windows without grouping + // doesn't arrange session, so two rows will be joined only if session range is exactly same + + test("multiple session windows in a single operator throws nice exception") { + val df = Seq( + ("2016-03-27 09:00:02", 3), + ("2016-03-27 09:00:35", 6)).toDF("time", "value") + val e = intercept[AnalysisException] { + df.select(session_window($"time", "10 second"), session_window($"time", "15 second")) + .collect() + } + assert(e.getMessage.contains( + "Multiple time/session window expressions would result in a cartesian product")) + } + + test("aliased session windows") { + val df = Seq( + ("2016-03-27 19:39:34", 1, Seq("a", "b")), + ("2016-03-27 19:39:56", 2, Seq("a", "c", "d"))).toDF("time", "value", "ids") + + checkAnswer( + df.select(session_window($"time", "10 seconds").as("session_window"), $"value") + .orderBy($"session_window.start".asc) + .select("value"), + Seq(Row(1), Row(2)) + ) + } + + private def withTempTable(f: String => Unit): Unit = { + val tableName = "temp" + Seq( + ("2016-03-27 19:39:34", 1), + ("2016-03-27 19:39:56", 2), + ("2016-03-27 19:39:27", 4)).toDF("time", "value").createOrReplaceTempView(tableName) + try { + f(tableName) + } finally { + spark.catalog.dropTempView(tableName) + } + } + + test("time window in SQL with single string expression") { + withTempTable { table => + checkAnswer( + spark.sql(s"""select session_window(time, "10 seconds"), value from $table""") + .select($"session_window.start".cast(StringType), $"session_window.end".cast(StringType), + $"value"), + Seq( + Row("2016-03-27 19:39:27", "2016-03-27 19:39:37", 4), + Row("2016-03-27 19:39:34", "2016-03-27 19:39:44", 1), + Row("2016-03-27 19:39:56", "2016-03-27 19:40:06", 2) + ) + ) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala index 4fdaeb57ad50..2ef43dcf562c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala @@ -239,7 +239,7 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSparkSession { df.select(window($"time", "10 second"), window($"time", "15 second")).collect() } assert(e.getMessage.contains( - "Multiple time window expressions would result in a cartesian product")) + "Multiple time/session window expressions would result in a cartesian product")) } test("aliased windows") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index b0d5c8932bcb..1e23c115ff24 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -136,7 +136,8 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark test("SPARK-14415: All functions should have own descriptions") { for (f <- spark.sessionState.functionRegistry.listFunction()) { - if (!Seq("cube", "grouping", "grouping_id", "rollup", "window").contains(f.unquotedString)) { + if (!Seq("cube", "grouping", "grouping_id", "rollup", "window", + "session_window").contains(f.unquotedString)) { checkKeywordsNotExist(sql(s"describe function $f"), "N/A.") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionsIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionsIteratorSuite.scala index 2a4245d3b363..045901bc20ca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionsIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionsIteratorSuite.scala @@ -199,9 +199,9 @@ class UpdatingSessionsIteratorSuite extends SharedSparkSession { val row6 = createRow("a", 2, 115, 125, 20, 1.2) val rows3 = List(row5, row6) + // This is to test the edge case that the last input row creates a new session. val row7 = createRow("a", 2, 127, 137, 30, 1.3) - val row8 = createRow("a", 2, 135, 145, 40, 1.4) - val rows4 = List(row7, row8) + val rows4 = List(row7) val rowsAll = rows1 ++ rows2 ++ rows3 ++ rows4 @@ -244,8 +244,8 @@ class UpdatingSessionsIteratorSuite extends SharedSparkSession { } retRows4.zip(rows4).foreach { case (retRow, expectedRow) => - // session being expanded to (127 ~ 145) - assertRowsEqualsWithNewSession(expectedRow, retRow, 127, 145) + // session being expanded to (127 ~ 137) + assertRowsEqualsWithNewSession(expectedRow, retRow, 127, 137) } assert(iterator.hasNext === false) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala index 30ee97a89b4e..08e21d537122 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala @@ -133,6 +133,7 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession { val ignoreSet = Set( // Explicitly inherits NonSQLExpression, and has no ExpressionDescription "org.apache.spark.sql.catalyst.expressions.TimeWindow", + "org.apache.spark.sql.catalyst.expressions.SessionWindow", // Cast aliases do not need examples "org.apache.spark.sql.catalyst.expressions.Cast") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala new file mode 100644 index 000000000000..a381d069df2d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala @@ -0,0 +1,460 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import java.util.Locale + +import org.scalatest.BeforeAndAfter +import org.scalatest.matchers.must.Matchers + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, RocksDBStateStoreProvider} +import org.apache.spark.sql.functions.{count, session_window, sum} +import org.apache.spark.sql.internal.SQLConf + +class StreamingSessionWindowSuite extends StreamTest + with BeforeAndAfter with Matchers with Logging { + + import testImplicits._ + + after { + sqlContext.streams.active.foreach(_.stop()) + } + + def testWithAllOptions(name: String, confPairs: (String, String)*) + (func: => Any): Unit = { + val mergingSessionOptions = Seq(true, false).map { value => + (SQLConf.STREAMING_SESSION_WINDOW_MERGE_SESSIONS_IN_LOCAL_PARTITION.key, value) + } + val providerOptions = Seq( + classOf[HDFSBackedStateStoreProvider].getCanonicalName, + classOf[RocksDBStateStoreProvider].getCanonicalName + ).map { value => + (SQLConf.STATE_STORE_PROVIDER_CLASS.key, value.stripSuffix("$")) + } + + val availableOptions = for ( + opt1 <- mergingSessionOptions; + opt2 <- providerOptions + ) yield (opt1, opt2) + + for (option <- availableOptions) { + test(s"$name - merging sessions in local partition: ${option._1._2} / " + + s"provider: ${option._2._2}") { + withSQLConf(confPairs ++ + Seq( + option._1._1 -> option._1._2.toString, + option._2._1 -> option._2._2): _*) { + func + } + } + } + } + + testWithAllOptions("complete mode - session window") { + // Implements StructuredSessionization.scala leveraging "session" function + // as a test, to verify the sessionization works with simple example + + // note that complete mode doesn't honor watermark: even it is specified, watermark will be + // always Unix timestamp 0 + + val inputData = MemoryStream[(String, Long)] + + // Split the lines into words, treat words as sessionId of events + val events = inputData.toDF() + .select($"_1".as("value"), $"_2".as("timestamp")) + .withColumn("eventTime", $"timestamp".cast("timestamp")) + .selectExpr("explode(split(value, ' ')) AS sessionId", "eventTime") + + val sessionUpdates = events + .groupBy(session_window($"eventTime", "10 seconds") as 'session, 'sessionId) + .agg(count("*").as("numEvents")) + .selectExpr("sessionId", "CAST(session.start AS LONG)", "CAST(session.end AS LONG)", + "CAST(session.end AS LONG) - CAST(session.start AS LONG) AS durationMs", + "numEvents") + + sessionUpdates.explain() + + testStream(sessionUpdates, OutputMode.Complete())( + AddData(inputData, + ("hello world spark streaming", 40L), + ("world hello structured streaming", 41L) + ), + CheckNewAnswer( + ("hello", 40, 51, 11, 2), + ("world", 40, 51, 11, 2), + ("streaming", 40, 51, 11, 2), + ("spark", 40, 50, 10, 1), + ("structured", 41, 51, 10, 1) + ), + + // placing new sessions "before" previous sessions + AddData(inputData, ("spark streaming", 25L)), + CheckNewAnswer( + ("spark", 25, 35, 10, 1), + ("streaming", 25, 35, 10, 1), + ("hello", 40, 51, 11, 2), + ("world", 40, 51, 11, 2), + ("streaming", 40, 51, 11, 2), + ("spark", 40, 50, 10, 1), + ("structured", 41, 51, 10, 1) + ), + + // concatenating multiple previous sessions into one + AddData(inputData, ("spark streaming", 30L)), + CheckNewAnswer( + ("spark", 25, 50, 25, 3), + ("streaming", 25, 51, 26, 4), + ("hello", 40, 51, 11, 2), + ("world", 40, 51, 11, 2), + ("structured", 41, 51, 10, 1) + ), + + // placing new sessions after previous sessions + AddData(inputData, ("hello apache spark", 60L)), + CheckNewAnswer( + ("spark", 25, 50, 25, 3), + ("streaming", 25, 51, 26, 4), + ("hello", 40, 51, 11, 2), + ("world", 40, 51, 11, 2), + ("structured", 41, 51, 10, 1), + ("hello", 60, 70, 10, 1), + ("apache", 60, 70, 10, 1), + ("spark", 60, 70, 10, 1) + ), + + AddData(inputData, ("structured streaming", 90L)), + CheckNewAnswer( + ("spark", 25, 50, 25, 3), + ("streaming", 25, 51, 26, 4), + ("hello", 40, 51, 11, 2), + ("world", 40, 51, 11, 2), + ("structured", 41, 51, 10, 1), + ("hello", 60, 70, 10, 1), + ("apache", 60, 70, 10, 1), + ("spark", 60, 70, 10, 1), + ("structured", 90, 100, 10, 1), + ("streaming", 90, 100, 10, 1) + ) + ) + } + + testWithAllOptions("complete mode - session window - no key") { + // complete mode doesn't honor watermark: even it is specified, watermark will be + // always Unix timestamp 0 + + val inputData = MemoryStream[Int] + + val windowedAggregation = inputData.toDF() + .selectExpr("*") + .withColumn("eventTime", $"value".cast("timestamp")) + .groupBy(session_window($"eventTime", "5 seconds") as 'session) + .agg(count("*") as 'count, sum("value") as 'sum) + .select($"session".getField("start").cast("long").as[Long], + $"session".getField("end").cast("long").as[Long], $"count".as[Long], $"sum".as[Long]) + + val e = intercept[StreamingQueryException] { + testStream(windowedAggregation, OutputMode.Complete())( + AddData(inputData, 40), + CheckAnswer() // this is just to trigger the exception + ) + } + Seq("Global aggregation with session window", "not supported").foreach { m => + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(m.toLowerCase(Locale.ROOT))) + } + } + + testWithAllOptions("append mode - session window") { + // Implements StructuredSessionization.scala leveraging "session" function + // as a test, to verify the sessionization works with simple example + + val inputData = MemoryStream[(String, Long)] + + // Split the lines into words, treat words as sessionId of events + val events = inputData.toDF() + .select($"_1".as("value"), $"_2".as("timestamp")) + .withColumn("eventTime", $"timestamp".cast("timestamp")) + .selectExpr("explode(split(value, ' ')) AS sessionId", "eventTime") + .withWatermark("eventTime", "30 seconds") + + val sessionUpdates = events + .groupBy(session_window($"eventTime", "10 seconds") as 'session, 'sessionId) + .agg(count("*").as("numEvents")) + .selectExpr("sessionId", "CAST(session.start AS LONG)", "CAST(session.end AS LONG)", + "CAST(session.end AS LONG) - CAST(session.start AS LONG) AS durationMs", + "numEvents") + + testStream(sessionUpdates, OutputMode.Append())( + AddData(inputData, + ("hello world spark streaming", 40L), + ("world hello structured streaming", 41L) + ), + + // watermark: 11 + // current sessions + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("streaming", 40, 51, 11, 2), + // ("spark", 40, 50, 10, 1), + // ("structured", 41, 51, 10, 1) + CheckNewAnswer( + ), + + // placing new sessions "before" previous sessions + AddData(inputData, ("spark streaming", 25L)), + // watermark: 11 + // current sessions + // ("spark", 25, 35, 10, 1), + // ("streaming", 25, 35, 10, 1), + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("streaming", 40, 51, 11, 2), + // ("spark", 40, 50, 10, 1), + // ("structured", 41, 51, 10, 1) + CheckNewAnswer( + ), + + // late event which session's end 10 would be later than watermark 11: should be dropped + AddData(inputData, ("spark streaming", 0L)), + // watermark: 11 + // current sessions + // ("spark", 25, 35, 10, 1), + // ("streaming", 25, 35, 10, 1), + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("streaming", 40, 51, 11, 2), + // ("spark", 40, 50, 10, 1), + // ("structured", 41, 51, 10, 1) + CheckNewAnswer( + ), + + // concatenating multiple previous sessions into one + AddData(inputData, ("spark streaming", 30L)), + // watermark: 11 + // current sessions + // ("spark", 25, 50, 25, 3), + // ("streaming", 25, 51, 26, 4), + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("structured", 41, 51, 10, 1) + CheckNewAnswer( + ), + + // placing new sessions after previous sessions + AddData(inputData, ("hello apache spark", 60L)), + // watermark: 30 + // current sessions + // ("spark", 25, 50, 25, 3), + // ("streaming", 25, 51, 26, 4), + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("structured", 41, 51, 10, 1), + // ("hello", 60, 70, 10, 1), + // ("apache", 60, 70, 10, 1), + // ("spark", 60, 70, 10, 1) + CheckNewAnswer( + ), + + AddData(inputData, ("structured streaming", 90L)), + // watermark: 60 + // current sessions + // ("hello", 60, 70, 10, 1), + // ("apache", 60, 70, 10, 1), + // ("spark", 60, 70, 10, 1), + // ("structured", 90, 100, 10, 1), + // ("streaming", 90, 100, 10, 1) + CheckNewAnswer( + ("spark", 25, 50, 25, 3), + ("streaming", 25, 51, 26, 4), + ("hello", 40, 51, 11, 2), + ("world", 40, 51, 11, 2), + ("structured", 41, 51, 10, 1) + ) + ) + } + + testWithAllOptions("append mode - session window - no key") { + val inputData = MemoryStream[Int] + + val windowedAggregation = inputData.toDF() + .selectExpr("*") + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(session_window($"eventTime", "5 seconds") as 'session) + .agg(count("*") as 'count, sum("value") as 'sum) + .select($"session".getField("start").cast("long").as[Long], + $"session".getField("end").cast("long").as[Long], $"count".as[Long], $"sum".as[Long]) + + val e = intercept[StreamingQueryException] { + testStream(windowedAggregation)( + AddData(inputData, 40), + CheckAnswer() // this is just to trigger the exception + ) + } + Seq("Global aggregation with session window", "not supported").foreach { m => + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(m.toLowerCase(Locale.ROOT))) + } + } + + testWithAllOptions("update mode - session window") { + // Implements StructuredSessionization.scala leveraging "session" function + // as a test, to verify the sessionization works with simple example + + val inputData = MemoryStream[(String, Long)] + + // Split the lines into words, treat words as sessionId of events + val events = inputData.toDF() + .select($"_1".as("value"), $"_2".as("timestamp")) + .withColumn("eventTime", $"timestamp".cast("timestamp")) + .selectExpr("explode(split(value, ' ')) AS sessionId", "eventTime") + .withWatermark("eventTime", "10 seconds") + + val sessionUpdates = events + .groupBy(session_window($"eventTime", "10 seconds") as 'session, 'sessionId) + .agg(count("*").as("numEvents")) + .selectExpr("sessionId", "CAST(session.start AS LONG)", "CAST(session.end AS LONG)", + "CAST(session.end AS LONG) - CAST(session.start AS LONG) AS durationMs", + "numEvents") + + testStream(sessionUpdates, OutputMode.Update())( + AddData(inputData, + ("hello world spark streaming", 40L), + ("world hello structured streaming", 41L) + ), + // watermark: 11 + // current sessions + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("streaming", 40, 51, 11, 2), + // ("spark", 40, 50, 10, 1), + // ("structured", 41, 51, 10, 1) + CheckNewAnswer( + ("hello", 40, 51, 11, 2), + ("world", 40, 51, 11, 2), + ("streaming", 40, 51, 11, 2), + ("spark", 40, 50, 10, 1), + ("structured", 41, 51, 10, 1) + ), + + // placing new sessions "before" previous sessions + AddData(inputData, ("spark streaming", 25L)), + // watermark: 11 + // current sessions + // ("spark", 25, 35, 10, 1), + // ("streaming", 25, 35, 10, 1), + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("streaming", 40, 51, 11, 2), + // ("spark", 40, 50, 10, 1), + // ("structured", 41, 51, 10, 1) + CheckNewAnswer( + ("spark", 25, 35, 10, 1), + ("streaming", 25, 35, 10, 1) + ), + + // late event which session's end 10 would be later than watermark 11: should be dropped + AddData(inputData, ("spark streaming", 0L)), + // watermark: 11 + // current sessions + // ("spark", 25, 35, 10, 1), + // ("streaming", 25, 35, 10, 1), + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("streaming", 40, 51, 11, 2), + // ("spark", 40, 50, 10, 1), + // ("structured", 41, 51, 10, 1) + CheckNewAnswer( + ), + + // concatenating multiple previous sessions into one + AddData(inputData, ("spark streaming", 30L)), + // watermark: 11 + // current sessions + // ("spark", 25, 50, 25, 3), + // ("streaming", 25, 51, 26, 4), + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("structured", 41, 51, 10, 1) + CheckNewAnswer( + ("spark", 25, 50, 25, 3), + ("streaming", 25, 51, 26, 4) + ), + + // placing new sessions after previous sessions + AddData(inputData, ("hello apache spark", 60L)), + // watermark: 30 + // current sessions + // ("spark", 25, 50, 25, 3), + // ("streaming", 25, 51, 26, 4), + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("structured", 41, 51, 10, 1), + // ("hello", 60, 70, 10, 1), + // ("apache", 60, 70, 10, 1), + // ("spark", 60, 70, 10, 1) + CheckNewAnswer( + ("hello", 60, 70, 10, 1), + ("apache", 60, 70, 10, 1), + ("spark", 60, 70, 10, 1) + ), + + AddData(inputData, ("structured streaming", 90L)), + // watermark: 60 + // current sessions + // ("hello", 60, 70, 10, 1), + // ("apache", 60, 70, 10, 1), + // ("spark", 60, 70, 10, 1), + // ("structured", 90, 100, 10, 1), + // ("streaming", 90, 100, 10, 1) + // evicted + // ("spark", 25, 50, 25, 3), + // ("streaming", 25, 51, 26, 4), + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("structured", 41, 51, 10, 1) + CheckNewAnswer( + ("structured", 90, 100, 10, 1), + ("streaming", 90, 100, 10, 1) + ) + ) + } + + testWithAllOptions("update mode - session window - no key") { + val inputData = MemoryStream[Int] + + val windowedAggregation = inputData.toDF() + .selectExpr("*") + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(session_window($"eventTime", "5 seconds") as 'session) + .agg(count("*") as 'count, sum("value") as 'sum) + .select($"session".getField("start").cast("long").as[Long], + $"session".getField("end").cast("long").as[Long], $"count".as[Long], $"sum".as[Long]) + + val e = intercept[StreamingQueryException] { + testStream(windowedAggregation, OutputMode.Update())( + AddData(inputData, 40), + CheckAnswer() // this is just to trigger the exception + ) + } + Seq("Global aggregation with session window", "not supported").foreach { m => + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(m.toLowerCase(Locale.ROOT))) + } + } +}