-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-23615][ML][PYSPARK]Add maxDF Parameter to Python CountVectorizer #20777
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -422,6 +422,14 @@ class _CountVectorizerParams(JavaParams, HasInputCol, HasOutputCol): | |
| " If this is an integer >= 1, this specifies the number of documents the term must" + | ||
| " appear in; if this is a double in [0,1), then this specifies the fraction of documents." + | ||
| " Default 1.0", typeConverter=TypeConverters.toFloat) | ||
| maxDF = Param( | ||
| Params._dummy(), "maxDF", "Specifies the maximum number of" + | ||
| " different documents a term could appear in to be included in the vocabulary." + | ||
| " A term that appears more than the threshold will be ignored. If this is an" + | ||
| " integer >= 1, this specifies the maximum number of documents the term could appear in;" + | ||
| " if this is a double in [0,1), then this specifies the maximum" + | ||
| " fraction of documents the term could appear in." + | ||
| " Default (2^63) - 1", typeConverter=TypeConverters.toFloat) | ||
| vocabSize = Param( | ||
| Params._dummy(), "vocabSize", "max size of the vocabulary. Default 1 << 18.", | ||
| typeConverter=TypeConverters.toInt) | ||
|
|
@@ -433,7 +441,7 @@ class _CountVectorizerParams(JavaParams, HasInputCol, HasOutputCol): | |
|
|
||
| def __init__(self, *args): | ||
| super(_CountVectorizerParams, self).__init__(*args) | ||
| self._setDefault(minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False) | ||
| self._setDefault(minTF=1.0, minDF=1.0, maxDF=2 ** 63 - 1, vocabSize=1 << 18, binary=False) | ||
|
|
||
| @since("1.6.0") | ||
| def getMinTF(self): | ||
|
|
@@ -449,6 +457,13 @@ def getMinDF(self): | |
| """ | ||
| return self.getOrDefault(self.minDF) | ||
|
|
||
| @since("2.4.0") | ||
| def getMaxDF(self): | ||
| """ | ||
| Gets the value of maxDF or its default value. | ||
| """ | ||
| return self.getOrDefault(self.maxDF) | ||
|
|
||
| @since("1.6.0") | ||
| def getVocabSize(self): | ||
| """ | ||
|
|
@@ -513,11 +528,11 @@ class CountVectorizer(JavaEstimator, _CountVectorizerParams, JavaMLReadable, Jav | |
| """ | ||
|
|
||
| @keyword_only | ||
| def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None, | ||
| outputCol=None): | ||
| def __init__(self, minTF=1.0, minDF=1.0, maxDF=2 ** 63 - 1, vocabSize=1 << 18, binary=False, | ||
|
||
| inputCol=None, outputCol=None): | ||
| """ | ||
| __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None,\ | ||
| outputCol=None) | ||
| __init__(self, minTF=1.0, minDF=1.0, maxDF=2 ** 63 - 1, vocabSize=1 << 18, binary=False,\ | ||
| inputCol=None,outputCol=None) | ||
| """ | ||
| super(CountVectorizer, self).__init__() | ||
| self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.CountVectorizer", | ||
|
|
@@ -527,11 +542,11 @@ def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputC | |
|
|
||
| @keyword_only | ||
| @since("1.6.0") | ||
| def setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None, | ||
| outputCol=None): | ||
| def setParams(self, minTF=1.0, minDF=1.0, maxDF=2 ** 63 - 1, vocabSize=1 << 18, binary=False, | ||
| inputCol=None, outputCol=None): | ||
| """ | ||
| setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None,\ | ||
| outputCol=None) | ||
| setParams(self, minTF=1.0, minDF=1.0, maxDF=2 ** 63 - 1, vocabSize=1 << 18, binary=False,\ | ||
| inputCol=None, outputCol=None) | ||
| Set the params for the CountVectorizer | ||
| """ | ||
| kwargs = self._input_kwargs | ||
|
|
@@ -551,6 +566,13 @@ def setMinDF(self, value): | |
| """ | ||
| return self._set(minDF=value) | ||
|
|
||
| @since("2.4.0") | ||
| def setMaxDF(self, value): | ||
| """ | ||
| Sets the value of :py:attr:`maxDF`. | ||
| """ | ||
| return self._set(maxDF=value) | ||
|
|
||
| @since("1.6.0") | ||
| def setVocabSize(self, value): | ||
| """ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -679,6 +679,31 @@ def test_count_vectorizer_with_binary(self): | |
| feature, expected = r | ||
| self.assertEqual(feature, expected) | ||
|
|
||
| def test_count_vectorizer_with_maxDF(self): | ||
| dataset = self.spark.createDataFrame([ | ||
| (0, "a b c d".split(' '), SparseVector(3, {0: 1.0, 1: 1.0, 2: 1.0}),), | ||
| (1, "a b c".split(' '), SparseVector(3, {0: 1.0, 1: 1.0}),), | ||
| (2, "a b".split(' '), SparseVector(3, {0: 1.0}),), | ||
| (3, "a".split(' '), SparseVector(3, {}),)], ["id", "words", "expected"]) | ||
| cv = CountVectorizer(inputCol="words", outputCol="features") | ||
| model1 = cv.setMaxDF(3).fit(dataset) | ||
|
||
| self.assertEqual(model1.vocabulary, ['b', 'c', 'd']) | ||
|
|
||
| transformedList1 = model1.transform(dataset).select("features", "expected").collect() | ||
|
|
||
| for r in transformedList1: | ||
| feature, expected = r | ||
| self.assertEqual(feature, expected) | ||
|
|
||
| model2 = cv.setMaxDF(0.75).fit(dataset) | ||
| self.assertEqual(model2.vocabulary, ['b', 'c', 'd']) | ||
|
|
||
| transformedList2 = model2.transform(dataset).select("features", "expected").collect() | ||
|
|
||
| for r in transformedList2: | ||
| feature, expected = r | ||
| self.assertEqual(feature, expected) | ||
|
|
||
| def test_count_vectorizer_from_vocab(self): | ||
| model = CountVectorizerModel.from_vocabulary(["a", "b", "c"], inputCol="words", | ||
| outputCol="features", minTF=2) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@srowen do these doc changes look ok to you? It was a little confusing before saying that the term "must appear" when it's a max value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree, your wording is clearer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @srowen !