Skip to content

Commit a0af0e3

Browse files
hhbyyhsrowen
authored andcommitted
[SPARK-11898][MLLIB] Use broadcast for the global tables in Word2Vec
jira: https://issues.apache.org/jira/browse/SPARK-11898 syn0Global and sync1Global in word2vec are quite large objects with size (vocab * vectorSize * 8), yet they are passed to worker using basic task serialization. Use broadcast can greatly improve the performance. My benchmark shows that, for 1M vocabulary and default vectorSize 100, changing to broadcast can help, 1. decrease the worker memory consumption by 45%. 2. decrease running time by 40%. This will also help extend the upper limit for Word2Vec. Author: Yuhao Yang <hhbyyh@gmail.com> Closes apache#9878 from hhbyyh/w2vBC.
1 parent 9693b0d commit a0af0e3

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,12 +316,15 @@ class Word2Vec extends Serializable with Logging {
316316
Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize)
317317
val syn1Global = new Array[Float](vocabSize * vectorSize)
318318
var alpha = learningRate
319+
319320
for (k <- 1 to numIterations) {
321+
val bcSyn0Global = sc.broadcast(syn0Global)
322+
val bcSyn1Global = sc.broadcast(syn1Global)
320323
val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) =>
321324
val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8))
322325
val syn0Modify = new Array[Int](vocabSize)
323326
val syn1Modify = new Array[Int](vocabSize)
324-
val model = iter.foldLeft((syn0Global, syn1Global, 0, 0)) {
327+
val model = iter.foldLeft((bcSyn0Global.value, bcSyn1Global.value, 0, 0)) {
325328
case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
326329
var lwc = lastWordCount
327330
var wc = wordCount
@@ -405,6 +408,8 @@ class Word2Vec extends Serializable with Logging {
405408
}
406409
i += 1
407410
}
411+
bcSyn0Global.unpersist(false)
412+
bcSyn1Global.unpersist(false)
408413
}
409414
newSentences.unpersist()
410415

0 commit comments

Comments
 (0)