-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-11847] [ML] Model export/import for spark.ml: LDA #9894
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,12 +17,13 @@ | |
|
|
||
| package org.apache.spark.ml.clustering | ||
|
|
||
| import org.apache.hadoop.fs.Path | ||
| import org.apache.spark.Logging | ||
| import org.apache.spark.annotation.{Experimental, Since} | ||
| import org.apache.spark.ml.util.{SchemaUtils, Identifiable} | ||
| import org.apache.spark.ml.{Estimator, Model} | ||
| import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasFeaturesCol, HasSeed, HasMaxIter} | ||
| import org.apache.spark.ml.param._ | ||
| import org.apache.spark.ml.util._ | ||
| import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedLDAModel, | ||
| EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel, | ||
| LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel, | ||
|
|
@@ -322,7 +323,7 @@ sealed abstract class LDAModel private[ml] ( | |
| @Since("1.6.0") override val uid: String, | ||
| @Since("1.6.0") val vocabSize: Int, | ||
| @Since("1.6.0") @transient protected val sqlContext: SQLContext) | ||
| extends Model[LDAModel] with LDAParams with Logging { | ||
| extends Model[LDAModel] with LDAParams with Logging with MLWritable { | ||
|
|
||
| // NOTE to developers: | ||
| // This abstraction should contain all important functionality for basic LDA usage. | ||
|
|
@@ -486,6 +487,61 @@ class LocalLDAModel private[ml] ( | |
|
|
||
| @Since("1.6.0") | ||
| override def isDistributed: Boolean = false | ||
|
|
||
| @Since("1.6.0") | ||
| override def write: MLWriter = new LocalLDAModel.LocalLDAModelWriter(this) | ||
| } | ||
|
|
||
|
|
||
| @Since("1.6.0") | ||
| object LocalLDAModel extends MLReadable[LocalLDAModel] { | ||
|
|
||
| private[LocalLDAModel] | ||
| class LocalLDAModelWriter(instance: LocalLDAModel) extends MLWriter { | ||
|
|
||
| private case class Data(vocabSize: Int, | ||
| topicsMatrix: Matrix, | ||
| docConcentration: Vector, | ||
| topicConcentration: Double) | ||
|
|
||
| override protected def saveImpl(path: String): Unit = { | ||
| DefaultParamsWriter.saveMetadata(instance, path, sc) | ||
| val oldModel = instance.oldLocalModel | ||
| val data = Data(instance.vocabSize, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no need to put 1 argument per line in method calls |
||
| oldModel.topicsMatrix, | ||
| oldModel.docConcentration, | ||
| oldModel.topicConcentration) | ||
| val dataPath = new Path(path, "data").toString | ||
| sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) | ||
| } | ||
| } | ||
|
|
||
| private class LocalLDAModelReader extends MLReader[LocalLDAModel] { | ||
|
|
||
| private val className = classOf[LocalLDAModel].getName | ||
|
|
||
| override def load(path: String): LocalLDAModel = { | ||
| val metadata = DefaultParamsReader.loadMetadata(path, sc, className) | ||
| val dataPath = new Path(path, "data").toString | ||
| val data = sqlContext.read.parquet(dataPath) | ||
| .select("vocabSize", "topicsMatrix", "docConcentration", "topicConcentration") | ||
| .head() | ||
| val vocabSize = data.getAs[Int](0) | ||
| val topicsMatrix = data.getAs[Matrix](1) | ||
| val docConcentration = data.getAs[Vector](2) | ||
| val topicConcentration = data.getAs[Double](3) | ||
| val oldModel = new OldLocalLDAModel(topicsMatrix, docConcentration, topicConcentration) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't provide an API currently to set gammaShape, but I wonder if we will in the future. If so, then we might want to store gammaShape with the Data so that we don't need a new model format in the future. |
||
| val model = new LocalLDAModel(metadata.uid, vocabSize, oldModel, sqlContext) | ||
| DefaultParamsReader.getAndSetParams(model, metadata) | ||
| model | ||
| } | ||
| } | ||
|
|
||
| @Since("1.6.0") | ||
| override def read: MLReader[LocalLDAModel] = new LocalLDAModelReader | ||
|
|
||
| @Since("1.6.0") | ||
| override def load(path: String): LocalLDAModel = super.load(path) | ||
| } | ||
|
|
||
|
|
||
|
|
@@ -562,6 +618,43 @@ class DistributedLDAModel private[ml] ( | |
| */ | ||
| @Since("1.6.0") | ||
| lazy val logPrior: Double = oldDistributedModel.logPrior | ||
|
|
||
| @Since("1.6.0") | ||
| override def write: MLWriter = new DistributedLDAModel.DistributedLDAModelWriter(this) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Distributed
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, but do you mean we can change it to |
||
| } | ||
|
|
||
|
|
||
| @Since("1.6.0") | ||
| object DistributedLDAModel extends MLReadable[DistributedLDAModel] { | ||
|
|
||
| private[DistributedLDAModel] | ||
| class DistributedLDAModelWriter(instance: DistributedLDAModel) extends MLWriter { | ||
|
|
||
| override protected def saveImpl(path: String): Unit = { | ||
| DefaultParamsWriter.saveMetadata(instance, path, sc) | ||
| instance.oldDistributedModel.save(sc, path) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This may not work; it will attempt and fail to overwrite the metadata saved on the previous line. Maybe put it in a subdirectory. |
||
| } | ||
| } | ||
|
|
||
| private class DistributedLDAModelReader extends MLReader[DistributedLDAModel] { | ||
|
|
||
| private val className = classOf[DistributedLDAModel].getName | ||
|
|
||
| override def load(path: String): DistributedLDAModel = { | ||
| val metadata = DefaultParamsReader.loadMetadata(path, sc, className) | ||
| val oldModel = OldDistributedLDAModel.load(sc, path) | ||
| val model = new DistributedLDAModel( | ||
| metadata.uid, oldModel.vocabSize, oldModel, sqlContext, None) | ||
| DefaultParamsReader.getAndSetParams(model, metadata) | ||
| model | ||
| } | ||
| } | ||
|
|
||
| @Since("1.6.0") | ||
| override def read: MLReader[DistributedLDAModel] = new DistributedLDAModelReader | ||
|
|
||
| @Since("1.6.0") | ||
| override def load(path: String): DistributedLDAModel = super.load(path) | ||
| } | ||
|
|
||
|
|
||
|
|
@@ -593,7 +686,8 @@ class DistributedLDAModel private[ml] ( | |
| @Since("1.6.0") | ||
| @Experimental | ||
| class LDA @Since("1.6.0") ( | ||
| @Since("1.6.0") override val uid: String) extends Estimator[LDAModel] with LDAParams { | ||
| @Since("1.6.0") override val uid: String) extends Estimator[LDAModel] | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. scala style: Put extends on next line with 2-space indentation |
||
| with LDAParams with DefaultParamsWritable { | ||
|
|
||
| @Since("1.6.0") | ||
| def this() = this(Identifiable.randomUID("lda")) | ||
|
|
@@ -695,7 +789,7 @@ class LDA @Since("1.6.0") ( | |
| } | ||
|
|
||
|
|
||
| private[clustering] object LDA { | ||
| private[clustering] object LDA extends DefaultParamsReadable[LDA]{ | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. style: space before brace |
||
|
|
||
| /** Get dataset for spark.mllib LDA */ | ||
| def getOldDataset(dataset: DataFrame, featuresCol: String): RDD[(Long, Vector)] = { | ||
|
|
@@ -706,4 +800,6 @@ private[clustering] object LDA { | |
| (docId, features) | ||
| } | ||
| } | ||
|
|
||
| override def load(path: String): LDA = super.load(path) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add Since version |
||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,9 +18,10 @@ | |
| package org.apache.spark.ml.clustering | ||
|
|
||
| import org.apache.spark.SparkFunSuite | ||
| import org.apache.spark.ml.util.MLTestingUtils | ||
| import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} | ||
| import org.apache.spark.mllib.linalg.{Vector, Vectors} | ||
| import org.apache.spark.mllib.util.MLlibTestSparkContext | ||
| import org.apache.spark.mllib.util.TestingUtils._ | ||
| import org.apache.spark.sql.{DataFrame, Row, SQLContext} | ||
|
|
||
|
|
||
|
|
@@ -39,10 +40,24 @@ object LDASuite { | |
| }.map(v => new TestRow(v)) | ||
| sql.createDataFrame(rdd) | ||
| } | ||
|
|
||
| /** | ||
| * Mapping from all Params to valid settings which differ from the defaults. | ||
| * This is useful for tests which need to exercise all Params, such as save/load. | ||
| * This excludes input columns to simplify some tests. | ||
| */ | ||
| val allParamSettings: Map[String, Any] = Map( | ||
| "k" -> 3, | ||
| "maxIter" -> 10, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Set to 2 for speed |
||
| "checkpointInterval" -> 30, | ||
| "learningOffset" -> 1023.0, | ||
| "learningDecay" -> 0.52, | ||
| "subsamplingRate" -> 0.051 | ||
| ) | ||
| } | ||
|
|
||
|
|
||
| class LDASuite extends SparkFunSuite with MLlibTestSparkContext { | ||
| class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { | ||
|
|
||
| val k: Int = 5 | ||
| val vocabSize: Int = 30 | ||
|
|
@@ -218,4 +233,16 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { | |
| val lp = model.logPrior | ||
| assert(lp <= 0.0 && lp != Double.NegativeInfinity) | ||
| } | ||
|
|
||
| test("read/write LocalLDAModel") { | ||
| def checkModelData(model: LDAModel, model2: LDAModel): Unit = { | ||
| assert(model.vocabSize === model2.vocabSize) | ||
| assert(Vectors.dense(model.topicsMatrix.toArray) ~== | ||
| Vectors.dense(model2.topicsMatrix.toArray) absTol 1e-6) | ||
| assert(Vectors.dense(model.getDocConcentration) ~== | ||
| Vectors.dense(model2.getDocConcentration) absTol 1e-6) | ||
| } | ||
| val lda = new LDA() | ||
| testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings, checkModelData) | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
scala style: put vocabSize on next line