Skip to content
Closed
35 changes: 35 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2333,6 +2333,41 @@ def check_string_field(field, fieldName):
return Column(res)


def session_window(timeColumn, gapDuration):
"""
Generates session window given a timestamp specifying column.
Session window is one of dynamic windows, which means the length of window is varying
according to the given inputs. The length of session window is defined as "the timestamp
of latest input of the session + gap duration", so when the new inputs are bound to the
current session window, the end time of session window can be expanded according to the new
inputs.
Windows can support microsecond precision. Windows in the order of months are not supported.
For a streaming query, you may use the function `current_timestamp` to generate windows on
processing time.
gapDuration is provided as strings, e.g. '1 second', '1 day 12 hours', '2 minutes'. Valid
interval strings are 'week', 'day', 'hour', 'minute', 'second', 'millisecond', 'microsecond'.
The output column will be a struct called 'session_window' by default with the nested columns
'start' and 'end', where 'start' and 'end' will be of :class:`pyspark.sql.types.TimestampType`.
.. versionadded:: 3.2.0
Examples
--------
>>> df = spark.createDataFrame([("2016-03-11 09:00:07", 1)]).toDF("date", "val")
>>> w = df.groupBy(session_window("date", "5 seconds")).agg(sum("val").alias("sum"))
>>> w.select(w.session_window.start.cast("string").alias("start"),
... w.session_window.end.cast("string").alias("end"), "sum").collect()
[Row(start='2016-03-11 09:00:07', end='2016-03-11 09:00:12', sum=1)]
"""
def check_string_field(field, fieldName):
if not field or type(field) is not str:
raise TypeError("%s should be provided as a string" % fieldName)

sc = SparkContext._active_spark_context
time_col = _to_java_column(timeColumn)
check_string_field(gapDuration, "gapDuration")
res = sc._jvm.functions.session_window(time_col, gapDuration)
return Column(res)


# ---------------------------- misc functions ----------------------------------

def crc32(col):
Expand Down
1 change: 1 addition & 0 deletions python/pyspark/sql/functions.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def window(
slideDuration: Optional[str] = ...,
startTime: Optional[str] = ...,
) -> Column: ...
def session_window(timeColumn: ColumnOrName, gapDuration: str) -> Column: ...
def crc32(col: ColumnOrName) -> Column: ...
def md5(col: ColumnOrName) -> Column: ...
def sha1(col: ColumnOrName) -> Column: ...
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ class Analyzer(override val catalogManager: CatalogManager)
GlobalAggregates ::
ResolveAggregateFunctions ::
TimeWindowing ::
SessionWindowing ::
ResolveInlineTables ::
ResolveHigherOrderFunctions(catalogManager) ::
ResolveLambdaVariables ::
Expand Down Expand Up @@ -3856,9 +3857,13 @@ object TimeWindowing extends Rule[LogicalPlan] {
val windowExpressions =
p.expressions.flatMap(_.collect { case t: TimeWindow => t }).toSet

val numWindowExpr = windowExpressions.size
val numWindowExpr = p.expressions.flatMap(_.collect {
case s: SessionWindow => s
case t: TimeWindow => t
}).toSet.size

// Only support a single window expression for now
if (numWindowExpr == 1 &&
if (numWindowExpr == 1 && windowExpressions.nonEmpty &&
windowExpressions.head.timeColumn.resolved &&
windowExpressions.head.checkInputDataTypes().isSuccess) {

Expand Down Expand Up @@ -3933,6 +3938,83 @@ object TimeWindowing extends Rule[LogicalPlan] {
}
}

/** Maps a time column to a session window. */
object SessionWindowing extends Rule[LogicalPlan] {
import org.apache.spark.sql.catalyst.dsl.expressions._

private final val SESSION_COL_NAME = "session_window"
private final val SESSION_START = "start"
private final val SESSION_END = "end"

/**
* Generates the logical plan for generating session window on a timestamp column.
* Each session window is initially defined as [timestamp, timestamp + gap).
*
* This also adds a marker to the session column so that downstream can easily find the column
* on session window.
*/
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case p: LogicalPlan if p.children.size == 1 =>
val child = p.children.head
val sessionExpressions =
p.expressions.flatMap(_.collect { case s: SessionWindow => s }).toSet

val numWindowExpr = p.expressions.flatMap(_.collect {
case s: SessionWindow => s
case t: TimeWindow => t
}).toSet.size

// Only support a single session expression for now
if (numWindowExpr == 1 && sessionExpressions.nonEmpty &&
sessionExpressions.head.timeColumn.resolved &&
sessionExpressions.head.checkInputDataTypes().isSuccess) {

val session = sessionExpressions.head

val metadata = session.timeColumn match {
case a: Attribute => a.metadata
case _ => Metadata.empty
}

val newMetadata = new MetadataBuilder()
.withMetadata(metadata)
.putBoolean(SessionWindow.marker, true)
.build()

val sessionAttr = AttributeReference(
SESSION_COL_NAME, session.dataType, metadata = newMetadata)()

val sessionStart = PreciseTimestampConversion(session.timeColumn, TimestampType, LongType)
val sessionEnd = sessionStart + session.gapDuration

val literalSessionStruct = CreateNamedStruct(
Literal(SESSION_START) ::
PreciseTimestampConversion(sessionStart, LongType, TimestampType) ::
Literal(SESSION_END) ::
PreciseTimestampConversion(sessionEnd, LongType, TimestampType) ::
Nil)

val sessionStruct = Alias(literalSessionStruct, SESSION_COL_NAME)(
exprId = sessionAttr.exprId, explicitMetadata = Some(newMetadata))

val replacedPlan = p transformExpressions {
case s: SessionWindow => sessionAttr
}

// As same as tumbling window, we add a filter to filter out nulls.
val filterExpr = IsNotNull(session.timeColumn)

replacedPlan.withNewChildren(
Project(sessionStruct +: child.output,
Filter(filterExpr, child)) :: Nil)
} else if (numWindowExpr > 1) {
throw QueryCompilationErrors.multiTimeWindowExpressionsNotSupportedError(p)
} else {
p // Return unchanged. Analyzer will throw exception later
}
}
}

/**
* Resolve expressions if they contains [[NamePlaceholder]]s.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,7 @@ object FunctionRegistry {
expression[WeekOfYear]("weekofyear"),
expression[Year]("year"),
expression[TimeWindow]("window"),
expression[SessionWindow]("session_window"),
expression[MakeDate]("make_date"),
expression[MakeTimestamp]("make_timestamp"),
expression[MakeTimestampNTZ]("make_timestamp_ntz", true),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
import org.apache.spark.sql.types._

/**
* Represent the session window.
*
* @param timeColumn the start time of session window
* @param gapDuration the duration of session gap, meaning the session will close if there is
* no new element appeared within "the last element in session + gap".
*/
case class SessionWindow(timeColumn: Expression, gapDuration: Long) extends UnaryExpression
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a few simple comments here? e.g, what gapDuration stands for.

with ImplicitCastInputTypes
with Unevaluable
with NonSQLExpression {

//////////////////////////
// SQL Constructors
//////////////////////////

def this(timeColumn: Expression, gapDuration: Expression) = {
this(timeColumn, TimeWindow.parseExpression(gapDuration))
}

override def child: Expression = timeColumn
override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType)
override def dataType: DataType = new StructType()
.add(StructField("start", TimestampType))
.add(StructField("end", TimestampType))

// This expression is replaced in the analyzer.
override lazy val resolved = false

/** Validate the inputs for the gap duration in addition to the input data type. */
override def checkInputDataTypes(): TypeCheckResult = {
val dataTypeCheck = super.checkInputDataTypes()
if (dataTypeCheck.isSuccess) {
if (gapDuration <= 0) {
return TypeCheckFailure(s"The window duration ($gapDuration) must be greater than 0.")
}
}
dataTypeCheck
}

override protected def withNewChildInternal(newChild: Expression): Expression =
copy(timeColumn = newChild)
}

object SessionWindow {
val marker = "spark.sessionWindow"

def apply(
timeColumn: Expression,
gapDuration: String): SessionWindow = {
SessionWindow(timeColumn,
TimeWindow.getIntervalInMicroSeconds(gapDuration))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ object TimeWindow {
* @return The interval duration in microseconds. SparkSQL casts TimestampType has microsecond
* precision.
*/
private def getIntervalInMicroSeconds(interval: String): Long = {
def getIntervalInMicroSeconds(interval: String): Long = {
val cal = IntervalUtils.stringToInterval(UTF8String.fromString(interval))
if (cal.months != 0) {
throw new IllegalArgumentException(
Expand All @@ -122,7 +122,7 @@ object TimeWindow {
* Parses the duration expression to generate the long value for the original constructor so
* that we can use `window` in SQL.
*/
private def parseExpression(expr: Expression): Long = expr match {
def parseExpression(expr: Expression): Long = expr match {
case NonNullLiteral(s, StringType) => getIntervalInMicroSeconds(s.toString)
case IntegerLiteral(i) => i.toLong
case NonNullLiteral(l, LongType) => l.toString.toLong
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -366,8 +366,9 @@ private[spark] object QueryCompilationErrors {
}

def multiTimeWindowExpressionsNotSupportedError(t: TreeNode[_]): Throwable = {
new AnalysisException("Multiple time window expressions would result in a cartesian product " +
"of rows, therefore they are currently not supported.", t.origin.line, t.origin.startPosition)
new AnalysisException("Multiple time/session window expressions would result in a cartesian " +
"product of rows, therefore they are currently not supported.", t.origin.line,
t.origin.startPosition)
}

def viewOutputNumberMismatchQueryColumnNamesError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1610,6 +1610,27 @@ object SQLConf {
.checkValue(v => Set(1, 2).contains(v), "Valid versions are 1 and 2")
.createWithDefault(2)

val STREAMING_SESSION_WINDOW_MERGE_SESSIONS_IN_LOCAL_PARTITION =
buildConf("spark.sql.streaming.sessionWindow.merge.sessions.in.local.partition")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about spark.sql.streaming.sessionWindow.localMerge.enabled or spark.sql.streaming.sessionWindow. mergeSessionsInLocalPartition.enabled?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comparing with the similar logic of AggUtils, maybe we can also remove this config? Just always do the local merge?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, is it necessary to have this config? Seems we can always do it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I explained on the doc method, this would incur additional "logical sort", so only useful when there're lots of input rows which are going to be consolidated into same session. The benefit is dependent on the characteristic of data.

If we want to pick one between two to simplify, it would be probably safer to remove local aggregation.

.internal()
.doc("When true, streaming session window sorts and merge sessions in local partition " +
"prior to shuffle. This is to reduce the rows to shuffle, but only beneficial when " +
"there're lots of rows in a batch being assigned to same sessions.")
.version("3.2.0")
.booleanConf
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: version

.createWithDefault(false)

val STREAMING_SESSION_WINDOW_STATE_FORMAT_VERSION =
buildConf("spark.sql.streaming.sessionWindow.stateFormatVersion")
.internal()
.doc("State format version used by streaming session window in a streaming query. " +
"State between versions are tend to be incompatible, so state format version shouldn't " +
"be modified after running.")
.version("3.2.0")
.intConf
.checkValue(v => Set(1).contains(v), "Valid version is 1")
.createWithDefault(1)

val UNSUPPORTED_OPERATION_CHECK_ENABLED =
buildConf("spark.sql.streaming.unsupportedOperationCheck")
.internal()
Expand Down Expand Up @@ -3676,6 +3697,9 @@ class SQLConf extends Serializable with Logging {

def fastHashAggregateRowMaxCapacityBit: Int = getConf(FAST_HASH_AGGREGATE_MAX_ROWS_CAPACITY_BIT)

def streamingSessionWindowMergeSessionInLocalPartition: Boolean =
getConf(STREAMING_SESSION_WINDOW_MERGE_SESSIONS_IN_LOCAL_PARTITION)

def datetimeJava8ApiEnabled: Boolean = getConf(DATETIME_JAVA8API_ENABLED)

def uiExplainMode: String = getConf(UI_EXPLAIN_MODE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
throw QueryCompilationErrors.groupAggPandasUDFUnsupportedByStreamingAggError()
}

val stateVersion = conf.getConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION)
val sessionWindowOption = namedGroupingExpressions.find { p =>
p.metadata.contains(SessionWindow.marker)
}

// Ideally this should be done in `NormalizeFloatingNumbers`, but we do it here because
// `groupingExpressions` is not extracted during logical phase.
Expand All @@ -335,12 +337,29 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
}

AggUtils.planStreamingAggregation(
normalizedGroupingExpressions,
aggregateExpressions.map(expr => expr.asInstanceOf[AggregateExpression]),
rewrittenResultExpressions,
stateVersion,
planLater(child))
sessionWindowOption match {
case Some(sessionWindow) =>
val stateVersion = conf.getConf(SQLConf.STREAMING_SESSION_WINDOW_STATE_FORMAT_VERSION)

AggUtils.planStreamingAggregationForSession(
normalizedGroupingExpressions,
sessionWindow,
aggregateExpressions.map(expr => expr.asInstanceOf[AggregateExpression]),
rewrittenResultExpressions,
stateVersion,
conf.streamingSessionWindowMergeSessionInLocalPartition,
planLater(child))

case None =>
val stateVersion = conf.getConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll remove above one. Nice finding!


AggUtils.planStreamingAggregation(
normalizedGroupingExpressions,
aggregateExpressions.map(expr => expr.asInstanceOf[AggregateExpression]),
rewrittenResultExpressions,
stateVersion,
planLater(child))
}

case _ => Nil
}
Expand Down
Loading