From 1fd0d390eb4cae1d80bae9bee3d6a042fa030446 Mon Sep 17 00:00:00 2001 From: lewuathe Date: Wed, 19 Aug 2015 21:58:33 +0900 Subject: [PATCH] Pyspark StringIndexer implements handleInvalid API --- python/pyspark/ml/feature.py | 21 +++++++++++++++++++-- python/pyspark/ml/tests.py | 21 +++++++++++++++++++++ 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 04b2b2ccc9e5..f785cf87900e 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -790,18 +790,22 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol): [(0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)] """ + handleInvalid = Param(Params._dummy(), "handleInvalid", "A handler in case of invalid column.") + @keyword_only - def __init__(self, inputCol=None, outputCol=None): + def __init__(self, handleInvalid="error", inputCol=None, outputCol=None): """ __init__(self, inputCol=None, outputCol=None) """ super(StringIndexer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StringIndexer", self.uid) + self.handleInvalid = Param(self, "handleInvalid", "A handler in case of invalid column.") + self._setDefault(handleInvalid="error") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only - def setParams(self, inputCol=None, outputCol=None): + def setParams(self, handleInvalid="error", inputCol=None, outputCol=None): """ setParams(self, inputCol=None, outputCol=None) Sets params for this StringIndexer. @@ -809,6 +813,19 @@ def setParams(self, inputCol=None, outputCol=None): kwargs = self.setParams._input_kwargs return self._set(**kwargs) + def setHandleInvalid(self, value): + """ + Sets the value of :py:attr:`handleInvalid`. + """ + self._paramMap[self.handleInvalid] = value + return self + + def getHandleInvalid(self): + """ + Gets the value of handleInvalid or its default value. + """ + return self.getOrDefault(self.handleInvalid) + def _create_model(self, java_model): return StringIndexerModel(java_model) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index c151d21fd661..625624a87aa0 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -263,6 +263,27 @@ def test_ngram(self): transformedDF = ngram0.transform(dataset) self.assertEquals(transformedDF.head().output, ["a b c d", "b c d e"]) + def test_string_indexer(self): + sqlContext = SQLContext(self.sc) + dataset1 = sqlContext.createDataFrame([(0, "a"), (1, "b"), (4, "b")], ["id", "label"]) + dataset2 = sqlContext.createDataFrame([(0, "a"), (1, "b"), (2, "c")], ["id", "label"]) + stringIndexer = StringIndexer(inputCol="label", outputCol="indexed") + self.assertEqual(stringIndexer.getHandleInvalid(), "error") + model = stringIndexer.fit(dataset1) + try: + model.transform(dataset2).collect() + self.fail("StringIndexer validation is default.") + except: + pass + + stringIndexer = StringIndexer(handleInvalid="skip", inputCol="label", outputCol="indexed") + model = stringIndexer.fit(dataset1) + try: + model.transform(dataset2).collect() + pass + except: + self.fail("StringIndexer validation is default.") + if __name__ == "__main__": unittest.main()