-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-20114][ML][FOLLOW-UP] spark.ml parity for sequential pattern mining - PrefixSpan #21393
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,6 +18,8 @@ | |
| package org.apache.spark.ml.fpm | ||
|
|
||
| import org.apache.spark.annotation.{Experimental, Since} | ||
| import org.apache.spark.ml.param._ | ||
| import org.apache.spark.ml.util.Identifiable | ||
| import org.apache.spark.mllib.fpm.{PrefixSpan => mllibPrefixSpan} | ||
| import org.apache.spark.sql.{DataFrame, Dataset, Row} | ||
| import org.apache.spark.sql.functions.col | ||
|
|
@@ -35,7 +37,87 @@ import org.apache.spark.sql.types.{ArrayType, LongType, StructField, StructType} | |
| */ | ||
| @Since("2.4.0") | ||
| @Experimental | ||
| object PrefixSpan { | ||
| final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params { | ||
|
|
||
| @Since("2.4.0") | ||
| def this() = this(Identifiable.randomUID("prefixSpan")) | ||
|
|
||
| /** | ||
| * the minimal support level of the sequential pattern, any pattern that | ||
|
||
| * appears more than (minSupport * size-of-the-dataset) times will be output | ||
| * (default value: `0.1`). | ||
| * @group param | ||
| */ | ||
| @Since("2.4.0") | ||
| val minSupport = new DoubleParam(this, "minSupport", "the minimal support level of the " + | ||
| "sequential pattern, any pattern that appears more than (minSupport * size-of-the-dataset) " + | ||
| "times will be output", ParamValidators.gt(0.0)) | ||
|
||
|
|
||
| /** @group getParam */ | ||
| @Since("2.4.0") | ||
| def getMinSupport: Double = $(minSupport) | ||
|
|
||
| /** | ||
| * Set the minSupport parameter. | ||
| * Default is 1.0. | ||
|
||
| * | ||
| * @group setParam | ||
| */ | ||
| @Since("1.3.0") | ||
|
||
| def setMinSupport(value: Double): this.type = set(minSupport, value) | ||
|
|
||
| /** | ||
| * the maximal length of the sequential pattern | ||
|
||
| * (default value: `10`). | ||
| * @group param | ||
| */ | ||
| @Since("2.4.0") | ||
| val maxPatternLength = new IntParam(this, "maxPatternLength", | ||
| "the maximal length of the sequential pattern", | ||
| ParamValidators.gt(0)) | ||
|
|
||
| /** @group getParam */ | ||
| @Since("2.4.0") | ||
| def getMaxPatternLength: Double = $(maxPatternLength) | ||
|
||
|
|
||
| /** | ||
| * Set the maxPatternLength parameter. | ||
| * Default is 10. | ||
| * | ||
| * @group setParam | ||
| */ | ||
| @Since("2.4.0") | ||
| def setMaxPatternLength(value: Int): this.type = set(maxPatternLength, value) | ||
|
|
||
| /** | ||
| * The maximum number of items (including delimiters used in the | ||
|
||
| * internal storage format) allowed in a projected database before | ||
| * local processing. If a projected database exceeds this size, another | ||
| * iteration of distributed prefix growth is run | ||
| * (default value: `32000000`). | ||
| * @group param | ||
| */ | ||
| @Since("2.4.0") | ||
| val maxLocalProjDBSize = new LongParam(this, "maxLocalProjDBSize", | ||
| "The maximum number of items (including delimiters used in the internal storage format) " + | ||
| "allowed in a projected database before local processing. If a projected database exceeds " + | ||
| "this size, another iteration of distributed prefix growth is run", | ||
| ParamValidators.gt(0)) | ||
|
|
||
| /** @group getParam */ | ||
| @Since("2.4.0") | ||
| def getMaxLocalProjDBSize: Double = $(maxLocalProjDBSize) | ||
|
||
|
|
||
| /** | ||
| * Set the maxLocalProjDBSize parameter. | ||
| * Default is 32000000. | ||
| * | ||
| * @group setParam | ||
| */ | ||
| @Since("2.4.0") | ||
| def setMaxLocalProjDBSize(value: Long): this.type = set(maxLocalProjDBSize, value) | ||
|
|
||
| setDefault(minSupport -> 0.1, maxPatternLength -> 10, maxLocalProjDBSize -> 32000000) | ||
|
|
||
| /** | ||
| * :: Experimental :: | ||
|
|
@@ -45,16 +127,6 @@ object PrefixSpan { | |
| * {{{Seq[Seq[_]]}}} type | ||
| * @param sequenceCol the name of the sequence column in dataset, rows with nulls in this column | ||
| * are ignored | ||
| * @param minSupport the minimal support level of the sequential pattern, any pattern that | ||
| * appears more than (minSupport * size-of-the-dataset) times will be output | ||
| * (recommended value: `0.1`). | ||
| * @param maxPatternLength the maximal length of the sequential pattern | ||
| * (recommended value: `10`). | ||
| * @param maxLocalProjDBSize The maximum number of items (including delimiters used in the | ||
| * internal storage format) allowed in a projected database before | ||
| * local processing. If a projected database exceeds this size, another | ||
| * iteration of distributed prefix growth is run | ||
| * (recommended value: `32000000`). | ||
| * @return A `DataFrame` that contains columns of sequence and corresponding frequency. | ||
| * The schema of it will be: | ||
| * - `sequence: Seq[Seq[T]]` (T is the item type) | ||
|
|
@@ -63,26 +135,22 @@ object PrefixSpan { | |
| @Since("2.4.0") | ||
| def findFrequentSequentialPatterns( | ||
| dataset: Dataset[_], | ||
| sequenceCol: String, | ||
| minSupport: Double, | ||
| maxPatternLength: Int, | ||
| maxLocalProjDBSize: Long): DataFrame = { | ||
| sequenceCol: String): DataFrame = { | ||
|
||
|
|
||
| val inputType = dataset.schema(sequenceCol).dataType | ||
| require(inputType.isInstanceOf[ArrayType] && | ||
| inputType.asInstanceOf[ArrayType].elementType.isInstanceOf[ArrayType], | ||
| s"The input column must be ArrayType and the array element type must also be ArrayType, " + | ||
| s"but got $inputType.") | ||
|
|
||
|
|
||
| val data = dataset.select(sequenceCol) | ||
| val sequences = data.where(col(sequenceCol).isNotNull).rdd | ||
| .map(r => r.getAs[Seq[Seq[Any]]](0).map(_.toArray).toArray) | ||
|
|
||
| val mllibPrefixSpan = new mllibPrefixSpan() | ||
| .setMinSupport(minSupport) | ||
| .setMaxPatternLength(maxPatternLength) | ||
| .setMaxLocalProjDBSize(maxLocalProjDBSize) | ||
| .setMinSupport($(minSupport)) | ||
| .setMaxPatternLength($(maxPatternLength)) | ||
| .setMaxLocalProjDBSize($(maxLocalProjDBSize)) | ||
|
|
||
| val rows = mllibPrefixSpan.run(sequences).freqSequences.map(f => Row(f.sequence, f.freq)) | ||
| val schema = StructType(Seq( | ||
|
|
@@ -93,4 +161,7 @@ object PrefixSpan { | |
| freqSequences | ||
| } | ||
|
|
||
| @Since("2.4.0") | ||
| override def copy(extra: ParamMap): PrefixSpan = defaultCopy(extra) | ||
|
|
||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the doc, mention that this class is not yet an Estimator/Transformer and link to
findFrequentSequentialPatternsmethod.