Skip to content

Commit ea06e4e

Browse files
uzadudesrowen
authored andcommitted
[SPARK-16469] enhanced simulate multiply
## What changes were proposed in this pull request? We have a use case of multiplying very big sparse matrices. we have about 1000x1000 distributed block matrices multiplication and the simulate multiply goes like O(n^4) (n being 1000). it takes about 1.5 hours. We modified it slightly with classical hashmap and now run in about 30 seconds O(n^2). ## How was this patch tested? We have added a performance test and verified the reduced time. Author: oraviv <[email protected]> Closes #14068 from uzadude/master.
1 parent 51ade51 commit ea06e4e

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -426,16 +426,21 @@ class BlockMatrix @Since("1.3.0") (
426426
partitioner: GridPartitioner): (BlockDestinations, BlockDestinations) = {
427427
val leftMatrix = blockInfo.keys.collect() // blockInfo should already be cached
428428
val rightMatrix = other.blocks.keys.collect()
429+
430+
val rightCounterpartsHelper = rightMatrix.groupBy(_._1).mapValues(_.map(_._2))
429431
val leftDestinations = leftMatrix.map { case (rowIndex, colIndex) =>
430-
val rightCounterparts = rightMatrix.filter(_._1 == colIndex)
431-
val partitions = rightCounterparts.map(b => partitioner.getPartition((rowIndex, b._2)))
432+
val rightCounterparts = rightCounterpartsHelper.getOrElse(colIndex, Array())
433+
val partitions = rightCounterparts.map(b => partitioner.getPartition((rowIndex, b)))
432434
((rowIndex, colIndex), partitions.toSet)
433435
}.toMap
436+
437+
val leftCounterpartsHelper = leftMatrix.groupBy(_._2).mapValues(_.map(_._1))
434438
val rightDestinations = rightMatrix.map { case (rowIndex, colIndex) =>
435-
val leftCounterparts = leftMatrix.filter(_._2 == rowIndex)
436-
val partitions = leftCounterparts.map(b => partitioner.getPartition((b._1, colIndex)))
439+
val leftCounterparts = leftCounterpartsHelper.getOrElse(rowIndex, Array())
440+
val partitions = leftCounterparts.map(b => partitioner.getPartition((b, colIndex)))
437441
((rowIndex, colIndex), partitions.toSet)
438442
}.toMap
443+
439444
(leftDestinations, rightDestinations)
440445
}
441446

0 commit comments

Comments
 (0)