-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-21027][ML][PYTHON] Added tunable parallelism to one vs. rest in both Scala mllib and Pyspark #18281
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-21027][ML][PYTHON] Added tunable parallelism to one vs. rest in both Scala mllib and Pyspark #18281
Changes from 1 commit
b69f201
e750d3e
81d458b
2133378
c59b1d8
5f635a2
4431ffc
a841b3e
a95a8af
d45bc23
30ac62d
cc634d2
ce14172
1c9de16
9f34404
585a3f8
f65381a
2a335fe
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -33,7 +33,11 @@ import org.apache.spark.ml._ | |
| import org.apache.spark.ml.attribute._ | ||
| import org.apache.spark.ml.linalg.Vector | ||
| import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params} | ||
| <<<<<<< HEAD | ||
| import org.apache.spark.ml.param.shared.HasParallelism | ||
| ======= | ||
| import org.apache.spark.ml.param.shared.HasWeightCol | ||
| >>>>>>> master | ||
| import org.apache.spark.ml.util._ | ||
| import org.apache.spark.sql.{DataFrame, Dataset, Row} | ||
| import org.apache.spark.sql.functions._ | ||
|
|
@@ -309,6 +313,18 @@ final class OneVsRest @Since("1.4.0") ( | |
| set(parallelism, value) | ||
| } | ||
|
|
||
| /** | ||
| * Sets the value of param [[weightCol]]. | ||
| * | ||
| * This is ignored if weight is not supported by [[classifier]]. | ||
| * If this is not set or empty, we treat all instance weights as 1.0. | ||
| * Default is not set, so all instances have weight one. | ||
| * | ||
| * @group setParam | ||
| */ | ||
| @Since("2.3.0") | ||
| def setWeightCol(value: String): this.type = set(weightCol, value) | ||
|
|
||
| @Since("1.4.0") | ||
| override def transformSchema(schema: StructType): StructType = { | ||
| validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType) | ||
|
|
@@ -368,12 +384,17 @@ final class OneVsRest @Since("1.4.0") ( | |
| paramMap.put(classifier.featuresCol -> getFeaturesCol) | ||
| paramMap.put(classifier.predictionCol -> getPredictionCol) | ||
| Future { | ||
| classifier.fit(trainingDataset, paramMap) | ||
| if (weightColIsUsed) { | ||
| val classifier_ = classifier.asInstanceOf[ClassifierType with HasWeightCol] | ||
| paramMap.put(classifier_.weightCol -> getWeightCol) | ||
| classifier_.fit(trainingDataset, paramMap) | ||
| } else { | ||
| classifier.fit(trainingDataset, paramMap) | ||
| } | ||
| }(executionContext) | ||
| } | ||
| val models = modelFutures | ||
| .map(ThreadUtils.awaitResult(_, Duration.Inf)).toArray[ClassificationModel[_, _]] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it be more idomatic to use
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @holdenk Because |
||
|
|
||
| instr.logNumFeatures(models.head.numFeatures) | ||
|
|
||
| if (handlePersistence) { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1608,10 +1608,10 @@ class OneVsRest(Estimator, OneVsRestParams, HasParallelism, JavaMLReadable, Java | |
|
|
||
| @keyword_only | ||
| def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", | ||
| classifier=None, parallelism=1): | ||
| classifier=None, weightCol=None, parallelism=1): | ||
| """ | ||
| __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ | ||
| classifier=None, parallelism=1) | ||
| classifier=None, weightCol=None, parallelism=1): | ||
| """ | ||
| super(OneVsRest, self).__init__() | ||
| self._setDefault(parallelism=1) | ||
|
|
@@ -1621,10 +1621,10 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred | |
| @keyword_only | ||
| @since("2.0.0") | ||
| def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", | ||
| classifier=None, parallelism=1): | ||
| classifier=None, weightCol=None, parallelism=1): | ||
| """ | ||
| setParams(self, featuresCol=None, labelCol=None, predictionCol=None, \ | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The default args here in the doc should match the method (for |
||
| classifier=None, parallelism=1): | ||
| classifier=None, weightCol=None, parallelism=1): | ||
| Sets params for OneVsRest. | ||
| """ | ||
| kwargs = self._input_kwargs | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1473,6 +1473,20 @@ def test_parallelism_doesnt_change_output(self): | |
| modelPar2.models[i].coefficients.toArray(), atol=1E-4)) | ||
| self.assertTrue(np.allclose(model.intercept, modelPar2.models[i].intercept, atol=1E-4)) | ||
|
|
||
| def test_support_for_weightCol(self): | ||
| df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8), 1.0), | ||
| (1.0, Vectors.sparse(2, [], []), 1.0), | ||
| (2.0, Vectors.dense(0.5, 0.5), 1.0)], | ||
| ["label", "features", "weight"]) | ||
| # classifier inherits hasWeightCol | ||
| lr = LogisticRegression(maxIter=5, regParam=0.01) | ||
| ovr = OneVsRest(classifier=lr, weightCol="weight") | ||
| self.assertIsNotNone(ovr.fit(df)) | ||
| # classifier doesn't inherit hasWeightCol | ||
| dt = DecisionTreeClassifier() | ||
| ovr2 = OneVsRest(classifier=dt, weightCol="weight") | ||
| self.assertIsNotNone(ovr2.fit(df)) | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar comment, can we add a test to make sure that we are actually training in parallel? This is perhaps especially important in Python because I could see us accidentally blocking on something unexpected. |
||
|
|
||
| class HashingTFTest(SparkSessionTestCase): | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some sort of merge problem? this shouldn't be in a commit