Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
took changes from GayathriMurali branch
  • Loading branch information
jkbradley committed Sep 9, 2016
commit 97d569791dbe4aba59c8c9a54a932a1767ddb635
85 changes: 68 additions & 17 deletions mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
package org.apache.spark.ml.clustering

import org.apache.hadoop.fs.Path
import org.json4s.DefaultFormats
import org.json4s.JsonAST.JObject
import org.json4s.jackson.JsonMethods._

import org.apache.spark.annotation.{DeveloperApi, Experimental, Since}
import org.apache.spark.internal.Logging
Expand All @@ -26,19 +29,21 @@ import org.apache.spark.ml.linalg.{Matrix, Vector, Vectors, VectorUDT}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasFeaturesCol, HasMaxIter, HasSeed}
import org.apache.spark.ml.util._
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedLDAModel,
EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel,
LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel,
OnlineLDAOptimizer => OldOnlineLDAOptimizer}
import org.apache.spark.mllib.impl.PeriodicCheckpointer
import org.apache.spark.mllib.linalg.{Matrices => OldMatrices, Vector => OldVector,
Vectors => OldVectors}
import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
import org.apache.spark.mllib.linalg.MatrixImplicits._
import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.functions.{col, monotonically_increasing_id, udf}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.VersionUtils


private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasMaxIter
Expand Down Expand Up @@ -80,6 +85,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
* - Values should be >= 0
* - default = uniformly (1.0 / k), following the implementation from
* [[https://github.com/Blei-Lab/onlineldavb]].
*
* @group param
*/
@Since("1.6.0")
Expand Down Expand Up @@ -121,6 +127,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
* - Value should be >= 0
* - default = (1.0 / k), following the implementation from
* [[https://github.com/Blei-Lab/onlineldavb]].
*
* @group param
*/
@Since("1.6.0")
Expand Down Expand Up @@ -354,6 +361,39 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
}
}

private object LDAParams {

/**
* Equivalent to [[DefaultParamsReader.getAndSetParams()]], but handles [[LDA]] and [[LDAModel]]
* formats saved with Spark 1.6, which differ from the formats in Spark 2.0+.
*
* @param model [[LDA]] or [[LDAModel]] instance. This instance will be modified with
* [[Param]] values extracted from metadata.
* @param metadata Loaded model metadata
*/
def getAndSetParams(model: LDAParams, metadata: Metadata): Unit = {
VersionUtils.majorMinorVersion(metadata.sparkVersion) match {
case (1, 6) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we didn't support LDA serialization in 1.5, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, it was 1.6

implicit val format = DefaultFormats
metadata.params match {
case JObject(pairs) =>
pairs.foreach { case (paramName, jsonValue) =>
val origParam =
if (paramName == "topicDistribution") "topicDistributionCol" else paramName
val param = model.getParam(origParam)
val value = param.jsonDecode(compact(render(jsonValue)))
model.set(param, value)
}
case _ =>
throw new IllegalArgumentException(
s"Cannot recognize JSON metadata: ${metadata.metadataJson}.")
}
case _ => // 2.0+
DefaultParamsReader.getAndSetParams(model, metadata)
}
}
}


/**
* :: Experimental ::
Expand Down Expand Up @@ -418,11 +458,11 @@ sealed abstract class LDAModel private[ml] (
val transformer = oldLocalModel.getTopicDistributionMethod(sparkSession.sparkContext)

val t = udf { (v: Vector) => transformer(OldVectors.fromML(v)).asML }
dataset.withColumn($(topicDistributionCol), t(col($(featuresCol)))).toDF
dataset.withColumn($(topicDistributionCol), t(col($(featuresCol)))).toDF()
} else {
logWarning("LDAModel.transform was called without any output columns. Set an output column" +
" such as topicDistributionCol to produce results.")
dataset.toDF
dataset.toDF()
}
}

Expand Down Expand Up @@ -578,18 +618,15 @@ object LocalLDAModel extends MLReadable[LocalLDAModel] {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath)
.select("vocabSize", "topicsMatrix", "docConcentration", "topicConcentration",
"gammaShape")
.head()
val vocabSize = data.getAs[Int](0)
val topicsMatrix = data.getAs[Matrix](1)
val docConcentration = data.getAs[Vector](2)
val topicConcentration = data.getAs[Double](3)
val gammaShape = data.getAs[Double](4)
val vectorConverted = MLUtils.convertVectorColumnsToML(data, "docConcentration")
val Row(vocabSize: Int, topicsMatrix: Matrix, docConcentration: Vector,
topicConcentration: Double, gammaShape: Double) = MLUtils.convertMatrixColumnsToML(
vectorConverted, "topicsMatrix").select("vocabSize", "topicsMatrix", "docConcentration",
"topicConcentration", "gammaShape").head()
val oldModel = new OldLocalLDAModel(topicsMatrix, docConcentration, topicConcentration,
gammaShape)
val model = new LocalLDAModel(metadata.uid, vocabSize, oldModel, sparkSession)
DefaultParamsReader.getAndSetParams(model, metadata)
LDAParams.getAndSetParams(model, metadata)
model
}
}
Expand Down Expand Up @@ -735,9 +772,9 @@ object DistributedLDAModel extends MLReadable[DistributedLDAModel] {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val modelPath = new Path(path, "oldModel").toString
val oldModel = OldDistributedLDAModel.load(sc, modelPath)
val model = new DistributedLDAModel(
metadata.uid, oldModel.vocabSize, oldModel, sparkSession, None)
DefaultParamsReader.getAndSetParams(model, metadata)
val model = new DistributedLDAModel(metadata.uid, oldModel.vocabSize,
oldModel, sparkSession, None)
LDAParams.getAndSetParams(model, metadata)
model
}
}
Expand Down Expand Up @@ -885,7 +922,7 @@ class LDA @Since("1.6.0") (
}

@Since("2.0.0")
object LDA extends DefaultParamsReadable[LDA] {
object LDA extends MLReadable[LDA] {

/** Get dataset for spark.mllib LDA */
private[clustering] def getOldDataset(
Expand All @@ -900,6 +937,20 @@ object LDA extends DefaultParamsReadable[LDA] {
}
}

private class LDAReader extends MLReader[LDA] {

private val className = classOf[LDA].getName

override def load(path: String): LDA = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val model = new LDA(metadata.uid)
LDAParams.getAndSetParams(model, metadata)
model
}
}

override def read: MLReader[LDA] = new LDAReader

@Since("2.0.0")
override def load(path: String): LDA = super.load(path)
}