Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Update Word2Vec.scala
  • Loading branch information
dgai91 authored Jun 20, 2017
commit d204612a8159cd0672633c753e75335cc99da7ff
11 changes: 6 additions & 5 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -274,13 +274,14 @@ class Word2VecModel private[ml] (
wordVectors.findSynonyms(word, num)
}


def doc2Vector(text: String, d: Int): SDV = {
val bVectors = wordVectors.getVectors.collect()
val textArray = text.split(" ")
/**
* using model.getVectors can get the wordVectors then you must convert the DataFrame
* to an Array.
*/
def doc2Vector(textArray: Array[String], d: Int, wordVectors: Array[Row]): SDV = {
var sum = Vectors.zeros(d)
textArray.foreach { word =>
bVectors.value.filter(_.getAs[String]("word") == word).foreach { v =>
wordVectors.value.filter(_.getAs[String]("word") == word).foreach { v =>
val sv = v.getAs[SDV]("vector")
BLAS.axpy(1.0, sv, sum)
}
Expand Down