Skip to content

Commit 39c025f

Browse files
committed
rename to modelPreservePath
1 parent c7e0bcd commit 39c025f

File tree

5 files changed

+24
-21
lines changed

5 files changed

+24
-21
lines changed

mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,13 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
9292
def setSeed(value: Long): this.type = set(seed, value)
9393

9494
/**
95-
* If set, all the models fitted during the cross validation will be preserved
96-
* under the specific directory path. By default the models will not be saved.
95+
* Optional parameter. If set, all the trained models during cross validation will be
96+
* saved in the specific path. By default the models will not be preserved.
9797
*
9898
* @group expertSetParam
9999
*/
100100
@Since("2.3.0")
101-
def setModelPath(value: String): this.type = set(modelPath, value)
101+
def setModelPreservePath(value: String): this.type = set(modelPreservePath, value)
102102

103103
@Since("2.0.0")
104104
override def fit(dataset: Dataset[_]): CrossValidatorModel = {
@@ -128,13 +128,15 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
128128
// TODO: duplicate evaluator to take extra params from input
129129
val metric = eval.evaluate(models(i).transform(validationDataset, epm(i)))
130130
logDebug(s"Got metric $metric for model trained with ${epm(i)}.")
131-
if (isDefined(modelPath)) {
131+
if (isDefined(modelPreservePath)) {
132132
models(i) match {
133133
case w: MLWritable =>
134-
val path = new Path($(modelPath), epm(i).toSeq.map(p => p.param.name + "-" + p.value)
135-
.mkString("-") + s"-split$splitIndex-${math.rint(metric * 1000) / 1000}")
136-
w.save(path.toString)
134+
// e.g. maxIter-5-regParam-0.001-split0-0.859
135+
val fileName = epm(i).toSeq.map(p => p.param.name + "-" + p.value).sorted
136+
.mkString("-") + s"-split$splitIndex-${math.rint(metric * 1000) / 1000}"
137+
w.save(new Path($(modelPreservePath), fileName).toString)
137138
case _ =>
139+
// for third-party algorithms
138140
logWarning(models(i).uid + " did not implement MLWritable. Serialization omitted.")
139141
}
140142
}

mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,13 +88,13 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
8888
def setSeed(value: Long): this.type = set(seed, value)
8989

9090
/**
91-
* If set, all the models fitted during the training will be preserved
91+
* Optional parameter. If set, all the models fitted during the training will be saved
9292
* under the specific directory path. By default the models will not be saved.
9393
*
9494
* @group expertSetParam
9595
*/
9696
@Since("2.3.0")
97-
def setModelPath(value: String): this.type = set(modelPath, value)
97+
def setModelPreservePath(value: String): this.type = set(modelPreservePath, value)
9898

9999
@Since("2.0.0")
100100
override def fit(dataset: Dataset[_]): TrainValidationSplitModel = {
@@ -124,12 +124,13 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
124124
// TODO: duplicate evaluator to take extra params from input
125125
val metric = eval.evaluate(models(i).transform(validationDataset, epm(i)))
126126
logDebug(s"Got metric $metric for model trained with ${epm(i)}.")
127-
if (isDefined(modelPath)) {
127+
if (isDefined(modelPreservePath)) {
128128
models(i) match {
129129
case w: MLWritable =>
130-
val path = new Path($(modelPath), epm(i).toSeq.map(p => p.param.name + "-" + p.value)
131-
.mkString("-") + s"-${math.rint(metric * 1000) / 1000}")
132-
w.save(path.toString)
130+
// e.g. maxIter-5-regParam-0.001-0.859
131+
val fileName = epm(i).toSeq.map(p => p.param.name + "-" + p.value).sorted
132+
.mkString("-") + s"-${math.rint(metric * 1000) / 1000}"
133+
w.save(new Path($(modelPreservePath), fileName).toString)
133134
case _ =>
134135
logWarning(models(i).uid + " did not implement MLWritable. Serialization omitted.")
135136
}

mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,18 +70,18 @@ private[ml] trait ValidatorParams extends HasSeed with Params {
7070

7171

7272
/**
73-
* Optional parameter. If set, all the models fitted during the cross validation will be
74-
* saved in the specific path. By default the models will not be saved.
73+
* Optional parameter. If set, all the models trained during the tuning grid search will be
74+
* saved in the specific path. By default the models will not be preserved.
7575
*
7676
* @group expertParam
7777
*/
78-
val modelPath: Param[String] = new Param(this, "modelPath",
78+
val modelPreservePath: Param[String] = new Param(this, "modelPath",
7979
"Optional parameter. If set, all the models fitted during the cross validation will be" +
8080
" saved in the path")
8181

8282
/** @group expertGetParam */
8383
@Since("2.3.0")
84-
def getModelPath: String = $(modelPath)
84+
def getModelPreservePath: String = $(modelPreservePath)
8585

8686
protected def transformSchemaImpl(schema: StructType): StructType = {
8787
require($(estimatorParamMaps).nonEmpty, s"Validator requires non-empty estimatorParamMaps")

mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class CrossValidatorSuite
5757
.setEstimatorParamMaps(lrParamMaps)
5858
.setEvaluator(eval)
5959
.setNumFolds(3)
60-
assert(!cv.isDefined(cv.modelPath))
60+
assert(!cv.isDefined(cv.modelPreservePath))
6161
val cvModel = cv.fit(dataset)
6262

6363
MLTestingUtils.checkCopyAndUids(cv, cvModel)
@@ -258,7 +258,7 @@ class CrossValidatorSuite
258258
.setEstimatorParamMaps(lrParamMaps)
259259
.setEvaluator(eval)
260260
.setNumFolds(3)
261-
.setModelPath(path)
261+
.setModelPreservePath(path)
262262
try {
263263
cv.fit(dataset)
264264
assert(tempDir.list().length === 3 * 2 * 2)

mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class TrainValidationSplitSuite
5454
.setSeed(42L)
5555
val tvsModel = tvs.fit(dataset)
5656
val parent = tvsModel.bestModel.parent.asInstanceOf[LogisticRegression]
57-
assert(!tvs.isDefined(tvs.modelPath))
57+
assert(!tvs.isDefined(tvs.modelPreservePath))
5858
assert(tvs.getTrainRatio === 0.5)
5959
assert(parent.getRegParam === 0.001)
6060
assert(parent.getMaxIter === 10)
@@ -136,7 +136,7 @@ class TrainValidationSplitSuite
136136
.setEvaluator(eval)
137137
.setTrainRatio(0.5)
138138
.setSeed(42L)
139-
.setModelPath(path)
139+
.setModelPreservePath(path)
140140
try {
141141
tvs.fit(dataset)
142142
assert(tempDir.list().length === 2 * 2)

0 commit comments

Comments
 (0)