Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 99 additions & 28 deletions mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
package org.apache.spark.ml.fpm

import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param._
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.fpm.{PrefixSpan => mllibPrefixSpan}
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions.col
Expand All @@ -29,68 +31,137 @@ import org.apache.spark.sql.types.{ArrayType, LongType, StructField, StructType}
* The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns
* Efficiently by Prefix-Projected Pattern Growth
* (see <a href="http://doi.org/10.1109/ICDE.2001.914830">here</a>).
* This class is not yet an Estimator/Transformer, use `findFrequentSequentialPatterns` method to
* run the PrefixSpan algorithm.
*
* @see <a href="https://en.wikipedia.org/wiki/Sequential_Pattern_Mining">Sequential Pattern Mining
* (Wikipedia)</a>
*/
@Since("2.4.0")
@Experimental
object PrefixSpan {
final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the doc, mention that this class is not yet an Estimator/Transformer and link to findFrequentSequentialPatterns method.


@Since("2.4.0")
def this() = this(Identifiable.randomUID("prefixSpan"))

/**
* Param for the minimal support level (default: `0.1`).
* Sequential patterns that appear more than (minSupport * size-of-the-dataset) times are
* identified as frequent sequential patterns.
* @group param
*/
@Since("2.4.0")
val minSupport = new DoubleParam(this, "minSupport", "The minimal support level of the " +
"sequential pattern. Sequential pattern that appears more than " +
"(minSupport * size-of-the-dataset)." +
"times will be output.", ParamValidators.gtEq(0.0))

/** @group getParam */
@Since("2.4.0")
def getMinSupport: Double = $(minSupport)

/** @group setParam */
@Since("2.4.0")
def setMinSupport(value: Double): this.type = set(minSupport, value)

/**
* Param for the maximal pattern length (default: `10`).
* @group param
*/
@Since("2.4.0")
val maxPatternLength = new IntParam(this, "maxPatternLength",
"The maximal length of the sequential pattern.",
ParamValidators.gt(0))

/** @group getParam */
@Since("2.4.0")
def getMaxPatternLength: Int = $(maxPatternLength)

/** @group setParam */
@Since("2.4.0")
def setMaxPatternLength(value: Int): this.type = set(maxPatternLength, value)

/**
* Param for the maximum number of items (including delimiters used in the internal storage
* format) allowed in a projected database before local processing (default: `32000000`).
* If a projected database exceeds this size, another iteration of distributed prefix growth
* is run.
* @group param
*/
@Since("2.4.0")
val maxLocalProjDBSize = new LongParam(this, "maxLocalProjDBSize",
"The maximum number of items (including delimiters used in the internal storage format) " +
"allowed in a projected database before local processing. If a projected database exceeds " +
"this size, another iteration of distributed prefix growth is run.",
ParamValidators.gt(0))

/** @group getParam */
@Since("2.4.0")
def getMaxLocalProjDBSize: Long = $(maxLocalProjDBSize)

/** @group setParam */
@Since("2.4.0")
def setMaxLocalProjDBSize(value: Long): this.type = set(maxLocalProjDBSize, value)

/**
* Param for the name of the sequence column in dataset (default "sequence"), rows with
* nulls in this column are ignored.
* @group param
*/
@Since("2.4.0")
val sequenceCol = new Param[String](this, "sequenceCol", "The name of the sequence column in " +
"dataset, rows with nulls in this column are ignored.")

/** @group getParam */
@Since("2.4.0")
def getSequenceCol: String = $(sequenceCol)

/** @group setParam */
@Since("2.4.0")
def setSequenceCol(value: String): this.type = set(sequenceCol, value)

setDefault(minSupport -> 0.1, maxPatternLength -> 10, maxLocalProjDBSize -> 32000000,
sequenceCol -> "sequence")

/**
* :: Experimental ::
* Finds the complete set of frequent sequential patterns in the input sequences of itemsets.
*
* @param dataset A dataset or a dataframe containing a sequence column which is
* {{{Seq[Seq[_]]}}} type
* @param sequenceCol the name of the sequence column in dataset, rows with nulls in this column
* are ignored
* @param minSupport the minimal support level of the sequential pattern, any pattern that
* appears more than (minSupport * size-of-the-dataset) times will be output
* (recommended value: `0.1`).
* @param maxPatternLength the maximal length of the sequential pattern
* (recommended value: `10`).
* @param maxLocalProjDBSize The maximum number of items (including delimiters used in the
* internal storage format) allowed in a projected database before
* local processing. If a projected database exceeds this size, another
* iteration of distributed prefix growth is run
* (recommended value: `32000000`).
* @return A `DataFrame` that contains columns of sequence and corresponding frequency.
* The schema of it will be:
* - `sequence: Seq[Seq[T]]` (T is the item type)
* - `freq: Long`
*/
@Since("2.4.0")
def findFrequentSequentialPatterns(
dataset: Dataset[_],
sequenceCol: String,
minSupport: Double,
maxPatternLength: Int,
maxLocalProjDBSize: Long): DataFrame = {

val inputType = dataset.schema(sequenceCol).dataType
def findFrequentSequentialPatterns(dataset: Dataset[_]): DataFrame = {
val sequenceColParam = $(sequenceCol)
val inputType = dataset.schema(sequenceColParam).dataType
require(inputType.isInstanceOf[ArrayType] &&
inputType.asInstanceOf[ArrayType].elementType.isInstanceOf[ArrayType],
s"The input column must be ArrayType and the array element type must also be ArrayType, " +
s"but got $inputType.")


val data = dataset.select(sequenceCol)
val sequences = data.where(col(sequenceCol).isNotNull).rdd
val data = dataset.select(sequenceColParam)
val sequences = data.where(col(sequenceColParam).isNotNull).rdd
.map(r => r.getAs[Seq[Seq[Any]]](0).map(_.toArray).toArray)

val mllibPrefixSpan = new mllibPrefixSpan()
.setMinSupport(minSupport)
.setMaxPatternLength(maxPatternLength)
.setMaxLocalProjDBSize(maxLocalProjDBSize)
.setMinSupport($(minSupport))
.setMaxPatternLength($(maxPatternLength))
.setMaxLocalProjDBSize($(maxLocalProjDBSize))

val rows = mllibPrefixSpan.run(sequences).freqSequences.map(f => Row(f.sequence, f.freq))
val schema = StructType(Seq(
StructField("sequence", dataset.schema(sequenceCol).dataType, nullable = false),
StructField("sequence", dataset.schema(sequenceColParam).dataType, nullable = false),
StructField("freq", LongType, nullable = false)))
val freqSequences = dataset.sparkSession.createDataFrame(rows, schema)

freqSequences
}

@Since("2.4.0")
override def copy(extra: ParamMap): PrefixSpan = defaultCopy(extra)

}
28 changes: 20 additions & 8 deletions mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,11 @@ class PrefixSpanSuite extends MLTest {

test("PrefixSpan projections with multiple partial starts") {
val smallDataset = Seq(Seq(Seq(1, 2), Seq(1, 2, 3))).toDF("sequence")
val result = PrefixSpan.findFrequentSequentialPatterns(smallDataset, "sequence",
minSupport = 1.0, maxPatternLength = 2, maxLocalProjDBSize = 32000000)
val result = new PrefixSpan()
.setMinSupport(1.0)
.setMaxPatternLength(2)
.setMaxLocalProjDBSize(32000000)
.findFrequentSequentialPatterns(smallDataset)
.as[(Seq[Seq[Int]], Long)].collect()
val expected = Array(
(Seq(Seq(1)), 1L),
Expand Down Expand Up @@ -90,17 +93,23 @@ class PrefixSpanSuite extends MLTest {

test("PrefixSpan Integer type, variable-size itemsets") {
val df = smallTestData.toDF("sequence")
val result = PrefixSpan.findFrequentSequentialPatterns(df, "sequence",
minSupport = 0.5, maxPatternLength = 5, maxLocalProjDBSize = 32000000)
val result = new PrefixSpan()
.setMinSupport(0.5)
.setMaxPatternLength(5)
.setMaxLocalProjDBSize(32000000)
.findFrequentSequentialPatterns(df)
.as[(Seq[Seq[Int]], Long)].collect()

compareResults[Int](smallTestDataExpectedResult, result)
}

test("PrefixSpan input row with nulls") {
val df = (smallTestData :+ null).toDF("sequence")
val result = PrefixSpan.findFrequentSequentialPatterns(df, "sequence",
minSupport = 0.5, maxPatternLength = 5, maxLocalProjDBSize = 32000000)
val result = new PrefixSpan()
.setMinSupport(0.5)
.setMaxPatternLength(5)
.setMaxLocalProjDBSize(32000000)
.findFrequentSequentialPatterns(df)
.as[(Seq[Seq[Int]], Long)].collect()

compareResults[Int](smallTestDataExpectedResult, result)
Expand All @@ -111,8 +120,11 @@ class PrefixSpanSuite extends MLTest {
val df = smallTestData
.map(seq => seq.map(itemSet => itemSet.map(intToString)))
.toDF("sequence")
val result = PrefixSpan.findFrequentSequentialPatterns(df, "sequence",
minSupport = 0.5, maxPatternLength = 5, maxLocalProjDBSize = 32000000)
val result = new PrefixSpan()
.setMinSupport(0.5)
.setMaxPatternLength(5)
.setMaxLocalProjDBSize(32000000)
.findFrequentSequentialPatterns(df)
.as[(Seq[Seq[String]], Long)].collect()

val expected = smallTestDataExpectedResult.map { case (seq, freq) =>
Expand Down