Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
b69f201
Added tunable parallelism to the pyspark implementation of one vs. re…
ajaysaini725 Jun 12, 2017
e750d3e
Fixed python style.
ajaysaini725 Jun 12, 2017
81d458b
Added functionality for tuning parellelism in the Scala implementatio…
ajaysaini725 Jun 13, 2017
2133378
Fixed code according to comments. Added both annotations and unit tes…
ajaysaini725 Jun 13, 2017
c59b1d8
Modified parallel one vs rest to use futures.
ajaysaini725 Jun 22, 2017
5f635a2
Put the parallelism parameter as well as the function for getting an …
ajaysaini725 Jun 23, 2017
4431ffc
Responded to pull request comments.
ajaysaini725 Jun 23, 2017
a841b3e
Made changes based on pull request comments.
ajaysaini725 Jul 6, 2017
a95a8af
Fixed based on pull request comments
ajaysaini725 Jul 14, 2017
d45bc23
Fixed based on comments
ajaysaini725 Jul 18, 2017
30ac62d
Reverting merge and adding change that would fix merge conflict (maki…
ajaysaini725 Jul 19, 2017
cc634d2
Merge branch 'master' into spark-21027
ajaysaini725 Jul 19, 2017
ce14172
Style fix with docstring
ajaysaini725 Jul 20, 2017
1c9de16
Fixed based on comments.
ajaysaini725 Jul 27, 2017
9f34404
Fixed style issue.
ajaysaini725 Jul 27, 2017
585a3f8
Fixed merge conflict
ajaysaini725 Aug 12, 2017
f65381a
Fixed remaining part of merge conflict.
ajaysaini725 Aug 23, 2017
2a335fe
Fixed style problem
ajaysaini725 Aug 23, 2017
049f371
Merge branch 'master' into spark-21027
WeichenXu123 Sep 2, 2017
ddc2ff4
address review feedback issues
WeichenXu123 Sep 3, 2017
fc6fd5e
update migration guide
WeichenXu123 Sep 3, 2017
7d0849e
update desc
WeichenXu123 Sep 6, 2017
edcf85c
fix style
WeichenXu123 Sep 6, 2017
7a1d404
merge master & resolve conflicts
WeichenXu123 Sep 6, 2017
c24d4e2
update out-of-date shared.py
WeichenXu123 Sep 12, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Added tunable parallelism to the pyspark implementation of one vs. re…
…st classification. Added a parallelism parameter to the scala implementation of one vs. rest for python persistence but have not yet used it to tune the scala parallelism implementation.
  • Loading branch information
ajaysaini725 committed Jun 12, 2017
commit b69f201bc51f8de87adc3869d4843e3df6750972
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import org.apache.spark.annotation.Since
import org.apache.spark.ml._
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
import org.apache.spark.ml.param.{IntParam, Param, ParamMap, ParamPair, Params, ParamValidators}
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
Expand Down Expand Up @@ -65,6 +65,12 @@ private[ml] trait OneVsRestParams extends PredictorParams with ClassifierTypeTra

/** @group getParam */
def getClassifier: ClassifierType = $(classifier)

val parallelism = new IntParam(this, "parallelism",
"parallelism parameter for tuning amount of parallelism", ParamValidators.gt(1))

/** @group getParam */
def getParallelism: Int = $(parallelism)
}

private[ml] object OneVsRestParams extends ClassifierTypeTrait {
Expand Down Expand Up @@ -282,6 +288,12 @@ final class OneVsRest @Since("1.4.0") (
set(classifier, value.asInstanceOf[ClassifierType])
}

/** @group setParam */
@Since("1.4.0")
def setParallelism(value: Int): this.type = {
set(parallelism, value)
}

/** @group setParam */
@Since("1.5.0")
def setLabelCol(value: String): this.type = set(labelCol, value)
Expand Down
31 changes: 26 additions & 5 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#

import operator
from multiprocessing.pool import ThreadPool

from pyspark import since, keyword_only
from pyspark.ml import Estimator, Model
Expand Down Expand Up @@ -1510,21 +1511,25 @@ class OneVsRest(Estimator, OneVsRestParams, MLReadable, MLWritable):

.. versionadded:: 2.0.0
"""
parallelism = Param(Params._dummy(), "parallelism",
"Number of models to fit in parallel",
typeConverter=TypeConverters.toInt)

@keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
classifier=None):
classifier=None, parallelism=8):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
classifier=None)
"""
super(OneVsRest, self).__init__()
self._setDefault(parallelism=8)
kwargs = self._input_kwargs
self._set(**kwargs)

@keyword_only
@since("2.0.0")
def setParams(self, featuresCol=None, labelCol=None, predictionCol=None, classifier=None):
def setParams(self, featuresCol=None, labelCol=None, predictionCol=None, classifier=None, parallelism=None):
"""
setParams(self, featuresCol=None, labelCol=None, predictionCol=None, classifier=None):
Sets params for OneVsRest.
Expand Down Expand Up @@ -1561,13 +1566,28 @@ def trainSingleClass(index):
return classifier.fit(trainingDataset, paramMap)

# TODO: Parallel training for all classes.
models = [trainSingleClass(i) for i in range(numClasses)]
pool = ThreadPool(processes=self.getParallelism())

models = pool.map(trainSingleClass, range(numClasses))
#models = [trainSingleClass(i) for i in range(numClasses)]

if handlePersistence:
multiclassLabeled.unpersist()

return self._copyValues(OneVsRestModel(models=models))

def setParallelism(self, value):
"""
Sets the value of :py:attr:`parallelism`.
"""
return self._set(parallelism=value)

def getParallelism(self):
"""
Gets the value of parallelism or its default value.
"""
return self.getOrDefault(self.parallelism)

@since("2.0.0")
def copy(self, extra=None):
"""
Expand Down Expand Up @@ -1611,8 +1631,9 @@ def _from_java(cls, java_stage):
labelCol = java_stage.getLabelCol()
predictionCol = java_stage.getPredictionCol()
classifier = JavaParams._from_java(java_stage.getClassifier())
parallelism = java_stage.getParallelism()
py_stage = cls(featuresCol=featuresCol, labelCol=labelCol, predictionCol=predictionCol,
classifier=classifier)
classifier=classifier, parallelism=parallelism)
py_stage._resetUid(java_stage.uid())
return py_stage

Expand All @@ -1625,12 +1646,12 @@ def _to_java(self):
_java_obj = JavaParams._new_java_obj("org.apache.spark.ml.classification.OneVsRest",
self.uid)
_java_obj.setClassifier(self.getClassifier()._to_java())
_java_obj.setParallelism(self.getParallelism())
_java_obj.setFeaturesCol(self.getFeaturesCol())
_java_obj.setLabelCol(self.getLabelCol())
_java_obj.setPredictionCol(self.getPredictionCol())
return _java_obj


class OneVsRestModel(Model, OneVsRestParams, MLReadable, MLWritable):
"""
.. note:: Experimental
Expand Down
32 changes: 29 additions & 3 deletions python/pyspark/ml/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,7 +951,7 @@ def test_onevsrest(self):
(2.0, Vectors.dense(0.5, 0.5))] * 10,
["label", "features"])
lr = LogisticRegression(maxIter=5, regParam=0.01)
ovr = OneVsRest(classifier=lr)
ovr = OneVsRest(classifier=lr, parallelism=8)
model = ovr.fit(df)
ovrPath = temp_path + "/ovr"
ovr.save(ovrPath)
Expand Down Expand Up @@ -1215,7 +1215,7 @@ def test_copy(self):
(2.0, Vectors.dense(0.5, 0.5))],
["label", "features"])
lr = LogisticRegression(maxIter=5, regParam=0.01)
ovr = OneVsRest(classifier=lr)
ovr = OneVsRest(classifier=lr, parallelism=1)
ovr1 = ovr.copy({lr.maxIter: 10})
self.assertEqual(ovr.getClassifier().getMaxIter(), 5)
self.assertEqual(ovr1.getClassifier().getMaxIter(), 10)
Expand All @@ -1229,11 +1229,37 @@ def test_output_columns(self):
(2.0, Vectors.dense(0.5, 0.5))],
["label", "features"])
lr = LogisticRegression(maxIter=5, regParam=0.01)
ovr = OneVsRest(classifier=lr)
ovr = OneVsRest(classifier=lr, parallelism=1)
model = ovr.fit(df)
output = model.transform(df)
self.assertEqual(output.columns, ["label", "features", "prediction"])

class ParOneVsRestTests(SparkSessionTestCase):

def test_copy(self):
df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)),
(1.0, Vectors.sparse(2, [], [])),
(2.0, Vectors.dense(0.5, 0.5))],
["label", "features"])
lr = LogisticRegression(maxIter=5, regParam=0.01)
ovr = OneVsRest(classifier=lr, parallelism=8)
ovr1 = ovr.copy({lr.maxIter: 10})
self.assertEqual(ovr.getClassifier().getMaxIter(), 5)
self.assertEqual(ovr1.getClassifier().getMaxIter(), 10)
model = ovr.fit(df)
model1 = model.copy({model.predictionCol: "indexed"})
self.assertEqual(model1.getPredictionCol(), "indexed")

def test_output_columns(self):
df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)),
(1.0, Vectors.sparse(2, [], [])),
(2.0, Vectors.dense(0.5, 0.5))],
["label", "features"])
lr = LogisticRegression(maxIter=5, regParam=0.01)
ovr = OneVsRest(classifier=lr, parallelism=8)
model = ovr.fit(df)
output = model.transform(df)
self.assertEqual(output.columns, ["label", "features", "prediction"])

class HashingTFTest(SparkSessionTestCase):

Expand Down