Skip to content

Commit 4f40891

Browse files
committed
Improve test suites
1 parent 5ab62ab commit 4f40891

File tree

3 files changed

+42
-38
lines changed

3 files changed

+42
-38
lines changed

mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ import com.google.common.base.Objects
2121

2222
import org.apache.spark.Logging
2323
import org.apache.spark.mllib.linalg.VectorUDT
24-
import org.apache.spark.mllib.regression.LabeledPoint
2524
import org.apache.spark.mllib.util.MLUtils
2625
import org.apache.spark.rdd.RDD
2726
import org.apache.spark.sql.types.{StructType, StructField, DoubleType}
@@ -37,7 +36,7 @@ import org.apache.spark.sql.sources._
3736
*/
3837
private[ml] class LibSVMRelation(val path: String, val numFeatures: Int, val vectorType: String)
3938
(@transient val sqlContext: SQLContext)
40-
extends BaseRelation with TableScan with Logging {
39+
extends BaseRelation with TableScan with Logging with Serializable {
4140

4241
override def schema: StructType = StructType(
4342
StructField("label", DoubleType, nullable = false) ::
@@ -48,18 +47,10 @@ private[ml] class LibSVMRelation(val path: String, val numFeatures: Int, val vec
4847
val sc = sqlContext.sparkContext
4948
val baseRdd = MLUtils.loadLibSVMFile(sc, path, numFeatures)
5049

51-
val rowBuilders = Array(
52-
(pt: LabeledPoint) => Seq(pt.label),
53-
if (vectorType == "dense") {
54-
(pt: LabeledPoint) => Seq(pt.features.toDense)
55-
} else {
56-
(pt: LabeledPoint) => Seq(pt.features.toSparse)
57-
}
58-
)
59-
60-
baseRdd.map(pt => {
61-
Row.fromSeq(rowBuilders.map(_(pt)).reduceOption(_ ++ _).getOrElse(Seq.empty))
62-
})
50+
baseRdd.map { pt =>
51+
val features = if (vectorType == "dense") pt.features.toDense else pt.features.toSparse
52+
Row(pt.label, features)
53+
}
6354
}
6455

6556
override def hashCode(): Int = {
@@ -95,7 +86,7 @@ class DefaultSource extends RelationProvider with DataSourceRegister {
9586
* featuresType can be selected "dense" or "sparse".
9687
* This parameter decides the type of returned feature vector.
9788
*/
98-
val featuresType = parameters.getOrElse("featuresType", "sparse")
99-
new LibSVMRelation(path, numFeatures, featuresType)(sqlContext)
89+
val vectorType = parameters.getOrElse("vectorType", "sparse")
90+
new LibSVMRelation(path, numFeatures, vectorType)(sqlContext)
10091
}
10192
}

mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,23 +17,25 @@
1717

1818
package org.apache.spark.ml.source;
1919

20+
import java.io.File;
21+
import java.io.IOException;
22+
2023
import com.google.common.base.Charsets;
2124
import com.google.common.io.Files;
2225

26+
import org.apache.spark.mllib.linalg.DenseVector;
27+
import org.junit.After;
28+
import org.junit.Assert;
29+
import org.junit.Before;
30+
import org.junit.Test;
31+
2332
import org.apache.spark.api.java.JavaSparkContext;
2433
import org.apache.spark.mllib.linalg.Vectors;
2534
import org.apache.spark.sql.DataFrame;
2635
import org.apache.spark.sql.Row;
2736
import org.apache.spark.sql.SQLContext;
2837
import org.apache.spark.util.Utils;
2938

30-
import org.junit.After;
31-
import org.junit.Assert;
32-
import org.junit.Before;
33-
import org.junit.Test;
34-
35-
import java.io.File;
36-
import java.io.IOException;
3739

3840
/**
3941
* Test LibSVMRelation in Java.
@@ -50,11 +52,8 @@ public void setUp() throws IOException {
5052
jsc = new JavaSparkContext("local", "JavaLibSVMRelationSuite");
5153
jsql = new SQLContext(jsc);
5254

53-
path = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource")
54-
.getCanonicalFile();
55-
if (path.exists()) {
56-
path.delete();
57-
}
55+
File tmpDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource");
56+
path = File.createTempFile("datasource", "libsvm-relation", tmpDir);
5857

5958
String s = "1 1:1.0 3:2.0 5:3.0\n0\n0 2:4.0 4:5.0 6:6.0";
6059
Files.write(s, path, Charsets.US_ASCII);
@@ -69,11 +68,13 @@ public void tearDown() {
6968

7069
@Test
7170
public void verifyLibSVMDF() {
72-
dataset = jsql.read().format("org.apache.spark.ml.source.libsvm").load(path.getPath());
71+
dataset = jsql.read().format("org.apache.spark.ml.source.libsvm").option("vectorType", "dense")
72+
.load(path.getPath());
7373
Assert.assertEquals("label", dataset.columns()[0]);
7474
Assert.assertEquals("features", dataset.columns()[1]);
7575
Row r = dataset.first();
76-
Assert.assertEquals(Double.valueOf(r.getDouble(0)), Double.valueOf(1.0));
77-
Assert.assertEquals(r.getAs(1), Vectors.dense(1.0, 0.0, 2.0, 0.0, 3.0, 0.0));
76+
Assert.assertEquals(Double.valueOf(1.0), Double.valueOf(r.getDouble(0)));
77+
DenseVector v = r.getAs(1);
78+
Assert.assertEquals(Vectors.dense(1.0, 0.0, 2.0, 0.0, 3.0, 0.0), v);
7879
}
7980
}

mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,28 +45,40 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
4545
}
4646

4747
test("select as sparse vector") {
48-
val df = sqlContext.read.options(Map("numFeatures" -> "6")).libsvm(path)
48+
val df = sqlContext.read.libsvm(path)
4949
assert(df.columns(0) == "label")
5050
assert(df.columns(1) == "features")
5151
val row1 = df.first()
5252
assert(row1.getDouble(0) == 1.0)
53-
assert(row1.getAs[SparseVector](1) == Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0))))
53+
val v = row1.getAs[SparseVector](1)
54+
assert(v == Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0))))
5455
}
5556

5657
test("select as dense vector") {
57-
val df = sqlContext.read.options(Map("numFeatures" -> "6", "featuresType" -> "dense"))
58+
val df = sqlContext.read.options(Map("vectorType" -> "dense"))
5859
.libsvm(path)
5960
assert(df.columns(0) == "label")
6061
assert(df.columns(1) == "features")
6162
assert(df.count() == 3)
6263
val row1 = df.first()
6364
assert(row1.getDouble(0) == 1.0)
64-
assert(row1.getAs[DenseVector](1) == Vectors.dense(1.0, 0.0, 2.0, 0.0, 3.0, 0.0))
65+
val v = row1.getAs[DenseVector](1)
66+
assert(v == Vectors.dense(1.0, 0.0, 2.0, 0.0, 3.0, 0.0))
6567
}
6668

67-
test("select without any option") {
68-
val df = sqlContext.read.libsvm(path)
69+
test("select long vector with specifying the number of features") {
70+
val lines =
71+
"""
72+
|1 1:1 10:2 20:3 30:4 40:5 50:6 60:7 70:8 80:9 90:10 100:1
73+
|0 1:1 10:10 20:9 30:8 40:7 50:6 60:5 70:4 80:3 90:2 100:1
74+
""".stripMargin
75+
val tempDir = Utils.createTempDir()
76+
val file = new File(tempDir.getPath, "part-00001")
77+
Files.write(lines, file, Charsets.US_ASCII)
78+
val df = sqlContext.read.option("numFeatures", "100").libsvm(tempDir.toURI.toString)
6979
val row1 = df.first()
70-
assert(row1.getAs[SparseVector](1) == Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0))))
80+
val v = row1.getAs[SparseVector](1)
81+
assert(v == Vectors.sparse(100, Seq((0, 1.0), (9, 2.0), (19, 3.0), (29, 4.0), (39, 5.0),
82+
(49, 6.0), (59, 7.0), (69, 8.0), (79, 9.0), (89, 10.0), (99, 1.0))))
7183
}
7284
}

0 commit comments

Comments
 (0)