-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-5563][mllib] LDA with online variational inference #4419
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from 1 commit
Commits
Show all changes
40 commits
Select commit
Hold shift + click to select a range
d640d9c
online lda initial checkin
hhbyyh 043e786
Merge remote-tracking branch 'upstream/master' into ldaonline
hhbyyh 26dca1b
style fix and make class private
hhbyyh f41c5ca
style fix
hhbyyh 45884ab
Merge remote-tracking branch 'upstream/master' into ldaonline
hhbyyh fa408a8
ssMerge remote-tracking branch 'upstream/master' into ldaonline
hhbyyh 0d0f3ee
replace random split with sliding
hhbyyh 0dd3947
kMerge remote-tracking branch 'upstream/master' into ldaonline
hhbyyh 3a06526
merge with new example
hhbyyh aa365d1
merge upstream master
hhbyyh 20328d1
Merge remote-tracking branch 'upstream/master' into ldaonline
hhbyyh 37af91a
iMerge remote-tracking branch 'upstream/master' into ldaonline
hhbyyh 581c623
seperate API and adjust batch split
hhbyyh e271eb1
remove non ascii
hhbyyh 4a3f27e
Merge remote-tracking branch 'upstream/master' into ldaonline
hhbyyh a570c9a
use sample to pick up batch
hhbyyh d86cdec
Merge remote-tracking branch 'upstream/master' into ldaonline
hhbyyh f6d47ca
Merge branch 'ldaonline' of https://github.com/hhbyyh/spark into ldao…
hhbyyh 02d0373
fix style in comment
hhbyyh 62405cc
Merge remote-tracking branch 'upstream/master' into ldaonline
hhbyyh 8cb16a6
Merge remote-tracking branch 'upstream/master' into ldaonline
hhbyyh f367cc9
change to optimization
hhbyyh e7bf3b0
move to seperate file
hhbyyh 97b9e1a
Merge remote-tracking branch 'upstream/master' into ldaonline
hhbyyh d19ef55
change OnlineLDA to class
hhbyyh b29193b
Merge remote-tracking branch 'upstream/master' into ldaonline
hhbyyh 15be071
Merge remote-tracking branch 'upstream/master' into ldaonline
hhbyyh dbe3cff
Merge remote-tracking branch 'upstream/master' into ldaonline
hhbyyh b1178cf
fit into the optimizer framework
hhbyyh a996a82
respond to comments
hhbyyh 61d60df
Minor cleanups:
jkbradley 9e910d9
small fix
jkbradley 138bfed
Merge pull request #1 from jkbradley/hhbyyh-ldaonline-update
hhbyyh 4041723
add ut
hhbyyh 68c2318
add a java ut
hhbyyh 54cf8da
some style change
hhbyyh cf0007d
Merge remote-tracking branch 'upstream/master' into ldaonline
hhbyyh 6149ca6
fix for setOptimizer
hhbyyh cf376ff
For private vars needed for testing, I made them private and added ac…
jkbradley 1045eec
Merge pull request #2 from jkbradley/hhbyyh-ldaonline2
hhbyyh File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
add ut
- Loading branch information
commit 4041723d4fe0dcfab845c9d2a2c72be2ed87895e
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -227,8 +227,8 @@ class OnlineLDAOptimizer extends LDAOptimizer { | |
| private var k: Int = 0 | ||
| private var corpusSize: Long = 0 | ||
| private var vocabSize: Int = 0 | ||
| private var alpha: Double = 0 | ||
| private var eta: Double = 0 | ||
| private[clustering] var alpha: Double = 0 | ||
| private[clustering] var eta: Double = 0 | ||
| private var randomGenerator: java.util.Random = null | ||
|
|
||
| // Online LDA specific parameters | ||
|
|
@@ -238,12 +238,11 @@ class OnlineLDAOptimizer extends LDAOptimizer { | |
|
|
||
| // internal data structure | ||
| private var docs: RDD[(Long, Vector)] = null | ||
| private var lambda: BDM[Double] = null | ||
| private var Elogbeta: BDM[Double] = null | ||
| private var expElogbeta: BDM[Double] = null | ||
| private[clustering] var lambda: BDM[Double] = null | ||
|
|
||
| // count of invocation to next, which helps deciding the weight for each iteration | ||
| private var iteration: Int = 0 | ||
| private var gammaShape: Double = 100 | ||
|
|
||
| /** | ||
| * A (positive) learning parameter that downweights early iterations. Larger values make early | ||
|
|
@@ -295,7 +294,24 @@ class OnlineLDAOptimizer extends LDAOptimizer { | |
| this | ||
| } | ||
|
|
||
| override private[clustering] def initialize(docs: RDD[(Long, Vector)], lda: LDA): LDAOptimizer = { | ||
| /** | ||
| * The function is for test only now. In the future, it can help support training strop/resume | ||
| */ | ||
| private[clustering] def setLambda(lambda: BDM[Double]): this.type = { | ||
| this.lambda = lambda | ||
| this | ||
| } | ||
|
|
||
| /** | ||
| * Used to control the gamma distribution. Larger value produces values closer to 1.0. | ||
| */ | ||
| private[clustering] def setGammaShape(shape: Double): this.type = { | ||
| this.gammaShape = shape | ||
| this | ||
| } | ||
|
|
||
| override private[clustering] def initialize(docs: RDD[(Long, Vector)], lda: LDA): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. scala style: If this can't fit on 1 line (100 chars), then put 1 argument per line: |
||
| OnlineLDAOptimizer = { | ||
| this.k = lda.getK | ||
| this.corpusSize = docs.count() | ||
| this.vocabSize = docs.first()._2.size | ||
|
|
@@ -307,26 +323,30 @@ class OnlineLDAOptimizer extends LDAOptimizer { | |
|
|
||
| // Initialize the variational distribution q(beta|lambda) | ||
| this.lambda = getGammaMatrix(k, vocabSize) | ||
| this.Elogbeta = dirichletExpectation(lambda) | ||
| this.expElogbeta = exp(Elogbeta) | ||
| this.iteration = 0 | ||
| this | ||
| } | ||
|
|
||
| override private[clustering] def next(): OnlineLDAOptimizer = { | ||
| val batch = docs.sample(withReplacement = true, miniBatchFraction, randomGenerator.nextLong()) | ||
| if (batch.isEmpty()) return this | ||
| submitMiniBatch(batch) | ||
| } | ||
|
|
||
|
|
||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. scala style: remove extra newline |
||
| /** | ||
| * Submit a subset (like 1%, decide by the miniBatchFraction) of the corpus to the Online LDA | ||
| * model, and it will update the topic distribution adaptively for the terms appearing in the | ||
| * subset. | ||
| */ | ||
| override private[clustering] def next(): OnlineLDAOptimizer = { | ||
| private[clustering] def submitMiniBatch(batch: RDD[(Long, Vector)]): OnlineLDAOptimizer = { | ||
| iteration += 1 | ||
| val batch = docs.sample(withReplacement = true, miniBatchFraction, randomGenerator.nextLong()) | ||
| if (batch.isEmpty()) return this | ||
|
|
||
| val k = this.k | ||
| val vocabSize = this.vocabSize | ||
| val expElogbeta = this.expElogbeta | ||
| val Elogbeta = dirichletExpectation(lambda) | ||
| val expElogbeta = exp(Elogbeta) | ||
| val alpha = this.alpha | ||
| val gammaShape = this.gammaShape | ||
|
|
||
| val stats: RDD[BDM[Double]] = batch.mapPartitions { docs => | ||
| val stat = BDM.zeros[Double](k, vocabSize) | ||
|
|
@@ -340,7 +360,7 @@ class OnlineLDAOptimizer extends LDAOptimizer { | |
| } | ||
|
|
||
| // Initialize the variational distribution q(theta|gamma) for the mini-batch | ||
| var gammad = new Gamma(100, 1.0 / 100.0).samplesVector(k).t // 1 * K | ||
| var gammad = new Gamma(gammaShape, 1.0 / gammaShape).samplesVector(k).t // 1 * K | ||
| var Elogthetad = digamma(gammad) - digamma(sum(gammad)) // 1 * K | ||
| var expElogthetad = exp(Elogthetad) // 1 * K | ||
| val expElogbetad = expElogbeta(::, ids).toDenseMatrix // K * ids | ||
|
|
@@ -350,7 +370,7 @@ class OnlineLDAOptimizer extends LDAOptimizer { | |
| val ctsVector = new BDV[Double](cts).t // 1 * ids | ||
|
|
||
| // Iterate between gamma and phi until convergence | ||
| while (meanchange > 1e-5) { | ||
| while (meanchange > 1e-3) { | ||
| val lastgamma = gammad | ||
| // 1*K 1 * ids ids * k | ||
| gammad = (expElogthetad :* ((ctsVector / phinorm) * expElogbetad.t)) + alpha | ||
|
|
@@ -372,7 +392,10 @@ class OnlineLDAOptimizer extends LDAOptimizer { | |
| Iterator(stat) | ||
| } | ||
|
|
||
| val batchResult: BDM[Double] = stats.reduce(_ += _) | ||
| val statsSum: BDM[Double] = stats.reduce(_ += _) | ||
| val batchResult = statsSum :* expElogbeta | ||
|
|
||
| // Note that this is an optimization to avoid batch.count | ||
| update(batchResult, iteration, (miniBatchFraction * corpusSize).toInt) | ||
| this | ||
| } | ||
|
|
@@ -384,28 +407,23 @@ class OnlineLDAOptimizer extends LDAOptimizer { | |
| /** | ||
| * Update lambda based on the batch submitted. batchSize can be different for each iteration. | ||
| */ | ||
| private def update(raw: BDM[Double], iter: Int, batchSize: Int): Unit = { | ||
| private[clustering] def update(stat: BDM[Double], iter: Int, batchSize: Int): Unit = { | ||
| val tau_0 = this.getTau_0 | ||
| val kappa = this.getKappa | ||
|
|
||
| // weight of the mini-batch. | ||
| val weight = math.pow(tau_0 + iter, -kappa) | ||
|
|
||
| // This step finishes computing the sufficient statistics for the M step | ||
| val stat = raw :* expElogbeta | ||
|
|
||
| // Update lambda based on documents. | ||
| lambda = lambda * (1 - weight) + | ||
| (stat * (corpusSize.toDouble / batchSize.toDouble) + eta) * weight | ||
| Elogbeta = dirichletExpectation(lambda) | ||
| expElogbeta = exp(Elogbeta) | ||
| } | ||
|
|
||
| /** | ||
| * Get a random matrix to initialize lambda | ||
| */ | ||
| private def getGammaMatrix(row: Int, col: Int): BDM[Double] = { | ||
| val gammaRandomGenerator = new Gamma(100, 1.0 / 100.0) | ||
| val gammaRandomGenerator = new Gamma(gammaShape, 1.0 / gammaShape) | ||
| val temp = gammaRandomGenerator.sample(row * col).toArray | ||
| new BDM[Double](col, row, temp).t | ||
| } | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo: "strop" -> "stop"