Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
mima fix
  • Loading branch information
YY-OnCall committed Jul 30, 2018
commit 40cf4497432bd4e2cbeec8e6647b52c4c8e74072
23 changes: 17 additions & 6 deletions mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ object FPGrowth extends DefaultParamsReadable[FPGrowth] {
class FPGrowthModel private[ml] (
@Since("2.2.0") override val uid: String,
@Since("2.2.0") @transient val freqItemsets: DataFrame,
@Since("2.4.0") val numTrainingRecords: Long)
@Since("2.4.0") val numTrainingRecords: Long = -1)
extends Model[FPGrowthModel] with FPGrowthParams with MLWritable {

/** @group setParam */
Expand Down Expand Up @@ -359,7 +359,7 @@ private[fpm] object AssociationRules {
* from algorithms like [[FPGrowth]].
* @param itemsCol column name for frequent itemsets
* @param freqCol column name for frequent itemsets count
* @param numTrainingRecords count of training Dataset
* @param numTrainingRecords count of training Dataset, default -1.
* @param minConfidence minimum confidence for the result association rules
* @return a DataFrame("antecedent", "consequent", "confidence", "support") containing the
* association rules.
Expand All @@ -376,15 +376,26 @@ private[fpm] object AssociationRules {
val rows = new MLlibAssociationRules()
.setMinConfidence(minConfidence)
.run(freqItemSetRdd)
.map(r => Row(r.antecedent, r.consequent, r.confidence, r.freqUnion / numTrainingRecords))
.map { r =>
if (numTrainingRecords > 0) {
Row(r.antecedent, r.consequent, r.confidence, r.freqUnion / numTrainingRecords)
} else {
Row(r.antecedent, r.consequent, r.confidence)
}

}

val dt = dataset.schema(itemsCol).dataType
val schema = StructType(Seq(
StructField("antecedent", dt, nullable = false),
StructField("consequent", dt, nullable = false),
StructField("confidence", DoubleType, nullable = false),
StructField("support", DoubleType, nullable = false)))
val rules = dataset.sparkSession.createDataFrame(rows, schema)
StructField("confidence", DoubleType, nullable = false)))
val rulesSchema = if (numTrainingRecords > 0) {
schema.add(StructField("support", DoubleType, nullable = false))
} else {
schema
}
val rules = dataset.sparkSession.createDataFrame(rows, rulesSchema)
rules
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
val data = dataset.withColumn("items", col("items").cast(ArrayType(dt)))
val model = new FPGrowth().setMinSupport(0.5).fit(data)
val generatedRules = model.setMinConfidence(0.5).associationRules
generatedRules.show()
val expectedRules = spark.createDataFrame(Seq(
(Array("2"), Array("1"), 1.0, 0.75),
(Array("1"), Array("2"), 0.75, 0.75)
Expand Down
3 changes: 3 additions & 0 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ object MimaExcludes {
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.getValidationIndicatorCol"),
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.org$apache$spark$ml$param$shared$HasValidationIndicatorCol$_setter_$validationIndicatorCol_="),
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.validationIndicatorCol")

// [SPARK-19939][ML] Add support for association rules in ML
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.fpm.FPGrowthModel.this")
)

// Exclude rules for 2.3.x
Expand Down