Skip to content

Commit 626fb64

Browse files
dwmclaryJames Z.M. Gao
authored andcommitted
Spark 1246 add min max to stat counter
Here's the addition of min and max to statscounter.py and min and max methods to rdd.py. Author: Dan McClary <[email protected]> Closes apache#144 from dwmclary/SPARK-1246-add-min-max-to-stat-counter and squashes the following commits: fd3fd4b [Dan McClary] fixed error, updated test 82cde0e [Dan McClary] flipped incorrectly assigned inf values in StatCounter 5d96799 [Dan McClary] added max and min to StatCounter repr for pyspark 21dd366 [Dan McClary] added max and min to StatCounter output, updated doc 1a97558 [Dan McClary] added max and min to StatCounter output, updated doc a5c13b0 [Dan McClary] Added min and max to Scala and Java RDD, added min and max to StatCounter ed67136 [Dan McClary] broke min/max out into separate transaction, added to rdd.py 1e7056d [Dan McClary] added underscore to getBucket 37a7dea [Dan McClary] cleaned up boundaries for histogram -- uses real min/max when buckets are derived 29981f2 [Dan McClary] fixed indentation on doctest comment eaf89d9 [Dan McClary] added correct doctest for histogram 4916016 [Dan McClary] added histogram method, added max and min to statscounter
1 parent 48ee62f commit 626fb64

File tree

7 files changed

+93
-5
lines changed

7 files changed

+93
-5
lines changed

core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,26 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
439439
new java.util.ArrayList(arr)
440440
}
441441

442+
/**
443+
* Returns the maximum element from this RDD as defined by the specified
444+
* Comparator[T].
445+
* @params comp the comparator that defines ordering
446+
* @return the maximum of the RDD
447+
* */
448+
def max(comp: Comparator[T]): T = {
449+
rdd.max()(Ordering.comparatorToOrdering(comp))
450+
}
451+
452+
/**
453+
* Returns the minimum element from this RDD as defined by the specified
454+
* Comparator[T].
455+
* @params comp the comparator that defines ordering
456+
* @return the minimum of the RDD
457+
* */
458+
def min(comp: Comparator[T]): T = {
459+
rdd.min()(Ordering.comparatorToOrdering(comp))
460+
}
461+
442462
/**
443463
* Returns the first K elements from this RDD using the
444464
* natural ordering for T while maintain the order.

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -941,6 +941,18 @@ abstract class RDD[T: ClassTag](
941941
*/
942942
def takeOrdered(num: Int)(implicit ord: Ordering[T]): Array[T] = top(num)(ord.reverse)
943943

944+
/**
945+
* Returns the max of this RDD as defined by the implicit Ordering[T].
946+
* @return the maximum element of the RDD
947+
* */
948+
def max()(implicit ord: Ordering[T]):T = this.reduce(ord.max)
949+
950+
/**
951+
* Returns the min of this RDD as defined by the implicit Ordering[T].
952+
* @return the minimum element of the RDD
953+
* */
954+
def min()(implicit ord: Ordering[T]):T = this.reduce(ord.min)
955+
944956
/**
945957
* Save this RDD as a text file, using string representations of elements.
946958
*/

core/src/main/scala/org/apache/spark/util/StatCounter.scala

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable {
2929
private var n: Long = 0 // Running count of our values
3030
private var mu: Double = 0 // Running mean of our values
3131
private var m2: Double = 0 // Running variance numerator (sum of (x - mean)^2)
32+
private var maxValue: Double = Double.NegativeInfinity // Running max of our values
33+
private var minValue: Double = Double.PositiveInfinity // Running min of our values
3234

3335
merge(values)
3436

@@ -41,6 +43,8 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable {
4143
n += 1
4244
mu += delta / n
4345
m2 += delta * (value - mu)
46+
maxValue = math.max(maxValue, value)
47+
minValue = math.min(minValue, value)
4448
this
4549
}
4650

@@ -58,7 +62,9 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable {
5862
if (n == 0) {
5963
mu = other.mu
6064
m2 = other.m2
61-
n = other.n
65+
n = other.n
66+
maxValue = other.maxValue
67+
minValue = other.minValue
6268
} else if (other.n != 0) {
6369
val delta = other.mu - mu
6470
if (other.n * 10 < n) {
@@ -70,6 +76,8 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable {
7076
}
7177
m2 += other.m2 + (delta * delta * n * other.n) / (n + other.n)
7278
n += other.n
79+
maxValue = math.max(maxValue, other.maxValue)
80+
minValue = math.min(minValue, other.minValue)
7381
}
7482
this
7583
}
@@ -81,6 +89,8 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable {
8189
other.n = n
8290
other.mu = mu
8391
other.m2 = m2
92+
other.maxValue = maxValue
93+
other.minValue = minValue
8494
other
8595
}
8696

@@ -90,6 +100,10 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable {
90100

91101
def sum: Double = n * mu
92102

103+
def max: Double = maxValue
104+
105+
def min: Double = minValue
106+
93107
/** Return the variance of the values. */
94108
def variance: Double = {
95109
if (n == 0)
@@ -119,7 +133,7 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable {
119133
def sampleStdev: Double = math.sqrt(sampleVariance)
120134

121135
override def toString: String = {
122-
"(count: %d, mean: %f, stdev: %f)".format(count, mean, stdev)
136+
"(count: %d, mean: %f, stdev: %f, max: %f, min: %f)".format(count, mean, stdev, max, min)
123137
}
124138
}
125139

core/src/test/scala/org/apache/spark/PartitioningSuite.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@ class PartitioningSuite extends FunSuite with SharedSparkContext {
147147
assert(abs(6.0/2 - rdd.mean) < 0.01)
148148
assert(abs(1.0 - rdd.variance) < 0.01)
149149
assert(abs(1.0 - rdd.stdev) < 0.01)
150+
assert(stats.max === 4.0)
151+
assert(stats.min === 2.0)
150152

151153
// Add other tests here for classes that should be able to handle empty partitions correctly
152154
}

core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ class RDDSuite extends FunSuite with SharedSparkContext {
4545
assert(nums.glom().map(_.toList).collect().toList === List(List(1, 2), List(3, 4)))
4646
assert(nums.collect({ case i if i >= 3 => i.toString }).collect().toList === List("3", "4"))
4747
assert(nums.keyBy(_.toString).collect().toList === List(("1", 1), ("2", 2), ("3", 3), ("4", 4)))
48+
assert(nums.max() === 4)
49+
assert(nums.min() === 1)
4850
val partitionSums = nums.mapPartitions(iter => Iterator(iter.reduceLeft(_ + _)))
4951
assert(partitionSums.collect().toList === List(3, 7))
5052

python/pyspark/rdd.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,26 @@ def func(iterator):
526526
return reduce(op, vals, zeroValue)
527527

528528
# TODO: aggregate
529+
529530

531+
def max(self):
532+
"""
533+
Find the maximum item in this RDD.
534+
535+
>>> sc.parallelize([1.0, 5.0, 43.0, 10.0]).max()
536+
43.0
537+
"""
538+
return self.reduce(max)
539+
540+
def min(self):
541+
"""
542+
Find the maximum item in this RDD.
543+
544+
>>> sc.parallelize([1.0, 5.0, 43.0, 10.0]).min()
545+
1.0
546+
"""
547+
return self.reduce(min)
548+
530549
def sum(self):
531550
"""
532551
Add up the elements in this RDD.

python/pyspark/statcounter.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ def __init__(self, values=[]):
2626
self.n = 0L # Running count of our values
2727
self.mu = 0.0 # Running mean of our values
2828
self.m2 = 0.0 # Running variance numerator (sum of (x - mean)^2)
29-
29+
self.maxValue = float("-inf")
30+
self.minValue = float("inf")
31+
3032
for v in values:
3133
self.merge(v)
3234

@@ -36,6 +38,11 @@ def merge(self, value):
3638
self.n += 1
3739
self.mu += delta / self.n
3840
self.m2 += delta * (value - self.mu)
41+
if self.maxValue < value:
42+
self.maxValue = value
43+
if self.minValue > value:
44+
self.minValue = value
45+
3946
return self
4047

4148
# Merge another StatCounter into this one, adding up the internal statistics.
@@ -49,7 +56,10 @@ def mergeStats(self, other):
4956
if self.n == 0:
5057
self.mu = other.mu
5158
self.m2 = other.m2
52-
self.n = other.n
59+
self.n = other.n
60+
self.maxValue = other.maxValue
61+
self.minValue = other.minValue
62+
5363
elif other.n != 0:
5464
delta = other.mu - self.mu
5565
if other.n * 10 < self.n:
@@ -58,6 +68,9 @@ def mergeStats(self, other):
5868
self.mu = other.mu - (delta * self.n) / (self.n + other.n)
5969
else:
6070
self.mu = (self.mu * self.n + other.mu * other.n) / (self.n + other.n)
71+
72+
self.maxValue = max(self.maxValue, other.maxValue)
73+
self.minValue = min(self.minValue, other.minValue)
6174

6275
self.m2 += other.m2 + (delta * delta * self.n * other.n) / (self.n + other.n)
6376
self.n += other.n
@@ -76,6 +89,12 @@ def mean(self):
7689
def sum(self):
7790
return self.n * self.mu
7891

92+
def min(self):
93+
return self.minValue
94+
95+
def max(self):
96+
return self.maxValue
97+
7998
# Return the variance of the values.
8099
def variance(self):
81100
if self.n == 0:
@@ -105,5 +124,5 @@ def sampleStdev(self):
105124
return math.sqrt(self.sampleVariance())
106125

107126
def __repr__(self):
108-
return "(count: %s, mean: %s, stdev: %s)" % (self.count(), self.mean(), self.stdev())
127+
return "(count: %s, mean: %s, stdev: %s, max: %s, min: %s)" % (self.count(), self.mean(), self.stdev(), self.max(), self.min())
109128

0 commit comments

Comments
 (0)