-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-5992][ML] Locality Sensitive Hashing #15148
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
1bbd48c
ca46d82
c693f5b
c9ee0f9
fc838e0
aa138e8
bbcbcf0
d389159
19d012a
269c8c9
9065f7d
d22dff4
7e6d938
0fad3ef
0080b87
a1c344b
396ad60
b79ebbd
7936315
f805658
8f04ee8
f82f3fe
ccd98f7
69efc84
eced98d
3487bcc
df19886
efe323c
142d8e9
40d1f1b
2c95e5c
fb120af
19f6d89
1b63173
126d980
a35e261
66d553a
cad4ecb
e14f73e
1c4b9fb
20a9ebf
9bb3fd6
9a3704c
6cda936
97e1238
3570845
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -45,7 +45,8 @@ private[ml] trait LSHParams extends HasInputCol with HasOutputCol { | |
| */ | ||
| @Since("2.1.0") | ||
| final val outputDim: IntParam = new IntParam(this, "outputDim", "output dimension, where" + | ||
| "increasing dimensionality lowers the false negative rate", ParamValidators.gt(0)) | ||
| "increasing dimensionality lowers the false negative rate, and decreasing dimensionality" + | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does increasing dimensionality lower the false negative rate?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No. Since we are implementing OR-amplification, increasing dimensionality lower the false negative rate. In AND-amplification, increasing dimensionality will lower the false positive rate. |
||
| " improves the running performance", ParamValidators.gt(0)) | ||
|
|
||
| /** @group getParam */ | ||
| @Since("2.1.0") | ||
|
|
@@ -56,8 +57,8 @@ private[ml] trait LSHParams extends HasInputCol with HasOutputCol { | |
|
|
||
| /** | ||
| * Transform the Schema for LSH | ||
| * @param schema The schema of the input dataset without outputCol | ||
| * @return A derived schema with outputCol added | ||
| * @param schema The schema of the input dataset without [[outputCol]] | ||
| * @return A derived schema with [[outputCol]] added | ||
| */ | ||
| @Since("2.1.0") | ||
| protected[this] final def validateAndTransformSchema(schema: StructType): StructType = { | ||
|
|
@@ -117,9 +118,9 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]] extends Model[T] with LSHP | |
|
|
||
| /** | ||
| * Given a large dataset and an item, approximately find at most k items which have the closest | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This method needs to document that it checks for the outputCol and transforms the data if it is missing, allowing caching of the transformed data when necessary.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
| * distance to the item. If the outputCol is missing, the method will transform the data; if the | ||
| * the outputCol exists, it will use the outputCol. This allows caching of the transformed data | ||
| * when necessary. | ||
| * distance to the item. If the [[outputCol]] is missing, the method will transform the data; if | ||
| * the [[outputCol]] exists, it will use the [[outputCol]]. This allows caching of the | ||
| * transformed data when necessary. | ||
| * | ||
| * This method implements two ways of fetching k nearest neighbors: | ||
| * - Single Probing: Fast, return at most k elements (Probing only one buckets) | ||
|
|
@@ -135,11 +136,11 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]] extends Model[T] with LSHP | |
| */ | ||
| @Since("2.1.0") | ||
| def approxNearestNeighbors( | ||
| @Since("2.1.0") dataset: Dataset[_], | ||
| @Since("2.1.0") key: Vector, | ||
| @Since("2.1.0") numNearestNeighbors: Int, | ||
| @Since("2.1.0") singleProbing: Boolean, | ||
| @Since("2.1.0") distCol: String): Dataset[_] = { | ||
| dataset: Dataset[_], | ||
| key: Vector, | ||
| numNearestNeighbors: Int, | ||
| singleProbing: Boolean, | ||
| distCol: String): Dataset[_] = { | ||
| require(numNearestNeighbors > 0, "The number of nearest neighbors cannot be less than 1") | ||
| // Get Hash Value of the key | ||
| val keyHash = hashFunction(key) | ||
|
|
@@ -177,21 +178,24 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]] extends Model[T] with LSHP | |
| */ | ||
| @Since("2.1.0") | ||
| def approxNearestNeighbors( | ||
| @Since("2.1.0") dataset: Dataset[_], | ||
| @Since("2.1.0") key: Vector, | ||
| @Since("2.1.0") numNearestNeighbors: Int): Dataset[_] = { | ||
| dataset: Dataset[_], | ||
| key: Vector, | ||
| numNearestNeighbors: Int): Dataset[_] = { | ||
| approxNearestNeighbors(dataset, key, numNearestNeighbors, true, "distCol") | ||
| } | ||
|
|
||
| /** | ||
| * Preprocess step for approximate similarity join. Transform and explode the outputCol to | ||
| * Preprocess step for approximate similarity join. Transform and explode the [[outputCol]] to | ||
| * explodeCols. | ||
|
||
| * @param dataset The dataset to transform and explode. | ||
| * @param explodeCols The alias for the exploded columns, must be a seq of two strings. | ||
| * @return A dataset containing idCol, inputCol and explodeCols | ||
| */ | ||
| @Since("2.1.0") | ||
| private[this] def processDataset(dataset: Dataset[_], explodeCols: Seq[String]): Dataset[_] = { | ||
| private[this] def processDataset( | ||
| dataset: Dataset[_], | ||
| inputName: String, | ||
| explodeCols: Seq[String]): Dataset[_] = { | ||
| require(explodeCols.size == 2, "explodeCols must be two strings.") | ||
| val vectorToMap: UserDefinedFunction = udf((x: Vector) => x.asBreeze.iterator.toMap, | ||
| MapType(DataTypes.IntegerType, DataTypes.DoubleType)) | ||
|
|
@@ -200,7 +204,9 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]] extends Model[T] with LSHP | |
| } else { | ||
| dataset.toDF() | ||
| } | ||
| modelDataset.select(col("*"), explode(vectorToMap(col($(outputCol)))).as(explodeCols)) | ||
| modelDataset.select( | ||
| struct(col("*")).as(inputName), | ||
| explode(vectorToMap(col($(outputCol)))).as(explodeCols)) | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -213,18 +219,21 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]] extends Model[T] with LSHP | |
| */ | ||
| @Since("2.1.0") | ||
| private[this] def recreateCol( | ||
| @Since("2.1.0") dataset: Dataset[_], | ||
| @Since("2.1.0") colName: String, | ||
| @Since("2.1.0") tmpColName: String): Dataset[_] = { | ||
| dataset: Dataset[_], | ||
| colName: String, | ||
| tmpColName: String): Dataset[_] = { | ||
| dataset | ||
| .withColumnRenamed(colName, tmpColName) | ||
| .withColumn(colName, col(tmpColName)) | ||
| .drop(tmpColName) | ||
| } | ||
|
|
||
| /** | ||
| * Join two dataset to approximately find all pairs of records whose distance are smaller | ||
| * than the threshold. | ||
| * Join two dataset to approximately find all pairs of records whose distance are smaller than | ||
| * the threshold. If the [[outputCol]] is missing, the method will transform the data; if the | ||
| * [[outputCol]] exists, it will use the [[outputCol]]. This allows caching of the transformed | ||
| * data when necessary. | ||
| * | ||
| * @param datasetA One of the datasets to join | ||
| * @param datasetB Another dataset to join | ||
| * @param threshold The threshold for the distance of record pairs | ||
|
|
@@ -234,21 +243,22 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]] extends Model[T] with LSHP | |
| */ | ||
| @Since("2.1.0") | ||
| def approxSimilarityJoin( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This too should document that it transforms data if needed, just like approxNearestNeighbors.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
| @Since("2.1.0") datasetA: Dataset[_], | ||
| @Since("2.1.0") datasetB: Dataset[_], | ||
| @Since("2.1.0") threshold: Double, | ||
| @Since("2.1.0") distCol: String): Dataset[_] = { | ||
| datasetA: Dataset[_], | ||
| datasetB: Dataset[_], | ||
| threshold: Double, | ||
| distCol: String): Dataset[_] = { | ||
|
|
||
| val explodeCols = Seq("lsh#entry", "lsh#hashValue") | ||
| val explodedA = processDataset(datasetA, explodeCols) | ||
| val explodeCols = Seq("entry", "hashValue") | ||
| val inputName = "input" | ||
| val explodedA = processDataset(datasetA, inputName, explodeCols) | ||
|
|
||
| // If this is a self join, we need to recreate the inputCol of datasetB to avoid ambiguity. | ||
| // TODO: Remove recreateCol logic once SPARK-17154 is resolved. | ||
| val explodedB = if (datasetA != datasetB) { | ||
| processDataset(datasetB, explodeCols) | ||
| processDataset(datasetB, inputName, explodeCols) | ||
| } else { | ||
| val recreatedB = recreateCol(datasetB, $(inputCol), s"${$(inputCol)}#${Random.nextString(5)}") | ||
| processDataset(recreatedB, explodeCols) | ||
| processDataset(recreatedB, inputName, explodeCols) | ||
| } | ||
|
|
||
| // Do a hash join on where the exploded hash values are equal. | ||
|
|
@@ -258,7 +268,8 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]] extends Model[T] with LSHP | |
| // Add a new column to store the distance of the two records. | ||
| val distUDF = udf((x: Vector, y: Vector) => keyDistance(x, y), DataTypes.DoubleType) | ||
| val joinedDatasetWithDist = joinedDataset.select(col("*"), | ||
| distUDF(explodedA($(inputCol)), explodedB($(inputCol))).as(distCol) | ||
| distUDF(explodedA(s"$inputName.${$(inputCol)}"), | ||
| explodedB(s"$inputName.${$(inputCol)}")).as(distCol) | ||
| ) | ||
|
|
||
| // Filter the joined datasets where the distance are smaller than the threshold. | ||
|
|
@@ -270,9 +281,9 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]] extends Model[T] with LSHP | |
| */ | ||
| @Since("2.1.0") | ||
| def approxSimilarityJoin( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The default distCol needs to be documented
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Scaladoc added. |
||
| @Since("2.1.0") datasetA: Dataset[_], | ||
| @Since("2.1.0") datasetB: Dataset[_], | ||
| @Since("2.1.0") threshold: Double): Dataset[_] = { | ||
| datasetA: Dataset[_], | ||
| datasetB: Dataset[_], | ||
| threshold: Double): Dataset[_] = { | ||
| approxSimilarityJoin(datasetA, datasetB, threshold, "distCol") | ||
| } | ||
| } | ||
|
|
@@ -282,19 +293,17 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]] extends Model[T] with LSHP | |
| * hash column, approximate nearest neighbor search with a dataset and a key, and approximate | ||
| * similarity join of two datasets. | ||
| * | ||
| * Currently the following LSH family is implemented: | ||
| * - Euclidean Distance: Random Projection | ||
| * | ||
| * References: | ||
| * (1) Gionis, Aristides, Piotr Indyk, and Rajeev Motwani. "Similarity search in high dimensions | ||
| * via hashing." VLDB 7 Sep. 1999: 518-529. | ||
| * (2) Wang, Jingdong et al. "Hashing for similarity search: A survey." arXiv preprint | ||
| * arXiv:1408.2927 (2014). | ||
| * @tparam T The class type of lsh | ||
| */ | ||
| @Experimental | ||
|
||
| @Since("2.1.0") | ||
| private[ml] abstract class LSH[T <: LSHModel[T]] extends Estimator[T] with LSHParams { | ||
|
||
| self: Estimator[T] => | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.1.0") | ||
| def setInputCol(value: String): this.type = set(inputCol, value) | ||
|
|
@@ -322,13 +331,9 @@ private[ml] abstract class LSH[T <: LSHModel[T]] extends Estimator[T] with LSHPa | |
|
|
||
| @Since("2.1.0") | ||
| override def fit(dataset: Dataset[_]): T = { | ||
| transformSchema(dataset.schema, logging = true) | ||
| val inputDim = dataset.select(col($(inputCol))).head().get(0).asInstanceOf[Vector].size | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd call transformSchema here before extracting inputDim
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
| val model = createRawLSHModel(inputDim).setParent(this) | ||
| copyValues(model) | ||
| } | ||
|
|
||
| @Since("2.1.0") | ||
| override def transformSchema(schema: StructType): StructType = { | ||
| validateAndTransformSchema(schema) | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,8 +20,9 @@ package org.apache.spark.ml.feature | |
| import scala.util.Random | ||
|
|
||
| import org.apache.spark.annotation.{Experimental, Since} | ||
| import org.apache.spark.ml.linalg.{Vector, Vectors} | ||
| import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} | ||
| import org.apache.spark.ml.util.Identifiable | ||
| import org.apache.spark.sql.types.StructType | ||
|
|
||
| /** | ||
| * Model produced by [[MinHash]] | ||
|
|
@@ -87,7 +88,7 @@ class MinHash private[ml] (override val uid: String) extends LSH[MinHashModel] { | |
| @Since("2.1.0") | ||
| override protected[this] def createRawLSHModel(inputDim: Int): MinHashModel = { | ||
| val numEntry = inputDim * 2 | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This could overflow. Use
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
| assert(numEntry < prime, "The input vector dimension is too large for MinHash to handle.") | ||
| require(numEntry < prime, "The input vector dimension is too large for MinHash to handle.") | ||
| val hashFunctions: Seq[Int => Long] = { | ||
| (0 until $(outputDim)).map { i: Int => | ||
|
||
| // Perfect Hash function, use 2n buckets to reduce collision. | ||
|
|
@@ -96,4 +97,11 @@ class MinHash private[ml] (override val uid: String) extends LSH[MinHashModel] { | |
| } | ||
| new MinHashModel(uid, hashFunctions) | ||
| } | ||
|
|
||
| @Since("2.1.0") | ||
| override def transformSchema(schema: StructType): StructType = { | ||
| require(schema.apply($(inputCol)).dataType.sameType(new VectorUDT), | ||
|
||
| s"${$(inputCol)} must be vectors") | ||
| validateAndTransformSchema(schema) | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,9 +22,10 @@ import scala.util.Random | |
| import breeze.linalg.normalize | ||
|
|
||
| import org.apache.spark.annotation.{Experimental, Since} | ||
| import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors} | ||
| import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors, VectorUDT} | ||
| import org.apache.spark.ml.param.{DoubleParam, Params, ParamValidators} | ||
| import org.apache.spark.ml.util.Identifiable | ||
| import org.apache.spark.sql.types.StructType | ||
|
|
||
| /** | ||
| * Params for [[RandomProjection]]. | ||
|
|
@@ -43,7 +44,7 @@ private[ml] trait RandomProjectionParams extends Params { | |
| } | ||
|
|
||
| /** | ||
| * Model produced by [[LSH]] | ||
| * Model produced by [[RandomProjection]] | ||
|
||
| */ | ||
| @Experimental | ||
| @Since("2.1.0") | ||
|
|
@@ -116,4 +117,11 @@ class RandomProjection private[ml] ( | |
| } | ||
| new RandomProjectionModel(uid, randUnitVectors) | ||
| } | ||
|
|
||
| @Since("2.1.0") | ||
| override def transformSchema(schema: StructType): StructType = { | ||
| require(schema.apply($(inputCol)).dataType.sameType(new VectorUDT), | ||
|
||
| s"${$(inputCol)} must be vectors") | ||
| validateAndTransformSchema(schema) | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These Since annotations within LSHParams, LSH, and LSHModel need to be removed as well. They are correct now, but if a new subclass is added in say spark 2.2, then they will be incorrect for that subclass. Sorry for misdirecting before!