Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
save load for ML localLDA
  • Loading branch information
hhbyyh committed Nov 22, 2015
commit 58abcabef6ae892dfce47bb805d39ca421ab0a3f
104 changes: 100 additions & 4 deletions mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@

package org.apache.spark.ml.clustering

import org.apache.hadoop.fs.Path
import org.apache.spark.Logging
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.util.{SchemaUtils, Identifiable}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasFeaturesCol, HasSeed, HasMaxIter}
import org.apache.spark.ml.param._
import org.apache.spark.ml.util._
import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedLDAModel,
EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel,
LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel,
Expand Down Expand Up @@ -322,7 +323,7 @@ sealed abstract class LDAModel private[ml] (
@Since("1.6.0") override val uid: String,
@Since("1.6.0") val vocabSize: Int,
@Since("1.6.0") @transient protected val sqlContext: SQLContext)
extends Model[LDAModel] with LDAParams with Logging {
extends Model[LDAModel] with LDAParams with Logging with MLWritable {

// NOTE to developers:
// This abstraction should contain all important functionality for basic LDA usage.
Expand Down Expand Up @@ -486,6 +487,61 @@ class LocalLDAModel private[ml] (

@Since("1.6.0")
override def isDistributed: Boolean = false

@Since("1.6.0")
override def write: MLWriter = new LocalLDAModel.LocalLDAModelWriter(this)
}


@Since("1.6.0")
object LocalLDAModel extends MLReadable[LocalLDAModel] {

private[LocalLDAModel]
class LocalLDAModelWriter(instance: LocalLDAModel) extends MLWriter {

private case class Data(vocabSize: Int,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scala style: put vocabSize on next line

topicsMatrix: Matrix,
docConcentration: Vector,
topicConcentration: Double)

override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
val oldModel = instance.oldLocalModel
val data = Data(instance.vocabSize,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to put 1 argument per line in method calls

oldModel.topicsMatrix,
oldModel.docConcentration,
oldModel.topicConcentration)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}

private class LocalLDAModelReader extends MLReader[LocalLDAModel] {

private val className = classOf[LocalLDAModel].getName

override def load(path: String): LocalLDAModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val data = sqlContext.read.parquet(dataPath)
.select("vocabSize", "topicsMatrix", "docConcentration", "topicConcentration")
.head()
val vocabSize = data.getAs[Int](0)
val topicsMatrix = data.getAs[Matrix](1)
val docConcentration = data.getAs[Vector](2)
val topicConcentration = data.getAs[Double](3)
val oldModel = new OldLocalLDAModel(topicsMatrix, docConcentration, topicConcentration)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't provide an API currently to set gammaShape, but I wonder if we will in the future. If so, then we might want to store gammaShape with the Data so that we don't need a new model format in the future.

val model = new LocalLDAModel(metadata.uid, vocabSize, oldModel, sqlContext)
DefaultParamsReader.getAndSetParams(model, metadata)
model
}
}

@Since("1.6.0")
override def read: MLReader[LocalLDAModel] = new LocalLDAModelReader

@Since("1.6.0")
override def load(path: String): LocalLDAModel = super.load(path)
}


Expand Down Expand Up @@ -562,6 +618,43 @@ class DistributedLDAModel private[ml] (
*/
@Since("1.6.0")
lazy val logPrior: Double = oldDistributedModel.logPrior

@Since("1.6.0")
override def write: MLWriter = new DistributedLDAModel.DistributedLDAModelWriter(this)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Distributed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, but do you mean we can change it to DistributedLDAModel.DistributedWriter(this) ?

}


@Since("1.6.0")
object DistributedLDAModel extends MLReadable[DistributedLDAModel] {

private[DistributedLDAModel]
class DistributedLDAModelWriter(instance: DistributedLDAModel) extends MLWriter {

override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
instance.oldDistributedModel.save(sc, path)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may not work; it will attempt and fail to overwrite the metadata saved on the previous line. Maybe put it in a subdirectory.

}
}

private class DistributedLDAModelReader extends MLReader[DistributedLDAModel] {

private val className = classOf[DistributedLDAModel].getName

override def load(path: String): DistributedLDAModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val oldModel = OldDistributedLDAModel.load(sc, path)
val model = new DistributedLDAModel(
metadata.uid, oldModel.vocabSize, oldModel, sqlContext, None)
DefaultParamsReader.getAndSetParams(model, metadata)
model
}
}

@Since("1.6.0")
override def read: MLReader[DistributedLDAModel] = new DistributedLDAModelReader

@Since("1.6.0")
override def load(path: String): DistributedLDAModel = super.load(path)
}


Expand Down Expand Up @@ -593,7 +686,8 @@ class DistributedLDAModel private[ml] (
@Since("1.6.0")
@Experimental
class LDA @Since("1.6.0") (
@Since("1.6.0") override val uid: String) extends Estimator[LDAModel] with LDAParams {
@Since("1.6.0") override val uid: String) extends Estimator[LDAModel]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scala style: Put extends on next line with 2-space indentation

with LDAParams with DefaultParamsWritable {

@Since("1.6.0")
def this() = this(Identifiable.randomUID("lda"))
Expand Down Expand Up @@ -695,7 +789,7 @@ class LDA @Since("1.6.0") (
}


private[clustering] object LDA {
private[clustering] object LDA extends DefaultParamsReadable[LDA]{
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: space before brace


/** Get dataset for spark.mllib LDA */
def getOldDataset(dataset: DataFrame, featuresCol: String): RDD[(Long, Vector)] = {
Expand All @@ -706,4 +800,6 @@ private[clustering] object LDA {
(docId, features)
}
}

override def load(path: String): LDA = super.load(path)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add Since version

}
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ abstract class LDAModel private[clustering] extends Saveable {
* @param topics Inferred topics (vocabSize x k matrix).
*/
@Since("1.3.0")
class LocalLDAModel private[clustering] (
class LocalLDAModel private[spark] (
@Since("1.3.0") val topics: Matrix,
@Since("1.5.0") override val docConcentration: Vector,
@Since("1.5.0") override val topicConcentration: Double,
Expand Down
31 changes: 29 additions & 2 deletions mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
package org.apache.spark.ml.clustering

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row, SQLContext}


Expand All @@ -39,10 +40,24 @@ object LDASuite {
}.map(v => new TestRow(v))
sql.createDataFrame(rdd)
}

/**
* Mapping from all Params to valid settings which differ from the defaults.
* This is useful for tests which need to exercise all Params, such as save/load.
* This excludes input columns to simplify some tests.
*/
val allParamSettings: Map[String, Any] = Map(
"k" -> 3,
"maxIter" -> 10,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Set to 2 for speed

"checkpointInterval" -> 30,
"learningOffset" -> 1023.0,
"learningDecay" -> 0.52,
"subsamplingRate" -> 0.051
)
}


class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {

val k: Int = 5
val vocabSize: Int = 30
Expand Down Expand Up @@ -218,4 +233,16 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
val lp = model.logPrior
assert(lp <= 0.0 && lp != Double.NegativeInfinity)
}

test("read/write LocalLDAModel") {
def checkModelData(model: LDAModel, model2: LDAModel): Unit = {
assert(model.vocabSize === model2.vocabSize)
assert(Vectors.dense(model.topicsMatrix.toArray) ~==
Vectors.dense(model2.topicsMatrix.toArray) absTol 1e-6)
assert(Vectors.dense(model.getDocConcentration) ~==
Vectors.dense(model2.getDocConcentration) absTol 1e-6)
}
val lda = new LDA()
testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings, checkModelData)
}
}