File tree Expand file tree Collapse file tree 1 file changed +3
-3
lines changed
Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -447,9 +447,10 @@ def _from_java(cls, java_stage):
447447 Used for ML persistence.
448448 """
449449 bestModel = JavaParams ._from_java (java_stage .bestModel ())
450+ avgMetrics = java_stage .avgMetrics ()
450451 estimator , epms , evaluator = super (CrossValidatorModel , cls )._from_java_impl (java_stage )
451452
452- py_stage = cls (bestModel = bestModel ).setEstimator (estimator )
453+ py_stage = cls (bestModel = bestModel , avgMetrics = avgMetrics ).setEstimator (estimator )
453454 py_stage = py_stage .setEstimatorParamMaps (epms ).setEvaluator (evaluator )
454455
455456 if java_stage .hasSubModels ():
@@ -468,11 +469,10 @@ def _to_java(self):
468469 """
469470
470471 sc = SparkContext ._active_spark_context
471- # TODO: persist average metrics as well
472472 _java_obj = JavaParams ._new_java_obj ("org.apache.spark.ml.tuning.CrossValidatorModel" ,
473473 self .uid ,
474474 self .bestModel ._to_java (),
475- _py2java ( sc , []) )
475+ self . avgMetrics )
476476 estimator , epms , evaluator = super (CrossValidatorModel , self )._to_java_impl ()
477477
478478 _java_obj .set ("evaluator" , evaluator )
You can’t perform that action at this time.
0 commit comments