Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
559c099
[SPARK-18334] MinHash should use binary hash distance
Nov 7, 2016
517a97b
Remove misleading documentation as requested
Yunni Nov 8, 2016
b546dbd
Add warning for multi-probe in MinHash
Nov 8, 2016
a3cd928
Merge branch 'SPARK-18334-yunn-minhash-bug' of https://github.com/Yun…
Nov 8, 2016
c8243c7
(1) Fix documentation as CR suggested (2) Fix typo in unit test
Nov 9, 2016
6aac8b3
Fix typo in unit test
Nov 9, 2016
9870743
[SPARK-18408] API Improvements for LSH
Nov 14, 2016
0e9250b
(1) Fix description for numHashFunctions (2) Make numEntries in MinHa…
Nov 14, 2016
adbbefe
Add assertion for hashFunction in BucketedRandomProjectionLSHSuite
Nov 14, 2016
c115ed3
Revert AND-amplification for a future PR
Nov 14, 2016
033ae5d
Code Review Comments
Nov 15, 2016
c597f4c
Add unit tests to run on Jenkins.
Nov 16, 2016
d759875
Add unit tests to run on Jenkins.
Nov 16, 2016
596eb06
CR comments
Nov 17, 2016
00d08bf
Merge branch 'master' of https://github.com/apache/spark into SPARK-1…
Nov 17, 2016
3d0810f
Update comments
Nov 17, 2016
257ef19
Add scaladoc for approximately min-wise independence
Yunni Nov 18, 2016
2c264b7
Change documentation reference
Yunni Nov 18, 2016
36ca278
Removing modulo numEntries
Nov 19, 2016
4508393
Merge branch 'SPARK-18408-yunn-api-improvements' of https://github.co…
Nov 19, 2016
939e9d5
Code Review Comments
Nov 22, 2016
8b9403d
Minimize the test cases by directly using artificial models
Nov 22, 2016
f0ebcb7
Code review comments
Nov 22, 2016
e198080
Merge branch 'master' of https://github.com/apache/spark into SPARK-1…
Yunni Nov 28, 2016
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Code Review Comments
  • Loading branch information
Yun Ni committed Nov 15, 2016
commit 033ae5db1092ab2cd426f974c3e8de594461ca20
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ import org.apache.spark.sql.types.StructType
*
* Params for [[BucketedRandomProjectionLSH]].
*/
private[ml] trait BucketedRandomProjectionParams extends Params {
private[ml] trait BucketedRandomProjectionLSHParams extends Params {

/**
* The length of each hash bucket, a larger bucket lowers the false negative rate. The number of
Expand Down Expand Up @@ -68,18 +68,18 @@ private[ml] trait BucketedRandomProjectionParams extends Params {
*/
@Experimental
@Since("2.1.0")
class BucketedRandomProjectionModel private[ml](
class BucketedRandomProjectionLSHModel private[ml](
override val uid: String,
@Since("2.1.0") private[ml] val randUnitVectors: Array[Vector])
extends LSHModel[BucketedRandomProjectionModel] with BucketedRandomProjectionParams {
private[ml] val randUnitVectors: Array[Vector])
extends LSHModel[BucketedRandomProjectionLSHModel] with BucketedRandomProjectionLSHParams {

@Since("2.1.0")
override protected[ml] val hashFunction: Vector => Array[Vector] = {
key: Vector => {
val hashValues: Array[Double] = randUnitVectors.map({
randUnitVector => Math.floor(BLAS.dot(key, randUnitVector) / $(bucketLength))
})
// TODO: For AND-amplification, output vectors of dimension numHashFunctions
// TODO: Output vectors of dimension numHashFunctions in SPARK-18450
hashValues.grouped(1).map(Vectors.dense).toArray
}
}
Expand All @@ -100,7 +100,7 @@ class BucketedRandomProjectionModel private[ml](

@Since("2.1.0")
override def write: MLWriter = {
new BucketedRandomProjectionModel.BucketedRandomProjectionModelWriter(this)
new BucketedRandomProjectionLSHModel.BucketedRandomProjectionLSHModelWriter(this)
}
}

Expand All @@ -111,8 +111,8 @@ class BucketedRandomProjectionModel private[ml](
* Euclidean distance metrics.
*
* The input is dense or sparse vectors, each of which represents a point in the Euclidean
* distance space. The output will be vectors of configurable dimension. Hash value in the same
* dimension is calculated by the same hash function.
* distance space. The output will be vectors of configurable dimension. Hash values in the
* same dimension are calculated by the same hash function.
*
* References:
*
Expand All @@ -125,7 +125,8 @@ class BucketedRandomProjectionModel private[ml](
@Experimental
@Since("2.1.0")
class BucketedRandomProjectionLSH(override val uid: String)
extends LSH[BucketedRandomProjectionModel] with BucketedRandomProjectionParams with HasSeed {
extends LSH[BucketedRandomProjectionLSHModel]
with BucketedRandomProjectionLSHParams with HasSeed {

@Since("2.1.0")
override def setInputCol(value: String): this.type = super.setInputCol(value)
Expand All @@ -138,7 +139,7 @@ class BucketedRandomProjectionLSH(override val uid: String)

@Since("2.1.0")
def this() = {
this(Identifiable.randomUID("random projection"))
this(Identifiable.randomUID("brp-lsh"))
}

/** @group setParam */
Expand All @@ -150,15 +151,17 @@ class BucketedRandomProjectionLSH(override val uid: String)
def setSeed(value: Long): this.type = set(seed, value)

@Since("2.1.0")
override protected[this] def createRawLSHModel(inputDim: Int): BucketedRandomProjectionModel = {
override protected[this] def createRawLSHModel(
inputDim: Int
): BucketedRandomProjectionLSHModel = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style nit: This should go on the previous line.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

val rand = new Random($(seed))
val randUnitVectors: Array[Vector] = {
Array.fill($(numHashTables)) {
val randArray = Array.fill(inputDim)(rand.nextGaussian())
Vectors.fromBreeze(normalize(breeze.linalg.Vector(randArray)))
}
}
new BucketedRandomProjectionModel(uid, randUnitVectors)
new BucketedRandomProjectionLSHModel(uid, randUnitVectors)
}

@Since("2.1.0")
Expand All @@ -179,18 +182,18 @@ object BucketedRandomProjectionLSH extends DefaultParamsReadable[BucketedRandomP
}

@Since("2.1.0")
object BucketedRandomProjectionModel extends MLReadable[BucketedRandomProjectionModel] {
object BucketedRandomProjectionLSHModel extends MLReadable[BucketedRandomProjectionLSHModel] {

@Since("2.1.0")
override def read: MLReader[BucketedRandomProjectionModel] = {
new BucketedRandomProjectionModelReader
override def read: MLReader[BucketedRandomProjectionLSHModel] = {
new BucketedRandomProjectionLSHModelReader
}

@Since("2.1.0")
override def load(path: String): BucketedRandomProjectionModel = super.load(path)
override def load(path: String): BucketedRandomProjectionLSHModel = super.load(path)

private[BucketedRandomProjectionModel] class BucketedRandomProjectionModelWriter(
instance: BucketedRandomProjectionModel) extends MLWriter {
private[BucketedRandomProjectionLSHModel] class BucketedRandomProjectionLSHModelWriter(
instance: BucketedRandomProjectionLSHModel) extends MLWriter {

// TODO: Save using the existing format of Array[Vector] once SPARK-12878 is resolved.
private case class Data(randUnitVectors: Matrix)
Expand All @@ -208,21 +211,22 @@ object BucketedRandomProjectionModel extends MLReadable[BucketedRandomProjection
}
}

private class BucketedRandomProjectionModelReader
extends MLReader[BucketedRandomProjectionModel] {
private class BucketedRandomProjectionLSHModelReader
extends MLReader[BucketedRandomProjectionLSHModel] {

/** Checked against metadata when loading model */
private val className = classOf[BucketedRandomProjectionModel].getName
private val className = classOf[BucketedRandomProjectionLSHModel].getName

override def load(path: String): BucketedRandomProjectionModel = {
override def load(path: String): BucketedRandomProjectionLSHModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)

val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath)
val Row(randUnitVectors: Matrix) = MLUtils.convertMatrixColumnsToML(data, "randUnitVectors")
.select("randUnitVectors")
.head()
val model = new BucketedRandomProjectionModel(metadata.uid, randUnitVectors.rowIter.toArray)
val model = new BucketedRandomProjectionLSHModel(metadata.uid,
randUnitVectors.rowIter.toArray)

DefaultParamsReader.getAndSetParams(model, metadata)
model
Expand Down
31 changes: 6 additions & 25 deletions mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ import org.apache.spark.sql.types._
*/
private[ml] trait LSHParams extends HasInputCol with HasOutputCol {
/**
* Param for the dimension of LSH OR-amplification.
* Param for the number of hash tables used in LSH OR-amplification.
*
* LSH OR-amplification can be used to reduce the false negative rate. The higher the dimension
* is, the lower the false negative rate.
* LSH OR-amplification can be used to reduce the false negative rate. Higher values for this
* param lead to a reduced false negative rate, at the expense of added computational complexity.
* @group param
*/
final val numHashTables: IntParam = new IntParam(this, "numHashTables", "number of hash " +
Expand Down Expand Up @@ -66,7 +66,7 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
self: T =>

/**
* The hash function of LSH, mapping an input feature to multiple vectors
* The hash function of LSH, mapping an input feature vector to multiple hash vectors.
* @return The mapping of LSH function.
*/
protected[ml] val hashFunction: Vector => Array[Vector]
Expand Down Expand Up @@ -99,26 +99,7 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
validateAndTransformSchema(schema)
}

/**
* Given a large dataset and an item, approximately find at most k items which have the closest
* distance to the item. If the [[outputCol]] is missing, the method will transform the data; if
* the [[outputCol]] exists, it will use the [[outputCol]]. This allows caching of the
* transformed data when necessary.
*
* This method implements two ways of fetching k nearest neighbors:
* - Single-probe: Fast, return at most k elements (Probing only one buckets)
* - Multi-probe: Slow, return exact k elements (Probing multiple buckets close to the key)
*
* Currently it is made private since more discussion is needed for Multi-probe
*
* @param dataset the dataset to search for nearest neighbors of the key
* @param key Feature vector representing the item to search for
* @param numNearestNeighbors The maximum number of nearest neighbors
* @param singleProbe True for using single-probe; false for multi-probe
* @param distCol Output column for storing the distance between each result row and the key
* @return A dataset containing at most k items closest to the key. A distCol is added to show
* the distance between each row and the key.
*/
// TODO: Fix the MultiProbe NN Search in SPARK-18454
private[feature] def approxNearestNeighbors(
dataset: Dataset[_],
key: Vector,
Expand Down Expand Up @@ -179,7 +160,7 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
* @return A dataset containing at most k items closest to the key. A distCol is added to show
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: A column "distCol" is added ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

* the distance between each row and the key.
*/
private[feature] def approxNearestNeighbors(
def approxNearestNeighbors(
dataset: Dataset[_],
key: Vector,
numNearestNeighbors: Int,
Expand Down
77 changes: 41 additions & 36 deletions mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,31 +32,31 @@ import org.apache.spark.sql.types.StructType
* :: Experimental ::
*
* Model produced by [[MinHashLSH]], where multiple hash functions are stored. Each hash function is
* a perfect hash function for a specific set `S` with cardinality equal to a half of `numEntries`:
* `h_i(x) = ((x \cdot k_i) \mod prime) \mod numEntries`
* a perfect hash function for a specific set `S` with cardinality equal to `numEntries`:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking more at the Wikipedia entry, I'm still doubtful about whether this is a perfect hash function. It looks like the first of 2 parts in the construction of a perfect hash function. I also still don't see why mentioning "perfect hash functions" will help users.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, removed.

* `h_i(x) = ((x \cdot a_i + b_i) \mod prime) \mod numEntries`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should remove the numEntries part here if we have removed it from the code. I still haven't gotten around to properly digging into this. For now, I'd like to change the second sentence to: "Each hash function is picked from the following family of hash functions, where a_i and b_i are randomly chosen integers less than prime:"

And I prefer this paper: "http://www.combinatorics.org/ojs/index.php/eljc/article/download/v7i1r26/pdf" as a reference because it is concise and easier to parse. That said, since it's a direct download link we could maybe not put the link in the doc, and just list the reference.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

*
* @param numEntries The number of entries of the hash functions.
* @param randCoefficients An array of random coefficients, each used by one hash function.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to update description

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

*/
@Experimental
@Since("2.1.0")
class MinHashModel private[ml] (
class MinHashLSHModel private[ml](
override val uid: String,
@Since("2.1.0") private[ml] val numEntries: Int,
@Since("2.1.0") private[ml] val randCoefficients: Array[Int])
extends LSHModel[MinHashModel] {
private[ml] val numEntries: Int,
private[ml] val randCoefficients: Array[(Int, Int)])
extends LSHModel[MinHashLSHModel] {

@Since("2.1.0")
override protected[ml] val hashFunction: Vector => Array[Vector] = {
elems: Vector => {
require(elems.numNonzeros > 0, "Must have at least 1 non zero entry.")
val elemsList = elems.toSparse.indices.toList
val hashValues = randCoefficients.map({ randCoefficient: Int =>
elemsList.map({ elem: Int =>
(1 + elem) * randCoefficient.toLong % MinHashLSH.prime % numEntries
}).min.toDouble
val hashValues = randCoefficients.map({ case (a: Int, b: Int) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: the "({" is redundant. Also, I don't think the type annotations are necessary.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the parentheses.

elemsList.map { elem: Int =>
((1 + elem) * a + b) % MinHashLSH.HASH_PRIME % numEntries
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still looking at it, but I don't think this is correct. Why do we tack on % numEntries here. Could you point me to a resource? The paper linked above (and many other references that I've seen) use (ax + b) mod p where p is a large prime.

I see the formula listed under the wiki article for perfect hashing functions lists (kx mod p) mod n, but that's not the full picture. They are referencing a paper which simply uses that formula as the first part of multilevel scheme.

If it's helpful - this seems to be the original paper on MinHash. The author mentions that

This is further explored in [5] where it is shown
that random linear transformations are likely to suffice in practice.

Reference 5 is here, which seems to be a more concise version of your reference. In that paper, they describe (ax + b mod p).

}.min.toDouble
})
// TODO: For AND-amplification, output vectors of dimension numHashFunctions
// TODO: Output vectors of dimension numHashFunctions in SPARK-18450
hashValues.grouped(1).map(Vectors.dense).toArray
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not: hashValues.map(Vectors.dense(_)) ? We can just add the grouping later when it's needed. Same for BRP.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Vectors.dense takes an array instead of a single number.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is an alternate constructor which takes a single (or multiple values). I guess I just think the grouped(1) is a bit confusing, not really an efficiency concern.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. It's dense(firstValue: Double, otherValues: Double*).

}
}
Expand All @@ -74,7 +74,7 @@ class MinHashModel private[ml] (
@Since("2.1.0")
override protected[ml] def hashDistance(x: Seq[Vector], y: Seq[Vector]): Double = {
// Since it's generated by hashing, it will be a pair of dense vectors.
// TODO: This hashDistance function is controversial. Requires more discussion.
// TODO: This hashDistance function requires more discussion in SPARK-18454
x.zip(y).map(vectorPair =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At this point, I'm quite unsure, but this does not look to me like what what was discussed here. @jkbradley Can you confirm this is what you wanted?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it's still under discussion, I am not sure which hashDistance to leave in the code. Do you just want me to change to the hashDistance @jkbradley suggested?

vectorPair._1.toArray.zip(vectorPair._2.toArray).count(pair => pair._1 != pair._2)
).min
Expand All @@ -84,7 +84,7 @@ class MinHashModel private[ml] (
override def copy(extra: ParamMap): this.type = defaultCopy(extra)

@Since("2.1.0")
override def write: MLWriter = new MinHashModel.MinHashModelWriter(this)
override def write: MLWriter = new MinHashLSHModel.MinHashLSHModelWriter(this)
}

/**
Expand All @@ -93,17 +93,17 @@ class MinHashModel private[ml] (
* LSH class for Jaccard distance.
*
* The input can be dense or sparse vectors, but it is more efficient if it is sparse. For example,
* `Vectors.sparse(10, Array[(2, 1.0), (3, 1.0), (5, 1.0)])`
* means there are 10 elements in the space. This set contains elem 2, elem 3 and elem 5.
* Also, any input vector must have at least 1 non-zero indices, and all non-zero values are treated
* as binary "1" values.
* `Vectors.sparse(10, Array((2, 1.0), (3, 1.0), (5, 1.0)))`
* means there are 10 elements in the space. This set contains non-zero values at indices 2, 3, and
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer the old terminology since all non-zero values are treated the same. How about "This set contains elements 2, 3, and 5." ?

* 5. Also, any input vector must have at least 1 non-zero index, and all non-zero values are
* treated as binary "1" values.
*
* References:
* [[https://en.wikipedia.org/wiki/MinHash Wikipedia on MinHash]]
*/
@Experimental
@Since("2.1.0")
class MinHashLSH(override val uid: String) extends LSH[MinHashModel] with HasSeed {
class MinHashLSH(override val uid: String) extends LSH[MinHashLSHModel] with HasSeed {

@Since("2.1.0")
override def setInputCol(value: String): this.type = super.setInputCol(value)
Expand All @@ -116,21 +116,23 @@ class MinHashLSH(override val uid: String) extends LSH[MinHashModel] with HasSee

@Since("2.1.0")
def this() = {
this(Identifiable.randomUID("min hash"))
this(Identifiable.randomUID("mh-lsh"))
}

/** @group setParam */
@Since("2.1.0")
def setSeed(value: Long): this.type = set(seed, value)

@Since("2.1.0")
override protected[ml] def createRawLSHModel(inputDim: Int): MinHashModel = {
require(inputDim <= MinHashLSH.prime / 2,
s"The input vector dimension $inputDim exceeds the threshold ${MinHashLSH.prime / 2}.")
override protected[ml] def createRawLSHModel(inputDim: Int): MinHashLSHModel = {
require(inputDim <= MinHashLSH.HASH_PRIME,
s"The input vector dimension $inputDim exceeds the threshold ${MinHashLSH.HASH_PRIME}.")
val rand = new Random($(seed))
val numEntry = inputDim * 2
val randCoofs: Array[Int] = Array.fill($(numHashTables))(1 + rand.nextInt(MinHashLSH.prime - 1))
new MinHashModel(uid, numEntry, randCoofs)
val numEntry = inputDim
val randCoefs: Array[(Int, Int)] = Array.fill(2 * $(numHashTables)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it 2 * $(numHashTables) ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this was an error before, we should have a unit test that catches this. Basically, the output of transform should be a vector of length equal to numHashTables.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unit tests added in LSHTest.scala

(1 + rand.nextInt(MinHashLSH.HASH_PRIME - 1), rand.nextInt(MinHashLSH.HASH_PRIME - 1))
}
new MinHashLSHModel(uid, numEntry, randCoefs)
}

@Since("2.1.0")
Expand All @@ -146,46 +148,49 @@ class MinHashLSH(override val uid: String) extends LSH[MinHashModel] with HasSee
@Since("2.1.0")
object MinHashLSH extends DefaultParamsReadable[MinHashLSH] {
// A large prime smaller than sqrt(2^63 − 1)
private[ml] val prime = 2038074743
private[ml] val HASH_PRIME = 2038074743

@Since("2.1.0")
override def load(path: String): MinHashLSH = super.load(path)
}

@Since("2.1.0")
object MinHashModel extends MLReadable[MinHashModel] {
object MinHashLSHModel extends MLReadable[MinHashLSHModel] {

@Since("2.1.0")
override def read: MLReader[MinHashModel] = new MinHashModelReader
override def read: MLReader[MinHashLSHModel] = new MinHashLSHModelReader

@Since("2.1.0")
override def load(path: String): MinHashModel = super.load(path)
override def load(path: String): MinHashLSHModel = super.load(path)

private[MinHashModel] class MinHashModelWriter(instance: MinHashModel) extends MLWriter {
private[MinHashLSHModel] class MinHashLSHModelWriter(instance: MinHashLSHModel)
extends MLWriter {

private case class Data(numEntries: Int, randCoefficients: Array[Int])

override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
val data = Data(instance.numEntries, instance.randCoefficients)
val data = Data(instance.numEntries, instance.randCoefficients
.flatMap(tuple => Array(tuple._1, tuple._2)))
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}

private class MinHashModelReader extends MLReader[MinHashModel] {
private class MinHashLSHModelReader extends MLReader[MinHashLSHModel] {

/** Checked against metadata when loading model */
private val className = classOf[MinHashModel].getName
private val className = classOf[MinHashLSHModel].getName

override def load(path: String): MinHashModel = {
override def load(path: String): MinHashLSHModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)

val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath).select("numEntries", "randCoefficients").head()
val numEntries = data.getAs[Int](0)
val randCoefficients = data.getAs[Seq[Int]](1).toArray
val model = new MinHashModel(metadata.uid, numEntries, randCoefficients)
val randCoefficients = data.getAs[Seq[Int]](1).grouped(2)
.map(tuple => (tuple(0), tuple(1))).toArray
val model = new MinHashLSHModel(metadata.uid, numEntries, randCoefficients)

DefaultParamsReader.getAndSetParams(model, metadata)
model
Expand Down
Loading