Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update & resolve conflicts
  • Loading branch information
WeichenXu123 committed Nov 15, 2017
commit 56143905f23f83b1c130fbbe42ef48fe3a5c1b2d
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
val foldMetrics = new Array[Double](epm.length)
est.fit(trainingDataset, epm, true, executionContext,
(model: Model[_], paramMap: ParamMap, paramMapIndex: Int) => {
if (collectSubModelsParam) {
subModels.get(splitIndex)(paramMapIndex) = model
}
// TODO: duplicate evaluator to take extra params from input
val metric = eval.evaluate(model.transform(validationDataset, paramMap))
logDebug(s"Got metric $metric for model trained with $paramMap.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
val metrics = new Array[Double](epm.length)
est.fit(trainingDataset, epm, true, executionContext,
(model: Model[_], paramMap: ParamMap, paramMapIndex: Int) => {
if (collectSubModelsParam) {
subModels.get(paramMapIndex) = model
}
// TODO: duplicate evaluator to take extra params from input
val metric = eval.evaluate(model.transform(validationDataset, paramMap))
logDebug(s"Got metric $metric for model trained with $paramMap.")
Expand Down
You are viewing a condensed version of this merge commit. You can view the full changes here.