-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-15784] Add Power Iteration Clustering to spark.ml #21493
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from 1 commit
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,21 +18,20 @@ | |
| package org.apache.spark.ml.clustering | ||
|
|
||
| import org.apache.spark.annotation.{Experimental, Since} | ||
| import org.apache.spark.ml.Transformer | ||
| import org.apache.spark.ml.param._ | ||
| import org.apache.spark.ml.param.shared._ | ||
| import org.apache.spark.ml.util._ | ||
| import org.apache.spark.mllib.clustering.{PowerIterationClustering => MLlibPowerIterationClustering} | ||
| import org.apache.spark.rdd.RDD | ||
| import org.apache.spark.sql.{DataFrame, Dataset, Row} | ||
| import org.apache.spark.sql.functions.col | ||
| import org.apache.spark.sql.functions.{col, lit} | ||
| import org.apache.spark.sql.types._ | ||
|
|
||
| /** | ||
| * Common params for PowerIterationClustering | ||
| */ | ||
| private[clustering] trait PowerIterationClusteringParams extends Params with HasMaxIter | ||
| with HasPredictionCol { | ||
| with HasWeightCol { | ||
|
|
||
| /** | ||
| * The number of clusters to create (k). Must be > 1. Default: 2. | ||
|
|
@@ -66,62 +65,35 @@ private[clustering] trait PowerIterationClusteringParams extends Params with Has | |
| def getInitMode: String = $(initMode) | ||
|
|
||
| /** | ||
| * Param for the name of the input column for vertex IDs. | ||
| * Default: "id" | ||
| * Param for the name of the input column for source vertex IDs. | ||
| * Default: "src" | ||
| * @group param | ||
| */ | ||
| @Since("2.4.0") | ||
| val idCol = new Param[String](this, "idCol", "Name of the input column for vertex IDs.", | ||
| val srcCol = new Param[String](this, "srcCol", "Name of the input column for source vertex IDs.", | ||
| (value: String) => value.nonEmpty) | ||
|
|
||
| setDefault(idCol, "id") | ||
| setDefault(srcCol, "src") | ||
|
|
||
| /** @group getParam */ | ||
| @Since("2.4.0") | ||
| def getIdCol: String = getOrDefault(idCol) | ||
| def getSrcCol: String = getOrDefault(srcCol) | ||
|
|
||
| /** | ||
| * Param for the name of the input column for neighbors in the adjacency list representation. | ||
| * Default: "neighbors" | ||
| * Name of the input column for destination vertex IDs. | ||
| * Default: "dst" | ||
| * @group param | ||
| */ | ||
| @Since("2.4.0") | ||
| val neighborsCol = new Param[String](this, "neighborsCol", | ||
| "Name of the input column for neighbors in the adjacency list representation.", | ||
| val dstCol = new Param[String](this, "dstCol", | ||
| "Name of the input column for destination vertex IDs.", | ||
| (value: String) => value.nonEmpty) | ||
|
|
||
| setDefault(neighborsCol, "neighbors") | ||
| setDefault(dstCol, "dst") | ||
|
|
||
| /** @group getParam */ | ||
| @Since("2.4.0") | ||
| def getNeighborsCol: String = $(neighborsCol) | ||
|
|
||
| /** | ||
| * Param for the name of the input column for neighbors in the adjacency list representation. | ||
| * Default: "similarities" | ||
| * @group param | ||
| */ | ||
| @Since("2.4.0") | ||
| val similaritiesCol = new Param[String](this, "similaritiesCol", | ||
| "Name of the input column for neighbors in the adjacency list representation.", | ||
| (value: String) => value.nonEmpty) | ||
|
|
||
| setDefault(similaritiesCol, "similarities") | ||
|
|
||
| /** @group getParam */ | ||
| @Since("2.4.0") | ||
| def getSimilaritiesCol: String = $(similaritiesCol) | ||
|
|
||
| protected def validateAndTransformSchema(schema: StructType): StructType = { | ||
| SchemaUtils.checkColumnTypes(schema, $(idCol), Seq(IntegerType, LongType)) | ||
| SchemaUtils.checkColumnTypes(schema, $(neighborsCol), | ||
| Seq(ArrayType(IntegerType, containsNull = false), | ||
| ArrayType(LongType, containsNull = false))) | ||
| SchemaUtils.checkColumnTypes(schema, $(similaritiesCol), | ||
| Seq(ArrayType(FloatType, containsNull = false), | ||
| ArrayType(DoubleType, containsNull = false))) | ||
| SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) | ||
| } | ||
| def getDstCol: String = $(dstCol) | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -131,21 +103,8 @@ private[clustering] trait PowerIterationClusteringParams extends Params with Has | |
| * PIC finds a very low-dimensional embedding of a dataset using truncated power | ||
| * iteration on a normalized pair-wise similarity matrix of the data. | ||
| * | ||
| * PIC takes an affinity matrix between items (or vertices) as input. An affinity matrix | ||
| * is a symmetric matrix whose entries are non-negative similarities between items. | ||
| * PIC takes this matrix (or graph) as an adjacency matrix. Specifically, each input row includes: | ||
| * - `idCol`: vertex ID | ||
| * - `neighborsCol`: neighbors of vertex in `idCol` | ||
| * - `similaritiesCol`: non-negative weights (similarities) of edges between the vertex | ||
| * in `idCol` and each neighbor in `neighborsCol` | ||
| * PIC returns a cluster assignment for each input vertex. It appends a new column `predictionCol` | ||
| * containing the cluster assignment in `[0,k)` for each row (vertex). | ||
| * | ||
| * Notes: | ||
| * - [[PowerIterationClustering]] is a transformer with an expensive [[transform]] operation. | ||
| * Transform runs the iterative PIC algorithm to cluster the whole input dataset. | ||
| * - Input validation: This validates that similarities are non-negative but does NOT validate | ||
| * that the input matrix is symmetric. | ||
| * This class is not yet an Estimator/Transformer, use `assignClusters` method to run the | ||
| * PowerIterationClustering algorithm. | ||
| * | ||
| * @see <a href=http://en.wikipedia.org/wiki/Spectral_clustering> | ||
| * Spectral clustering (Wikipedia)</a> | ||
|
|
@@ -154,7 +113,7 @@ private[clustering] trait PowerIterationClusteringParams extends Params with Has | |
| @Experimental | ||
| class PowerIterationClustering private[clustering] ( | ||
| @Since("2.4.0") override val uid: String) | ||
| extends Transformer with PowerIterationClusteringParams with DefaultParamsWritable { | ||
| extends PowerIterationClusteringParams with DefaultParamsWritable { | ||
|
|
||
| setDefault( | ||
| k -> 2, | ||
|
|
@@ -164,10 +123,6 @@ class PowerIterationClustering private[clustering] ( | |
| @Since("2.4.0") | ||
| def this() = this(Identifiable.randomUID("PowerIterationClustering")) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.4.0") | ||
| def setPredictionCol(value: String): this.type = set(predictionCol, value) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.4.0") | ||
| def setK(value: Int): this.type = set(k, value) | ||
|
|
@@ -182,66 +137,60 @@ class PowerIterationClustering private[clustering] ( | |
|
|
||
| /** @group setParam */ | ||
| @Since("2.4.0") | ||
| def setIdCol(value: String): this.type = set(idCol, value) | ||
| def setSrcCol(value: String): this.type = set(srcCol, value) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.4.0") | ||
| def setNeighborsCol(value: String): this.type = set(neighborsCol, value) | ||
| def setDstCol(value: String): this.type = set(dstCol, value) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.4.0") | ||
| def setSimilaritiesCol(value: String): this.type = set(similaritiesCol, value) | ||
| def setWeightCol(value: String): this.type = set(weightCol, value) | ||
|
|
||
| /** | ||
| * @param dataset A dataset with columns src, dst, weight representing the affinity matrix, | ||
|
||
| * which is the matrix A in the PIC paper. Suppose the src column value is i, | ||
| * the dst column value is j, the weight column value is similarity s,,ij,, | ||
| * which must be nonnegative. This is a symmetric matrix and hence | ||
| * s,,ij,, = s,,ji,,. For any (i, j) with nonzero similarity, there should be | ||
| * either (i, j, s,,ij,,) or (j, i, s,,ji,,) in the input. Rows with i = j are | ||
| * ignored, because we assume s,,ij,, = 0.0. | ||
| * | ||
| * @return A dataset that contains columns of vertex id and the corresponding cluster for the id. | ||
| * The schema of it will be: | ||
| * - id: Long | ||
| * - cluster: Int | ||
| */ | ||
| @Since("2.4.0") | ||
| override def transform(dataset: Dataset[_]): DataFrame = { | ||
| transformSchema(dataset.schema, logging = true) | ||
| def assignClusters(dataset: Dataset[_]): DataFrame = { | ||
| val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) { | ||
| lit(1.0) | ||
| } else { | ||
| col($(weightCol)).cast(DoubleType) | ||
| } | ||
|
|
||
| val sparkSession = dataset.sparkSession | ||
| val idColValue = $(idCol) | ||
| val rdd: RDD[(Long, Long, Double)] = | ||
| dataset.select( | ||
| col($(idCol)).cast(LongType), | ||
| col($(neighborsCol)).cast(ArrayType(LongType, containsNull = false)), | ||
| col($(similaritiesCol)).cast(ArrayType(DoubleType, containsNull = false)) | ||
| ).rdd.flatMap { | ||
| case Row(id: Long, nbrs: Seq[_], sims: Seq[_]) => | ||
| require(nbrs.size == sims.size, s"The length of the neighbor ID list must be " + | ||
| s"equal to the the length of the neighbor similarity list. Row for ID " + | ||
| s"$idColValue=$id has neighbor ID list of length ${nbrs.length} but similarity list " + | ||
| s"of length ${sims.length}.") | ||
| nbrs.asInstanceOf[Seq[Long]].zip(sims.asInstanceOf[Seq[Double]]).map { | ||
| case (nbr, similarity) => (id, nbr, similarity) | ||
| } | ||
| } | ||
| SchemaUtils.checkColumnTypes(dataset.schema, $(srcCol), Seq(IntegerType, LongType)) | ||
| SchemaUtils.checkColumnTypes(dataset.schema, $(dstCol), Seq(IntegerType, LongType)) | ||
| val rdd: RDD[(Long, Long, Double)] = dataset.select( | ||
| col($(srcCol)).cast(LongType), | ||
| col($(dstCol)).cast(LongType), | ||
| w).rdd.map { | ||
| case Row(src: Long, dst: Long, weight: Double) => (src, dst, weight) | ||
| } | ||
| val algorithm = new MLlibPowerIterationClustering() | ||
| .setK($(k)) | ||
| .setInitializationMode($(initMode)) | ||
| .setMaxIterations($(maxIter)) | ||
| val model = algorithm.run(rdd) | ||
|
|
||
| val predictionsRDD: RDD[Row] = model.assignments.map { assignment => | ||
| val assignmentsRDD: RDD[Row] = model.assignments.map { assignment => | ||
|
||
| Row(assignment.id, assignment.cluster) | ||
| } | ||
| val assignmentsSchema = StructType(Seq( | ||
| StructField("id", LongType, nullable = false), | ||
| StructField("cluster", IntegerType, nullable = false))) | ||
|
|
||
| val predictionsSchema = StructType(Seq( | ||
| StructField($(idCol), LongType, nullable = false), | ||
| StructField($(predictionCol), IntegerType, nullable = false))) | ||
| val predictions = { | ||
| val uncastPredictions = sparkSession.createDataFrame(predictionsRDD, predictionsSchema) | ||
| dataset.schema($(idCol)).dataType match { | ||
| case _: LongType => | ||
| uncastPredictions | ||
| case otherType => | ||
| uncastPredictions.select(col($(idCol)).cast(otherType).alias($(idCol))) | ||
| } | ||
| } | ||
|
|
||
| dataset.join(predictions, $(idCol)) | ||
| } | ||
|
|
||
| @Since("2.4.0") | ||
| override def transformSchema(schema: StructType): StructType = { | ||
| validateAndTransformSchema(schema) | ||
| dataset.sparkSession.createDataFrame(assignmentsRDD, assignmentsSchema) | ||
| } | ||
|
|
||
| @Since("2.4.0") | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you put all default values in a single
setDefaultcall?