Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
add tests for copy summary
  • Loading branch information
sethah committed Nov 4, 2016
commit 9d2a64be0676718259fedd8d9090717fb2432457
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,9 @@ class BisectingKMeansModel private[ml] (

@Since("2.0.0")
override def copy(extra: ParamMap): BisectingKMeansModel = {
val copied = new BisectingKMeansModel(uid, parentModel)
copyValues(copied, extra)
val copied = copyValues(new BisectingKMeansModel(uid, parentModel), extra)
if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get)
copied.setParent(this.parent)
}

@Since("2.0.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,9 @@ class GaussianMixtureModel private[ml] (

@Since("2.0.0")
override def copy(extra: ParamMap): GaussianMixtureModel = {
val copied = new GaussianMixtureModel(uid, weights, gaussians)
copyValues(copied, extra).setParent(this.parent)
val copied = copyValues(new GaussianMixtureModel(uid, weights, gaussians), extra)
if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get)
copied.setParent(this.parent)
}

@Since("2.0.0")
Expand Down Expand Up @@ -169,6 +170,13 @@ class GaussianMixtureModel private[ml] (
throw new RuntimeException(
s"No training summary available for the ${this.getClass.getSimpleName}")
}

// @Since("2.1.0")
// override def copy(extra: ParamMap): GaussianMixtureModel = {
// val newModel = copyValues(new GaussianMixtureModel(uid, weights, gaussians), extra)
// if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get)
// newModel.setParent(parent)
// }
}

@Since("2.0.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,9 @@ class KMeansModel private[ml] (

@Since("1.5.0")
override def copy(extra: ParamMap): KMeansModel = {
val copied = new KMeansModel(uid, parentModel)
copyValues(copied, extra)
val copied = copyValues(new KMeansModel(uid, parentModel), extra)
if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get)
copied.setParent(this.parent)
}

/** @group setParam */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -776,8 +776,10 @@ class GeneralizedLinearRegressionModel private[ml] (

@Since("2.0.0")
override def copy(extra: ParamMap): GeneralizedLinearRegressionModel = {
copyValues(new GeneralizedLinearRegressionModel(uid, coefficients, intercept), extra)
.setParent(parent)
val copied = copyValues(new GeneralizedLinearRegressionModel(uid, coefficients, intercept),
extra)
if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get)
copied.setParent(parent)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.classification.LogisticRegressionSuite._
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.linalg.{DenseMatrix, Matrices, SparseMatrix, SparseVector, Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.util.MLlibTestSparkContext
Expand Down Expand Up @@ -141,6 +141,10 @@ class LogisticRegressionSuite
assert(model.getProbabilityCol === "probability")
assert(model.intercept !== 0.0)
assert(model.hasParent)

assert(model.hasSummary)
val copiedModel = model.copy(ParamMap.empty)
assert(copiedModel.hasSummary)
}

test("empty probabilityCol") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
package org.apache.spark.ml.clustering

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util.{MLTestingUtils, DefaultReadWriteTest}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.Dataset

Expand All @@ -41,6 +42,12 @@ class BisectingKMeansSuite
assert(bkm.getPredictionCol === "prediction")
assert(bkm.getMaxIter === 20)
assert(bkm.getMinDivisibleClusterSize === 1.0)

val model = bkm.setMaxIter(1).fit(dataset)
assert(model.hasSummary)
val copiedModel = model.copy(ParamMap.empty)
assert(copiedModel.hasSummary)
MLTestingUtils.checkCopy(model)
}

test("setter/getter") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
package org.apache.spark.ml.clustering

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util.{MLTestingUtils, DefaultReadWriteTest}
import org.apache.spark.mllib.util.{MLUtils, MLlibTestSparkContext}
import org.apache.spark.sql.Dataset


Expand All @@ -43,6 +45,12 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext
assert(gm.getPredictionCol === "prediction")
assert(gm.getMaxIter === 100)
assert(gm.getTol === 0.01)

val model = gm.setMaxIter(1).fit(dataset)
assert(model.hasSummary)
val copiedModel = model.copy(ParamMap.empty)
assert(copiedModel.hasSummary)
MLTestingUtils.checkCopy(model)
}

test("set parameters") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ package org.apache.spark.ml.clustering

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util.{MLTestingUtils, DefaultReadWriteTest}
import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
Expand Down Expand Up @@ -47,6 +48,12 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
assert(kmeans.getInitMode === MLlibKMeans.K_MEANS_PARALLEL)
assert(kmeans.getInitSteps === 2)
assert(kmeans.getTol === 1e-4)

val model = kmeans.setMaxIter(1).fit(dataset)
assert(model.hasSummary)
val copiedModel = model.copy(ParamMap.empty)
assert(copiedModel.hasSummary)
MLTestingUtils.checkCopy(model)
}

test("set parameters") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.ml.classification.LogisticRegressionSuite._
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.{BLAS, DenseVector, Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.random._
Expand Down Expand Up @@ -183,6 +183,9 @@ class GeneralizedLinearRegressionSuite

// copied model must have the same parent.
MLTestingUtils.checkCopy(model)
assert(model.hasSummary)
val copiedModel = model.copy(ParamMap.empty)
assert(copiedModel.hasSummary)

assert(model.getFeaturesCol === "features")
assert(model.getPredictionCol === "prediction")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.{DenseVector, Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
Expand Down Expand Up @@ -140,9 +140,13 @@ class LinearRegressionSuite
assert(lir.getStandardization)
assert(lir.getSolver == "auto")
val model = lir.fit(datasetWithDenseFeature)
assert(model.hasSummary)

// copied model must have the same parent.
MLTestingUtils.checkCopy(model)
val copiedModel = model.copy(ParamMap.empty)
assert(copiedModel.hasSummary)


model.transform(datasetWithDenseFeature)
.select("label", "prediction")
Expand Down