Skip to content
Closed
Next Next commit
add handleInvalid to QuantileDiscretizer
  • Loading branch information
techaddict committed Nov 9, 2016
commit 0e41b36493fcb5eee5f342f694b0d2bc2a1e6c41
32 changes: 27 additions & 5 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -1163,9 +1163,11 @@ class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadab

>>> df = spark.createDataFrame([(0.1,), (0.4,), (1.2,), (1.5,)], ["values"])
>>> qds = QuantileDiscretizer(numBuckets=2,
... inputCol="values", outputCol="buckets", relativeError=0.01)
... inputCol="values", outputCol="buckets", relativeError=0.01, handleInvalid="error")
>>> qds.getRelativeError()
0.01
>>> qds.getHandleInvalid()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We didn't add anything to the doctest of bucketizer. Actually, I think it would be nice in both places to set handleInvalid='skip' and then add an invalid value to the example data. That way we can show what we mean by invalid and prove that it works.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea! adding

'error'
>>> bucketizer = qds.fit(df)
>>> splits = bucketizer.getSplits()
>>> splits[0]
Expand Down Expand Up @@ -1194,21 +1196,27 @@ class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadab
"Must be in the range [0, 1].",
typeConverter=TypeConverters.toFloat)

handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle" +
"invalid entries. Options are skip (filter out rows with invalid values)"+
", error (throw an error), or keep (keep invalid values in a special "+
"additional bucket).",
typeConverter=TypeConverters.toString)

@keyword_only
def __init__(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001):
def __init__(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001, handleInvalid="error"):
"""
__init__(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001)
__init__(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001, handleInvalid="error")
"""
super(QuantileDiscretizer, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.QuantileDiscretizer",
self.uid)
self._setDefault(numBuckets=2, relativeError=0.001)
self._setDefault(numBuckets=2, relativeError=0.001, handleInvalid="error")
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)

@keyword_only
@since("2.0.0")
def setParams(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001):
def setParams(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001, handleInvalid="error"):
"""
setParams(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001)
Set the params for the QuantileDiscretizer
Expand Down Expand Up @@ -1244,6 +1252,20 @@ def getRelativeError(self):
"""
return self.getOrDefault(self.relativeError)

@since("2.1.0")
def setHandleInvalid(self):
"""
Sets the value of :py:attr:`handleInvalid`.
"""
return self._set(handleInvalid=value)

@since("2.1.0")
def getHandleInvalid(self):
"""
Gets the value of handleInvalid or its default value.
"""
return self.getOrDefault(self.handleInvalid)

def _create_model(self, java_model):
"""
Private method to convert the java_model to a Python model.
Expand Down