Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix some reviews
  • Loading branch information
Lewuathe committed Sep 9, 2015
commit 11d513f32046af0f40bc7e73df1d437d69ef9e35
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.ml.source.libsvm
import com.google.common.base.Objects

import org.apache.spark.Logging
import org.apache.spark.annotation.Since
import org.apache.spark.mllib.linalg.VectorUDT
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -68,8 +69,10 @@ private[ml] class LibSVMRelation(val path: String, val numFeatures: Int, val vec
* This is used for creating DataFrame from LibSVM format file.
* The LibSVM file path must be specified to DefaultSource.
*/
@Since("1.6.0")
class DefaultSource extends RelationProvider with DataSourceRegister {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also missing doc

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


@Since("1.6.0")
override def shortName(): String = "libsvm"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add @Since("1.6.0")


private def checkPath(parameters: Map[String, String]): String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@
import com.google.common.base.Charsets;
import com.google.common.io.Files;

import org.apache.spark.mllib.linalg.DenseVector;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

import org.apache.spark.api.java.JavaSparkContext;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto. organize imports

import org.apache.spark.mllib.linalg.DenseVector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
Expand All @@ -45,15 +45,16 @@ public class JavaLibSVMRelationSuite {
private transient SQLContext jsql;
private transient DataFrame dataset;

private File tmpDir;
private File path;

@Before
public void setUp() throws IOException {
jsc = new JavaSparkContext("local", "JavaLibSVMRelationSuite");
jsql = new SQLContext(jsc);

File tmpDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource");
path = File.createTempFile("datasource", "libsvm-relation", tmpDir);
tmpDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource");
path = new File(tmpDir.getPath(), "part-00000");

String s = "1 1:1.0 3:2.0 5:3.0\n0\n0 2:4.0 4:5.0 6:6.0";
Files.write(s, path, Charsets.US_ASCII);
Expand All @@ -63,7 +64,7 @@ public void setUp() throws IOException {
public void tearDown() {
jsc.stop();
jsc = null;
path.delete();
Utils.deleteRecursively(tmpDir);
}

@Test
Expand All @@ -72,7 +73,7 @@ public void verifyLibSVMDF() {
Assert.assertEquals("label", dataset.columns()[0]);
Assert.assertEquals("features", dataset.columns()[1]);
Row r = dataset.first();
Assert.assertEquals(Double.valueOf(1.0), Double.valueOf(r.getDouble(0)));
Assert.assertEquals(1.0, r.getDouble(0), 1e-15);
DenseVector v = r.getAs(1);
Assert.assertEquals(Vectors.dense(1.0, 0.0, 2.0, 0.0, 3.0, 0.0), v);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ import java.io.File

import com.google.common.base.Charsets
import com.google.common.io.Files

import org.apache.spark.SparkFunSuite
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto. organize imports

import org.apache.spark.ml.source.libsvm._
import org.apache.spark.mllib.linalg.{SparseVector, Vectors, DenseVector}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -67,19 +67,10 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
}

test("select long vector with specifying the number of features") {
val lines =
"""
|1 1:1 10:2 20:3 30:4 40:5 50:6 60:7 70:8 80:9 90:10 100:1
|0 1:1 10:10 20:9 30:8 40:7 50:6 60:5 70:4 80:3 90:2 100:1
""".stripMargin
val tempDir = Utils.createTempDir()
val file = new File(tempDir.getPath, "part-00001")
Files.write(lines, file, Charsets.US_ASCII)
val df = sqlContext.read.option("numFeatures", "100").format("libsvm")
.load(tempDir.toURI.toString)
.load(path)
val row1 = df.first()
val v = row1.getAs[SparseVector](1)
assert(v == Vectors.sparse(100, Seq((0, 1.0), (9, 2.0), (19, 3.0), (29, 4.0), (39, 5.0),
(49, 6.0), (59, 7.0), (69, 8.0), (79, 9.0), (89, 10.0), (99, 1.0))))
assert(v == Vectors.sparse(100, Seq((0, 1.0), (2, 2.0), (4, 3.0))))
}
}