-
Notifications
You must be signed in to change notification settings - Fork 29k
[MLlib] [SPARK-2510]Word2Vec: Distributed Representation of Words #1719
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
8d6befe
0aafb1b
e4a04d3
57dc50d
2e92b59
720b5a3
6bcc8be
7efbb6f
1a8fb41
e93e726
384c771
c14da41
26a948d
e248441
2ba9483
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,20 +17,19 @@ | |
|
|
||
| package org.apache.spark.mllib.feature | ||
|
|
||
| import scala.util.Random | ||
| import scala.collection.mutable.ArrayBuffer | ||
| import scala.collection.mutable | ||
| import scala.collection.mutable.ArrayBuffer | ||
| import scala.util.Random | ||
|
|
||
| import com.github.fommil.netlib.BLAS.{getInstance => blas} | ||
|
|
||
| import org.apache.spark.annotation.Experimental | ||
| import org.apache.spark.Logging | ||
| import org.apache.spark.rdd._ | ||
| import org.apache.spark.{HashPartitioner, Logging} | ||
| import org.apache.spark.SparkContext._ | ||
| import org.apache.spark.annotation.Experimental | ||
| import org.apache.spark.mllib.linalg.{Vector, Vectors} | ||
| import org.apache.spark.HashPartitioner | ||
| import org.apache.spark.storage.StorageLevel | ||
| import org.apache.spark.mllib.rdd.RDDFunctions._ | ||
| import org.apache.spark.rdd._ | ||
| import org.apache.spark.storage.StorageLevel | ||
|
|
||
| /** | ||
| * Entry in vocabulary | ||
| */ | ||
|
|
@@ -52,7 +51,7 @@ private case class VocabWord( | |
| * | ||
| * We used skip-gram model in our implementation and hierarchical softmax | ||
| * method to train the model. The variable names in the implementation | ||
| * mathes the original C implementation. | ||
| * matches the original C implementation. | ||
| * | ||
| * For original C implementation, see https://code.google.com/p/word2vec/ | ||
| * For research papers, see | ||
|
|
@@ -61,34 +60,41 @@ private case class VocabWord( | |
| * Distributed Representations of Words and Phrases and their Compositionality. | ||
| * @param size vector dimension | ||
| * @param startingAlpha initial learning rate | ||
| * @param window context words from [-window, window] | ||
| * @param minCount minimum frequncy to consider a vocabulary word | ||
| * @param parallelisum number of partitions to run Word2Vec | ||
| * @param parallelism number of partitions to run Word2Vec (using a small number for accuracy) | ||
| * @param numIterations number of iterations to run, should be smaller than or equal to parallelism | ||
| */ | ||
| @Experimental | ||
| class Word2Vec( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need more docs here, for example, link to the C implementation and the original papers for word2vec.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and briefly explain what it does.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Btw, this is definitely an experimental feature. Please add |
||
| val size: Int, | ||
| val startingAlpha: Double, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is word2vec sensitive to alpha? If not, we should try to expose less parameters to users.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. word2vec is sensitive to alpha. Larger alpha may generate meaningless result
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we can suggest a reasonable default value in the doc. |
||
| val window: Int, | ||
| val minCount: Int, | ||
| val parallelism:Int = 1, | ||
| val numIterations:Int = 1) | ||
| extends Serializable with Logging { | ||
|
|
||
| val parallelism: Int, | ||
| val numIterations: Int) extends Serializable with Logging { | ||
|
|
||
| /** | ||
| * Word2Vec with a single thread. | ||
| */ | ||
| def this(size: Int, startingAlpha: Int) = this(size, startingAlpha, 1, 1) | ||
|
|
||
| private val EXP_TABLE_SIZE = 1000 | ||
| private val MAX_EXP = 6 | ||
| private val MAX_CODE_LENGTH = 40 | ||
| private val MAX_SENTENCE_LENGTH = 1000 | ||
| private val layer1Size = size | ||
| private val modelPartitionNum = 100 | ||
|
|
||
|
|
||
| /** context words from [-window, window] */ | ||
| private val window = 5 | ||
|
|
||
| /** minimum frequency to consider a vocabulary word */ | ||
| private val minCount = 5 | ||
|
|
||
| private var trainWordsCount = 0 | ||
| private var vocabSize = 0 | ||
| private var vocab: Array[VocabWord] = null | ||
| private var vocabHash = mutable.HashMap.empty[String, Int] | ||
| private var alpha = startingAlpha | ||
|
|
||
| private def learnVocab(words:RDD[String]){ | ||
| private def learnVocab(words:RDD[String]): Unit = { | ||
| vocab = words.map(w => (w, 1)) | ||
| .reduceByKey(_ + _) | ||
| .map(x => VocabWord( | ||
|
|
@@ -99,7 +105,7 @@ class Word2Vec( | |
| 0)) | ||
| .filter(_.cn >= minCount) | ||
| .collect() | ||
| .sortWith((a, b)=> a.cn > b.cn) | ||
| .sortWith((a, b) => a.cn > b.cn) | ||
|
|
||
| vocabSize = vocab.length | ||
| var a = 0 | ||
|
|
@@ -111,22 +117,18 @@ class Word2Vec( | |
| logInfo("trainWordsCount = " + trainWordsCount) | ||
| } | ||
|
|
||
| private def learnVocabPerPartition(words:RDD[String]) { | ||
|
|
||
| } | ||
|
|
||
| private def createExpTable(): Array[Double] = { | ||
| val expTable = new Array[Double](EXP_TABLE_SIZE) | ||
| private def createExpTable(): Array[Float] = { | ||
| val expTable = new Array[Float](EXP_TABLE_SIZE) | ||
| var i = 0 | ||
| while (i < EXP_TABLE_SIZE) { | ||
| val tmp = math.exp((2.0 * i / EXP_TABLE_SIZE - 1.0) * MAX_EXP) | ||
| expTable(i) = tmp / (tmp + 1) | ||
| expTable(i) = (tmp / (tmp + 1.0)).toFloat | ||
| i += 1 | ||
| } | ||
| expTable | ||
| } | ||
|
|
||
| private def createBinaryTree() { | ||
| private def createBinaryTree(): Unit = { | ||
| val count = new Array[Long](vocabSize * 2 + 1) | ||
| val binary = new Array[Int](vocabSize * 2 + 1) | ||
| val parentNode = new Array[Int](vocabSize * 2 + 1) | ||
|
|
@@ -208,8 +210,7 @@ class Word2Vec( | |
| * @param dataset an RDD of words | ||
| * @return a Word2VecModel | ||
| */ | ||
|
|
||
| def fit[S <: Iterable[String]](dataset:RDD[S]): Word2VecModel = { | ||
| def fit[S <: Iterable[String]](dataset: RDD[S]): Word2VecModel = { | ||
|
|
||
| val words = dataset.flatMap(x => x) | ||
|
|
||
|
|
@@ -223,39 +224,37 @@ class Word2Vec( | |
| val bcVocab = sc.broadcast(vocab) | ||
| val bcVocabHash = sc.broadcast(vocabHash) | ||
|
|
||
| val sentences: RDD[Array[Int]] = words.mapPartitions { | ||
| iter => { new Iterator[Array[Int]] { | ||
| def hasNext = iter.hasNext | ||
|
|
||
| def next = { | ||
| var sentence = new ArrayBuffer[Int] | ||
| var sentenceLength = 0 | ||
| while (iter.hasNext && sentenceLength < MAX_SENTENCE_LENGTH) { | ||
| val word = bcVocabHash.value.get(iter.next) | ||
| word match { | ||
| case Some(w) => { | ||
| sentence += w | ||
| sentenceLength += 1 | ||
| } | ||
| case None => | ||
| } | ||
| val sentences: RDD[Array[Int]] = words.mapPartitions { iter => | ||
| new Iterator[Array[Int]] { | ||
| def hasNext: Boolean = iter.hasNext | ||
|
|
||
| def next(): Array[Int] = { | ||
| var sentence = new ArrayBuffer[Int] | ||
| var sentenceLength = 0 | ||
| while (iter.hasNext && sentenceLength < MAX_SENTENCE_LENGTH) { | ||
| val word = bcVocabHash.value.get(iter.next()) | ||
| word match { | ||
| case Some(w) => | ||
| sentence += w | ||
| sentenceLength += 1 | ||
| case None => | ||
| } | ||
| sentence.toArray | ||
| } | ||
| sentence.toArray | ||
| } | ||
| } | ||
| } | ||
|
|
||
| val newSentences = sentences.repartition(parallelism).cache() | ||
| var syn0Global | ||
| = Array.fill[Double](vocabSize * layer1Size)((Random.nextDouble - 0.5) / layer1Size) | ||
| var syn1Global = new Array[Double](vocabSize * layer1Size) | ||
| var syn0Global = | ||
| Array.fill[Float](vocabSize * layer1Size)((Random.nextFloat() - 0.5f) / layer1Size) | ||
| var syn1Global = new Array[Float](vocabSize * layer1Size) | ||
|
|
||
| for(iter <- 1 to numIterations) { | ||
| val (aggSyn0, aggSyn1, _, _) = | ||
| // TODO: broadcast temp instead of serializing it directly | ||
| // TODO: broadcast temp instead of serializing it directly | ||
| // or initialize the model in each executor | ||
| newSentences.treeAggregate((syn0Global.clone(), syn1Global.clone(), 0, 0))( | ||
| newSentences.treeAggregate((syn0Global, syn1Global, 0, 0))( | ||
| seqOp = (c, v) => (c, v) match { | ||
| case ((syn0, syn1, lastWordCount, wordCount), sentence) => | ||
| var lwc = lastWordCount | ||
|
|
@@ -280,23 +279,23 @@ class Word2Vec( | |
| if (c >= 0 && c < sentence.size) { | ||
| val lastWord = sentence(c) | ||
| val l1 = lastWord * layer1Size | ||
| val neu1e = new Array[Double](layer1Size) | ||
| val neu1e = new Array[Float](layer1Size) | ||
| // Hierarchical softmax | ||
| var d = 0 | ||
| while (d < bcVocab.value(word).codeLen) { | ||
| val l2 = bcVocab.value(word).point(d) * layer1Size | ||
| // Propagate hidden -> output | ||
| var f = blas.ddot(layer1Size, syn0, l1, 1, syn1, l2, 1) | ||
| var f = blas.sdot(layer1Size, syn0, l1, 1, syn1, l2, 1) | ||
| if (f > -MAX_EXP && f < MAX_EXP) { | ||
| val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt | ||
| f = expTable.value(ind) | ||
| val g = (1 - bcVocab.value(word).code(d) - f) * alpha | ||
| blas.daxpy(layer1Size, g, syn1, l2, 1, neu1e, 0, 1) | ||
| blas.daxpy(layer1Size, g, syn0, l1, 1, syn1, l2, 1) | ||
| val g = ((1 - bcVocab.value(word).code(d) - f) * alpha).toFloat | ||
| blas.saxpy(layer1Size, g, syn1, l2, 1, neu1e, 0, 1) | ||
| blas.saxpy(layer1Size, g, syn0, l1, 1, syn1, l2, 1) | ||
| } | ||
| d += 1 | ||
| } | ||
| blas.daxpy(layer1Size, 1.0, neu1e, 0, 1, syn0, l1, 1) | ||
| blas.saxpy(layer1Size, 1.0f, neu1e, 0, 1, syn0, l1, 1) | ||
| } | ||
| } | ||
| a += 1 | ||
|
|
@@ -308,24 +307,24 @@ class Word2Vec( | |
| combOp = (c1, c2) => (c1, c2) match { | ||
| case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) => | ||
| val n = syn0_1.length | ||
| val weight1 = 1.0 * wc_1 / (wc_1 + wc_2) | ||
| val weight2 = 1.0 * wc_2 / (wc_1 + wc_2) | ||
| blas.dscal(n, weight1, syn0_1, 1) | ||
| blas.dscal(n, weight1, syn1_1, 1) | ||
| blas.daxpy(n, weight2, syn0_2, 1, syn0_1, 1) | ||
| blas.daxpy(n, weight2, syn1_2, 1, syn1_1, 1) | ||
| val weight1 = 1.0f * wc_1 / (wc_1 + wc_2) | ||
| val weight2 = 1.0f * wc_2 / (wc_1 + wc_2) | ||
| blas.sscal(n, weight1, syn0_1, 1) | ||
| blas.sscal(n, weight1, syn1_1, 1) | ||
| blas.saxpy(n, weight2, syn0_2, 1, syn0_1, 1) | ||
| blas.saxpy(n, weight2, syn1_2, 1, syn1_1, 1) | ||
| (syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2) | ||
| }) | ||
| syn0Global = aggSyn0 | ||
| syn1Global = aggSyn1 | ||
| } | ||
| newSentences.unpersist() | ||
|
|
||
| val wordMap = new Array[(String, Array[Double])](vocabSize) | ||
| val wordMap = new Array[(String, Array[Float])](vocabSize) | ||
| var i = 0 | ||
| while (i < vocabSize) { | ||
| val word = bcVocab.value(i).word | ||
| val vector = new Array[Double](layer1Size) | ||
| val vector = new Array[Float](layer1Size) | ||
| Array.copy(syn0Global, i * layer1Size, vector, 0, layer1Size) | ||
| wordMap(i) = (word, vector) | ||
| i += 1 | ||
|
|
@@ -341,15 +340,15 @@ class Word2Vec( | |
| /** | ||
| * Word2Vec model | ||
| */ | ||
| class Word2VecModel (private val model:RDD[(String, Array[Double])]) extends Serializable { | ||
| class Word2VecModel(private val model: RDD[(String, Array[Float])]) extends Serializable { | ||
|
|
||
| private def cosineSimilarity(v1: Array[Double], v2: Array[Double]): Double = { | ||
| private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = { | ||
| require(v1.length == v2.length, "Vectors should have the same length") | ||
| val n = v1.length | ||
| val norm1 = blas.dnrm2(n, v1, 1) | ||
| val norm2 = blas.dnrm2(n, v2, 1) | ||
| val norm1 = blas.snrm2(n, v1, 1) | ||
| val norm2 = blas.snrm2(n, v2, 1) | ||
| if (norm1 == 0 || norm2 == 0) return 0.0 | ||
| blas.ddot(n, v1, 1, v2,1) / norm1 / norm2 | ||
| blas.sdot(n, v1, 1, v2,1) / norm1 / norm2 | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -360,9 +359,9 @@ class Word2VecModel (private val model:RDD[(String, Array[Double])]) extends Ser | |
| def transform(word: String): Vector = { | ||
| val result = model.lookup(word) | ||
| if (result.isEmpty) { | ||
| throw new IllegalStateException(s"${word} not in vocabulary") | ||
| throw new IllegalStateException(s"$word not in vocabulary") | ||
| } | ||
| else Vectors.dense(result(0)) | ||
| else Vectors.dense(result(0).map(_.toDouble)) | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -394,7 +393,7 @@ class Word2VecModel (private val model:RDD[(String, Array[Double])]) extends Ser | |
| def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = { | ||
| require(num > 0, "Number of similar words should > 0") | ||
| val topK = model.map { case(w, vec) => | ||
| (cosineSimilarity(vector.toArray, vec), w) } | ||
| (cosineSimilarity(vector.toArray.map(_.toFloat), vec), w) } | ||
| .sortByKey(ascending = false) | ||
| .take(num + 1) | ||
| .map(_.swap) | ||
|
|
@@ -410,18 +409,16 @@ object Word2Vec{ | |
| * @param input RDD of words | ||
| * @param size vector dimension | ||
| * @param startingAlpha initial learning rate | ||
| * @param window context words from [-window, window] | ||
| * @param minCount minimum frequncy to consider a vocabulary word | ||
| * @return Word2Vec model | ||
| */ | ||
| * @param parallelism number of partitions to run Word2Vec (using a small number for accuracy) | ||
| * @param numIterations number of iterations, should be smaller than or equal to parallelism | ||
| * @return Word2Vec model | ||
| */ | ||
| def train[S <: Iterable[String]]( | ||
| input: RDD[S], | ||
| size: Int, | ||
| startingAlpha: Double, | ||
| window: Int, | ||
| minCount: Int, | ||
| parallelism: Int = 1, | ||
| numIterations:Int = 1): Word2VecModel = { | ||
| new Word2Vec(size,startingAlpha, window, minCount, parallelism, numIterations).fit[S](input) | ||
| new Word2Vec(size,startingAlpha, parallelism, numIterations).fit[S](input) | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,10 +19,12 @@ package org.apache.spark.mllib.feature | |
|
|
||
| import org.scalatest.FunSuite | ||
|
|
||
| import org.apache.spark.SparkContext._ | ||
| import org.apache.spark.mllib.util.LocalSparkContext | ||
|
|
||
| class Word2VecSuite extends FunSuite with LocalSparkContext { | ||
|
|
||
| // TODO: add more tests | ||
|
|
||
| test("Word2Vec") { | ||
| val sentence = "a b " * 100 + "a c " * 10 | ||
| val localDoc = Seq(sentence, sentence) | ||
|
|
@@ -33,28 +35,27 @@ class Word2VecSuite extends FunSuite with LocalSparkContext { | |
| val window = 2 | ||
| val minCount = 2 | ||
| val num = 2 | ||
| val word = "a" | ||
|
|
||
| val model = Word2Vec.train(doc, size, startingAlpha, window, minCount) | ||
| val synons = model.findSynonyms("a", 2) | ||
| assert(synons.length == num) | ||
| assert(synons(0)._1 == "b") | ||
| assert(synons(1)._1 == "c") | ||
| val syms = model.findSynonyms("a", 2) | ||
| assert(syms.length == num) | ||
| assert(syms(0)._1 == "b") | ||
| assert(syms(1)._1 == "c") | ||
| } | ||
|
|
||
|
|
||
| test("Word2VecModel") { | ||
| val num = 2 | ||
| val localModel = Seq( | ||
| ("china" , Array(0.50, 0.50, 0.50, 0.50)), | ||
| ("japan" , Array(0.40, 0.50, 0.50, 0.50)), | ||
| ("taiwan", Array(0.60, 0.50, 0.50, 0.50)), | ||
| ("korea" , Array(0.45, 0.60, 0.60, 0.60)) | ||
| ("china" , Array(0.50f, 0.50f, 0.50f, 0.50f)), | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed |
||
| ("japan" , Array(0.40f, 0.50f, 0.50f, 0.50f)), | ||
| ("taiwan", Array(0.60f, 0.50f, 0.50f, 0.50f)), | ||
| ("korea" , Array(0.45f, 0.60f, 0.60f, 0.60f)) | ||
| ) | ||
| val model = new Word2VecModel(sc.parallelize(localModel, 2)) | ||
| val synons = model.findSynonyms("china", num) | ||
| assert(synons.length == num) | ||
| assert(synons(0)._1 == "taiwan") | ||
| assert(synons(1)._1 == "japan") | ||
| val syms = model.findSynonyms("china", num) | ||
| assert(syms.length == num) | ||
| assert(syms(0)._1 == "taiwan") | ||
| assert(syms(1)._1 == "japan") | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Force using
F2jBLASmay be better because only level-1 operations are used. I will send you a PR on this.