Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 90 additions & 19 deletions mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
package org.apache.spark.ml.fpm

import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param._
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.fpm.{PrefixSpan => mllibPrefixSpan}
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions.col
Expand All @@ -35,7 +37,87 @@ import org.apache.spark.sql.types.{ArrayType, LongType, StructField, StructType}
*/
@Since("2.4.0")
@Experimental
object PrefixSpan {
final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the doc, mention that this class is not yet an Estimator/Transformer and link to findFrequentSequentialPatterns method.


@Since("2.4.0")
def this() = this(Identifiable.randomUID("prefixSpan"))

/**
* the minimal support level of the sequential pattern, any pattern that
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use uppercase for the first char:

"""
Param for the minimal support level (default: 0.1).
Sequential patterns that appear more than (minSupport * size-of-the-dataset) times are identified as frequent sequential patterns.
"""

* appears more than (minSupport * size-of-the-dataset) times will be output
* (default value: `0.1`).
* @group param
*/
@Since("2.4.0")
val minSupport = new DoubleParam(this, "minSupport", "the minimal support level of the " +
"sequential pattern, any pattern that appears more than (minSupport * size-of-the-dataset) " +
"times will be output", ParamValidators.gt(0.0))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gt -> gtEq


/** @group getParam */
@Since("2.4.0")
def getMinSupport: Double = $(minSupport)

/**
* Set the minSupport parameter.
* Default is 1.0.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default is wrong. We don't need doc for the setters and getters. Just leave "@group setParam".

*
* @group setParam
*/
@Since("1.3.0")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also wrong. The method is new.

def setMinSupport(value: Double): this.type = set(minSupport, value)

/**
* the maximal length of the sequential pattern
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Param for the maximal pattern length (default: `10`).

Just copy the doc from mllib.fpm.PrefixSpan.

* (default value: `10`).
* @group param
*/
@Since("2.4.0")
val maxPatternLength = new IntParam(this, "maxPatternLength",
"the maximal length of the sequential pattern",
ParamValidators.gt(0))

/** @group getParam */
@Since("2.4.0")
def getMaxPatternLength: Double = $(maxPatternLength)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return Int


/**
* Set the maxPatternLength parameter.
* Default is 10.
*
* @group setParam
*/
@Since("2.4.0")
def setMaxPatternLength(value: Int): this.type = set(maxPatternLength, value)

/**
* The maximum number of items (including delimiters used in the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here. Just copy the doc from mllib.fpm.PrefixSpan.

* internal storage format) allowed in a projected database before
* local processing. If a projected database exceeds this size, another
* iteration of distributed prefix growth is run
* (default value: `32000000`).
* @group param
*/
@Since("2.4.0")
val maxLocalProjDBSize = new LongParam(this, "maxLocalProjDBSize",
"The maximum number of items (including delimiters used in the internal storage format) " +
"allowed in a projected database before local processing. If a projected database exceeds " +
"this size, another iteration of distributed prefix growth is run",
ParamValidators.gt(0))

/** @group getParam */
@Since("2.4.0")
def getMaxLocalProjDBSize: Double = $(maxLocalProjDBSize)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Long


/**
* Set the maxLocalProjDBSize parameter.
* Default is 32000000.
*
* @group setParam
*/
@Since("2.4.0")
def setMaxLocalProjDBSize(value: Long): this.type = set(maxLocalProjDBSize, value)

setDefault(minSupport -> 0.1, maxPatternLength -> 10, maxLocalProjDBSize -> 32000000)

/**
* :: Experimental ::
Expand All @@ -45,16 +127,6 @@ object PrefixSpan {
* {{{Seq[Seq[_]]}}} type
* @param sequenceCol the name of the sequence column in dataset, rows with nulls in this column
* are ignored
* @param minSupport the minimal support level of the sequential pattern, any pattern that
* appears more than (minSupport * size-of-the-dataset) times will be output
* (recommended value: `0.1`).
* @param maxPatternLength the maximal length of the sequential pattern
* (recommended value: `10`).
* @param maxLocalProjDBSize The maximum number of items (including delimiters used in the
* internal storage format) allowed in a projected database before
* local processing. If a projected database exceeds this size, another
* iteration of distributed prefix growth is run
* (recommended value: `32000000`).
* @return A `DataFrame` that contains columns of sequence and corresponding frequency.
* The schema of it will be:
* - `sequence: Seq[Seq[T]]` (T is the item type)
Expand All @@ -63,26 +135,22 @@ object PrefixSpan {
@Since("2.4.0")
def findFrequentSequentialPatterns(
dataset: Dataset[_],
sequenceCol: String,
minSupport: Double,
maxPatternLength: Int,
maxLocalProjDBSize: Long): DataFrame = {
sequenceCol: String): DataFrame = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not making it a param?


val inputType = dataset.schema(sequenceCol).dataType
require(inputType.isInstanceOf[ArrayType] &&
inputType.asInstanceOf[ArrayType].elementType.isInstanceOf[ArrayType],
s"The input column must be ArrayType and the array element type must also be ArrayType, " +
s"but got $inputType.")


val data = dataset.select(sequenceCol)
val sequences = data.where(col(sequenceCol).isNotNull).rdd
.map(r => r.getAs[Seq[Seq[Any]]](0).map(_.toArray).toArray)

val mllibPrefixSpan = new mllibPrefixSpan()
.setMinSupport(minSupport)
.setMaxPatternLength(maxPatternLength)
.setMaxLocalProjDBSize(maxLocalProjDBSize)
.setMinSupport($(minSupport))
.setMaxPatternLength($(maxPatternLength))
.setMaxLocalProjDBSize($(maxLocalProjDBSize))

val rows = mllibPrefixSpan.run(sequences).freqSequences.map(f => Row(f.sequence, f.freq))
val schema = StructType(Seq(
Expand All @@ -93,4 +161,7 @@ object PrefixSpan {
freqSequences
}

@Since("2.4.0")
override def copy(extra: ParamMap): PrefixSpan = defaultCopy(extra)

}
28 changes: 20 additions & 8 deletions mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,11 @@ class PrefixSpanSuite extends MLTest {

test("PrefixSpan projections with multiple partial starts") {
val smallDataset = Seq(Seq(Seq(1, 2), Seq(1, 2, 3))).toDF("sequence")
val result = PrefixSpan.findFrequentSequentialPatterns(smallDataset, "sequence",
minSupport = 1.0, maxPatternLength = 2, maxLocalProjDBSize = 32000000)
val result = new PrefixSpan()
.setMinSupport(1.0)
.setMaxPatternLength(2)
.setMaxLocalProjDBSize(32000000)
.findFrequentSequentialPatterns(smallDataset, "sequence")
.as[(Seq[Seq[Int]], Long)].collect()
val expected = Array(
(Seq(Seq(1)), 1L),
Expand Down Expand Up @@ -90,17 +93,23 @@ class PrefixSpanSuite extends MLTest {

test("PrefixSpan Integer type, variable-size itemsets") {
val df = smallTestData.toDF("sequence")
val result = PrefixSpan.findFrequentSequentialPatterns(df, "sequence",
minSupport = 0.5, maxPatternLength = 5, maxLocalProjDBSize = 32000000)
val result = new PrefixSpan()
.setMinSupport(0.5)
.setMaxPatternLength(5)
.setMaxLocalProjDBSize(32000000)
.findFrequentSequentialPatterns(df, "sequence")
.as[(Seq[Seq[Int]], Long)].collect()

compareResults[Int](smallTestDataExpectedResult, result)
}

test("PrefixSpan input row with nulls") {
val df = (smallTestData :+ null).toDF("sequence")
val result = PrefixSpan.findFrequentSequentialPatterns(df, "sequence",
minSupport = 0.5, maxPatternLength = 5, maxLocalProjDBSize = 32000000)
val result = new PrefixSpan()
.setMinSupport(0.5)
.setMaxPatternLength(5)
.setMaxLocalProjDBSize(32000000)
.findFrequentSequentialPatterns(df, "sequence")
.as[(Seq[Seq[Int]], Long)].collect()

compareResults[Int](smallTestDataExpectedResult, result)
Expand All @@ -111,8 +120,11 @@ class PrefixSpanSuite extends MLTest {
val df = smallTestData
.map(seq => seq.map(itemSet => itemSet.map(intToString)))
.toDF("sequence")
val result = PrefixSpan.findFrequentSequentialPatterns(df, "sequence",
minSupport = 0.5, maxPatternLength = 5, maxLocalProjDBSize = 32000000)
val result = new PrefixSpan()
.setMinSupport(0.5)
.setMaxPatternLength(5)
.setMaxLocalProjDBSize(32000000)
.findFrequentSequentialPatterns(df, "sequence")
.as[(Seq[Seq[String]], Long)].collect()

val expected = smallTestDataExpectedResult.map { case (seq, freq) =>
Expand Down