Skip to content
Prev Previous commit
Next Next commit
Add test for Word2Vec algorithm, minor fixes
  • Loading branch information
Liquan Pei committed Aug 2, 2014
commit 720b5a3ea697a881fc7d7c286b65ef110421f89e
17 changes: 10 additions & 7 deletions mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -50,24 +50,27 @@ private case class VocabWord(
* natural language processing and machine learning algorithms.
*
* We used skip-gram model in our implementation and hierarchical softmax
* method to train the model.
* method to train the model. The variable names in the implementation
* mathes the original C implementation.
*
* For original C implementation, see https://code.google.com/p/word2vec/
* For research papers, see
* Efficient Estimation of Word Representations in Vector Space
* and
* Distributed Representations of Words and Phrases and their Compositionality
* Distributed Representations of Words and Phrases and their Compositionality.
* @param size vector dimension
* @param startingAlpha initial learning rate
* @param window context words from [-window, window]
* @param minCount minimum frequncy to consider a vocabulary word
* @param parallelisum number of partitions to run Word2Vec
*/
@Experimental
class Word2Vec(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need more docs here, for example, link to the C implementation and the original papers for word2vec.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and briefly explain what it does.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

val size: Int,
val startingAlpha: Double,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is word2vec sensitive to alpha? If not, we should try to expose less parameters to users.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

word2vec is sensitive to alpha. Larger alpha may generate meaningless result

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can suggest a reasonable default value in the doc.

val window: Int,
val minCount: Int)
val minCount: Int,
val parallelism:Int = 1)
extends Serializable with Logging {

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please leave a note that the variable/method names are to match the original C implementation. Then people understand why, e.g., we map size to layer1Size.

private val EXP_TABLE_SIZE = 1000
Expand Down Expand Up @@ -237,7 +240,7 @@ class Word2Vec(
}
}

val newSentences = sentences.repartition(1).cache()
val newSentences = sentences.repartition(parallelism).cache()
val temp = Array.fill[Double](vocabSize * layer1Size)((Random.nextDouble - 0.5) / layer1Size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try to fix the seed to make the computation reproducible.

val (aggSyn0, _, _, _) =
// TODO: broadcast temp instead of serializing it directly
Expand All @@ -248,7 +251,7 @@ class Word2Vec(
var wc = wordCount
if (wordCount - lastWordCount > 10000) {
lwc = wordCount
alpha = startingAlpha * (1 - wordCount.toDouble / (trainWordsCount + 1))
alpha = startingAlpha * (1 - parallelism * wordCount.toDouble / (trainWordsCount + 1))
if (alpha < startingAlpha * 0.0001) alpha = startingAlpha * 0.0001
logInfo("wordCount = " + wordCount + ", alpha = " + alpha)
}
Expand Down Expand Up @@ -296,7 +299,7 @@ class Word2Vec(
val n = syn0_1.length
blas.daxpy(n, 1.0, syn0_2, 1, syn0_1, 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should use weighted sum to handle imbalanced partitions.

blas.daxpy(n, 1.0, syn1_2, 1, syn1_1, 1)
(syn0_1, syn0_2, lwc_1 + lwc_2, wc_1 + wc_2)
(syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2)
})

val wordMap = new Array[(String, Array[Double])](vocabSize)
Expand All @@ -309,7 +312,7 @@ class Word2Vec(
i += 1
}
val modelRDD = sc.parallelize(wordMap, modelPartitionNum)
.partitionBy(new HashPartitioner(modelPartitionNum))
.partitionBy(new HashPartitioner(modelPartitionNum)).cache()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.cache() to .persist(MEMROY_AND_DISK), so it won't be kicked out by later jobs.

new Word2VecModel(modelRDD)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please call newSentences.unpersist() before exit

}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,27 @@ import org.apache.spark.SparkContext._
import org.apache.spark.mllib.util.LocalSparkContext

class Word2VecSuite extends FunSuite with LocalSparkContext {
test("word2vec") {
test("Word2Vec") {
val sentence = "a b " * 100 + "a c " * 10
val localDoc = Seq(sentence, sentence)
val doc = sc.parallelize(localDoc)
.map(line => line.split(" ").toSeq)
val size = 10
val startingAlpha = 0.025
val window = 2
val minCount = 2
val num = 2
val word = "a"

val model = Word2Vec.train(doc, size, startingAlpha, window, minCount)
val synons = model.findSynonyms("a", 2)
assert(synons.length == num)
assert(synons(0)._1 == "b")
assert(synons(1)._1 == "c")
}


test("Word2VecModel") {
val num = 2
val localModel = Seq(
("china" , Array(0.50, 0.50, 0.50, 0.50)),
Expand Down