Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,7 @@ class LogisticRegression @Since("1.2.0") (
$(labelCol),
$(featuresCol),
objectiveHistory)
model.setSummary(logRegSummary)
model.setSummary(Some(logRegSummary))
} else {
model
}
Expand Down Expand Up @@ -803,9 +803,9 @@ class LogisticRegressionModel private[spark] (
}
}

private[classification] def setSummary(
summary: LogisticRegressionTrainingSummary): this.type = {
this.trainingSummary = Some(summary)
private[classification]
def setSummary(summary: Option[LogisticRegressionTrainingSummary]): this.type = {
this.trainingSummary = summary
this
}

Expand Down Expand Up @@ -900,8 +900,7 @@ class LogisticRegressionModel private[spark] (
override def copy(extra: ParamMap): LogisticRegressionModel = {
val newModel = copyValues(new LogisticRegressionModel(uid, coefficientMatrix, interceptVector,
numClasses, isMultinomial), extra)
if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get)
newModel.setParent(parent)
newModel.setSummary(trainingSummary).setParent(parent)
}

override protected def raw2prediction(rawPrediction: Vector): Double = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,7 @@ class BisectingKMeansModel private[ml] (
@Since("2.0.0")
override def copy(extra: ParamMap): BisectingKMeansModel = {
val copied = copyValues(new BisectingKMeansModel(uid, parentModel), extra)
if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get)
copied.setParent(this.parent)
copied.setSummary(trainingSummary).setParent(this.parent)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks better. Could you make the change for Scala LiR, LoR, GLM and KMeans as well? I think they should be consistent. Thanks.

Copy link
Contributor Author

@sethah sethah Nov 21, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated. I also added tests. Thanks for reviewing!

}

@Since("2.0.0")
Expand Down Expand Up @@ -132,8 +131,8 @@ class BisectingKMeansModel private[ml] (

private var trainingSummary: Option[BisectingKMeansSummary] = None

private[clustering] def setSummary(summary: BisectingKMeansSummary): this.type = {
this.trainingSummary = Some(summary)
private[clustering] def setSummary(summary: Option[BisectingKMeansSummary]): this.type = {
this.trainingSummary = summary
this
}

Expand Down Expand Up @@ -265,7 +264,7 @@ class BisectingKMeans @Since("2.0.0") (
val model = copyValues(new BisectingKMeansModel(uid, parentModel).setParent(this))
val summary = new BisectingKMeansSummary(
model.transform(dataset), $(predictionCol), $(featuresCol), $(k))
model.setSummary(summary)
model.setSummary(Some(summary))
instr.logSuccess(model)
model
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,7 @@ class GaussianMixtureModel private[ml] (
@Since("2.0.0")
override def copy(extra: ParamMap): GaussianMixtureModel = {
val copied = copyValues(new GaussianMixtureModel(uid, weights, gaussians), extra)
if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get)
copied.setParent(this.parent)
copied.setSummary(trainingSummary).setParent(this.parent)
}

@Since("2.0.0")
Expand Down Expand Up @@ -150,8 +149,8 @@ class GaussianMixtureModel private[ml] (

private var trainingSummary: Option[GaussianMixtureSummary] = None

private[clustering] def setSummary(summary: GaussianMixtureSummary): this.type = {
this.trainingSummary = Some(summary)
private[clustering] def setSummary(summary: Option[GaussianMixtureSummary]): this.type = {
this.trainingSummary = summary
this
}

Expand Down Expand Up @@ -340,7 +339,7 @@ class GaussianMixture @Since("2.0.0") (
.setParent(this)
val summary = new GaussianMixtureSummary(model.transform(dataset),
$(predictionCol), $(probabilityCol), $(featuresCol), $(k))
model.setSummary(summary)
model.setSummary(Some(summary))
instr.logNumFeatures(model.gaussians.head.mean.size)
instr.logSuccess(model)
model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,7 @@ class KMeansModel private[ml] (
@Since("1.5.0")
override def copy(extra: ParamMap): KMeansModel = {
val copied = copyValues(new KMeansModel(uid, parentModel), extra)
if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get)
copied.setParent(this.parent)
copied.setSummary(trainingSummary).setParent(this.parent)
}

/** @group setParam */
Expand Down Expand Up @@ -165,8 +164,8 @@ class KMeansModel private[ml] (

private var trainingSummary: Option[KMeansSummary] = None

private[clustering] def setSummary(summary: KMeansSummary): this.type = {
this.trainingSummary = Some(summary)
private[clustering] def setSummary(summary: Option[KMeansSummary]): this.type = {
this.trainingSummary = summary
this
}

Expand Down Expand Up @@ -325,7 +324,7 @@ class KMeans @Since("1.5.0") (
val model = copyValues(new KMeansModel(uid, parentModel).setParent(this))
val summary = new KMeansSummary(
model.transform(dataset), $(predictionCol), $(featuresCol), $(k))
model.setSummary(summary)
model.setSummary(Some(summary))
instr.logSuccess(model)
model
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
.setParent(this))
val trainingSummary = new GeneralizedLinearRegressionTrainingSummary(dataset, model,
wlsModel.diagInvAtWA.toArray, 1, getSolver)
return model.setSummary(trainingSummary)
return model.setSummary(Some(trainingSummary))
}

// Fit Generalized Linear Model by iteratively reweighted least squares (IRLS).
Expand All @@ -284,7 +284,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
.setParent(this))
val trainingSummary = new GeneralizedLinearRegressionTrainingSummary(dataset, model,
irlsModel.diagInvAtWA.toArray, irlsModel.numIterations, getSolver)
model.setSummary(trainingSummary)
model.setSummary(Some(trainingSummary))
}

@Since("2.0.0")
Expand Down Expand Up @@ -761,8 +761,8 @@ class GeneralizedLinearRegressionModel private[ml] (
def hasSummary: Boolean = trainingSummary.nonEmpty

private[regression]
def setSummary(summary: GeneralizedLinearRegressionTrainingSummary): this.type = {
this.trainingSummary = Some(summary)
def setSummary(summary: Option[GeneralizedLinearRegressionTrainingSummary]): this.type = {
this.trainingSummary = summary
this
}

Expand All @@ -778,8 +778,7 @@ class GeneralizedLinearRegressionModel private[ml] (
override def copy(extra: ParamMap): GeneralizedLinearRegressionModel = {
val copied = copyValues(new GeneralizedLinearRegressionModel(uid, coefficients, intercept),
extra)
if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get)
copied.setParent(parent)
copied.setSummary(trainingSummary).setParent(parent)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
model.diagInvAtWA.toArray,
model.objectiveHistory)

return lrModel.setSummary(trainingSummary)
return lrModel.setSummary(Some(trainingSummary))
}

val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
Expand Down Expand Up @@ -276,7 +276,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
model,
Array(0D),
Array(0D))
return model.setSummary(trainingSummary)
return model.setSummary(Some(trainingSummary))
} else {
require($(regParam) == 0.0, "The standard deviation of the label is zero. " +
"Model cannot be regularized.")
Expand Down Expand Up @@ -398,7 +398,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
model,
Array(0D),
objectiveHistory)
model.setSummary(trainingSummary)
model.setSummary(Some(trainingSummary))
}

@Since("1.4.0")
Expand Down Expand Up @@ -444,8 +444,9 @@ class LinearRegressionModel private[ml] (
throw new SparkException("No training summary available for this LinearRegressionModel")
}

private[regression] def setSummary(summary: LinearRegressionTrainingSummary): this.type = {
this.trainingSummary = Some(summary)
private[regression]
def setSummary(summary: Option[LinearRegressionTrainingSummary]): this.type = {
this.trainingSummary = summary
this
}

Expand Down Expand Up @@ -488,8 +489,7 @@ class LinearRegressionModel private[ml] (
@Since("1.4.0")
override def copy(extra: ParamMap): LinearRegressionModel = {
val newModel = copyValues(new LinearRegressionModel(uid, coefficients, intercept), extra)
if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get)
newModel.setParent(parent)
newModel.setSummary(trainingSummary).setParent(parent)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ class LogisticRegressionSuite
assert(model.hasSummary)
val copiedModel = model.copy(ParamMap.empty)
assert(copiedModel.hasSummary)
model.setSummary(None)
assert(!model.hasSummary)
}

test("empty probabilityCol") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ class BisectingKMeansSuite
assert(clusterSizes.length === k)
assert(clusterSizes.sum === numRows)
assert(clusterSizes.forall(_ >= 0))

model.setSummary(None)
assert(!model.hasSummary)
}

test("read/write") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext
assert(clusterSizes.length === k)
assert(clusterSizes.sum === numRows)
assert(clusterSizes.forall(_ >= 0))

model.setSummary(None)
assert(!model.hasSummary)
}

test("read/write") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
assert(clusterSizes.length === k)
assert(clusterSizes.sum === numRows)
assert(clusterSizes.forall(_ >= 0))

model.setSummary(None)
assert(!model.hasSummary)
}

test("KMeansModel transform with non-default feature and prediction cols") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ class GeneralizedLinearRegressionSuite
assert(model.hasSummary)
val copiedModel = model.copy(ParamMap.empty)
assert(copiedModel.hasSummary)
model.setSummary(None)
assert(!model.hasSummary)

assert(model.getFeaturesCol === "features")
assert(model.getPredictionCol === "prediction")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ class LinearRegressionSuite
assert(model.hasSummary)
val copiedModel = model.copy(ParamMap.empty)
assert(copiedModel.hasSummary)
model.setSummary(None)
assert(!model.hasSummary)

model.transform(datasetWithDenseFeature)
.select("label", "prediction")
Expand Down
15 changes: 9 additions & 6 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,13 +309,16 @@ def interceptVector(self):
@since("2.0.0")
def summary(self):
"""
Gets summary (e.g. residuals, mse, r-squared ) of model on
training set. An exception is thrown if
`trainingSummary is None`.
Gets summary (e.g. accuracy/precision/recall, objective history, total iterations) of model
trained on the training set. An exception is thrown if `trainingSummary is None`.
"""
java_blrt_summary = self._call_java("summary")
# Note: Once multiclass is added, update this to return correct summary
return BinaryLogisticRegressionTrainingSummary(java_blrt_summary)
if self.hasSummary:
java_blrt_summary = self._call_java("summary")
# Note: Once multiclass is added, update this to return correct summary
return BinaryLogisticRegressionTrainingSummary(java_blrt_summary)
else:
raise RuntimeError("No training summary available for this %s" %
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before, this would throw a Py4JJavaError. I think it's slightly better to throw a RuntimeError here as is done in Scala.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think thats generally a good improvement, the Py4J errors are often confusing to end users.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this change, we should always throw an exception easy to understand by users.

self.__class__.__name__)

@property
@since("2.0.0")
Expand Down
Loading