Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add Java test
  • Loading branch information
Lewuathe committed Sep 2, 2015
commit 40d30276658df1f12b7ba0feae2338add94103db
Original file line number Diff line number Diff line change
Expand Up @@ -19,26 +19,29 @@ package org.apache.spark.ml.source.libsvm

import com.google.common.base.Objects
import org.apache.spark.Logging
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.annotation.Since
import org.apache.spark.mllib.linalg.{VectorUDT, Vector}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.sources.{DataSourceRegister, PrunedScan, BaseRelation, RelationProvider}


class LibSVMRelation(val path: String, val numFeatures: Int, val featuresType: String)
/**
* LibSVMRelation provides the DataFrame constructed from LibSVM format data.
* @param path
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fill in documentation

* @param numFeatures
* @param vectorType
* @param sqlContext
*/
private[ml] class LibSVMRelation(val path: String, val numFeatures: Int, val vectorType: String)
(@transient val sqlContext: SQLContext)
extends BaseRelation with PrunedScan with Logging {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we need to read the entire file anyway, it doesn't save much with PrunedScan. Maybe TableScan is simpler but sufficient.


private final val vectorType: DataType
= classOf[Vector].getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()


override def schema: StructType = StructType(
StructField("label", DoubleType, nullable = false) ::
StructField("features", vectorType, nullable = false) :: Nil
StructField("features", new VectorUDT(), nullable = false) :: Nil
)

override def buildScan(requiredColumns: Array[String]): RDD[Row] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my comments above about PrunedScan vs. TableScan.

Expand All @@ -47,8 +50,8 @@ class LibSVMRelation(val path: String, val numFeatures: Int, val featuresType: S

val rowBuilders = requiredColumns.map {
case "label" => (pt: LabeledPoint) => Seq(pt.label)
case "features" if featuresType == "sparse" => (pt: LabeledPoint) => Seq(pt.features.toSparse)
case "features" if featuresType == "dense" => (pt: LabeledPoint) => Seq(pt.features.toDense)
case "features" if vectorType == "sparse" => (pt: LabeledPoint) => Seq(pt.features.toSparse)
case "features" if vectorType == "dense" => (pt: LabeledPoint) => Seq(pt.features.toDense)
}

baseRdd.map(pt => {
Expand All @@ -69,16 +72,6 @@ class LibSVMRelation(val path: String, val numFeatures: Int, val featuresType: S

class DefaultSource extends RelationProvider with DataSourceRegister {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also missing doc

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


/**
* The string that represents the format that this data source provider uses. This is
* overridden by children to provide a nice alias for the data source. For example:
*
* {{{
* override def format(): String = "parquet"
* }}}
*
* @since 1.5.0
*/
override def shortName(): String = "libsvm"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add @Since("1.6.0")


private def checkPath(parameters: Map[String, String]): String = {
Expand All @@ -90,8 +83,8 @@ class DefaultSource extends RelationProvider with DataSourceRegister {
* Note: the parameters' keywords are case insensitive and this insensitivity is enforced
* by the Map that is passed to the function.
*/
override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]):
BaseRelation = {
override def createRelation(sqlContext: SQLContext, parameters: Map[String, String])
: BaseRelation = {
val path = checkPath(parameters)
val numFeatures = parameters.getOrElse("numFeatures", "-1").toInt
/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package org.apache.spark.ml.source;

import com.google.common.base.Charsets;
import com.google.common.io.Files;
import org.apache.spark.api.java.JavaSparkContext;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto. organize imports

import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.util.Utils;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

import java.io.File;
import java.io.IOException;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

organize imports: java, scala, 3rd-party, spark.


/**
* Test LibSVMRelation in Java.
*/
public class JavaLibSVMRelationSuite {
private transient JavaSparkContext jsc;
private transient SQLContext jsql;
private transient DataFrame dataset;

private File path;

@Before
public void setUp() throws IOException {
jsc = new JavaSparkContext("local", "JavaLibSVMRelationSuite");
jsql = new SQLContext(jsc);

path = Utils.createTempDir(System.getProperty("java.io.tmpdir"),
"datasource").getCanonicalFile();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

path = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource")
  .getCanonicalFile();

if (path.exists()) {
path.delete();
}

String s = "1 1:1.0 3:2.0 5:3.0\n0\n0 2:4.0 4:5.0 6:6.0";
Files.write(s, path, Charsets.US_ASCII);
}

@After
public void tearDown() {
jsc.stop();
jsc = null;
}

@Test
public void verifyLibSvmDF() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LibSVM to be consistent.

dataset = jsql.read().format("libsvm").load();
Assert.assertEquals(dataset.columns()[0], "label");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In JUnit, assertEquals takes the expected value as the first arg. So it should be "label", dataset.columns()[0] here.

Assert.assertEquals(dataset.columns()[1], "features");
Row r = dataset.first();
Assert.assertTrue(r.getDouble(0) == 1.0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use assertEquals

Assert.assertEquals(r.getAs(1), Vectors.dense(1.0, 0.0, 2.0, 0.0, 3.0, 0.0));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to check the class name first or cast it to DenseVector directly:

DenseVector v = r.getAs(1)
Assert.assertEquals(Vectors.dense(...), v)

If it is a sparse vector, the first line will throw an error.

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,4 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
val row1 = df.first()
assert(row1.getAs[SparseVector](1) == Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0))))
}


}