Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
b80bb1f
add pic framework (model, class etc)
wangmiao1981 Jun 13, 2016
75004e8
change a comment
wangmiao1981 Jun 13, 2016
e1d9a33
add missing functions fit predict load save etc.
wangmiao1981 Jun 17, 2016
f8343e0
add unit test flie
wangmiao1981 Jun 18, 2016
c62a2c0
add test cases part 1
wangmiao1981 Jun 20, 2016
1277f75
add unit test part 2: test fit, parameters etc.
wangmiao1981 Jun 20, 2016
f50873d
fix a type issue
wangmiao1981 Jun 20, 2016
88a9ae0
add more unit tests
wangmiao1981 Jun 21, 2016
0618815
delete unused import and add comments
wangmiao1981 Jun 21, 2016
04fddbd
change version to 2.1.0
wangmiao1981 Oct 25, 2016
b49f4c7
change PIC as a Transformer
wangmiao1981 Nov 3, 2016
d3f86d0
add LabelCol
wangmiao1981 Nov 4, 2016
655bc67
change col implementation
wangmiao1981 Nov 4, 2016
d5975bc
address some of the comments
wangmiao1981 Feb 17, 2017
f012624
add additional test with dataset having more data
wangmiao1981 Feb 21, 2017
bef0594
change input data format
wangmiao1981 Mar 14, 2017
a4bee89
resolve warnings
wangmiao1981 Mar 15, 2017
0f97907
add neighbor and weight cols
wangmiao1981 Mar 16, 2017
015383a
address review comments 1
wangmiao1981 Aug 15, 2017
2d29570
fix style
wangmiao1981 Aug 15, 2017
af549e8
remove unused comments
wangmiao1981 Aug 15, 2017
9b4f3d5
add Since
wangmiao1981 Aug 15, 2017
e35fe54
fix missing >
wangmiao1981 Aug 17, 2017
73485d8
fix doc
wangmiao1981 Aug 17, 2017
bd5ca5d
Merge github.com:apache/spark into pic
wangmiao1981 Sep 12, 2017
3b0f71c
Merge github.com:apache/spark into pic
wangmiao1981 Oct 25, 2017
752b685
address review comments
wangmiao1981 Oct 25, 2017
cfa18af
fix unit test
wangmiao1981 Oct 30, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add missing functions fit predict load save etc.
  • Loading branch information
wangmiao1981 committed Aug 16, 2017
commit e1d9a3320336ff8a54f4ea441c6587431ab9e81c
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,21 @@ import org.apache.spark.ml.linalg.{Vector, VectorUDT}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.mllib.clustering.{PowerIterationClustering => MLlibPowerIterationClustering}
import org.apache.spark.mllib.clustering.{PowerIterationClusteringModel => MLlibPowerIterationClusteringModel}
import org.apache.spark.mllib.clustering.PowerIterationClustering.Assignment
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{IntegerType, StructType}
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

/*
* Common params for PowerIterationClustering and PowerIterationClusteringModel
*/
private[clustering] trait PowerIterationClusteringParams extends Params with HasMaxIter
with HasFeaturesCol with HasPredictionCol {

/*
/**
* The number of clusters to create (k). Must be > 1. Default: 2.
* @group param
*/
Expand Down Expand Up @@ -66,10 +68,10 @@ private[clustering] trait PowerIterationClusteringParams extends Params with Has
def getInitMode: String = $(initMode)

/**
* Validates and transforms the input schema.
* @param schema input schema
* @return output schema
*/
* Validates and transforms the input schema.
* @param schema input schema
* @return output schema
*/
protected def validateAndTransformSchema(schema: StructType): StructType = {
SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType)
Expand All @@ -80,26 +82,56 @@ private[clustering] trait PowerIterationClusteringParams extends Params with Has
@Since("2.0.0")
@Experimental
class PowerIterationClusteringModel private[ml] (
@Since("2.0.0") override val uid: String)
@Since("2.0.0") override val uid: String,
private val parentModel: MLlibPowerIterationClusteringModel)
extends Model[PowerIterationClusteringModel] with PowerIterationClusteringParams with MLWritable {

@Since("2.0.0")
override def copy(extra: ParamMap): PowerIterationClusteringModel = {
val copied = new PowerIterationClusteringModel(uid)
val copied = new PowerIterationClusteringModel(uid, parentModel)
copyValues(copied, extra).setParent(this.parent)
}

def assignments: RDD[Assignment] = parentModel.assignments

/** @group setParam */
@Since("2.0.0")
def saveK(value: Int): this.type = set(k, value)

/** @group expertSetParam */
@Since("2.0.0")
def saveInitMode(value: String): this.type = set(initMode, value)

/** @group setParam */
@Since("2.0.0")
def saveMaxIter(value: Int): this.type = set(maxIter, value)

@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
val predictUDF = udf((vector: Vector) => predict(vector))
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
predict(dataset)
}

@Since("2.0.0")
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema)
}

private[clustering] def predict(features: Vector): Int = ???
private[clustering] def predict(features: Dataset[_]): DataFrame = {
val sparkSession = features.sparkSession
val powerIterationClustering = new PowerIterationClustering().setK($(k))
.setInitMode($(initMode))
.setMaxIter($(maxIter))
val model = powerIterationClustering.fit(features)
model.saveK($(k))
.saveInitMode($(initMode))
.saveMaxIter($(maxIter))
val rows: RDD[Row] = model.assignments.map {
case assignment: Assignment => Row(assignment.cluster)
}
val schema = new StructType(Array(StructField("cluster", IntegerType)))
val predict = sparkSession.createDataFrame(rows, schema)
features.withColumn($(predictionCol), predict.col("cluster"))
}

@Since("2.0.0")
override def write: MLWriter =
Expand All @@ -113,15 +145,15 @@ class PowerIterationClusteringModel private[ml] (
}

/**
* Return true if there exists summary of model.
*/
* Return true if there exists summary of model.
*/
@Since("2.0.0")
def hasSummary: Boolean = trainingSummary.nonEmpty

/**
* Gets summary of model on training set. An exception is
* thrown if `trainingSummary == None`.
*/
* Gets summary of model on training set. An exception is
* thrown if `trainingSummary == None`.
*/
@Since("2.0.0")
def summary: PowerIterationClusteringSummary = trainingSummary.getOrElse {
throw new SparkException(
Expand All @@ -137,19 +169,34 @@ object PowerIterationClusteringModel extends MLReadable[PowerIterationClustering
new PowerIterationClusteringModelReader()

@Since("2.0.0")
override def load(path: String): PowerIterationClusteringModel = ???
override def load(path: String): PowerIterationClusteringModel = super.load(path)

/** [[MLWriter]] instance for [[PowerIterationClusteringModel]] */
private[PowerIterationClusteringModel] class PowerIterationClusteringModelWriter
(instance: PowerIterationClusteringModel) extends MLWriter {

override protected def saveImpl(path: String): Unit = ???
override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sc)
MLlibPowerIterationClusteringModel.SaveLoadV1_0.save(sc, instance.parentModel, path)
}
}

private class PowerIterationClusteringModelReader
extends MLReader[PowerIterationClusteringModel] {

override def load(path: String): PowerIterationClusteringModel = ???
/** Checked against metadata when loading model */
private val className = classOf[PowerIterationClusteringModel].getName

override def load(path: String): PowerIterationClusteringModel = {

val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val parentModel = MLlibPowerIterationClusteringModel.SaveLoadV1_0.load(sc, path)

val model = new PowerIterationClusteringModel(metadata.uid, parentModel)
DefaultParamsReader.getAndSetParams(model, metadata)
model
}
}
}

Expand Down Expand Up @@ -192,7 +239,22 @@ class PowerIterationClustering @Since("2.0.0") (
def setMaxIter(value: Int): this.type = set(maxIter, value)

@Since("2.0.0")
override def fit(dataset: Dataset[_]): PowerIterationClusteringModel = ???
override def fit(dataset: Dataset[_]): PowerIterationClusteringModel = {
val rdd: RDD[(Long, Long, Double)] = dataset.select(col($(featuresCol))).rdd.map {
case Row(point: Vector) => point.asInstanceOf[(Long, Long, Double)]
}

val algo = new MLlibPowerIterationClustering()
.setK($(k))
.setInitializationMode($(initMode))
.setMaxIterations($(maxIter))
val parentModel = algo.run(rdd)
val model = copyValues(new PowerIterationClusteringModel(uid, parentModel).setParent(this))
model.saveK($(k))
.saveInitMode($(initMode))
.saveMaxIter($(maxIter))
model
}

@Since("2.0.0")
override def transformSchema(schema: StructType): StructType = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode
PowerIterationClusteringModel.SaveLoadV1_0.load(sc, path)
}

private[clustering]
private[spark]
object SaveLoadV1_0 {

private val thisFormatVersion = "1.0"
Expand Down