Skip to content
Prev Previous commit
Next Next commit
Add comments, minor fixes
  • Loading branch information
Liquan Pei committed Aug 1, 2014
commit 0aafb1b02a19fe4f1689543baf1882a49a7ff11a
69 changes: 46 additions & 23 deletions mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ import org.apache.spark.SparkContext._
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.HashPartitioner

/**
* Entry in vocabulary
*/
private case class VocabWord(
var word: String,
var cn: Int,
Expand All @@ -39,6 +42,9 @@ private case class VocabWord(
var codeLen:Int
)

/**
* Vector representation of word
*/
class Word2Vec(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need more docs here, for example, link to the C implementation and the original papers for word2vec.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and briefly explain what it does.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

val size: Int,
val startingAlpha: Double,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is word2vec sensitive to alpha? If not, we should try to expose less parameters to users.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

word2vec is sensitive to alpha. Larger alpha may generate meaningless result

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can suggest a reasonable default value in the doc.

Expand All @@ -51,7 +57,8 @@ class Word2Vec(
private val MAX_CODE_LENGTH = 40
private val MAX_SENTENCE_LENGTH = 1000
private val layer1Size = size

private val modelPartitionNum = 100

private var trainWordsCount = 0
private var vocabSize = 0
private var vocab: Array[VocabWord] = null
Expand Down Expand Up @@ -169,6 +176,7 @@ class Word2Vec(
* Computes the vector representation of each word in
* vocabulary
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move vocabulary to previous line and add .

* @param dataset an RDD of strings
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need more information about what each record should be, a word, a sentence, or a paragraph?

* @return a Word2VecModel
*/

def fit(dataset:RDD[String]): Word2VecModel = {
Expand Down Expand Up @@ -274,11 +282,14 @@ class Word2Vec(
wordMap(i) = (word, vector)
i += 1
}
val modelRDD = sc.parallelize(wordMap,100).partitionBy(new HashPartitioner(100))
val modelRDD = sc.parallelize(wordMap, modelPartitionNum).partitionBy(new HashPartitioner(modelPartitionNum))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

line too wide

new Word2VecModel(modelRDD)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please call newSentences.unpersist() before exit

}
}

/**
* Word2Vec model
*/
class Word2VecModel (val _model:RDD[(String, Array[Double])]) extends Serializable {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_model is not necessary. Shall we change the constructor to

class Word2VecModel private (private val model: RDD[(String, Array[Double])]) {

Serializable may be unnecessary.


val model = _model
Expand All @@ -292,22 +303,46 @@ class Word2VecModel (val _model:RDD[(String, Array[Double])]) extends Serializab
blas.ddot(n, v1, 1, v2,1) / norm1 / norm2
}

/**
* Transforms a word to its vector representation
* @param word a word
* @return vector representation of word
*/

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove empty line

def transform(word: String): Array[Double] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we use mllib.linalg.Vector as the return type?

val result = model.lookup(word)
if (result.isEmpty) Array[Double]()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should create an empty vector of the same size or throw an exception

else result(0)
}

/**
* Transforms an RDD to its vector representation
* @param dataset a an RDD of words
* @return RDD of vector representation
*/

def transform(dataset: RDD[String]): RDD[Array[Double]] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return type should be RDD[Vector].

dataset.map(word => transform(word))
}

/**
* Find synonyms of a word
* @param word a word
* @param num number of synonyms to find
* @return array of (word, similarity)
*/
def findSynonyms(word: String, num: Int): Array[(String, Double)] = {
val vector = transform(word)
if (vector.isEmpty) Array[(String, Double)]()
else findSynonyms(vector,num)
}

/**
* Find synonyms of the vector representation of a word
* @param vector vector representation of a word
* @param num number of synonyms to find
* @return array of (word, similarity)
*/
def findSynonyms(vector: Array[Double], num: Int): Array[(String, Double)] = {
require(num > 0, "Number of similar words should > 0")
val topK = model.map(
Expand All @@ -321,6 +356,15 @@ class Word2VecModel (val _model:RDD[(String, Array[Double])]) extends Serializab
}

object Word2Vec extends Serializable with Logging {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Serializable and Logging are not used.

/**
* Train Word2Vec model
* @param input RDD of words
* @param size vectoer dimension
* @param startingAlpha initial learning rate
* @param window context words from [-window, window]
* @param minCount minimum frequncy to consider a vocabulary word
* @return Word2Vec model
*/
def train(
input: RDD[String],
size: Int,
Expand All @@ -329,25 +373,4 @@ object Word2Vec extends Serializable with Logging {
minCount: Int): Word2VecModel = {
new Word2Vec(size,startingAlpha, window, minCount).fit(input)
}

def main(args: Array[String]) {
if (args.length < 6) {
println("Usage: word2vec input size startingAlpha window minCount num")
sys.exit(1)
}
val conf = new SparkConf()
.setAppName("word2vec")

val sc = new SparkContext(conf)
val input = sc.textFile(args(0))
val size = args(1).toInt
val startingAlpha = args(2).toDouble
val window = args(3).toInt
val minCount = args(4).toInt
val num = args(5).toInt
val model = train(input, size, startingAlpha, window, minCount)
val vec = model.findSynonyms("china", num)
for((w, dist) <- vec) logInfo(w.toString + " " + dist.toString)
sc.stop()
}
}