diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 5425d311f8c7..19c85b8d3c94 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1397,6 +1397,23 @@ def check_string_field(field, fieldName): return Column(res) +@since(3.0) +@ignore_unicode_prefix +def session_window(timeColumn, gapDuration): + """ + # FIXME: python doc!! + """ + def check_string_field(field, fieldName): + if not field or type(field) is not str: + raise TypeError("%s should be provided as a string" % fieldName) + + sc = SparkContext._active_spark_context + time_col = _to_java_column(timeColumn) + check_string_field(gapDuration, "gapDuration") + res = sc._jvm.functions.session_window(time_col, gapDuration) + return Column(res) + + # ---------------------------- misc functions ---------------------------------- @since(1.5) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d72e512e0df5..94e4a2fa00ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -181,6 +181,7 @@ class Analyzer( GlobalAggregates :: ResolveAggregateFunctions :: TimeWindowing :: + SessionWindowing :: ResolveInlineTables(conf) :: ResolveHigherOrderFunctions(catalog) :: ResolveLambdaVariables(conf) :: @@ -2643,9 +2644,13 @@ object TimeWindowing extends Rule[LogicalPlan] { val windowExpressions = p.expressions.flatMap(_.collect { case t: TimeWindow => t }).toSet - val numWindowExpr = windowExpressions.size + val numWindowExpr = p.expressions.flatMap(_.collect { + case s: SessionWindow => s + case t: TimeWindow => t + }).toSet.size + // Only support a single window expression for now - if (numWindowExpr == 1 && + if (numWindowExpr == 1 && windowExpressions.nonEmpty && windowExpressions.head.timeColumn.resolved && windowExpressions.head.checkInputDataTypes().isSuccess) { @@ -2713,8 +2718,80 @@ object TimeWindowing extends Rule[LogicalPlan] { renamedPlan.withNewChildren(substitutedPlan :: Nil) } } else if (numWindowExpr > 1) { - p.failAnalysis("Multiple time window expressions would result in a cartesian product " + - "of rows, therefore they are currently not supported.") + p.failAnalysis("Multiple time/session window expressions would result in a cartesian " + + "product of rows, therefore they are currently not supported.") + } else { + p // Return unchanged. Analyzer will throw exception later + } + } +} + +// FIXME: javadoc +object SessionWindowing extends Rule[LogicalPlan] { + import org.apache.spark.sql.catalyst.dsl.expressions._ + + private final val SESSION_COL_NAME = "session_window" + private final val SESSION_START = "start" + private final val SESSION_END = "end" + + // FIXME: javadoc + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { + case p: LogicalPlan if p.children.size == 1 => + val child = p.children.head + val sessionExpressions = + p.expressions.flatMap(_.collect { case s: SessionWindow => s }).toSet + + val numWindowExpr = p.expressions.flatMap(_.collect { + case s: SessionWindow => s + case t: TimeWindow => t + }).toSet.size + + // Only support a single session expression for now + if (numWindowExpr == 1 && sessionExpressions.nonEmpty && + sessionExpressions.head.timeColumn.resolved && + sessionExpressions.head.checkInputDataTypes().isSuccess) { + + val session = sessionExpressions.head + + val metadata = session.timeColumn match { + case a: Attribute => a.metadata + case _ => Metadata.empty + } + + val newMetadata = new MetadataBuilder() + .withMetadata(metadata) + .putBoolean(SessionWindow.marker, true) + .build() + + val sessionAttr = AttributeReference( + SESSION_COL_NAME, session.dataType, metadata = newMetadata)() + + val sessionStart = PreciseTimestampConversion(session.timeColumn, TimestampType, LongType) + val sessionEnd = sessionStart + session.gapDuration + + val literalSessionStruct = CreateNamedStruct( + Literal(SESSION_START) :: + PreciseTimestampConversion(sessionStart, LongType, TimestampType) :: + Literal(SESSION_END) :: + PreciseTimestampConversion(sessionEnd, LongType, TimestampType) :: + Nil) + + val sessionStruct = Alias(literalSessionStruct, SESSION_COL_NAME)( + exprId = sessionAttr.exprId, explicitMetadata = Some(newMetadata)) + + val replacedPlan = p transformExpressions { + case s: SessionWindow => sessionAttr + } + + // For backwards compatibility we add a filter to filter out nulls + val filterExpr = IsNotNull(session.timeColumn) + + replacedPlan.withNewChildren( + Filter(filterExpr, + Project(sessionStruct +: child.output, child)) :: Nil) + } else if (numWindowExpr > 1) { + p.failAnalysis("Multiple time/session window expressions would result in a " + + "cartesian product of rows, therefore they are currently not supported.") } else { p // Return unchanged. Analyzer will throw exception later } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 7dafebff7987..620b2b396ce4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -398,6 +398,7 @@ object FunctionRegistry { expression[WeekOfYear]("weekofyear"), expression[Year]("year"), expression[TimeWindow]("window"), + expression[SessionWindow]("session_window"), // collection functions expression[CreateArray]("array"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala new file mode 100644 index 000000000000..54c516a93c18 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.commons.lang3.StringUtils + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval + +case class SessionWindow(timeColumn: Expression, gapDuration: Long) extends UnaryExpression + with ImplicitCastInputTypes + with Unevaluable + with NonSQLExpression { + + ////////////////////////// + // SQL Constructors + ////////////////////////// + + def this(timeColumn: Expression, gapDuration: Expression) = { + this(timeColumn, SessionWindow.parseExpression(gapDuration)) + } + + override def child: Expression = timeColumn + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) + override def dataType: DataType = new StructType() + .add(StructField("start", TimestampType)) + .add(StructField("end", TimestampType)) + + // This expression is replaced in the analyzer. + override lazy val resolved = false + + /** Validate the inputs for the gap duration in addition to the input data type. */ + override def checkInputDataTypes(): TypeCheckResult = { + val dataTypeCheck = super.checkInputDataTypes() + if (dataTypeCheck.isSuccess) { + if (gapDuration <= 0) { + return TypeCheckFailure(s"The window duration ($gapDuration) must be greater than 0.") + } + } + dataTypeCheck + } +} + +object SessionWindow { + val marker = "spark.sessionWindow" + + /** + * Parses the interval string for a valid time duration. CalendarInterval expects interval + * strings to start with the string `interval`. For usability, we prepend `interval` to the string + * if the user omitted it. + * + * @param interval The interval string + * @return The interval duration in microseconds. SparkSQL casts TimestampType has microsecond + * precision. + */ + private def getIntervalInMicroSeconds(interval: String): Long = { + if (StringUtils.isBlank(interval)) { + throw new IllegalArgumentException( + "The window duration, slide duration and start time cannot be null or blank.") + } + val intervalString = if (interval.startsWith("interval")) { + interval + } else { + "interval " + interval + } + val cal = CalendarInterval.fromString(intervalString) + if (cal == null) { + throw new IllegalArgumentException( + s"The provided interval ($interval) did not correspond to a valid interval string.") + } + if (cal.months > 0) { + throw new IllegalArgumentException( + s"Intervals greater than a month is not supported ($interval).") + } + cal.microseconds + } + + /** + * Parses the duration expression to generate the long value for the original constructor so + * that we can use `window` in SQL. + */ + private def parseExpression(expr: Expression): Long = expr match { + case NonNullLiteral(s, StringType) => getIntervalInMicroSeconds(s.toString) + case IntegerLiteral(i) => i.toLong + case NonNullLiteral(l, LongType) => l.toString.toLong + case _ => throw new AnalysisException("The duration and time inputs to window must be " + + "an integer, long or string literal.") + } + + def apply( + timeColumn: Expression, + gapDuration: String): SessionWindow = { + SessionWindow(timeColumn, + getIntervalInMicroSeconds(gapDuration)) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index d5857e060a2c..d47c161ae7d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1229,6 +1229,10 @@ object CodeGenerator extends Logging { // bytecode instruction final val MUTABLESTATEARRAY_SIZE_LIMIT = 32768 + // This is the threshold to print out debug information when code generation takes more + // than this value. + final val SLOW_CODEGEN_MILLIS_THRESHOLD = 100 + /** * Compile the Java source code into a Java class, using Janino. * @@ -1375,6 +1379,27 @@ object CodeGenerator extends Logging { CodegenMetrics.METRIC_SOURCE_CODE_SIZE.update(code.body.length) CodegenMetrics.METRIC_COMPILATION_TIME.update(timeMs.toLong) logInfo(s"Code generated in $timeMs ms") + + if (timeMs > SLOW_CODEGEN_MILLIS_THRESHOLD) { + logWarning(s"Code generation took more than $SLOW_CODEGEN_MILLIS_THRESHOLD ms." + + "Please set logger level to DEBUG to see further debug information.") + + logDebug(s"Printing out debug information - body: ${code.body}... / " + + s"comment: ${code.comment}") + + def getRelevantStackTraceForDebug(): Array[StackTraceElement] = { + Thread.currentThread().getStackTrace.drop(1) + .filterNot { p => + p.getClassName.startsWith("com.google.common") || + p.getClassName.startsWith("org.apache.spark.sql.catalyst") || + p.getClassName.startsWith("org.apache.spark.rdd") + } + } + + logDebug(s"Stack trace - " + + s"${getRelevantStackTraceForDebug().take(30).map(_.toString).mkString("\n")}") + } + result } }) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index b699707d8523..88c513a3e46b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -915,6 +915,15 @@ object SQLConf { .checkValue(v => Set(1, 2).contains(v), "Valid versions are 1 and 2") .createWithDefault(2) + val STREAMING_SESSION_WINDOW_MERGE_SESSIONS_IN_LOCAL_PARTITION = + buildConf("spark.sql.streaming.sessionWindow.merge.sessions.in.local.partition") + .internal() + .doc("When true, streaming session window sorts and merge sessions in local partition " + + "prior to shuffle. This is to reduce the rows to shuffle, but only beneficial when " + + "there're lots of rows in a batch being assigned to same sessions.") + .booleanConf + .createWithDefault(false) + val UNSUPPORTED_OPERATION_CHECK_ENABLED = buildConf("spark.sql.streaming.unsupportedOperationCheck") .internal() @@ -1750,6 +1759,9 @@ class SQLConf extends Serializable with Logging { def fastHashAggregateRowMaxCapacityBit: Int = getConf(FAST_HASH_AGGREGATE_MAX_ROWS_CAPACITY_BIT) + def streamingSessionWindowMergeSessionInLocalPartition: Boolean = + getConf(STREAMING_SESSION_WINDOW_MERGE_SESSIONS_IN_LOCAL_PARTITION) + /** * Returns the [[Resolver]] for the current configuration, which can be used to determine if two * identifiers are equal. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index dbc6db62bd82..2cc8ddaf9979 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -329,14 +329,31 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { "Streaming aggregation doesn't support group aggregate pandas UDF") } - val stateVersion = conf.getConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION) - - aggregate.AggUtils.planStreamingAggregation( - namedGroupingExpressions, - aggregateExpressions.map(expr => expr.asInstanceOf[AggregateExpression]), - rewrittenResultExpressions, - stateVersion, - planLater(child)) + val sessionWindowOption = namedGroupingExpressions.find { p => + p.metadata.contains(SessionWindow.marker) + } + + sessionWindowOption match { + case Some(sessionWindow) => + + aggregate.AggUtils.planStreamingAggregationForSession( + namedGroupingExpressions, + sessionWindow, + aggregateExpressions.map(expr => expr.asInstanceOf[AggregateExpression]), + rewrittenResultExpressions, + conf.streamingSessionWindowMergeSessionInLocalPartition, + planLater(child)) + + case None => + val stateVersion = conf.getConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION) + + aggregate.AggUtils.planStreamingAggregation( + namedGroupingExpressions, + aggregateExpressions.map(expr => expr.asInstanceOf[AggregateExpression]), + rewrittenResultExpressions, + stateVersion, + planLater(child)) + } case _ => Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 6be88c463dbd..faca078041ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution} import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.streaming.{StateStoreRestoreExec, StateStoreSaveExec} -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.execution.streaming._ /** * Utility functions used by the query planner to convert our plan to new aggregation code path. @@ -98,6 +98,9 @@ object AggUtils { resultExpressions = partialResultExpressions, child = child) + val interExec: SparkPlan = mayAppendMergingSessionExec(groupingExpressions, + aggregateExpressions, partialAggregate) + // 2. Create an Aggregate Operator for final aggregations. val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final)) // The attributes of the final aggregation buffer, which is presented as input to the result @@ -111,7 +114,7 @@ object AggUtils { aggregateAttributes = finalAggregateAttributes, initialInputBufferOffset = groupingExpressions.length, resultExpressions = resultExpressions, - child = partialAggregate) + child = interExec) finalAggregate :: Nil } @@ -123,6 +126,8 @@ object AggUtils { resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { + val maySessionChild = mayAppendUpdatingSessionExec(groupingExpressions, child) + // functionsWithDistinct is guaranteed to be non-empty. Even though it may contain more than one // DISTINCT aggregate function, all of those functions will have the same column expressions. // For example, it would be valid for functionsWithDistinct to be @@ -149,7 +154,7 @@ object AggUtils { aggregateAttributes = aggregateAttributes, resultExpressions = groupingAttributes ++ distinctAttributes ++ aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), - child = child) + child = maySessionChild) } // 2. Create an Aggregate Operator for partial merge aggregations. @@ -204,6 +209,7 @@ object AggUtils { val partialAggregateResult = groupingAttributes ++ mergeAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) ++ distinctAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) + createAggregate( groupingExpressions = groupingAttributes, aggregateExpressions = mergeAggregateExpressions ++ distinctAggregateExpressions, @@ -338,4 +344,188 @@ object AggUtils { finalAndCompleteAggregate :: Nil } + + /** + * Plans a streaming session aggregation using the following progression: + * + * - Partial Aggregation + * - all tuples will have aggregated columns with initial value + * - (If "spark.sql.streaming.sessionWindow.merge.sessions.in.local.partition" is enabled) + * - Sort within partition (sort: all keys) + * - MergingSessionExec + * - calculate session among tuples, and aggregate tuples in session with partial merge + * - Shuffle & Sort (distribution: keys "without" session, sort: all keys) + * - SessionWindowStateStoreRestore (group: keys "without" session) + * - merge input tuples with stored tuples (sessions) respecting sort order + * - MergingSessionExec + * - calculate session among tuples, and aggregate tuples in session with partial merge + * - NOTE: it leverages the fact that the output of SessionWindowStateStoreRestore is sorted + * - now there is at most 1 tuple per group, key with session + * - SessionWindowStateStoreSave (group: keys "without" session) + * - saves tuple(s) for the next batch (multiple sessions could co-exist at the same time) + * - Complete (output the current result of the aggregation) + */ + def planStreamingAggregationForSession( + groupingExpressions: Seq[NamedExpression], + sessionExpression: NamedExpression, + functionsWithoutDistinct: Seq[AggregateExpression], + resultExpressions: Seq[NamedExpression], + mergeSessionsInLocalPartition: Boolean, + child: SparkPlan): Seq[SparkPlan] = { + + val groupWithoutSessionExpression = groupingExpressions.filterNot { p => + p.semanticEquals(sessionExpression) + } + + val groupingWithoutSessionAttributes = groupWithoutSessionExpression.map(_.toAttribute) + + val groupingAttributes = groupingExpressions.map(_.toAttribute) + + // we don't do partial aggregate here, because it requires additional shuffle + // and there will be less rows which have same session start + // here doing partial merge is to have aggregated columns with default value for each row + val partialAggregate: SparkPlan = { + val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) + val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) + createAggregate( + groupingExpressions = groupingExpressions, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + resultExpressions = groupingAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + child = child) + } + + val partialMerged1: SparkPlan = if (mergeSessionsInLocalPartition) { + val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) + val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) + + // sort happens here to merge sessions on each partition + // this is to reduce amount of rows to shuffle + MergingSessionsExec( + requiredChildDistributionExpressions = None, + requiredChildDistributionOption = None, + groupingExpressions = groupingAttributes, + sessionExpression = sessionExpression, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = groupingAttributes.length, + resultExpressions = groupingAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + child = partialAggregate + ) + } else { + partialAggregate + } + + // shuffle & sort happens here: most of details are also handled in this physical plan + val restored = SessionWindowStateStoreRestoreExec(groupingWithoutSessionAttributes, + sessionExpression.toAttribute, stateInfo = None, eventTimeWatermark = None, + partialMerged1) + + val mergedSessions = { + val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) + val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) + MergingSessionsExec( + requiredChildDistributionExpressions = None, + requiredChildDistributionOption = Some(restored.requiredChildDistribution), + groupingExpressions = groupingAttributes, + sessionExpression = sessionExpression, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = groupingAttributes.length, + resultExpressions = groupingAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + child = restored + ) + } + + // Note: stateId and returnAllStates are filled in later with preparation rules + // in IncrementalExecution. + val saved = + SessionWindowStateStoreSaveExec( + groupingWithoutSessionAttributes, + sessionExpression.toAttribute, + stateInfo = None, + outputMode = None, + eventTimeWatermark = None, + mergedSessions) + + val finalAndCompleteAggregate: SparkPlan = { + val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final)) + // The attributes of the final aggregation buffer, which is presented as input to the result + // projection: + val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute) + + createAggregate( + requiredChildDistributionExpressions = Some(groupingAttributes), + groupingExpressions = groupingAttributes, + aggregateExpressions = finalAggregateExpressions, + aggregateAttributes = finalAggregateAttributes, + initialInputBufferOffset = groupingAttributes.length, + resultExpressions = resultExpressions, + child = saved) + } + + finalAndCompleteAggregate :: Nil + } + + private def mayAppendUpdatingSessionExec(groupingExpressions: Seq[NamedExpression], + maybeChildPlan: SparkPlan): SparkPlan = + groupingExpressions.find(_.metadata.contains(SessionWindow.marker)) match { + case Some(sessionExpression) => + val groupWithoutSessionExpression = groupingExpressions.filterNot { + p => p.semanticEquals(sessionExpression) + } + + val groupingWithoutSessionAttributes = groupWithoutSessionExpression.map(_.toAttribute) + + val childDistribution = if (groupWithoutSessionExpression.isEmpty) { + AllTuples :: Nil + } else { + ClusteredDistribution(groupWithoutSessionExpression) :: Nil + } + val childOrdering = Seq((groupingWithoutSessionAttributes ++ Seq(sessionExpression)) + .map(SortOrder(_, Ascending))) + val updatedSession = UpdatingSessionExec( + groupingExpressions.map(_.toAttribute), + sessionExpression.toAttribute, + optRequiredChildDistribution = Some(childDistribution), + optRequiredChildOrdering = Some(childOrdering), + maybeChildPlan) + updatedSession + case None => maybeChildPlan + } + + private def mayAppendMergingSessionExec( + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression], + partialAggregate: SparkPlan): SparkPlan = { + + groupingExpressions.find(_.metadata.contains(SessionWindow.marker)) match { + case Some(sessionExpression) => + val aggExpressions = aggregateExpressions.map(_.copy(mode = PartialMerge)) + val aggAttributes = aggregateExpressions.map(_.resultAttribute) + + val groupingAttributes = groupingExpressions.map(_.toAttribute) + val groupingWithoutSessionExpressions = groupingExpressions.diff(Seq(sessionExpression)) + val groupingWithoutSessionsAttributes = groupingWithoutSessionExpressions + .map(_.toAttribute) + + MergingSessionsExec( + requiredChildDistributionExpressions = Some(groupingWithoutSessionsAttributes), + requiredChildDistributionOption = None, + groupingExpressions = groupingAttributes, + sessionExpression = sessionExpression, + aggregateExpressions = aggExpressions, + aggregateAttributes = aggAttributes, + initialInputBufferOffset = groupingAttributes.length, + resultExpressions = groupingAttributes ++ + aggExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + child = partialAggregate + ) + + case None => partialAggregate + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/MergingSessionsExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/MergingSessionsExec.scala new file mode 100644 index 000000000000..e2784c0d2bd0 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/MergingSessionsExec.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeSet, Expression, NamedExpression, SortOrder, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.metric.SQLMetrics + +// FIXME: javadoc should provide precondition that input must be sorted +// or both required child distribution as well as required child ordering should be presented +// to guarantee input will be sorted +case class MergingSessionsExec( + requiredChildDistributionExpressions: Option[Seq[Expression]], + requiredChildDistributionOption: Option[Seq[Distribution]], + groupingExpressions: Seq[NamedExpression], + sessionExpression: NamedExpression, + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], + initialInputBufferOffset: Int, + resultExpressions: Seq[NamedExpression], + child: SparkPlan) extends UnaryExecNode { + + val keyWithoutSessionExpressions = groupingExpressions.diff(Seq(sessionExpression)) + + private[this] val aggregateBufferAttributes = { + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) + } + + override def producedAttributes: AttributeSet = + AttributeSet(aggregateAttributes) ++ + AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++ + AttributeSet(aggregateBufferAttributes) + + override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) + + override def output: Seq[Attribute] = child.output + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + override def requiredChildDistribution: Seq[Distribution] = { + requiredChildDistributionExpressions match { + case Some(exprs) if exprs.isEmpty => AllTuples :: Nil + case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil + case None => requiredChildDistributionOption match { + case Some(distributions) => distributions + case None => UnspecifiedDistribution :: Nil + } + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = { + Seq((keyWithoutSessionExpressions ++ Seq(sessionExpression)).map(SortOrder(_, Ascending))) + } + + override protected def doExecute(): RDD[InternalRow] = { + val numOutputRows = longMetric("numOutputRows") + child.execute().mapPartitionsWithIndexInternal { (partIndex, iter) => + // Because the constructor of an aggregation iterator will read at least the first row, + // we need to get the value of iter.hasNext first. + val hasInput = iter.hasNext + if (!hasInput && groupingExpressions.nonEmpty) { + // This is a grouped aggregate and the input iterator is empty, + // so return an empty iterator. + Iterator[UnsafeRow]() + } else { + val outputIter = new MergingSessionsIterator( + partIndex, + groupingExpressions, + sessionExpression, + child.output, + iter, + aggregateExpressions, + aggregateAttributes, + initialInputBufferOffset, + resultExpressions, + (expressions, inputSchema) => + newMutableProjection(expressions, inputSchema, subexpressionEliminationEnabled), + numOutputRows) + if (!hasInput && groupingExpressions.isEmpty) { + // There is no input and there is no grouping expressions. + // We need to output a single row as the output. + numOutputRows += 1 + Iterator[UnsafeRow](outputIter.outputForEmptyGroupingKeyWithoutInput()) + } else { + outputIter + } + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/MergingSessionsIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/MergingSessionsIterator.scala new file mode 100644 index 000000000000..1b9f78274c56 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/MergingSessionsIterator.scala @@ -0,0 +1,242 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, GenericInternalRow, JoinedRow, Literal, MutableProjection, NamedExpression, PreciseTimestampConversion, SpecificInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.execution.metric.SQLMetric + +// FIXME: javadoc! +// FIXME: groupingExpressions should contain sessionExpression +class MergingSessionsIterator( + partIndex: Int, + groupingExpressions: Seq[NamedExpression], + sessionExpression: NamedExpression, + valueAttributes: Seq[Attribute], + inputIterator: Iterator[InternalRow], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], + initialInputBufferOffset: Int, + resultExpressions: Seq[NamedExpression], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection, + numOutputRows: SQLMetric) + extends AggregationIterator( + partIndex, + groupingExpressions, + valueAttributes, + aggregateExpressions, + aggregateAttributes, + initialInputBufferOffset, + resultExpressions, + newMutableProjection) { + + val groupingWithoutSession: Seq[NamedExpression] = + groupingExpressions.diff(Seq(sessionExpression)) + val groupingWithoutSessionAttributes: Seq[Attribute] = groupingWithoutSession.map(_.toAttribute) + + + /** + * Creates a new aggregation buffer and initializes buffer values + * for all aggregate functions. + */ + private def newBuffer: InternalRow = { + val bufferSchema = aggregateFunctions.flatMap(_.aggBufferAttributes) + val bufferRowSize: Int = bufferSchema.length + + val genericMutableBuffer = new GenericInternalRow(bufferRowSize) + val useUnsafeBuffer = bufferSchema.map(_.dataType).forall(UnsafeRow.isMutable) + + val buffer = if (useUnsafeBuffer) { + val unsafeProjection = + UnsafeProjection.create(bufferSchema.map(_.dataType)) + unsafeProjection.apply(genericMutableBuffer) + } else { + genericMutableBuffer + } + initializeBuffer(buffer) + buffer + } + + /////////////////////////////////////////////////////////////////////////// + // Mutable states for sort based aggregation. + /////////////////////////////////////////////////////////////////////////// + + // The partition key of the current partition. + private[this] var currentGroupingKey: UnsafeRow = _ + + private[this] var currentSession: UnsafeRow = _ + + // The partition key of next partition. + private[this] var nextGroupingKey: UnsafeRow = _ + + private[this] var nextGroupingSession: UnsafeRow = _ + + // The first row of next partition. + private[this] var firstRowInNextGroup: InternalRow = _ + + // Indicates if we has new group of rows from the sorted input iterator + private[this] var sortedInputHasNewGroup: Boolean = false + + // The aggregation buffer used by the sort-based aggregation. + private[this] val sortBasedAggregationBuffer: InternalRow = newBuffer + + private[this] val groupingWithoutSessionProjection: UnsafeProjection = + UnsafeProjection.create(groupingWithoutSession, valueAttributes) + + private[this] val sessionIndex = resultExpressions.indexOf(sessionExpression) + + private[this] val sessionProjection: UnsafeProjection = + UnsafeProjection.create(Seq(sessionExpression), valueAttributes) + + protected def initialize(): Unit = { + if (inputIterator.hasNext) { + initializeBuffer(sortBasedAggregationBuffer) + val inputRow = inputIterator.next() + nextGroupingKey = groupingWithoutSessionProjection(inputRow).copy() + val session = sessionProjection(inputRow) + nextGroupingSession = session.getStruct(0, 2).copy() + firstRowInNextGroup = inputRow.copy() + sortedInputHasNewGroup = true + } else { + // This inputIter is empty. + sortedInputHasNewGroup = false + } + } + + initialize() + + /** Processes rows in the current group. It will stop when it find a new group. */ + protected def processCurrentSortedGroup(): Unit = { + currentGroupingKey = nextGroupingKey + currentSession = nextGroupingSession + + // Now, we will start to find all rows belonging to this group. + // We create a variable to track if we see the next group. + var findNextPartition = false + // firstRowInNextGroup is the first row of this group. We first process it. + processRow(sortBasedAggregationBuffer, firstRowInNextGroup) + + // The search will stop when we see the next group or there is no + // input row left in the iter. + while (!findNextPartition && inputIterator.hasNext) { + // Get the grouping key. + val currentRow = inputIterator.next() + val groupingKey = groupingWithoutSessionProjection(currentRow) + + val session = sessionProjection(currentRow) + val sessionStruct = session.getStruct(0, 2) + val sessionStart = getSessionStart(sessionStruct) + val sessionEnd = getSessionEnd(sessionStruct) + + // Check if the current row belongs the current input row. + if (currentGroupingKey == groupingKey) { + if (sessionStart < getSessionStart(currentSession)) { + throw new IllegalArgumentException("Input iterator is not sorted based on session!") + } else if (sessionStart <= getSessionEnd(currentSession)) { + // expanding session length if needed + expandEndOfCurrentSession(sessionEnd) + processRow(sortBasedAggregationBuffer, currentRow) + } else { + // We find a new group. + findNextPartition = true + startNewSession(currentRow, groupingKey, sessionStruct) + } + } else { + // We find a new group. + findNextPartition = true + startNewSession(currentRow, groupingKey, sessionStruct) + } + } + + // We have not seen a new group. It means that there is no new row in the input + // iter. The current group is the last group of the iter. + if (!findNextPartition) { + sortedInputHasNewGroup = false + } + } + + private def startNewSession(currentRow: InternalRow, groupingKey: UnsafeRow, + sessionStruct: UnsafeRow): Unit = { + nextGroupingKey = groupingKey.copy() + nextGroupingSession = sessionStruct.copy() + firstRowInNextGroup = currentRow.copy() + } + + private def getSessionStart(sessionStruct: UnsafeRow): Long = { + sessionStruct.getLong(0) + } + + private def getSessionEnd(sessionStruct: UnsafeRow): Long = { + sessionStruct.getLong(1) + } + + def updateSessionEnd(sessionStruct: UnsafeRow, sessionEnd: Long): Unit = { + sessionStruct.setLong(1, sessionEnd) + } + + private def expandEndOfCurrentSession(sessionEnd: Long): Unit = { + if (sessionEnd > getSessionEnd(currentSession)) { + updateSessionEnd(currentSession, sessionEnd) + } + } + + /////////////////////////////////////////////////////////////////////////// + // Iterator's public methods + /////////////////////////////////////////////////////////////////////////// + + override final def hasNext: Boolean = sortedInputHasNewGroup + + override final def next(): UnsafeRow = { + if (hasNext) { + // Process the current group. + processCurrentSortedGroup() + // Generate output row for the current group. + + val groupingKey = generateGroupingKey() + + val outputRow = generateOutput(groupingKey, sortBasedAggregationBuffer) + // Initialize buffer values for the next group. + initializeBuffer(sortBasedAggregationBuffer) + numOutputRows += 1 + outputRow + } else { + // no more result + throw new NoSuchElementException + } + } + + private val join = new JoinedRow + + private val groupingKeyProj = GenerateUnsafeProjection.generate(groupingExpressions, + groupingWithoutSessionAttributes :+ sessionExpression.toAttribute) + + private def generateGroupingKey(): UnsafeRow = { + val newRow = new SpecificInternalRow(Seq(sessionExpression.toAttribute).toStructType) + newRow.update(0, currentSession) + val joined = join(currentGroupingKey, newRow) + + groupingKeyProj(joined) + } + + def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = { + initializeBuffer(sortBasedAggregationBuffer) + generateOutput(UnsafeRow.createFromByteArray(0, 0), sortBasedAggregationBuffer) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionExec.scala new file mode 100644 index 000000000000..0c625f6cf1ce --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionExec.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder} +import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning} +import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} + +// FIXME: javadoc should provide precondition that input must be sorted +// or both required child distribution as well as required child ordering should be presented +// to guarantee input will be sorted +case class UpdatingSessionExec( + keyExpressions: Seq[Attribute], + sessionExpression: Attribute, + optRequiredChildDistribution: Option[Seq[Distribution]], + optRequiredChildOrdering: Option[Seq[Seq[SortOrder]]], + child: SparkPlan) extends UnaryExecNode { + + override protected def doExecute(): RDD[InternalRow] = { + val inMemoryThreshold = sqlContext.conf.windowExecBufferInMemoryThreshold + val spillThreshold = sqlContext.conf.windowExecBufferSpillThreshold + + child.execute().mapPartitions { iter => + new UpdatingSessionIterator(iter, keyExpressions, sessionExpression, + child.output, inMemoryThreshold, spillThreshold) + } + } + + override def output: Seq[Attribute] = child.output + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def requiredChildDistribution: Seq[Distribution] = optRequiredChildDistribution match { + case Some(distribution) => distribution + case None => super.requiredChildDistribution + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = optRequiredChildOrdering match { + case Some(ordering) => ordering + case None => super.requiredChildOrdering + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionIterator.scala new file mode 100644 index 000000000000..ab54fa2220cb --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionIterator.scala @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray +import org.apache.spark.sql.types.{LongType, TimestampType} + +// FIXME: javadoc!! +class UpdatingSessionIterator( + iter: Iterator[InternalRow], + groupingExpressions: Seq[NamedExpression], + sessionExpression: NamedExpression, + inputSchema: Seq[Attribute], + inMemoryThreshold: Int, + spillThreshold: Int) extends Iterator[InternalRow] { + + val sessionIndex = inputSchema.indexOf(sessionExpression) + + private val groupingWithoutSession: Seq[NamedExpression] = + groupingExpressions.diff(Seq(sessionExpression)) + private val groupingWithoutSessionAttributes: Seq[Attribute] = + groupingWithoutSession.map(_.toAttribute) + private[this] val groupingWithoutSessionProjection: UnsafeProjection = + UnsafeProjection.create(groupingWithoutSession, inputSchema) + + val valuesExpressions: Seq[Attribute] = inputSchema.diff(groupingWithoutSession) + + private[this] val sessionProjection: UnsafeProjection = + UnsafeProjection.create(Seq(sessionExpression), inputSchema) + + var currentKeys: InternalRow = _ + var currentSession: UnsafeRow = _ + + var currentRows: ExternalAppendOnlyUnsafeRowArray = new ExternalAppendOnlyUnsafeRowArray( + inMemoryThreshold, spillThreshold) + + var returnRows: ExternalAppendOnlyUnsafeRowArray = _ + var returnRowsIter: Iterator[InternalRow] = _ + var errorOnIterator: Boolean = false + + val processedKeys: mutable.HashSet[InternalRow] = new mutable.HashSet[InternalRow]() + + override def hasNext: Boolean = { + assertIteratorNotCorrupted() + + if (returnRowsIter != null && returnRowsIter.hasNext) { + return true + } + + if (returnRowsIter != null) { + returnRowsIter = null + returnRows.clear() + } + + iter.hasNext + } + + override def next(): InternalRow = { + assertIteratorNotCorrupted() + + if (returnRowsIter != null && returnRowsIter.hasNext) { + return returnRowsIter.next() + } + + var exitCondition = false + while (iter.hasNext && !exitCondition) { + // we are going to modify the row, so we should make sure multiple objects are not + // referencing same memory, which could be possible when optimizing iterator + // without this, multiple rows in same key will be returned with same content + val row = iter.next().copy() + + val keys = groupingWithoutSessionProjection(row) + val session = sessionProjection(row) + val sessionStruct = session.getStruct(0, 2) + val sessionStart = getSessionStart(sessionStruct) + val sessionEnd = getSessionEnd(sessionStruct) + + if (currentKeys == null) { + startNewSession(row, keys, sessionStruct) + } else if (keys != currentKeys) { + closeCurrentSession(keyChanged = true) + processedKeys.add(currentKeys) + startNewSession(row, keys, sessionStruct) + exitCondition = true + } else { + if (sessionStart < getSessionStart(currentSession)) { + handleBrokenPreconditionForSort() + } else if (sessionStart <= getSessionEnd(currentSession)) { + // expanding session length if needed + expandEndOfCurrentSession(sessionEnd) + currentRows.add(row.asInstanceOf[UnsafeRow]) + } else { + closeCurrentSession(keyChanged = false) + startNewSession(row, keys, sessionStruct) + exitCondition = true + } + } + } + + if (!iter.hasNext) { + // no further row: closing session + closeCurrentSession(keyChanged = false) + } + + // here returnRowsIter should be able to provide at least one row + require(returnRowsIter != null && returnRowsIter.hasNext) + + returnRowsIter.next() + } + + private def startNewSession(currentRow: InternalRow, groupingKey: UnsafeRow, + sessionStruct: UnsafeRow): Unit = { + if (processedKeys.contains(groupingKey)) { + handleBrokenPreconditionForSort() + } + + currentKeys = groupingKey.copy() + currentSession = sessionStruct.copy() + + currentRows.clear() + currentRows.add(currentRow.asInstanceOf[UnsafeRow]) + } + + private def getSessionStart(sessionStruct: UnsafeRow): Long = { + sessionStruct.getLong(0) + } + + private def getSessionEnd(sessionStruct: UnsafeRow): Long = { + sessionStruct.getLong(1) + } + + def updateSessionEnd(sessionStruct: UnsafeRow, sessionEnd: Long): Unit = { + sessionStruct.setLong(1, sessionEnd) + } + + private def expandEndOfCurrentSession(sessionEnd: Long): Unit = { + if (sessionEnd > getSessionEnd(currentSession)) { + updateSessionEnd(currentSession, sessionEnd) + } + } + + private def handleBrokenPreconditionForSort(): Unit = { + errorOnIterator = true + throw new IllegalStateException("The iterator must be sorted by key and session start!") + } + + private def createSessionRow(): InternalRow = { + val sessionRow = new SpecificInternalRow(Seq(sessionExpression.toAttribute).toStructType) + sessionRow.update(0, currentSession) + sessionRow + } + + private val join = new JoinedRow + private val join2 = new JoinedRow + + private val groupingKeyProj = GenerateUnsafeProjection.generate(groupingExpressions, + groupingWithoutSessionAttributes :+ sessionExpression.toAttribute) + private val valueProj = GenerateUnsafeProjection.generate(valuesExpressions, inputSchema) + private val restoreProj = GenerateUnsafeProjection.generate(inputSchema, + groupingExpressions.map(_.toAttribute) ++ valuesExpressions.map(_.toAttribute)) + + private def generateGroupingKey(): UnsafeRow = { + val newRow = new SpecificInternalRow(Seq(sessionExpression.toAttribute).toStructType) + newRow.update(0, currentSession) + val joined = join(currentKeys, newRow) + + groupingKeyProj(joined) + } + + private def closeCurrentSession(keyChanged: Boolean): Unit = { + returnRows = currentRows + currentRows = new ExternalAppendOnlyUnsafeRowArray( + inMemoryThreshold, spillThreshold) + + val groupingKey = generateGroupingKey() + + val currentRowsIter = returnRows.generateIterator().map { internalRow => + val valueRow = valueProj(internalRow) + restoreProj(join2(groupingKey, valueRow)).copy() + } + + if (returnRowsIter != null && returnRowsIter.hasNext) { + returnRowsIter = returnRowsIter ++ currentRowsIter + } else { + returnRowsIter = currentRowsIter + } + + if (keyChanged) processedKeys.add(currentKeys) + + currentKeys = null + currentSession = null + } + + private def assertIteratorNotCorrupted(): Unit = { + if (errorOnIterator) { + throw new IllegalStateException("The iterator is already corrupted.") + } + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala index 2ab7240556aa..80f5aee29e5c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.aggregate.UpdatingSessionIterator import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types.{DataType, StructField, StructType} import org.apache.spark.util.Utils @@ -53,14 +54,33 @@ case class AggregateInPandasExec( override def producedAttributes: AttributeSet = AttributeSet(output) + val sessionWindowOption = groupingExpressions.find { p => + p.metadata.contains(SessionWindow.marker) + } + + val groupingWithoutSessionExpressions = sessionWindowOption match { + case Some(sessionExpression) => + groupingExpressions.filterNot { p => p.semanticEquals(sessionExpression) } + + case None => groupingExpressions + } + override def requiredChildDistribution: Seq[Distribution] = { - if (groupingExpressions.isEmpty) { + if (groupingWithoutSessionExpressions.isEmpty) { AllTuples :: Nil } else { - ClusteredDistribution(groupingExpressions) :: Nil + ClusteredDistribution(groupingWithoutSessionExpressions) :: Nil } } + override def requiredChildOrdering: Seq[Seq[SortOrder]] = sessionWindowOption match { + case Some(sessionExpression) => + Seq((groupingWithoutSessionExpressions ++ Seq(sessionExpression)) + .map(SortOrder(_, Ascending))) + + case None => Seq(groupingExpressions.map(SortOrder(_, Ascending))) + } + private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = { udf.children match { case Seq(u: PythonUDF) => @@ -73,9 +93,6 @@ case class AggregateInPandasExec( } } - override def requiredChildOrdering: Seq[Seq[SortOrder]] = - Seq(groupingExpressions.map(SortOrder(_, Ascending))) - override protected def doExecute(): RDD[InternalRow] = { val inputRDD = child.execute() @@ -106,13 +123,15 @@ case class AggregateInPandasExec( }) inputRDD.mapPartitionsInternal { iter => + val newIter: Iterator[InternalRow] = mayAppendUpdatingSessionIterator(iter) + val prunedProj = UnsafeProjection.create(allInputs, child.output) val grouped = if (groupingExpressions.isEmpty) { // Use an empty unsafe row as a place holder for the grouping key - Iterator((new UnsafeRow(), iter)) + Iterator((new UnsafeRow(), newIter)) } else { - GroupedIterator(iter, groupingExpressions, child.output) + GroupedIterator(newIter, groupingExpressions, child.output) }.map { case (key, rows) => (key, rows.map(prunedProj)) } @@ -153,4 +172,20 @@ case class AggregateInPandasExec( } } } + + private def mayAppendUpdatingSessionIterator(iter: Iterator[InternalRow]) + : Iterator[InternalRow] = { + val newIter = sessionWindowOption match { + case Some(sessionExpression) => + val inMemoryThreshold = sqlContext.conf.windowExecBufferInMemoryThreshold + val spillThreshold = sqlContext.conf.windowExecBufferSpillThreshold + + new UpdatingSessionIterator(iter, groupingWithoutSessionExpressions, sessionExpression, + child.output, inMemoryThreshold, spillThreshold) + + case None => iter + } + + newIter + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index fad287e28877..3a55cfab5dde 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -119,6 +119,24 @@ class IncrementalExecution( stateFormatVersion, child) :: Nil)) + case SessionWindowStateStoreSaveExec(keys, session, None, None, None, + UnaryExecNode(agg, + SessionWindowStateStoreRestoreExec(_, _, None, None, child))) => + val aggStateInfo = nextStatefulOperationStateInfo + SessionWindowStateStoreSaveExec( + keys, + session, + Some(aggStateInfo), + Some(outputMode), + Some(offsetSeqMetadata.batchWatermarkMs), + agg.withNewChildren( + SessionWindowStateStoreRestoreExec( + keys, + session, + Some(aggStateInfo), + Some(offsetSeqMetadata.batchWatermarkMs), + child) :: Nil)) + case StreamingDeduplicateExec(keys, child, None, None) => StreamingDeduplicateExec( keys, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowLinkedListStateIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowLinkedListStateIterator.scala new file mode 100644 index 000000000000..9955aa5045ae --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowLinkedListStateIterator.scala @@ -0,0 +1,244 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.execution.streaming.state.SessionWindowLinkedListState + +// FIXME: javadoc!! +class MergingSortWithSessionWindowLinkedListStateIterator( + iter: Iterator[InternalRow], + state: SessionWindowLinkedListState, + groupWithoutSessionExpressions: Seq[Attribute], + sessionExpression: Attribute, + keysProjection: UnsafeProjection, + sessionProjection: UnsafeProjection, + inputSchema: Seq[Attribute]) extends Iterator[InternalRow] { + + def this( + iter: Iterator[InternalRow], + state: SessionWindowLinkedListState, + groupWithoutSessionExpressions: Seq[Attribute], + sessionExpression: Attribute, + inputSchema: Seq[Attribute]) { + this(iter, state, groupWithoutSessionExpressions, sessionExpression, + GenerateUnsafeProjection.generate(groupWithoutSessionExpressions, inputSchema), + GenerateUnsafeProjection.generate(Seq(sessionExpression), inputSchema), + inputSchema) + } + + private case class SessionRowInformation(keys: UnsafeRow, sessionStart: Long, sessionEnd: Long, + row: InternalRow) + + private object SessionRowInformation { + def of(row: InternalRow): SessionRowInformation = { + val keys = keysProjection(row).copy() + val session = sessionProjection(row).copy() + val sessionRow = session.getStruct(0, 2) + val sessionStart = sessionRow.getLong(0) + val sessionEnd = sessionRow.getLong(1) + + SessionRowInformation(keys, sessionStart, sessionEnd, row) + } + } + + private var lastKey: UnsafeRow = _ + private var currentRow: SessionRowInformation = _ + private var lastCheckpointOnStateRows: Option[Long] = _ + private var stateRowWaitForEmit: SessionRowInformation = _ + + private val keyOrdering: Ordering[UnsafeRow] = TypeUtils.getInterpretedOrdering( + groupWithoutSessionExpressions.toStructType).asInstanceOf[Ordering[UnsafeRow]] + + override def hasNext: Boolean = { + currentRow != null || iter.hasNext || stateRowWaitForEmit != null + } + + override def next(): InternalRow = { + if (currentRow == null) { + mayFillCurrentRow() + } + + if (currentRow == null && stateRowWaitForEmit == null) { + throw new IllegalStateException("No Row to provide in next() which should not happen!") + } + + // early return on input rows vs state row waiting for emitting + val returnCurrentRow = if (currentRow == null) { + false + } else if (stateRowWaitForEmit == null) { + true + } else { + // compare between current row and state row waiting for emitting + if (!keyOrdering.equiv(currentRow.keys, stateRowWaitForEmit.keys)) { + // state row cannot advance to row in input, so state row should be lower + false + } else { + currentRow.sessionStart < stateRowWaitForEmit.sessionStart + } + } + + // if state row should be emitted, do emit + if (!returnCurrentRow) { + val stateRow = stateRowWaitForEmit + stateRowWaitForEmit = null + return stateRow.row + } + + if (lastKey == null || !keyOrdering.equiv(lastKey, currentRow.keys)) { + // new key + stateRowWaitForEmit = null + lastCheckpointOnStateRows = None + lastKey = currentRow.keys + } + + // we don't need to check against sessions which are already candidate to emit + // so we apply checkpoint to skip some sessions + val stateSessionsEnclosingCurrentRow = findSessionPointerEnclosingEvent(currentRow, + startPointer = lastCheckpointOnStateRows) + + var prevSessionToEmit: Option[SessionRowInformation] = None + stateSessionsEnclosingCurrentRow match { + case None => + case Some(x) => + x._1 match { + case Some(prev) => + val prevSession = SessionRowInformation.of(state.get(currentRow.keys, prev)) + + val sessionLaterThanCheckpoint = lastCheckpointOnStateRows match { + case Some(lastCheckpoint) => lastCheckpoint < prevSession.sessionStart + case None => true + } + + if (sessionLaterThanCheckpoint) { + // based on definition of session window and the fact that events are sorted, + // if the state session is not matched to this event, it will not be matched with + // later events as well + lastCheckpointOnStateRows = Some(prevSession.sessionStart) + + if (isSessionsOverlap(currentRow, prevSession)) { + prevSessionToEmit = Some(prevSession) + } + } + + case None => + } + + x._2 match { + case Some(next) => + val nextSession = SessionRowInformation.of(state.get(currentRow.keys, next)) + + val sessionLaterThanCheckpoint = lastCheckpointOnStateRows match { + case Some(lastCheckpoint) => lastCheckpoint < nextSession.sessionStart + case None => true + } + + if (sessionLaterThanCheckpoint) { + // next session could be matched to latter events even it doesn't match to + // current event, so unless it is added to rows to emit, don't add to checked set + if (isSessionsOverlap(currentRow, nextSession)) { + stateRowWaitForEmit = nextSession + lastCheckpointOnStateRows = Some(nextSession.sessionStart) + } + } + + case None => + } + } + + // emitting sessions always follows the pattern: + // previous sessions if any -> current event -> (later events) -> next sessions + prevSessionToEmit match { + case Some(prevSession) => prevSession.row + case None => emitCurrentRow() + } + } + + private def emitCurrentRow(): InternalRow = { + val ret = currentRow + currentRow = null + ret.row + } + + private def mayFillCurrentRow(): Unit = { + if (iter.hasNext) { + currentRow = SessionRowInformation.of(iter.next()) + } + } + + private def findSessionPointerEnclosingEvent(row: SessionRowInformation, + startPointer: Option[Long]) + : Option[(Option[Long], Option[Long])] = { + val startOption = startPointer match { + case None => state.getFirstSessionStart(currentRow.keys) + case _ => startPointer + } + + startOption match { + // empty list + case None => None + case Some(start) => + var currOption: Option[Long] = Some(start) + + var enclosingSessions: Option[(Option[Long], Option[Long])] = None + while (enclosingSessions.isEmpty && currOption.isDefined) { + val curr = currOption.get + val newPrev = state.getPrevSessionStart(currentRow.keys, curr) + val newNext = state.getNextSessionStart(currentRow.keys, curr) + + val isEventEnclosed = newPrev match { + case Some(prev) => + prev <= currentRow.sessionStart && currentRow.sessionStart <= curr + case None => currentRow.sessionStart <= curr + } + + val willNotBeEnclosed = newPrev match { + case Some(prev) => prev > currentRow.sessionStart + case None => false + } + + if (isEventEnclosed) { + enclosingSessions = Some((newPrev, currOption)) + } else if (willNotBeEnclosed) { + enclosingSessions = Some((None, None)) + } else if (newNext.isEmpty) { + // curr is the last session in state + if (currentRow.sessionStart >= curr) { + enclosingSessions = Some((currOption, None)) + } else { + enclosingSessions = Some((None, None)) + } + } + + currOption = newNext + } + + // enclosingSessions should not be None unless list is empty + enclosingSessions + } + } + + private def isSessionsOverlap(s1: SessionRowInformation, s2: SessionRowInformation): Boolean = { + (s1.sessionStart >= s2.sessionStart && s1.sessionStart <= s2.sessionEnd) || + (s2.sessionStart >= s1.sessionStart && s2.sessionStart <= s1.sessionEnd) + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListState.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListState.scala new file mode 100644 index 000000000000..eaac81d23724 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListState.scala @@ -0,0 +1,779 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import java.util.Locale + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.TaskContext +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, Literal, SpecificInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo +import org.apache.spark.sql.types.{LongType, StructField, StructType} +import org.apache.spark.util.NextIterator + +// FIXME: javadoc!! +class SessionWindowLinkedListState( + storeNamePrefix: String, + inputValueAttributes: Seq[Attribute], + keys: Seq[Expression], + stateInfo: Option[StatefulOperatorStateInfo], + storeConf: StateStoreConf, + hadoopConf: Configuration) extends Logging { + + import SessionWindowLinkedListState._ + + /* + ===================================================== + Public methods + ===================================================== + */ + + def get(key: UnsafeRow): Iterator[UnsafeRow] = { + keyToHeadSessionStartStore.get(key) match { + case Some(headSessionStart) => + new NextIterator[UnsafeRow] { + var curSessionStart: Option[Long] = Some(headSessionStart) + + override protected def getNext(): UnsafeRow = { + curSessionStart match { + case Some(sessionStart) => + val ret = keyAndSessionStartToValueStore.get(key, sessionStart) + curSessionStart = keyAndSessionStartToPointerStore.get(key, sessionStart)._2 + ret + + case None => + finished = true + null + } + } + + override protected def close(): Unit = {} + } + + case None => + Seq.empty[UnsafeRow].iterator + } + } + + def get(key: UnsafeRow, sessionStart: Long): UnsafeRow = { + keyAndSessionStartToValueStore.get(key, sessionStart) + } + + def iteratePointers(key: UnsafeRow): Iterator[(Long, Option[Long], Option[Long])] = { + keyToHeadSessionStartStore.get(key) match { + case Some(headSessionStart) => + new NextIterator[(Long, Option[Long], Option[Long])] { + var curSessionStart: Option[Long] = Some(headSessionStart) + + override protected def getNext(): (Long, Option[Long], Option[Long]) = { + curSessionStart match { + case Some(sessionStart) => + val ret = keyAndSessionStartToPointerStore.get(key, sessionStart) + assertValidPointer(ret) + curSessionStart = ret._2 + (sessionStart, ret._1, ret._2) + + case None => + finished = true + null + } + } + + override protected def close(): Unit = {} + } + + case None => + Seq.empty[(Long, Option[Long], Option[Long])].iterator + } + } + + def setHead(key: UnsafeRow, sessionStart: Long, value: UnsafeRow): Unit = { + require(keyToHeadSessionStartStore.get(key).isEmpty, "Head should not be exist.") + + keyToHeadSessionStartStore.put(key, sessionStart) + keyAndSessionStartToPointerStore.put(key, sessionStart, None, None) + keyAndSessionStartToValueStore.put(key, sessionStart, value) + } + + def addBefore(key: UnsafeRow, sessionStart: Long, value: UnsafeRow, + targetSessionStart: Long): Unit = { + val targetPointer = keyAndSessionStartToPointerStore.get(key, targetSessionStart) + assertValidPointer(targetPointer) + + targetPointer._1 match { + case Some(prev) => + keyAndSessionStartToPointerStore.updateNext(key, prev, Some(sessionStart)) + keyAndSessionStartToPointerStore.updatePrev(key, targetSessionStart, Some(sessionStart)) + keyAndSessionStartToPointerStore.put(key, sessionStart, + Some(prev), Some(targetSessionStart)) + + case None => + // we're changing head + keyAndSessionStartToPointerStore.updatePrev(key, targetSessionStart, Some(sessionStart)) + keyAndSessionStartToPointerStore.put(key, sessionStart, None, Some(targetSessionStart)) + keyToHeadSessionStartStore.put(key, sessionStart) + } + + keyAndSessionStartToValueStore.put(key, sessionStart, value) + } + + def addAfter(key: UnsafeRow, sessionStart: Long, value: UnsafeRow, + targetSessionStart: Long): Unit = { + val targetPointer = keyAndSessionStartToPointerStore.get(key, targetSessionStart) + assertValidPointer(targetPointer) + + targetPointer._2 match { + case Some(next) => + keyAndSessionStartToPointerStore.updatePrev(key, next, Some(sessionStart)) + keyAndSessionStartToPointerStore.updateNext(key, targetSessionStart, Some(sessionStart)) + keyAndSessionStartToPointerStore.put(key, sessionStart, Some(targetSessionStart), + Some(next)) + + case None => + keyAndSessionStartToPointerStore.updateNext(key, targetSessionStart, Some(sessionStart)) + keyAndSessionStartToPointerStore.put(key, sessionStart, Some(targetSessionStart), None) + } + + keyAndSessionStartToValueStore.put(key, sessionStart, value) + } + + def update(key: UnsafeRow, sessionStart: Long, newValue: UnsafeRow): Unit = { + val targetPointer = keyAndSessionStartToPointerStore.get(key, sessionStart) + assertValidPointer(targetPointer) + keyAndSessionStartToValueStore.put(key, sessionStart, newValue) + } + + def isEmpty(key: UnsafeRow): Boolean = { + keyToHeadSessionStartStore.get(key).isEmpty + } + + def findFirstSessionStartEnsurePredicate(key: UnsafeRow, predicate: Long => Boolean, + startIndex: Long): Option[Long] = { + + val pointers = keyAndSessionStartToPointerStore.get(key, startIndex) + assertValidPointer(pointers) + + var currentSessionStart: Option[Long] = Some(startIndex) + var ret: Option[Long] = None + var found = false + + while (!found && currentSessionStart.isDefined) { + val cur = currentSessionStart.get + if (predicate.apply(cur)) { + ret = Some(cur) + found = true + } else { + currentSessionStart = getNextSessionStart(key, cur) + } + } + + ret + } + + def findFirstSessionStartEnsurePredicate(key: UnsafeRow, predicate: Long => Boolean) + : Option[Long] = { + val head = keyToHeadSessionStartStore.get(key) + if (head.isEmpty) { + return None + } + + findFirstSessionStartEnsurePredicate(key, predicate, head.get) + } + + def getSessionStartOnNearest(key: UnsafeRow, sessionStart: Long): (Option[Long], Option[Long]) = { + keyAndSessionStartToPointerStore.get(key, sessionStart) + } + + def getPrevSessionStart(key: UnsafeRow, sessionStart: Long): Option[Long] = { + val pointers = keyAndSessionStartToPointerStore.get(key, sessionStart) + assertValidPointer(pointers) + pointers._1 + } + + def getNextSessionStart(key: UnsafeRow, sessionStart: Long): Option[Long] = { + val pointers = keyAndSessionStartToPointerStore.get(key, sessionStart) + assertValidPointer(pointers) + pointers._2 + } + + // FIXME: cover with test cases + def getFirstSessionStart(key: UnsafeRow): Option[Long] = { + keyToHeadSessionStartStore.get(key) + } + + // FIXME: cover with test cases + def getLastSessionStart(key: UnsafeRow): Option[Long] = { + getFirstSessionStart(key) match { + case Some(start) => getLastSessionStart(key, start) + case None => None + } + } + + // FIXME: cover with test cases + def getLastSessionStart(key: UnsafeRow, startIndex: Long): Option[Long] = { + val pointers = keyAndSessionStartToPointerStore.get(key, startIndex) + assertValidPointer(pointers) + + var lastSessionStart = startIndex + while (getNextSessionStart(key, lastSessionStart).isDefined) { + lastSessionStart = getNextSessionStart(key, lastSessionStart).get + } + + Some(lastSessionStart) + } + + def remove(key: UnsafeRow, sessionStart: Long): Unit = { + val targetPointer = keyAndSessionStartToPointerStore.get(key, sessionStart) + assertValidPointer(targetPointer) + + val prevOption = targetPointer._1 + val nextOption = targetPointer._2 + + keyAndSessionStartToPointerStore.remove(key, sessionStart) + keyAndSessionStartToValueStore.remove(key, sessionStart) + + targetPointer match { + case (Some(prev), Some(next)) => + keyAndSessionStartToPointerStore.updateNext(key, prev, nextOption) + keyAndSessionStartToPointerStore.updatePrev(key, next, prevOption) + + case (Some(prev), None) => + keyAndSessionStartToPointerStore.updateNext(key, prev, None) + + case (None, Some(next)) => + keyAndSessionStartToPointerStore.updatePrev(key, next, None) + keyToHeadSessionStartStore.put(key, next) + + case (None, None) => + if (keyToHeadSessionStartStore.get(key).get != sessionStart) { + throw new IllegalStateException("The element has pointer information for head, " + + "but the list has different head.") + } + + keyToHeadSessionStartStore.remove(key) + } + + } + + def removeByValueCondition(removalCondition: UnsafeRow => Boolean, + stopOnConditionMismatch: Boolean = false): Iterator[UnsafeRowPair] = { + new NextIterator[UnsafeRowPair] { + + // Reuse this object to avoid creation+GC overhead. + private val reusedPair = new UnsafeRowPair() + + private val allKeysToHeadSessionStarts = keyToHeadSessionStartStore.iterator + + private var currentKey: UnsafeRow = null + private var currentSessionStart: Option[Long] = None + + override protected def getNext(): UnsafeRowPair = { + + // first setup + if (currentKey == null) { + if (!setupNextKey()) { + finished = true + return null + } + } + + val retVal = findNextValueToRemove() + if (retVal == null) { + finished = true + return null + } + + reusedPair.withRows(currentKey.copy(), retVal) + } + + override protected def close(): Unit = {} + + private def setupNextKey(): Boolean = { + if (!allKeysToHeadSessionStarts.hasNext) { + false + } else { + val keyAndHeadSessionStart = allKeysToHeadSessionStarts.next() + currentKey = keyAndHeadSessionStart.key.copy() + currentSessionStart = Some(keyAndHeadSessionStart.sessionStart) + true + } + } + + private def findNextValueToRemove(): UnsafeRow = { + var nextValue: UnsafeRow = null + while (nextValue == null) { + currentSessionStart match { + case Some(sessionStart) => + val pointers = keyAndSessionStartToPointerStore.get(currentKey, sessionStart) + val session = keyAndSessionStartToValueStore.get(currentKey, sessionStart) + + if (pointers == null || session == null) { + throw new IllegalStateException("Should not happen!") + } + + if (removalCondition(session)) { + nextValue = session + remove(currentKey, sessionStart) + currentSessionStart = pointers._2 + } else { + if (stopOnConditionMismatch) { + currentSessionStart = None + } else { + currentSessionStart = pointers._2 + } + } + + case None => + if (!setupNextKey()) { + return null + } + } + } + + nextValue + } + } + } + + def getAllRowPairs: Iterator[UnsafeRowPair] = { + new NextIterator[UnsafeRowPair] { + // Reuse this object to avoid creation+GC overhead. + private val reusedPair = new UnsafeRowPair() + + private val allKeysToHeadSessionStarts = keyToHeadSessionStartStore.iterator + + private var currentKey: UnsafeRow = _ + private var currentSessionStart: Option[Long] = None + + override def getNext(): UnsafeRowPair = { + // first setup + if (currentKey == null) { + if (!setupNextKey()) { + finished = true + return null + } + } + + val nextValue = findNextValue() + if (nextValue == null) { + finished = true + return null + } + + reusedPair.withRows(currentKey, nextValue) + } + + override def close(): Unit = {} + + private def setupNextKey(): Boolean = { + if (!allKeysToHeadSessionStarts.hasNext) { + false + } else { + val keyAndHeadSessionStart = allKeysToHeadSessionStarts.next() + currentKey = keyAndHeadSessionStart.key.copy() + currentSessionStart = Some(keyAndHeadSessionStart.sessionStart) + true + } + } + + private def findNextValue(): UnsafeRow = { + var nextValue: UnsafeRow = null + while (nextValue == null) { + currentSessionStart match { + case Some(sessionStart) => + val pointers = keyAndSessionStartToPointerStore.get(currentKey, sessionStart) + val session = keyAndSessionStartToValueStore.get(currentKey, sessionStart) + + currentSessionStart = pointers._2 + nextValue = session + + case None => + if (!setupNextKey()) { + finished = true + return null + } + } + } + + nextValue + } + + } + } + + /** Commit all the changes to all the state stores */ + def commit(): Unit = { + keyToHeadSessionStartStore.commit() + keyAndSessionStartToPointerStore.commit() + keyAndSessionStartToValueStore.commit() + } + + /** Abort any changes to the state stores if needed */ + def abortIfNeeded(): Unit = { + keyToHeadSessionStartStore.abortIfNeeded() + keyAndSessionStartToPointerStore.abortIfNeeded() + keyAndSessionStartToValueStore.abortIfNeeded() + } + + /** Get the combined metrics of all the state stores */ + def metrics: StateStoreMetrics = { + val keyToHeadSessionStartMetrics = keyToHeadSessionStartStore.metrics + val keyAndSessionStartToPointerMetrics = keyAndSessionStartToPointerStore.metrics + val keyAndSessionStartToValueMetrics = keyAndSessionStartToValueStore.metrics + def newDesc(desc: String): String = s"${storeNamePrefix.toUpperCase(Locale.ROOT)}: $desc" + + val totalSize = keyToHeadSessionStartMetrics.memoryUsedBytes + + keyAndSessionStartToPointerMetrics.memoryUsedBytes + + keyAndSessionStartToValueMetrics.memoryUsedBytes + StateStoreMetrics( + keyAndSessionStartToValueMetrics.numKeys, // represent each buffered row only once + totalSize, + keyAndSessionStartToValueMetrics.customMetrics.map { + case (s @ StateStoreCustomSumMetric(_, desc), value) => + s.copy(desc = newDesc(desc)) -> value + case (s @ StateStoreCustomSizeMetric(_, desc), value) => + s.copy(desc = newDesc(desc)) -> value + case (s @ StateStoreCustomTimingMetric(_, desc), value) => + s.copy(desc = newDesc(desc)) -> value + case (s, _) => + throw new IllegalArgumentException( + s"Unknown state store custom metric is found at metrics: $s") + } + ) + } + + private[sql] def getIteratorOfHeadPointers: Iterator[KeyAndHeadSessionStart] = { + keyToHeadSessionStartStore.iterator + } + + private[sql] def getIteratorOfRawPointers: Iterator[KeyWithSessionStartAndPointers] = { + keyAndSessionStartToPointerStore.iterator + } + + private[sql] def getIteratorOfRawValues: Iterator[KeyWithSessionStartAndValue] = { + keyAndSessionStartToValueStore.iterator + } + + /* + ===================================================== + Private methods and inner classes + ===================================================== + */ + + private def assertValidPointer(targetPointer: (Option[Long], Option[Long])): Unit = { + if (targetPointer == null) { + throw new IllegalArgumentException("Invalid pointer is provided.") + } + } + + private val keySchema = StructType( + keys.zipWithIndex.map { case (k, i) => StructField(s"field$i", k.dataType, k.nullable) }) + private val keyAttributes = keySchema.toAttributes + + private val keyToHeadSessionStartStore = new KeyToHeadSessionStartStore() + private val keyAndSessionStartToPointerStore = new KeyAndSessionStartToPointerStore() + private val keyAndSessionStartToValueStore = new KeyAndSessionStartToValueStore() + + // Clean up any state store resources if necessary at the end of the task + Option(TaskContext.get()).foreach { _.addTaskCompletionListener[Unit] { _ => abortIfNeeded() } } + + /** Helper trait for invoking common functionalities of a state store. */ + private abstract class StateStoreHandler(stateStoreType: StateStoreType) extends Logging { + + /** StateStore that the subclasses of this class is going to operate on */ + protected def stateStore: StateStore + + def commit(): Unit = { + stateStore.commit() + logDebug("Committed, metrics = " + stateStore.metrics) + } + + def abortIfNeeded(): Unit = { + if (!stateStore.hasCommitted) { + logInfo(s"Aborted store ${stateStore.id}") + stateStore.abort() + } + } + + def metrics: StateStoreMetrics = stateStore.metrics + + /** Get the StateStore with the given schema */ + protected def getStateStore(keySchema: StructType, valueSchema: StructType): StateStore = { + val storeProviderId = StateStoreProviderId(stateInfo.get, TaskContext.getPartitionId(), + getStateStoreName(storeNamePrefix, stateStoreType)) + val store = StateStore.get( + storeProviderId, keySchema, valueSchema, None, + stateInfo.get.storeVersion, storeConf, hadoopConf) + logInfo(s"Loaded store ${store.id}") + store + } + } + + /** + * Helper class for representing data returned by [[KeyToHeadSessionStartStore]]. + * Designed for object reuse. + */ + private[state] case class KeyAndHeadSessionStart(var key: UnsafeRow = null, + var sessionStart: Long = 0) { + def withNew(newKey: UnsafeRow, newSessionStart: Long): this.type = { + this.key = newKey + this.sessionStart = newSessionStart + this + } + } + + /** + * Helper class for representing data returned by [[KeyAndSessionStartToPointerStore]]. + * Designed for object reuse. + */ + private[state] case class KeyWithSessionStartAndPointers( + var key: UnsafeRow = null, + var sessionStart: Long = 0, + var prevSessionStart: Option[Long] = None, + var nextSessionStart: Option[Long] = None) { + def withNew(newKey: UnsafeRow, sessionStart: Long, prevSessionStart: Option[Long], + nextSessionStart: Option[Long]): this.type = { + this.key = newKey + this.sessionStart = sessionStart + this.prevSessionStart = prevSessionStart + this.nextSessionStart = nextSessionStart + this + } + } + + /** + * Helper class for representing data returned by [[KeyAndSessionStartToValueStore]]. + * Designed for object reuse. + */ + private[state] case class KeyWithSessionStartAndValue( + var key: UnsafeRow = null, + var sessionStart: Long = 0, + var value: UnsafeRow = null) { + def withNew(newKey: UnsafeRow, sessionStart: Long, newValue: UnsafeRow): this.type = { + this.key = newKey + this.sessionStart = sessionStart + this.value = newValue + this + } + } + + private class KeyToHeadSessionStartStore extends StateStoreHandler(KeyToHeadSessionStartType) { + private val longValueSchema = new StructType().add("value", "long") + private val longToUnsafeRow = UnsafeProjection.create(longValueSchema) + private val valueRow = longToUnsafeRow(new SpecificInternalRow(longValueSchema)) + protected val stateStore: StateStore = getStateStore(keySchema, longValueSchema) + + /** Get the head of list via session start the key has */ + def get(key: UnsafeRow): Option[Long] = { + val longValueRow = stateStore.get(key) + if (longValueRow != null) { + Some(longValueRow.getLong(0)) + } else { + None + } + } + + /** Set the head of list via session start the key has */ + def put(key: UnsafeRow, sessionStart: Long): Unit = { + valueRow.setLong(0, sessionStart) + stateStore.put(key, valueRow) + } + + def remove(key: UnsafeRow): Unit = { + stateStore.remove(key) + } + + def iterator: Iterator[KeyAndHeadSessionStart] = { + val keyAndHeadSessionStart = KeyAndHeadSessionStart() + stateStore.getRange(None, None).map { pair => + keyAndHeadSessionStart.withNew(pair.key, pair.value.getLong(0)) + } + } + } + + private abstract class KeyAndSessionStartAsKeyStore(t: StateStoreType) + extends StateStoreHandler(t) { + protected val keyWithSessionStartExprs = keyAttributes :+ Literal(1L) + protected val keyWithSessionStartSchema = keySchema.add("sessionStart", LongType) + protected val indexOrdinalInKeyWithSessionStartRow = keyAttributes.size + + // Projection to generate (key + session start) row from key row + protected val keyWithSessionStartRowGenerator = UnsafeProjection.create( + keyWithSessionStartExprs, keyAttributes) + + // Projection to generate key row from (key + index) row + protected val keyRowGenerator = UnsafeProjection.create( + keyAttributes, keyAttributes :+ AttributeReference("sessionStart", LongType)()) + + /** Generated a row using the key and session start */ + protected def keyWithSessionStartRow(key: UnsafeRow, sessionStart: Long): UnsafeRow = { + val row = keyWithSessionStartRowGenerator(key) + row.setLong(indexOrdinalInKeyWithSessionStartRow, sessionStart) + row + } + } + + private class KeyAndSessionStartToPointerStore extends KeyAndSessionStartAsKeyStore( + KeyAndSessionStartToPointerType) { + private val doublyPointersValueSchema = new StructType() + .add("prev", "long", nullable = true).add("next", "long", nullable = true) + private val doublyPointersToUnsafeRow = UnsafeProjection.create(doublyPointersValueSchema) + private val valueRow = doublyPointersToUnsafeRow( + new SpecificInternalRow(doublyPointersValueSchema)) + protected val stateStore: StateStore = getStateStore(keySchema, doublyPointersValueSchema) + + /** Get the prev/next pointer of current session */ + def get(key: UnsafeRow, sessionStart: Long): (Option[Long], Option[Long]) = { + val actualRow = stateStore.get(keyWithSessionStartRow(key, sessionStart)) + if (actualRow != null) { + (getPrevSessionStart(actualRow), getNextSessionStart(actualRow)) + } else { + null + } + } + + def updatePrev(key: UnsafeRow, sessionStart: Long, prevSessionStart: Option[Long]): Unit = { + val actualKeyRow = keyWithSessionStartRow(key, sessionStart) + val row = stateStore.get(actualKeyRow).copy() + setPrevSessionStart(row, prevSessionStart) + stateStore.put(actualKeyRow, row) + } + + def updateNext(key: UnsafeRow, sessionStart: Long, nextSessionStart: Option[Long]): Unit = { + val actualKeyRow = keyWithSessionStartRow(key, sessionStart) + val row = stateStore.get(actualKeyRow).copy() + setNextSessionStart(row, nextSessionStart) + stateStore.put(actualKeyRow, row) + } + + /** Set the head of list via session start the key has */ + def put(key: UnsafeRow, sessionStart: Long, prevSessionStart: Option[Long], + nextSessionStart: Option[Long]): Unit = { + setPrevSessionStart(valueRow, prevSessionStart) + setNextSessionStart(valueRow, nextSessionStart) + stateStore.put(keyWithSessionStartRow(key, sessionStart), valueRow) + } + + def remove(key: UnsafeRow, sessionStart: Long): Unit = { + stateStore.remove(keyWithSessionStartRow(key, sessionStart)) + } + + def iterator: Iterator[KeyWithSessionStartAndPointers] = { + val keyWithSessionStartAndPointers = KeyWithSessionStartAndPointers() + stateStore.getRange(None, None).map { pair => + val keyPart = keyRowGenerator(pair.key) + val sessionStart = pair.key.getLong(indexOrdinalInKeyWithSessionStartRow) + val prevSessionStart = getPrevSessionStart(pair.value) + val nextSessionStart = getNextSessionStart(pair.value) + keyWithSessionStartAndPointers.withNew(keyPart, sessionStart, prevSessionStart, + nextSessionStart) + } + } + + private def getPrevSessionStart(value: UnsafeRow): Option[Long] = { + if (value.isNullAt(0)) { + None + } else { + Some(value.getLong(0)) + } + } + + private def setPrevSessionStart(value: UnsafeRow, sessionStart: Option[Long]): Unit = { + sessionStart match { + case Some(l) => value.setLong(0, l) + case None => value.setNullAt(0) + } + } + + private def getNextSessionStart(value: UnsafeRow): Option[Long] = { + if (value.isNullAt(1)) { + None + } else { + Some(value.getLong(1)) + } + } + + private def setNextSessionStart(value: UnsafeRow, sessionStart: Option[Long]): Unit = { + sessionStart match { + case Some(l) => value.setLong(1, l) + case None => value.setNullAt(1) + } + } + } + + private class KeyAndSessionStartToValueStore extends KeyAndSessionStartAsKeyStore( + KeyAndSessionStartToValueType) { + protected val stateStore = getStateStore(keyWithSessionStartSchema, + inputValueAttributes.toStructType) + + def get(key: UnsafeRow, sessionStart: Long): UnsafeRow = { + stateStore.get(keyWithSessionStartRow(key, sessionStart)) + } + + /** Put new value for key at the given index */ + def put(key: UnsafeRow, sessionStart: Long, value: UnsafeRow): Unit = { + val keyWithSessionStart = keyWithSessionStartRow(key, sessionStart) + stateStore.put(keyWithSessionStart, value) + } + + /** + * Remove key and value at given session start. + */ + def remove(key: UnsafeRow, sessionStart: Long): Unit = { + stateStore.remove(keyWithSessionStartRow(key, sessionStart)) + } + + def iterator: Iterator[KeyWithSessionStartAndValue] = { + val keyWithSessionStartAndValue = KeyWithSessionStartAndValue() + stateStore.getRange(None, None).map { pair => + val keyPart = keyRowGenerator(pair.key) + val sessionStart = pair.key.getLong(indexOrdinalInKeyWithSessionStartRow) + val value = pair.value + keyWithSessionStartAndValue.withNew(keyPart, sessionStart, value) + } + } + } +} + +object SessionWindowLinkedListState { + sealed trait StateStoreType + + case object KeyToHeadSessionStartType extends StateStoreType { + override def toString(): String = "keyToHeadSessionStart" + } + + case object KeyAndSessionStartToPointerType extends StateStoreType { + override def toString(): String = "keyAndSessionStartToPointer" + } + + case object KeyAndSessionStartToValueType extends StateStoreType { + override def toString(): String = "keyAndSessionStartToValue" + } + + def getStateStoreName(storeNamePrefix: String, storeType: StateStoreType): String = { + s"$storeNamePrefix-$storeType" + } + + def getAllStateStoreName(storeNamePrefix: String): Seq[String] = { + val allStateStoreTypes: Seq[StateStoreType] = Seq(KeyToHeadSessionStartType, + KeyAndSessionStartToPointerType, KeyAndSessionStartToValueType) + allStateStoreTypes.map(getStateStoreName(storeNamePrefix, _)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListStateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListStateStoreRDD.scala new file mode 100644 index 000000000000..51bb709deb20 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListStateStoreRDD.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import scala.reflect.ClassTag + +import org.apache.spark.{Partition, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo +import org.apache.spark.sql.execution.streaming.continuous.EpochTracker +import org.apache.spark.sql.internal.SessionState +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.SerializableConfiguration + +// FIXME: javadoc!! +class SessionWindowLinkedListStateStoreRDD[T: ClassTag, U: ClassTag]( + dataRDD: RDD[T], + storeUpdateFunction: (SessionWindowLinkedListState, Iterator[T]) => Iterator[U], + stateInfo: StatefulOperatorStateInfo, + keySchema: StructType, + valueSchema: StructType, + indexOrdinal: Option[Int], + sessionState: SessionState, + @transient private val storeCoordinator: Option[StateStoreCoordinatorRef]) + extends RDD[U](dataRDD) { + + private val storeConf = new StateStoreConf(sessionState.conf) + + // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it + private val hadoopConfBroadcast = dataRDD.context.broadcast( + new SerializableConfiguration(sessionState.newHadoopConf())) + + private val stateStorePrefix: String = s"sessionwindow-${stateInfo.operatorId}" + + override protected def getPartitions: Array[Partition] = dataRDD.partitions + + /** + * Set the preferred location of each partition using the executor that has the related + * [[StateStoreProvider]] already loaded. + */ + override def getPreferredLocations(partition: Partition): Seq[String] = { + SessionWindowLinkedListState.getAllStateStoreName(stateStorePrefix).flatMap { storeName => + val stateStoreProviderId = StateStoreProviderId(stateInfo, partition.index, storeName) + storeCoordinator.flatMap(_.getLocation(stateStoreProviderId)) + }.distinct + } + + override def compute(partition: Partition, ctxt: TaskContext): Iterator[U] = { + // If we're in continuous processing mode, we should get the store version for the current + // epoch rather than the one at planning time. + val currentVersion = EpochTracker.getCurrentEpoch match { + case None => stateInfo.storeVersion + case Some(value) => value + } + + val modifiedStateInfo = stateInfo.copy(storeVersion = currentVersion) + + val state = new SessionWindowLinkedListState(stateStorePrefix, + valueSchema.toAttributes, keySchema.toAttributes, Some(modifiedStateInfo), storeConf, + hadoopConfBroadcast.value.value) + + val inputIter = dataRDD.iterator(partition, ctxt) + storeUpdateFunction(state, inputIter) + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index b6021438e902..0495da280ae6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -80,5 +80,48 @@ package object state { sessionState, storeCoordinator) } + + /** Map each partition of an RDD along with data in a [[SessionWindowLinkedListState]]. */ + def mapPartitionsWithSessionWindowLinkedListState[U: ClassTag]( + sqlContext: SQLContext, + stateInfo: StatefulOperatorStateInfo, + keySchema: StructType, + valueSchema: StructType, + indexOrdinal: Option[Int])( + storeUpdateFunction: (SessionWindowLinkedListState, Iterator[T]) => Iterator[U]) + : SessionWindowLinkedListStateStoreRDD[T, U] = { + + mapPartitionsWithSessionWindowLinkedListState( + stateInfo, + keySchema, + valueSchema, + indexOrdinal, + sqlContext.sessionState, + Some(sqlContext.streams.stateStoreCoordinator))( + storeUpdateFunction) + } + + /** Map each partition of an RDD along with data in a [[SessionWindowLinkedListState]]. */ + private[streaming] def mapPartitionsWithSessionWindowLinkedListState[U: ClassTag]( + stateInfo: StatefulOperatorStateInfo, + keySchema: StructType, + valueSchema: StructType, + indexOrdinal: Option[Int], + sessionState: SessionState, + storeCoordinator: Option[StateStoreCoordinatorRef])( + storeUpdateFunction: (SessionWindowLinkedListState, Iterator[T]) => Iterator[U]) + : SessionWindowLinkedListStateStoreRDD[T, U] = { + + val cleanedF = dataRDD.sparkContext.clean(storeUpdateFunction) + new SessionWindowLinkedListStateStoreRDD( + dataRDD, + cleanedF, + stateInfo, + keySchema, + valueSchema, + indexOrdinal, + sessionState, + storeCoordinator) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index c11af345b024..e0a98bb064ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjecti import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.execution.streaming.state._ @@ -424,6 +425,373 @@ case class StateStoreSaveExec( } } +// FIXME: javadoc! +case class SessionWindowStateStoreRestoreExec( + keyWithoutSessionExpressions: Seq[Attribute], + sessionExpression: Attribute, + stateInfo: Option[StatefulOperatorStateInfo], + eventTimeWatermark: Option[Long], + child: SparkPlan) + extends UnaryExecNode with StateStoreReader with WatermarkSupport { + + override def keyExpressions: Seq[Attribute] = keyWithoutSessionExpressions + + override protected def doExecute(): RDD[InternalRow] = { + val numOutputRows = longMetric("numOutputRows") + + child.execute().mapPartitionsWithSessionWindowLinkedListState( + getStateInfo, + keyExpressions.toStructType, + child.output.toStructType, + indexOrdinal = None, + sqlContext.sessionState, + Some(sqlContext.streams.stateStoreCoordinator)) { case (state, iter) => + + val keyWithoutSessionProjection = GenerateUnsafeProjection.generate( + keyWithoutSessionExpressions, child.output) + val sessionProjection = GenerateUnsafeProjection.generate(Seq(sessionExpression), + child.output) + + // We need to filter out outdated inputs + val filteredIterator = watermarkPredicateForData match { + case Some(predicate) => iter.filter((row: InternalRow) => !predicate.eval(row)) + case None => iter + } + + new MergingSortWithSessionWindowLinkedListStateIterator( + filteredIterator, + state, + keyWithoutSessionExpressions, + sessionExpression, + keyWithoutSessionProjection, + sessionProjection, + child.output).map { row => + numOutputRows += 1 + row + } + } + } + + override def output: Seq[Attribute] = child.output + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def outputOrdering: Seq[SortOrder] = { + (keyWithoutSessionExpressions ++ Seq(sessionExpression)).map(SortOrder(_, Ascending)) + } + + override def requiredChildDistribution: Seq[Distribution] = { + if (keyWithoutSessionExpressions.isEmpty) { + AllTuples :: Nil + } else { + ClusteredDistribution(keyWithoutSessionExpressions, stateInfo.map(_.numPartitions)) :: Nil + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = { + Seq((keyWithoutSessionExpressions ++ Seq(sessionExpression)).map(SortOrder(_, Ascending))) + } +} + +/** + * For each input tuple, the key is calculated and sessions are being `put` into + * the [[SessionWindowLinkedListState]]. + */ +case class SessionWindowStateStoreSaveExec( + keyWithoutSessionExpressions: Seq[Attribute], + sessionExpression: Attribute, + stateInfo: Option[StatefulOperatorStateInfo] = None, + outputMode: Option[OutputMode] = None, + eventTimeWatermark: Option[Long] = None, + child: SparkPlan) + extends UnaryExecNode with StateStoreWriter with WatermarkSupport { + + override def keyExpressions: Seq[Attribute] = keyWithoutSessionExpressions + + override protected def doExecute(): RDD[InternalRow] = { + metrics // force lazy init at driver + assert(outputMode.nonEmpty, + "Incorrect planning in IncrementalExecution, outputMode has not been set") + + child.execute().mapPartitionsWithSessionWindowLinkedListState( + getStateInfo, + keyWithoutSessionExpressions.toStructType, + child.output.toStructType, + indexOrdinal = None, + sqlContext.sessionState, + Some(sqlContext.streams.stateStoreCoordinator)) { (state, iter) => + + val numOutputRows = longMetric("numOutputRows") + val numUpdatedStateRows = longMetric("numUpdatedStateRows") + val allUpdatesTimeMs = longMetric("allUpdatesTimeMs") + val allRemovalsTimeMs = longMetric("allRemovalsTimeMs") + val commitTimeMs = longMetric("commitTimeMs") + + val keyProjection = GenerateUnsafeProjection.generate(keyExpressions, child.output) + val sessionProjection = GenerateUnsafeProjection.generate(Seq(sessionExpression), + child.output) + + val keyOrdering = TypeUtils.getInterpretedOrdering(keyExpressions.toStructType) + .asInstanceOf[Ordering[UnsafeRow]] + + var lastSearchedSessionStartOption: Option[Long] = None + var stateFetchedKey: UnsafeRow = null + + def reflectNewSession(row: UnsafeRow): Boolean = { + val key = keyProjection(row) + val session = sessionProjection(row).getStruct(0, 2) + val sessionStart = session.getLong(0) + val sessionEnd = session.getLong(1) + + if (state.isEmpty(key)) { + state.setHead(key, sessionStart, row) + return true + } + + // need to find sessions which could be replaced with new session + // new session should enclose previous session(s) if it overlaps, + // since session always expands + + val nearestSessions = state.getSessionStartOnNearest(key, sessionStart) + if (nearestSessions != null) { + // there's rare chance that existing session and row is equivalent + // because in MergingSortWithSessionWindowLinkedListStateIterator, + // we emit existing sessions only when it overlaps with input row + // so unless aggregation make no difference, it will not happen + // always replace instead of comparing with actual value + + // if the old session can be replaced with new session, + // (condition: 1:1 match, no change on "session start") + // just replace it to avoid overhead on manipulating linked list + + nearestSessions._2 match { + case Some(next) if next > sessionEnd => + state.update(key, sessionStart, row) + return true + case None => + state.update(key, sessionStart, row) + return true + + case _ => + } + } + + if (stateFetchedKey == null || keyOrdering.equiv(stateFetchedKey, key)) { + stateFetchedKey = key + lastSearchedSessionStartOption = None + } + + // find the first state session which is enclosed by new session + val firstStateSessionEnclosedByNewSession = lastSearchedSessionStartOption match { + case Some(lastSearchedSessionStart) => + state.findFirstSessionStartEnsurePredicate(key, start => start >= sessionStart, + lastSearchedSessionStart) + + case None => + state.findFirstSessionStartEnsurePredicate(key, start => start >= sessionStart) + } + + firstStateSessionEnclosedByNewSession match { + case Some(firstStateSessionStart) => + // get previous earlier to enable addAfter on new session after removal + val prevForFirstStateSession = state.getPrevSessionStart(key, firstStateSessionStart) + + // search and remove sessions which is enclosed by new session + var currentStateSessionStart: Option[Long] = Some(firstStateSessionStart) + var stop = false + while (!stop && currentStateSessionStart.isDefined) { + val stateSession = state.get(key, currentStateSessionStart.get) + + val stateSessionStart = sessionProjection(stateSession).getStruct(0, 2).getLong(0) + val stateSessionEnd = sessionProjection(stateSession).getStruct(0, 2).getLong(1) + + require(stateSessionStart == currentStateSessionStart.get, + "Session pointer doesn't match with actual session start!") + + // get next to continue searching after removal + val nextStateSessionStart = state.getNextSessionStart(key, stateSessionStart) + + // remove session if it is enclosed + if (stateSessionStart >= sessionStart && stateSessionEnd <= sessionEnd) { + state.remove(key, stateSessionStart) + currentStateSessionStart = nextStateSessionStart + } else { + stop = true + } + } + + // currentStateSessionStart is now the earliest session in state which + // new session should be added before + (prevForFirstStateSession, currentStateSessionStart) match { + case (_, Some(next)) => + state.addBefore(key, sessionStart, row, next) + lastSearchedSessionStartOption = Some(sessionStart) + + case (Some(prev), None) => + state.addAfter(key, sessionStart, row, prev) + lastSearchedSessionStartOption = Some(sessionStart) + + case (None, None) => + // we removed all elements + require(state.isEmpty(key), "It must be empty list since all elements are removed!") + state.setHead(key, sessionStart, row) + } + + case None => + // add to last: we got rid of the case list is empty + val lastSessionStartOption = lastSearchedSessionStartOption match { + case Some(lastSearchedSessionStart) => + state.getLastSessionStart(key, lastSearchedSessionStart) + case None => state.getLastSessionStart(key) + } + + lastSessionStartOption match { + case Some(lastSessionStart) => + state.addAfter(key, sessionStart, row, lastSessionStart) + + case None => + throw new IllegalStateException("List should not be empty!") + } + } + + // we don't need to search before the start of new session, since new sessions are sorted + // by session start + lastSearchedSessionStartOption = Some(sessionStart) + + true + } + + // assuming late events were dropped before + + outputMode match { + case Some(Complete) => + allUpdatesTimeMs += timeTakenMs { + while (iter.hasNext) { + val row = iter.next().asInstanceOf[UnsafeRow] + + if (reflectNewSession(row)) { + numUpdatedStateRows += 1 + } + } + } + + CompletionIterator[InternalRow, Iterator[InternalRow]]( + state.getAllRowPairs.map(_.value), { + commitTimeMs += timeTakenMs { state.commit() } + setStoreMetrics(state) + } + ) + + // Update and output only sessions being evicted from the MultiValuesStateManager + // Assumption: watermark predicates must be non-empty if append mode is allowed + case Some(Append) => + allUpdatesTimeMs += timeTakenMs { + while (iter.hasNext) { + val row = iter.next().asInstanceOf[UnsafeRow] + if (reflectNewSession(row)) { + numUpdatedStateRows += 1 + } + } + } + + val removalStartTimeNs = System.nanoTime + + val retIter = state.removeByValueCondition(row => watermarkPredicateForData match { + case Some(predicate) => predicate.eval(row) + case None => false + }, stopOnConditionMismatch = true).map { row => + numOutputRows += 1 + row.value + } + + CompletionIterator[InternalRow, Iterator[InternalRow]](retIter, { + allRemovalsTimeMs += NANOSECONDS.toMillis(System.nanoTime - removalStartTimeNs) + commitTimeMs += timeTakenMs { state.commit() } + setStoreMetrics(state) + }) + + // Update and output modified rows from the MultiValuesStateManager. + case Some(Update) => + + new NextIterator[InternalRow] { + private val updatesStartTimeNs = System.nanoTime + + override protected def getNext(): InternalRow = { + var ret: InternalRow = null + + while (ret == null && iter.hasNext) { + val row = iter.next().asInstanceOf[UnsafeRow] + if (reflectNewSession(row)) { + numUpdatedStateRows += 1 + ret = row + } + } + + if (ret == null && !iter.hasNext) { + finished = true + null + } else { + // !iter.hasNext && ret != null => can return ret, and next getNext() call will + // set finished = true + // iter.hasNext && (ret != null || ret == null) => not possible + numOutputRows += 1 + ret + } + } + + override protected def close(): Unit = { + allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs) + + // Remove old aggregates if watermark specified + allRemovalsTimeMs += timeTakenMs { + // fully consume iterator to ensure all necessary elements are evicted + state.removeByValueCondition(row => watermarkPredicateForData match { + case Some(predicate) => predicate.eval(row) + case None => false + }, stopOnConditionMismatch = true).toList + } + commitTimeMs += timeTakenMs { state.commit() } + setStoreMetrics(state) + } + } + + case _ => throw new UnsupportedOperationException(s"Invalid output mode: $outputMode") + } + } + } + + override def output: Seq[Attribute] = child.output + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def requiredChildDistribution: Seq[Distribution] = { + if (keyWithoutSessionExpressions.isEmpty) { + AllTuples :: Nil + } else { + ClusteredDistribution(keyWithoutSessionExpressions, stateInfo.map(_.numPartitions)) :: Nil + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = { + Seq((keyWithoutSessionExpressions ++ Seq(sessionExpression)).map(SortOrder(_, Ascending))) + } + + override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { + (outputMode.contains(Append) || outputMode.contains(Update)) && + eventTimeWatermark.isDefined && + newMetadata.batchWatermarkMs > eventTimeWatermark.get + } + + protected def setStoreMetrics(state: SessionWindowLinkedListState): Unit = { + val storeMetrics = state.metrics + longMetric("numTotalStateRows") += storeMetrics.numKeys + longMetric("stateMemory") += storeMetrics.memoryUsedBytes + storeMetrics.customMetrics.foreach { case (metric, value) => + longMetric(metric.name) += value + } + } +} + /** Physical operator for executing streaming Deduplicate. */ case class StreamingDeduplicateExec( keyExpressions: Seq[Attribute], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 4247d3110f1e..192688ded73e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3262,6 +3262,13 @@ object functions { window(timeColumn, windowDuration, windowDuration, "0 second") } + // FIXME: javadoc! + def session_window(timeColumn: Column, gapDuration: String): Column = { + withExpr { + SessionWindow(timeColumn.expr, gapDuration) + }.as("session_window") + } + ////////////////////////////////////////////////////////////////////////////////////////////// // Collection functions ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala new file mode 100644 index 000000000000..9f56d9f5962f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala @@ -0,0 +1,290 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.sql.catalyst.plans.logical.Expand +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.StringType + +class DataFrameSessionWindowingSuite extends QueryTest with SharedSQLContext + with BeforeAndAfterEach { + + import testImplicits._ + + test("simple session window with record at window start") { + val df = Seq( + ("2016-03-27 19:39:30", 1, "a")).toDF("time", "value", "id") + + checkAnswer( + df.groupBy(session_window($"time", "10 seconds")) + .agg(count("*").as("counts")) + .orderBy($"session_window.start".asc) + .select($"session_window.start".cast("string"), $"session_window.end".cast("string"), + $"counts"), + Seq( + Row("2016-03-27 19:39:30", "2016-03-27 19:39:40", 1) + ) + ) + } + + test("session window groupBy statement") { + val df = Seq( + ("2016-03-27 19:39:34", 1, "a"), + ("2016-03-27 19:39:56", 2, "a"), + ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + + // session window handles sort while applying group by + // whereas time window doesn't + + checkAnswer( + df.groupBy(session_window($"time", "10 seconds")) + .agg(count("*").as("counts")) + .orderBy($"session_window.start".asc) + .select("counts"), + Seq(Row(2), Row(1)) + ) + } + + test("session window groupBy with multiple keys statement") { + val df = Seq( + ("2016-03-27 19:39:34", 1, "a"), + ("2016-03-27 19:39:39", 1, "a"), + ("2016-03-27 19:39:56", 2, "a"), + ("2016-03-27 19:40:04", 2, "a"), + ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + + // session window handles sort while applying group by + // whereas time window doesn't + + // expected sessions + // key "a" => (19:39:34 ~ 19:39:49) (19:39:56 ~ 19:40:14) + // key "b" => (19:39:27 ~ 19:39:37) + + checkAnswer( + df.groupBy(session_window($"time", "10 seconds"), 'id) + .agg(count("*").as("counts"), sum("value").as("sum")) + .orderBy($"session_window.start".asc) + .selectExpr("CAST(session_window.start AS STRING)", "CAST(session_window.end AS STRING)", + "id", "counts", "sum"), + + Seq( + Row("2016-03-27 19:39:27", "2016-03-27 19:39:37", "b", 1, 4), + Row("2016-03-27 19:39:34", "2016-03-27 19:39:49", "a", 2, 2), + Row("2016-03-27 19:39:56", "2016-03-27 19:40:14", "a", 2, 4) + ) + ) + } + + test("session window groupBy with multiple keys statement - one distinct") { + val df = Seq( + ("2016-03-27 19:39:34", 1, "a"), + ("2016-03-27 19:39:39", 1, "a"), + ("2016-03-27 19:39:56", 2, "a"), + ("2016-03-27 19:40:04", 2, "a"), + ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + + // session window handles sort while applying group by + // whereas time window doesn't + + // expected sessions + // key "a" => (19:39:34 ~ 19:39:49) (19:39:56 ~ 19:40:14) + // key "b" => (19:39:27 ~ 19:39:37) + + checkAnswer( + df.groupBy(session_window($"time", "10 seconds"), 'id) + .agg(count("*").as("counts"), sumDistinct("value").as("sum")) + .orderBy($"session_window.start".asc) + .selectExpr("CAST(session_window.start AS STRING)", "CAST(session_window.end AS STRING)", + "id", "counts", "sum"), + Seq( + Row("2016-03-27 19:39:27", "2016-03-27 19:39:37", "b", 1, 4), + Row("2016-03-27 19:39:34", "2016-03-27 19:39:49", "a", 2, 1), + Row("2016-03-27 19:39:56", "2016-03-27 19:40:14", "a", 2, 2) + ) + ) + } + + test("session window groupBy with multiple keys statement - two distinct") { + val df = Seq( + ("2016-03-27 19:39:34", 1, 2, "a"), + ("2016-03-27 19:39:39", 1, 2, "a"), + ("2016-03-27 19:39:56", 2, 4, "a"), + ("2016-03-27 19:40:04", 2, 4, "a"), + ("2016-03-27 19:39:27", 4, 8, "b")).toDF("time", "value", "value2", "id") + + // session window handles sort while applying group by + // whereas time window doesn't + + // expected sessions + // key "a" => (19:39:34 ~ 19:39:49) (19:39:56 ~ 19:40:14) + // key "b" => (19:39:27 ~ 19:39:37) + + checkAnswer( + df.groupBy(session_window($"time", "10 seconds"), 'id) + .agg(sumDistinct("value").as("sum"), sumDistinct("value2").as("sum2")) + .orderBy($"session_window.start".asc) + .selectExpr("CAST(session_window.start AS STRING)", "CAST(session_window.end AS STRING)", + "id", "sum", "sum2"), + Seq( + Row("2016-03-27 19:39:27", "2016-03-27 19:39:37", "b", 4, 8), + Row("2016-03-27 19:39:34", "2016-03-27 19:39:49", "a", 1, 2), + Row("2016-03-27 19:39:56", "2016-03-27 19:40:14", "a", 2, 4) + ) + ) + } + + test("session window groupBy with multiple keys statement - keys overlapped with sessions") { + val df = Seq( + ("2016-03-27 19:39:34", 1, "a"), + ("2016-03-27 19:39:39", 1, "b"), + ("2016-03-27 19:39:40", 2, "a"), + ("2016-03-27 19:39:45", 2, "b"), + ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + + // session window handles sort while applying group by + // whereas time window doesn't + + // expected sessions + // a => (19:39:34 ~ 19:39:50) + // b => (19:39:27 ~ 19:39:37), (19:39:39 ~ 19:39:55) + + checkAnswer( + df.groupBy(session_window($"time", "10 seconds"), 'id) + .agg(count("*").as("counts"), sum("value").as("sum")) + .orderBy($"session_window.start".asc) + .selectExpr("CAST(session_window.start AS STRING)", "CAST(session_window.end AS STRING)", + "id", "counts", "sum"), + + Seq( + Row("2016-03-27 19:39:27", "2016-03-27 19:39:37", "b", 1, 4), + Row("2016-03-27 19:39:34", "2016-03-27 19:39:50", "a", 2, 3), + Row("2016-03-27 19:39:39", "2016-03-27 19:39:55", "b", 2, 3) + ) + ) + } + + test("session window with multi-column projection") { + val df = Seq( + ("2016-03-27 19:39:34", 1, "a"), + ("2016-03-27 19:39:56", 2, "a"), + ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + .select(session_window($"time", "10 seconds"), $"value") + .orderBy($"session_window.start".asc) + .select($"session_window.start".cast("string"), $"session_window.end".cast("string"), + $"value") + + val expands = df.queryExecution.optimizedPlan.find(_.isInstanceOf[Expand]) + assert(expands.isEmpty, "Session windows shouldn't require expand") + + checkAnswer( + df, + Seq( + Row("2016-03-27 19:39:27", "2016-03-27 19:39:37", 4), + Row("2016-03-27 19:39:34", "2016-03-27 19:39:44", 1), + Row("2016-03-27 19:39:56", "2016-03-27 19:40:06", 2) + ) + ) + } + + test("session window combined with explode expression") { + val df = Seq( + ("2016-03-27 19:39:34", 1, Seq("a", "b")), + ("2016-03-27 19:39:56", 2, Seq("a", "c", "d"))).toDF("time", "value", "ids") + + checkAnswer( + df.select(session_window($"time", "10 seconds"), $"value", explode($"ids")) + .orderBy($"session_window.start".asc).select("value"), + // first window exploded to two rows for "a", and "b", second window exploded to 3 rows + Seq(Row(1), Row(1), Row(2), Row(2), Row(2)) + ) + } + + test("null timestamps") { + val df = Seq( + ("2016-03-27 09:00:05", 1), + ("2016-03-27 09:00:32", 2), + (null, 3), + (null, 4)).toDF("time", "value") + + checkDataset( + df.select(session_window($"time", "10 seconds"), $"value") + .orderBy($"session_window.start".asc) + .select("value") + .as[Int], + 1, 2) // null columns are dropped + } + + // NOTE: unlike time window, joining session windows without grouping + // doesn't arrange session, so two rows will be joined only if session range is exactly same + + test("multiple session windows in a single operator throws nice exception") { + val df = Seq( + ("2016-03-27 09:00:02", 3), + ("2016-03-27 09:00:35", 6)).toDF("time", "value") + val e = intercept[AnalysisException] { + df.select(session_window($"time", "10 second"), session_window($"time", "15 second")) + .collect() + } + assert(e.getMessage.contains( + "Multiple time/session window expressions would result in a cartesian product")) + } + + test("aliased session windows") { + val df = Seq( + ("2016-03-27 19:39:34", 1, Seq("a", "b")), + ("2016-03-27 19:39:56", 2, Seq("a", "c", "d"))).toDF("time", "value", "ids") + + checkAnswer( + df.select(session_window($"time", "10 seconds").as("session_window"), $"value") + .orderBy($"session_window.start".asc) + .select("value"), + Seq(Row(1), Row(2)) + ) + } + + private def withTempTable(f: String => Unit): Unit = { + val tableName = "temp" + Seq( + ("2016-03-27 19:39:34", 1), + ("2016-03-27 19:39:56", 2), + ("2016-03-27 19:39:27", 4)).toDF("time", "value").createOrReplaceTempView(tableName) + try { + f(tableName) + } finally { + spark.catalog.dropTempView(tableName) + } + } + + test("time window in SQL with single string expression") { + withTempTable { table => + checkAnswer( + spark.sql(s"""select session_window(time, "10 seconds"), value from $table""") + .select($"session_window.start".cast(StringType), $"session_window.end".cast(StringType), + $"value"), + Seq( + Row("2016-03-27 19:39:27", "2016-03-27 19:39:37", 4), + Row("2016-03-27 19:39:34", "2016-03-27 19:39:44", 1), + Row("2016-03-27 19:39:56", "2016-03-27 19:40:06", 2) + ) + ) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala index 2953425b1db4..f1be46bf759c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala @@ -241,7 +241,7 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B df.select(window($"time", "10 second"), window($"time", "15 second")).collect() } assert(e.getMessage.contains( - "Multiple time window expressions would result in a cartesian product")) + "Multiple time/session window expressions would result in a cartesian product")) } test("aliased windows") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 631ab1b7ece7..7b60ca0510aa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -107,7 +107,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-14415: All functions should have own descriptions") { for (f <- spark.sessionState.functionRegistry.listFunction()) { - if (!Seq("cube", "grouping", "grouping_id", "rollup", "window").contains(f.unquotedString)) { + val excludes = Seq("cube", "grouping", "grouping_id", "rollup", "window", "session_window") + if (!excludes.contains(f.unquotedString)) { checkKeywordsNotExist(sql(s"describe function `$f`"), "N/A.") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowLinkedListStateIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowLinkedListStateIteratorSuite.scala new file mode 100644 index 000000000000..6e7b13c63fd0 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowLinkedListStateIteratorSuite.scala @@ -0,0 +1,432 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.util.UUID + +import org.apache.hadoop.conf.Configuration +import org.scalatest.exceptions.TestFailedException + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericInternalRow, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.execution.streaming.state.{SessionWindowLinkedListState, StateStore, StateStoreConf} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +class MergingSortWithSessionWindowLinkedListStateIteratorSuite extends SharedSQLContext { + + val rowSchema = new StructType().add("key1", StringType).add("key2", IntegerType) + .add("session", new StructType().add("start", LongType).add("end", LongType)) + .add("aggVal1", LongType).add("aggVal2", DoubleType) + val rowAttributes = rowSchema.toAttributes + + val keysWithoutSessionSchema = rowSchema.filter(st => List("key1", "key2").contains(st.name)) + val keysWithoutSessionAttributes = rowAttributes.filter { + attr => List("key1", "key2").contains(attr.name) + } + + val sessionSchema = rowSchema.filter(st => st.name == "session").head + val sessionAttribute = rowAttributes.filter(attr => attr.name == "session").head + + val valuesSchema = rowSchema.filter(st => List("aggVal1", "aggVal2").contains(st.name)) + val valuesAttributes = rowAttributes.filter { + attr => List("aggVal1", "aggVal2").contains(attr.name) + } + + val keyProjection = GenerateUnsafeProjection.generate(keysWithoutSessionAttributes, rowAttributes) + val sessionProjection = GenerateUnsafeProjection.generate(Seq(sessionAttribute), rowAttributes) + + test("no row in input data") { + withSessionWindowLinkedListState(rowAttributes, keysWithoutSessionAttributes) { state => + val iterator = new MergingSortWithSessionWindowLinkedListStateIterator(None.iterator, + state, keysWithoutSessionAttributes, sessionAttribute, rowAttributes) + + assert(!iterator.hasNext) + } + } + + test("no row in input data but having state") { + withSessionWindowLinkedListState(rowAttributes, keysWithoutSessionAttributes) { state => + val srow11 = createRow("a", 1, 55, 85, 50, 2.5) + val srow12 = createRow("a", 1, 105, 140, 30, 2.0) + + setRowsInState(state, keyProjection(srow11), srow11, srow12) + + val iterator = new MergingSortWithSessionWindowLinkedListStateIterator(None.iterator, + state, keysWithoutSessionAttributes, sessionAttribute, rowAttributes) + + assert(!iterator.hasNext) + } + } + + test("no previous state") { + withSessionWindowLinkedListState(rowAttributes, keysWithoutSessionAttributes) { state => + val row1 = createRow("a", 1, 100, 110, 10, 1.1) + val row2 = createRow("a", 1, 100, 110, 20, 1.2) + val row3 = createRow("a", 2, 110, 120, 10, 1.1) + val row4 = createRow("a", 2, 115, 125, 20, 1.2) + val rows = List(row1, row2, row3, row4) + + val iterator = new MergingSortWithSessionWindowLinkedListStateIterator(rows.iterator, + state, keysWithoutSessionAttributes, sessionAttribute, rowAttributes) + + rows.foreach { row => + assert(iterator.hasNext) + assertRowsEquals(row, iterator.next()) + } + + assert(!iterator.hasNext) + } + } + + test("single previous state") { + withSessionWindowLinkedListState(rowAttributes, keysWithoutSessionAttributes) { state => + // key1 - events are earlier than state + val row11 = createRow("a", 1, 100, 110, 10, 1.1) + val row12 = createRow("a", 1, 110, 120, 20, 1.2) + + // below will not be picked up since the session is not matched to new events + val srow11 = createRow("a", 1, 200, 220, 10, 1.3) + setRowsInState(state, keyProjection(srow11), srow11) + + // key2 - events are later than state + // below will not be picked up since the session is not matched to new events + val srow21 = createRow("a", 2, 50, 70, 10, 1.1) + + val row21 = createRow("a", 2, 100, 110, 10, 1.1) + val row22 = createRow("a", 2, 110, 120, 20, 1.2) + setRowsInState(state, keyProjection(srow21), srow21) + + // key3 - events are enclosing the state + val row31 = createRow("a", 3, 90, 100, 10, 1.1) + val srow31 = createRow("a", 3, 100, 110, 10, 1.1) + val row32 = createRow("a", 3, 105, 115, 20, 1.2) + setRowsInState(state, keyProjection(srow31), srow31) + + val rows = List(row11, row12) ++ List(row21, row22) ++ List(row31, row32) + + val expectedRows = List(row11, row12) ++ List(row21, row22) ++ + List(row31, srow31, row32) + + val iterator = new MergingSortWithSessionWindowLinkedListStateIterator(rows.iterator, + state, keysWithoutSessionAttributes, sessionAttribute, rowAttributes) + + expectedRows.foreach { row => + assert(iterator.hasNext, "Iterator.hasNext is false while we expected row " + + s"${getTupleFromRow(row)}") + assertRowsEquals(row, iterator.next()) + } + + assert(!iterator.hasNext) + } + } + + test("only emitting sessions in state which enclose events") { + withSessionWindowLinkedListState(rowAttributes, keysWithoutSessionAttributes) { state => + // below example is group by line separated + + val row1 = createRow("a", 1, 10, 20, 1, 1.1) + val row2 = createRow("a", 1, 20, 30, 1, 1.1) + val row3 = createRow("a", 1, 30, 40, 1, 1.1) + val srow1 = createRow("a", 1, 40, 60, 2, 2.2) + val row4 = createRow("a", 1, 40, 50, 1, 1.1) + + // below will not be picked up since the session is not matched to new events + val srow2 = createRow("a", 1, 80, 90, 2, 2.2) + val srow3 = createRow("a", 1, 100, 110, 2, 2.2) + + val srow4 = createRow("a", 1, 120, 130, 2, 2.2) + val row5 = createRow("a", 1, 125, 135, 1, 1.1) + val row6 = createRow("a", 1, 140, 150, 1, 1.1) + + // below will not be picked up since the session is not matched to new events + val srow5 = createRow("a", 1, 180, 200, 2, 2.2) + val srow6 = createRow("a", 1, 220, 260, 2, 2.2) + + setRowsInState(state, keyProjection(srow1), srow1, srow2, srow3, srow4, srow5, srow6) + + val rows = List(row1, row2, row3, row4, row5, row6) + + val expectedRowSequence = List(row1, row2, row3, srow1, row4, srow4, + row5, row6) + + val iterator = new MergingSortWithSessionWindowLinkedListStateIterator(rows.iterator, + state, keysWithoutSessionAttributes, sessionAttribute, rowAttributes) + + expectedRowSequence.foreach { row => + assert(iterator.hasNext) + assertRowsEquals(row, iterator.next()) + } + + assert(!iterator.hasNext) + } + } + + test("multiple keys in input data and state") { + withSessionWindowLinkedListState(rowAttributes, keysWithoutSessionAttributes) { state => + // key 1 - placing sessions in state to start and end + val srow11 = createRow("a", 1, 85, 105, 50, 2.5) + val row11 = createRow("a", 1, 100, 110, 10, 1.1) + val row12 = createRow("a", 1, 100, 110, 20, 1.2) + val srow12 = createRow("a", 1, 105, 140, 30, 2.0) + + val key1 = keyProjection(srow11) + setRowsInState(state, key1, srow11, srow12) + + val rowsForKey1 = List(row11, row12) + val expectedForKey1 = List(srow11, row11, row12, srow12) + + // key 2 - no state + val row21 = createRow("a", 2, 110, 120, 10, 1.1) + val row22 = createRow("a", 2, 115, 125, 20, 1.2) + + val rowsForKey2 = List(row21, row22) + val expectedForKey2 = List(row21, row22) + + // key 3 - placing sessions in state to only start + + // below will not be picked up since the session is not matched to new events + val srow31 = createRow("a", 3, 105, 115, 30, 2.0) + val srow32 = createRow("a", 3, 120, 125, 30, 2.0) + + val row31 = createRow("a", 3, 130, 140, 10, 1.1) + val row32 = createRow("a", 3, 135, 145, 20, 1.2) + + val key3 = keyProjection(srow31) + setRowsInState(state, key3, srow31, srow32) + + val rowsForKey3 = List(row31, row32) + val expectedForKey3 = List(row31, row32) + + // key 4 - placing sessions in state to only end + val row41 = createRow("a", 4, 100, 110, 10, 1.1) + val row42 = createRow("a", 4, 100, 115, 20, 1.2) + + // below will not be picked up since the session is not matched to new events + val srow41 = createRow("a", 4, 120, 140, 30, 2.0) + val srow42 = createRow("a", 4, 150, 180, 30, 2.0) + + val key4 = keyProjection(srow41) + setRowsInState(state, key4, srow41, srow42) + + val rowsForKey4 = List(row41, row42) + val expectedForKey4 = List(row41, row42) + + // key 5 - placing sessions in state like one row and state session and another + val srow51 = createRow("a", 5, 90, 120, 30, 2.0) + val row51 = createRow("a", 5, 100, 110, 10, 1.1) + val srow52 = createRow("a", 5, 130, 155, 30, 2.0) + val row52 = createRow("a", 5, 140, 160, 20, 1.2) + val srow53 = createRow("a", 5, 160, 190, 30, 2.0) + + val key5 = keyProjection(srow51) + setRowsInState(state, key5, srow51, srow52, srow53) + + val rowsForKey5 = List(row51, row52) + val expectedForKey5 = List(srow51, row51, srow52, row52, srow53) + + val rows = rowsForKey1 ++ rowsForKey2 ++ rowsForKey3 ++ rowsForKey4 ++ rowsForKey5 + + val expectedRowSequence = expectedForKey1 ++ expectedForKey2 ++ expectedForKey3 ++ + expectedForKey4 ++ expectedForKey5 + + val iterator = new MergingSortWithSessionWindowLinkedListStateIterator(rows.iterator, + state, keysWithoutSessionAttributes, sessionAttribute, rowAttributes) + + expectedRowSequence.foreach { row => + assert(iterator.hasNext, s"Iterator closed while we expect ${getTupleFromRow(row)}") + assertRowsEquals(row, iterator.next()) + } + + assert(!iterator.hasNext) + } + } + + test("no keys in input data and state") { + val noKeyRowSchema = new StructType() + .add("session", new StructType().add("start", LongType).add("end", LongType)) + .add("aggVal1", LongType).add("aggVal2", DoubleType) + val noKeyRowAttributes = noKeyRowSchema.toAttributes + + val noKeySessionAttribute = noKeyRowAttributes.filter(attr => attr.name == "session").head + + def createNoKeyRow(sessionStart: Long, sessionEnd: Long, + aggVal1: Long, aggVal2: Double): UnsafeRow = { + val genericRow = new GenericInternalRow(4) + val session: Array[Any] = new Array[Any](2) + session(0) = sessionStart + session(1) = sessionEnd + + val sessionRow = new GenericInternalRow(session) + genericRow.update(0, sessionRow) + + genericRow.setLong(1, aggVal1) + genericRow.setDouble(2, aggVal2) + + val rowProjection = GenerateUnsafeProjection.generate(noKeyRowAttributes, noKeyRowAttributes) + rowProjection(genericRow) + } + + def assertNoKeyRowsEquals(expectedRow: InternalRow, retRow: InternalRow): Unit = { + assert(retRow.getStruct(0, 2).getLong(0) == expectedRow.getStruct(0, 2).getLong(0)) + assert(retRow.getStruct(0, 2).getLong(1) == expectedRow.getStruct(0, 2).getLong(1)) + assert(retRow.getLong(1) === expectedRow.getLong(1)) + assert(doubleEquals(retRow.getDouble(2), expectedRow.getDouble(2))) + } + + def setNoKeyRowsInState(state: SessionWindowLinkedListState, rows: UnsafeRow*) + : Unit = { + def getSessionStart(row: UnsafeRow): Long = { + row.getStruct(0, 2).getLong(0) + } + + val key = new UnsafeRow(0) + val iter = rows.sortBy(getSessionStart).iterator + + var prevSessionStart: Option[Long] = None + while (iter.hasNext) { + val row = iter.next() + val sessionStart = getSessionStart(row) + if (prevSessionStart.isDefined) { + state.addAfter(key, sessionStart, row, prevSessionStart.get) + } else { + state.setHead(key, sessionStart, row) + } + + prevSessionStart = Some(sessionStart) + } + } + + withSessionWindowLinkedListState(noKeyRowAttributes, Seq.empty[Attribute]) { state => + // this will not be picked up because the session in state is not enclosing events + val srow1 = createNoKeyRow(10, 16, 10, 21) + val srow2 = createNoKeyRow(17, 27, 2, 39) + + val srow3 = createNoKeyRow(35, 40, 1, 35) + val row1 = createNoKeyRow(40, 45, 10, 45) + setNoKeyRowsInState(state, srow1, srow2, srow3) + + val rows = List(row1) + + val expectedRowSequence = List(srow3, row1) + + val iterator = new MergingSortWithSessionWindowLinkedListStateIterator(rows.iterator, + state, Seq.empty[Attribute], noKeySessionAttribute, noKeyRowAttributes) + + expectedRowSequence.foreach { row => + assert(iterator.hasNext) + assertNoKeyRowsEquals(row, iterator.next()) + } + + assert(!iterator.hasNext) + } + } + + private def setRowsInState(state: SessionWindowLinkedListState, key: UnsafeRow, + rows: UnsafeRow*): Unit = { + def getSessionStart(row: UnsafeRow): Long = { + row.getStruct(2, 2).getLong(0) + } + + val iter = rows.sortBy(getSessionStart).iterator + + var prevSessionStart: Option[Long] = None + while (iter.hasNext) { + val row = iter.next() + val sessionStart = getSessionStart(row) + if (prevSessionStart.isDefined) { + state.addAfter(key, sessionStart, row, prevSessionStart.get) + } else { + state.setHead(key, sessionStart, row) + } + + prevSessionStart = Some(sessionStart) + } + } + + private def createRow(key1: String, key2: Int, sessionStart: Long, sessionEnd: Long, + aggVal1: Long, aggVal2: Double): UnsafeRow = { + val genericRow = new GenericInternalRow(6) + if (key1 != null) { + genericRow.update(0, UTF8String.fromString(key1)) + } else { + genericRow.setNullAt(0) + } + genericRow.setInt(1, key2) + + val session: Array[Any] = new Array[Any](2) + session(0) = sessionStart + session(1) = sessionEnd + + val sessionRow = new GenericInternalRow(session) + genericRow.update(2, sessionRow) + + genericRow.setLong(3, aggVal1) + genericRow.setDouble(4, aggVal2) + + val rowProjection = GenerateUnsafeProjection.generate(rowAttributes, rowAttributes) + rowProjection(genericRow) + } + + private def doubleEquals(value1: Double, value2: Double): Boolean = { + value1 > value2 - 0.000001 && value1 < value2 + 0.000001 + } + + private def getTupleFromRow(row: InternalRow): (String, Int, Long, Long, Long, Double) = { + (row.getString(0), row.getInt(1), row.getStruct(2, 2).getLong(0), + row.getStruct(2, 2).getLong(1), row.getLong(3), row.getDouble(4)) + } + + private def assertRowsEquals(expectedRow: InternalRow, retRow: InternalRow): Unit = { + val tupleFromExpectedRow = getTupleFromRow(expectedRow) + val tupleFromInternalRow = getTupleFromRow(retRow) + try { + assert(tupleFromExpectedRow._1 === tupleFromInternalRow._1) + assert(tupleFromExpectedRow._2 === tupleFromInternalRow._2) + assert(tupleFromExpectedRow._3 === tupleFromInternalRow._3) + assert(tupleFromExpectedRow._4 === tupleFromInternalRow._4) + assert(tupleFromExpectedRow._5 === tupleFromInternalRow._5) + assert(doubleEquals(tupleFromExpectedRow._6, tupleFromInternalRow._6)) + } catch { + case e: TestFailedException => + throw new TestFailedException(s"$tupleFromExpectedRow did not equal $tupleFromInternalRow", + e, e.failedCodeStackDepth) + } + } + + private def withSessionWindowLinkedListState( + inputValueAttribs: Seq[Attribute], + keyAttribs: Seq[Attribute])(f: SessionWindowLinkedListState => Unit): Unit = { + + withTempDir { file => + val storeConf = new StateStoreConf() + val stateInfo = StatefulOperatorStateInfo(file.getAbsolutePath, UUID.randomUUID, 0, 0, 5) + + val state = new SessionWindowLinkedListState(s"session-${stateInfo.operatorId}-", + inputValueAttribs, keyAttribs, Some(stateInfo), storeConf, new Configuration) + try { + f(state) + } finally { + state.abortIfNeeded() + } + } + StateStore.stop() + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionIteratorSuite.scala new file mode 100644 index 000000000000..ff6d774200f4 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionIteratorSuite.scala @@ -0,0 +1,420 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.util.Properties + +import org.apache.spark._ +import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericInternalRow, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.execution.aggregate.UpdatingSessionIterator +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +class UpdatingSessionIteratorSuite extends SharedSQLContext { + + val rowSchema = new StructType().add("key1", StringType).add("key2", IntegerType) + .add("session", new StructType().add("start", LongType).add("end", LongType)) + .add("aggVal1", LongType).add("aggVal2", DoubleType) + val rowAttributes = rowSchema.toAttributes + + val keysWithSessionSchema = rowSchema.filter { attr => + List("key1", "key2", "session").contains(attr.name) + } + val keysWithSessionAttributes = rowAttributes.filter { attr => + List("key1", "key2", "session").contains(attr.name) + } + + val sessionSchema = rowSchema.filter(st => st.name == "session").head + val sessionAttribute = rowAttributes.filter(attr => attr.name == "session").head + + val valuesSchema = rowSchema.filter(st => List("aggVal1", "aggVal2").contains(st.name)) + val valuesAttributes = rowAttributes.filter { + attr => List("aggVal1", "aggVal2").contains(attr.name) + } + + override def beforeAll(): Unit = { + super.beforeAll() + val taskManager = new TaskMemoryManager(new TestMemoryManager(sqlContext.sparkContext.conf), 0) + TaskContext.setTaskContext( + new TaskContextImpl(0, 0, 0, 0, 0, taskManager, new Properties, null)) + } + + override def afterAll(): Unit = try { + TaskContext.unset() + } finally { + super.afterAll() + } + + // just copying default values to avoid bothering with SQLContext + val inMemoryThreshold = 4096 + val spillThreshold = Int.MaxValue + + test("no row") { + val iterator = new UpdatingSessionIterator(None.iterator, keysWithSessionAttributes, + sessionAttribute, rowAttributes, inMemoryThreshold, spillThreshold) + + assert(!iterator.hasNext) + } + + test("only one row") { + val rows = List(createRow("a", 1, 100, 110, 10, 1.1)) + + val iterator = new UpdatingSessionIterator(rows.iterator, keysWithSessionAttributes, + sessionAttribute, rowAttributes, inMemoryThreshold, spillThreshold) + + assert(iterator.hasNext) + + val retRow = iterator.next() + assertRowsEquals(retRow, rows.head) + + assert(!iterator.hasNext) + } + + test("one session per key, one key") { + val row1 = createRow("a", 1, 100, 110, 10, 1.1) + val row2 = createRow("a", 1, 100, 110, 20, 1.2) + val row3 = createRow("a", 1, 105, 115, 30, 1.3) + val row4 = createRow("a", 1, 113, 123, 40, 1.4) + val rows = List(row1, row2, row3, row4) + + val iterator = new UpdatingSessionIterator(rows.iterator, keysWithSessionAttributes, + sessionAttribute, rowAttributes, inMemoryThreshold, spillThreshold) + + val retRows = rows.indices.map { _ => + assert(iterator.hasNext) + iterator.next() + } + + retRows.zip(rows).foreach { case (retRow, expectedRow) => + // session being expanded to (100 ~ 123) + assertRowsEqualsWithNewSession(expectedRow, retRow, 100, 123) + } + + assert(iterator.hasNext === false) + } + + test("one session per key, multi keys") { + val row1 = createRow("a", 1, 100, 110, 10, 1.1) + val row2 = createRow("a", 1, 100, 110, 20, 1.2) + val row3 = createRow("a", 1, 105, 115, 30, 1.3) + val row4 = createRow("a", 1, 113, 123, 40, 1.4) + val rows1 = List(row1, row2, row3, row4) + + val row5 = createRow("a", 2, 110, 120, 10, 1.1) + val row6 = createRow("a", 2, 115, 125, 20, 1.2) + val row7 = createRow("a", 2, 117, 127, 30, 1.3) + val row8 = createRow("a", 2, 125, 135, 40, 1.4) + val rows2 = List(row5, row6, row7, row8) + + val rowsAll = rows1 ++ rows2 + + val iterator = new UpdatingSessionIterator(rowsAll.iterator, keysWithSessionAttributes, + sessionAttribute, rowAttributes, inMemoryThreshold, spillThreshold) + + val retRows1 = rows1.indices.map { _ => + assert(iterator.hasNext) + iterator.next() + } + val retRows2 = rows2.indices.map { _ => + assert(iterator.hasNext) + iterator.next() + } + + retRows1.zip(rows1).foreach { case (retRow, expectedRow) => + // session being expanded to (100 ~ 123) + assertRowsEqualsWithNewSession(expectedRow, retRow, 100, 123) + } + + retRows2.zip(rows2).foreach { case (retRow, expectedRow) => + // session being expanded to (110 ~ 135) + assertRowsEqualsWithNewSession(expectedRow, retRow, 110, 135) + } + + assert(iterator.hasNext === false) + } + + test("multiple sessions per key, single key") { + val row1 = createRow("a", 1, 100, 110, 10, 1.1) + val row2 = createRow("a", 1, 105, 115, 20, 1.2) + val rows1 = List(row1, row2) + + val row3 = createRow("a", 1, 125, 135, 30, 1.3) + val row4 = createRow("a", 1, 127, 137, 40, 1.4) + val rows2 = List(row3, row4) + + val rowsAll = rows1 ++ rows2 + + val iterator = new UpdatingSessionIterator(rowsAll.iterator, keysWithSessionAttributes, + sessionAttribute, rowAttributes, inMemoryThreshold, spillThreshold) + + val retRows1 = rows1.indices.map { _ => + assert(iterator.hasNext) + iterator.next() + } + + val retRows2 = rows2.indices.map { _ => + assert(iterator.hasNext) + iterator.next() + } + + retRows1.zip(rows1).foreach { case (retRow, expectedRow) => + // session being expanded to (100 ~ 115) + assertRowsEqualsWithNewSession(expectedRow, retRow, 100, 115) + } + + retRows2.zip(rows2).foreach { case (retRow, expectedRow) => + // session being expanded to (125 ~ 137) + assertRowsEqualsWithNewSession(expectedRow, retRow, 125, 137) + } + + assert(iterator.hasNext === false) + } + + test("multiple sessions per key, multi keys") { + val row1 = createRow("a", 1, 100, 110, 10, 1.1) + val row2 = createRow("a", 1, 100, 110, 20, 1.2) + val rows1 = List(row1, row2) + + val row3 = createRow("a", 1, 115, 125, 30, 1.3) + val row4 = createRow("a", 1, 119, 129, 40, 1.4) + val rows2 = List(row3, row4) + + val row5 = createRow("a", 2, 110, 120, 10, 1.1) + val row6 = createRow("a", 2, 115, 125, 20, 1.2) + val rows3 = List(row5, row6) + + val row7 = createRow("a", 2, 127, 137, 30, 1.3) + val row8 = createRow("a", 2, 135, 145, 40, 1.4) + val rows4 = List(row7, row8) + + val rowsAll = rows1 ++ rows2 ++ rows3 ++ rows4 + + val iterator = new UpdatingSessionIterator(rowsAll.iterator, keysWithSessionAttributes, + sessionAttribute, rowAttributes, inMemoryThreshold, spillThreshold) + + val retRows1 = rows1.indices.map { _ => + assert(iterator.hasNext) + iterator.next() + } + + val retRows2 = rows2.indices.map { _ => + assert(iterator.hasNext) + iterator.next() + } + + val retRows3 = rows3.indices.map { _ => + assert(iterator.hasNext) + iterator.next() + } + + val retRows4 = rows4.indices.map { _ => + assert(iterator.hasNext) + iterator.next() + } + + retRows1.zip(rows1).foreach { case (retRow, expectedRow) => + // session being expanded to (100 ~ 110) + assertRowsEqualsWithNewSession(expectedRow, retRow, 100, 110) + } + + retRows2.zip(rows2).foreach { case (retRow, expectedRow) => + // session being expanded to (115 ~ 129) + assertRowsEqualsWithNewSession(expectedRow, retRow, 115, 129) + } + + retRows3.zip(rows3).foreach { case (retRow, expectedRow) => + // session being expanded to (110 ~ 125) + assertRowsEqualsWithNewSession(expectedRow, retRow, 110, 125) + } + + retRows4.zip(rows4).foreach { case (retRow, expectedRow) => + // session being expanded to (127 ~ 145) + assertRowsEqualsWithNewSession(expectedRow, retRow, 127, 145) + } + + assert(iterator.hasNext === false) + } + + test("throws exception if data is not sorted by session start") { + val row1 = createRow("a", 1, 100, 110, 10, 1.1) + val row2 = createRow("a", 1, 100, 110, 20, 1.2) + val row3 = createRow("a", 1, 95, 105, 30, 1.3) + val row4 = createRow("a", 1, 113, 123, 40, 1.4) + val rows = List(row1, row2, row3, row4) + + val iterator = new UpdatingSessionIterator(rows.iterator, keysWithSessionAttributes, + sessionAttribute, rowAttributes, inMemoryThreshold, spillThreshold) + + // UpdatingSessionIterator can't detect error on hasNext + assert(iterator.hasNext) + + // when calling next() it can detect error and throws IllegalStateException + intercept[IllegalStateException] { + iterator.next() + } + + // afterwards, calling either hasNext() or next() will throw IllegalStateException + intercept[IllegalStateException] { + iterator.hasNext + } + + intercept[IllegalStateException] { + iterator.next() + } + } + + test("throws exception if data is not sorted by key") { + val row1 = createRow("a", 1, 100, 110, 10, 1.1) + val row2 = createRow("a", 2, 100, 110, 20, 1.2) + val row3 = createRow("a", 1, 113, 123, 40, 1.4) + val rows = List(row1, row2, row3) + + val iterator = new UpdatingSessionIterator(rows.iterator, keysWithSessionAttributes, + sessionAttribute, rowAttributes, inMemoryThreshold, spillThreshold) + + // UpdatingSessionIterator can't detect error on hasNext + assert(iterator.hasNext) + + assertRowsEquals(row1, iterator.next()) + + assert(iterator.hasNext) + + // second row itself is OK but while finding end of session it reads third row, and finds + // its key is already finished processing, hence precondition for sorting is broken, and + // it throws IllegalStateException + intercept[IllegalStateException] { + iterator.next() + } + + // afterwards, calling either hasNext() or next() will throw IllegalStateException + intercept[IllegalStateException] { + iterator.hasNext + } + + intercept[IllegalStateException] { + iterator.next() + } + } + + test("no key") { + val noKeyRowSchema = new StructType() + .add("session", new StructType().add("start", LongType).add("end", LongType)) + .add("aggVal1", LongType).add("aggVal2", DoubleType) + val noKeyRowAttributes = noKeyRowSchema.toAttributes + + val noKeySessionAttribute = noKeyRowAttributes.filter(attr => attr.name == "session").head + + def createNoKeyRow(sessionStart: Long, sessionEnd: Long, + aggVal1: Long, aggVal2: Double): UnsafeRow = { + val genericRow = new GenericInternalRow(4) + val session: Array[Any] = new Array[Any](2) + session(0) = sessionStart + session(1) = sessionEnd + + val sessionRow = new GenericInternalRow(session) + genericRow.update(0, sessionRow) + + genericRow.setLong(1, aggVal1) + genericRow.setDouble(2, aggVal2) + + val rowProjection = GenerateUnsafeProjection.generate(noKeyRowAttributes, noKeyRowAttributes) + rowProjection(genericRow) + } + + def assertNoKeyRowsEqualsWithNewSession(expectedRow: InternalRow, retRow: InternalRow, + newSessionStart: Long, newSessionEnd: Long): Unit = { + assert(retRow.getStruct(0, 2).getLong(0) == newSessionStart) + assert(retRow.getStruct(0, 2).getLong(1) == newSessionEnd) + assert(retRow.getLong(1) === expectedRow.getLong(1)) + assert(doubleEquals(retRow.getDouble(2), expectedRow.getDouble(2))) + } + + val row1 = createNoKeyRow(100, 110, 10, 1.1) + val row2 = createNoKeyRow(100, 110, 20, 1.2) + val row3 = createNoKeyRow(105, 115, 30, 1.3) + val row4 = createNoKeyRow(113, 123, 40, 1.4) + val rows = List(row1, row2, row3, row4) + + val iterator = new UpdatingSessionIterator(rows.iterator, Seq(noKeySessionAttribute), + noKeySessionAttribute, noKeyRowAttributes, inMemoryThreshold, spillThreshold) + + val retRows = rows.indices.map { _ => + assert(iterator.hasNext) + iterator.next() + } + + retRows.zip(rows).foreach { case (retRow, expectedRow) => + // session being expanded to (100 ~ 123) + assertNoKeyRowsEqualsWithNewSession(expectedRow, retRow, 100, 123) + } + + assert(iterator.hasNext === false) + } + + private def createRow(key1: String, key2: Int, sessionStart: Long, sessionEnd: Long, + aggVal1: Long, aggVal2: Double): UnsafeRow = { + val genericRow = new GenericInternalRow(6) + if (key1 != null) { + genericRow.update(0, UTF8String.fromString(key1)) + } else { + genericRow.setNullAt(0) + } + genericRow.setInt(1, key2) + + val session: Array[Any] = new Array[Any](2) + session(0) = sessionStart + session(1) = sessionEnd + + val sessionRow = new GenericInternalRow(session) + genericRow.update(2, sessionRow) + + genericRow.setLong(3, aggVal1) + genericRow.setDouble(4, aggVal2) + + val rowProjection = GenerateUnsafeProjection.generate(rowAttributes, rowAttributes) + rowProjection(genericRow) + } + + private def doubleEquals(value1: Double, value2: Double): Boolean = { + value1 > value2 - 0.000001 && value1 < value2 + 0.000001 + } + + private def assertRowsEquals(expectedRow: InternalRow, retRow: InternalRow): Unit = { + assert(retRow.getString(0) === expectedRow.getString(0)) + assert(retRow.getInt(1) === expectedRow.getInt(1)) + assert(retRow.getStruct(2, 2).getLong(0) == expectedRow.getStruct(2, 2).getLong(0)) + assert(retRow.getStruct(2, 2).getLong(1) == expectedRow.getStruct(2, 2).getLong(1)) + assert(retRow.getLong(3) === expectedRow.getLong(3)) + assert(doubleEquals(retRow.getDouble(3), expectedRow.getDouble(3))) + } + + private def assertRowsEqualsWithNewSession(expectedRow: InternalRow, retRow: InternalRow, + newSessionStart: Long, newSessionEnd: Long): Unit = { + assert(retRow.getString(0) === expectedRow.getString(0)) + assert(retRow.getInt(1) === expectedRow.getInt(1)) + assert(retRow.getStruct(2, 2).getLong(0) == newSessionStart) + assert(retRow.getStruct(2, 2).getLong(1) == newSessionEnd) + assert(retRow.getLong(3) === expectedRow.getLong(3)) + assert(doubleEquals(retRow.getDouble(3), expectedRow.getDouble(3))) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListStateSuite.scala new file mode 100644 index 000000000000..af7f2f986a30 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListStateSuite.scala @@ -0,0 +1,376 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import java.util.UUID + +import scala.util.Random + +import org.apache.hadoop.conf.Configuration +import org.scalatest.exceptions.TestFailedException + +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, GenericInternalRow, LessThanOrEqual, Literal, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate +import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark +import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.types._ + +class SessionWindowLinkedListStateSuite extends StreamTest { + + test("add sessions - normal case") { + withSessionWindowLinkedListState(inputValueAttribs, keyExprs) { state => + implicit val st = state + + assert(get(20) === Seq.empty) + setHead(20, 3, time = 3) + assert(get(20) === Seq(3)) + assert(numRows === 1) + + // add element before head: 1 is the new head + addBefore(20, 1, time = 1, targetTime = 3) + assert(get(20) === Seq(1, 3)) + assert(numRows === 2) + + // add element before other element but after head + addBefore(20, 2, time = 2, targetTime = 3) + assert(get(20) === Seq(1, 2, 3)) + assert(numRows === 3) + + // add element at the end + addAfter(20, 5, time = 5, targetTime = 3) + assert(get(20) === Seq(1, 2, 3, 5)) + assert(numRows === 4) + + // add element after other element but before tail element + addAfter(20, 4, time = 4, targetTime = 3) + assert(get(20) === Seq(1, 2, 3, 4, 5)) + assert(numRows === 5) + + update(20, 100, time = 3) + assert(get(20) === Seq(1, 2, 100, 4, 5)) + assert(numRows === 5) + + assert(get(30) === Seq.empty) + setHead(30, 1, time = 1) + assert(get(30) === Seq(1)) + assert(get(20) === Seq(1, 2, 100, 4, 5)) + assert(numRows === 6) + } + } + + test("add sessions - improper usage") { + withSessionWindowLinkedListState(inputValueAttribs, keyExprs) { state => + implicit val st = state + + assert(get(20) === Seq.empty) + + setHead(20, 2, time = 2) + // setting head twice + intercept[IllegalArgumentException] { + setHead(20, 2, time = 2) + } + + // add element with dangling pointer + intercept[IllegalArgumentException] { + addBefore(20, 1, time = 1, targetTime = 3) + } + + // add element with dangling pointer + intercept[IllegalArgumentException] { + addAfter(20, 2, time = 5, targetTime = 3) + } + } + } + + test("remove sessions - normal usage") { + withSessionWindowLinkedListState(inputValueAttribs, keyExprs) { state => + implicit val st = state + + assert(numRows === 0) + + setHead(20, 1, time = 1) + addAfter(20, 2, time = 2, targetTime = 1) + addAfter(20, 3, time = 3, targetTime = 2) + addAfter(20, 4, time = 4, targetTime = 3) + assert(numRows === 4) + + // remove head which list has another elements as well + remove(20, time = 1) + assert(get(20) === Seq(2, 3, 4)) + assert(numRows === 3) + + // remove intermediate element + remove(20, time = 3) + assert(get(20) === Seq(2, 4)) + assert(numRows === 2) + + // remove tail element + remove(20, time = 4) + assert(get(20) === Seq(2)) + assert(numRows === 1) + + // remove head which list has only one element + remove(20, time = 2) + assert(get(20) === Seq.empty) + assert(numRows === 0) + } + } + + test("remove sessions - improper usage") { + withSessionWindowLinkedListState(inputValueAttribs, keyExprs) { state => + implicit val st = state + + assert(get(20) === Seq.empty) + setHead(20, 2, time = 2) + + // try to remove non-exist time + intercept[IllegalArgumentException] { + remove(20, 3) + } + + assert(get(20) === Seq(2)) + assert(numRows === 1) + } + } + + test("get all pairs, iterate pointers, find first") { + withSessionWindowLinkedListState(inputValueAttribs, keyExprs) { state => + implicit val st = state + assert(numRows === 0) + + setHead(20, 1, time = 1) + addAfter(20, 2, time = 2, targetTime = 1) + addAfter(20, 3, time = 3, targetTime = 2) + addAfter(20, 4, time = 4, targetTime = 3) + + setHead(30, 5, time = 5) + addAfter(30, 6, time = 6, targetTime = 5) + addAfter(30, 7, time = 7, targetTime = 6) + addAfter(30, 8, time = 8, targetTime = 7) + + setHead(40, 10, time = 10) + addAfter(40, 11, time = 11, targetTime = 10) + addAfter(40, 12, time = 12, targetTime = 11) + addAfter(40, 13, time = 13, targetTime = 12) + + assert(numRows === 12) + + // must keep input order per key + val groupedTuples = getAllRowPairs.groupBy(_._1) + assert(groupedTuples(20).map(_._2) === Seq(1, 2, 3, 4)) + assert(groupedTuples(30).map(_._2) === Seq(5, 6, 7, 8)) + assert(groupedTuples(40).map(_._2) === Seq(10, 11, 12, 13)) + + // iterate pointers + + val expected = Seq((1, None, Some(2)), (2, Some(1), Some(3)), (3, Some(2), Some(4)), + (4, Some(3), None)) + expected.foreach { case (current, expectedPrev, expectedNext) => + assert(getPrevTime(20, current) == expectedPrev) + assert(getNextTime(20, current) == expectedNext) + } + + assert(iterateTimes(20).toSeq === expected.map(s => (s._1, s._2, s._3))) + + // against non-exist key + assert(iterateTimes(100).toSeq === Seq.empty) + + // find first + + assert(findFirstTime(20, time => time > 0) === Some(1)) + assert(findFirstTime(20, time => time > 3) === Some(4)) + assert(findFirstTime(20, time => time > 5) === None) + + // using start time to skip elements + assert(findFirstTime(20, time => time > 0, startTime = 3) === Some(3)) + assert(findFirstTime(20, time => time > 3, startTime = 1) === Some(4)) + intercept[IllegalArgumentException] { + findFirstTime(20, time => time > 3, startTime = 7) + } + + // against non-exist key + assert(findFirstTime(100, time => time > 1) === None) + } + } + + test("remove by watermark - stop on condition mismatch == true") { + removeByWatermarkTest(stopOnConditionMismatch = true) + } + + test("remove by watermark - stop on condition mismatch == false") { + removeByWatermarkTest(stopOnConditionMismatch = false) + } + + private def removeByWatermarkTest(stopOnConditionMismatch: Boolean): Unit = { + withSessionWindowLinkedListState(inputValueAttribs, keyExprs) { state => + implicit val st = state + assert(numRows === 0) + + setHead(20, 1, time = 1) + addAfter(20, 2, time = 2, targetTime = 1) + addAfter(20, 3, time = 3, targetTime = 2) + addAfter(20, 4, time = 4, targetTime = 3) + + setHead(30, 5, time = 5) + addAfter(30, 6, time = 6, targetTime = 5) + addAfter(30, 7, time = 7, targetTime = 6) + addAfter(30, 8, time = 8, targetTime = 7) + + setHead(40, 10, time = 10) + addAfter(40, 11, time = 11, targetTime = 10) + addAfter(40, 12, time = 12, targetTime = 11) + addAfter(40, 13, time = 13, targetTime = 12) + + assert(numRows === 12) + + // must keep input order per key + val groupedTuples = removeByValue(6, stopOnConditionMismatch).groupBy(_._1) + assert(groupedTuples(20).map(_._2) === Seq(1, 2, 3, 4)) + assert(groupedTuples(30).map(_._2) === Seq(5, 6)) + assert(groupedTuples.get(40).isEmpty) + + assert(get(20) === Seq.empty) + assert(get(30) === Seq(7, 8)) + assert(get(40) === Seq(10, 11, 12, 13)) + assert(numRows === 6) + } + } + + val watermarkMetadata = new MetadataBuilder().putLong(EventTimeWatermark.delayKey, 10).build() + val inputValueSchema = new StructType() + .add(StructField("time", IntegerType, metadata = watermarkMetadata)) + .add(StructField("value", BooleanType)) + val inputValueAttribs = inputValueSchema.toAttributes + val inputValueAttribWithWatermark = inputValueAttribs(0) + val keyExprs = Seq[Expression](Literal(false), inputValueAttribWithWatermark, Literal(10.0)) + + val inputValueGen = UnsafeProjection.create(inputValueAttribs.map(_.dataType).toArray) + val keyGen = UnsafeProjection.create(keyExprs.map(_.dataType).toArray) + + def toInputValue(i: Int): UnsafeRow = { + inputValueGen.apply(new GenericInternalRow(Array[Any](i, false))) + } + + def toKeyRow(i: Int): UnsafeRow = { + keyGen.apply(new GenericInternalRow(Array[Any](false, i, 10.0))) + } + + def toKeyInt(inputKeyRow: UnsafeRow): Int = inputKeyRow.getInt(1) + + def toValueInt(inputValueRow: UnsafeRow): Int = inputValueRow.getInt(0) + + def setHead(key: Int, value: Int, time: Int) + (implicit state: SessionWindowLinkedListState): Unit = { + state.setHead(toKeyRow(key), time, toInputValue(value)) + } + + def addBefore(key: Int, value: Int, time: Int, targetTime: Int) + (implicit state: SessionWindowLinkedListState): Unit = { + state.addBefore(toKeyRow(key), time, toInputValue(value), targetTime) + } + + def addAfter(key: Int, value: Int, time: Int, targetTime: Int) + (implicit state: SessionWindowLinkedListState): Unit = { + state.addAfter(toKeyRow(key), time, toInputValue(value), targetTime) + } + + def update(key: Int, value: Int, time: Int) + (implicit state: SessionWindowLinkedListState): Unit = { + state.update(toKeyRow(key), time, toInputValue(value)) + } + + def remove(key: Int, time: Int)(implicit state: SessionWindowLinkedListState): Unit = { + state.remove(toKeyRow(key), time) + } + + def get(key: Int)(implicit state: SessionWindowLinkedListState): Seq[Int] = { + state.get(toKeyRow(key)).map(toValueInt).toSeq + } + + def iterateTimes(key: Int)(implicit state: SessionWindowLinkedListState) + : Iterator[(Int, Option[Int], Option[Int])] = { + state.iteratePointers(toKeyRow(key)).map { s => + (s._1.toInt, s._2.map(_.toInt), s._3.map(_.toInt)) + } + } + + def getPrevTime(key: Int, time: Int)(implicit state: SessionWindowLinkedListState) + : Option[Int] = { + state.getPrevSessionStart(toKeyRow(key), time).map(_.toInt) + } + + def getNextTime(key: Int, time: Int)(implicit state: SessionWindowLinkedListState) + : Option[Int] = { + state.getNextSessionStart(toKeyRow(key), time).map(_.toInt) + } + + def findFirstTime(key: Int, predicate: Int => Boolean) + (implicit state: SessionWindowLinkedListState): Option[Int] = { + val ret = state.findFirstSessionStartEnsurePredicate( + toKeyRow(key), (s: Long) => predicate.apply(s.intValue())) + ret.map(_.intValue()) + } + + def findFirstTime(key: Int, predicate: Int => Boolean, startTime: Int) + (implicit state: SessionWindowLinkedListState): Option[Int] = { + val ret = state.findFirstSessionStartEnsurePredicate( + toKeyRow(key), (s: Long) => predicate.apply(s.intValue()), startTime) + ret.map(_.intValue()) + } + + def getAllRowPairs(implicit state: SessionWindowLinkedListState): Seq[(Int, Int)] = { + state.getAllRowPairs + .map(pair => (toKeyInt(pair.key), toValueInt(pair.value))) + .toSeq + } + + /** Remove values where `time <= threshold` */ + def removeByValue(watermark: Long, stopOnConditionMismatch: Boolean) + (implicit state: SessionWindowLinkedListState) + : Seq[(Int, Int)] = { + val expr = LessThanOrEqual(inputValueAttribWithWatermark, Literal(watermark)) + state.removeByValueCondition( + GeneratePredicate.generate(expr, inputValueAttribs).eval _, + stopOnConditionMismatch) + .map(pair => (toKeyInt(pair.key), toValueInt(pair.value))) + .toSeq + } + + def numRows(implicit state: SessionWindowLinkedListState): Long = { + state.metrics.numKeys + } + + def withSessionWindowLinkedListState( + inputValueAttribs: Seq[Attribute], + keyExprs: Seq[Expression])(f: SessionWindowLinkedListState => Unit): Unit = { + + withTempDir { file => + val storeConf = new StateStoreConf() + val stateInfo = StatefulOperatorStateInfo(file.getAbsolutePath, UUID.randomUUID, 0, 0, 5) + val state = new SessionWindowLinkedListState("testing", inputValueAttribs, keyExprs, + Some(stateInfo), storeConf, new Configuration) + try { + f(state) + } finally { + state.abortIfNeeded() + } + } + StateStore.stop() + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index c696204cecc2..1558669f3fd9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.{AnalysisException, Dataset} import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.functions.{count, window} +import org.apache.spark.sql.functions.{count, max, session_window, sum, window} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode._ import org.apache.spark.util.Utils diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala new file mode 100644 index 000000000000..1297b5408cf3 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala @@ -0,0 +1,508 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.scalatest.{BeforeAndAfter, Matchers} + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.functions.{count, session_window, sum} +import org.apache.spark.sql.internal.SQLConf + +class StreamingSessionWindowSuite extends StreamTest + with BeforeAndAfter with Matchers with Logging { + + import testImplicits._ + + after { + sqlContext.streams.active.foreach(_.stop()) + } + + def testWithAllOptionsMergingSessionInLocalPartition(name: String, confPairs: (String, String)*) + (func: => Any): Unit = { + val key = SQLConf.STREAMING_SESSION_WINDOW_MERGE_SESSIONS_IN_LOCAL_PARTITION.key + val availableOptions = Seq(true, false) + + for (enabled <- availableOptions) { + test(s"$name - merging sessions in local partition: $enabled") { + withSQLConf(confPairs ++ Seq(key -> enabled.toString): _*) { + func + } + } + } + } + + testWithAllOptionsMergingSessionInLocalPartition("complete mode - session window") { + // Implements StructuredSessionization.scala leveraging "session" function + // as a test, to verify the sessionization works with simple example + + // note that complete mode doesn't honor watermark: even it is specified, watermark will be + // always Unix timestamp 0 + + val inputData = MemoryStream[(String, Long)] + + // Split the lines into words, treat words as sessionId of events + val events = inputData.toDF() + .select($"_1".as("value"), $"_2".as("timestamp")) + .withColumn("eventTime", $"timestamp".cast("timestamp")) + .selectExpr("explode(split(value, ' ')) AS sessionId", "eventTime") + + val sessionUpdates = events + .groupBy(session_window($"eventTime", "10 seconds") as 'session, 'sessionId) + .agg(count("*").as("numEvents")) + .selectExpr("sessionId", "CAST(session.start AS LONG)", "CAST(session.end AS LONG)", + "CAST(session.end AS LONG) - CAST(session.start AS LONG) AS durationMs", + "numEvents") + + testStream(sessionUpdates, OutputMode.Complete())( + AddData(inputData, + ("hello world spark streaming", 40L), + ("world hello structured streaming", 41L) + ), + CheckNewAnswer( + ("hello", 40, 51, 11, 2), + ("world", 40, 51, 11, 2), + ("streaming", 40, 51, 11, 2), + ("spark", 40, 50, 10, 1), + ("structured", 41, 51, 10, 1) + ), + + // placing new sessions "before" previous sessions + AddData(inputData, ("spark streaming", 25L)), + CheckNewAnswer( + ("spark", 25, 35, 10, 1), + ("streaming", 25, 35, 10, 1), + ("hello", 40, 51, 11, 2), + ("world", 40, 51, 11, 2), + ("streaming", 40, 51, 11, 2), + ("spark", 40, 50, 10, 1), + ("structured", 41, 51, 10, 1) + ), + + // concatenating multiple previous sessions into one + AddData(inputData, ("spark streaming", 30L)), + CheckNewAnswer( + ("spark", 25, 50, 25, 3), + ("streaming", 25, 51, 26, 4), + ("hello", 40, 51, 11, 2), + ("world", 40, 51, 11, 2), + ("structured", 41, 51, 10, 1) + ), + + // placing new sessions after previous sessions + AddData(inputData, ("hello apache spark", 60L)), + CheckNewAnswer( + ("spark", 25, 50, 25, 3), + ("streaming", 25, 51, 26, 4), + ("hello", 40, 51, 11, 2), + ("world", 40, 51, 11, 2), + ("structured", 41, 51, 10, 1), + ("hello", 60, 70, 10, 1), + ("apache", 60, 70, 10, 1), + ("spark", 60, 70, 10, 1) + ), + + AddData(inputData, ("structured streaming", 90L)), + CheckNewAnswer( + ("spark", 25, 50, 25, 3), + ("streaming", 25, 51, 26, 4), + ("hello", 40, 51, 11, 2), + ("world", 40, 51, 11, 2), + ("structured", 41, 51, 10, 1), + ("hello", 60, 70, 10, 1), + ("apache", 60, 70, 10, 1), + ("spark", 60, 70, 10, 1), + ("structured", 90, 100, 10, 1), + ("streaming", 90, 100, 10, 1) + ) + ) + } + + testWithAllOptionsMergingSessionInLocalPartition("complete mode - session window - no key") { + // complete mode doesn't honor watermark: even it is specified, watermark will be + // always Unix timestamp 0 + + val inputData = MemoryStream[Int] + + val windowedAggregation = inputData.toDF() + .selectExpr("*") + .withColumn("eventTime", $"value".cast("timestamp")) + .groupBy(session_window($"eventTime", "5 seconds") as 'session) + .agg(count("*") as 'count, sum("value") as 'sum) + .select($"session".getField("start").cast("long").as[Long], + $"session".getField("end").cast("long").as[Long], $"count".as[Long], $"sum".as[Long]) + + testStream(windowedAggregation, OutputMode.Complete())( + AddData(inputData, 10, 11), + CheckNewAnswer((10, 16, 2, 21)), + + AddData(inputData, 17), + CheckNewAnswer( + (10, 16, 2, 21), + (17, 22, 1, 17) + ), + + AddData(inputData, 35), + CheckNewAnswer( + (10, 16, 2, 21), + (17, 22, 1, 17), + (35, 40, 1, 35) + ), + + // should reflect late row + AddData(inputData, 22), + CheckNewAnswer( + (10, 16, 2, 21), + (17, 27, 2, 39), + (35, 40, 1, 35) + ), + + AddData(inputData, 40), + CheckNewAnswer( + (10, 16, 2, 21), + (17, 27, 2, 39), + (35, 45, 2, 75) + ) + ) + } + + testWithAllOptionsMergingSessionInLocalPartition("append mode - session window") { + // Implements StructuredSessionization.scala leveraging "session" function + // as a test, to verify the sessionization works with simple example + + val inputData = MemoryStream[(String, Long)] + + // Split the lines into words, treat words as sessionId of events + val events = inputData.toDF() + .select($"_1".as("value"), $"_2".as("timestamp")) + .withColumn("eventTime", $"timestamp".cast("timestamp")) + .selectExpr("explode(split(value, ' ')) AS sessionId", "eventTime") + .withWatermark("eventTime", "30 seconds") + + val sessionUpdates = events + .groupBy(session_window($"eventTime", "10 seconds") as 'session, 'sessionId) + .agg(count("*").as("numEvents")) + .selectExpr("sessionId", "CAST(session.start AS LONG)", "CAST(session.end AS LONG)", + "CAST(session.end AS LONG) - CAST(session.start AS LONG) AS durationMs", + "numEvents") + + testStream(sessionUpdates, OutputMode.Append())( + AddData(inputData, + ("hello world spark streaming", 40L), + ("world hello structured streaming", 41L) + ), + + // watermark: 11 + // current sessions + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("streaming", 40, 51, 11, 2), + // ("spark", 40, 50, 10, 1), + // ("structured", 41, 51, 10, 1) + CheckNewAnswer( + ), + + // placing new sessions "before" previous sessions + AddData(inputData, ("spark streaming", 25L)), + // watermark: 11 + // current sessions + // ("spark", 25, 35, 10, 1), + // ("streaming", 25, 35, 10, 1), + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("streaming", 40, 51, 11, 2), + // ("spark", 40, 50, 10, 1), + // ("structured", 41, 51, 10, 1) + CheckNewAnswer( + ), + + // late event which session's end 10 would be later than watermark 11: should be dropped + AddData(inputData, ("spark streaming", 0L)), + // watermark: 11 + // current sessions + // ("spark", 25, 35, 10, 1), + // ("streaming", 25, 35, 10, 1), + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("streaming", 40, 51, 11, 2), + // ("spark", 40, 50, 10, 1), + // ("structured", 41, 51, 10, 1) + CheckNewAnswer( + ), + + // concatenating multiple previous sessions into one + AddData(inputData, ("spark streaming", 30L)), + // watermark: 11 + // current sessions + // ("spark", 25, 50, 25, 3), + // ("streaming", 25, 51, 26, 4), + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("structured", 41, 51, 10, 1) + CheckNewAnswer( + ), + + // placing new sessions after previous sessions + AddData(inputData, ("hello apache spark", 60L)), + // watermark: 30 + // current sessions + // ("spark", 25, 50, 25, 3), + // ("streaming", 25, 51, 26, 4), + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("structured", 41, 51, 10, 1), + // ("hello", 60, 70, 10, 1), + // ("apache", 60, 70, 10, 1), + // ("spark", 60, 70, 10, 1) + CheckNewAnswer( + ), + + AddData(inputData, ("structured streaming", 90L)), + // watermark: 60 + // current sessions + // ("hello", 60, 70, 10, 1), + // ("apache", 60, 70, 10, 1), + // ("spark", 60, 70, 10, 1), + // ("structured", 90, 100, 10, 1), + // ("streaming", 90, 100, 10, 1) + CheckNewAnswer( + ("spark", 25, 50, 25, 3), + ("streaming", 25, 51, 26, 4), + ("hello", 40, 51, 11, 2), + ("world", 40, 51, 11, 2), + ("structured", 41, 51, 10, 1) + ) + ) + } + + testWithAllOptionsMergingSessionInLocalPartition("append mode - session window - no key") { + val inputData = MemoryStream[Int] + + val windowedAggregation = inputData.toDF() + .selectExpr("*") + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(session_window($"eventTime", "5 seconds") as 'session) + .agg(count("*") as 'count, sum("value") as 'sum) + .select($"session".getField("start").cast("long").as[Long], + $"session".getField("end").cast("long").as[Long], $"count".as[Long], $"sum".as[Long]) + + testStream(windowedAggregation)( + AddData(inputData, 10, 11), // sessions: (10,16) + CheckNewAnswer(), + + AddData(inputData, 17), + // Advance watermark to 7 seconds + // sessions: (10,16), (17,23) + CheckNewAnswer(), + + AddData(inputData, 25), + // Advance watermark to 15 seconds + // sessions: (10,16), (17,23), (25,30) + CheckNewAnswer(), + + AddData(inputData, 35), + // Advance watermark to 25 seconds + // sessions: (10,16), (17,22), (25,30), (35,40) + // evicts: (10,16), (17,22) + CheckNewAnswer((10, 16, 2, 21), (17, 22, 1, 17)), + + AddData(inputData, 10), // Should not emit anything as data less than watermark + CheckNewAnswer(), + + AddData(inputData, 40), + // Advance watermark to 30 seconds + // sessions: (25,30) / (35,45) + // evicts: (25,30) + CheckNewAnswer((25, 30, 1, 25)) + ) + } + + testWithAllOptionsMergingSessionInLocalPartition("update mode - session window") { + // Implements StructuredSessionization.scala leveraging "session" function + // as a test, to verify the sessionization works with simple example + + val inputData = MemoryStream[(String, Long)] + + // Split the lines into words, treat words as sessionId of events + val events = inputData.toDF() + .select($"_1".as("value"), $"_2".as("timestamp")) + .withColumn("eventTime", $"timestamp".cast("timestamp")) + .selectExpr("explode(split(value, ' ')) AS sessionId", "eventTime") + .withWatermark("eventTime", "10 seconds") + + val sessionUpdates = events + .groupBy(session_window($"eventTime", "10 seconds") as 'session, 'sessionId) + .agg(count("*").as("numEvents")) + .selectExpr("sessionId", "CAST(session.start AS LONG)", "CAST(session.end AS LONG)", + "CAST(session.end AS LONG) - CAST(session.start AS LONG) AS durationMs", + "numEvents") + + testStream(sessionUpdates, OutputMode.Update())( + AddData(inputData, + ("hello world spark streaming", 40L), + ("world hello structured streaming", 41L) + ), + // watermark: 11 + // current sessions + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("streaming", 40, 51, 11, 2), + // ("spark", 40, 50, 10, 1), + // ("structured", 41, 51, 10, 1) + CheckNewAnswer( + ("hello", 40, 51, 11, 2), + ("world", 40, 51, 11, 2), + ("streaming", 40, 51, 11, 2), + ("spark", 40, 50, 10, 1), + ("structured", 41, 51, 10, 1) + ), + + // placing new sessions "before" previous sessions + AddData(inputData, ("spark streaming", 25L)), + // watermark: 11 + // current sessions + // ("spark", 25, 35, 10, 1), + // ("streaming", 25, 35, 10, 1), + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("streaming", 40, 51, 11, 2), + // ("spark", 40, 50, 10, 1), + // ("structured", 41, 51, 10, 1) + CheckNewAnswer( + ("spark", 25, 35, 10, 1), + ("streaming", 25, 35, 10, 1) + ), + + // late event which session's end 10 would be later than watermark 11: should be dropped + AddData(inputData, ("spark streaming", 0L)), + // watermark: 11 + // current sessions + // ("spark", 25, 35, 10, 1), + // ("streaming", 25, 35, 10, 1), + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("streaming", 40, 51, 11, 2), + // ("spark", 40, 50, 10, 1), + // ("structured", 41, 51, 10, 1) + CheckNewAnswer( + ), + + // concatenating multiple previous sessions into one + AddData(inputData, ("spark streaming", 30L)), + // watermark: 11 + // current sessions + // ("spark", 25, 50, 25, 3), + // ("streaming", 25, 51, 26, 4), + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("structured", 41, 51, 10, 1) + CheckNewAnswer( + ("spark", 25, 50, 25, 3), + ("streaming", 25, 51, 26, 4) + ), + + // placing new sessions after previous sessions + AddData(inputData, ("hello apache spark", 60L)), + // watermark: 30 + // current sessions + // ("spark", 25, 50, 25, 3), + // ("streaming", 25, 51, 26, 4), + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("structured", 41, 51, 10, 1), + // ("hello", 60, 70, 10, 1), + // ("apache", 60, 70, 10, 1), + // ("spark", 60, 70, 10, 1) + CheckNewAnswer( + ("hello", 60, 70, 10, 1), + ("apache", 60, 70, 10, 1), + ("spark", 60, 70, 10, 1) + ), + + AddData(inputData, ("structured streaming", 90L)), + // watermark: 60 + // current sessions + // ("hello", 60, 70, 10, 1), + // ("apache", 60, 70, 10, 1), + // ("spark", 60, 70, 10, 1), + // ("structured", 90, 100, 10, 1), + // ("streaming", 90, 100, 10, 1) + // evicted + // ("spark", 25, 50, 25, 3), + // ("streaming", 25, 51, 26, 4), + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("structured", 41, 51, 10, 1) + CheckNewAnswer( + ("structured", 90, 100, 10, 1), + ("streaming", 90, 100, 10, 1) + ) + ) + } + + testWithAllOptionsMergingSessionInLocalPartition("update mode - session window - no key") { + val inputData = MemoryStream[Int] + + val windowedAggregation = inputData.toDF() + .selectExpr("*") + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(session_window($"eventTime", "5 seconds") as 'session) + .agg(count("*") as 'count, sum("value") as 'sum) + .select($"session".getField("start").cast("long").as[Long], + $"session".getField("end").cast("long").as[Long], $"count".as[Long], $"sum".as[Long]) + + testStream(windowedAggregation, OutputMode.Update())( + + AddData(inputData, 10, 11), + // Advance watermark to 1 seconds + // sessions: (10,16) + CheckNewAnswer((10, 16, 2, 21)), + + AddData(inputData, 17), + // Advance watermark to 7 seconds + // sessions: (10,16), (17,22) + // updated: (17,22) + CheckNewAnswer((17, 22, 1, 17)), + + AddData(inputData, 25), + // Advance watermark to 15 seconds + // sessions: (10,16), (17,22), (25,30) + // updated: (25,30) + CheckNewAnswer((25, 30, 1, 25)), + + AddData(inputData, 35), + // Advance watermark to 25 seconds + // sessions: (10,16), (17,22), (25,30), (35,40) + // updated: (35, 40) + // evicts: (10,16), (17,22) + CheckNewAnswer((35, 40, 1, 35)), + + AddData(inputData, 10), // Should not emit anything as data less than watermark + CheckNewAnswer(), + + AddData(inputData, 40), + // Advance watermark to 30 seconds + // sessions: (25,30), (35,45) + // updated: (35, 45) + CheckNewAnswer((35, 45, 2, 75)) + ) + } + +}