-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-10117][MLLIB] Implement SQL data source API for reading LIBSVM data #8537
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
99accaa
7056d4a
40d3027
3fd8dce
70ee4dd
aef9564
a97ee97
62010af
7d693c2
2c12894
b56a948
8660d0e
5ab62ab
4f40891
0ea1c1c
ba3657c
1fdd2df
9ce63c7
21600a4
11d513f
986999d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,6 +20,7 @@ package org.apache.spark.ml.source.libsvm | |
| import com.google.common.base.Objects | ||
|
|
||
| import org.apache.spark.Logging | ||
| import org.apache.spark.annotation.Since | ||
| import org.apache.spark.mllib.linalg.VectorUDT | ||
| import org.apache.spark.mllib.util.MLUtils | ||
| import org.apache.spark.rdd.RDD | ||
|
|
@@ -68,8 +69,10 @@ private[ml] class LibSVMRelation(val path: String, val numFeatures: Int, val vec | |
| * This is used for creating DataFrame from LibSVM format file. | ||
| * The LibSVM file path must be specified to DefaultSource. | ||
| */ | ||
| @Since("1.6.0") | ||
| class DefaultSource extends RelationProvider with DataSourceRegister { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add |
||
|
|
||
| @Since("1.6.0") | ||
| override def shortName(): String = "libsvm" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add |
||
|
|
||
| private def checkPath(parameters: Map[String, String]): String = { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,13 +23,13 @@ | |
| import com.google.common.base.Charsets; | ||
| import com.google.common.io.Files; | ||
|
|
||
| import org.apache.spark.mllib.linalg.DenseVector; | ||
| import org.junit.After; | ||
| import org.junit.Assert; | ||
| import org.junit.Before; | ||
| import org.junit.Test; | ||
|
|
||
| import org.apache.spark.api.java.JavaSparkContext; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto. organize imports |
||
| import org.apache.spark.mllib.linalg.DenseVector; | ||
| import org.apache.spark.mllib.linalg.Vectors; | ||
| import org.apache.spark.sql.DataFrame; | ||
| import org.apache.spark.sql.Row; | ||
|
|
@@ -45,15 +45,16 @@ public class JavaLibSVMRelationSuite { | |
| private transient SQLContext jsql; | ||
| private transient DataFrame dataset; | ||
|
|
||
| private File tmpDir; | ||
| private File path; | ||
|
|
||
| @Before | ||
| public void setUp() throws IOException { | ||
| jsc = new JavaSparkContext("local", "JavaLibSVMRelationSuite"); | ||
| jsql = new SQLContext(jsc); | ||
|
|
||
| File tmpDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource"); | ||
| path = File.createTempFile("datasource", "libsvm-relation", tmpDir); | ||
| tmpDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource"); | ||
| path = new File(tmpDir.getPath(), "part-00000"); | ||
|
|
||
| String s = "1 1:1.0 3:2.0 5:3.0\n0\n0 2:4.0 4:5.0 6:6.0"; | ||
| Files.write(s, path, Charsets.US_ASCII); | ||
|
|
@@ -63,7 +64,7 @@ public void setUp() throws IOException { | |
| public void tearDown() { | ||
| jsc.stop(); | ||
| jsc = null; | ||
| path.delete(); | ||
| Utils.deleteRecursively(tmpDir); | ||
| } | ||
|
|
||
| @Test | ||
|
|
@@ -72,7 +73,7 @@ public void verifyLibSVMDF() { | |
| Assert.assertEquals("label", dataset.columns()[0]); | ||
| Assert.assertEquals("features", dataset.columns()[1]); | ||
| Row r = dataset.first(); | ||
| Assert.assertEquals(Double.valueOf(1.0), Double.valueOf(r.getDouble(0))); | ||
| Assert.assertEquals(1.0, r.getDouble(0), 1e-15); | ||
| DenseVector v = r.getAs(1); | ||
| Assert.assertEquals(Vectors.dense(1.0, 0.0, 2.0, 0.0, 3.0, 0.0), v); | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,8 +21,8 @@ import java.io.File | |
|
|
||
| import com.google.common.base.Charsets | ||
| import com.google.common.io.Files | ||
|
|
||
| import org.apache.spark.SparkFunSuite | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto. organize imports |
||
| import org.apache.spark.ml.source.libsvm._ | ||
| import org.apache.spark.mllib.linalg.{SparseVector, Vectors, DenseVector} | ||
| import org.apache.spark.mllib.util.MLlibTestSparkContext | ||
| import org.apache.spark.util.Utils | ||
|
|
@@ -67,19 +67,10 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { | |
| } | ||
|
|
||
| test("select long vector with specifying the number of features") { | ||
| val lines = | ||
| """ | ||
| |1 1:1 10:2 20:3 30:4 40:5 50:6 60:7 70:8 80:9 90:10 100:1 | ||
| |0 1:1 10:10 20:9 30:8 40:7 50:6 60:5 70:4 80:3 90:2 100:1 | ||
| """.stripMargin | ||
| val tempDir = Utils.createTempDir() | ||
| val file = new File(tempDir.getPath, "part-00001") | ||
| Files.write(lines, file, Charsets.US_ASCII) | ||
| val df = sqlContext.read.option("numFeatures", "100").format("libsvm") | ||
| .load(tempDir.toURI.toString) | ||
| .load(path) | ||
| val row1 = df.first() | ||
| val v = row1.getAs[SparseVector](1) | ||
| assert(v == Vectors.sparse(100, Seq((0, 1.0), (9, 2.0), (19, 3.0), (29, 4.0), (39, 5.0), | ||
| (49, 6.0), (59, 7.0), (69, 8.0), (79, 9.0), (89, 10.0), (99, 1.0)))) | ||
| assert(v == Vectors.sparse(100, Seq((0, 1.0), (2, 2.0), (4, 3.0)))) | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also missing doc