diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaAssociationRulesExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaAssociationRulesExample.java index 189560e3fe1f..1e622f153c98 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaAssociationRulesExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaAssociationRulesExample.java @@ -45,7 +45,7 @@ public static void main(String[] args) { AssociationRules arules = new AssociationRules() .setMinConfidence(0.8); - JavaRDD> results = arules.run(freqItemsets); + JavaRDD> results = arules.run(freqItemsets, 50L); for (AssociationRules.Rule rule : results.collect()) { System.out.println( diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala index 11e18c9f040b..23633bfff2d2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala @@ -39,7 +39,7 @@ object AssociationRulesExample { val ar = new AssociationRules() .setMinConfidence(0.8) - val results = ar.run(freqItemsets) + val results = ar.run(freqItemsets, 50L) results.collect().foreach { rule => println("[" + rule.antecedent.mkString(",") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/FPGrowthModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/FPGrowthModelWrapper.scala index e6d1dceebed4..fb0034cfa03e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/FPGrowthModelWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/FPGrowthModelWrapper.scala @@ -24,7 +24,7 @@ import org.apache.spark.rdd.RDD * A Wrapper of FPGrowthModel to provide helper method for Python */ private[python] class FPGrowthModelWrapper(model: FPGrowthModel[Any]) - extends FPGrowthModel(model.freqItemsets) { + extends FPGrowthModel(model.freqItemsets, model.dataSize) { def getFreqItemsets: RDD[Array[Any]] = { SerDe.fromTuple2RDD(model.freqItemsets.map(x => (x.javaItems, x.freq))) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala index 9a63cc29dacb..c6650e09975c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala @@ -63,7 +63,7 @@ class AssociationRules private[fpm] ( * */ @Since("1.5.0") - def run[Item: ClassTag](freqItemsets: RDD[FreqItemset[Item]]): RDD[Rule[Item]] = { + def run[Item: ClassTag](freqItemsets: RDD[FreqItemset[Item]], dataSize: Long): RDD[Rule[Item]] = { // For candidate rule X => Y, generate (X, (Y, freq(X union Y))) val candidates = freqItemsets.flatMap { itemset => val items = itemset.items @@ -79,15 +79,15 @@ class AssociationRules private[fpm] ( // Join to get (X, ((Y, freq(X union Y)), freq(X))), generate rules, and filter by confidence candidates.join(freqItemsets.map(x => (x.items.toSeq, x.freq))) .map { case (antecendent, ((consequent, freqUnion), freqAntecedent)) => - new Rule(antecendent.toArray, consequent.toArray, freqUnion, freqAntecedent) + new Rule(antecendent.toArray, consequent.toArray, freqUnion, freqAntecedent, dataSize) }.filter(_.confidence >= minConfidence) } /** Java-friendly version of [[run]]. */ @Since("1.5.0") - def run[Item](freqItemsets: JavaRDD[FreqItemset[Item]]): JavaRDD[Rule[Item]] = { + def run[Item](freqItemsets: JavaRDD[FreqItemset[Item]], dataSize: Long): JavaRDD[Rule[Item]] = { val tag = fakeClassTag[Item] - run(freqItemsets.rdd)(tag) + run(freqItemsets.rdd, dataSize)(tag) } } @@ -111,7 +111,8 @@ object AssociationRules { @Since("1.5.0") val antecedent: Array[Item], @Since("1.5.0") val consequent: Array[Item], freqUnion: Double, - freqAntecedent: Double) extends Serializable { + freqAntecedent: Double, + dataSize: Long) extends Serializable { /** * Returns the confidence of the rule. @@ -120,6 +121,13 @@ object AssociationRules { @Since("1.5.0") def confidence: Double = freqUnion.toDouble / freqAntecedent + /** + * Returns the support of the rule. Current implementation would return the number of + * co-occurrence of antecedent and consequent. + */ + @Since("2.1.0") + def support: Double = freqUnion.toDouble / dataSize + require(antecedent.toSet.intersect(consequent.toSet).isEmpty, { val sharedItems = antecedent.toSet.intersect(consequent.toSet) s"A valid association rule must have disjoint antecedent and " + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index 0f7fbe9556c5..c15c536a8a9e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -49,7 +49,8 @@ import org.apache.spark.storage.StorageLevel */ @Since("1.3.0") class FPGrowthModel[Item: ClassTag] @Since("1.3.0") ( - @Since("1.3.0") val freqItemsets: RDD[FreqItemset[Item]]) + @Since("1.3.0") val freqItemsets: RDD[FreqItemset[Item]], + @Since("2.0.0") val dataSize: Long) extends Saveable with Serializable { /** * Generates association rules for the [[Item]]s in [[freqItemsets]]. @@ -58,7 +59,7 @@ class FPGrowthModel[Item: ClassTag] @Since("1.3.0") ( @Since("1.5.0") def generateAssociationRules(confidence: Double): RDD[AssociationRules.Rule[Item]] = { val associationRules = new AssociationRules(confidence) - associationRules.run(freqItemsets) + associationRules.run(freqItemsets, dataSize) } /** @@ -102,7 +103,8 @@ object FPGrowthModel extends Loader[FPGrowthModel[_]] { val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val metadata = compact(render( - ("class" -> thisClassName) ~ ("version" -> thisFormatVersion))) + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ + ("dataSize" -> model.dataSize))) sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) // Get the type of item class @@ -128,19 +130,20 @@ object FPGrowthModel extends Loader[FPGrowthModel[_]] { val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) assert(className == thisClassName) assert(formatVersion == thisFormatVersion) - + val dataSize = (metadata \ "dataSize").extract[Long] val freqItemsets = spark.read.parquet(Loader.dataPath(path)) val sample = freqItemsets.select("items").head().get(0) - loadImpl(freqItemsets, sample) + loadImpl(freqItemsets, sample, dataSize) } - def loadImpl[Item: ClassTag](freqItemsets: DataFrame, sample: Item): FPGrowthModel[Item] = { + def loadImpl[Item: ClassTag](freqItemsets: DataFrame, sample: Item, + dataSize: Long): FPGrowthModel[Item] = { val freqItemsetsRDD = freqItemsets.select("items", "freq").rdd.map { x => val items = x.getAs[Seq[Item]](0).toArray val freq = x.getLong(1) new FreqItemset(items, freq) } - new FPGrowthModel(freqItemsetsRDD) + new FPGrowthModel(freqItemsetsRDD, dataSize) } } } @@ -215,7 +218,7 @@ class FPGrowth private ( val partitioner = new HashPartitioner(numParts) val freqItems = genFreqItems(data, minCount, partitioner) val freqItemsets = genFreqItemsets(data, minCount, freqItems, partitioner) - new FPGrowthModel(freqItemsets) + new FPGrowthModel(freqItemsets, count) } /** Java-friendly version of [[run]]. */ diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java index 3451e0773759..e4624bb67410 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java @@ -36,6 +36,7 @@ public void runAssociationRules() { new FreqItemset(new String[]{"a", "b"}, 12L) )); - JavaRDD> results = (new AssociationRules()).run(freqItemsets); + JavaRDD> results = (new AssociationRules()).run( + freqItemsets, 50L); } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala index dcb1f398b04b..9786049b1d01 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala @@ -38,7 +38,7 @@ class AssociationRulesSuite extends SparkFunSuite with MLlibTestSparkContext { val results1 = ar .setMinConfidence(0.9) - .run(freqItemsets) + .run(freqItemsets, 10L) .collect() /* Verify results using the `R` code: @@ -67,7 +67,7 @@ class AssociationRulesSuite extends SparkFunSuite with MLlibTestSparkContext { val results2 = ar .setMinConfidence(0) - .run(freqItemsets) + .run(freqItemsets, 10L) .collect() /* Verify results using the `R` code: