Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Added persistence to LDA in Python
  • Loading branch information
jkbradley committed Apr 27, 2016
commit 4f807e82cf4366ea83557aa6060b14245c7f5764
15 changes: 11 additions & 4 deletions python/pyspark/ml/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ def estimatedDocConcentration(self):
return self._call_java("estimatedDocConcentration")


class DistributedLDAModel(LDAModel):
class DistributedLDAModel(LDAModel, JavaMLReadable, JavaMLWritable):
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.. note:: Experimental

Distributed model fitted by :py:class:`LDA`.
This type of model is currently only produced by Expectation-Maximization (EM).
Expand All @@ -542,6 +542,12 @@ class DistributedLDAModel(LDAModel):

@since("2.0.0")
def toLocal(self):
"""
Convert this distributed model to a local representation. This discards info about the
training dataset.

WARNING: This involves collecting a large :py:func:`topicsMatrix` to the driver.
"""
return LocalLDAModel(self._call_java("toLocal"))

@since("2.0.0")
Expand Down Expand Up @@ -584,7 +590,7 @@ def getCheckpointFiles(self):
return self._call_java("getCheckpointFiles")


class LocalLDAModel(LDAModel):
class LocalLDAModel(LDAModel, JavaMLReadable, JavaMLWritable):
"""
Local (non-distributed) model fitted by :py:class:`LDA`.
This model stores the inferred topics only; it does not store info about the training dataset.
Expand All @@ -594,7 +600,8 @@ class LocalLDAModel(LDAModel):
pass


class LDA(JavaEstimator, HasFeaturesCol, HasMaxIter, HasSeed, HasCheckpointInterval):
class LDA(JavaEstimator, HasFeaturesCol, HasMaxIter, HasSeed, HasCheckpointInterval,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@inherit_doc

JavaMLReadable, JavaMLWritable):
"""
Latent Dirichlet Allocation (LDA), a topic model designed for text documents.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.. note:: Experimental, should be consistent with Scala.

Expand Down Expand Up @@ -768,7 +775,7 @@ def setLearningOffset(self, value):

>>> algo = LDA().setLearningOffset(100)
>>> algo.getLearningOffset()
100
100.0
"""
return self._set(learningOffset=value)

Expand Down
56 changes: 55 additions & 1 deletion python/pyspark/ml/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer
from pyspark.ml.classification import (
LogisticRegression, DecisionTreeClassifier, OneVsRest, OneVsRestModel)
from pyspark.ml.clustering import KMeans
from pyspark.ml.clustering import *
from pyspark.ml.evaluation import BinaryClassificationEvaluator, RegressionEvaluator
from pyspark.ml.feature import *
from pyspark.ml.param import Param, Params, TypeConverters
Expand Down Expand Up @@ -782,6 +782,60 @@ def test_decisiontree_regressor(self):
pass


class LDATest(PySparkTestCase):

def _compare(self, m1, m2):
"""
Temp method for comparing instances.
TODO: Replace with generic implementation once SPARK-14706 is merged.
"""
self.assertEqual(m1.uid, m2.uid)
self.assertEqual(type(m1), type(m2))
self.assertEqual(len(m1.params), len(m2.params))
for p in m1.params:
if m1.isDefined(p):
self.assertEqual(m1.getOrDefault(p), m2.getOrDefault(p))
self.assertEqual(p.parent, m2.getParam(p.name).parent)
if isinstance(m1, LDAModel):
self.assertEqual(m1.vocabSize(), m2.vocabSize())
self.assertEqual(m1.topicsMatrix(), m2.topicsMatrix())

def test_persistence(self):
# Test save/load for LDA, LocalLDAModel, DistributedLDAModel.
sqlContext = SQLContext(self.sc)
df = sqlContext.createDataFrame([[1, Vectors.dense([0.0, 1.0])],
[2, Vectors.sparse(2, {0: 1.0})],],
["id", "features"])
# Fit model
lda = LDA(k=2, seed=1, optimizer="em")
distributedModel = lda.fit(df)
self.assertTrue(distributedModel.isDistributed())
localModel = distributedModel.toLocal()
self.assertFalse(localModel.isDistributed())
# Define paths
path = tempfile.mkdtemp()
lda_path = path + "/lda"
dist_model_path = path + "/distLDAModel"
local_model_path = path + "/localLDAModel"
# Test LDA
lda.save(lda_path)
lda2 = LDA.load(lda_path)
self._compare(lda, lda2)
# Test DistributedLDAModel
distributedModel.save(dist_model_path)
distributedModel2 = DistributedLDAModel.load(dist_model_path)
self._compare(distributedModel, distributedModel2)
# Test LocalLDAModel
localModel.save(local_model_path)
localModel2 = LocalLDAModel.load(local_model_path)
self._compare(localModel, localModel2)
# Clean up
try:
rmtree(path)
except OSError:
pass


class TrainingSummaryTest(PySparkTestCase):

def test_linear_regression_summary(self):
Expand Down