Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
dab8425
add support of arbitrary length sentence by using the nature represen…
Dec 5, 2015
05d1eae
added more javadoc to make changes more clear for usage
Dec 5, 2015
c226e9b
removed distance functions
Dec 12, 2015
bd67405
fixed a wrong placement of statement and formatted more
Dec 12, 2015
7d4e9dd
addressed comments
Dec 12, 2015
d8d8b0b
make the normalization default behavior
Dec 16, 2015
6f9e9f0
adjust comments as advised.
Dec 16, 2015
f7c9296
adjust comments for passing lint constraint.
Dec 16, 2015
3309913
removed unnecessary getters
Dec 16, 2015
ced05d1
add since annotation according to suggestion from @MLnick
Dec 19, 2015
d6ae270
modify comments according to suggestion from @MLnick
Dec 22, 2015
214d0d9
change since tag to 2.0.0 per request
Dec 23, 2015
909dbbd
Merge pull request #1 from apache/master
ygcao Jan 18, 2016
76e8266
handle potential divide by zero, and adjust test case to reflect our …
Jan 18, 2016
141d7a2
adjust expected value for pySpark's test cases about findSynonyms
Jan 18, 2016
5aaff6e
adjust output formatting
Jan 20, 2016
71089c4
Merge pull request #2 from apache/master
ygcao Jan 20, 2016
e938208
Merge branch 'improvementForSentenceBoundary' of https://github.com/y…
Jan 20, 2016
443ec06
adopted suggestions for simplification
Jan 31, 2016
84a0bc4
handle the lint warning.:add braces for else statement
Feb 8, 2016
32df78e
beautify according to comments
Feb 10, 2016
84feb3d
address comments for comments and var name
Feb 12, 2016
2e052e5
removed extra braces
Feb 16, 2016
a4abd40
removed unnecessary if statement
Feb 17, 2016
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 37 additions & 37 deletions mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,18 @@ class Word2Vec extends Serializable with Logging {
private var numIterations = 1
private var seed = Utils.random.nextLong()
private var minCount = 5
private var maxSentenceLength = 1000

/**
* Sets the maximum length (in words) of each sentence in the input data.
* Any sentence longer than this threshold will be divided into chunks of
* up to `maxSentenceLength` size (default: 1000)
*/
@Since("2.0.0")
def setMaxSentenceLength(maxSentenceLength: Int): this.type = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should target 1.6.1 here: @Since("1.6.1") - overall I'd view this PR as a bugfix (though adding the parameter is a minor extra feature). I think we'd want to include this in branch-1.6, and possibly even think about backporting the core changes to branch-1.5

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not clear from the doc what "sentence length" means, number of words or number of characters. We can either update the doc or change the param name to maxWordsPerSentence to make this clear from the name.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The param name comes from the original Google implementation. Either option (or both) works, but I guess I'd be marginally more in favour of amending the first line of doc to read ... maximum length (in words) of each ..., or something similar.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good.

this.maxSentenceLength = maxSentenceLength
this
}

/**
* Sets vector size (default: 100).
Expand Down Expand Up @@ -146,7 +158,6 @@ class Word2Vec extends Serializable with Logging {
private val EXP_TABLE_SIZE = 1000
private val MAX_EXP = 6
private val MAX_CODE_LENGTH = 40
private val MAX_SENTENCE_LENGTH = 1000

/** context words from [-window, window] */
private var window = 5
Expand All @@ -156,7 +167,9 @@ class Word2Vec extends Serializable with Logging {
@transient private var vocab: Array[VocabWord] = null
@transient private var vocabHash = mutable.HashMap.empty[String, Int]

private def learnVocab(words: RDD[String]): Unit = {
private def learnVocab[S <: Iterable[String]](dataset: RDD[S]): Unit = {
val words = dataset.flatMap(x => x)

vocab = words.map(w => (w, 1))
.reduceByKey(_ + _)
.filter(_._2 >= minCount)
Expand Down Expand Up @@ -272,15 +285,14 @@ class Word2Vec extends Serializable with Logging {

/**
* Computes the vector representation of each word in vocabulary.
* @param dataset an RDD of words
* @param dataset an RDD of sentences,
* each sentence is expressed as an iterable collection of words
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should remove L286 and change learnVocab to learnVocab(dataset). The val words = dataset.flatMap(x => x) can the be moved into learnVocab. This seems clearer to me, to avoid confusion, since words is no longer used apart from the learn vocab step.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Will do

* @return a Word2VecModel
*/
@Since("1.1.0")
def fit[S <: Iterable[String]](dataset: RDD[S]): Word2VecModel = {

val words = dataset.flatMap(x => x)

learnVocab(words)
learnVocab(dataset)

createBinaryTree()

Expand All @@ -289,25 +301,15 @@ class Word2Vec extends Serializable with Logging {
val expTable = sc.broadcast(createExpTable())
val bcVocab = sc.broadcast(vocab)
val bcVocabHash = sc.broadcast(vocabHash)

val sentences: RDD[Array[Int]] = words.mapPartitions { iter =>
new Iterator[Array[Int]] {
def hasNext: Boolean = iter.hasNext

def next(): Array[Int] = {
val sentence = ArrayBuilder.make[Int]
var sentenceLength = 0
while (iter.hasNext && sentenceLength < MAX_SENTENCE_LENGTH) {
val word = bcVocabHash.value.get(iter.next())
word match {
case Some(w) =>
sentence += w
sentenceLength += 1
case None =>
}
}
sentence.result()
}
// each partition is a collection of sentences,
// will be translated into arrays of Index integer
val sentences: RDD[Array[Int]] = dataset.mapPartitions { sentenceIter =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can simply do:

    val sentences: RDD[Array[Int]] = dataset.mapPartitions { iter =>
      new Iterator[Array[Int]] {
        def hasNext: Boolean = iter.hasNext

        def next(): Array[Int] = {
          val sentence = ArrayBuilder.make[Int]
          var sentenceLength = 0
          val wordIter = iter.next().iterator
          while (wordIter.hasNext && sentenceLength < MAX_SENTENCE_LENGTH) {
            val word = bcVocabHash.value.get(wordIter.next())
            word match {
              case Some(w) =>
                sentence += w
                sentenceLength += 1
              case None =>
            }
          }
          sentence.result()
        }
      }
    }

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very close to my original version exception for will throw out all words after MAX_SENTENCE_LENGTH, and are you preferring to make the maxSentenceLength static config?

The latest version of mine will still try to take use of the rest words of the sentence to cut for training after cutting by maxSentenceLength. e.g. for a 2200 words long sentence, it will be used as three cut sentences just like the old version except for the last/third sentence from the cut will be 200 words long without words padded from the next sentence. This way, we can maximize the usage of our data with both respecting sentence boundary and sentence length restriction.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah right, yes you're correct. We should chunk up the sentences.

(I had copy-pasted the MAX_SENTENCE_LENGTH from the existing code, it should of course be the maxSentenceLength var)

// Each sentence will map to 0 or more Array[Int]
sentenceIter.flatMap { sentence =>
// Sentence of words, some of which map to a word index
val wordIndexes = sentence.flatMap(bcVocabHash.value.get)
// break wordIndexes into trunks of maxSentenceLength when has more
wordIndexes.grouped(maxSentenceLength).map(_.toArray)
}
}

Expand Down Expand Up @@ -477,15 +479,6 @@ class Word2VecModel private[spark] (
this(Word2VecModel.buildWordIndex(model), Word2VecModel.buildWordVectors(model))
}

private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = {
require(v1.length == v2.length, "Vectors should have the same length")
val n = v1.length
val norm1 = blas.snrm2(n, v1, 1)
val norm2 = blas.snrm2(n, v2, 1)
if (norm1 == 0 || norm2 == 0) return 0.0
blas.sdot(n, v1, 1, v2, 1) / norm1 / norm2
}

override protected def formatVersion = "1.0"

@Since("1.4.0")
Expand Down Expand Up @@ -542,6 +535,7 @@ class Word2VecModel private[spark] (
// Need not divide with the norm of the given vector since it is constant.
val cosVec = cosineVec.map(_.toDouble)
var ind = 0
val vecNorm = blas.snrm2(vectorSize, fVector, 1)
while (ind < numWords) {
val norm = wordVecNorms(ind)
if (norm == 0.0) {
Expand All @@ -551,12 +545,17 @@ class Word2VecModel private[spark] (
}
ind += 1
}
wordList.zip(cosVec)
var topResults = wordList.zip(cosVec)
.toSeq
.sortBy(- _._2)
.sortBy(-_._2)
.take(num + 1)
.tail
.toArray
if (vecNorm != 0.0f) {
topResults = topResults.map { case (word, cosVal) =>
(word, cosVal / vecNorm)
}
}
topResults.toArray
}

/**
Expand All @@ -568,6 +567,7 @@ class Word2VecModel private[spark] (
(word, wordVectors.slice(vectorSize * ind, vectorSize * ind + vectorSize))
}
}

}

@Since("1.4.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
.setSeed(42L)
.fit(docDF)

val expectedSimilarity = Array(0.18032623242822343, -0.5717976464798823)
val expectedSimilarity = Array(0.2608488929093532, -0.8271274846926078)
val (synonyms, similarity) = model.findSynonyms("a", 2).map {
case Row(w: String, sim: Double) => (w, sim)
}.collect().unzip
Expand Down
12 changes: 6 additions & 6 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -1837,12 +1837,12 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has
+----+--------------------+
...
>>> model.findSynonyms("a", 2).show()
+----+--------------------+
|word| similarity|
+----+--------------------+
| b| 0.16782984556103436|
| c|-0.46761559092107646|
+----+--------------------+
+----+-------------------+
|word| similarity|
+----+-------------------+
| b| 0.2505344027513247|
| c|-0.6980510075367647|
+----+-------------------+
...
>>> model.transform(doc).head().model
DenseVector([0.5524, -0.4995, -0.3599, 0.0241, 0.3461])
Expand Down