-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-5992][ML] Locality Sensitive Hashing #15148
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
1bbd48c
ca46d82
c693f5b
c9ee0f9
fc838e0
aa138e8
bbcbcf0
d389159
19d012a
269c8c9
9065f7d
d22dff4
7e6d938
0fad3ef
0080b87
a1c344b
396ad60
b79ebbd
7936315
f805658
8f04ee8
f82f3fe
ccd98f7
69efc84
eced98d
3487bcc
df19886
efe323c
142d8e9
40d1f1b
2c95e5c
fb120af
19f6d89
1b63173
126d980
a35e261
66d553a
cad4ecb
e14f73e
1c4b9fb
20a9ebf
9bb3fd6
9a3704c
6cda936
97e1238
3570845
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
… functions in Min Hash
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,7 +19,6 @@ package org.apache.spark.ml.feature | |
|
|
||
| import scala.util.Random | ||
|
|
||
| import org.apache.spark.annotation.Since | ||
| import org.apache.spark.ml.{Estimator, Model} | ||
| import org.apache.spark.ml.linalg.{Vector, VectorUDT} | ||
| import org.apache.spark.ml.param.{IntParam, ParamValidators} | ||
|
|
@@ -40,13 +39,11 @@ private[ml] trait LSHParams extends HasInputCol with HasOutputCol { | |
| * higher the dimension is, the lower the false negative rate. | ||
| * @group param | ||
| */ | ||
| @Since("2.1.0") | ||
| final val outputDim: IntParam = new IntParam(this, "outputDim", "output dimension, where" + | ||
| "increasing dimensionality lowers the false negative rate, and decreasing dimensionality" + | ||
| " improves the running performance", ParamValidators.gt(0)) | ||
|
|
||
| /** @group getParam */ | ||
| @Since("2.1.0") | ||
| final def getOutputDim: Int = $(outputDim) | ||
|
|
||
| setDefault(outputDim -> 1, outputCol -> "lshFeatures") | ||
|
||
|
|
@@ -56,7 +53,6 @@ private[ml] trait LSHParams extends HasInputCol with HasOutputCol { | |
| * @param schema The schema of the input dataset without [[outputCol]] | ||
| * @return A derived schema with [[outputCol]] added | ||
| */ | ||
| @Since("2.1.0") | ||
| protected[this] final def validateAndTransformSchema(schema: StructType): StructType = { | ||
| SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The inputCol cannot be checked here since its type may be algorithm-dependent, but it should be checked in transformSchema or a similar validateAndTransformSchema in the MinHash and RP algorithms below.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, I did not get it, there is no check for inputCol here.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I meant that transformSchema should validate that inputCol has the correct DataType. That can be done by putting a line in each algorithm's transformSchema.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. I will add that. |
||
| } | ||
|
|
@@ -73,7 +69,6 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]] | |
| * The hash function of LSH, mapping a predefined KeyType to a Vector | ||
| * @return The mapping of LSH function. | ||
| */ | ||
| @Since("2.1.0") | ||
| protected[ml] val hashFunction: Vector => Vector | ||
|
|
||
| /** | ||
|
|
@@ -83,7 +78,6 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]] | |
| * @param y One input vector in the metric space | ||
| * @return The distance between x and y | ||
| */ | ||
| @Since("2.1.0") | ||
| protected[ml] def keyDistance(x: Vector, y: Vector): Double | ||
|
|
||
| /** | ||
|
|
@@ -93,17 +87,14 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]] | |
| * @param y Another hash vector | ||
| * @return The distance between hash vectors x and y | ||
| */ | ||
| @Since("2.1.0") | ||
| protected[ml] def hashDistance(x: Vector, y: Vector): Double | ||
|
|
||
| @Since("2.1.0") | ||
| override def transform(dataset: Dataset[_]): DataFrame = { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need to copy documentation for overridden methods, unless the docs are specialized for this class
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
| transformSchema(dataset.schema, logging = true) | ||
| val transformUDF = udf(hashFunction, new VectorUDT) | ||
| dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol)))) | ||
| } | ||
|
|
||
| @Since("2.1.0") | ||
| override def transformSchema(schema: StructType): StructType = { | ||
| validateAndTransformSchema(schema) | ||
| } | ||
|
|
@@ -126,7 +117,6 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]] | |
| * @return A dataset containing at most k items closest to the key. A distCol is added to show | ||
| * the distance between each row and the key. | ||
| */ | ||
| @Since("2.1.0") | ||
| def approxNearestNeighbors( | ||
| dataset: Dataset[_], | ||
| key: Vector, | ||
|
|
@@ -168,7 +158,6 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]] | |
| * Overloaded method for approxNearestNeighbors. Use Single Probing as default way to search | ||
| * nearest neighbors and "distCol" as default distCol. | ||
| */ | ||
| @Since("2.1.0") | ||
| def approxNearestNeighbors( | ||
| dataset: Dataset[_], | ||
| key: Vector, | ||
|
|
@@ -185,7 +174,6 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]] | |
| * @param explodeCols The alias for the exploded columns, must be a seq of two strings. | ||
| * @return A dataset containing idCol, inputCol and explodeCols | ||
| */ | ||
| @Since("2.1.0") | ||
| private[this] def processDataset( | ||
| dataset: Dataset[_], | ||
| inputName: String, | ||
|
|
@@ -211,7 +199,6 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]] | |
| * @param tmpColName A temporary column name which does not conflict with existing columns | ||
| * @return | ||
| */ | ||
| @Since("2.1.0") | ||
| private[this] def recreateCol( | ||
| dataset: Dataset[_], | ||
| colName: String, | ||
|
|
@@ -235,7 +222,6 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]] | |
| * @return A joined dataset containing pairs of rows. The original rows are in columns | ||
| * "datasetA" and "datasetB", and a distCol is added to show the distance of each pair | ||
| */ | ||
| @Since("2.1.0") | ||
| def approxSimilarityJoin( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This too should document that it transforms data if needed, just like approxNearestNeighbors.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
| datasetA: Dataset[_], | ||
| datasetB: Dataset[_], | ||
|
|
@@ -273,7 +259,6 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]] | |
| /** | ||
| * Overloaded method for approxSimilarityJoin. Use "distCol" as default distCol. | ||
| */ | ||
| @Since("2.1.0") | ||
| def approxSimilarityJoin( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The default distCol needs to be documented
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Scaladoc added. |
||
| datasetA: Dataset[_], | ||
| datasetB: Dataset[_], | ||
|
|
@@ -302,15 +287,12 @@ private[ml] abstract class LSH[T <: LSHModel[T]] | |
| self: Estimator[T] => | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.1.0") | ||
| def setInputCol(value: String): this.type = set(inputCol, value) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.1.0") | ||
| def setOutputCol(value: String): this.type = set(outputCol, value) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.1.0") | ||
| def setOutputDim(value: Int): this.type = set(outputDim, value) | ||
|
|
||
| /** | ||
|
|
@@ -320,10 +302,8 @@ private[ml] abstract class LSH[T <: LSHModel[T]] | |
| * @param inputDim The dimension of the input dataset | ||
| * @return A new LSHModel instance without any params | ||
| */ | ||
| @Since("2.1.0") | ||
| protected[this] def createRawLSHModel(inputDim: Int): T | ||
|
|
||
| @Since("2.1.0") | ||
| override def fit(dataset: Dataset[_]): T = { | ||
| transformSchema(dataset.schema, logging = true) | ||
| val inputDim = dataset.select(col($(inputCol))).head().get(0).asInstanceOf[Vector].size | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd call transformSchema here before extracting inputDim
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,7 +30,14 @@ import org.apache.spark.sql.types.StructType | |
|
|
||
| /** | ||
| * :: Experimental :: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need to mark a private class Experimental
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removed. |
||
| * Model produced by [[MinHash]] | ||
| * Model produced by [[MinHash]], where multiple hash functions are stored. Each hash function is | ||
| * a perfect hash function: | ||
| * g_i(x) = (x * k_i mod prime) mod numEntries | ||
| * where c_i is the i-th coefficient | ||
| * | ||
| * Reference: | ||
| * https://en.wikipedia.org/wiki/Perfect_hash_function | ||
| * | ||
| * @param numEntries The number of entries of the hash functions. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The doc could be clearer here. If I did not read the code, I might not know what "entries" are. One good way to explain this would be to just state what the hash function is here. |
||
| * @param randCoefficients An array of random coefficients, each used by one hash function. | ||
| */ | ||
|
|
@@ -117,7 +124,7 @@ class MinHash(override val uid: String) extends LSH[MinHashModel] with HasSeed { | |
| @Since("2.1.0") | ||
| override protected[ml] def createRawLSHModel(inputDim: Int): MinHashModel = { | ||
| require(inputDim <= MinHash.prime / 2, | ||
| "The input vector dimension is too large for MinHash to handle.") | ||
| s"The input vector dimension $inputDim exceeds the threshold ${MinHash.prime / 2}.") | ||
| val rand = new Random($(seed)) | ||
| val numEntry = inputDim * 2 | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This could overflow. Use
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
| val randCoofs: Array[Int] = Array.fill($(outputDim))(1 + rand.nextInt(MinHash.prime - 1)) | ||
|
|
@@ -158,7 +165,6 @@ object MinHashModel extends MLReadable[MinHashModel] { | |
|
|
||
| override protected def saveImpl(path: String): Unit = { | ||
| DefaultParamsWriter.saveMetadata(instance, path, sc) | ||
| // Save model data: pi, theta | ||
| val data = Data(instance.numEntries, instance.randCoefficients) | ||
| val dataPath = new Path(path, "data").toString | ||
| sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -44,13 +44,11 @@ private[ml] trait RandomProjectionParams extends Params { | |
| * reasonable value | ||
| * @group param | ||
| */ | ||
| @Since("2.1.0") | ||
| val bucketLength: DoubleParam = new DoubleParam(this, "bucketLength", | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add Scala doc for bucketLength. Some guidance on reasonable value ranges would be good. E.g., "If input vectors have unit norm, then ...." In doc, put bucketLength in
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
| "the length of each hash bucket, a larger bucket lowers the false negative rate.", | ||
| ParamValidators.gt(0)) | ||
|
|
||
| /** @group getParam */ | ||
| @Since("2.1.0") | ||
| final def getBucketLength: Double = $(bucketLength) | ||
| } | ||
|
|
||
|
|
@@ -180,7 +178,6 @@ object RandomProjectionModel extends MLReadable[RandomProjectionModel] { | |
|
|
||
| override protected def saveImpl(path: String): Unit = { | ||
| DefaultParamsWriter.saveMetadata(instance, path, sc) | ||
| // Save model data: pi, theta | ||
| val numRows = instance.randUnitVectors.length | ||
| require(numRows > 0) | ||
| val numCols = instance.randUnitVectors.head.size | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does increasing dimensionality lower the false negative rate?
I think increasing dimensionality should lower false positive rate, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No. Since we are implementing OR-amplification, increasing dimensionality lower the false negative rate.
In AND-amplification, increasing dimensionality will lower the false positive rate.