Skip to content
Closed
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 28 additions & 9 deletions python/pyspark/ml/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from pyspark import since, keyword_only
from pyspark.ml import Estimator, Model
from pyspark.ml.common import _py2java
from pyspark.ml.common import _py2java, _java2py
from pyspark.ml.param import Params, Param, TypeConverters
from pyspark.ml.param.shared import HasCollectSubModels, HasParallelism, HasSeed
from pyspark.ml.util import *
Expand Down Expand Up @@ -216,6 +216,8 @@ class CrossValidator(Estimator, _CrossValidatorParams, HasParallelism, HasCollec
>>> from pyspark.ml.classification import LogisticRegression
>>> from pyspark.ml.evaluation import BinaryClassificationEvaluator
>>> from pyspark.ml.linalg import Vectors
>>> from pyspark.ml.tuning import CrossValidatorModel
>>> import tempfile
>>> dataset = spark.createDataFrame(
... [(Vectors.dense([0.0]), 0.0),
... (Vectors.dense([0.4]), 1.0),
Expand All @@ -233,6 +235,12 @@ class CrossValidator(Estimator, _CrossValidatorParams, HasParallelism, HasCollec
3
>>> cvModel.avgMetrics[0]
0.5
>>> path = tempfile.mkdtemp()
>>> model_path = path + "/model"
>>> cvModel.write().save(model_path)
>>> cvModelRead = CrossValidatorModel.read().load(model_path)
>>> cvModelRead.avgMetrics
[0.5, ...
>>> evaluator.evaluate(cvModel.transform(dataset))
0.8333...

Expand Down Expand Up @@ -483,10 +491,12 @@ def _from_java(cls, java_stage):
Given a Java CrossValidatorModel, create and return a Python wrapper of it.
Used for ML persistence.
"""
sc = SparkContext._active_spark_context
bestModel = JavaParams._from_java(java_stage.bestModel())
avgMetrics = _java2py(sc, java_stage.avgMetrics())
estimator, epms, evaluator = super(CrossValidatorModel, cls)._from_java_impl(java_stage)

py_stage = cls(bestModel=bestModel).setEstimator(estimator)
py_stage = cls(bestModel=bestModel, avgMetrics=avgMetrics).setEstimator(estimator)
py_stage = py_stage.setEstimatorParamMaps(epms).setEvaluator(evaluator)

if java_stage.hasSubModels():
Expand All @@ -504,12 +514,10 @@ def _to_java(self):
:return: Java object equivalent to this instance.
"""

sc = SparkContext._active_spark_context
# TODO: persist average metrics as well
_java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidatorModel",
self.uid,
self.bestModel._to_java(),
_py2java(sc, []))
self.avgMetrics)
estimator, epms, evaluator = super(CrossValidatorModel, self)._to_java_impl()

_java_obj.set("evaluator", evaluator)
Expand Down Expand Up @@ -551,6 +559,8 @@ class TrainValidationSplit(Estimator, _TrainValidationSplitParams, HasParallelis
>>> from pyspark.ml.classification import LogisticRegression
>>> from pyspark.ml.evaluation import BinaryClassificationEvaluator
>>> from pyspark.ml.linalg import Vectors
>>> from pyspark.ml.tuning import TrainValidationSplitModel
>>> import tempfile
>>> dataset = spark.createDataFrame(
... [(Vectors.dense([0.0]), 0.0),
... (Vectors.dense([0.4]), 1.0),
Expand All @@ -566,6 +576,14 @@ class TrainValidationSplit(Estimator, _TrainValidationSplitParams, HasParallelis
>>> tvsModel = tvs.fit(dataset)
>>> tvsModel.getTrainRatio()
0.75
>>> tvsModel.validationMetrics
[0.5, ...
>>> path = tempfile.mkdtemp()
>>> model_path = path + "/model"
>>> tvsModel.write().save(model_path)
>>> tvsModelRead = TrainValidationSplitModel.read().load(model_path)
>>> tvsModelRead.validationMetrics
[0.5, ...
>>> evaluator.evaluate(tvsModel.transform(dataset))
0.833...

Expand Down Expand Up @@ -809,11 +827,14 @@ def _from_java(cls, java_stage):
"""

# Load information from java_stage to the instance.
sc = SparkContext._active_spark_context
bestModel = JavaParams._from_java(java_stage.bestModel())
validationMetrics = _java2py(sc, java_stage.validationMetrics())
estimator, epms, evaluator = super(TrainValidationSplitModel,
cls)._from_java_impl(java_stage)
# Create a new instance of this stage.
py_stage = cls(bestModel=bestModel).setEstimator(estimator)
py_stage = cls(bestModel=bestModel,
validationMetrics=validationMetrics).setEstimator(estimator)
py_stage = py_stage.setEstimatorParamMaps(epms).setEvaluator(evaluator)

if java_stage.hasSubModels():
Expand All @@ -829,13 +850,11 @@ def _to_java(self):
:return: Java object equivalent to this instance.
"""

sc = SparkContext._active_spark_context
# TODO: persst validation metrics as well
_java_obj = JavaParams._new_java_obj(
"org.apache.spark.ml.tuning.TrainValidationSplitModel",
self.uid,
self.bestModel._to_java(),
_py2java(sc, []))
self.validationMetrics)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems fine but out of curiosity why is the _py2java call no longer needed here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we will be converting _py2java here?

def _new_java_obj(java_class, *args):
"""
Returns a new Java object.
"""
sc = SparkContext._active_spark_context
java_obj = _jvm()
for name in java_class.split("."):
java_obj = getattr(java_obj, name)
java_args = [_py2java(sc, arg) for arg in args]
return java_obj(*java_args)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I compared with _py2java here and without here, both cases the written metadata file is same. I'll add _py2java here, for consistency.

estimator, epms, evaluator = super(TrainValidationSplitModel, self)._to_java_impl()

_java_obj.set("evaluator", evaluator)
Expand Down