Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix + address comments
  • Loading branch information
mgaido91 committed Aug 27, 2018
commit 4c8b7beb7fe4f28d9f33306410d6237f19cadf72
23 changes: 11 additions & 12 deletions mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
Original file line number Diff line number Diff line change
Expand Up @@ -218,12 +218,9 @@ object FPGrowth extends DefaultParamsReadable[FPGrowth] {
class FPGrowthModel private[ml] (
@Since("2.2.0") override val uid: String,
@Since("2.2.0") @transient val freqItemsets: DataFrame,
private val itemSupport: Map[Any, Long])
private val itemSupport: scala.collection.Map[Any, Double])
extends Model[FPGrowthModel] with FPGrowthParams with MLWritable {

private[ml] def this(uid: String, freqItemsets: DataFrame) =
this(uid, freqItemsets, Map.empty)

/** @group setParam */
@Since("2.2.0")
def setMinConfidence(value: Double): this.type = set(minConfidence, value)
Expand Down Expand Up @@ -332,15 +329,16 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] {
instance.freqItemsets.write.parquet(dataPath)
val itemDataType = instance.freqItemsets.schema(instance.getItemsCol).dataType match {
case ArrayType(et, _) => et
case other => other // we should never get here
case other => throw new RuntimeException(s"Expected ${ArrayType.simpleString}, but got " +
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I slightly prefer subclasses like IllegalArgumentException or IllegalStateException, but it's just a matter of taste. You can interpolate the second argument and probably get it on one line if you break before the message starts.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, I'll do, thanks.

other.catalogString + ".")
}
val itemSupportPath = new Path(path, "itemSupport").toString
val itemSupportRows = instance.itemSupport.map {
case (item, support) => Row(item, support)
}.toSeq
val schema = StructType(Seq(
StructField("item", itemDataType, nullable = false),
StructField("support", LongType, nullable = false)))
StructField("support", DoubleType, nullable = false)))
sparkSession.createDataFrame(sc.parallelize(itemSupportRows), schema)
.repartition(1).write.parquet(itemSupportPath)
}
Expand All @@ -358,11 +356,11 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] {
val itemSupportPath = new Path(path, "itemSupport")
val fs = FileSystem.get(sc.hadoopConfiguration)
val itemSupport = if (fs.exists(itemSupportPath)) {
sparkSession.read.parquet(itemSupportPath.toString).rdd.collect().map {
case Row(item: Any, support: Long) => item -> support
}.toMap
sparkSession.read.parquet(itemSupportPath.toString).rdd.map {
case Row(item: Any, support: Double) => item -> support
}.collectAsMap()
} else {
Map.empty[Any, Long]
Map.empty[Any, Double]
}
val model = new FPGrowthModel(metadata.uid, frequentItems, itemSupport)
metadata.getAndSetParams(model)
Expand All @@ -380,6 +378,7 @@ private[fpm] object AssociationRules {
* @param itemsCol column name for frequent itemsets
* @param freqCol column name for appearance count of the frequent itemsets
* @param minConfidence minimum confidence for generating the association rules
* @param itemSupport map containing an item and its support
* @return a DataFrame("antecedent"[Array], "consequent"[Array], "confidence"[Double])
* containing the association rules.
*/
Expand All @@ -388,13 +387,13 @@ private[fpm] object AssociationRules {
itemsCol: String,
freqCol: String,
minConfidence: Double,
itemSupport: Map[Any, Long]): DataFrame = {
itemSupport: scala.collection.Map[T, Double]): DataFrame = {

val freqItemSetRdd = dataset.select(itemsCol, freqCol).rdd
.map(row => new FreqItemset(row.getSeq[T](0).toArray, row.getLong(1)))
val rows = new MLlibAssociationRules()
.setMinConfidence(minConfidence)
.run(freqItemSetRdd, itemSupport.asInstanceOf[Map[T, Long]])
.run(freqItemSetRdd, itemSupport)
.map(r => Row(r.antecedent, r.consequent, r.confidence, r.lift.orNull))

val dt = dataset.schema(itemsCol).dataType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,23 +56,24 @@ class AssociationRules private[fpm] (
/**
* Computes the association rules with confidence above `minConfidence`.
* @param freqItemsets frequent itemset model obtained from [[FPGrowth]]
* @return a `Set[Rule[Item]]` containing the association rules.
* @return a `RDD[Rule[Item]]` containing the association rules.
*
*/
@Since("1.5.0")
def run[Item: ClassTag](freqItemsets: RDD[FreqItemset[Item]]): RDD[Rule[Item]] = {
run(freqItemsets, Map.empty[Item, Long])
run(freqItemsets, Map.empty[Item, Double])
}

/**
* Computes the association rules with confidence above `minConfidence`.
* @param freqItemsets frequent itemset model obtained from [[FPGrowth]]
* @return a `Set[Rule[Item]]` containing the association rules. The rules will be able to
* @param itemSupport map containing an item and its support
* @return a `RDD[Rule[Item]]` containing the association rules. The rules will be able to
* compute also the lift metric.
*/
@Since("2.4.0")
def run[Item: ClassTag](freqItemsets: RDD[FreqItemset[Item]],
itemSupport: Map[Item, Long]): RDD[Rule[Item]] = {
itemSupport: scala.collection.Map[Item, Double]): RDD[Rule[Item]] = {
// For candidate rule X => Y, generate (X, (Y, freq(X union Y)))
val candidates = freqItemsets.flatMap { itemset =>
val items = itemset.items
Expand Down Expand Up @@ -125,7 +126,7 @@ object AssociationRules {
@Since("1.5.0") val consequent: Array[Item],
freqUnion: Double,
freqAntecedent: Double,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally these frequencies would have been Longs I think, but too late. Yes, stay consistent.

freqConsequent: Option[Long]) extends Serializable {
freqConsequent: Option[Double]) extends Serializable {

/**
* Returns the confidence of the rule.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ import org.apache.spark.storage.StorageLevel
@Since("1.3.0")
class FPGrowthModel[Item: ClassTag] @Since("2.4.0") (
@Since("1.3.0") val freqItemsets: RDD[FreqItemset[Item]],
@Since("2.4.0") val itemSupport: Map[Item, Long])
@Since("2.4.0") val itemSupport: Map[Item, Double])
extends Saveable with Serializable {

@Since("1.3.0")
Expand Down Expand Up @@ -220,7 +220,10 @@ class FPGrowth private[spark] (
val partitioner = new HashPartitioner(numParts)
val freqItemsCount = genFreqItems(data, minCount, partitioner)
val freqItemsets = genFreqItemsets(data, minCount, freqItemsCount.map(_._1), partitioner)
new FPGrowthModel(freqItemsets, freqItemsCount.toMap)
val itemSupport = freqItemsCount.map {
case (item, cnt) => item -> cnt.toDouble / count
}.toMap
new FPGrowthModel(freqItemsets, itemSupport)
}

/**
Expand All @@ -236,7 +239,7 @@ class FPGrowth private[spark] (
* Generates frequent items by filtering the input data using minimal support level.
* @param minCount minimum count for frequent itemsets
* @param partitioner partitioner used to distribute items
* @return array of frequent pattern ordered by their frequencies
* @return array of frequent patterns and their frequencies ordered by their frequencies
*/
private def genFreqItems[Item: ClassTag](
data: RDD[Array[Item]],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
val model = new FPGrowth().setMinSupport(0.5).fit(data)
val generatedRules = model.setMinConfidence(0.5).associationRules
val expectedRules = spark.createDataFrame(Seq(
(Array("2"), Array("1"), 1.0, 0.25),
(Array("1"), Array("2"), 0.75, 0.25)
(Array("2"), Array("1"), 1.0, 1.0),
(Array("1"), Array("2"), 0.75, 1.0)
)).toDF("antecedent", "consequent", "confidence", "lift")
.withColumn("antecedent", col("antecedent").cast(ArrayType(dt)))
.withColumn("consequent", col("consequent").cast(ArrayType(dt)))
Expand Down
1 change: 1 addition & 0 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ object MimaExcludes {
// Exclude rules for 2.4.x
lazy val v24excludes = v23excludes ++ Seq(
// [SPARK-10697][ML] Add lift to Association rules
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.fpm.FPGrowthModel.this"),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are for the private[ml] constructors right? OK to suppress, yes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, they are the private ones.

ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.fpm.AssociationRules#Rule.this"),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note for reviewers and myself: this method is private (private[fpm])

// [SPARK-24296][CORE] Replicate large blocks as a stream.
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.network.netty.NettyBlockRpcServer.this"),
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/ml/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2158,8 +2158,8 @@ def test_association_rules(self):
fpm = fp.fit(self.data)

expected_association_rules = self.spark.createDataFrame(
[([3], [1], 1.0), ([2], [1], 1.0)],
["antecedent", "consequent", "confidence"]
[([3], [1], 1.0, 1.0), ([2], [1], 1.0, 1.0)],
["antecedent", "consequent", "confidence", "lift"]
)
actual_association_rules = fpm.associationRules

Expand Down