Skip to content

Commit f98e711

Browse files
author
DB Tsai
committed
address feedback
1 parent a784321 commit f98e711

File tree

2 files changed

+44
-25
lines changed

2 files changed

+44
-25
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -107,14 +107,17 @@ class LogisticRegression
107107
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
108108

109109
val (summarizer, labelSummarizer) = instances.treeAggregate(
110-
(new MultivariateOnlineSummarizer, new MultiClassSummarizer))( {
111-
case ((summarizer: MultivariateOnlineSummarizer, labelSummarizer: MultiClassSummarizer),
112-
(label: Double, features: Vector)) =>
113-
(summarizer.add(features), labelSummarizer.add(label))
114-
}, {
115-
case ((summarizer1: MultivariateOnlineSummarizer, labelSummarizer1: MultiClassSummarizer),
116-
(summarizer2: MultivariateOnlineSummarizer, labelSummarizer2: MultiClassSummarizer)) =>
117-
(summarizer1.merge(summarizer2), labelSummarizer1.merge(labelSummarizer2))
110+
(new MultivariateOnlineSummarizer, new MultiClassSummarizer))(
111+
seqOp = (c, v) => (c, v) match {
112+
case ((summarizer: MultivariateOnlineSummarizer, labelSummarizer: MultiClassSummarizer),
113+
(label: Double, features: Vector)) =>
114+
(summarizer.add(features), labelSummarizer.add(label))
115+
},
116+
combOp = (c1, c2) => (c1, c2) match {
117+
case ((summarizer1: MultivariateOnlineSummarizer,
118+
classSummarizer1: MultiClassSummarizer), (summarizer2: MultivariateOnlineSummarizer,
119+
classSummarizer2: MultiClassSummarizer)) =>
120+
(summarizer1.merge(summarizer2), classSummarizer1.merge(classSummarizer2))
118121
})
119122

120123
val histogram = labelSummarizer.histogram
@@ -123,15 +126,17 @@ class LogisticRegression
123126
val numFeatures = summarizer.mean.size
124127

125128
if (numInvalid != 0) {
126-
logError("Classification labels should be in {0 to " + (numClasses - 1) + "}. " +
127-
"Found " + numInvalid + " invalid labels.")
128-
throw new SparkException("Input validation failed.")
129+
val msg = s"Classification labels should be in {0 to ${numClasses - 1} " +
130+
s"Found $numInvalid invalid labels."
131+
logError(msg)
132+
throw new SparkException(msg)
129133
}
130134

131135
if (numClasses > 2) {
132-
logError("Currently, LogisticRegression with ElasticNet in ML package only supports " +
133-
"binary classification. Found " + numClasses + " in the input dataset.")
134-
throw new SparkException("Input validation failed.")
136+
val msg = s"Currently, LogisticRegression with ElasticNet in ML package only supports " +
137+
s"binary classification. Found $numClasses in the input dataset."
138+
logError(msg)
139+
throw new SparkException(msg)
135140
}
136141

137142
val featuresMean = summarizer.mean.toArray
@@ -361,10 +366,13 @@ class MultiClassSummarizer private[ml] extends Serializable {
361366
largeMap
362367
}
363368

369+
/** @return The total invalid input counts. */
364370
def countInvalid: Long = totalInvalidCnt
365371

372+
/** @return The number of distinct labels in the input dataset. */
366373
def numClasses: Int = distinctMap.keySet.max + 1
367374

375+
/** @return The counts of each label in the input dataset. */
368376
def histogram: Array[Long] = {
369377
val result = Array.ofDim[Long](numClasses)
370378
var i = 0
@@ -377,11 +385,20 @@ class MultiClassSummarizer private[ml] extends Serializable {
377385
}
378386

379387
/**
380-
* :: DeveloperApi ::
388+
* LogisticAggregator computes the gradient and loss for binary logistic loss function, as used
389+
* in binary classification for samples in sparse or dense vector in a online fashion.
390+
*
391+
* Note that multinomial logistic loss is not supported yet!
392+
*
393+
* Two LogisticAggregator can be merged together to have a summary of loss and gradient of
394+
* the corresponding joint dataset.
381395
*
396+
* @param weights The weights/coefficients corresponding to the features.
382397
* @param numClasses the number of possible outcomes for k classes classification problem in
383-
* Multinomial Logistic Regression. By default, it is binary logistic regression
384-
* so numClasses will be set to 2.
398+
* Multinomial Logistic Regression.
399+
* @param fitIntercept Whether to fit an intercept term.
400+
* @param featuresStd The standard deviation values of the features.
401+
* @param featuresMean The mean values of the features.
385402
*/
386403
private class LogisticAggregator(
387404
weights: Vector,

mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -105,14 +105,16 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
105105
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
106106

107107
val (summarizer, statCounter) = instances.treeAggregate(
108-
(new MultivariateOnlineSummarizer, new StatCounter))( {
109-
case ((summarizer: MultivariateOnlineSummarizer, statCounter: StatCounter),
110-
(label: Double, features: Vector)) =>
111-
(summarizer.add(features), statCounter.merge(label))
112-
}, {
113-
case ((summarizer1: MultivariateOnlineSummarizer, statCounter1: StatCounter),
114-
(summarizer2: MultivariateOnlineSummarizer, statCounter2: StatCounter)) =>
115-
(summarizer1.merge(summarizer2), statCounter1.merge(statCounter2))
108+
(new MultivariateOnlineSummarizer, new StatCounter))(
109+
seqOp = (c, v) => (c, v) match {
110+
case ((summarizer: MultivariateOnlineSummarizer, statCounter: StatCounter),
111+
(label: Double, features: Vector)) =>
112+
(summarizer.add(features), statCounter.merge(label))
113+
},
114+
combOp = (c1, c2) => (c1, c2) match {
115+
case ((summarizer1: MultivariateOnlineSummarizer, statCounter1: StatCounter),
116+
(summarizer2: MultivariateOnlineSummarizer, statCounter2: StatCounter)) =>
117+
(summarizer1.merge(summarizer2), statCounter1.merge(statCounter2))
116118
})
117119

118120
val numFeatures = summarizer.mean.size

0 commit comments

Comments
 (0)