Skip to content

Commit d125529

Browse files
huaxingaosrowen
authored andcommitted
[SPARK-19939][ML] Add support for association rules in ML
### What changes were proposed in this pull request? Adding support to Association Rules in Spark ml.fpm. ### Why are the changes needed? Support is an indication of how frequently the itemset of an association rule appears in the database and suggests if the rules are generally applicable to the dateset. Refer to [wiki](https://en.wikipedia.org/wiki/Association_rule_learning#Support) for more details. ### Does this PR introduce _any_ user-facing change? Yes. Associate Rules now have support measure ### How was this patch tested? existing and new unit test Closes apache#28903 from huaxingao/fpm. Authored-by: Huaxin Gao <[email protected]> Signed-off-by: Sean Owen <[email protected]>
1 parent bbb2cba commit d125529

File tree

7 files changed

+57
-26
lines changed

7 files changed

+57
-26
lines changed

R/pkg/R/mllib_fpm.R

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,11 +122,12 @@ setMethod("spark.freqItemsets", signature(object = "FPGrowthModel"),
122122
# Get association rules.
123123

124124
#' @return A \code{SparkDataFrame} with association rules.
125-
#' The \code{SparkDataFrame} contains four columns:
125+
#' The \code{SparkDataFrame} contains five columns:
126126
#' \code{antecedent} (an array of the same type as the input column),
127127
#' \code{consequent} (an array of the same type as the input column),
128128
#' \code{condfidence} (confidence for the rule)
129-
#' and \code{lift} (lift for the rule)
129+
#' \code{lift} (lift for the rule)
130+
#' and \code{support} (support for the rule)
130131
#' @rdname spark.fpGrowth
131132
#' @aliases associationRules,FPGrowthModel-method
132133
#' @note spark.associationRules(FPGrowthModel) since 2.2.0

R/pkg/tests/fulltests/test_mllib_fpm.R

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ test_that("spark.fpGrowth", {
4545
antecedent = I(list(list("2"), list("3"))),
4646
consequent = I(list(list("1"), list("1"))),
4747
confidence = c(1, 1),
48-
lift = c(1, 1)
48+
lift = c(1, 1),
49+
support = c(0.75, 0.5)
4950
)
5051

5152
expect_equivalent(expected_association_rules, collect(spark.associationRules(model)))

mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -244,17 +244,18 @@ class FPGrowthModel private[ml] (
244244
@transient private var _cachedRules: DataFrame = _
245245

246246
/**
247-
* Get association rules fitted using the minConfidence. Returns a dataframe with four fields,
248-
* "antecedent", "consequent", "confidence" and "lift", where "antecedent" and "consequent" are
249-
* Array[T], whereas "confidence" and "lift" are Double.
247+
* Get association rules fitted using the minConfidence. Returns a dataframe with five fields,
248+
* "antecedent", "consequent", "confidence", "lift" and "support", where "antecedent" and
249+
* "consequent" are Array[T], whereas "confidence", "lift" and "support" are Double.
250250
*/
251251
@Since("2.2.0")
252252
@transient def associationRules: DataFrame = {
253253
if ($(minConfidence) == _cachedMinConf) {
254254
_cachedRules
255255
} else {
256256
_cachedRules = AssociationRules
257-
.getAssociationRulesFromFP(freqItemsets, "items", "freq", $(minConfidence), itemSupport)
257+
.getAssociationRulesFromFP(freqItemsets, "items", "freq", $(minConfidence), itemSupport,
258+
numTrainingRecords)
258259
_cachedMinConf = $(minConfidence)
259260
_cachedRules
260261
}
@@ -385,6 +386,7 @@ private[fpm] object AssociationRules {
385386
* @param freqCol column name for appearance count of the frequent itemsets
386387
* @param minConfidence minimum confidence for generating the association rules
387388
* @param itemSupport map containing an item and its support
389+
* @param numTrainingRecords count of training Dataset
388390
* @return a DataFrame("antecedent"[Array], "consequent"[Array], "confidence"[Double],
389391
* "lift" [Double]) containing the association rules.
390392
*/
@@ -393,21 +395,23 @@ private[fpm] object AssociationRules {
393395
itemsCol: String,
394396
freqCol: String,
395397
minConfidence: Double,
396-
itemSupport: scala.collection.Map[T, Double]): DataFrame = {
397-
398+
itemSupport: scala.collection.Map[T, Double],
399+
numTrainingRecords: Long): DataFrame = {
398400
val freqItemSetRdd = dataset.select(itemsCol, freqCol).rdd
399401
.map(row => new FreqItemset(row.getSeq[T](0).toArray, row.getLong(1)))
400402
val rows = new MLlibAssociationRules()
401403
.setMinConfidence(minConfidence)
402404
.run(freqItemSetRdd, itemSupport)
403-
.map(r => Row(r.antecedent, r.consequent, r.confidence, r.lift.orNull))
405+
.map(r => Row(r.antecedent, r.consequent, r.confidence, r.lift.orNull,
406+
r.freqUnion / numTrainingRecords))
404407

405408
val dt = dataset.schema(itemsCol).dataType
406409
val schema = StructType(Seq(
407410
StructField("antecedent", dt, nullable = false),
408411
StructField("consequent", dt, nullable = false),
409412
StructField("confidence", DoubleType, nullable = false),
410-
StructField("lift", DoubleType)))
413+
StructField("lift", DoubleType),
414+
StructField("support", DoubleType, nullable = false)))
411415
val rules = dataset.sparkSession.createDataFrame(rows, schema)
412416
rules
413417
}

mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ object AssociationRules {
124124
class Rule[Item] private[fpm] (
125125
@Since("1.5.0") val antecedent: Array[Item],
126126
@Since("1.5.0") val consequent: Array[Item],
127-
freqUnion: Double,
127+
private[spark] val freqUnion: Double,
128128
freqAntecedent: Double,
129129
freqConsequent: Option[Double]) extends Serializable {
130130

mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
3939
val model = new FPGrowth().setMinSupport(0.5).fit(data)
4040
val generatedRules = model.setMinConfidence(0.5).associationRules
4141
val expectedRules = spark.createDataFrame(Seq(
42-
(Array("2"), Array("1"), 1.0, 1.0),
43-
(Array("1"), Array("2"), 0.75, 1.0)
44-
)).toDF("antecedent", "consequent", "confidence", "lift")
42+
(Array("2"), Array("1"), 1.0, 1.0, 0.75),
43+
(Array("1"), Array("2"), 0.75, 1.0, 0.75)
44+
)).toDF("antecedent", "consequent", "confidence", "lift", "support")
4545
.withColumn("antecedent", col("antecedent").cast(ArrayType(dt)))
4646
.withColumn("consequent", col("consequent").cast(ArrayType(dt)))
4747
assert(expectedRules.sort("antecedent").rdd.collect().sameElements(
@@ -61,6 +61,31 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
6161
}
6262
}
6363

64+
test("FPGrowth associationRules") {
65+
val dataset = spark.createDataFrame(Seq(
66+
(1, Array("1", "2")),
67+
(2, Array("3")),
68+
(3, Array("4", "5")),
69+
(4, Array("1", "2", "3")),
70+
(5, Array("2"))
71+
)).toDF("id", "items")
72+
val model = new FPGrowth().setMinSupport(0.1).setMinConfidence(0.1).fit(dataset)
73+
val expectedRules = spark.createDataFrame(Seq(
74+
(Array("2"), Array("1"), 0.6666666666666666, 1.6666666666666665, 0.4),
75+
(Array("2"), Array("3"), 0.3333333333333333, 0.8333333333333333, 0.2),
76+
(Array("3"), Array("1"), 0.5, 1.25, 0.2),
77+
(Array("3"), Array("2"), 0.5, 0.8333333333333334, 0.2),
78+
(Array("1", "3"), Array("2"), 1.0, 1.6666666666666667, 0.2),
79+
(Array("1", "2"), Array("3"), 0.5, 1.25, 0.2),
80+
(Array("4"), Array("5"), 1.0, 5.0, 0.2),
81+
(Array("5"), Array("4"), 1.0, 5.0, 0.2),
82+
(Array("1"), Array("3"), 0.5, 1.25, 0.2),
83+
(Array("1"), Array("2"), 1.0, 1.6666666666666667, 0.4),
84+
(Array("3", "2"), Array("1"), 1.0, 2.5, 0.2)
85+
)).toDF("antecedent", "consequent", "confidence", "lift", "support")
86+
assert(expectedRules.collect().toSet.equals(model.associationRules.collect().toSet))
87+
}
88+
6489
test("FPGrowth getFreqItems") {
6590
val model = new FPGrowth().setMinSupport(0.7).fit(dataset)
6691
val expectedFreq = spark.createDataFrame(Seq(

python/pyspark/ml/fpm.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -180,15 +180,15 @@ class FPGrowth(JavaEstimator, _FPGrowthParams, JavaMLWritable, JavaMLReadable):
180180
only showing top 5 rows
181181
...
182182
>>> fpm.associationRules.show(5)
183-
+----------+----------+----------+----+
184-
|antecedent|consequent|confidence|lift|
185-
+----------+----------+----------+----+
186-
| [t, s]| [y]| 1.0| 2.0|
187-
| [t, s]| [x]| 1.0| 1.5|
188-
| [t, s]| [z]| 1.0| 1.2|
189-
| [p]| [r]| 1.0| 2.0|
190-
| [p]| [z]| 1.0| 1.2|
191-
+----------+----------+----------+----+
183+
+----------+----------+----------+----+------------------+
184+
|antecedent|consequent|confidence|lift| support|
185+
+----------+----------+----------+----+------------------+
186+
| [t, s]| [y]| 1.0| 2.0|0.3333333333333333|
187+
| [t, s]| [x]| 1.0| 1.5|0.3333333333333333|
188+
| [t, s]| [z]| 1.0| 1.2|0.3333333333333333|
189+
| [p]| [r]| 1.0| 2.0|0.3333333333333333|
190+
| [p]| [z]| 1.0| 1.2|0.3333333333333333|
191+
+----------+----------+----------+----+------------------+
192192
only showing top 5 rows
193193
...
194194
>>> new_data = spark.createDataFrame([(["t", "s"], )], ["items"])

python/pyspark/ml/tests/test_algorithms.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,8 +226,8 @@ def test_association_rules(self):
226226
fpm = fp.fit(self.data)
227227

228228
expected_association_rules = self.spark.createDataFrame(
229-
[([3], [1], 1.0, 1.0), ([2], [1], 1.0, 1.0)],
230-
["antecedent", "consequent", "confidence", "lift"]
229+
[([3], [1], 1.0, 1.0, 0.5), ([2], [1], 1.0, 1.0, 0.75)],
230+
["antecedent", "consequent", "confidence", "lift", "support"]
231231
)
232232
actual_association_rules = fpm.associationRules
233233

0 commit comments

Comments
 (0)