Skip to content
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -934,8 +934,8 @@ class LogisticRegressionModel private[spark] (
@Since("2.1.0") val interceptVector: Vector,
@Since("1.3.0") override val numClasses: Int,
private val isMultinomial: Boolean)
extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel]
with LogisticRegressionParams with MLWritable {
extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel] with MLWritable
with LogisticRegressionParams with HasTrainingSummary[LogisticRegressionTrainingSummary] {

require(coefficientMatrix.numRows == interceptVector.size, s"Dimension mismatch! Expected " +
s"coefficientMatrix.numRows == interceptVector.size, but ${coefficientMatrix.numRows} != " +
Expand Down Expand Up @@ -1018,16 +1018,12 @@ class LogisticRegressionModel private[spark] (
@Since("1.6.0")
override val numFeatures: Int = coefficientMatrix.numCols

private var trainingSummary: Option[LogisticRegressionTrainingSummary] = None

/**
* Gets summary of model on training set. An exception is thrown
* if `trainingSummary == None`.
*/
@Since("1.5.0")
def summary: LogisticRegressionTrainingSummary = trainingSummary.getOrElse {
throw new SparkException("No training summary available for this LogisticRegressionModel")
}
override def summary: LogisticRegressionTrainingSummary = super.summary

/**
* Gets summary of model on training set. An exception is thrown
Expand Down Expand Up @@ -1062,16 +1058,6 @@ class LogisticRegressionModel private[spark] (
(model, model.getProbabilityCol, model.getPredictionCol)
}

private[classification]
def setSummary(summary: Option[LogisticRegressionTrainingSummary]): this.type = {
this.trainingSummary = summary
this
}

/** Indicates whether a training summary exists for this model instance. */
@Since("1.5.0")
def hasSummary: Boolean = trainingSummary.isDefined

/**
* Evaluates the model on a test dataset.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,9 @@ private[clustering] trait BisectingKMeansParams extends Params with HasMaxIter
@Since("2.0.0")
class BisectingKMeansModel private[ml] (
@Since("2.0.0") override val uid: String,
private val parentModel: MLlibBisectingKMeansModel
) extends Model[BisectingKMeansModel] with BisectingKMeansParams with MLWritable {
private val parentModel: MLlibBisectingKMeansModel)
extends Model[BisectingKMeansModel] with BisectingKMeansParams with MLWritable
with HasTrainingSummary[BisectingKMeansSummary] {

@Since("2.0.0")
override def copy(extra: ParamMap): BisectingKMeansModel = {
Expand Down Expand Up @@ -143,28 +144,12 @@ class BisectingKMeansModel private[ml] (
@Since("2.0.0")
override def write: MLWriter = new BisectingKMeansModel.BisectingKMeansModelWriter(this)

private var trainingSummary: Option[BisectingKMeansSummary] = None

private[clustering] def setSummary(summary: Option[BisectingKMeansSummary]): this.type = {
this.trainingSummary = summary
this
}

/**
* Return true if there exists summary of model.
*/
@Since("2.1.0")
def hasSummary: Boolean = trainingSummary.nonEmpty

/**
* Gets summary of model on training set. An exception is
* thrown if `trainingSummary == None`.
*/
@Since("2.1.0")
def summary: BisectingKMeansSummary = trainingSummary.getOrElse {
throw new SparkException(
s"No training summary available for the ${this.getClass.getSimpleName}")
}
override def summary: BisectingKMeansSummary = super.summary
}

object BisectingKMeansModel extends MLReadable[BisectingKMeansModel] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ class GaussianMixtureModel private[ml] (
@Since("2.0.0") override val uid: String,
@Since("2.0.0") val weights: Array[Double],
@Since("2.0.0") val gaussians: Array[MultivariateGaussian])
extends Model[GaussianMixtureModel] with GaussianMixtureParams with MLWritable {
extends Model[GaussianMixtureModel] with GaussianMixtureParams with MLWritable
with HasTrainingSummary[GaussianMixtureSummary] {

/** @group setParam */
@Since("2.1.0")
Expand Down Expand Up @@ -160,28 +161,13 @@ class GaussianMixtureModel private[ml] (
@Since("2.0.0")
override def write: MLWriter = new GaussianMixtureModel.GaussianMixtureModelWriter(this)

private var trainingSummary: Option[GaussianMixtureSummary] = None

private[clustering] def setSummary(summary: Option[GaussianMixtureSummary]): this.type = {
this.trainingSummary = summary
this
}

/**
* Return true if there exists summary of model.
*/
@Since("2.0.0")
def hasSummary: Boolean = trainingSummary.nonEmpty

/**
* Gets summary of model on training set. An exception is
* thrown if `trainingSummary == None`.
*/
@Since("2.0.0")
def summary: GaussianMixtureSummary = trainingSummary.getOrElse {
throw new RuntimeException(
s"No training summary available for the ${this.getClass.getSimpleName}")
}
override def summary: GaussianMixtureSummary = super.summary

}

@Since("2.0.0")
Expand Down
21 changes: 3 additions & 18 deletions mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe
class KMeansModel private[ml] (
@Since("1.5.0") override val uid: String,
private[clustering] val parentModel: MLlibKMeansModel)
extends Model[KMeansModel] with KMeansParams with GeneralMLWritable {
extends Model[KMeansModel] with KMeansParams with GeneralMLWritable
with HasTrainingSummary[KMeansSummary] {

@Since("1.5.0")
override def copy(extra: ParamMap): KMeansModel = {
Expand Down Expand Up @@ -153,28 +154,12 @@ class KMeansModel private[ml] (
@Since("1.6.0")
override def write: GeneralMLWriter = new GeneralMLWriter(this)

private var trainingSummary: Option[KMeansSummary] = None

private[clustering] def setSummary(summary: Option[KMeansSummary]): this.type = {
this.trainingSummary = summary
this
}

/**
* Return true if there exists summary of model.
*/
@Since("2.0.0")
def hasSummary: Boolean = trainingSummary.nonEmpty

/**
* Gets summary of model on training set. An exception is
* thrown if `trainingSummary == None`.
*/
@Since("2.0.0")
def summary: KMeansSummary = trainingSummary.getOrElse {
throw new SparkException(
s"No training summary available for the ${this.getClass.getSimpleName}")
}
override def summary: KMeansSummary = super.summary
}

/** Helper class for storing model data */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1001,7 +1001,8 @@ class GeneralizedLinearRegressionModel private[ml] (
@Since("2.0.0") val coefficients: Vector,
@Since("2.0.0") val intercept: Double)
extends RegressionModel[Vector, GeneralizedLinearRegressionModel]
with GeneralizedLinearRegressionBase with MLWritable {
with GeneralizedLinearRegressionBase with MLWritable
with HasTrainingSummary[GeneralizedLinearRegressionTrainingSummary]{
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: space before braces

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated. Thanks

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like there still isn't a space here.


/**
* Sets the link prediction (linear predictor) column name.
Expand Down Expand Up @@ -1054,29 +1055,12 @@ class GeneralizedLinearRegressionModel private[ml] (
output.toDF()
}

private var trainingSummary: Option[GeneralizedLinearRegressionTrainingSummary] = None

/**
* Gets R-like summary of model on training set. An exception is
* thrown if there is no summary available.
*/
@Since("2.0.0")
def summary: GeneralizedLinearRegressionTrainingSummary = trainingSummary.getOrElse {
throw new SparkException(
"No training summary available for this GeneralizedLinearRegressionModel")
}

/**
* Indicates if [[summary]] is available.
*/
@Since("2.0.0")
def hasSummary: Boolean = trainingSummary.nonEmpty

private[regression]
def setSummary(summary: Option[GeneralizedLinearRegressionTrainingSummary]): this.type = {
this.trainingSummary = summary
this
}
override def summary: GeneralizedLinearRegressionTrainingSummary = super.summary

/**
* Evaluate the model on the given dataset, returning a summary of the results.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -647,33 +647,20 @@ class LinearRegressionModel private[ml] (
@Since("1.3.0") val intercept: Double,
@Since("2.3.0") val scale: Double)
extends RegressionModel[Vector, LinearRegressionModel]
with LinearRegressionParams with GeneralMLWritable {
with LinearRegressionParams with GeneralMLWritable
with HasTrainingSummary[LinearRegressionTrainingSummary] {

private[ml] def this(uid: String, coefficients: Vector, intercept: Double) =
this(uid, coefficients, intercept, 1.0)

private var trainingSummary: Option[LinearRegressionTrainingSummary] = None

override val numFeatures: Int = coefficients.size

/**
* Gets summary (e.g. residuals, mse, r-squared ) of model on training set. An exception is
* thrown if `trainingSummary == None`.
*/
@Since("1.5.0")
def summary: LinearRegressionTrainingSummary = trainingSummary.getOrElse {
throw new SparkException("No training summary available for this LinearRegressionModel")
}

private[regression]
def setSummary(summary: Option[LinearRegressionTrainingSummary]): this.type = {
this.trainingSummary = summary
this
}

/** Indicates whether a training summary exists for this model instance. */
@Since("1.5.0")
def hasSummary: Boolean = trainingSummary.isDefined
override def summary: LinearRegressionTrainingSummary = super.summary
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the one hand, you don't need these overrides for this to work correctly, right? but I suppose it's necessary to preserve the @Since tag, which varies across implementations. But these were mostly introduced in 1.5.0, and where they have a later @Since tag, it matches when the class was introduced. I think it would also be coherent, for Spark 3.0, to remove these overrides, and mark the methods in the new trait as @Since 1.5.0. The result would be similar to what would happen if this had been introduced at the start. I don't feel strongly about it but what do you think? would clean up the code a little more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got an error message from Java side when removing summary

 /home/yuhao/workspace/github/hhbyyh/spark/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java:145: error: incompatible types: Object cannot be converted to LogisticRegressionTrainingSummary
[error]     LogisticRegressionTrainingSummary summary = model.summary();

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah OK nevermind then. Thanks for checking.


/**
* Evaluates the model on a test dataset.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.util

import org.apache.spark.SparkException
import org.apache.spark.annotation.Since


/**
* Trait for models that provides Training summary.
*
* @tparam T Summary instance type
*/
@Since("3.0.0")
private[ml] trait HasTrainingSummary[T] {

private[ml] final var trainingSummary: Option[T] = None

/** Indicates whether a training summary exists for this model instance. */
@Since("3.0.0")
def hasSummary: Boolean = trainingSummary.isDefined

/**
* Gets summary of model on training set. An exception is
* thrown if `trainingSummary == None`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: from the callers perspective they don't know what trainingSummary is. "if hasSummary is false"?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One more nit @hhbyyh - can we change this one too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Thanks for checking.

*/
@Since("3.0.0")
def summary: T = trainingSummary.getOrElse {
throw new SparkException(
s"No training summary available for this ${this.getClass.getSimpleName}")
}

private[ml] def setSummary(summary: Option[T]): this.type = {
this.trainingSummary = summary
this
}
}