Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,7 +920,7 @@ def mean(self):


@inherit_doc
class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol):
class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some other classes has no shared accessor class (e.g. StandardScaler, RegexTokenizer) corresponding to arbitrary properties. It might be better to keep handleInvalid only inside of StringIndexer or create other shared accessor classes for StandardScaler etc to standardize the way of accessing to arbitrary properties.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

handleInvalid is common property and other Transformer/Estimators may use it in future, so I think we need to make it shared param. Another reason is that we want to make Python consistency with Scala, and handleInvalid in Scala is shared param.
@jkbradley

"""
.. note:: Experimental

Expand All @@ -943,19 +943,20 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol):
"""

@keyword_only
def __init__(self, inputCol=None, outputCol=None):
def __init__(self, inputCol=None, outputCol=None, handleInvalid="error"):
"""
__init__(self, inputCol=None, outputCol=None)
__init__(self, inputCol=None, outputCol=None, handleInvalid="error")
"""
super(StringIndexer, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StringIndexer", self.uid)
self._setDefault(handleInvalid="error")
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)

@keyword_only
def setParams(self, inputCol=None, outputCol=None):
def setParams(self, inputCol=None, outputCol=None, handleInvalid="error"):
"""
setParams(self, inputCol=None, outputCol=None)
setParams(self, inputCol=None, outputCol=None, handleInvalid="error")
Sets params for this StringIndexer.
"""
kwargs = self.setParams._input_kwargs
Expand Down Expand Up @@ -1235,6 +1236,10 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol):
>>> model = indexer.fit(df)
>>> model.transform(df).head().indexed
DenseVector([1.0, 0.0])
>>> model.numFeatures
2
>>> model.categoryMaps
{0: {0.0: 0, -1.0: 1}}
>>> indexer.setParams(outputCol="test").fit(df).transform(df).collect()[1].test
DenseVector([0.0, 1.0])
>>> params = {indexer.maxCategories: 3, indexer.outputCol: "vector"}
Expand Down Expand Up @@ -1297,6 +1302,22 @@ class VectorIndexerModel(JavaModel):
Model fitted by VectorIndexer.
"""

@property
def numFeatures(self):
"""
Number of features, i.e., length of Vectors which this transforms.
"""
return self._call_java("numFeatures")

@property
def categoryMaps(self):
"""
Feature value index. Keys are categorical feature indices (column indices).
Values are maps from original features values to 0-based category indices.
If a feature is not in this map, it is treated as continuous.
"""
return self._call_java("javaCategoryMaps")


@inherit_doc
@ignore_unicode_prefix
Expand Down
5 changes: 4 additions & 1 deletion python/pyspark/ml/param/_shared_params_code_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,10 @@ def get$Name(self):
("checkpointInterval", "checkpoint interval (>= 1)", None),
("seed", "random seed", "hash(type(self).__name__)"),
("tol", "the convergence tolerance for iterative algorithms", None),
("stepSize", "Step size to be used for each iteration of optimization.", None)]
("stepSize", "Step size to be used for each iteration of optimization.", None),
("handleInvalid", "how to handle invalid entries. Options are skip (which will filter " +
"out rows with bad values), or error (which will throw an errror). More options may be " +
"added later.", None)]
code = []
for name, doc, defaultValueStr in shared:
param_code = _gen_param_header(name, doc, defaultValueStr)
Expand Down
31 changes: 29 additions & 2 deletions python/pyspark/ml/param/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,33 @@ def getStepSize(self):
return self.getOrDefault(self.stepSize)


class HasHandleInvalid(Params):
"""
Mixin for param handleInvalid: how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later..
"""

# a placeholder to make it appear in the generated doc
handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.")

def __init__(self):
super(HasHandleInvalid, self).__init__()
#: param for how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.
self.handleInvalid = Param(self, "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.")

def setHandleInvalid(self, value):
"""
Sets the value of :py:attr:`handleInvalid`.
"""
self._paramMap[self.handleInvalid] = value
return self

def getHandleInvalid(self):
"""
Gets the value of handleInvalid or its default value.
"""
return self.getOrDefault(self.handleInvalid)


class DecisionTreeParams(Params):
"""
Mixin for Decision Tree parameters.
Expand All @@ -444,7 +471,7 @@ class DecisionTreeParams(Params):
minInfoGain = Param(Params._dummy(), "minInfoGain", "Minimum information gain for a split to be considered at a tree node.")
maxMemoryInMB = Param(Params._dummy(), "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.")
cacheNodeIds = Param(Params._dummy(), "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees.")


def __init__(self):
super(DecisionTreeParams, self).__init__()
Expand All @@ -460,7 +487,7 @@ def __init__(self):
self.maxMemoryInMB = Param(self, "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.")
#: param for If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees.
self.cacheNodeIds = Param(self, "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees.")

def setMaxDepth(self, value):
"""
Sets the value of :py:attr:`maxDepth`.
Expand Down