@@ -107,14 +107,17 @@ class LogisticRegression
107107 if (handlePersistence) instances.persist(StorageLevel .MEMORY_AND_DISK )
108108
109109 val (summarizer, labelSummarizer) = instances.treeAggregate(
110- (new MultivariateOnlineSummarizer , new MultiClassSummarizer ))( {
111- case ((summarizer : MultivariateOnlineSummarizer , labelSummarizer : MultiClassSummarizer ),
112- (label : Double , features : Vector )) =>
113- (summarizer.add(features), labelSummarizer.add(label))
114- }, {
115- case ((summarizer1 : MultivariateOnlineSummarizer , labelSummarizer1 : MultiClassSummarizer ),
116- (summarizer2 : MultivariateOnlineSummarizer , labelSummarizer2 : MultiClassSummarizer )) =>
117- (summarizer1.merge(summarizer2), labelSummarizer1.merge(labelSummarizer2))
110+ (new MultivariateOnlineSummarizer , new MultiClassSummarizer ))(
111+ seqOp = (c, v) => (c, v) match {
112+ case ((summarizer : MultivariateOnlineSummarizer , labelSummarizer : MultiClassSummarizer ),
113+ (label : Double , features : Vector )) =>
114+ (summarizer.add(features), labelSummarizer.add(label))
115+ },
116+ combOp = (c1, c2) => (c1, c2) match {
117+ case ((summarizer1 : MultivariateOnlineSummarizer ,
118+ classSummarizer1 : MultiClassSummarizer ), (summarizer2 : MultivariateOnlineSummarizer ,
119+ classSummarizer2 : MultiClassSummarizer )) =>
120+ (summarizer1.merge(summarizer2), classSummarizer1.merge(classSummarizer2))
118121 })
119122
120123 val histogram = labelSummarizer.histogram
@@ -123,15 +126,17 @@ class LogisticRegression
123126 val numFeatures = summarizer.mean.size
124127
125128 if (numInvalid != 0 ) {
126- logError(" Classification labels should be in {0 to " + (numClasses - 1 ) + " }. " +
127- " Found " + numInvalid + " invalid labels." )
128- throw new SparkException (" Input validation failed." )
129+ val msg = s " Classification labels should be in {0 to ${numClasses - 1 } " +
130+ s " Found $numInvalid invalid labels. "
131+ logError(msg)
132+ throw new SparkException (msg)
129133 }
130134
131135 if (numClasses > 2 ) {
132- logError(" Currently, LogisticRegression with ElasticNet in ML package only supports " +
133- " binary classification. Found " + numClasses + " in the input dataset." )
134- throw new SparkException (" Input validation failed." )
136+ val msg = s " Currently, LogisticRegression with ElasticNet in ML package only supports " +
137+ s " binary classification. Found $numClasses in the input dataset. "
138+ logError(msg)
139+ throw new SparkException (msg)
135140 }
136141
137142 val featuresMean = summarizer.mean.toArray
@@ -361,10 +366,13 @@ class MultiClassSummarizer private[ml] extends Serializable {
361366 largeMap
362367 }
363368
369+ /** @return The total invalid input counts. */
364370 def countInvalid : Long = totalInvalidCnt
365371
372+ /** @return The number of distinct labels in the input dataset. */
366373 def numClasses : Int = distinctMap.keySet.max + 1
367374
375+ /** @return The counts of each label in the input dataset. */
368376 def histogram : Array [Long ] = {
369377 val result = Array .ofDim[Long ](numClasses)
370378 var i = 0
@@ -377,11 +385,20 @@ class MultiClassSummarizer private[ml] extends Serializable {
377385}
378386
379387/**
380- * :: DeveloperApi ::
388+ * LogisticAggregator computes the gradient and loss for binary logistic loss function, as used
389+ * in binary classification for samples in sparse or dense vector in a online fashion.
390+ *
391+ * Note that multinomial logistic loss is not supported yet!
392+ *
393+ * Two LogisticAggregator can be merged together to have a summary of loss and gradient of
394+ * the corresponding joint dataset.
381395 *
396+ * @param weights The weights/coefficients corresponding to the features.
382397 * @param numClasses the number of possible outcomes for k classes classification problem in
383- * Multinomial Logistic Regression. By default, it is binary logistic regression
384- * so numClasses will be set to 2.
398+ * Multinomial Logistic Regression.
399+ * @param fitIntercept Whether to fit an intercept term.
400+ * @param featuresStd The standard deviation values of the features.
401+ * @param featuresMean The mean values of the features.
385402 */
386403private class LogisticAggregator (
387404 weights : Vector ,
0 commit comments