Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Added class method to construct CountVectorizerModel from vocab, not …
…yet working because missing param _copyValues from estimator to model
  • Loading branch information
BryanCutler committed Mar 5, 2018
commit 01e5a4bf8e7951e8fe2a5cc80bcd856ba2da7ee8
31 changes: 29 additions & 2 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
if sys.version > '3':
basestring = str

from pyspark import since, keyword_only
from pyspark import since, keyword_only, SparkContext
from pyspark.rdd import ignore_unicode_prefix
from pyspark.ml.linalg import _convert_to_vector
from pyspark.ml.param.shared import *
Expand Down Expand Up @@ -437,6 +437,16 @@ class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable,
>>> loadedModel = CountVectorizerModel.load(modelPath)
>>> loadedModel.vocabulary == model.vocabulary
True
>>> fromVocabModel = CountVectorizerModel.fromVocabulary(model.vocabulary,
... inputCol="raw", outputCol="vectors")
>>> fromVocabModel.transform(df).show(truncate=False)
+-----+---------------+-------------------------+
|label|raw |vectors |
+-----+---------------+-------------------------+
|0 |[a, b, c] |(3,[0,1,2],[1.0,1.0,1.0])|
|1 |[a, b, b, c, a]|(3,[0,1,2],[2.0,2.0,1.0])|
+-----+---------------+-------------------------+
...

.. versionadded:: 1.6.0
"""
Expand Down Expand Up @@ -550,13 +560,30 @@ def _create_model(self, java_model):
return CountVectorizerModel(java_model)


class CountVectorizerModel(JavaModel, JavaMLReadable, JavaMLWritable):
class CountVectorizerModel(JavaModel, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
"""
Model fitted by :py:class:`CountVectorizer`.

.. versionadded:: 1.6.0
"""

@classmethod
@since("2.2.0")
def fromVocabulary(cls, vocab, inputCol, outputCol=None):
"""
Construct the model directly from a vocabulary list, requires
an active SparkContext.
"""
sc = SparkContext._active_spark_context
java_class = sc._gateway.jvm.java.lang.String
jvocab = CountVectorizerModel._new_java_array(vocab, java_class)
model = CountVectorizerModel._create_from_java_class(
"org.apache.spark.ml.feature.CountVectorizerModel", jvocab)
model.setInputCol(inputCol)
if outputCol is not None:
model.setOutputCol(outputCol)
return model

@property
@since("1.6.0")
def vocabulary(self):
Expand Down