1717
1818package org .apache .spark .mllib .stat
1919
20- import breeze .linalg .{DenseVector => BDV }
21-
2220import org .apache .spark .annotation .DeveloperApi
23- import org .apache .spark .mllib .linalg .{DenseVector , SparseVector , Vectors , Vector }
21+ import org .apache .spark .mllib .linalg .{Vectors , Vector }
2422
2523/**
2624 * :: DeveloperApi ::
@@ -40,35 +38,14 @@ import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors, Vector
4038class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with Serializable {
4139
4240 private var n = 0
43- private var currMean : BDV [Double ] = _
44- private var currM2n : BDV [Double ] = _
45- private var currM2 : BDV [Double ] = _
46- private var currL1 : BDV [Double ] = _
41+ private var currMean : Array [Double ] = _
42+ private var currM2n : Array [Double ] = _
43+ private var currM2 : Array [Double ] = _
44+ private var currL1 : Array [Double ] = _
4745 private var totalCnt : Long = 0
48- private var nnz : BDV [Double ] = _
49- private var currMax : BDV [Double ] = _
50- private var currMin : BDV [Double ] = _
51-
52- /**
53- * Adds input value to position i.
54- */
55- private [this ] def add (i : Int , value : Double ) = {
56- if (currMax(i) < value) {
57- currMax(i) = value
58- }
59- if (currMin(i) > value) {
60- currMin(i) = value
61- }
62-
63- val prevMean = currMean(i)
64- val diff = value - prevMean
65- currMean(i) = prevMean + diff / (nnz(i) + 1.0 )
66- currM2n(i) += (value - currMean(i)) * diff
67- currM2(i) += value * value
68- currL1(i) += math.abs(value)
69-
70- nnz(i) += 1.0
71- }
46+ private var nnz : Array [Double ] = _
47+ private var currMax : Array [Double ] = _
48+ private var currMin : Array [Double ] = _
7249
7350 /**
7451 * Add a new sample to this summarizer, and update the statistical summary.
@@ -81,19 +58,37 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
8158 require(sample.size > 0 , s " Vector should have dimension larger than zero. " )
8259 n = sample.size
8360
84- currMean = BDV .zeros [Double ](n)
85- currM2n = BDV .zeros [Double ](n)
86- currM2 = BDV .zeros [Double ](n)
87- currL1 = BDV .zeros [Double ](n)
88- nnz = BDV .zeros [Double ](n)
89- currMax = BDV .fill(n)(Double .MinValue )
90- currMin = BDV .fill(n)(Double .MaxValue )
61+ currMean = Array .ofDim [Double ](n)
62+ currM2n = Array .ofDim [Double ](n)
63+ currM2 = Array .ofDim [Double ](n)
64+ currL1 = Array .ofDim [Double ](n)
65+ nnz = Array .ofDim [Double ](n)
66+ currMax = Array .fill[ Double ] (n)(Double .MinValue )
67+ currMin = Array .fill[ Double ] (n)(Double .MaxValue )
9168 }
9269
9370 require(n == sample.size, s " Dimensions mismatch when adding new sample. " +
9471 s " Expecting $n but got ${sample.size}. " )
9572
96- sample.foreach(true )(x => add(x._1, x._2))
73+ sample.foreachActive((index, value) => {
74+ if (value != 0.0 ){
75+ if (currMax(index) < value) {
76+ currMax(index) = value
77+ }
78+ if (currMin(index) > value) {
79+ currMin(index) = value
80+ }
81+
82+ val prevMean = currMean(index)
83+ val diff = value - prevMean
84+ currMean(index) = prevMean + diff / (nnz(index) + 1.0 )
85+ currM2n(index) += (value - currMean(index)) * diff
86+ currM2(index) += value * value
87+ currL1(index) += math.abs(value)
88+
89+ nnz(index) += 1.0
90+ }
91+ })
9792
9893 totalCnt += 1
9994 this
@@ -135,34 +130,34 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
135130 }
136131 } else if (totalCnt == 0 && other.totalCnt != 0 ) {
137132 this .n = other.n
138- this .currMean = other.currMean.copy
139- this .currM2n = other.currM2n.copy
140- this .currM2 = other.currM2.copy
141- this .currL1 = other.currL1.copy
133+ this .currMean = other.currMean.clone
134+ this .currM2n = other.currM2n.clone
135+ this .currM2 = other.currM2.clone
136+ this .currL1 = other.currL1.clone
142137 this .totalCnt = other.totalCnt
143- this .nnz = other.nnz.copy
144- this .currMax = other.currMax.copy
145- this .currMin = other.currMin.copy
138+ this .nnz = other.nnz.clone
139+ this .currMax = other.currMax.clone
140+ this .currMin = other.currMin.clone
146141 }
147142 this
148143 }
149144
150145 override def mean : Vector = {
151146 require(totalCnt > 0 , s " Nothing has been added to this summarizer. " )
152147
153- val realMean = BDV .zeros [Double ](n)
148+ val realMean = Array .ofDim [Double ](n)
154149 var i = 0
155150 while (i < n) {
156151 realMean(i) = currMean(i) * (nnz(i) / totalCnt)
157152 i += 1
158153 }
159- Vectors .fromBreeze (realMean)
154+ Vectors .dense (realMean)
160155 }
161156
162157 override def variance : Vector = {
163158 require(totalCnt > 0 , s " Nothing has been added to this summarizer. " )
164159
165- val realVariance = BDV .zeros [Double ](n)
160+ val realVariance = Array .ofDim [Double ](n)
166161
167162 val denominator = totalCnt - 1.0
168163
@@ -177,16 +172,15 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
177172 i += 1
178173 }
179174 }
180-
181- Vectors .fromBreeze(realVariance)
175+ Vectors .dense(realVariance)
182176 }
183177
184178 override def count : Long = totalCnt
185179
186180 override def numNonzeros : Vector = {
187181 require(totalCnt > 0 , s " Nothing has been added to this summarizer. " )
188182
189- Vectors .fromBreeze (nnz)
183+ Vectors .dense (nnz)
190184 }
191185
192186 override def max : Vector = {
@@ -197,7 +191,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
197191 if ((nnz(i) < totalCnt) && (currMax(i) < 0.0 )) currMax(i) = 0.0
198192 i += 1
199193 }
200- Vectors .fromBreeze (currMax)
194+ Vectors .dense (currMax)
201195 }
202196
203197 override def min : Vector = {
@@ -208,25 +202,25 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
208202 if ((nnz(i) < totalCnt) && (currMin(i) > 0.0 )) currMin(i) = 0.0
209203 i += 1
210204 }
211- Vectors .fromBreeze (currMin)
205+ Vectors .dense (currMin)
212206 }
213207
214208 override def normL2 : Vector = {
215209 require(totalCnt > 0 , s " Nothing has been added to this summarizer. " )
216210
217- val realMagnitude = BDV .zeros [Double ](n)
211+ val realMagnitude = Array .ofDim [Double ](n)
218212
219213 var i = 0
220214 while (i < currM2.size) {
221215 realMagnitude(i) = math.sqrt(currM2(i))
222216 i += 1
223217 }
224-
225- Vectors .fromBreeze(realMagnitude)
218+ Vectors .dense(realMagnitude)
226219 }
227220
228221 override def normL1 : Vector = {
229222 require(totalCnt > 0 , s " Nothing has been added to this summarizer. " )
230- Vectors .fromBreeze(currL1)
223+
224+ Vectors .dense(currL1)
231225 }
232226}
0 commit comments