From 77bf12da6a7045b243bf520bb1ecf001af32ea17 Mon Sep 17 00:00:00 2001 From: zero323 Date: Thu, 16 Mar 2017 18:46:50 +0100 Subject: [PATCH] Replace featuresCol with itemsCol in ml.fpm.FPGrowth --- .../org/apache/spark/ml/fpm/FPGrowth.scala | 35 +++++++++++++------ .../apache/spark/ml/fpm/FPGrowthSuite.scala | 14 ++++---- 2 files changed, 31 insertions(+), 18 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index fa39dd954af5..e2bc270b38da 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol} +import org.apache.spark.ml.param.shared.HasPredictionCol import org.apache.spark.ml.util._ import org.apache.spark.mllib.fpm.{AssociationRules => MLlibAssociationRules, FPGrowth => MLlibFPGrowth} @@ -37,7 +37,20 @@ import org.apache.spark.sql.types._ /** * Common params for FPGrowth and FPGrowthModel */ -private[fpm] trait FPGrowthParams extends Params with HasFeaturesCol with HasPredictionCol { +private[fpm] trait FPGrowthParams extends Params with HasPredictionCol { + + /** + * Items column name. + * Default: "items" + * @group param + */ + @Since("2.2.0") + val itemsCol: Param[String] = new Param[String](this, "itemsCol", "items column name") + setDefault(itemsCol -> "items") + + /** @group getParam */ + @Since("2.2.0") + def getItemsCol: String = $(itemsCol) /** * Minimal support level of the frequent pattern. [0.0, 1.0]. Any pattern that appears @@ -91,10 +104,10 @@ private[fpm] trait FPGrowthParams extends Params with HasFeaturesCol with HasPre */ @Since("2.2.0") protected def validateAndTransformSchema(schema: StructType): StructType = { - val inputType = schema($(featuresCol)).dataType + val inputType = schema($(itemsCol)).dataType require(inputType.isInstanceOf[ArrayType], s"The input column must be ArrayType, but got $inputType.") - SchemaUtils.appendColumn(schema, $(predictionCol), schema($(featuresCol)).dataType) + SchemaUtils.appendColumn(schema, $(predictionCol), schema($(itemsCol)).dataType) } } @@ -133,7 +146,7 @@ class FPGrowth @Since("2.2.0") ( /** @group setParam */ @Since("2.2.0") - def setFeaturesCol(value: String): this.type = set(featuresCol, value) + def setItemsCol(value: String): this.type = set(itemsCol, value) /** @group setParam */ @Since("2.2.0") @@ -146,8 +159,8 @@ class FPGrowth @Since("2.2.0") ( } private def genericFit[T: ClassTag](dataset: Dataset[_]): FPGrowthModel = { - val data = dataset.select($(featuresCol)) - val items = data.where(col($(featuresCol)).isNotNull).rdd.map(r => r.getSeq[T](0).toArray) + val data = dataset.select($(itemsCol)) + val items = data.where(col($(itemsCol)).isNotNull).rdd.map(r => r.getSeq[T](0).toArray) val mllibFP = new MLlibFPGrowth().setMinSupport($(minSupport)) if (isSet(numPartitions)) { mllibFP.setNumPartitions($(numPartitions)) @@ -156,7 +169,7 @@ class FPGrowth @Since("2.2.0") ( val rows = parentModel.freqItemsets.map(f => Row(f.items, f.freq)) val schema = StructType(Seq( - StructField("items", dataset.schema($(featuresCol)).dataType, nullable = false), + StructField("items", dataset.schema($(itemsCol)).dataType, nullable = false), StructField("freq", LongType, nullable = false))) val frequentItems = dataset.sparkSession.createDataFrame(rows, schema) copyValues(new FPGrowthModel(uid, frequentItems)).setParent(this) @@ -198,7 +211,7 @@ class FPGrowthModel private[ml] ( /** @group setParam */ @Since("2.2.0") - def setFeaturesCol(value: String): this.type = set(featuresCol, value) + def setItemsCol(value: String): this.type = set(itemsCol, value) /** @group setParam */ @Since("2.2.0") @@ -235,7 +248,7 @@ class FPGrowthModel private[ml] ( .collect().asInstanceOf[Array[(Seq[Any], Seq[Any])]] val brRules = dataset.sparkSession.sparkContext.broadcast(rules) - val dt = dataset.schema($(featuresCol)).dataType + val dt = dataset.schema($(itemsCol)).dataType // For each rule, examine the input items and summarize the consequents val predictUDF = udf((items: Seq[_]) => { if (items != null) { @@ -249,7 +262,7 @@ class FPGrowthModel private[ml] ( } else { Seq.empty }}, dt) - dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + dataset.withColumn($(predictionCol), predictUDF(col($(itemsCol)))) } @Since("2.2.0") diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala index 910d4b07d130..4603a618d2f9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala @@ -34,7 +34,7 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("FPGrowth fit and transform with different data types") { Array(IntegerType, StringType, ShortType, LongType, ByteType).foreach { dt => - val data = dataset.withColumn("features", col("features").cast(ArrayType(dt))) + val data = dataset.withColumn("items", col("items").cast(ArrayType(dt))) val model = new FPGrowth().setMinSupport(0.5).fit(data) val generatedRules = model.setMinConfidence(0.5).associationRules val expectedRules = spark.createDataFrame(Seq( @@ -52,8 +52,8 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul (0, Array("1", "2"), Array.emptyIntArray), (0, Array("1", "2"), Array.emptyIntArray), (0, Array("1", "3"), Array(2)) - )).toDF("id", "features", "prediction") - .withColumn("features", col("features").cast(ArrayType(dt))) + )).toDF("id", "items", "prediction") + .withColumn("items", col("items").cast(ArrayType(dt))) .withColumn("prediction", col("prediction").cast(ArrayType(dt))) assert(expectedTransformed.collect().toSet.equals( transformed.collect().toSet)) @@ -79,7 +79,7 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul (1, Array("1", "2", "3", "5")), (2, Array("1", "2", "3", "4")), (3, null.asInstanceOf[Array[String]]) - )).toDF("id", "features") + )).toDF("id", "items") val model = new FPGrowth().setMinSupport(0.7).fit(dataset) val prediction = model.transform(df) assert(prediction.select("prediction").where("id=3").first().getSeq[String](0).isEmpty) @@ -108,11 +108,11 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul val dataset = spark.createDataFrame(Seq( Array("1", "3"), Array("2", "3") - ).map(Tuple1(_))).toDF("features") + ).map(Tuple1(_))).toDF("items") val model = new FPGrowth().fit(dataset) val prediction = model.transform( - spark.createDataFrame(Seq(Tuple1(Array("1", "2")))).toDF("features") + spark.createDataFrame(Seq(Tuple1(Array("1", "2")))).toDF("items") ).first().getAs[Seq[String]]("prediction") assert(prediction === Seq("3")) @@ -127,7 +127,7 @@ object FPGrowthSuite { (0, Array("1", "2")), (0, Array("1", "2")), (0, Array("1", "3")) - )).toDF("id", "features") + )).toDF("id", "items") } /**