Skip to content
Prev Previous commit
Next Next commit
modify according to feedback
  • Loading branch information
Liquan Pei committed Aug 2, 2014
commit 2e92b5991ad8f3f73bbeab9a056f452c4b532b3c
146 changes: 86 additions & 60 deletions mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
Original file line number Diff line number Diff line change
@@ -1,33 +1,33 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* Add a comment to this line
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.feature

import scala.util.{Random => Random}
import scala.util.Random
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable

import com.github.fommil.netlib.BLAS.{getInstance => blas}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Force using F2jBLAS may be better because only level-1 operations are used. I will send you a PR on this.


import org.apache.spark._
import org.apache.spark.annotation.Experimental
import org.apache.spark.Logging
import org.apache.spark.rdd._
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.HashPartitioner

/**
Expand All @@ -42,8 +42,27 @@ private case class VocabWord(
)

/**
* Vector representation of word
* :: Experimental ::
* Word2Vec creates vector representation of words in a text corpus.
* The algorithm first constructs a vocabulary from the corpus
* and then learns vector representation of words in the vocabulary.
* The vector representation can be used as features in
* natural language processing and machine learning algorithms.
*
* We used skip-gram model in our implementation and hierarchical softmax
* method to train the model.
*
* For original C implementation, see https://code.google.com/p/word2vec/
* For research papers, see
* Efficient Estimation of Word Representations in Vector Space
* and
* Distributed Representations of Words and Phrases and their Compositionality
* @param size vector dimension
* @param startingAlpha initial learning rate
* @param window context words from [-window, window]
* @param minCount minimum frequncy to consider a vocabulary word
*/
@Experimental
class Word2Vec(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need more docs here, for example, link to the C implementation and the original papers for word2vec.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and briefly explain what it does.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

val size: Int,
val startingAlpha: Double,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is word2vec sensitive to alpha? If not, we should try to expose less parameters to users.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

word2vec is sensitive to alpha. Larger alpha may generate meaningless result

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can suggest a reasonable default value in the doc.

Expand All @@ -64,11 +83,15 @@ class Word2Vec(
private var vocabHash = mutable.HashMap.empty[String, Int]
private var alpha = startingAlpha

private def learnVocab(dataset: RDD[String]) {
vocab = dataset.flatMap(line => line.split(" "))
.map(w => (w, 1))
private def learnVocab(words:RDD[String]) {
vocab = words.map(w => (w, 1))
.reduceByKey(_ + _)
.map(x => VocabWord(x._1, x._2, new Array[Int](MAX_CODE_LENGTH), new Array[Int](MAX_CODE_LENGTH), 0))
.map(x => VocabWord(
x._1,
x._2,
new Array[Int](MAX_CODE_LENGTH),
new Array[Int](MAX_CODE_LENGTH),
0))
.filter(_.cn >= minCount)
.collect()
.sortWith((a, b)=> a.cn > b.cn)
Expand Down Expand Up @@ -172,15 +195,16 @@ class Word2Vec(
}

/**
* Computes the vector representation of each word in
* vocabulary
* @param dataset an RDD of strings
* Computes the vector representation of each word in vocabulary.
* @param dataset an RDD of words
* @return a Word2VecModel
*/

def fit(dataset:RDD[String]): Word2VecModel = {
def fit[S <: Iterable[String]](dataset:RDD[S]): Word2VecModel = {

learnVocab(dataset)
val words = dataset.flatMap(x => x)

learnVocab(words)

createBinaryTree()

Expand All @@ -190,9 +214,10 @@ class Word2Vec(
val V = sc.broadcast(vocab)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is V the name used in the original implementation? If not, we should rename it to bcVocab

val VHash = sc.broadcast(vocabHash)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same issue here. bcVocabHash


val sentences = dataset.flatMap(line => line.split(" ")).mapPartitions {
val sentences = words.mapPartitions {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sentences => sentences: RDD[Array[Int]] for code readability

iter => { new Iterator[Array[Int]] {
def hasNext = iter.hasNext

def next = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

insert an empty line between method definitions

var sentence = new ArrayBuffer[Int]
var sentenceLength = 0
Expand All @@ -215,7 +240,8 @@ class Word2Vec(
val newSentences = sentences.repartition(1).cache()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we make it configurable?

val temp = Array.fill[Double](vocabSize * layer1Size)((Random.nextDouble - 0.5) / layer1Size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try to fix the seed to make the computation reproducible.

val (aggSyn0, _, _, _) =
// TODO: broadcast temp instead of serializing it directly or initialize the model in each executor
// TODO: broadcast temp instead of serializing it directly
// or initialize the model in each executor
newSentences.aggregate((temp.clone(), new Array[Double](vocabSize * layer1Size), 0, 0))(
seqOp = (c, v) => (c, v) match { case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
var lwc = lastWordCount
Expand All @@ -241,7 +267,7 @@ class Word2Vec(
val lastWord = sentence(c)
val l1 = lastWord * layer1Size
val neu1e = new Array[Double](layer1Size)
//HS
// Hierarchical softmax
var d = 0
while (d < vocab(word).codeLen) {
val l2 = vocab(word).point(d) * layer1Size
Expand All @@ -265,11 +291,12 @@ class Word2Vec(
}
(syn0, syn1, lwc, wc)
},
combOp = (c1, c2) => (c1, c2) match { case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) =>
val n = syn0_1.length
blas.daxpy(n, 1.0, syn0_2, 1, syn0_1, 1)
blas.daxpy(n, 1.0, syn1_2, 1, syn1_1, 1)
(syn0_1, syn0_2, lwc_1 + lwc_2, wc_1 + wc_2)
combOp = (c1, c2) => (c1, c2) match {
case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) =>
val n = syn0_1.length
blas.daxpy(n, 1.0, syn0_2, 1, syn0_1, 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should use weighted sum to handle imbalanced partitions.

blas.daxpy(n, 1.0, syn1_2, 1, syn1_1, 1)
(syn0_1, syn0_2, lwc_1 + lwc_2, wc_1 + wc_2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syn0_2 -> syn1_1

})

val wordMap = new Array[(String, Array[Double])](vocabSize)
Expand All @@ -281,19 +308,18 @@ class Word2Vec(
wordMap(i) = (word, vector)
i += 1
}
val modelRDD = sc.parallelize(wordMap, modelPartitionNum).partitionBy(new HashPartitioner(modelPartitionNum))
val modelRDD = sc.parallelize(wordMap, modelPartitionNum)
.partitionBy(new HashPartitioner(modelPartitionNum))
new Word2VecModel(modelRDD)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please call newSentences.unpersist() before exit

}
}

/**
* Word2Vec model
*/
class Word2VecModel (val _model:RDD[(String, Array[Double])]) extends Serializable {

val model = _model
class Word2VecModel (private val model:RDD[(String, Array[Double])]) extends Serializable {

private def distance(v1: Array[Double], v2: Array[Double]): Double = {
private def cosineSimilarity(v1: Array[Double], v2: Array[Double]): Double = {
require(v1.length == v2.length, "Vectors should have the same length")
val n = v1.length
val norm1 = blas.dnrm2(n, v1, 1)
Expand All @@ -307,20 +333,20 @@ class Word2VecModel (val _model:RDD[(String, Array[Double])]) extends Serializab
* @param word a word
* @return vector representation of word
*/

def transform(word: String): Array[Double] = {
def transform(word: String): Vector = {
val result = model.lookup(word)
if (result.isEmpty) Array[Double]()
else result(0)
if (result.isEmpty) {
throw new IllegalStateException(s"${word} not in vocabulary")
}
else Vectors.dense(result(0))
}

/**
* Transforms an RDD to its vector representation
* @param dataset a an RDD of words
* @return RDD of vector representation
*/

def transform(dataset: RDD[String]): RDD[Array[Double]] = {
def transform(dataset: RDD[String]): RDD[Vector] = {
dataset.map(word => transform(word))
}

Expand All @@ -332,44 +358,44 @@ class Word2VecModel (val _model:RDD[(String, Array[Double])]) extends Serializab
*/
def findSynonyms(word: String, num: Int): Array[(String, Double)] = {
val vector = transform(word)
if (vector.isEmpty) Array[(String, Double)]()
else findSynonyms(vector,num)
findSynonyms(vector,num)
}

/**
* Find synonyms of the vector representation of a word
* @param vector vector representation of a word
* @param num number of synonyms to find
* @return array of (word, similarity)
* @return array of (word, cosineSimilarity)
*/
def findSynonyms(vector: Array[Double], num: Int): Array[(String, Double)] = {
def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = {
require(num > 0, "Number of similar words should > 0")
val topK = model.map(
{case(w, vec) => (distance(vector, vec), w)})
val topK = model.map { case(w, vec) =>
(cosineSimilarity(vector.toArray, vec), w) }
.sortByKey(ascending = false)
.take(num + 1)
.map({case (dist, w) => (w, dist)}).drop(1)
.map(_.swap)
.tail

topK
}
}

object Word2Vec extends Serializable with Logging {
object Word2Vec{
/**
* Train Word2Vec model
* @param input RDD of words
* @param size vectoer dimension
* @param size vector dimension
* @param startingAlpha initial learning rate
* @param window context words from [-window, window]
* @param minCount minimum frequncy to consider a vocabulary word
* @return Word2Vec model
*/
def train(
input: RDD[String],
def train[S <: Iterable[String]](
input: RDD[S],
size: Int,
startingAlpha: Double,
window: Int,
minCount: Int): Word2VecModel = {
new Word2Vec(size,startingAlpha, window, minCount).fit(input)
new Word2Vec(size,startingAlpha, window, minCount).fit[S](input)
}
}
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* Add a comment to this line
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.feature

import org.scalatest.FunSuite

import org.apache.spark.SparkContext._
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

insert an empty line to separate 3rd-party imports from spark imports

import org.apache.spark.mllib.util.LocalSparkContext

Expand Down