Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
d640d9c
online lda initial checkin
hhbyyh Feb 6, 2015
043e786
Merge remote-tracking branch 'upstream/master' into ldaonline
hhbyyh Feb 6, 2015
26dca1b
style fix and make class private
hhbyyh Feb 6, 2015
f41c5ca
style fix
hhbyyh Feb 6, 2015
45884ab
Merge remote-tracking branch 'upstream/master' into ldaonline
hhbyyh Feb 8, 2015
fa408a8
ssMerge remote-tracking branch 'upstream/master' into ldaonline
hhbyyh Feb 9, 2015
0d0f3ee
replace random split with sliding
hhbyyh Feb 10, 2015
0dd3947
kMerge remote-tracking branch 'upstream/master' into ldaonline
hhbyyh Feb 10, 2015
3a06526
merge with new example
hhbyyh Feb 10, 2015
aa365d1
merge upstream master
hhbyyh Mar 2, 2015
20328d1
Merge remote-tracking branch 'upstream/master' into ldaonline
hhbyyh Mar 2, 2015
37af91a
iMerge remote-tracking branch 'upstream/master' into ldaonline
hhbyyh Mar 2, 2015
581c623
seperate API and adjust batch split
hhbyyh Mar 2, 2015
e271eb1
remove non ascii
hhbyyh Mar 2, 2015
4a3f27e
Merge remote-tracking branch 'upstream/master' into ldaonline
hhbyyh Mar 5, 2015
a570c9a
use sample to pick up batch
hhbyyh Mar 11, 2015
d86cdec
Merge remote-tracking branch 'upstream/master' into ldaonline
hhbyyh Mar 11, 2015
f6d47ca
Merge branch 'ldaonline' of https://github.com/hhbyyh/spark into ldao…
hhbyyh Mar 11, 2015
02d0373
fix style in comment
hhbyyh Mar 12, 2015
62405cc
Merge remote-tracking branch 'upstream/master' into ldaonline
hhbyyh Mar 20, 2015
8cb16a6
Merge remote-tracking branch 'upstream/master' into ldaonline
hhbyyh Mar 23, 2015
f367cc9
change to optimization
hhbyyh Mar 23, 2015
e7bf3b0
move to seperate file
hhbyyh Mar 27, 2015
97b9e1a
Merge remote-tracking branch 'upstream/master' into ldaonline
hhbyyh Mar 27, 2015
d19ef55
change OnlineLDA to class
hhbyyh Apr 2, 2015
b29193b
Merge remote-tracking branch 'upstream/master' into ldaonline
hhbyyh Apr 16, 2015
15be071
Merge remote-tracking branch 'upstream/master' into ldaonline
hhbyyh Apr 17, 2015
dbe3cff
Merge remote-tracking branch 'upstream/master' into ldaonline
hhbyyh Apr 28, 2015
b1178cf
fit into the optimizer framework
hhbyyh Apr 28, 2015
a996a82
respond to comments
hhbyyh Apr 29, 2015
61d60df
Minor cleanups:
jkbradley Apr 29, 2015
9e910d9
small fix
jkbradley Apr 29, 2015
138bfed
Merge pull request #1 from jkbradley/hhbyyh-ldaonline-update
hhbyyh Apr 29, 2015
4041723
add ut
hhbyyh Apr 29, 2015
68c2318
add a java ut
hhbyyh Apr 30, 2015
54cf8da
some style change
hhbyyh May 1, 2015
cf0007d
Merge remote-tracking branch 'upstream/master' into ldaonline
hhbyyh May 1, 2015
6149ca6
fix for setOptimizer
hhbyyh May 1, 2015
cf376ff
For private vars needed for testing, I made them private and added ac…
jkbradley May 2, 2015
1045eec
Merge pull request #2 from jkbradley/hhbyyh-ldaonline2
hhbyyh May 3, 2015
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
seperate API and adjust batch split
  • Loading branch information
hhbyyh committed Mar 2, 2015
commit 581c623106f38d91497fb8123f47c4e661057071
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ public Tuple2<Long, Vector> call(Tuple2<Vector, Long> doc_id) {
corpus.cache();

// Cluster the documents into three topics using LDA
DistributedLDAModel ldaModel = (DistributedLDAModel) new LDA().setK(3).run(corpus);
DistributedLDAModel ldaModel = new LDA().setK(3).run(corpus);

// Output topics. Each is a distribution over words (matching word count vectors)
System.out.println("Learned topics (as distributions over vocab of " + ldaModel.vocabSize()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ object LDAExample {
}
println()
}

sc.stop()
}

/**
Expand Down
126 changes: 61 additions & 65 deletions mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer
import org.apache.spark.mllib.linalg.{Vector, DenseVector, SparseVector, Matrices}
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils
import org.apache.spark.mllib.rdd.RDDFunctions._


/**
Expand Down Expand Up @@ -223,10 +222,6 @@ class LDA private (
this
}

object LDAMode extends Enumeration {
val EM, Online = Value
}

/**
* Learn an LDA model using the given dataset.
*
Expand All @@ -236,37 +231,30 @@ class LDA private (
* Document IDs must be unique and >= 0.
* @return Inferred LDA model
*/
def run(documents: RDD[(Long, Vector)], mode: LDAMode.Value = LDAMode.EM ): LDAModel = {
mode match {
case LDAMode.EM =>
val state = LDA.initialState(documents, k, getDocConcentration, getTopicConcentration, seed,
checkpointInterval)
var iter = 0
val iterationTimes = Array.fill[Double](maxIterations)(0)
while (iter < maxIterations) {
val start = System.nanoTime()
state.next()
val elapsedSeconds = (System.nanoTime() - start) / 1e9
iterationTimes(iter) = elapsedSeconds
iter += 1
}
state.graphCheckpointer.deleteAllCheckpoints()
new DistributedLDAModel(state, iterationTimes)
case LDAMode.Online =>
val vocabSize = documents.first._2.size
val onlineLDA = new LDA.OnlineLDAOptimizer(documents, k, vocabSize)
var iter = 0
while (iter < onlineLDA.batchNumber) {
onlineLDA.next()
iter += 1
}
new LocalLDAModel(Matrices.fromBreeze(onlineLDA._lambda).transpose)
case _ => throw new IllegalArgumentException(s"Do not support mode $mode.")
def run(documents: RDD[(Long, Vector)]): DistributedLDAModel = {
val state = LDA.initialState(documents, k, getDocConcentration, getTopicConcentration, seed,
checkpointInterval)
var iter = 0
val iterationTimes = Array.fill[Double](maxIterations)(0)
while (iter < maxIterations) {
val start = System.nanoTime()
state.next()
val elapsedSeconds = (System.nanoTime() - start) / 1e9
iterationTimes(iter) = elapsedSeconds
iter += 1
}
state.graphCheckpointer.deleteAllCheckpoints()
new DistributedLDAModel(state, iterationTimes)
}

def runOnlineLDA(documents: RDD[(Long, Vector)]): LDAModel = {
val onlineLDA = new LDA.OnlineLDAOptimizer(documents, k)
(0 until onlineLDA.batchNumber).map(_ => onlineLDA.next())
new LocalLDAModel(Matrices.fromBreeze(onlineLDA.lambda).transpose)
}

/** Java-friendly version of [[run()]] */
def run(documents: JavaPairRDD[java.lang.Long, Vector]): LDAModel = {
def run(documents: JavaPairRDD[java.lang.Long, Vector]): DistributedLDAModel = {
run(documents.rdd.asInstanceOf[RDD[(Long, Vector)]])
}
}
Expand Down Expand Up @@ -418,58 +406,66 @@ private[clustering] object LDA {

}

// todo: add reference to paper and Hoffman
/**
* Optimizer for Online LDA algorithm which breaks corpus into mini-batches and scans only once.
* Hoffman, Blei and Bach, “Online Learning for Latent Dirichlet Allocation.” NIPS, 2010.
*/
private[clustering] class OnlineLDAOptimizer(
val documents: RDD[(Long, Vector)],
val k: Int,
val vocabSize: Int) extends Serializable{
private val documents: RDD[(Long, Vector)],
private val k: Int) extends Serializable{

private val kappa = 0.5 // (0.5, 1] how quickly old information is forgotten
private val tau0 = 1024 // down weights early iterations
private val D = documents.count()
private val vocabSize = documents.first._2.size
private val D = documents.count().toInt
private val batchSize = if (D / 1000 > 4096) 4096
else if (D / 1000 < 4) 4
else D / 1000
val batchNumber = (D/batchSize + 1).toInt
private val batches = documents.sliding(batchNumber).collect()
val batchNumber = D/batchSize

// Initialize the variational distribution q(beta|lambda)
var _lambda = getGammaMatrix(k, vocabSize) // K * V
private var _Elogbeta = dirichlet_expectation(_lambda) // K * V
private var _expElogbeta = exp(_Elogbeta) // K * V
var lambda = getGammaMatrix(k, vocabSize) // K * V
private var Elogbeta = dirichlet_expectation(lambda) // K * V
private var expElogbeta = exp(Elogbeta) // K * V

private var batchCount = 0
private var batchId = 0
def next(): Unit = {
// weight of the mini-batch.
val rhot = math.pow(tau0 + batchCount, -kappa)
require(batchId < batchNumber)
// weight of the mini-batch. 1024 down weights early iterations
val weight = math.pow(1024 + batchId, -0.5)
val batch = documents.filter(doc => doc._1 % batchNumber == batchId)

// Given a mini-batch of documents, estimates the parameters gamma controlling the
// variational distribution over the topic weights for each document in the mini-batch.
var stat = BDM.zeros[Double](k, vocabSize)
stat = batches(batchCount).aggregate(stat)(seqOp, _ += _)

stat = stat :* _expElogbeta
_lambda = _lambda * (1 - rhot) + (stat * D.toDouble / batchSize.toDouble + 1.0 / k) * rhot
_Elogbeta = dirichlet_expectation(_lambda)
_expElogbeta = exp(_Elogbeta)
batchCount += 1
stat = batch.aggregate(stat)(seqOp, _ += _)
stat = stat :* expElogbeta

// Update lambda based on documents.
lambda = lambda * (1 - weight) + (stat * D.toDouble / batchSize.toDouble + 1.0 / k) * weight
Elogbeta = dirichlet_expectation(lambda)
expElogbeta = exp(Elogbeta)
batchId += 1
}

private def seqOp(other: BDM[Double], doc: (Long, Vector)): BDM[Double] = {
// for each document d update that document's gamma and phi
private def seqOp(stat: BDM[Double], doc: (Long, Vector)): BDM[Double] = {
val termCounts = doc._2
val (ids, cts) = termCounts match {
case v: DenseVector => (((0 until v.size).toList), v.values)
case v: SparseVector => (v.indices.toList, v.values)
case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
}

// Initialize the variational distribution q(theta|gamma) for the mini-batch
var gammad = new Gamma(100, 1.0 / 100.0).samplesVector(k).t // 1 * K
var Elogthetad = vector_dirichlet_expectation(gammad.t).t // 1 * K
var expElogthetad = exp(Elogthetad.t).t // 1 * K
val expElogbetad = _expElogbeta(::, ids).toDenseMatrix // K * ids
val expElogbetad = expElogbeta(::, ids).toDenseMatrix // K * ids

var phinorm = expElogthetad * expElogbetad + 1e-100 // 1 * ids
var meanchange = 1D
val ctsVector = new BDV[Double](cts).t // 1 * ids
val ctsVector = new BDV[Double](cts).t // 1 * ids

// Iterate between gamma and phi until convergence
while (meanchange > 1e-6) {
val lastgamma = gammad
// 1*K 1 * ids ids * k
Expand All @@ -480,30 +476,30 @@ private[clustering] object LDA {
meanchange = sum(abs((gammad - lastgamma).t)) / gammad.t.size.toDouble
}

val v1 = expElogthetad.t.toDenseMatrix.t
val v2 = (ctsVector / phinorm).t.toDenseMatrix
val outerResult = kron(v1, v2) // K * ids
val m1 = expElogthetad.t.toDenseMatrix.t
val m2 = (ctsVector / phinorm).t.toDenseMatrix
val outerResult = kron(m1, m2) // K * ids
for (i <- 0 until ids.size) {
other(::, ids(i)) := (other(::, ids(i)) + outerResult(::, i))
stat(::, ids(i)) := (stat(::, ids(i)) + outerResult(::, i))
}
other
stat
}

def getGammaMatrix(row:Int, col:Int): BDM[Double] ={
private def getGammaMatrix(row:Int, col:Int): BDM[Double] ={
val gammaRandomGenerator = new Gamma(100, 1.0 / 100.0)
val temp = gammaRandomGenerator.sample(row * col).toArray
(new BDM[Double](col, row, temp)).t
}

def dirichlet_expectation(alpha : BDM[Double]): BDM[Double] = {
private def dirichlet_expectation(alpha : BDM[Double]): BDM[Double] = {
val rowSum = sum(alpha(breeze.linalg.*, ::))
val digAlpha = digamma(alpha)
val digRowSum = digamma(rowSum)
val result = digAlpha(::, breeze.linalg.*) - digRowSum
result
}

def vector_dirichlet_expectation(v : BDV[Double]): (BDV[Double]) ={
private def vector_dirichlet_expectation(v : BDV[Double]): (BDV[Double]) ={
digamma(v) - digamma(sum(v))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ public void distributedLDAModel() {
.setMaxIterations(5)
.setSeed(12345);

DistributedLDAModel model = (DistributedLDAModel)lda.run(corpus);
DistributedLDAModel model = lda.run(corpus);

// Check: basic parameters
LocalLDAModel localModel = model.toLocal();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class LDASuite extends FunSuite with MLlibTestSparkContext {
.setSeed(12345)
val corpus = sc.parallelize(tinyCorpus, 2)

val model: DistributedLDAModel = lda.run(corpus).asInstanceOf[DistributedLDAModel]
val model: DistributedLDAModel = lda.run(corpus)

// Check: basic parameters
val localModel = model.toLocal
Expand Down