-
Notifications
You must be signed in to change notification settings - Fork 29k
[MLlib] [SPARK-2510]Word2Vec: Distributed Representation of Words #1719
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
8d6befe
0aafb1b
e4a04d3
57dc50d
2e92b59
720b5a3
6bcc8be
7efbb6f
1a8fb41
e93e726
384c771
c14da41
26a948d
e248441
2ba9483
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -70,7 +70,8 @@ class Word2Vec( | |
| val startingAlpha: Double, | ||
| val window: Int, | ||
| val minCount: Int, | ||
| val parallelism:Int = 1) | ||
| val parallelism:Int = 1, | ||
| val numIterations:Int = 1) | ||
| extends Serializable with Logging { | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please leave a note that the variable/method names are to match the original C implementation. Then people understand why, e.g., we map |
||
| private val EXP_TABLE_SIZE = 1000 | ||
|
|
@@ -241,73 +242,80 @@ class Word2Vec( | |
| } | ||
|
|
||
| val newSentences = sentences.repartition(parallelism).cache() | ||
| val temp = Array.fill[Double](vocabSize * layer1Size)((Random.nextDouble - 0.5) / layer1Size) | ||
| val (aggSyn0, _, _, _) = | ||
| // TODO: broadcast temp instead of serializing it directly | ||
| // or initialize the model in each executor | ||
| newSentences.aggregate((temp.clone(), new Array[Double](vocabSize * layer1Size), 0, 0))( | ||
| seqOp = (c, v) => (c, v) match { case ((syn0, syn1, lastWordCount, wordCount), sentence) => | ||
| var lwc = lastWordCount | ||
| var wc = wordCount | ||
| if (wordCount - lastWordCount > 10000) { | ||
| lwc = wordCount | ||
| alpha = startingAlpha * (1 - parallelism * wordCount.toDouble / (trainWordsCount + 1)) | ||
| if (alpha < startingAlpha * 0.0001) alpha = startingAlpha * 0.0001 | ||
| logInfo("wordCount = " + wordCount + ", alpha = " + alpha) | ||
| } | ||
| wc += sentence.size | ||
| var pos = 0 | ||
| while (pos < sentence.size) { | ||
| val word = sentence(pos) | ||
| // TODO: fix random seed | ||
| val b = Random.nextInt(window) | ||
| // Train Skip-gram | ||
| var a = b | ||
| while (a < window * 2 + 1 - b) { | ||
| if (a != window) { | ||
| val c = pos - window + a | ||
| if (c >= 0 && c < sentence.size) { | ||
| val lastWord = sentence(c) | ||
| val l1 = lastWord * layer1Size | ||
| val neu1e = new Array[Double](layer1Size) | ||
| // Hierarchical softmax | ||
| var d = 0 | ||
| while (d < vocab(word).codeLen) { | ||
| val l2 = vocab(word).point(d) * layer1Size | ||
| // Propagate hidden -> output | ||
| var f = blas.ddot(layer1Size, syn0, l1, 1, syn1, l2, 1) | ||
| if (f > -MAX_EXP && f < MAX_EXP) { | ||
| val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt | ||
| f = expTable.value(ind) | ||
| val g = (1 - vocab(word).code(d) - f) * alpha | ||
| blas.daxpy(layer1Size, g, syn1, l2, 1, neu1e, 0, 1) | ||
| blas.daxpy(layer1Size, g, syn0, l1, 1, syn1, l2, 1) | ||
| var syn0Global | ||
| = Array.fill[Double](vocabSize * layer1Size)((Random.nextDouble - 0.5) / layer1Size) | ||
| var syn1Global = new Array[Double](vocabSize * layer1Size) | ||
|
|
||
| for(iter <- 1 to numIterations) { | ||
| val (aggSyn0, aggSyn1, _, _) = | ||
| // TODO: broadcast temp instead of serializing it directly | ||
| // or initialize the model in each executor | ||
| newSentences.aggregate((syn0Global.clone(), syn1Global.clone(), 0, 0))( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you mind changing it to |
||
| seqOp = (c, v) => (c, v) match { | ||
| case ((syn0, syn1, lastWordCount, wordCount), sentence) => | ||
| var lwc = lastWordCount | ||
| var wc = wordCount | ||
| if (wordCount - lastWordCount > 10000) { | ||
| lwc = wordCount | ||
| alpha = startingAlpha * (1 - parallelism * wordCount.toDouble / (trainWordsCount + 1)) | ||
| if (alpha < startingAlpha * 0.0001) alpha = startingAlpha * 0.0001 | ||
| logInfo("wordCount = " + wordCount + ", alpha = " + alpha) | ||
| } | ||
| wc += sentence.size | ||
| var pos = 0 | ||
| while (pos < sentence.size) { | ||
| val word = sentence(pos) | ||
| // TODO: fix random seed | ||
| val b = Random.nextInt(window) | ||
| // Train Skip-gram | ||
| var a = b | ||
| while (a < window * 2 + 1 - b) { | ||
| if (a != window) { | ||
| val c = pos - window + a | ||
| if (c >= 0 && c < sentence.size) { | ||
| val lastWord = sentence(c) | ||
| val l1 = lastWord * layer1Size | ||
| val neu1e = new Array[Double](layer1Size) | ||
| // Hierarchical softmax | ||
| var d = 0 | ||
| while (d < vocab(word).codeLen) { | ||
| val l2 = vocab(word).point(d) * layer1Size | ||
| // Propagate hidden -> output | ||
| var f = blas.ddot(layer1Size, syn0, l1, 1, syn1, l2, 1) | ||
| if (f > -MAX_EXP && f < MAX_EXP) { | ||
| val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt | ||
| f = expTable.value(ind) | ||
| val g = (1 - vocab(word).code(d) - f) * alpha | ||
| blas.daxpy(layer1Size, g, syn1, l2, 1, neu1e, 0, 1) | ||
| blas.daxpy(layer1Size, g, syn0, l1, 1, syn1, l2, 1) | ||
| } | ||
| d += 1 | ||
| } | ||
| d += 1 | ||
| blas.daxpy(layer1Size, 1.0, neu1e, 0, 1, syn0, l1, 1) | ||
| } | ||
| blas.daxpy(layer1Size, 1.0, neu1e, 0, 1, syn0, l1, 1) | ||
| } | ||
| a += 1 | ||
| } | ||
| a += 1 | ||
| pos += 1 | ||
| } | ||
| pos += 1 | ||
| } | ||
| (syn0, syn1, lwc, wc) | ||
| }, | ||
| combOp = (c1, c2) => (c1, c2) match { | ||
| case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) => | ||
| val n = syn0_1.length | ||
| blas.daxpy(n, 1.0, syn0_2, 1, syn0_1, 1) | ||
| blas.daxpy(n, 1.0, syn1_2, 1, syn1_1, 1) | ||
| (syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2) | ||
| }) | ||
|
|
||
| (syn0, syn1, lwc, wc) | ||
| }, | ||
| combOp = (c1, c2) => (c1, c2) match { | ||
| case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) => | ||
| val n = syn0_1.length | ||
| blas.daxpy(n, 1.0, syn0_2, 1, syn0_1, 1) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using weighted sum may be more robust. |
||
| blas.daxpy(n, 1.0, syn1_2, 1, syn1_1, 1) | ||
| (syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2) | ||
| }) | ||
| syn0Global = aggSyn0 | ||
| syn1Global = aggSyn1 | ||
| } | ||
| val wordMap = new Array[(String, Array[Double])](vocabSize) | ||
| var i = 0 | ||
| while (i < vocabSize) { | ||
| val word = vocab(i).word | ||
| val vector = new Array[Double](layer1Size) | ||
| Array.copy(aggSyn0, i * layer1Size, vector, 0, layer1Size) | ||
| Array.copy(syn0Global, i * layer1Size, vector, 0, layer1Size) | ||
| wordMap(i) = (word, vector) | ||
| i += 1 | ||
| } | ||
|
|
@@ -398,7 +406,9 @@ object Word2Vec{ | |
| size: Int, | ||
| startingAlpha: Double, | ||
| window: Int, | ||
| minCount: Int): Word2VecModel = { | ||
| new Word2Vec(size,startingAlpha, window, minCount).fit[S](input) | ||
| minCount: Int, | ||
| parallelism: Int = 1, | ||
| numIterations:Int = 1): Word2VecModel = { | ||
| new Word2Vec(size,startingAlpha, window, minCount, parallelism, numIterations).fit[S](input) | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is word2vec sensitive to alpha? If not, we should try to expose less parameters to users.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
word2vec is sensitive to alpha. Larger alpha may generate meaningless result
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can suggest a reasonable default value in the doc.