Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,25 +790,42 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol):
[(0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)]
"""

handleInvalid = Param(Params._dummy(), "handleInvalid", "A handler in case of invalid column.")

@keyword_only
def __init__(self, inputCol=None, outputCol=None):
def __init__(self, handleInvalid="error", inputCol=None, outputCol=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can make handleInvalid shared param just like what we do in Scala API. I have submit #8313 for this issue, please feel free to comment.

"""
__init__(self, inputCol=None, outputCol=None)
"""
super(StringIndexer, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StringIndexer", self.uid)
self.handleInvalid = Param(self, "handleInvalid", "A handler in case of invalid column.")
self._setDefault(handleInvalid="error")
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)

@keyword_only
def setParams(self, inputCol=None, outputCol=None):
def setParams(self, handleInvalid="error", inputCol=None, outputCol=None):
"""
setParams(self, inputCol=None, outputCol=None)
Sets params for this StringIndexer.
"""
kwargs = self.setParams._input_kwargs
return self._set(**kwargs)

def setHandleInvalid(self, value):
"""
Sets the value of :py:attr:`handleInvalid`.
"""
self._paramMap[self.handleInvalid] = value
return self

def getHandleInvalid(self):
"""
Gets the value of handleInvalid or its default value.
"""
return self.getOrDefault(self.handleInvalid)

def _create_model(self, java_model):
return StringIndexerModel(java_model)

Expand Down
21 changes: 21 additions & 0 deletions python/pyspark/ml/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,27 @@ def test_ngram(self):
transformedDF = ngram0.transform(dataset)
self.assertEquals(transformedDF.head().output, ["a b c d", "b c d e"])

def test_string_indexer(self):
sqlContext = SQLContext(self.sc)
dataset1 = sqlContext.createDataFrame([(0, "a"), (1, "b"), (4, "b")], ["id", "label"])
dataset2 = sqlContext.createDataFrame([(0, "a"), (1, "b"), (2, "c")], ["id", "label"])
stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")
self.assertEqual(stringIndexer.getHandleInvalid(), "error")
model = stringIndexer.fit(dataset1)
try:
model.transform(dataset2).collect()
self.fail("StringIndexer validation is default.")
except:
pass

stringIndexer = StringIndexer(handleInvalid="skip", inputCol="label", outputCol="indexed")
model = stringIndexer.fit(dataset1)
try:
model.transform(dataset2).collect()
pass
except:
self.fail("StringIndexer validation is default.")


if __name__ == "__main__":
unittest.main()