Skip to content

Conversation

@shahidki31
Copy link
Contributor

@shahidki31 shahidki31 commented Oct 6, 2019

What changes were proposed in this pull request?

Currently pyspark doesn't write/read avgMetrics in CrossValidatorModel, whereas scala supports it.

Why are the changes needed?

Test step to reproduce it:

dataset = spark.createDataFrame([(Vectors.dense([0.0]), 0.0),
     (Vectors.dense([0.4]), 1.0),
     (Vectors.dense([0.5]), 0.0),
      (Vectors.dense([0.6]), 1.0),
      (Vectors.dense([1.0]), 1.0)] * 10,
     ["features", "label"])
lr = LogisticRegression()
grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
evaluator = BinaryClassificationEvaluator()
cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator,parallelism=2)
cvModel = cv.fit(dataset)
cvModel.write().save("/tmp/model")
cvModel2 = CrossValidatorModel.read().load("/tmp/model")
print(cvModel.avgMetrics) # prints non empty result as expected
print(cvModel2.avgMetrics) # Bug: prints an empty result.

Does this PR introduce any user-facing change?

No

How was this patch tested?

Manually tested

Before patch:

>>> cvModel.write().save("/tmp/model_0")
>>> cvModel2 = CrossValidatorModel.read().load("/tmp/model_0")
>>> print(cvModel2.avgMetrics)
[]

After patch:

>>> cvModel2 = CrossValidatorModel.read().load("/tmp/model_2")
>>> print(cvModel2.avgMetrics[0])
0.5

@SparkQA
Copy link

SparkQA commented Oct 6, 2019

Test build #111826 has finished for PR 26038 at commit 1c594da.

  • This patch fails Python style tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Oct 6, 2019

Test build #111827 has finished for PR 26038 at commit 13e3a59.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@shahidki31
Copy link
Contributor Author

cc @zhengruifeng Kindly review

@zhengruifeng
Copy link
Contributor

In [1]: from pyspark.ml.classification import LogisticRegression

In [2]: from pyspark.ml.evaluation import BinaryClassificationEvaluator

In [3]: from pyspark.ml.linalg import Vectors

In [4]: dataset = spark.createDataFrame(
   ...:     ...     [(Vectors.dense([0.0]), 0.0),
   ...:     ...      (Vectors.dense([0.4]), 1.0),
   ...:     ...      (Vectors.dense([0.5]), 0.0),
   ...:     ...      (Vectors.dense([0.6]), 1.0),
   ...:     ...      (Vectors.dense([1.0]), 1.0)] * 10,
   ...:     ...     ["features", "label"]).repartition(1)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-4-47bd70df4aa7> in <module>()
      1 dataset = spark.createDataFrame(
      2     ...     [(Vectors.dense([0.0]), 0.0),
----> 3     ...      (Vectors.dense([0.4]), 1.0),
      4     ...      (Vectors.dense([0.5]), 0.0),
      5     ...      (Vectors.dense([0.6]), 1.0),

TypeError: 'ellipsis' object is not callable

In [5]: dataset = spark.createDataFrame([(Vectors.dense([0.0]), 0.0),(Vectors.dense([0.4]), 1.0),(Vectors.dense([0.5]), 0.0),(Vectors.dense([0.6]), 1.0),(Vectors.dense([1.0]), 1.0)] * 10,["features", "label
   ...: "]).repartition(1)

In [6]: lr = LogisticRegression()

In [7]: grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
<ipython-input-7-045a988cd0ea> in <module>()
----> 1 grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()

NameError: name 'ParamGridBuilder' is not defined

In [8]: from pyspark.ml.tuning import *

In [9]: grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()

In [10]: evaluator = BinaryClassificationEvaluator()

In [11]: tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator, parallelism=1, seed=42)

In [12]: tvsModel = tvs.fit(dataset)
19/10/09 09:36:51 WARN BLAS: Failed to load implementation from: com.github.fommil.netlib.NativeSystemBLAS
19/10/09 09:36:51 WARN BLAS: Failed to load implementation from: com.github.fommil.netlib.NativeRefBLAS

In [13]: tvsModel.save("/tmp/model")

In [14]: tvsModel2 = TrainValidationSplitModel.load("/tmp/model")

In [15]: tvsModel.validationMetrics
Out[15]: [0.5, 0.8857142857142857]

In [16]: tvsModel2.validationMetrics
Out[16]: []

@shahidki31 Same issue also exist in TrainValidationSplitModel, can you also fix it in this pr?
BTW, what about adding doctests for model savle/load? (also check the loaded metrics)

@shahidki31
Copy link
Contributor Author

Thanks @zhengruifeng I will add metrics for TrainValidationSplitModel too.

@srowen
Copy link
Member

srowen commented Oct 16, 2019

If you'll make the changes @shahidki31 I think we can merge this.

@shahidki31
Copy link
Contributor Author

Thanks @srowen . I will update it today. Actually, there seems an issue. I think AvgMetrics need to convert from java to python object, while reading.

@SparkQA
Copy link

SparkQA commented Oct 17, 2019

Test build #112229 has finished for PR 26038 at commit 6068f66.

  • This patch fails PySpark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Oct 17, 2019

Test build #112234 has finished for PR 26038 at commit b0f1975.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Oct 17, 2019

Test build #112235 has finished for PR 26038 at commit 5a79a8a.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@shahidki31 shahidki31 requested a review from srowen October 17, 2019 22:04
@shahidki31
Copy link
Contributor Author

Updated the PR. Locally verified.

@SparkQA
Copy link

SparkQA commented Oct 17, 2019

Test build #112232 has finished for PR 26038 at commit b0f1975.

  • This patch passes all tests.
  • This patch does not merge cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Oct 17, 2019

Test build #112233 has finished for PR 26038 at commit 5e39d5a.

  • This patch passes all tests.
  • This patch does not merge cleanly.
  • This patch adds the following public classes (experimental):
  • class _ValidatorParams(HasSeed):
  • class _CrossValidatorParams(_ValidatorParams):
  • class CrossValidator(Estimator, _CrossValidatorParams, HasParallelism, HasCollectSubModels,
  • class CrossValidatorModel(Model, _CrossValidatorParams, MLReadable, MLWritable):
  • class _TrainValidationSplitParams(_ValidatorParams):
  • class TrainValidationSplit(Estimator, _TrainValidationSplitParams, HasParallelism,
  • class TrainValidationSplitModel(Model, _TrainValidationSplitParams, MLReadable, MLWritable):

@shahidki31
Copy link
Contributor Author

retest this please

@SparkQA
Copy link

SparkQA commented Oct 18, 2019

Test build #112240 has finished for PR 26038 at commit 5a79a8a.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Oct 18, 2019

Test build #112241 has finished for PR 26038 at commit 2755376.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

self.uid,
self.bestModel._to_java(),
_py2java(sc, []))
self.validationMetrics)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems fine but out of curiosity why is the _py2java call no longer needed here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we will be converting _py2java here?

def _new_java_obj(java_class, *args):
"""
Returns a new Java object.
"""
sc = SparkContext._active_spark_context
java_obj = _jvm()
for name in java_class.split("."):
java_obj = getattr(java_obj, name)
java_args = [_py2java(sc, arg) for arg in args]
return java_obj(*java_args)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I compared with _py2java here and without here, both cases the written metadata file is same. I'll add _py2java here, for consistency.

@SparkQA
Copy link

SparkQA commented Oct 18, 2019

Test build #112292 has finished for PR 26038 at commit 00c4258.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@srowen srowen closed this in 4a6005c Oct 19, 2019
@srowen
Copy link
Member

srowen commented Oct 19, 2019

Merged to master

@shahidki31
Copy link
Contributor Author

Thanks @srowen @zhengruifeng

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants