Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
minor updates per code review
  • Loading branch information
jkbradley committed Nov 12, 2015
commit b3e9341498840e137d277ec3347a1cc4fce54179
40 changes: 27 additions & 13 deletions mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,13 @@ sealed abstract class LDAModel private[ml] (
@Since("1.6.0")
def setSeed(value: Long): this.type = set(seed, value)

/**
* Transforms the input dataset.
*
* WARNING: If this model is an instance of [[DistributedLDAModel]] (produced when [[optimizer]]
* is set to "em"), this involves collecting a large [[topicsMatrix]] to the driver.
* This implementation may be changed in the future.
*/
@Since("1.6.0")
override def transform(dataset: DataFrame): DataFrame = {
if ($(topicDistributionCol).nonEmpty) {
Expand Down Expand Up @@ -382,9 +389,9 @@ sealed abstract class LDAModel private[ml] (
* This is a matrix of size vocabSize x k, where each column is a topic.
* No guarantees are given about the ordering of the topics.
*
* WARNING: If this model is actually a [[DistributedLDAModel]] instance from EM,
* then this method could involve collecting a large amount of data to the driver
* (on the order of vocabSize x k).
* WARNING: If this model is actually a [[DistributedLDAModel]] instance produced by
* the Expectation-Maximization ("em") [[optimizer]], then this method could involve
* collecting a large amount of data to the driver (on the order of vocabSize x k).
*/
@Since("1.6.0")
def topicsMatrix: Matrix = oldLocalModel.topicsMatrix
Expand All @@ -398,9 +405,9 @@ sealed abstract class LDAModel private[ml] (
*
* See Equation (16) in the Online LDA paper (Hoffman et al., 2010).
*
* WARNING: If this model was learned via a [[DistributedLDAModel]], this involves collecting
* a large [[topicsMatrix]] to the driver. This implementation may be changed in the
* future.
* WARNING: If this model is an instance of [[DistributedLDAModel]] (produced when [[optimizer]]
* is set to "em"), this involves collecting a large [[topicsMatrix]] to the driver.
* This implementation may be changed in the future.
*
* @param dataset test corpus to use for calculating log likelihood
* @return variational lower bound on the log likelihood of the entire corpus
Expand All @@ -415,6 +422,10 @@ sealed abstract class LDAModel private[ml] (
* Calculate an upper bound bound on perplexity. (Lower is better.)
* See Equation (16) in the Online LDA paper (Hoffman et al., 2010).
*
* WARNING: If this model is an instance of [[DistributedLDAModel]] (produced when [[optimizer]]
* is set to "em"), this involves collecting a large [[topicsMatrix]] to the driver.
* This implementation may be changed in the future.
*
* @param dataset test corpus to use for calculating perplexity
* @return Variational upper bound on log perplexity per token.
*/
Expand Down Expand Up @@ -486,19 +497,20 @@ class LocalLDAModel private[ml] (
*
* This model stores the inferred topics, the full training dataset, and the topic distribution
* for each training document.
*
* @param oldLocalModelOption Used to implement [[oldLocalModel]] as a lazy val, but keeping
* [[copy()]] cheap.
*/
@Since("1.6.0")
@Experimental
class DistributedLDAModel private[ml] (
uid: String,
vocabSize: Int,
private val oldDistributedModel: OldDistributedLDAModel,
sqlContext: SQLContext)
sqlContext: SQLContext,
private var oldLocalModelOption: Option[OldLocalLDAModel])
extends LDAModel(uid, vocabSize, sqlContext) {

/** Used to implement [[oldLocalModel]] as a lazy val, but with cheap [[copy()]] */
private var oldLocalModelOption: Option[OldLocalLDAModel] = None

override protected def oldLocalModel: OldLocalLDAModel = {
if (oldLocalModelOption.isEmpty) {
oldLocalModelOption = Some(oldDistributedModel.toLocal)
Expand All @@ -511,14 +523,16 @@ class DistributedLDAModel private[ml] (
/**
* Convert this distributed model to a local representation. This discards info about the
* training dataset.
*
* WARNING: This involves collecting a large [[topicsMatrix]] to the driver.
*/
@Since("1.6.0")
def toLocal: LocalLDAModel = new LocalLDAModel(uid, vocabSize, oldLocalModel, sqlContext)

@Since("1.6.0")
override def copy(extra: ParamMap): DistributedLDAModel = {
val copied = new DistributedLDAModel(uid, vocabSize, oldDistributedModel, sqlContext)
copied.oldLocalModelOption = oldLocalModelOption
val copied =
new DistributedLDAModel(uid, vocabSize, oldDistributedModel, sqlContext, oldLocalModelOption)
copyValues(copied, extra).setParent(parent)
copied
}
Expand Down Expand Up @@ -669,7 +683,7 @@ class LDA @Since("1.6.0") (
case m: OldLocalLDAModel =>
new LocalLDAModel(uid, m.vocabSize, m, dataset.sqlContext)
case m: OldDistributedLDAModel =>
new DistributedLDAModel(uid, m.vocabSize, m, dataset.sqlContext)
new DistributedLDAModel(uid, m.vocabSize, m, dataset.sqlContext, None)
}
copyValues(newModel).setParent(this)
}
Expand Down