Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Not use LabeledPoint
  • Loading branch information
yu-iskw committed Jul 15, 2015
commit 77fd1b7b1c4960df79d0a597c0eae0010c35b666
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ private[python] class PythonMLLibAPI extends Serializable {
* Java stub for Python mllib LDA.run()
*/
def trainLDAModel(
data: JavaRDD[LabeledPoint],
data: JavaRDD[java.util.List[Any]],
k: Int,
maxIterations: Int,
docConcentration: Double,
Expand All @@ -524,11 +524,14 @@ private[python] class PythonMLLibAPI extends Serializable {

if (seed != null) algo.setSeed(seed)

try {
algo.run(data.rdd.map(x => (x.label.toLong, x.features)))
} finally {
data.rdd.unpersist(blocking = false)
val documents = data.rdd.map(_.asScala.toArray).map { r =>
r(0).getClass.getSimpleName match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

r(0) match {
case i: java.lang.Integer => i.toLong
case l: java.lang.Long => l
}

case "Integer" => (r(0).asInstanceOf[java.lang.Integer].toLong, r(1).asInstanceOf[Vector])
case "Long" => (r(0).asInstanceOf[java.lang.Long].toLong, r(1).asInstanceOf[Vector])
case _ => throw new IllegalArgumentException("input values contains invalid type value.")
}
}
algo.run(documents)
}


Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/mllib/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,8 +593,8 @@ class LDAModel(JavaModelWrapper):
>>> from collections import namedtuple
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not needed

>>> from numpy.testing import assert_almost_equal
>>> data = [
... LabeledPoint(1, [0.0, 1.0]),
... LabeledPoint(2, [1.0, 0.0]),
... [1, Vectors.dense([0.0, 1.0])],
... [2, SparseVector(2, {0: 1.0})],
... ]
>>> rdd = sc.parallelize(data)
>>> model = LDA.train(rdd, k=2)
Expand Down