Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
export numFeatures in ML PredictionModel
  • Loading branch information
vectorijk committed Jun 27, 2016
commit 872d384be9710eb5fd5c381ac4dd9eb7e40fa00a
13 changes: 0 additions & 13 deletions python/pyspark/ml/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,16 +116,3 @@ class Model(Transformer):
"""

__metaclass__ = ABCMeta


class HasNumFeaturesModel:
"""
Provides getter of the number of features especially for model class
It should be mixin with JavaModel.
"""
@property
def numFeatures(self):
"""
The number of features used to train the model.
"""
return self._call_java("numFeatures")
84 changes: 71 additions & 13 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

from pyspark import since, keyword_only
from pyspark.ml import Estimator, Model
from pyspark.ml.base import HasNumFeaturesModel
from pyspark.ml.param.shared import *
from pyspark.ml.regression import DecisionTreeModel, DecisionTreeRegressionModel, \
RandomForestParams, TreeEnsembleModels, TreeEnsembleParams
Expand Down Expand Up @@ -66,6 +65,8 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
DenseVector([5.5...])
>>> model.intercept
-2.68...
>>> model.numFeatures
1
>>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0))]).toDF()
>>> result = model.transform(test0).head()
>>> result.prediction
Expand Down Expand Up @@ -93,6 +94,8 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
True
>>> model.intercept == model2.intercept
True
>>> model.numFeatures == model2.numFeatures
True

.. versionadded:: 1.3.0
"""
Expand Down Expand Up @@ -215,7 +218,7 @@ def _checkThresholdConsistency(self):
" threshold (%g) and thresholds (equivalent to %g)" % (t2, t))


class LogisticRegressionModel(HasNumFeaturesModel, JavaModel, JavaMLWritable, JavaMLReadable):
class LogisticRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
.. note:: Experimental

Expand All @@ -240,6 +243,14 @@ def intercept(self):
"""
return self._call_java("intercept")

@property
@since("2.0.0")
def numFeatures(self):
"""
Number of features the model was trained on.
"""
return self._call_java("numFeatures")

@property
@since("2.0.0")
def summary(self):
Expand Down Expand Up @@ -525,6 +536,8 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
1
>>> model.featureImportances
SparseVector(1, {0: 1.0})
>>> model.numFeatures
1
>>> print(model.toDebugString)
DecisionTreeClassificationModel (uid=...) of depth 1 with 3 nodes...
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
Expand All @@ -549,8 +562,8 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
>>> model2 = DecisionTreeClassificationModel.load(model_path)
>>> model.featureImportances == model2.featureImportances
True
>>> model.numFeatures
1
>>> model.numFeatures == model2.numFeatures
True

.. versionadded:: 1.4.0
"""
Expand Down Expand Up @@ -600,8 +613,7 @@ def _create_model(self, java_model):


@inherit_doc
class DecisionTreeClassificationModel(HasNumFeaturesModel, DecisionTreeModel, JavaMLWritable,
JavaMLReadable):
class DecisionTreeClassificationModel(DecisionTreeModel, JavaMLWritable, JavaMLReadable):
"""
.. note:: Experimental

Expand Down Expand Up @@ -631,6 +643,14 @@ def featureImportances(self):
"""
return self._call_java("featureImportances")

@property
@since("2.0.0")
def numFeatures(self):
"""
Number of features the model was trained on.
"""
return self._call_java("numFeatures")


@inherit_doc
class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasSeed,
Expand Down Expand Up @@ -672,6 +692,8 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
>>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
>>> model.transform(test1).head().prediction
1.0
>>> model.numFeatures
1
>>> model.trees
[DecisionTreeClassificationModel (uid=...) of depth..., DecisionTreeClassificationModel...]
>>> rfc_path = temp_path + "/rfc"
Expand Down Expand Up @@ -734,8 +756,7 @@ def _create_model(self, java_model):
return RandomForestClassificationModel(java_model)


class RandomForestClassificationModel(TreeEnsembleModels, JavaMLWritable, JavaMLReadable,
HasNumFeaturesModel):
class RandomForestClassificationModel(TreeEnsembleModels, JavaMLWritable, JavaMLReadable):
"""
.. note:: Experimental

Expand All @@ -759,6 +780,14 @@ def featureImportances(self):
"""
return self._call_java("featureImportances")

@property
@since("2.0.0")
def numFeatures(self):
"""
Number of features the model was trained on.
"""
return self._call_java("numFeatures")

@property
@since("2.0.0")
def trees(self):
Expand Down Expand Up @@ -811,6 +840,8 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
1.0
>>> model.totalNumNodes
15
>>> model.numFeatures
1
>>> print(model.toDebugString)
GBTClassificationModel (uid=...)...with 5 trees...
>>> gbtc_path = temp_path + "gbtc"
Expand Down Expand Up @@ -892,7 +923,7 @@ def getLossType(self):
return self.getOrDefault(self.lossType)


class GBTClassificationModel(TreeEnsembleModels, JavaMLWritable, JavaMLReadable, HasNumFeaturesModel):
class GBTClassificationModel(TreeEnsembleModels, JavaMLWritable, JavaMLReadable):
"""
.. note:: Experimental

Expand All @@ -916,6 +947,14 @@ def featureImportances(self):
"""
return self._call_java("featureImportances")

@property
@since("2.0.0")
def numFeatures(self):
"""
Number of features the model was trained on.
"""
return self._call_java("numFeatures")

@property
@since("2.0.0")
def trees(self):
Expand Down Expand Up @@ -961,6 +1000,8 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H
>>> test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [1.0]))]).toDF()
>>> model.transform(test1).head().prediction
1.0
>>> model.numFeatures
2
>>> nb_path = temp_path + "/nb"
>>> nb.save(nb_path)
>>> nb2 = NaiveBayes.load(nb_path)
Expand All @@ -979,7 +1020,7 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H
>>> result.prediction
0.0
>>> model.numFeatures == model2.numFeatures
2
True

.. versionadded:: 1.5.0
"""
Expand Down Expand Up @@ -1052,7 +1093,7 @@ def getModelType(self):
return self.getOrDefault(self.modelType)


class NaiveBayesModel(JavaModel, JavaMLWritable, JavaMLReadable, HasNumFeaturesModel):
class NaiveBayesModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
.. note:: Experimental

Expand All @@ -1077,6 +1118,14 @@ def theta(self):
"""
return self._call_java("theta")

@property
@since("2.0.0")
def numFeatures(self):
"""
Number of features the model was trained on.
"""
return self._call_java("numFeatures")


@inherit_doc
class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
Expand All @@ -1102,6 +1151,8 @@ class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol,
[2, 2, 2]
>>> model.weights.size
12
>>> model.numFeatures
2
>>> testDF = spark.createDataFrame([
... (Vectors.dense([1.0, 0.0]),),
... (Vectors.dense([0.0, 0.0]),)], ["features"])
Expand Down Expand Up @@ -1255,8 +1306,7 @@ def getInitialWeights(self):
return self.getOrDefault(self.initialWeights)


class MultilayerPerceptronClassificationModel(JavaModel, JavaMLWritable, JavaMLReadable,
HasNumFeaturesModel):
class MultilayerPerceptronClassificationModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
.. note:: Experimental

Expand All @@ -1281,6 +1331,14 @@ def weights(self):
"""
return self._call_java("weights")

@property
@since("2.0.0")
def numFeatures(self):
"""
Number of features the model was trained on.
"""
return self._call_java("numFeatures")


class OneVsRestParams(HasFeaturesCol, HasLabelCol, HasPredictionCol):
"""
Expand Down
Loading