Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Python API for mllib.feature
  • Loading branch information
davies committed Oct 16, 2014
commit 8a50584ed6ea38b5fccc64e6da3fc18d4513c9c5
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.mllib.classification._
import org.apache.spark.mllib.clustering._
import org.apache.spark.mllib.feature.Word2Vec
import org.apache.spark.mllib.feature.Word2VecModel
import org.apache.spark.mllib.feature._
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.random.{RandomRDDs => RG}
Expand Down Expand Up @@ -289,6 +288,43 @@ class PythonMLLibAPI extends Serializable {
ALS.trainImplicit(ratingsJRDD.rdd, rank, iterations, lambda, blocks, alpha)
}

/**
* Java stub for Normalizer.transform()
*/
def normalizeVector(p: Double, vector: Vector): Vector = {
new Normalizer(p).transform(vector)
}

/**
* Java stub for Normalizer.transform()
*/
def normalizeVector(p: Double, rdd: JavaRDD[Vector]): JavaRDD[Vector] = {
new Normalizer(p).transform(rdd)
}

/**
* Java stub for IDF.fit(). This stub returns a
* handle to the Java object instead of the content of the Java object.
* Extra care needs to be taken in the Python code to ensure it gets freed on
* exit; see the Py4J documentation.
*/
def fitStandardScaler(
withMean: Boolean,
withStd: Boolean,
data: JavaRDD[Vector]): StandardScalerModel = {
new StandardScaler(withMean, withStd).fit(data.rdd)
}

/**
* Java stub for IDF.fit(). This stub returns a
* handle to the Java object instead of the content of the Java object.
* Extra care needs to be taken in the Python code to ensure it gets freed on
* exit; see the Py4J documentation.
*/
def fitIDF(minDocFreq: Int, dataset: JavaRDD[Vector]): IDFModel = {
new IDF(minDocFreq).fit(dataset)
}

/**
* Java stub for Python mllib Word2Vec fit(). This stub returns a
* handle to the Java object instead of the content of the Java object.
Expand Down Expand Up @@ -326,6 +362,16 @@ class PythonMLLibAPI extends Serializable {
model.transform(word)
}

/**
* TODO: model is not serializable
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this an outdated comment?

* Transforms an RDD of words to its vector representation
* @param rdd an RDD of words
* @return an RDD of vector representations of words
*/
def transform(rdd: JavaRDD[String]): JavaRDD[Vector] = {
rdd.rdd.map(model.transform(_))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could remove "(_)"

}

def findSynonyms(word: String, num: Int): java.util.List[java.lang.Object] = {
val vec = transform(word)
findSynonyms(vec, num)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.mllib.feature
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.rdd.RDD
import org.apache.spark.api.java.JavaRDD
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

organize imports in alphabetical order


/**
* :: DeveloperApi ::
Expand Down Expand Up @@ -48,4 +49,14 @@ trait VectorTransformer extends Serializable {
data.map(x => this.transform(x))
}

/**
* Applies transformation on an JavaRDD[Vector].
*
* @param data JavaRDD[Vector] to be transformed.
* @return transformed JavaRDD[Vector].
*/
def transform(data: JavaRDD[Vector]): JavaRDD[Vector] = {
transform(data.rdd)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ class Word2VecModel private[mllib] (
throw new IllegalStateException(s"$word not in vocabulary")
}
}

/**
* Find synonyms of a word
* @param word a word
Expand All @@ -443,7 +443,7 @@ class Word2VecModel private[mllib] (
val vector = transform(word)
findSynonyms(vector,num)
}

/**
* Find synonyms of the vector representation of a word
* @param vector vector representation of a word
Expand Down
Loading