Skip to content

Commit 27524a3

Browse files
Yuming Wangmengxr
authored andcommitted
[SPARK-11626][ML] ml.feature.Word2Vec.transform() function very slow
org.apache.spark.ml.feature.Word2Vec.transform() very slow. we should not read broadcast every sentence. Author: Yuming Wang <[email protected]> Author: yuming.wang <[email protected]> Author: Xiangrui Meng <[email protected]> Closes #9592 from 979969786/master.
1 parent 1510c52 commit 27524a3

File tree

1 file changed

+16
-18
lines changed

1 file changed

+16
-18
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,16 @@
1717

1818
package org.apache.spark.ml.feature
1919

20-
import org.apache.spark.annotation.Experimental
2120
import org.apache.spark.SparkContext
21+
import org.apache.spark.annotation.Experimental
2222
import org.apache.spark.ml.{Estimator, Model}
2323
import org.apache.spark.ml.param._
2424
import org.apache.spark.ml.param.shared._
2525
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
2626
import org.apache.spark.mllib.feature
27-
import org.apache.spark.mllib.linalg.{VectorUDT, Vector, Vectors}
28-
import org.apache.spark.mllib.linalg.BLAS._
29-
import org.apache.spark.sql.DataFrame
27+
import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT, Vectors}
28+
import org.apache.spark.sql.{DataFrame, SQLContext}
3029
import org.apache.spark.sql.functions._
31-
import org.apache.spark.sql.SQLContext
3230
import org.apache.spark.sql.types._
3331

3432
/**
@@ -148,10 +146,9 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel]
148146
@Experimental
149147
class Word2VecModel private[ml] (
150148
override val uid: String,
151-
wordVectors: feature.Word2VecModel)
149+
@transient wordVectors: feature.Word2VecModel)
152150
extends Model[Word2VecModel] with Word2VecBase {
153151

154-
155152
/**
156153
* Returns a dataframe with two fields, "word" and "vector", with "word" being a String and
157154
* and the vector the DenseVector that it is mapped to.
@@ -197,22 +194,23 @@ class Word2VecModel private[ml] (
197194
*/
198195
override def transform(dataset: DataFrame): DataFrame = {
199196
transformSchema(dataset.schema, logging = true)
200-
val bWordVectors = dataset.sqlContext.sparkContext.broadcast(wordVectors)
197+
val vectors = wordVectors.getVectors
198+
.mapValues(vv => Vectors.dense(vv.map(_.toDouble)))
199+
.map(identity) // mapValues doesn't return a serializable map (SI-7005)
200+
val bVectors = dataset.sqlContext.sparkContext.broadcast(vectors)
201+
val d = $(vectorSize)
201202
val word2Vec = udf { sentence: Seq[String] =>
202203
if (sentence.size == 0) {
203-
Vectors.sparse($(vectorSize), Array.empty[Int], Array.empty[Double])
204+
Vectors.sparse(d, Array.empty[Int], Array.empty[Double])
204205
} else {
205-
val cum = Vectors.zeros($(vectorSize))
206-
val model = bWordVectors.value.getVectors
207-
for (word <- sentence) {
208-
if (model.contains(word)) {
209-
axpy(1.0, bWordVectors.value.transform(word), cum)
210-
} else {
211-
// pass words which not belong to model
206+
val sum = Vectors.zeros(d)
207+
sentence.foreach { word =>
208+
bVectors.value.get(word).foreach { v =>
209+
BLAS.axpy(1.0, v, sum)
212210
}
213211
}
214-
scal(1.0 / sentence.size, cum)
215-
cum
212+
BLAS.scal(1.0 / sentence.size, sum)
213+
sum
216214
}
217215
}
218216
dataset.withColumn($(outputCol), word2Vec(col($(inputCol))))

0 commit comments

Comments
 (0)