Skip to content
161 changes: 79 additions & 82 deletions mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,19 @@

package org.apache.spark.mllib.feature

import scala.util.Random
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.util.Random

import com.github.fommil.netlib.BLAS.{getInstance => blas}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Force using F2jBLAS may be better because only level-1 operations are used. I will send you a PR on this.


import org.apache.spark.annotation.Experimental
import org.apache.spark.Logging
import org.apache.spark.rdd._
import org.apache.spark.{HashPartitioner, Logging}
import org.apache.spark.SparkContext._
import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.HashPartitioner
import org.apache.spark.storage.StorageLevel
import org.apache.spark.mllib.rdd.RDDFunctions._
import org.apache.spark.rdd._
import org.apache.spark.storage.StorageLevel

/**
* Entry in vocabulary
*/
Expand All @@ -52,7 +51,7 @@ private case class VocabWord(
*
* We used skip-gram model in our implementation and hierarchical softmax
* method to train the model. The variable names in the implementation
* mathes the original C implementation.
* matches the original C implementation.
*
* For original C implementation, see https://code.google.com/p/word2vec/
* For research papers, see
Expand All @@ -61,34 +60,41 @@ private case class VocabWord(
* Distributed Representations of Words and Phrases and their Compositionality.
* @param size vector dimension
* @param startingAlpha initial learning rate
* @param window context words from [-window, window]
* @param minCount minimum frequncy to consider a vocabulary word
* @param parallelisum number of partitions to run Word2Vec
* @param parallelism number of partitions to run Word2Vec (using a small number for accuracy)
* @param numIterations number of iterations to run, should be smaller than or equal to parallelism
*/
@Experimental
class Word2Vec(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need more docs here, for example, link to the C implementation and the original papers for word2vec.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and briefly explain what it does.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

val size: Int,
val startingAlpha: Double,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is word2vec sensitive to alpha? If not, we should try to expose less parameters to users.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

word2vec is sensitive to alpha. Larger alpha may generate meaningless result

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can suggest a reasonable default value in the doc.

val window: Int,
val minCount: Int,
val parallelism:Int = 1,
val numIterations:Int = 1)
extends Serializable with Logging {

val parallelism: Int,
val numIterations: Int) extends Serializable with Logging {

/**
* Word2Vec with a single thread.
*/
def this(size: Int, startingAlpha: Int) = this(size, startingAlpha, 1, 1)

private val EXP_TABLE_SIZE = 1000
private val MAX_EXP = 6
private val MAX_CODE_LENGTH = 40
private val MAX_SENTENCE_LENGTH = 1000
private val layer1Size = size
private val modelPartitionNum = 100


/** context words from [-window, window] */
private val window = 5

/** minimum frequency to consider a vocabulary word */
private val minCount = 5

private var trainWordsCount = 0
private var vocabSize = 0
private var vocab: Array[VocabWord] = null
private var vocabHash = mutable.HashMap.empty[String, Int]
private var alpha = startingAlpha

private def learnVocab(words:RDD[String]){
private def learnVocab(words:RDD[String]): Unit = {
vocab = words.map(w => (w, 1))
.reduceByKey(_ + _)
.map(x => VocabWord(
Expand All @@ -99,7 +105,7 @@ class Word2Vec(
0))
.filter(_.cn >= minCount)
.collect()
.sortWith((a, b)=> a.cn > b.cn)
.sortWith((a, b) => a.cn > b.cn)

vocabSize = vocab.length
var a = 0
Expand All @@ -111,22 +117,18 @@ class Word2Vec(
logInfo("trainWordsCount = " + trainWordsCount)
}

private def learnVocabPerPartition(words:RDD[String]) {

}

private def createExpTable(): Array[Double] = {
val expTable = new Array[Double](EXP_TABLE_SIZE)
private def createExpTable(): Array[Float] = {
val expTable = new Array[Float](EXP_TABLE_SIZE)
var i = 0
while (i < EXP_TABLE_SIZE) {
val tmp = math.exp((2.0 * i / EXP_TABLE_SIZE - 1.0) * MAX_EXP)
expTable(i) = tmp / (tmp + 1)
expTable(i) = (tmp / (tmp + 1.0)).toFloat
i += 1
}
expTable
}

private def createBinaryTree() {
private def createBinaryTree(): Unit = {
val count = new Array[Long](vocabSize * 2 + 1)
val binary = new Array[Int](vocabSize * 2 + 1)
val parentNode = new Array[Int](vocabSize * 2 + 1)
Expand Down Expand Up @@ -208,8 +210,7 @@ class Word2Vec(
* @param dataset an RDD of words
* @return a Word2VecModel
*/

def fit[S <: Iterable[String]](dataset:RDD[S]): Word2VecModel = {
def fit[S <: Iterable[String]](dataset: RDD[S]): Word2VecModel = {

val words = dataset.flatMap(x => x)

Expand All @@ -223,39 +224,37 @@ class Word2Vec(
val bcVocab = sc.broadcast(vocab)
val bcVocabHash = sc.broadcast(vocabHash)

val sentences: RDD[Array[Int]] = words.mapPartitions {
iter => { new Iterator[Array[Int]] {
def hasNext = iter.hasNext

def next = {
var sentence = new ArrayBuffer[Int]
var sentenceLength = 0
while (iter.hasNext && sentenceLength < MAX_SENTENCE_LENGTH) {
val word = bcVocabHash.value.get(iter.next)
word match {
case Some(w) => {
sentence += w
sentenceLength += 1
}
case None =>
}
val sentences: RDD[Array[Int]] = words.mapPartitions { iter =>
new Iterator[Array[Int]] {
def hasNext: Boolean = iter.hasNext

def next(): Array[Int] = {
var sentence = new ArrayBuffer[Int]
var sentenceLength = 0
while (iter.hasNext && sentenceLength < MAX_SENTENCE_LENGTH) {
val word = bcVocabHash.value.get(iter.next())
word match {
case Some(w) =>
sentence += w
sentenceLength += 1
case None =>
}
sentence.toArray
}
sentence.toArray
}
}
}

val newSentences = sentences.repartition(parallelism).cache()
var syn0Global
= Array.fill[Double](vocabSize * layer1Size)((Random.nextDouble - 0.5) / layer1Size)
var syn1Global = new Array[Double](vocabSize * layer1Size)
var syn0Global =
Array.fill[Float](vocabSize * layer1Size)((Random.nextFloat() - 0.5f) / layer1Size)
var syn1Global = new Array[Float](vocabSize * layer1Size)

for(iter <- 1 to numIterations) {
val (aggSyn0, aggSyn1, _, _) =
// TODO: broadcast temp instead of serializing it directly
// TODO: broadcast temp instead of serializing it directly
// or initialize the model in each executor
newSentences.treeAggregate((syn0Global.clone(), syn1Global.clone(), 0, 0))(
newSentences.treeAggregate((syn0Global, syn1Global, 0, 0))(
seqOp = (c, v) => (c, v) match {
case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
var lwc = lastWordCount
Expand All @@ -280,23 +279,23 @@ class Word2Vec(
if (c >= 0 && c < sentence.size) {
val lastWord = sentence(c)
val l1 = lastWord * layer1Size
val neu1e = new Array[Double](layer1Size)
val neu1e = new Array[Float](layer1Size)
// Hierarchical softmax
var d = 0
while (d < bcVocab.value(word).codeLen) {
val l2 = bcVocab.value(word).point(d) * layer1Size
// Propagate hidden -> output
var f = blas.ddot(layer1Size, syn0, l1, 1, syn1, l2, 1)
var f = blas.sdot(layer1Size, syn0, l1, 1, syn1, l2, 1)
if (f > -MAX_EXP && f < MAX_EXP) {
val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt
f = expTable.value(ind)
val g = (1 - bcVocab.value(word).code(d) - f) * alpha
blas.daxpy(layer1Size, g, syn1, l2, 1, neu1e, 0, 1)
blas.daxpy(layer1Size, g, syn0, l1, 1, syn1, l2, 1)
val g = ((1 - bcVocab.value(word).code(d) - f) * alpha).toFloat
blas.saxpy(layer1Size, g, syn1, l2, 1, neu1e, 0, 1)
blas.saxpy(layer1Size, g, syn0, l1, 1, syn1, l2, 1)
}
d += 1
}
blas.daxpy(layer1Size, 1.0, neu1e, 0, 1, syn0, l1, 1)
blas.saxpy(layer1Size, 1.0f, neu1e, 0, 1, syn0, l1, 1)
}
}
a += 1
Expand All @@ -308,24 +307,24 @@ class Word2Vec(
combOp = (c1, c2) => (c1, c2) match {
case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) =>
val n = syn0_1.length
val weight1 = 1.0 * wc_1 / (wc_1 + wc_2)
val weight2 = 1.0 * wc_2 / (wc_1 + wc_2)
blas.dscal(n, weight1, syn0_1, 1)
blas.dscal(n, weight1, syn1_1, 1)
blas.daxpy(n, weight2, syn0_2, 1, syn0_1, 1)
blas.daxpy(n, weight2, syn1_2, 1, syn1_1, 1)
val weight1 = 1.0f * wc_1 / (wc_1 + wc_2)
val weight2 = 1.0f * wc_2 / (wc_1 + wc_2)
blas.sscal(n, weight1, syn0_1, 1)
blas.sscal(n, weight1, syn1_1, 1)
blas.saxpy(n, weight2, syn0_2, 1, syn0_1, 1)
blas.saxpy(n, weight2, syn1_2, 1, syn1_1, 1)
(syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2)
})
syn0Global = aggSyn0
syn1Global = aggSyn1
}
newSentences.unpersist()

val wordMap = new Array[(String, Array[Double])](vocabSize)
val wordMap = new Array[(String, Array[Float])](vocabSize)
var i = 0
while (i < vocabSize) {
val word = bcVocab.value(i).word
val vector = new Array[Double](layer1Size)
val vector = new Array[Float](layer1Size)
Array.copy(syn0Global, i * layer1Size, vector, 0, layer1Size)
wordMap(i) = (word, vector)
i += 1
Expand All @@ -341,15 +340,15 @@ class Word2Vec(
/**
* Word2Vec model
*/
class Word2VecModel (private val model:RDD[(String, Array[Double])]) extends Serializable {
class Word2VecModel(private val model: RDD[(String, Array[Float])]) extends Serializable {

private def cosineSimilarity(v1: Array[Double], v2: Array[Double]): Double = {
private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = {
require(v1.length == v2.length, "Vectors should have the same length")
val n = v1.length
val norm1 = blas.dnrm2(n, v1, 1)
val norm2 = blas.dnrm2(n, v2, 1)
val norm1 = blas.snrm2(n, v1, 1)
val norm2 = blas.snrm2(n, v2, 1)
if (norm1 == 0 || norm2 == 0) return 0.0
blas.ddot(n, v1, 1, v2,1) / norm1 / norm2
blas.sdot(n, v1, 1, v2,1) / norm1 / norm2
}

/**
Expand All @@ -360,9 +359,9 @@ class Word2VecModel (private val model:RDD[(String, Array[Double])]) extends Ser
def transform(word: String): Vector = {
val result = model.lookup(word)
if (result.isEmpty) {
throw new IllegalStateException(s"${word} not in vocabulary")
throw new IllegalStateException(s"$word not in vocabulary")
}
else Vectors.dense(result(0))
else Vectors.dense(result(0).map(_.toDouble))
}

/**
Expand Down Expand Up @@ -394,7 +393,7 @@ class Word2VecModel (private val model:RDD[(String, Array[Double])]) extends Ser
def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = {
require(num > 0, "Number of similar words should > 0")
val topK = model.map { case(w, vec) =>
(cosineSimilarity(vector.toArray, vec), w) }
(cosineSimilarity(vector.toArray.map(_.toFloat), vec), w) }
.sortByKey(ascending = false)
.take(num + 1)
.map(_.swap)
Expand All @@ -410,18 +409,16 @@ object Word2Vec{
* @param input RDD of words
* @param size vector dimension
* @param startingAlpha initial learning rate
* @param window context words from [-window, window]
* @param minCount minimum frequncy to consider a vocabulary word
* @return Word2Vec model
*/
* @param parallelism number of partitions to run Word2Vec (using a small number for accuracy)
* @param numIterations number of iterations, should be smaller than or equal to parallelism
* @return Word2Vec model
*/
def train[S <: Iterable[String]](
input: RDD[S],
size: Int,
startingAlpha: Double,
window: Int,
minCount: Int,
parallelism: Int = 1,
numIterations:Int = 1): Word2VecModel = {
new Word2Vec(size,startingAlpha, window, minCount, parallelism, numIterations).fit[S](input)
new Word2Vec(size,startingAlpha, parallelism, numIterations).fit[S](input)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ package org.apache.spark.mllib.feature

import org.scalatest.FunSuite

import org.apache.spark.SparkContext._
import org.apache.spark.mllib.util.LocalSparkContext

class Word2VecSuite extends FunSuite with LocalSparkContext {

// TODO: add more tests

test("Word2Vec") {
val sentence = "a b " * 100 + "a c " * 10
val localDoc = Seq(sentence, sentence)
Expand All @@ -33,28 +35,27 @@ class Word2VecSuite extends FunSuite with LocalSparkContext {
val window = 2
val minCount = 2
val num = 2
val word = "a"

val model = Word2Vec.train(doc, size, startingAlpha, window, minCount)
val synons = model.findSynonyms("a", 2)
assert(synons.length == num)
assert(synons(0)._1 == "b")
assert(synons(1)._1 == "c")
val syms = model.findSynonyms("a", 2)
assert(syms.length == num)
assert(syms(0)._1 == "b")
assert(syms(1)._1 == "c")
}


test("Word2VecModel") {
val num = 2
val localModel = Seq(
("china" , Array(0.50, 0.50, 0.50, 0.50)),
("japan" , Array(0.40, 0.50, 0.50, 0.50)),
("taiwan", Array(0.60, 0.50, 0.50, 0.50)),
("korea" , Array(0.45, 0.60, 0.60, 0.60))
("china" , Array(0.50f, 0.50f, 0.50f, 0.50f)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

, -> , (remove extra spaces)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

("japan" , Array(0.40f, 0.50f, 0.50f, 0.50f)),
("taiwan", Array(0.60f, 0.50f, 0.50f, 0.50f)),
("korea" , Array(0.45f, 0.60f, 0.60f, 0.60f))
)
val model = new Word2VecModel(sc.parallelize(localModel, 2))
val synons = model.findSynonyms("china", num)
assert(synons.length == num)
assert(synons(0)._1 == "taiwan")
assert(synons(1)._1 == "japan")
val syms = model.findSynonyms("china", num)
assert(syms.length == num)
assert(syms(0)._1 == "taiwan")
assert(syms(1)._1 == "japan")
}
}