Skip to content

Commit 68abc1b

Browse files
hhbyyhsrowen
authored andcommitted
[SPARK-14814][MLLIB] API: Java compatibility, docs
## What changes were proposed in this pull request? jira: https://issues.apache.org/jira/browse/SPARK-14814 fix a java compatibility function in mllib DecisionTreeModel. As synced in jira, other compatibility issues don't need fixes. ## How was this patch tested? existing ut Author: Yuhao Yang <[email protected]> Closes apache#12971 from hhbyyh/javacompatibility.
1 parent 635ef40 commit 68abc1b

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ class DecisionTreeModel @Since("1.0.0") (
7575
* @return JavaRDD of predictions for each of the given data points
7676
*/
7777
@Since("1.2.0")
78-
def predict(features: JavaRDD[Vector]): JavaRDD[Double] = {
79-
predict(features.rdd)
78+
def predict(features: JavaRDD[Vector]): JavaRDD[java.lang.Double] = {
79+
predict(features.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]]
8080
}
8181

8282
/**

mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828

2929
import org.apache.spark.api.java.JavaRDD;
3030
import org.apache.spark.api.java.JavaSparkContext;
31+
import org.apache.spark.api.java.function.Function;
32+
import org.apache.spark.mllib.linalg.Vector;
3133
import org.apache.spark.mllib.regression.LabeledPoint;
3234
import org.apache.spark.mllib.tree.configuration.Algo;
3335
import org.apache.spark.mllib.tree.configuration.Strategy;
@@ -95,6 +97,14 @@ public void runDTUsingStaticMethods() {
9597

9698
DecisionTreeModel model = DecisionTree$.MODULE$.train(rdd.rdd(), strategy);
9799

100+
// java compatibility test
101+
JavaRDD<Double> predictions = model.predict(rdd.map(new Function<LabeledPoint, Vector>() {
102+
@Override
103+
public Vector call(LabeledPoint v1) {
104+
return v1.features();
105+
}
106+
}));
107+
98108
int numCorrect = validatePrediction(arr, model);
99109
Assert.assertTrue(numCorrect == rdd.count());
100110
}

0 commit comments

Comments
 (0)