Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fixed merge conflict
  • Loading branch information
ajaysaini725 committed Aug 12, 2017
commit 585a3f8ea21359f11cd5a19ba195df88e091d9e0
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@ import org.apache.spark.ml._
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
<<<<<<< HEAD
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some sort of merge problem? this shouldn't be in a commit

import org.apache.spark.ml.param.shared.HasParallelism
=======
import org.apache.spark.ml.param.shared.HasWeightCol
>>>>>>> master
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
Expand Down Expand Up @@ -309,6 +313,18 @@ final class OneVsRest @Since("1.4.0") (
set(parallelism, value)
}

/**
* Sets the value of param [[weightCol]].
*
* This is ignored if weight is not supported by [[classifier]].
* If this is not set or empty, we treat all instance weights as 1.0.
* Default is not set, so all instances have weight one.
*
* @group setParam
*/
@Since("2.3.0")
def setWeightCol(value: String): this.type = set(weightCol, value)

@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType)
Expand Down Expand Up @@ -368,12 +384,17 @@ final class OneVsRest @Since("1.4.0") (
paramMap.put(classifier.featuresCol -> getFeaturesCol)
paramMap.put(classifier.predictionCol -> getPredictionCol)
Future {
classifier.fit(trainingDataset, paramMap)
if (weightColIsUsed) {
val classifier_ = classifier.asInstanceOf[ClassifierType with HasWeightCol]
paramMap.put(classifier_.weightCol -> getWeightCol)
classifier_.fit(trainingDataset, paramMap)
} else {
classifier.fit(trainingDataset, paramMap)
}
}(executionContext)
}
val models = modelFutures
.map(ThreadUtils.awaitResult(_, Duration.Inf)).toArray[ClassificationModel[_, _]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be more idomatic to use Future.sequence( here and just wait on the Future[List] ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@holdenk Because ThreadUtils.awaitResult wraps and re-throws any exceptions thrown by the underlying Await call, ensuring that this thread's stack trace appears in logs


instr.logNumFeatures(models.head.numFeatures)

if (handlePersistence) {
Expand Down
8 changes: 4 additions & 4 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -1608,10 +1608,10 @@ class OneVsRest(Estimator, OneVsRestParams, HasParallelism, JavaMLReadable, Java

@keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
classifier=None, parallelism=1):
classifier=None, weightCol=None, parallelism=1):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
classifier=None, parallelism=1)
classifier=None, weightCol=None, parallelism=1):
"""
super(OneVsRest, self).__init__()
self._setDefault(parallelism=1)
Expand All @@ -1621,10 +1621,10 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred
@keyword_only
@since("2.0.0")
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
classifier=None, parallelism=1):
classifier=None, weightCol=None, parallelism=1):
"""
setParams(self, featuresCol=None, labelCol=None, predictionCol=None, \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default args here in the doc should match the method (for featuresCol, labelCol and predictionCol)

classifier=None, parallelism=1):
classifier=None, weightCol=None, parallelism=1):
Sets params for OneVsRest.
"""
kwargs = self._input_kwargs
Expand Down
14 changes: 14 additions & 0 deletions python/pyspark/ml/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1473,6 +1473,20 @@ def test_parallelism_doesnt_change_output(self):
modelPar2.models[i].coefficients.toArray(), atol=1E-4))
self.assertTrue(np.allclose(model.intercept, modelPar2.models[i].intercept, atol=1E-4))

def test_support_for_weightCol(self):
df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8), 1.0),
(1.0, Vectors.sparse(2, [], []), 1.0),
(2.0, Vectors.dense(0.5, 0.5), 1.0)],
["label", "features", "weight"])
# classifier inherits hasWeightCol
lr = LogisticRegression(maxIter=5, regParam=0.01)
ovr = OneVsRest(classifier=lr, weightCol="weight")
self.assertIsNotNone(ovr.fit(df))
# classifier doesn't inherit hasWeightCol
dt = DecisionTreeClassifier()
ovr2 = OneVsRest(classifier=dt, weightCol="weight")
self.assertIsNotNone(ovr2.fit(df))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar comment, can we add a test to make sure that we are actually training in parallel? This is perhaps especially important in Python because I could see us accidentally blocking on something unexpected.


class HashingTFTest(SparkSessionTestCase):

Expand Down
You are viewing a condensed version of this merge commit. You can view the full changes here.