-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-12153][SPARK-7617][MLlib]add support of arbitrary length sentence and other tuning for Word2Vec #10152
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
dab8425
05d1eae
c226e9b
bd67405
7d4e9dd
d8d8b0b
6f9e9f0
f7c9296
3309913
ced05d1
d6ae270
214d0d9
909dbbd
76e8266
141d7a2
5aaff6e
71089c4
e938208
443ec06
84a0bc4
32df78e
84feb3d
2e052e5
a4abd40
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -76,6 +76,18 @@ class Word2Vec extends Serializable with Logging { | |
| private var numIterations = 1 | ||
| private var seed = Utils.random.nextLong() | ||
| private var minCount = 5 | ||
| private var maxSentenceLength = 1000 | ||
|
|
||
| /** | ||
| * Sets the maximum length (in words) of each sentence in the input data. | ||
| * Any sentence longer than this threshold will be divided into chunks of | ||
| * up to `maxSentenceLength` size (default: 1000) | ||
| */ | ||
| @Since("2.0.0") | ||
| def setMaxSentenceLength(maxSentenceLength: Int): this.type = { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is not clear from the doc what "sentence length" means, number of words or number of characters. We can either update the doc or change the param name to
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The param name comes from the original Google implementation. Either option (or both) works, but I guess I'd be marginally more in favour of amending the first line of doc to read
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good. |
||
| this.maxSentenceLength = maxSentenceLength | ||
| this | ||
| } | ||
|
|
||
| /** | ||
| * Sets vector size (default: 100). | ||
|
|
@@ -146,7 +158,6 @@ class Word2Vec extends Serializable with Logging { | |
| private val EXP_TABLE_SIZE = 1000 | ||
| private val MAX_EXP = 6 | ||
| private val MAX_CODE_LENGTH = 40 | ||
| private val MAX_SENTENCE_LENGTH = 1000 | ||
|
|
||
| /** context words from [-window, window] */ | ||
| private var window = 5 | ||
|
|
@@ -156,7 +167,9 @@ class Word2Vec extends Serializable with Logging { | |
| @transient private var vocab: Array[VocabWord] = null | ||
| @transient private var vocabHash = mutable.HashMap.empty[String, Int] | ||
|
|
||
| private def learnVocab(words: RDD[String]): Unit = { | ||
| private def learnVocab[S <: Iterable[String]](dataset: RDD[S]): Unit = { | ||
| val words = dataset.flatMap(x => x) | ||
|
|
||
| vocab = words.map(w => (w, 1)) | ||
| .reduceByKey(_ + _) | ||
| .filter(_._2 >= minCount) | ||
|
|
@@ -272,15 +285,14 @@ class Word2Vec extends Serializable with Logging { | |
|
|
||
| /** | ||
| * Computes the vector representation of each word in vocabulary. | ||
| * @param dataset an RDD of words | ||
| * @param dataset an RDD of sentences, | ||
| * each sentence is expressed as an iterable collection of words | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should remove L286 and change
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point. Will do |
||
| * @return a Word2VecModel | ||
| */ | ||
| @Since("1.1.0") | ||
| def fit[S <: Iterable[String]](dataset: RDD[S]): Word2VecModel = { | ||
|
|
||
| val words = dataset.flatMap(x => x) | ||
|
|
||
| learnVocab(words) | ||
| learnVocab(dataset) | ||
|
|
||
| createBinaryTree() | ||
|
|
||
|
|
@@ -289,25 +301,15 @@ class Word2Vec extends Serializable with Logging { | |
| val expTable = sc.broadcast(createExpTable()) | ||
| val bcVocab = sc.broadcast(vocab) | ||
| val bcVocabHash = sc.broadcast(vocabHash) | ||
|
|
||
| val sentences: RDD[Array[Int]] = words.mapPartitions { iter => | ||
| new Iterator[Array[Int]] { | ||
| def hasNext: Boolean = iter.hasNext | ||
|
|
||
| def next(): Array[Int] = { | ||
| val sentence = ArrayBuilder.make[Int] | ||
| var sentenceLength = 0 | ||
| while (iter.hasNext && sentenceLength < MAX_SENTENCE_LENGTH) { | ||
| val word = bcVocabHash.value.get(iter.next()) | ||
| word match { | ||
| case Some(w) => | ||
| sentence += w | ||
| sentenceLength += 1 | ||
| case None => | ||
| } | ||
| } | ||
| sentence.result() | ||
| } | ||
| // each partition is a collection of sentences, | ||
| // will be translated into arrays of Index integer | ||
| val sentences: RDD[Array[Int]] = dataset.mapPartitions { sentenceIter => | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can simply do: val sentences: RDD[Array[Int]] = dataset.mapPartitions { iter =>
new Iterator[Array[Int]] {
def hasNext: Boolean = iter.hasNext
def next(): Array[Int] = {
val sentence = ArrayBuilder.make[Int]
var sentenceLength = 0
val wordIter = iter.next().iterator
while (wordIter.hasNext && sentenceLength < MAX_SENTENCE_LENGTH) {
val word = bcVocabHash.value.get(wordIter.next())
word match {
case Some(w) =>
sentence += w
sentenceLength += 1
case None =>
}
}
sentence.result()
}
}
}
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is very close to my original version exception for will throw out all words after MAX_SENTENCE_LENGTH, and are you preferring to make the maxSentenceLength static config? The latest version of mine will still try to take use of the rest words of the sentence to cut for training after cutting by maxSentenceLength. e.g. for a 2200 words long sentence, it will be used as three cut sentences just like the old version except for the last/third sentence from the cut will be 200 words long without words padded from the next sentence. This way, we can maximize the usage of our data with both respecting sentence boundary and sentence length restriction.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah right, yes you're correct. We should chunk up the sentences. (I had copy-pasted the |
||
| // Each sentence will map to 0 or more Array[Int] | ||
| sentenceIter.flatMap { sentence => | ||
| // Sentence of words, some of which map to a word index | ||
| val wordIndexes = sentence.flatMap(bcVocabHash.value.get) | ||
| // break wordIndexes into trunks of maxSentenceLength when has more | ||
| wordIndexes.grouped(maxSentenceLength).map(_.toArray) | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -477,15 +479,6 @@ class Word2VecModel private[spark] ( | |
| this(Word2VecModel.buildWordIndex(model), Word2VecModel.buildWordVectors(model)) | ||
| } | ||
|
|
||
| private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = { | ||
| require(v1.length == v2.length, "Vectors should have the same length") | ||
| val n = v1.length | ||
| val norm1 = blas.snrm2(n, v1, 1) | ||
| val norm2 = blas.snrm2(n, v2, 1) | ||
| if (norm1 == 0 || norm2 == 0) return 0.0 | ||
| blas.sdot(n, v1, 1, v2, 1) / norm1 / norm2 | ||
| } | ||
|
|
||
| override protected def formatVersion = "1.0" | ||
|
|
||
| @Since("1.4.0") | ||
|
|
@@ -542,6 +535,7 @@ class Word2VecModel private[spark] ( | |
| // Need not divide with the norm of the given vector since it is constant. | ||
| val cosVec = cosineVec.map(_.toDouble) | ||
| var ind = 0 | ||
| val vecNorm = blas.snrm2(vectorSize, fVector, 1) | ||
| while (ind < numWords) { | ||
| val norm = wordVecNorms(ind) | ||
| if (norm == 0.0) { | ||
|
|
@@ -551,12 +545,17 @@ class Word2VecModel private[spark] ( | |
| } | ||
| ind += 1 | ||
| } | ||
| wordList.zip(cosVec) | ||
| var topResults = wordList.zip(cosVec) | ||
| .toSeq | ||
| .sortBy(- _._2) | ||
| .sortBy(-_._2) | ||
| .take(num + 1) | ||
| .tail | ||
| .toArray | ||
| if (vecNorm != 0.0f) { | ||
| topResults = topResults.map { case (word, cosVal) => | ||
| (word, cosVal / vecNorm) | ||
| } | ||
| } | ||
| topResults.toArray | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -568,6 +567,7 @@ class Word2VecModel private[spark] ( | |
| (word, wordVectors.slice(vectorSize * ind, vectorSize * ind + vectorSize)) | ||
| } | ||
| } | ||
|
|
||
| } | ||
|
|
||
| @Since("1.4.0") | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should target
1.6.1here:@Since("1.6.1")- overall I'd view this PR as a bugfix (though adding the parameter is a minor extra feature). I think we'd want to include this in branch-1.6, and possibly even think about backporting the core changes to branch-1.5