Skip to content

Commit 40d3027

Browse files
committed
Add Java test
1 parent 7056d4a commit 40d3027

File tree

3 files changed

+74
-24
lines changed

3 files changed

+74
-24
lines changed

mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -19,26 +19,29 @@ package org.apache.spark.ml.source.libsvm
1919

2020
import com.google.common.base.Objects
2121
import org.apache.spark.Logging
22-
import org.apache.spark.mllib.linalg.Vector
22+
import org.apache.spark.annotation.Since
23+
import org.apache.spark.mllib.linalg.{VectorUDT, Vector}
2324
import org.apache.spark.mllib.regression.LabeledPoint
2425
import org.apache.spark.mllib.util.MLUtils
2526
import org.apache.spark.rdd.RDD
2627
import org.apache.spark.sql.types._
2728
import org.apache.spark.sql.{Row, SQLContext}
2829
import org.apache.spark.sql.sources.{DataSourceRegister, PrunedScan, BaseRelation, RelationProvider}
2930

30-
31-
class LibSVMRelation(val path: String, val numFeatures: Int, val featuresType: String)
31+
/**
32+
* LibSVMRelation provides the DataFrame constructed from LibSVM format data.
33+
* @param path
34+
* @param numFeatures
35+
* @param vectorType
36+
* @param sqlContext
37+
*/
38+
private[ml] class LibSVMRelation(val path: String, val numFeatures: Int, val vectorType: String)
3239
(@transient val sqlContext: SQLContext)
3340
extends BaseRelation with PrunedScan with Logging {
3441

35-
private final val vectorType: DataType
36-
= classOf[Vector].getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
37-
38-
3942
override def schema: StructType = StructType(
4043
StructField("label", DoubleType, nullable = false) ::
41-
StructField("features", vectorType, nullable = false) :: Nil
44+
StructField("features", new VectorUDT(), nullable = false) :: Nil
4245
)
4346

4447
override def buildScan(requiredColumns: Array[String]): RDD[Row] = {
@@ -47,8 +50,8 @@ class LibSVMRelation(val path: String, val numFeatures: Int, val featuresType: S
4750

4851
val rowBuilders = requiredColumns.map {
4952
case "label" => (pt: LabeledPoint) => Seq(pt.label)
50-
case "features" if featuresType == "sparse" => (pt: LabeledPoint) => Seq(pt.features.toSparse)
51-
case "features" if featuresType == "dense" => (pt: LabeledPoint) => Seq(pt.features.toDense)
53+
case "features" if vectorType == "sparse" => (pt: LabeledPoint) => Seq(pt.features.toSparse)
54+
case "features" if vectorType == "dense" => (pt: LabeledPoint) => Seq(pt.features.toDense)
5255
}
5356

5457
baseRdd.map(pt => {
@@ -69,16 +72,6 @@ class LibSVMRelation(val path: String, val numFeatures: Int, val featuresType: S
6972

7073
class DefaultSource extends RelationProvider with DataSourceRegister {
7174

72-
/**
73-
* The string that represents the format that this data source provider uses. This is
74-
* overridden by children to provide a nice alias for the data source. For example:
75-
*
76-
* {{{
77-
* override def format(): String = "parquet"
78-
* }}}
79-
*
80-
* @since 1.5.0
81-
*/
8275
override def shortName(): String = "libsvm"
8376

8477
private def checkPath(parameters: Map[String, String]): String = {
@@ -90,8 +83,8 @@ class DefaultSource extends RelationProvider with DataSourceRegister {
9083
* Note: the parameters' keywords are case insensitive and this insensitivity is enforced
9184
* by the Map that is passed to the function.
9285
*/
93-
override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]):
94-
BaseRelation = {
86+
override def createRelation(sqlContext: SQLContext, parameters: Map[String, String])
87+
: BaseRelation = {
9588
val path = checkPath(parameters)
9689
val numFeatures = parameters.getOrElse("numFeatures", "-1").toInt
9790
/**
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
package org.apache.spark.ml.source;
2+
3+
import com.google.common.base.Charsets;
4+
import com.google.common.io.Files;
5+
import org.apache.spark.api.java.JavaSparkContext;
6+
import org.apache.spark.mllib.linalg.Vectors;
7+
import org.apache.spark.sql.DataFrame;
8+
import org.apache.spark.sql.Row;
9+
import org.apache.spark.sql.SQLContext;
10+
import org.apache.spark.util.Utils;
11+
import org.junit.After;
12+
import org.junit.Assert;
13+
import org.junit.Before;
14+
import org.junit.Test;
15+
16+
import java.io.File;
17+
import java.io.IOException;
18+
19+
/**
20+
* Test LibSVMRelation in Java.
21+
*/
22+
public class JavaLibSVMRelationSuite {
23+
private transient JavaSparkContext jsc;
24+
private transient SQLContext jsql;
25+
private transient DataFrame dataset;
26+
27+
private File path;
28+
29+
@Before
30+
public void setUp() throws IOException {
31+
jsc = new JavaSparkContext("local", "JavaLibSVMRelationSuite");
32+
jsql = new SQLContext(jsc);
33+
34+
path = Utils.createTempDir(System.getProperty("java.io.tmpdir"),
35+
"datasource").getCanonicalFile();
36+
if (path.exists()) {
37+
path.delete();
38+
}
39+
40+
String s = "1 1:1.0 3:2.0 5:3.0\n0\n0 2:4.0 4:5.0 6:6.0";
41+
Files.write(s, path, Charsets.US_ASCII);
42+
}
43+
44+
@After
45+
public void tearDown() {
46+
jsc.stop();
47+
jsc = null;
48+
}
49+
50+
@Test
51+
public void verifyLibSvmDF() {
52+
dataset = jsql.read().format("libsvm").load();
53+
Assert.assertEquals(dataset.columns()[0], "label");
54+
Assert.assertEquals(dataset.columns()[1], "features");
55+
Row r = dataset.first();
56+
Assert.assertTrue(r.getDouble(0) == 1.0);
57+
Assert.assertEquals(r.getAs(1), Vectors.dense(1.0, 0.0, 2.0, 0.0, 3.0, 0.0));
58+
}
59+
}

mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,4 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
6969
val row1 = df.first()
7070
assert(row1.getAs[SparseVector](1) == Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0))))
7171
}
72-
73-
7472
}

0 commit comments

Comments
 (0)