-
Notifications
You must be signed in to change notification settings - Fork 29k
[Spark-18408][ML] API Improvements for LSH #15874
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
559c099
517a97b
b546dbd
a3cd928
c8243c7
6aac8b3
9870743
0e9250b
adbbefe
c115ed3
033ae5d
c597f4c
d759875
596eb06
00d08bf
3d0810f
257ef19
2c264b7
36ca278
4508393
939e9d5
8b9403d
f0ebcb7
e198080
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -33,10 +33,10 @@ import org.apache.spark.sql.types._ | |
| */ | ||
| private[ml] trait LSHParams extends HasInputCol with HasOutputCol { | ||
| /** | ||
| * Param for the dimension of LSH OR-amplification. | ||
| * Param for the number of hash tables used in LSH OR-amplification. | ||
| * | ||
| * LSH OR-amplification can be used to reduce the false negative rate. The higher the dimension | ||
| * is, the lower the false negative rate. | ||
| * LSH OR-amplification can be used to reduce the false negative rate. Higher values for this | ||
| * param lead to a reduced false negative rate, at the expense of added computational complexity. | ||
| * @group param | ||
| */ | ||
| final val numHashTables: IntParam = new IntParam(this, "numHashTables", "number of hash " + | ||
|
|
@@ -66,7 +66,7 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]] | |
| self: T => | ||
|
|
||
| /** | ||
| * The hash function of LSH, mapping an input feature to multiple vectors | ||
| * The hash function of LSH, mapping an input feature vector to multiple hash vectors. | ||
| * @return The mapping of LSH function. | ||
| */ | ||
| protected[ml] val hashFunction: Vector => Array[Vector] | ||
|
|
@@ -99,26 +99,7 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]] | |
| validateAndTransformSchema(schema) | ||
| } | ||
|
|
||
| /** | ||
| * Given a large dataset and an item, approximately find at most k items which have the closest | ||
| * distance to the item. If the [[outputCol]] is missing, the method will transform the data; if | ||
| * the [[outputCol]] exists, it will use the [[outputCol]]. This allows caching of the | ||
| * transformed data when necessary. | ||
| * | ||
| * This method implements two ways of fetching k nearest neighbors: | ||
| * - Single-probe: Fast, return at most k elements (Probing only one buckets) | ||
| * - Multi-probe: Slow, return exact k elements (Probing multiple buckets close to the key) | ||
| * | ||
| * Currently it is made private since more discussion is needed for Multi-probe | ||
| * | ||
| * @param dataset the dataset to search for nearest neighbors of the key | ||
| * @param key Feature vector representing the item to search for | ||
| * @param numNearestNeighbors The maximum number of nearest neighbors | ||
| * @param singleProbe True for using single-probe; false for multi-probe | ||
| * @param distCol Output column for storing the distance between each result row and the key | ||
| * @return A dataset containing at most k items closest to the key. A distCol is added to show | ||
| * the distance between each row and the key. | ||
| */ | ||
| // TODO: Fix the MultiProbe NN Search in SPARK-18454 | ||
| private[feature] def approxNearestNeighbors( | ||
| dataset: Dataset[_], | ||
| key: Vector, | ||
|
|
@@ -179,7 +160,7 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]] | |
| * @return A dataset containing at most k items closest to the key. A distCol is added to show | ||
|
||
| * the distance between each row and the key. | ||
| */ | ||
| private[feature] def approxNearestNeighbors( | ||
| def approxNearestNeighbors( | ||
| dataset: Dataset[_], | ||
| key: Vector, | ||
| numNearestNeighbors: Int, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,31 +32,31 @@ import org.apache.spark.sql.types.StructType | |
| * :: Experimental :: | ||
| * | ||
| * Model produced by [[MinHashLSH]], where multiple hash functions are stored. Each hash function is | ||
| * a perfect hash function for a specific set `S` with cardinality equal to a half of `numEntries`: | ||
| * `h_i(x) = ((x \cdot k_i) \mod prime) \mod numEntries` | ||
| * a perfect hash function for a specific set `S` with cardinality equal to `numEntries`: | ||
|
||
| * `h_i(x) = ((x \cdot a_i + b_i) \mod prime) \mod numEntries` | ||
|
||
| * | ||
| * @param numEntries The number of entries of the hash functions. | ||
| * @param randCoefficients An array of random coefficients, each used by one hash function. | ||
|
||
| */ | ||
| @Experimental | ||
| @Since("2.1.0") | ||
| class MinHashModel private[ml] ( | ||
| class MinHashLSHModel private[ml]( | ||
| override val uid: String, | ||
| @Since("2.1.0") private[ml] val numEntries: Int, | ||
| @Since("2.1.0") private[ml] val randCoefficients: Array[Int]) | ||
| extends LSHModel[MinHashModel] { | ||
| private[ml] val numEntries: Int, | ||
| private[ml] val randCoefficients: Array[(Int, Int)]) | ||
| extends LSHModel[MinHashLSHModel] { | ||
|
|
||
| @Since("2.1.0") | ||
| override protected[ml] val hashFunction: Vector => Array[Vector] = { | ||
| elems: Vector => { | ||
| require(elems.numNonzeros > 0, "Must have at least 1 non zero entry.") | ||
| val elemsList = elems.toSparse.indices.toList | ||
| val hashValues = randCoefficients.map({ randCoefficient: Int => | ||
| elemsList.map({ elem: Int => | ||
| (1 + elem) * randCoefficient.toLong % MinHashLSH.prime % numEntries | ||
| }).min.toDouble | ||
| val hashValues = randCoefficients.map({ case (a: Int, b: Int) => | ||
|
||
| elemsList.map { elem: Int => | ||
| ((1 + elem) * a + b) % MinHashLSH.HASH_PRIME % numEntries | ||
|
||
| }.min.toDouble | ||
| }) | ||
| // TODO: For AND-amplification, output vectors of dimension numHashFunctions | ||
| // TODO: Output vectors of dimension numHashFunctions in SPARK-18450 | ||
| hashValues.grouped(1).map(Vectors.dense).toArray | ||
|
||
| } | ||
| } | ||
|
|
@@ -74,7 +74,7 @@ class MinHashModel private[ml] ( | |
| @Since("2.1.0") | ||
| override protected[ml] def hashDistance(x: Seq[Vector], y: Seq[Vector]): Double = { | ||
| // Since it's generated by hashing, it will be a pair of dense vectors. | ||
| // TODO: This hashDistance function is controversial. Requires more discussion. | ||
| // TODO: This hashDistance function requires more discussion in SPARK-18454 | ||
| x.zip(y).map(vectorPair => | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. At this point, I'm quite unsure, but this does not look to me like what what was discussed here. @jkbradley Can you confirm this is what you wanted?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since it's still under discussion, I am not sure which |
||
| vectorPair._1.toArray.zip(vectorPair._2.toArray).count(pair => pair._1 != pair._2) | ||
| ).min | ||
|
|
@@ -84,7 +84,7 @@ class MinHashModel private[ml] ( | |
| override def copy(extra: ParamMap): this.type = defaultCopy(extra) | ||
|
|
||
| @Since("2.1.0") | ||
| override def write: MLWriter = new MinHashModel.MinHashModelWriter(this) | ||
| override def write: MLWriter = new MinHashLSHModel.MinHashLSHModelWriter(this) | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -93,17 +93,17 @@ class MinHashModel private[ml] ( | |
| * LSH class for Jaccard distance. | ||
| * | ||
| * The input can be dense or sparse vectors, but it is more efficient if it is sparse. For example, | ||
| * `Vectors.sparse(10, Array[(2, 1.0), (3, 1.0), (5, 1.0)])` | ||
| * means there are 10 elements in the space. This set contains elem 2, elem 3 and elem 5. | ||
| * Also, any input vector must have at least 1 non-zero indices, and all non-zero values are treated | ||
| * as binary "1" values. | ||
| * `Vectors.sparse(10, Array((2, 1.0), (3, 1.0), (5, 1.0)))` | ||
| * means there are 10 elements in the space. This set contains non-zero values at indices 2, 3, and | ||
|
||
| * 5. Also, any input vector must have at least 1 non-zero index, and all non-zero values are | ||
| * treated as binary "1" values. | ||
| * | ||
| * References: | ||
| * [[https://en.wikipedia.org/wiki/MinHash Wikipedia on MinHash]] | ||
| */ | ||
| @Experimental | ||
| @Since("2.1.0") | ||
| class MinHashLSH(override val uid: String) extends LSH[MinHashModel] with HasSeed { | ||
| class MinHashLSH(override val uid: String) extends LSH[MinHashLSHModel] with HasSeed { | ||
|
|
||
| @Since("2.1.0") | ||
| override def setInputCol(value: String): this.type = super.setInputCol(value) | ||
|
|
@@ -116,21 +116,23 @@ class MinHashLSH(override val uid: String) extends LSH[MinHashModel] with HasSee | |
|
|
||
| @Since("2.1.0") | ||
| def this() = { | ||
| this(Identifiable.randomUID("min hash")) | ||
| this(Identifiable.randomUID("mh-lsh")) | ||
| } | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.1.0") | ||
| def setSeed(value: Long): this.type = set(seed, value) | ||
|
|
||
| @Since("2.1.0") | ||
| override protected[ml] def createRawLSHModel(inputDim: Int): MinHashModel = { | ||
| require(inputDim <= MinHashLSH.prime / 2, | ||
| s"The input vector dimension $inputDim exceeds the threshold ${MinHashLSH.prime / 2}.") | ||
| override protected[ml] def createRawLSHModel(inputDim: Int): MinHashLSHModel = { | ||
| require(inputDim <= MinHashLSH.HASH_PRIME, | ||
| s"The input vector dimension $inputDim exceeds the threshold ${MinHashLSH.HASH_PRIME}.") | ||
| val rand = new Random($(seed)) | ||
| val numEntry = inputDim * 2 | ||
| val randCoofs: Array[Int] = Array.fill($(numHashTables))(1 + rand.nextInt(MinHashLSH.prime - 1)) | ||
| new MinHashModel(uid, numEntry, randCoofs) | ||
| val numEntry = inputDim | ||
| val randCoefs: Array[(Int, Int)] = Array.fill(2 * $(numHashTables)) { | ||
|
||
| (1 + rand.nextInt(MinHashLSH.HASH_PRIME - 1), rand.nextInt(MinHashLSH.HASH_PRIME - 1)) | ||
| } | ||
| new MinHashLSHModel(uid, numEntry, randCoefs) | ||
| } | ||
|
|
||
| @Since("2.1.0") | ||
|
|
@@ -146,46 +148,49 @@ class MinHashLSH(override val uid: String) extends LSH[MinHashModel] with HasSee | |
| @Since("2.1.0") | ||
| object MinHashLSH extends DefaultParamsReadable[MinHashLSH] { | ||
| // A large prime smaller than sqrt(2^63 − 1) | ||
| private[ml] val prime = 2038074743 | ||
| private[ml] val HASH_PRIME = 2038074743 | ||
|
|
||
| @Since("2.1.0") | ||
| override def load(path: String): MinHashLSH = super.load(path) | ||
| } | ||
|
|
||
| @Since("2.1.0") | ||
| object MinHashModel extends MLReadable[MinHashModel] { | ||
| object MinHashLSHModel extends MLReadable[MinHashLSHModel] { | ||
|
|
||
| @Since("2.1.0") | ||
| override def read: MLReader[MinHashModel] = new MinHashModelReader | ||
| override def read: MLReader[MinHashLSHModel] = new MinHashLSHModelReader | ||
|
|
||
| @Since("2.1.0") | ||
| override def load(path: String): MinHashModel = super.load(path) | ||
| override def load(path: String): MinHashLSHModel = super.load(path) | ||
|
|
||
| private[MinHashModel] class MinHashModelWriter(instance: MinHashModel) extends MLWriter { | ||
| private[MinHashLSHModel] class MinHashLSHModelWriter(instance: MinHashLSHModel) | ||
| extends MLWriter { | ||
|
|
||
| private case class Data(numEntries: Int, randCoefficients: Array[Int]) | ||
|
|
||
| override protected def saveImpl(path: String): Unit = { | ||
| DefaultParamsWriter.saveMetadata(instance, path, sc) | ||
| val data = Data(instance.numEntries, instance.randCoefficients) | ||
| val data = Data(instance.numEntries, instance.randCoefficients | ||
| .flatMap(tuple => Array(tuple._1, tuple._2))) | ||
| val dataPath = new Path(path, "data").toString | ||
| sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) | ||
| } | ||
| } | ||
|
|
||
| private class MinHashModelReader extends MLReader[MinHashModel] { | ||
| private class MinHashLSHModelReader extends MLReader[MinHashLSHModel] { | ||
|
|
||
| /** Checked against metadata when loading model */ | ||
| private val className = classOf[MinHashModel].getName | ||
| private val className = classOf[MinHashLSHModel].getName | ||
|
|
||
| override def load(path: String): MinHashModel = { | ||
| override def load(path: String): MinHashLSHModel = { | ||
| val metadata = DefaultParamsReader.loadMetadata(path, sc, className) | ||
|
|
||
| val dataPath = new Path(path, "data").toString | ||
| val data = sparkSession.read.parquet(dataPath).select("numEntries", "randCoefficients").head() | ||
| val numEntries = data.getAs[Int](0) | ||
| val randCoefficients = data.getAs[Seq[Int]](1).toArray | ||
| val model = new MinHashModel(metadata.uid, numEntries, randCoefficients) | ||
| val randCoefficients = data.getAs[Seq[Int]](1).grouped(2) | ||
| .map(tuple => (tuple(0), tuple(1))).toArray | ||
| val model = new MinHashLSHModel(metadata.uid, numEntries, randCoefficients) | ||
|
|
||
| DefaultParamsReader.getAndSetParams(model, metadata) | ||
| model | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style nit: This should go on the previous line.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.