-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-12424][ML] The implementation of ParamMap#filter is wrong. #10381
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,7 +17,10 @@ | |
|
|
||
| package org.apache.spark.ml.param | ||
|
|
||
| import java.io.{ByteArrayOutputStream, NotSerializableException, ObjectOutputStream} | ||
|
|
||
| import org.apache.spark.SparkFunSuite | ||
| import org.apache.spark.ml.util.MyParams | ||
| import org.apache.spark.mllib.linalg.{Vector, Vectors} | ||
|
|
||
| class ParamsSuite extends SparkFunSuite { | ||
|
|
@@ -349,6 +352,41 @@ class ParamsSuite extends SparkFunSuite { | |
| val t3 = t.copy(ParamMap(t.maxIter -> 20)) | ||
| assert(t3.isSet(t3.maxIter)) | ||
| } | ||
|
|
||
| test("Filtering ParamMap") { | ||
| val params1 = new MyParams("my_params1") | ||
| val params2 = new MyParams("my_params2") | ||
| val paramMap = ParamMap( | ||
| params1.intParam -> 1, | ||
| params2.intParam -> 1, | ||
| params1.doubleParam -> 0.2, | ||
| params2.doubleParam -> 0.2) | ||
| val filteredParamMap = paramMap.filter(params1) | ||
|
|
||
| assert(filteredParamMap.size === 2) | ||
| filteredParamMap.toSeq.foreach { | ||
| case ParamPair(p, _) => | ||
| assert(p.parent === params1.uid) | ||
| } | ||
|
|
||
| // At the previous implementation of ParamMap#filter, | ||
| // mutable.Map#filterKeys was used internally but | ||
| // the return type of the method is not serializable (see SI-6654). | ||
| // Now mutable.Map#filter is used instead of filterKeys and the return type is serializable. | ||
| // So let's ensure serializability. | ||
| val objOut = new ObjectOutputStream(new ByteArrayOutputStream()) | ||
| try { | ||
| objOut.writeObject(filteredParamMap) | ||
| } catch { | ||
| case _: NotSerializableException => | ||
| fail("The field of ParamMap 'map' may not be serializable. " + | ||
| "See SI-6654 and the implementation of ParamMap#filter") | ||
| case e: Exception => | ||
|
||
| fail(s"Exception was thrown unexpectedly during the serializability test: ${e.getMessage}") | ||
| } finally { | ||
| objOut.close() | ||
| } | ||
| } | ||
| } | ||
|
|
||
| object ParamsSuite extends SparkFunSuite { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hasn't this changed the logic slightly? now you compare to parent.uid
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the original logic is wrong because the type of
.parentis String (this is a member ofParam) while the type of parameterparentisParams.According to the implementation of
Param, the memberparentofParamis passeduidofIdentifiablewhich is a trait ofParams.