Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[SPARK-19852][PYSPARK][ML] Update Python API for StringIndexer setHan…
…dleInvalid

This PR reflect the changes made in SPARK-17498 on pyspark to support a new option
'keep' in StringIndexer to handle unseen labels

Signed-off-by: VinceShieh <[email protected]>
  • Loading branch information
VinceShieh committed Mar 10, 2017
commit d94dc68a8c1a5c082cf3de6c7e4d429bfd24d817
30 changes: 28 additions & 2 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -1917,8 +1917,7 @@ def mean(self):


@inherit_doc
class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, JavaMLReadable,
JavaMLWritable):
class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
"""
A label indexer that maps a string column of labels to an ML column of label indices.
If the input column is numeric, we cast it to string and index the string values.
Expand All @@ -1936,6 +1935,14 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid,
>>> sorted(set([(i[0], str(i[1])) for i in itd.select(itd.id, itd.label2).collect()]),
... key=lambda x: x[0])
[(0, 'a'), (1, 'b'), (2, 'c'), (3, 'a'), (4, 'a'), (5, 'c')]
>>> testData2 = sc.parallelize([Row(id=0, label="a"), Row(id=1, label="d"),
... Row(id=2, label="e")], 2)
>>> dfKeep= spark.createDataFrame(testData2)
>>> tdKeep = stringIndexer.setHandleInvalid("keep").fit(stringIndDf).transform(dfKeep)
>>> itdKeep = inverter.transform(tdKeep)
>>> sorted(set([(i[0], str(i[1])) for i in itdKeep.select(itdKeep.id, itdKeep.label2).collect()]),
... key=lambda x: x[0])
[(0, 'a'), (6, 'd'), (6, 'e')]
>>> stringIndexerPath = temp_path + "/string-indexer"
>>> stringIndexer.save(stringIndexerPath)
>>> loadedIndexer = StringIndexer.load(stringIndexerPath)
Expand All @@ -1955,6 +1962,11 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid,
.. versionadded:: 1.4.0
"""

handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle unseen labels. " +
"Options are 'skip' (filter out rows with unseen labels), " +
"error (throw an error), or 'keep' (put unseen labels in a special " +
"additional bucket, at index numLabels).",
typeConverter=TypeConverters.toString)
@keyword_only
def __init__(self, inputCol=None, outputCol=None, handleInvalid="error"):
"""
Expand All @@ -1979,6 +1991,20 @@ def setParams(self, inputCol=None, outputCol=None, handleInvalid="error"):
def _create_model(self, java_model):
return StringIndexerModel(java_model)

@since("2.2.0")
def setHandleInvalid(self, value):
"""
Sets the value of :py:attr:`handleInvalid`.
"""
return self._set(handleInvalid=value)

@since("2.2.0")
def getHandleInvalid(self):
"""
Gets the value of :py:attr:`handleInvalid` or its default value.
"""
return self.getOrDefault(self.handleInvalid)


class StringIndexerModel(JavaModel, JavaMLReadable, JavaMLWritable):
"""
Expand Down