Skip to content

Commit a58c1af

Browse files
committed
[SPARK-100354] [MLLIB] fix some apparent memory issues in k-means|| initializaiton
* do not cache first cost RDD * change following cost RDD cache level to MEMORY_AND_DISK * remove Vector wrapper to save a object per instance Further improvements will be addressed in SPARK-10329 cc: yu-iskw HuJiayin Author: Xiangrui Meng <meng@databricks.com> Closes #8526 from mengxr/SPARK-10354. (cherry picked from commit f0f563a) Signed-off-by: Xiangrui Meng <meng@databricks.com>
1 parent e8b0564 commit a58c1af

File tree

1 file changed

+14
-7
lines changed
  • mllib/src/main/scala/org/apache/spark/mllib/clustering

1 file changed

+14
-7
lines changed

mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ class KMeans private (
281281
: Array[Array[VectorWithNorm]] = {
282282
// Initialize empty centers and point costs.
283283
val centers = Array.tabulate(runs)(r => ArrayBuffer.empty[VectorWithNorm])
284-
var costs = data.map(_ => Vectors.dense(Array.fill(runs)(Double.PositiveInfinity))).cache()
284+
var costs = data.map(_ => Array.fill(runs)(Double.PositiveInfinity))
285285

286286
// Initialize each run's first center to a random point.
287287
val seed = new XORShiftRandom(this.seed).nextInt()
@@ -306,21 +306,28 @@ class KMeans private (
306306
val bcNewCenters = data.context.broadcast(newCenters)
307307
val preCosts = costs
308308
costs = data.zip(preCosts).map { case (point, cost) =>
309-
Vectors.dense(
310309
Array.tabulate(runs) { r =>
311310
math.min(KMeans.pointCost(bcNewCenters.value(r), point), cost(r))
312-
})
313-
}.cache()
311+
}
312+
}.persist(StorageLevel.MEMORY_AND_DISK)
314313
val sumCosts = costs
315-
.aggregate(Vectors.zeros(runs))(
314+
.aggregate(new Array[Double](runs))(
316315
seqOp = (s, v) => {
317316
// s += v
318-
axpy(1.0, v, s)
317+
var r = 0
318+
while (r < runs) {
319+
s(r) += v(r)
320+
r += 1
321+
}
319322
s
320323
},
321324
combOp = (s0, s1) => {
322325
// s0 += s1
323-
axpy(1.0, s1, s0)
326+
var r = 0
327+
while (r < runs) {
328+
s0(r) += s1(r)
329+
r += 1
330+
}
324331
s0
325332
}
326333
)

0 commit comments

Comments
 (0)