Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
address comments
  • Loading branch information
WeichenXu123 committed May 23, 2018
commit 90d71e84f36075aeaab19b496eee87792877c48b
92 changes: 46 additions & 46 deletions mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ import org.apache.spark.sql.types.{ArrayType, LongType, StructField, StructType}
* The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns
* Efficiently by Prefix-Projected Pattern Growth
* (see <a href="http://doi.org/10.1109/ICDE.2001.914830">here</a>).
* This class is not yet an Estimator/Transformer, use `findFrequentSequentialPatterns` method to
* run the PrefixSpan algorithm.
*
* @see <a href="https://en.wikipedia.org/wiki/Sequential_Pattern_Mining">Sequential Pattern Mining
* (Wikipedia)</a>
Expand All @@ -43,108 +45,106 @@ final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params
def this() = this(Identifiable.randomUID("prefixSpan"))

/**
* the minimal support level of the sequential pattern, any pattern that
* appears more than (minSupport * size-of-the-dataset) times will be output
* (default value: `0.1`).
* Param for the minimal support level (default: `0.1`).
* Sequential patterns that appear more than (minSupport * size-of-the-dataset) times are
* identified as frequent sequential patterns.
* @group param
*/
@Since("2.4.0")
val minSupport = new DoubleParam(this, "minSupport", "the minimal support level of the " +
"sequential pattern, any pattern that appears more than (minSupport * size-of-the-dataset) " +
"times will be output", ParamValidators.gt(0.0))
val minSupport = new DoubleParam(this, "minSupport", "The minimal support level of the " +
"sequential pattern. Sequential pattern that appears more than " +
"(minSupport * size-of-the-dataset)." +
"times will be output.", ParamValidators.gtEq(0.0))

/** @group getParam */
@Since("2.4.0")
def getMinSupport: Double = $(minSupport)

/**
* Set the minSupport parameter.
* Default is 1.0.
*
* @group setParam
*/
@Since("1.3.0")
/** @group setParam */
@Since("2.4.0")
def setMinSupport(value: Double): this.type = set(minSupport, value)

/**
* the maximal length of the sequential pattern
* (default value: `10`).
* Param for the maximal pattern length (default: `10`).
* @group param
*/
@Since("2.4.0")
val maxPatternLength = new IntParam(this, "maxPatternLength",
"the maximal length of the sequential pattern",
"The maximal length of the sequential pattern.",
ParamValidators.gt(0))

/** @group getParam */
@Since("2.4.0")
def getMaxPatternLength: Double = $(maxPatternLength)
def getMaxPatternLength: Int = $(maxPatternLength)

/**
* Set the maxPatternLength parameter.
* Default is 10.
*
* @group setParam
*/
/** @group setParam */
@Since("2.4.0")
def setMaxPatternLength(value: Int): this.type = set(maxPatternLength, value)

/**
* The maximum number of items (including delimiters used in the
* internal storage format) allowed in a projected database before
* local processing. If a projected database exceeds this size, another
* iteration of distributed prefix growth is run
* (default value: `32000000`).
* Param for the maximum number of items (including delimiters used in the internal storage
* format) allowed in a projected database before local processing (default: `32000000`).
* If a projected database exceeds this size, another iteration of distributed prefix growth
* is run.
* @group param
*/
@Since("2.4.0")
val maxLocalProjDBSize = new LongParam(this, "maxLocalProjDBSize",
"The maximum number of items (including delimiters used in the internal storage format) " +
"allowed in a projected database before local processing. If a projected database exceeds " +
"this size, another iteration of distributed prefix growth is run",
"this size, another iteration of distributed prefix growth is run.",
ParamValidators.gt(0))

/** @group getParam */
@Since("2.4.0")
def getMaxLocalProjDBSize: Double = $(maxLocalProjDBSize)
def getMaxLocalProjDBSize: Long = $(maxLocalProjDBSize)

/** @group setParam */
@Since("2.4.0")
def setMaxLocalProjDBSize(value: Long): this.type = set(maxLocalProjDBSize, value)

/**
* Set the maxLocalProjDBSize parameter.
* Default is 32000000.
*
* @group setParam
* Param for the name of the sequence column in dataset, rows with nulls in this column
* are ignored.
* @group param
*/
@Since("2.4.0")
def setMaxLocalProjDBSize(value: Long): this.type = set(maxLocalProjDBSize, value)
val sequenceCol = new Param[String](this, "sequenceCol", "The name of the sequence column in " +
"dataset, rows with nulls in this column are ignored.")

setDefault(minSupport -> 0.1, maxPatternLength -> 10, maxLocalProjDBSize -> 32000000)
/** @group getParam */
@Since("2.4.0")
def getSequenceCol: String = $(sequenceCol)

/** @group setParam */
@Since("2.4.0")
def setSequenceCol(value: String): this.type = set(sequenceCol, value)

setDefault(minSupport -> 0.1, maxPatternLength -> 10, maxLocalProjDBSize -> 32000000,
sequenceCol -> "sequence")

/**
* :: Experimental ::
* Finds the complete set of frequent sequential patterns in the input sequences of itemsets.
*
* @param dataset A dataset or a dataframe containing a sequence column which is
* {{{Seq[Seq[_]]}}} type
* @param sequenceCol the name of the sequence column in dataset, rows with nulls in this column
* are ignored
* @return A `DataFrame` that contains columns of sequence and corresponding frequency.
* The schema of it will be:
* - `sequence: Seq[Seq[T]]` (T is the item type)
* - `freq: Long`
*/
@Since("2.4.0")
def findFrequentSequentialPatterns(
dataset: Dataset[_],
sequenceCol: String): DataFrame = {

val inputType = dataset.schema(sequenceCol).dataType
def findFrequentSequentialPatterns(dataset: Dataset[_]): DataFrame = {
val sequenceColParam = $(sequenceCol)
val inputType = dataset.schema(sequenceColParam).dataType
require(inputType.isInstanceOf[ArrayType] &&
inputType.asInstanceOf[ArrayType].elementType.isInstanceOf[ArrayType],
s"The input column must be ArrayType and the array element type must also be ArrayType, " +
s"but got $inputType.")

val data = dataset.select(sequenceCol)
val sequences = data.where(col(sequenceCol).isNotNull).rdd
val data = dataset.select(sequenceColParam)
val sequences = data.where(col(sequenceColParam).isNotNull).rdd
.map(r => r.getAs[Seq[Seq[Any]]](0).map(_.toArray).toArray)

val mllibPrefixSpan = new mllibPrefixSpan()
Expand All @@ -154,7 +154,7 @@ final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params

val rows = mllibPrefixSpan.run(sequences).freqSequences.map(f => Row(f.sequence, f.freq))
val schema = StructType(Seq(
StructField("sequence", dataset.schema(sequenceCol).dataType, nullable = false),
StructField("sequence", dataset.schema(sequenceColParam).dataType, nullable = false),
StructField("freq", LongType, nullable = false)))
val freqSequences = dataset.sparkSession.createDataFrame(rows, schema)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class PrefixSpanSuite extends MLTest {
.setMinSupport(1.0)
.setMaxPatternLength(2)
.setMaxLocalProjDBSize(32000000)
.findFrequentSequentialPatterns(smallDataset, "sequence")
.findFrequentSequentialPatterns(smallDataset)
.as[(Seq[Seq[Int]], Long)].collect()
val expected = Array(
(Seq(Seq(1)), 1L),
Expand Down Expand Up @@ -97,7 +97,7 @@ class PrefixSpanSuite extends MLTest {
.setMinSupport(0.5)
.setMaxPatternLength(5)
.setMaxLocalProjDBSize(32000000)
.findFrequentSequentialPatterns(df, "sequence")
.findFrequentSequentialPatterns(df)
.as[(Seq[Seq[Int]], Long)].collect()

compareResults[Int](smallTestDataExpectedResult, result)
Expand All @@ -109,7 +109,7 @@ class PrefixSpanSuite extends MLTest {
.setMinSupport(0.5)
.setMaxPatternLength(5)
.setMaxLocalProjDBSize(32000000)
.findFrequentSequentialPatterns(df, "sequence")
.findFrequentSequentialPatterns(df)
.as[(Seq[Seq[Int]], Long)].collect()

compareResults[Int](smallTestDataExpectedResult, result)
Expand All @@ -124,7 +124,7 @@ class PrefixSpanSuite extends MLTest {
.setMinSupport(0.5)
.setMaxPatternLength(5)
.setMaxLocalProjDBSize(32000000)
.findFrequentSequentialPatterns(df, "sequence")
.findFrequentSequentialPatterns(df)
.as[(Seq[Seq[String]], Long)].collect()

val expected = smallTestDataExpectedResult.map { case (seq, freq) =>
Expand Down