@@ -36,11 +36,11 @@ import org.apache.spark.util.Utils
3636 * independent Gaussian distributions with associated "mixing" weights
3737 * specifying each's contribution to the composite.
3838 *
39- * Given a set of sample points, this class will maximize the log-likelihood
40- * for a mixture of k Gaussians, iterating until the log-likelihood changes by
39+ * Given a set of sample points, this class will maximize the log-likelihood
40+ * for a mixture of k Gaussians, iterating until the log-likelihood changes by
4141 * less than convergenceTol, or until it has reached the max number of iterations.
4242 * While this process is generally guaranteed to converge, it is not guaranteed
43- * to find a global optimum.
43+ * to find a global optimum.
4444 *
4545 * Note: For high-dimensional data (with many features), this algorithm may perform poorly.
4646 * This is due to high-dimensional data (a) making it difficult to cluster at all (based
@@ -53,24 +53,24 @@ import org.apache.spark.util.Utils
5353 */
5454@ Experimental
5555class GaussianMixture private (
56- private var k : Int ,
57- private var convergenceTol : Double ,
56+ private var k : Int ,
57+ private var convergenceTol : Double ,
5858 private var maxIterations : Int ,
5959 private var seed : Long ) extends Serializable {
60-
60+
6161 /**
6262 * Constructs a default instance. The default parameters are {k: 2, convergenceTol: 0.01,
6363 * maxIterations: 100, seed: random}.
6464 */
6565 def this () = this (2 , 0.01 , 100 , Utils .random.nextLong())
66-
66+
6767 // number of samples per cluster to use when initializing Gaussians
6868 private val nSamples = 5
69-
70- // an initializing GMM can be provided rather than using the
69+
70+ // an initializing GMM can be provided rather than using the
7171 // default random starting point
7272 private var initialModel : Option [GaussianMixtureModel ] = None
73-
73+
7474 /** Set the initial GMM starting point, bypassing the random initialization.
7575 * You must call setK() prior to calling this method, and the condition
7676 * (model.k == this.k) must be met; failure will result in an IllegalArgumentException
@@ -83,37 +83,37 @@ class GaussianMixture private (
8383 }
8484 this
8585 }
86-
86+
8787 /** Return the user supplied initial GMM, if supplied */
8888 def getInitialModel : Option [GaussianMixtureModel ] = initialModel
89-
89+
9090 /** Set the number of Gaussians in the mixture model. Default: 2 */
9191 def setK (k : Int ): this .type = {
9292 this .k = k
9393 this
9494 }
95-
95+
9696 /** Return the number of Gaussians in the mixture model */
9797 def getK : Int = k
98-
98+
9999 /** Set the maximum number of iterations to run. Default: 100 */
100100 def setMaxIterations (maxIterations : Int ): this .type = {
101101 this .maxIterations = maxIterations
102102 this
103103 }
104-
104+
105105 /** Return the maximum number of iterations to run */
106106 def getMaxIterations : Int = maxIterations
107-
107+
108108 /**
109- * Set the largest change in log-likelihood at which convergence is
109+ * Set the largest change in log-likelihood at which convergence is
110110 * considered to have occurred.
111111 */
112112 def setConvergenceTol (convergenceTol : Double ): this .type = {
113113 this .convergenceTol = convergenceTol
114114 this
115115 }
116-
116+
117117 /**
118118 * Return the largest change in log-likelihood at which convergence is
119119 * considered to have occurred.
@@ -132,41 +132,41 @@ class GaussianMixture private (
132132 /** Perform expectation maximization */
133133 def run (data : RDD [Vector ]): GaussianMixtureModel = {
134134 val sc = data.sparkContext
135-
135+
136136 // we will operate on the data as breeze data
137137 val breezeData = data.map(_.toBreeze).cache()
138-
138+
139139 // Get length of the input vectors
140140 val d = breezeData.first().length
141-
141+
142142 // Determine initial weights and corresponding Gaussians.
143143 // If the user supplied an initial GMM, we use those values, otherwise
144144 // we start with uniform weights, a random mean from the data, and
145145 // diagonal covariance matrices using component variances
146- // derived from the samples
146+ // derived from the samples
147147 val (weights, gaussians) = initialModel match {
148148 case Some (gmm) => (gmm.weights, gmm.gaussians)
149-
149+
150150 case None => {
151151 val samples = breezeData.takeSample(withReplacement = true , k * nSamples, seed)
152- (Array .fill(k)(1.0 / k), Array .tabulate(k) { i =>
152+ (Array .fill(k)(1.0 / k), Array .tabulate(k) { i =>
153153 val slice = samples.view(i * nSamples, (i + 1 ) * nSamples)
154- new MultivariateGaussian (vectorMean(slice), initCovariance(slice))
154+ new MultivariateGaussian (vectorMean(slice), initCovariance(slice))
155155 })
156156 }
157157 }
158-
159- var llh = Double .MinValue // current log-likelihood
158+
159+ var llh = Double .MinValue // current log-likelihood
160160 var llhp = 0.0 // previous log-likelihood
161-
161+
162162 var iter = 0
163163 while (iter < maxIterations && math.abs(llh- llhp) > convergenceTol) {
164164 // create and broadcast curried cluster contribution function
165165 val compute = sc.broadcast(ExpectationSum .add(weights, gaussians)_)
166-
166+
167167 // aggregate the cluster contribution for all sample points
168168 val sums = breezeData.aggregate(ExpectationSum .zero(k, d))(compute.value, _ += _)
169-
169+
170170 // Create new distributions based on the partial assignments
171171 // (often referred to as the "M" step in literature)
172172 val sumWeights = sums.weights.sum
@@ -179,22 +179,22 @@ class GaussianMixture private (
179179 gaussians(i) = new MultivariateGaussian (mu, sums.sigmas(i) / sums.weights(i))
180180 i = i + 1
181181 }
182-
182+
183183 llhp = llh // current becomes previous
184184 llh = sums.logLikelihood // this is the freshly computed log-likelihood
185185 iter += 1
186- }
187-
186+ }
187+
188188 new GaussianMixtureModel (weights, gaussians)
189189 }
190-
190+
191191 /** Average of dense breeze vectors */
192192 private def vectorMean (x : IndexedSeq [BV [Double ]]): BDV [Double ] = {
193193 val v = BDV .zeros[Double ](x(0 ).length)
194194 x.foreach(xi => v += xi)
195- v / x.length.toDouble
195+ v / x.length.toDouble
196196 }
197-
197+
198198 /**
199199 * Construct matrix where diagonal entries are element-wise
200200 * variance of input vectors (computes biased variance)
@@ -210,14 +210,14 @@ class GaussianMixture private (
210210// companion class to provide zero constructor for ExpectationSum
211211private object ExpectationSum {
212212 def zero (k : Int , d : Int ): ExpectationSum = {
213- new ExpectationSum (0.0 , Array .fill(k)(0.0 ),
213+ new ExpectationSum (0.0 , Array .fill(k)(0.0 ),
214214 Array .fill(k)(BDV .zeros(d)), Array .fill(k)(BreezeMatrix .zeros(d, d)))
215215 }
216-
216+
217217 // compute cluster contributions for each input point
218218 // (U, T) => U for aggregation
219219 def add (
220- weights : Array [Double ],
220+ weights : Array [Double ],
221221 dists : Array [MultivariateGaussian ])
222222 (sums : ExpectationSum , x : BV [Double ]): ExpectationSum = {
223223 val p = weights.zip(dists).map {
@@ -235,7 +235,7 @@ private object ExpectationSum {
235235 i = i + 1
236236 }
237237 sums
238- }
238+ }
239239}
240240
241241// Aggregation class for partial expectation results
@@ -244,9 +244,9 @@ private class ExpectationSum(
244244 val weights : Array [Double ],
245245 val means : Array [BDV [Double ]],
246246 val sigmas : Array [BreezeMatrix [Double ]]) extends Serializable {
247-
247+
248248 val k = weights.length
249-
249+
250250 def += (x : ExpectationSum ): ExpectationSum = {
251251 var i = 0
252252 while (i < k) {
@@ -257,5 +257,5 @@ private class ExpectationSum(
257257 }
258258 logLikelihood += x.logLikelihood
259259 this
260- }
260+ }
261261}
0 commit comments