Skip to content
Prev Previous commit
Next Next commit
Added test suite for spark.ml LDA
  • Loading branch information
jkbradley committed Nov 6, 2015
commit ffb68c5af18e7755fe2e834fc16ce2cd413786d3
33 changes: 31 additions & 2 deletions mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
case "online" => setOptimizer(new OnlineLDAOptimizer)
case "em" => setOptimizer(new EMLDAOptimizer)
case _ => throw new IllegalArgumentException(
s"LDA was given unknown optimizer \"$value\". Supported values: em, online")
s"LDA was given unknown optimizer '$value'. Supported values: em, online")
}

/**
Expand Down Expand Up @@ -228,6 +228,14 @@ class LDAModel private[ml] (
@Since("1.6.0") @transient protected val sqlContext: SQLContext)
extends Model[LDAModel] with LDAParams with Logging {

override def validateParams(): Unit = {
if (getDocConcentration.length != 1) {
require(getDocConcentration.length == getK, s"LDA docConcentration was of length" +
s" ${getDocConcentration.length}, but k = $getK. docConcentration must be either" +
s" length 1 (scalar) or an array of length k.")
}
}

/** Returns underlying spark.mllib model */
@Since("1.6.0")
protected def getModel: OldLDAModel = oldLocalModel match {
Expand Down Expand Up @@ -431,10 +439,13 @@ class DistributedLDAModel private[ml] (
* given the current parameter estimates:
* log P(docs | topics, topic distributions for docs, alpha, eta)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alpha and especially eta are confusing in this context where the implementation is in a whole different file

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

*
* Note:
* Notes:
* - This excludes the prior; for that, use [[logPrior]].
* - Even with [[logPrior]], this is NOT the same as the data log likelihood given the
* hyperparameters.
* - This is computed from the topic distributions computed during training. If you call
* [[logLikelihood()]] on the same training dataset, the topic distributions will be computed
* again, possibly giving different results.
*/
@Since("1.6.0")
lazy val trainingLogLikelihood: Double = oldDistributedModel.logLikelihood
Expand Down Expand Up @@ -649,6 +660,10 @@ class OnlineLDAOptimizer @Since("1.6.0") (
@Since("1.6.0")
def getTau0: Double = $(tau0)

/** @group setParam */
@Since("1.6.0")
def setTau0(value: Double) = set(tau0, value)

/**
* Learning rate, set as an exponential decay rate.
* This should be between (0.5, 1.0] to guarantee asymptotic convergence.
Expand All @@ -662,6 +677,10 @@ class OnlineLDAOptimizer @Since("1.6.0") (

setDefault(kappa -> 0.51)

/** @group setParam */
@Since("1.6.0")
def setKappa(value: Double) = set(kappa, value)

/** @group getParam */
@Since("1.6.0")
def getKappa: Double = $(kappa)
Expand Down Expand Up @@ -691,6 +710,12 @@ class OnlineLDAOptimizer @Since("1.6.0") (
@Since("1.6.0")
def getSubsamplingRate: Double = $(subsamplingRate)

// TODO: MOVE TO SHARED PARAMS

/** @group setParam */
@Since("1.6.0")
def setSubsamplingRate(value: Double) = set(subsamplingRate, value)

/**
* Indicates whether the docConcentration (Dirichlet parameter for
* document-topic distribution) will be optimized during training.
Expand All @@ -708,4 +733,8 @@ class OnlineLDAOptimizer @Since("1.6.0") (
/** @group getParam */
@Since("1.6.0")
def getOptimizeDocConcentration: Boolean = $(optimizeDocConcentration)

/** @group setParam */
@Since("1.6.0")
def setOptimizeDocConcentration(value: Boolean) = set(optimizeDocConcentration, value)
}
228 changes: 228 additions & 0 deletions mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.clustering

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row, SQLContext}


object LDASuite {
def generateLDAData(
sql: SQLContext,
rows: Int,
dim: Int,
k: Int,
vocabSize: Int): DataFrame = {
val sc = sql.sparkContext
val rng = new java.util.Random()
rng.setSeed(1)
val rdd = sc.parallelize(1 to rows).map { i =>
Vectors.dense(Array.fill(dim)(rng.nextInt(vocabSize).toDouble))
}.map(v => new TestRow(v))
sql.createDataFrame(rdd)
}
}


class LDASuite extends SparkFunSuite with MLlibTestSparkContext {

val k = 5
@transient var dataset: DataFrame = _
@transient var vocabSize: Int = _

override def beforeAll(): Unit = {
super.beforeAll()

dataset = LDASuite.generateLDAData(sqlContext, 50, 3, k, 30)
vocabSize = dataset.flatMap(_.getAs[Vector](0).toArray.map(_.toInt).toSet)
.distinct().count().toInt
}

test("default parameters") {
val lda = new LDA()

assert(lda.getFeaturesCol === "features")
assert(lda.getMaxIter === 20)
assert(lda.isDefined(lda.seed))
assert(!lda.isDefined(lda.checkpointInterval))
assert(lda.getK === 10)
assert(lda.getDocConcentration === Array(-1.0))
assert(lda.getTopicConcentration === -1.0)
assert(lda.getOptimizer.isInstanceOf[OnlineLDAOptimizer])
val optimizer = lda.getOptimizer.asInstanceOf[OnlineLDAOptimizer]
assert(optimizer.getKappa === 0.51)
assert(optimizer.getTau0 === 1024)
assert(optimizer.getSubsamplingRate === 0.05)
assert(optimizer.getOptimizeDocConcentration)
assert(lda.getTopicDistributionCol === "topicDistribution")
}

test("set parameters") {
val lda = new LDA()
.setFeaturesCol("test_feature")
.setMaxIter(33)
.setSeed(123)
.setCheckpointInterval(7)
.setK(9)
.setTopicConcentration(0.56)
.setTopicDistributionCol("myOutput")

assert(lda.getFeaturesCol === "test_feature")
assert(lda.getMaxIter === 33)
assert(lda.getSeed === 123)
assert(lda.getCheckpointInterval === 7)
assert(lda.getK === 9)
assert(lda.getTopicConcentration === 0.56)
assert(lda.getTopicDistributionCol === "myOutput")


// setOptimizer
lda.setOptimizer("em")
assert(lda.getOptimizer.isInstanceOf[EMLDAOptimizer])
lda.setOptimizer("online")
assert(lda.getOptimizer.isInstanceOf[OnlineLDAOptimizer])
val optimizer = lda.getOptimizer.asInstanceOf[OnlineLDAOptimizer]
optimizer.setKappa(0.53)
assert(optimizer.getKappa === 0.53)
optimizer.setTau0(1027)
assert(optimizer.getTau0 === 1027)
optimizer.setSubsamplingRate(0.06)
assert(optimizer.getSubsamplingRate === 0.06)
optimizer.setOptimizeDocConcentration(false)
assert(!optimizer.getOptimizeDocConcentration)
}

test("parameters validation") {
val lda = new LDA()

// misc Params
intercept[IllegalArgumentException] {
new LDA().setK(1)
}
intercept[IllegalArgumentException] {
new LDA().setOptimizer("no_such_optimizer")
}
intercept[IllegalArgumentException] {
new LDA().setDocConcentration(-1.1)
}
intercept[IllegalArgumentException] {
new LDA().setTopicConcentration(-1.1)
}

// validateParams()
lda.setDocConcentration(-1)
assert(lda.getDocConcentration === -1)
lda.validateParams()
lda.setDocConcentration(0.1)
lda.validateParams()
lda.setDocConcentration(Range(0, lda.getK).map(_ + 2.0).toArray)
lda.validateParams()
lda.setDocConcentration(Range(0, lda.getK - 1).map(_ + 2.0).toArray)
withClue("LDA docConcentration validity check failed for bad array length") {
intercept[IllegalArgumentException] {
lda.validateParams()
}
}

// OnlineLDAOptimizer
intercept[IllegalArgumentException] {
new OnlineLDAOptimizer().setTau0(0)
}
intercept[IllegalArgumentException] {
new OnlineLDAOptimizer().setKappa(0)
}
intercept[IllegalArgumentException] {
new OnlineLDAOptimizer().setSubsamplingRate(0)
}
intercept[IllegalArgumentException] {
new OnlineLDAOptimizer().setSubsamplingRate(1.1)
}
}

test("fit & transform with Online LDA") {
val lda = new LDA().setK(k).setSeed(1).setOptimizer("online").setMaxIter(2)
val model = lda.fit(dataset)

MLTestingUtils.checkCopy(model)

assert(!model.isInstanceOf[DistributedLDAModel])
assert(model.vocabSize === vocabSize)
assert(model.estimatedDocConcentration.size === k)
assert(model.topicsMatrix.numRows === vocabSize)
assert(model.topicsMatrix.numCols === k)
assert(!model.isDistributed)

// transform()
val transformed = model.transform(dataset)
val expectedColumns = Array("features", lda.getTopicDistributionCol)
expectedColumns.foreach { column =>
assert(transformed.columns.contains(column))
}
transformed.select(lda.getTopicDistributionCol).collect().foreach { r =>
val topicDistribution = r.getAs[Vector](0)
assert(topicDistribution.size === vocabSize)
assert(topicDistribution.toArray.forall(w => w >= 0.0 && w <= 1.0))
}

// logLikelihood, logPerplexity
val ll = model.logLikelihood(dataset)
assert(ll <= 0.0 && ll != Double.NegativeInfinity)
val lp = model.logPerplexity(dataset)
assert(lp >= 0.0 && lp != Double.PositiveInfinity)

// describeTopics
val topics = model.describeTopics(3)
assert(topics.count() === k)
assert(topics.select("topic").map(_.getInt(0)).collect().toSet === Range(0, k).toSet)
assert(topics.select("termIndices").collect().forall { case r: Row =>
val termIndices = r.getAs[Array[Int]](0)
termIndices.length === 3 && termIndices.toSet.size === 3
})
assert(topics.select("termWeights").collect().forall { case r: Row =>
val termWeights = r.getAs[Array[Double]](0)
termWeights.length === 3 && termWeights.forall(w => w >= 0.0 && w <= 1.0)
})
}

test("fit & transform with EM LDA") {
val lda = new LDA().setK(k).setSeed(1).setOptimizer("em").setMaxIter(2)
val model_ = lda.fit(dataset)

MLTestingUtils.checkCopy(model_)

assert(model_.isInstanceOf[DistributedLDAModel])
val model = model_.asInstanceOf[DistributedLDAModel]
assert(model.vocabSize === vocabSize)
assert(model.estimatedDocConcentration.size === k)
assert(model.topicsMatrix.numRows === vocabSize)
assert(model.topicsMatrix.numCols === k)
assert(model.isDistributed)

val localModel = model.toLocal
assert(!localModel.isInstanceOf[DistributedLDAModel])

// training logLikelihood, logPrior
val ll = model.trainingLogLikelihood
assert(ll <= 0.0 && ll != Double.NegativeInfinity)
val lp = model.logPrior
assert(lp >= 0.0 && lp != Double.PositiveInfinity)
}
}