Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 28 additions & 7 deletions python/pyspark/ml/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from pyspark import since, keyword_only
from pyspark.ml import Estimator, Model
from pyspark.ml.common import _py2java
from pyspark.ml.common import _py2java, _java2py
from pyspark.ml.param import Params, Param, TypeConverters
from pyspark.ml.param.shared import HasCollectSubModels, HasParallelism, HasSeed
from pyspark.ml.util import *
Expand Down Expand Up @@ -216,6 +216,8 @@ class CrossValidator(Estimator, _CrossValidatorParams, HasParallelism, HasCollec
>>> from pyspark.ml.classification import LogisticRegression
>>> from pyspark.ml.evaluation import BinaryClassificationEvaluator
>>> from pyspark.ml.linalg import Vectors
>>> from pyspark.ml.tuning import CrossValidatorModel
>>> import tempfile
>>> dataset = spark.createDataFrame(
... [(Vectors.dense([0.0]), 0.0),
... (Vectors.dense([0.4]), 1.0),
Expand All @@ -233,6 +235,12 @@ class CrossValidator(Estimator, _CrossValidatorParams, HasParallelism, HasCollec
3
>>> cvModel.avgMetrics[0]
0.5
>>> path = tempfile.mkdtemp()
>>> model_path = path + "/model"
>>> cvModel.write().save(model_path)
>>> cvModelRead = CrossValidatorModel.read().load(model_path)
>>> cvModelRead.avgMetrics
[0.5, ...
>>> evaluator.evaluate(cvModel.transform(dataset))
0.8333...

Expand Down Expand Up @@ -483,10 +491,12 @@ def _from_java(cls, java_stage):
Given a Java CrossValidatorModel, create and return a Python wrapper of it.
Used for ML persistence.
"""
sc = SparkContext._active_spark_context
bestModel = JavaParams._from_java(java_stage.bestModel())
avgMetrics = _java2py(sc, java_stage.avgMetrics())
estimator, epms, evaluator = super(CrossValidatorModel, cls)._from_java_impl(java_stage)

py_stage = cls(bestModel=bestModel).setEstimator(estimator)
py_stage = cls(bestModel=bestModel, avgMetrics=avgMetrics).setEstimator(estimator)
py_stage = py_stage.setEstimatorParamMaps(epms).setEvaluator(evaluator)

if java_stage.hasSubModels():
Expand All @@ -505,11 +515,10 @@ def _to_java(self):
"""

sc = SparkContext._active_spark_context
# TODO: persist average metrics as well
_java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidatorModel",
self.uid,
self.bestModel._to_java(),
_py2java(sc, []))
_py2java(sc, self.avgMetrics))
estimator, epms, evaluator = super(CrossValidatorModel, self)._to_java_impl()

_java_obj.set("evaluator", evaluator)
Expand Down Expand Up @@ -551,6 +560,8 @@ class TrainValidationSplit(Estimator, _TrainValidationSplitParams, HasParallelis
>>> from pyspark.ml.classification import LogisticRegression
>>> from pyspark.ml.evaluation import BinaryClassificationEvaluator
>>> from pyspark.ml.linalg import Vectors
>>> from pyspark.ml.tuning import TrainValidationSplitModel
>>> import tempfile
>>> dataset = spark.createDataFrame(
... [(Vectors.dense([0.0]), 0.0),
... (Vectors.dense([0.4]), 1.0),
Expand All @@ -566,6 +577,14 @@ class TrainValidationSplit(Estimator, _TrainValidationSplitParams, HasParallelis
>>> tvsModel = tvs.fit(dataset)
>>> tvsModel.getTrainRatio()
0.75
>>> tvsModel.validationMetrics
[0.5, ...
>>> path = tempfile.mkdtemp()
>>> model_path = path + "/model"
>>> tvsModel.write().save(model_path)
>>> tvsModelRead = TrainValidationSplitModel.read().load(model_path)
>>> tvsModelRead.validationMetrics
[0.5, ...
>>> evaluator.evaluate(tvsModel.transform(dataset))
0.833...

Expand Down Expand Up @@ -809,11 +828,14 @@ def _from_java(cls, java_stage):
"""

# Load information from java_stage to the instance.
sc = SparkContext._active_spark_context
bestModel = JavaParams._from_java(java_stage.bestModel())
validationMetrics = _java2py(sc, java_stage.validationMetrics())
estimator, epms, evaluator = super(TrainValidationSplitModel,
cls)._from_java_impl(java_stage)
# Create a new instance of this stage.
py_stage = cls(bestModel=bestModel).setEstimator(estimator)
py_stage = cls(bestModel=bestModel,
validationMetrics=validationMetrics).setEstimator(estimator)
py_stage = py_stage.setEstimatorParamMaps(epms).setEvaluator(evaluator)

if java_stage.hasSubModels():
Expand All @@ -830,12 +852,11 @@ def _to_java(self):
"""

sc = SparkContext._active_spark_context
# TODO: persst validation metrics as well
_java_obj = JavaParams._new_java_obj(
"org.apache.spark.ml.tuning.TrainValidationSplitModel",
self.uid,
self.bestModel._to_java(),
_py2java(sc, []))
_py2java(sc, self.validationMetrics))
estimator, epms, evaluator = super(TrainValidationSplitModel, self)._to_java_impl()

_java_obj.set("evaluator", evaluator)
Expand Down