From 99accaa9542f0f38e1d41c25f75e143ef3f9be08 Mon Sep 17 00:00:00 2001 From: lewuathe Date: Mon, 31 Aug 2015 22:16:38 +0900 Subject: [PATCH 01/13] [SPARK-10117] Implement SQL data source API for reading LIBSVM data --- .../ml/source/libsvm/LibSVMRelation.scala | 104 ++++++++++++++++++ .../spark/ml/source/libsvm/package.scala | 33 ++++++ .../spark/ml/source/LibSVMRelationSuite.scala | 74 +++++++++++++ 3 files changed, 211 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/source/libsvm/package.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala new file mode 100644 index 000000000000..52b808e01a0b --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.source.libsvm + +import com.google.common.base.Objects +import org.apache.spark.Logging +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.sources.{DataSourceRegister, PrunedScan, BaseRelation, RelationProvider} + + +class LibSVMRelation(val path: String, val numFeatures: Int, val featuresType: String) + (@transient val sqlContext: SQLContext) + extends BaseRelation with PrunedScan with Logging { + + private final val vectorType: DataType + = classOf[Vector].getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() + + + override def schema: StructType = StructType( + StructField("label", DoubleType, nullable = false) :: + StructField("features", vectorType, nullable = false) :: Nil + ) + + override def buildScan(requiredColumns: Array[String]): RDD[Row] = { + val sc = sqlContext.sparkContext + val baseRdd = MLUtils.loadLibSVMFile(sc, path, numFeatures) + + val rowBuilders = requiredColumns.map { + case "label" => (pt: LabeledPoint) => Seq(pt.label) + case "features" if featuresType == "sparse" => (pt: LabeledPoint) => Seq(pt.features.toSparse) + case "features" if featuresType == "dense" => (pt: LabeledPoint) => Seq(pt.features.toDense) + } + + baseRdd.map(pt => { + Row.fromSeq(rowBuilders.map(_(pt)).reduceOption(_ ++ _).getOrElse(Seq.empty)) + }) + } + + override def hashCode(): Int = { + Objects.hashCode(path, schema) + } + + override def equals(other: Any): Boolean = other match { + case that: LibSVMRelation => (this.path == that.path) && this.schema.equals(that.schema) + case _ => false + } + +} + +class DefaultSource extends RelationProvider with DataSourceRegister { + + /** + * The string that represents the format that this data source provider uses. This is + * overridden by children to provide a nice alias for the data source. For example: + * + * {{{ + * override def format(): String = "parquet" + * }}} + * + * @since 1.5.0 + */ + override def shortName(): String = "libsvm" + + private def checkPath(parameters: Map[String, String]): String = { + parameters.getOrElse("path", sys.error("'path' must be specified")) + } + + /** + * Returns a new base relation with the given parameters. + * Note: the parameters' keywords are case insensitive and this insensitivity is enforced + * by the Map that is passed to the function. + */ + override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): + BaseRelation = { + val path = checkPath(parameters) + val numFeatures = parameters.getOrElse("numFeatures", "-1").toInt + /** + * featuresType can be selected "dense" or "sparse". + * This parameter decides the type of returned feature vector. + */ + val featuresType = parameters.getOrElse("featuresType", "sparse") + new LibSVMRelation(path, numFeatures, featuresType)(sqlContext) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/package.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/package.scala new file mode 100644 index 000000000000..92c021e4b4e6 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/package.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.source + +import org.apache.spark.sql.{DataFrame, DataFrameReader} + +package object libsvm { + + /** + * Implicit declaration in order to be used from SQLContext. + * It is necessary to import org.apache.spark.ml.source.libsvm._ + * @param read + */ + implicit class LibSVMReader(read: DataFrameReader) { + def libsvm(filePath: String): DataFrame + = read.format(classOf[DefaultSource].getName).load(filePath) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala new file mode 100644 index 000000000000..accf37d9886a --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.source + +import java.io.File + +import com.google.common.base.Charsets +import com.google.common.io.Files +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.source.libsvm._ +import org.apache.spark.mllib.linalg.{SparseVector, Vectors, DenseVector} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.util.Utils + +class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { + var path: String = _ + + override def beforeAll(): Unit = { + super.beforeAll() + val lines = + """ + |1 1:1.0 3:2.0 5:3.0 + |0 + |0 2:4.0 4:5.0 6:6.0 + """.stripMargin + val tempDir = Utils.createTempDir() + val file = new File(tempDir.getPath, "part-00000") + Files.write(lines, file, Charsets.US_ASCII) + path = tempDir.toURI.toString + } + + test("select as sparse vector") { + val df = sqlContext.read.options(Map("numFeatures" -> "6")).libsvm(path) + assert(df.columns(0) == "label") + assert(df.columns(1) == "features") + val row1 = df.first() + assert(row1.getDouble(0) == 1.0) + assert(row1.getAs[SparseVector](1) == Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0)))) + } + + test("select as dense vector") { + val df = sqlContext.read.options(Map("numFeatures" -> "6", "featuresType" -> "dense")) + .libsvm(path) + assert(df.columns(0) == "label") + assert(df.columns(1) == "features") + assert(df.count() == 3) + val row1 = df.first() + assert(row1.getDouble(0) == 1.0) + assert(row1.getAs[DenseVector](1) == Vectors.dense(1.0, 0.0, 2.0, 0.0, 3.0, 0.0)) + } + + test("select without any option") { + val df = sqlContext.read.libsvm(path) + val row1 = df.first() + assert(row1.getAs[SparseVector](1) == Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0)))) + } + + +} From 40d30276658df1f12b7ba0feae2338add94103db Mon Sep 17 00:00:00 2001 From: lewuathe Date: Wed, 2 Sep 2015 23:35:18 +0900 Subject: [PATCH 02/13] Add Java test --- .../ml/source/libsvm/LibSVMRelation.scala | 37 +++++------- .../ml/source/JavaLibSVMRelationSuite.java | 59 +++++++++++++++++++ .../spark/ml/source/LibSVMRelationSuite.scala | 2 - 3 files changed, 74 insertions(+), 24 deletions(-) create mode 100644 mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index 52b808e01a0b..bf10536f3955 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -19,7 +19,8 @@ package org.apache.spark.ml.source.libsvm import com.google.common.base.Objects import org.apache.spark.Logging -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.annotation.Since +import org.apache.spark.mllib.linalg.{VectorUDT, Vector} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD @@ -27,18 +28,20 @@ import org.apache.spark.sql.types._ import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.sql.sources.{DataSourceRegister, PrunedScan, BaseRelation, RelationProvider} - -class LibSVMRelation(val path: String, val numFeatures: Int, val featuresType: String) +/** + * LibSVMRelation provides the DataFrame constructed from LibSVM format data. + * @param path + * @param numFeatures + * @param vectorType + * @param sqlContext + */ +private[ml] class LibSVMRelation(val path: String, val numFeatures: Int, val vectorType: String) (@transient val sqlContext: SQLContext) extends BaseRelation with PrunedScan with Logging { - private final val vectorType: DataType - = classOf[Vector].getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() - - override def schema: StructType = StructType( StructField("label", DoubleType, nullable = false) :: - StructField("features", vectorType, nullable = false) :: Nil + StructField("features", new VectorUDT(), nullable = false) :: Nil ) override def buildScan(requiredColumns: Array[String]): RDD[Row] = { @@ -47,8 +50,8 @@ class LibSVMRelation(val path: String, val numFeatures: Int, val featuresType: S val rowBuilders = requiredColumns.map { case "label" => (pt: LabeledPoint) => Seq(pt.label) - case "features" if featuresType == "sparse" => (pt: LabeledPoint) => Seq(pt.features.toSparse) - case "features" if featuresType == "dense" => (pt: LabeledPoint) => Seq(pt.features.toDense) + case "features" if vectorType == "sparse" => (pt: LabeledPoint) => Seq(pt.features.toSparse) + case "features" if vectorType == "dense" => (pt: LabeledPoint) => Seq(pt.features.toDense) } baseRdd.map(pt => { @@ -69,16 +72,6 @@ class LibSVMRelation(val path: String, val numFeatures: Int, val featuresType: S class DefaultSource extends RelationProvider with DataSourceRegister { - /** - * The string that represents the format that this data source provider uses. This is - * overridden by children to provide a nice alias for the data source. For example: - * - * {{{ - * override def format(): String = "parquet" - * }}} - * - * @since 1.5.0 - */ override def shortName(): String = "libsvm" private def checkPath(parameters: Map[String, String]): String = { @@ -90,8 +83,8 @@ class DefaultSource extends RelationProvider with DataSourceRegister { * Note: the parameters' keywords are case insensitive and this insensitivity is enforced * by the Map that is passed to the function. */ - override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): - BaseRelation = { + override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]) + : BaseRelation = { val path = checkPath(parameters) val numFeatures = parameters.getOrElse("numFeatures", "-1").toInt /** diff --git a/mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java b/mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java new file mode 100644 index 000000000000..0464988f9980 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java @@ -0,0 +1,59 @@ +package org.apache.spark.ml.source; + +import com.google.common.base.Charsets; +import com.google.common.io.Files; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.util.Utils; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.io.File; +import java.io.IOException; + +/** + * Test LibSVMRelation in Java. + */ +public class JavaLibSVMRelationSuite { + private transient JavaSparkContext jsc; + private transient SQLContext jsql; + private transient DataFrame dataset; + + private File path; + + @Before + public void setUp() throws IOException { + jsc = new JavaSparkContext("local", "JavaLibSVMRelationSuite"); + jsql = new SQLContext(jsc); + + path = Utils.createTempDir(System.getProperty("java.io.tmpdir"), + "datasource").getCanonicalFile(); + if (path.exists()) { + path.delete(); + } + + String s = "1 1:1.0 3:2.0 5:3.0\n0\n0 2:4.0 4:5.0 6:6.0"; + Files.write(s, path, Charsets.US_ASCII); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void verifyLibSvmDF() { + dataset = jsql.read().format("libsvm").load(); + Assert.assertEquals(dataset.columns()[0], "label"); + Assert.assertEquals(dataset.columns()[1], "features"); + Row r = dataset.first(); + Assert.assertTrue(r.getDouble(0) == 1.0); + Assert.assertEquals(r.getAs(1), Vectors.dense(1.0, 0.0, 2.0, 0.0, 3.0, 0.0)); + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala index accf37d9886a..960ab8575fa5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala @@ -69,6 +69,4 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { val row1 = df.first() assert(row1.getAs[SparseVector](1) == Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0)))) } - - } From 3fd8dcebee5473e7ddafc79383f90570f9c59316 Mon Sep 17 00:00:00 2001 From: lewuathe Date: Mon, 31 Aug 2015 22:16:38 +0900 Subject: [PATCH 03/13] [SPARK-10117] Implement SQL data source API for reading LIBSVM data --- .../ml/source/libsvm/LibSVMRelation.scala | 104 ++++++++++++++++++ .../spark/ml/source/libsvm/package.scala | 33 ++++++ .../spark/ml/source/LibSVMRelationSuite.scala | 74 +++++++++++++ 3 files changed, 211 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/source/libsvm/package.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala new file mode 100644 index 000000000000..52b808e01a0b --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.source.libsvm + +import com.google.common.base.Objects +import org.apache.spark.Logging +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.sources.{DataSourceRegister, PrunedScan, BaseRelation, RelationProvider} + + +class LibSVMRelation(val path: String, val numFeatures: Int, val featuresType: String) + (@transient val sqlContext: SQLContext) + extends BaseRelation with PrunedScan with Logging { + + private final val vectorType: DataType + = classOf[Vector].getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() + + + override def schema: StructType = StructType( + StructField("label", DoubleType, nullable = false) :: + StructField("features", vectorType, nullable = false) :: Nil + ) + + override def buildScan(requiredColumns: Array[String]): RDD[Row] = { + val sc = sqlContext.sparkContext + val baseRdd = MLUtils.loadLibSVMFile(sc, path, numFeatures) + + val rowBuilders = requiredColumns.map { + case "label" => (pt: LabeledPoint) => Seq(pt.label) + case "features" if featuresType == "sparse" => (pt: LabeledPoint) => Seq(pt.features.toSparse) + case "features" if featuresType == "dense" => (pt: LabeledPoint) => Seq(pt.features.toDense) + } + + baseRdd.map(pt => { + Row.fromSeq(rowBuilders.map(_(pt)).reduceOption(_ ++ _).getOrElse(Seq.empty)) + }) + } + + override def hashCode(): Int = { + Objects.hashCode(path, schema) + } + + override def equals(other: Any): Boolean = other match { + case that: LibSVMRelation => (this.path == that.path) && this.schema.equals(that.schema) + case _ => false + } + +} + +class DefaultSource extends RelationProvider with DataSourceRegister { + + /** + * The string that represents the format that this data source provider uses. This is + * overridden by children to provide a nice alias for the data source. For example: + * + * {{{ + * override def format(): String = "parquet" + * }}} + * + * @since 1.5.0 + */ + override def shortName(): String = "libsvm" + + private def checkPath(parameters: Map[String, String]): String = { + parameters.getOrElse("path", sys.error("'path' must be specified")) + } + + /** + * Returns a new base relation with the given parameters. + * Note: the parameters' keywords are case insensitive and this insensitivity is enforced + * by the Map that is passed to the function. + */ + override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): + BaseRelation = { + val path = checkPath(parameters) + val numFeatures = parameters.getOrElse("numFeatures", "-1").toInt + /** + * featuresType can be selected "dense" or "sparse". + * This parameter decides the type of returned feature vector. + */ + val featuresType = parameters.getOrElse("featuresType", "sparse") + new LibSVMRelation(path, numFeatures, featuresType)(sqlContext) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/package.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/package.scala new file mode 100644 index 000000000000..92c021e4b4e6 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/package.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.source + +import org.apache.spark.sql.{DataFrame, DataFrameReader} + +package object libsvm { + + /** + * Implicit declaration in order to be used from SQLContext. + * It is necessary to import org.apache.spark.ml.source.libsvm._ + * @param read + */ + implicit class LibSVMReader(read: DataFrameReader) { + def libsvm(filePath: String): DataFrame + = read.format(classOf[DefaultSource].getName).load(filePath) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala new file mode 100644 index 000000000000..accf37d9886a --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.source + +import java.io.File + +import com.google.common.base.Charsets +import com.google.common.io.Files +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.source.libsvm._ +import org.apache.spark.mllib.linalg.{SparseVector, Vectors, DenseVector} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.util.Utils + +class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { + var path: String = _ + + override def beforeAll(): Unit = { + super.beforeAll() + val lines = + """ + |1 1:1.0 3:2.0 5:3.0 + |0 + |0 2:4.0 4:5.0 6:6.0 + """.stripMargin + val tempDir = Utils.createTempDir() + val file = new File(tempDir.getPath, "part-00000") + Files.write(lines, file, Charsets.US_ASCII) + path = tempDir.toURI.toString + } + + test("select as sparse vector") { + val df = sqlContext.read.options(Map("numFeatures" -> "6")).libsvm(path) + assert(df.columns(0) == "label") + assert(df.columns(1) == "features") + val row1 = df.first() + assert(row1.getDouble(0) == 1.0) + assert(row1.getAs[SparseVector](1) == Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0)))) + } + + test("select as dense vector") { + val df = sqlContext.read.options(Map("numFeatures" -> "6", "featuresType" -> "dense")) + .libsvm(path) + assert(df.columns(0) == "label") + assert(df.columns(1) == "features") + assert(df.count() == 3) + val row1 = df.first() + assert(row1.getDouble(0) == 1.0) + assert(row1.getAs[DenseVector](1) == Vectors.dense(1.0, 0.0, 2.0, 0.0, 3.0, 0.0)) + } + + test("select without any option") { + val df = sqlContext.read.libsvm(path) + val row1 = df.first() + assert(row1.getAs[SparseVector](1) == Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0)))) + } + + +} From 70ee4dd4fc8081c4b1abd52c4bd25b158299b907 Mon Sep 17 00:00:00 2001 From: lewuathe Date: Wed, 2 Sep 2015 23:35:18 +0900 Subject: [PATCH 04/13] Add Java test --- .../ml/source/libsvm/LibSVMRelation.scala | 37 +++++------- .../ml/source/JavaLibSVMRelationSuite.java | 59 +++++++++++++++++++ .../spark/ml/source/LibSVMRelationSuite.scala | 2 - 3 files changed, 74 insertions(+), 24 deletions(-) create mode 100644 mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index 52b808e01a0b..bf10536f3955 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -19,7 +19,8 @@ package org.apache.spark.ml.source.libsvm import com.google.common.base.Objects import org.apache.spark.Logging -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.annotation.Since +import org.apache.spark.mllib.linalg.{VectorUDT, Vector} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD @@ -27,18 +28,20 @@ import org.apache.spark.sql.types._ import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.sql.sources.{DataSourceRegister, PrunedScan, BaseRelation, RelationProvider} - -class LibSVMRelation(val path: String, val numFeatures: Int, val featuresType: String) +/** + * LibSVMRelation provides the DataFrame constructed from LibSVM format data. + * @param path + * @param numFeatures + * @param vectorType + * @param sqlContext + */ +private[ml] class LibSVMRelation(val path: String, val numFeatures: Int, val vectorType: String) (@transient val sqlContext: SQLContext) extends BaseRelation with PrunedScan with Logging { - private final val vectorType: DataType - = classOf[Vector].getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() - - override def schema: StructType = StructType( StructField("label", DoubleType, nullable = false) :: - StructField("features", vectorType, nullable = false) :: Nil + StructField("features", new VectorUDT(), nullable = false) :: Nil ) override def buildScan(requiredColumns: Array[String]): RDD[Row] = { @@ -47,8 +50,8 @@ class LibSVMRelation(val path: String, val numFeatures: Int, val featuresType: S val rowBuilders = requiredColumns.map { case "label" => (pt: LabeledPoint) => Seq(pt.label) - case "features" if featuresType == "sparse" => (pt: LabeledPoint) => Seq(pt.features.toSparse) - case "features" if featuresType == "dense" => (pt: LabeledPoint) => Seq(pt.features.toDense) + case "features" if vectorType == "sparse" => (pt: LabeledPoint) => Seq(pt.features.toSparse) + case "features" if vectorType == "dense" => (pt: LabeledPoint) => Seq(pt.features.toDense) } baseRdd.map(pt => { @@ -69,16 +72,6 @@ class LibSVMRelation(val path: String, val numFeatures: Int, val featuresType: S class DefaultSource extends RelationProvider with DataSourceRegister { - /** - * The string that represents the format that this data source provider uses. This is - * overridden by children to provide a nice alias for the data source. For example: - * - * {{{ - * override def format(): String = "parquet" - * }}} - * - * @since 1.5.0 - */ override def shortName(): String = "libsvm" private def checkPath(parameters: Map[String, String]): String = { @@ -90,8 +83,8 @@ class DefaultSource extends RelationProvider with DataSourceRegister { * Note: the parameters' keywords are case insensitive and this insensitivity is enforced * by the Map that is passed to the function. */ - override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): - BaseRelation = { + override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]) + : BaseRelation = { val path = checkPath(parameters) val numFeatures = parameters.getOrElse("numFeatures", "-1").toInt /** diff --git a/mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java b/mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java new file mode 100644 index 000000000000..0464988f9980 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java @@ -0,0 +1,59 @@ +package org.apache.spark.ml.source; + +import com.google.common.base.Charsets; +import com.google.common.io.Files; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.util.Utils; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.io.File; +import java.io.IOException; + +/** + * Test LibSVMRelation in Java. + */ +public class JavaLibSVMRelationSuite { + private transient JavaSparkContext jsc; + private transient SQLContext jsql; + private transient DataFrame dataset; + + private File path; + + @Before + public void setUp() throws IOException { + jsc = new JavaSparkContext("local", "JavaLibSVMRelationSuite"); + jsql = new SQLContext(jsc); + + path = Utils.createTempDir(System.getProperty("java.io.tmpdir"), + "datasource").getCanonicalFile(); + if (path.exists()) { + path.delete(); + } + + String s = "1 1:1.0 3:2.0 5:3.0\n0\n0 2:4.0 4:5.0 6:6.0"; + Files.write(s, path, Charsets.US_ASCII); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void verifyLibSvmDF() { + dataset = jsql.read().format("libsvm").load(); + Assert.assertEquals(dataset.columns()[0], "label"); + Assert.assertEquals(dataset.columns()[1], "features"); + Row r = dataset.first(); + Assert.assertTrue(r.getDouble(0) == 1.0); + Assert.assertEquals(r.getAs(1), Vectors.dense(1.0, 0.0, 2.0, 0.0, 3.0, 0.0)); + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala index accf37d9886a..960ab8575fa5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala @@ -69,6 +69,4 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { val row1 = df.first() assert(row1.getAs[SparseVector](1) == Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0)))) } - - } From aef95643878eef2dd88a5ccd86d60fa7bbb2251d Mon Sep 17 00:00:00 2001 From: lewuathe Date: Thu, 3 Sep 2015 20:53:45 +0900 Subject: [PATCH 05/13] Fix --- .../ml/source/libsvm/LibSVMRelation.scala | 36 ++++++++++--------- .../spark/ml/source/libsvm/package.scala | 2 +- .../ml/source/JavaLibSVMRelationSuite.java | 32 +++++++++++++---- 3 files changed, 47 insertions(+), 23 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index bf10536f3955..abddcc98ef88 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -18,41 +18,44 @@ package org.apache.spark.ml.source.libsvm import com.google.common.base.Objects + import org.apache.spark.Logging -import org.apache.spark.annotation.Since -import org.apache.spark.mllib.linalg.{VectorUDT, Vector} +import org.apache.spark.mllib.linalg.VectorUDT import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD -import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.{StructType, StructField, DoubleType} import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.sql.sources.{DataSourceRegister, PrunedScan, BaseRelation, RelationProvider} +import org.apache.spark.sql.sources._ /** * LibSVMRelation provides the DataFrame constructed from LibSVM format data. - * @param path - * @param numFeatures - * @param vectorType - * @param sqlContext + * @param path File path of LibSVM format + * @param numFeatures The number of features + * @param vectorType The type of vector. It can be 'sparse' or 'dense' + * @param sqlContext The Spark SQLContext */ private[ml] class LibSVMRelation(val path: String, val numFeatures: Int, val vectorType: String) (@transient val sqlContext: SQLContext) - extends BaseRelation with PrunedScan with Logging { + extends BaseRelation with TableScan with Logging { override def schema: StructType = StructType( StructField("label", DoubleType, nullable = false) :: StructField("features", new VectorUDT(), nullable = false) :: Nil ) - override def buildScan(requiredColumns: Array[String]): RDD[Row] = { + override def buildScan(): RDD[Row] = { val sc = sqlContext.sparkContext val baseRdd = MLUtils.loadLibSVMFile(sc, path, numFeatures) - val rowBuilders = requiredColumns.map { - case "label" => (pt: LabeledPoint) => Seq(pt.label) - case "features" if vectorType == "sparse" => (pt: LabeledPoint) => Seq(pt.features.toSparse) - case "features" if vectorType == "dense" => (pt: LabeledPoint) => Seq(pt.features.toDense) - } + val rowBuilders = Array( + (pt: LabeledPoint) => Seq(pt.label), + if (vectorType == "dense") { + (pt: LabeledPoint) => Seq(pt.features.toSparse) + } else { + (pt: LabeledPoint) => Seq(pt.features.toDense) + } + ) baseRdd.map(pt => { Row.fromSeq(rowBuilders.map(_(pt)).reduceOption(_ ++ _).getOrElse(Seq.empty)) @@ -75,7 +78,8 @@ class DefaultSource extends RelationProvider with DataSourceRegister { override def shortName(): String = "libsvm" private def checkPath(parameters: Map[String, String]): String = { - parameters.getOrElse("path", sys.error("'path' must be specified")) + require(parameters.contains("path"), "'path' must be specified") + parameters.get("path") } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/package.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/package.scala index 92c021e4b4e6..f15253c7657c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/package.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/package.scala @@ -24,7 +24,7 @@ package object libsvm { /** * Implicit declaration in order to be used from SQLContext. * It is necessary to import org.apache.spark.ml.source.libsvm._ - * @param read + * @param read Given original DataFrameReader */ implicit class LibSVMReader(read: DataFrameReader) { def libsvm(filePath: String): DataFrame diff --git a/mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java b/mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java index 0464988f9980..e719027315fa 100644 --- a/mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java @@ -1,13 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.ml.source; import com.google.common.base.Charsets; import com.google.common.io.Files; + import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; import org.apache.spark.util.Utils; + import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -31,8 +50,8 @@ public void setUp() throws IOException { jsc = new JavaSparkContext("local", "JavaLibSVMRelationSuite"); jsql = new SQLContext(jsc); - path = Utils.createTempDir(System.getProperty("java.io.tmpdir"), - "datasource").getCanonicalFile(); + path = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource") + .getCanonicalFile(); if (path.exists()) { path.delete(); } @@ -45,15 +64,16 @@ public void setUp() throws IOException { public void tearDown() { jsc.stop(); jsc = null; + path.delete(); } @Test - public void verifyLibSvmDF() { + public void verifyLibSVMDF() { dataset = jsql.read().format("libsvm").load(); - Assert.assertEquals(dataset.columns()[0], "label"); - Assert.assertEquals(dataset.columns()[1], "features"); + Assert.assertEquals("label", dataset.columns()[0]); + Assert.assertEquals("features", dataset.columns()[1]); Row r = dataset.first(); - Assert.assertTrue(r.getDouble(0) == 1.0); + Assert.assertEquals(Double.valueOf(r.getDouble(0)), Double.valueOf(1.0)); Assert.assertEquals(r.getAs(1), Vectors.dense(1.0, 0.0, 2.0, 0.0, 3.0, 0.0)); } } From a97ee97dc7fbc72c466ab089c661bf4e058bf67d Mon Sep 17 00:00:00 2001 From: lewuathe Date: Thu, 3 Sep 2015 21:35:50 +0900 Subject: [PATCH 06/13] Fix some points --- .../org/apache/spark/ml/source/libsvm/LibSVMRelation.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index abddcc98ef88..e437a2a6a0df 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -51,9 +51,9 @@ private[ml] class LibSVMRelation(val path: String, val numFeatures: Int, val vec val rowBuilders = Array( (pt: LabeledPoint) => Seq(pt.label), if (vectorType == "dense") { - (pt: LabeledPoint) => Seq(pt.features.toSparse) - } else { (pt: LabeledPoint) => Seq(pt.features.toDense) + } else { + (pt: LabeledPoint) => Seq(pt.features.toSparse) } ) @@ -79,7 +79,7 @@ class DefaultSource extends RelationProvider with DataSourceRegister { private def checkPath(parameters: Map[String, String]): String = { require(parameters.contains("path"), "'path' must be specified") - parameters.get("path") + parameters.get("path").get } /** From 2c128940e9631c0b2147f2c076fef8f78a34e3ea Mon Sep 17 00:00:00 2001 From: lewuathe Date: Thu, 3 Sep 2015 22:04:16 +0900 Subject: [PATCH 07/13] Remove unnecessary tag --- .../scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index cb32ad41d58e..e437a2a6a0df 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -44,7 +44,6 @@ private[ml] class LibSVMRelation(val path: String, val numFeatures: Int, val vec StructField("features", new VectorUDT(), nullable = false) :: Nil ) -<<<<<<< HEAD override def buildScan(): RDD[Row] = { val sc = sqlContext.sparkContext val baseRdd = MLUtils.loadLibSVMFile(sc, path, numFeatures) From 8660d0e2a815b367cc9f34251926e315bc95f9c1 Mon Sep 17 00:00:00 2001 From: lewuathe Date: Fri, 4 Sep 2015 21:36:05 +0900 Subject: [PATCH 08/13] Fix Java unit test --- .../org/apache/spark/ml/source/JavaLibSVMRelationSuite.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java b/mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java index e719027315fa..d3d41b219c1f 100644 --- a/mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java @@ -69,7 +69,7 @@ public void tearDown() { @Test public void verifyLibSVMDF() { - dataset = jsql.read().format("libsvm").load(); + dataset = jsql.read().format("org.apache.spark.ml.source.libsvm").load(path.getPath()); Assert.assertEquals("label", dataset.columns()[0]); Assert.assertEquals("features", dataset.columns()[1]); Row r = dataset.first(); From 4f40891a96112776f96bd3ac7cef4bd674bbc947 Mon Sep 17 00:00:00 2001 From: lewuathe Date: Sun, 6 Sep 2015 23:33:06 +0900 Subject: [PATCH 09/13] Improve test suites --- .../ml/source/libsvm/LibSVMRelation.scala | 23 +++++--------- .../ml/source/JavaLibSVMRelationSuite.java | 31 ++++++++++--------- .../spark/ml/source/LibSVMRelationSuite.scala | 26 +++++++++++----- 3 files changed, 42 insertions(+), 38 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index e437a2a6a0df..92114d56a026 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -21,7 +21,6 @@ import com.google.common.base.Objects import org.apache.spark.Logging import org.apache.spark.mllib.linalg.VectorUDT -import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.types.{StructType, StructField, DoubleType} @@ -37,7 +36,7 @@ import org.apache.spark.sql.sources._ */ private[ml] class LibSVMRelation(val path: String, val numFeatures: Int, val vectorType: String) (@transient val sqlContext: SQLContext) - extends BaseRelation with TableScan with Logging { + extends BaseRelation with TableScan with Logging with Serializable { override def schema: StructType = StructType( StructField("label", DoubleType, nullable = false) :: @@ -48,18 +47,10 @@ private[ml] class LibSVMRelation(val path: String, val numFeatures: Int, val vec val sc = sqlContext.sparkContext val baseRdd = MLUtils.loadLibSVMFile(sc, path, numFeatures) - val rowBuilders = Array( - (pt: LabeledPoint) => Seq(pt.label), - if (vectorType == "dense") { - (pt: LabeledPoint) => Seq(pt.features.toDense) - } else { - (pt: LabeledPoint) => Seq(pt.features.toSparse) - } - ) - - baseRdd.map(pt => { - Row.fromSeq(rowBuilders.map(_(pt)).reduceOption(_ ++ _).getOrElse(Seq.empty)) - }) + baseRdd.map { pt => + val features = if (vectorType == "dense") pt.features.toDense else pt.features.toSparse + Row(pt.label, features) + } } override def hashCode(): Int = { @@ -95,7 +86,7 @@ class DefaultSource extends RelationProvider with DataSourceRegister { * featuresType can be selected "dense" or "sparse". * This parameter decides the type of returned feature vector. */ - val featuresType = parameters.getOrElse("featuresType", "sparse") - new LibSVMRelation(path, numFeatures, featuresType)(sqlContext) + val vectorType = parameters.getOrElse("vectorType", "sparse") + new LibSVMRelation(path, numFeatures, vectorType)(sqlContext) } } diff --git a/mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java b/mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java index d3d41b219c1f..a00820d23773 100644 --- a/mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java @@ -17,9 +17,18 @@ package org.apache.spark.ml.source; +import java.io.File; +import java.io.IOException; + import com.google.common.base.Charsets; import com.google.common.io.Files; +import org.apache.spark.mllib.linalg.DenseVector; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.sql.DataFrame; @@ -27,13 +36,6 @@ import org.apache.spark.sql.SQLContext; import org.apache.spark.util.Utils; -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; - -import java.io.File; -import java.io.IOException; /** * Test LibSVMRelation in Java. @@ -50,11 +52,8 @@ public void setUp() throws IOException { jsc = new JavaSparkContext("local", "JavaLibSVMRelationSuite"); jsql = new SQLContext(jsc); - path = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource") - .getCanonicalFile(); - if (path.exists()) { - path.delete(); - } + File tmpDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource"); + path = File.createTempFile("datasource", "libsvm-relation", tmpDir); String s = "1 1:1.0 3:2.0 5:3.0\n0\n0 2:4.0 4:5.0 6:6.0"; Files.write(s, path, Charsets.US_ASCII); @@ -69,11 +68,13 @@ public void tearDown() { @Test public void verifyLibSVMDF() { - dataset = jsql.read().format("org.apache.spark.ml.source.libsvm").load(path.getPath()); + dataset = jsql.read().format("org.apache.spark.ml.source.libsvm").option("vectorType", "dense") + .load(path.getPath()); Assert.assertEquals("label", dataset.columns()[0]); Assert.assertEquals("features", dataset.columns()[1]); Row r = dataset.first(); - Assert.assertEquals(Double.valueOf(r.getDouble(0)), Double.valueOf(1.0)); - Assert.assertEquals(r.getAs(1), Vectors.dense(1.0, 0.0, 2.0, 0.0, 3.0, 0.0)); + Assert.assertEquals(Double.valueOf(1.0), Double.valueOf(r.getDouble(0))); + DenseVector v = r.getAs(1); + Assert.assertEquals(Vectors.dense(1.0, 0.0, 2.0, 0.0, 3.0, 0.0), v); } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala index 960ab8575fa5..8fa51f1d521b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala @@ -45,28 +45,40 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { } test("select as sparse vector") { - val df = sqlContext.read.options(Map("numFeatures" -> "6")).libsvm(path) + val df = sqlContext.read.libsvm(path) assert(df.columns(0) == "label") assert(df.columns(1) == "features") val row1 = df.first() assert(row1.getDouble(0) == 1.0) - assert(row1.getAs[SparseVector](1) == Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0)))) + val v = row1.getAs[SparseVector](1) + assert(v == Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0)))) } test("select as dense vector") { - val df = sqlContext.read.options(Map("numFeatures" -> "6", "featuresType" -> "dense")) + val df = sqlContext.read.options(Map("vectorType" -> "dense")) .libsvm(path) assert(df.columns(0) == "label") assert(df.columns(1) == "features") assert(df.count() == 3) val row1 = df.first() assert(row1.getDouble(0) == 1.0) - assert(row1.getAs[DenseVector](1) == Vectors.dense(1.0, 0.0, 2.0, 0.0, 3.0, 0.0)) + val v = row1.getAs[DenseVector](1) + assert(v == Vectors.dense(1.0, 0.0, 2.0, 0.0, 3.0, 0.0)) } - test("select without any option") { - val df = sqlContext.read.libsvm(path) + test("select long vector with specifying the number of features") { + val lines = + """ + |1 1:1 10:2 20:3 30:4 40:5 50:6 60:7 70:8 80:9 90:10 100:1 + |0 1:1 10:10 20:9 30:8 40:7 50:6 60:5 70:4 80:3 90:2 100:1 + """.stripMargin + val tempDir = Utils.createTempDir() + val file = new File(tempDir.getPath, "part-00001") + Files.write(lines, file, Charsets.US_ASCII) + val df = sqlContext.read.option("numFeatures", "100").libsvm(tempDir.toURI.toString) val row1 = df.first() - assert(row1.getAs[SparseVector](1) == Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0)))) + val v = row1.getAs[SparseVector](1) + assert(v == Vectors.sparse(100, Seq((0, 1.0), (9, 2.0), (19, 3.0), (29, 4.0), (39, 5.0), + (49, 6.0), (59, 7.0), (69, 8.0), (79, 9.0), (89, 10.0), (99, 1.0)))) } } From 0ea1c1c5f4127c66015b5831943460456f102bdd Mon Sep 17 00:00:00 2001 From: lewuathe Date: Tue, 8 Sep 2015 00:22:17 +0900 Subject: [PATCH 10/13] LibSVMRelation is registered into META-INF --- .../ml/source/libsvm/LibSVMRelation.scala | 4 +++ .../spark/ml/source/libsvm/package.scala | 33 ------------------- .../ml/source/JavaLibSVMRelationSuite.java | 3 +- .../spark/ml/source/LibSVMRelationSuite.scala | 9 ++--- ...pache.spark.sql.sources.DataSourceRegister | 1 + 5 files changed, 11 insertions(+), 39 deletions(-) delete mode 100644 mllib/src/main/scala/org/apache/spark/ml/source/libsvm/package.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index 92114d56a026..b0fc6603ced1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -64,6 +64,10 @@ private[ml] class LibSVMRelation(val path: String, val numFeatures: Int, val vec } +/** + * This is used for creating DataFrame from LibSVM format file. + * The LibSVM file path must be specified to DefaultSource. + */ class DefaultSource extends RelationProvider with DataSourceRegister { override def shortName(): String = "libsvm" diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/package.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/package.scala deleted file mode 100644 index f15253c7657c..000000000000 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/package.scala +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ml.source - -import org.apache.spark.sql.{DataFrame, DataFrameReader} - -package object libsvm { - - /** - * Implicit declaration in order to be used from SQLContext. - * It is necessary to import org.apache.spark.ml.source.libsvm._ - * @param read Given original DataFrameReader - */ - implicit class LibSVMReader(read: DataFrameReader) { - def libsvm(filePath: String): DataFrame - = read.format(classOf[DefaultSource].getName).load(filePath) - } -} diff --git a/mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java b/mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java index a00820d23773..5ccbafb640cf 100644 --- a/mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java @@ -68,8 +68,7 @@ public void tearDown() { @Test public void verifyLibSVMDF() { - dataset = jsql.read().format("org.apache.spark.ml.source.libsvm").option("vectorType", "dense") - .load(path.getPath()); + dataset = jsql.read().format("libsvm").option("vectorType", "dense").load(path.getPath()); Assert.assertEquals("label", dataset.columns()[0]); Assert.assertEquals("features", dataset.columns()[1]); Row r = dataset.first(); diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala index 8fa51f1d521b..dc9980d526e9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala @@ -45,7 +45,7 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { } test("select as sparse vector") { - val df = sqlContext.read.libsvm(path) + val df = sqlContext.read.format("libsvm").load(path) assert(df.columns(0) == "label") assert(df.columns(1) == "features") val row1 = df.first() @@ -55,8 +55,8 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { } test("select as dense vector") { - val df = sqlContext.read.options(Map("vectorType" -> "dense")) - .libsvm(path) + val df = sqlContext.read.format("libsvm").options(Map("vectorType" -> "dense")) + .load(path) assert(df.columns(0) == "label") assert(df.columns(1) == "features") assert(df.count() == 3) @@ -75,7 +75,8 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { val tempDir = Utils.createTempDir() val file = new File(tempDir.getPath, "part-00001") Files.write(lines, file, Charsets.US_ASCII) - val df = sqlContext.read.option("numFeatures", "100").libsvm(tempDir.toURI.toString) + val df = sqlContext.read.option("numFeatures", "100").format("libsvm") + .load(tempDir.toURI.toString) val row1 = df.first() val v = row1.getAs[SparseVector](1) assert(v == Vectors.sparse(100, Seq((0, 1.0), (9, 2.0), (19, 3.0), (29, 4.0), (39, 5.0), diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index ca50000b4756..55bebf96dabb 100644 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -1,3 +1,4 @@ org.apache.spark.sql.execution.datasources.jdbc.DefaultSource org.apache.spark.sql.execution.datasources.json.DefaultSource org.apache.spark.sql.execution.datasources.parquet.DefaultSource +org.apache.spark.ml.source.libsvm.DefaultSource From 9ce63c737a31be993edf542dc29b9a2f37e2b067 Mon Sep 17 00:00:00 2001 From: lewuathe Date: Tue, 8 Sep 2015 19:37:53 +0900 Subject: [PATCH 11/13] Rewrite service loader file --- .../services/org.apache.spark.sql.sources.DataSourceRegister | 1 + .../services/org.apache.spark.sql.sources.DataSourceRegister | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) create mode 100644 mllib/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister diff --git a/mllib/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/mllib/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 000000000000..f632dd603c44 --- /dev/null +++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1 @@ +org.apache.spark.ml.source.libsvm.DefaultSource diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 55bebf96dabb..ca50000b4756 100644 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -1,4 +1,3 @@ org.apache.spark.sql.execution.datasources.jdbc.DefaultSource org.apache.spark.sql.execution.datasources.json.DefaultSource org.apache.spark.sql.execution.datasources.parquet.DefaultSource -org.apache.spark.ml.source.libsvm.DefaultSource From 11d513f32046af0f40bc7e73df1d437d69ef9e35 Mon Sep 17 00:00:00 2001 From: lewuathe Date: Thu, 10 Sep 2015 00:18:11 +0900 Subject: [PATCH 12/13] Fix some reviews --- .../spark/ml/source/libsvm/LibSVMRelation.scala | 3 +++ .../spark/ml/source/JavaLibSVMRelationSuite.java | 11 ++++++----- .../spark/ml/source/LibSVMRelationSuite.scala | 15 +++------------ 3 files changed, 12 insertions(+), 17 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index b0fc6603ced1..b12cb62a4ef1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml.source.libsvm import com.google.common.base.Objects import org.apache.spark.Logging +import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.VectorUDT import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD @@ -68,8 +69,10 @@ private[ml] class LibSVMRelation(val path: String, val numFeatures: Int, val vec * This is used for creating DataFrame from LibSVM format file. * The LibSVM file path must be specified to DefaultSource. */ +@Since("1.6.0") class DefaultSource extends RelationProvider with DataSourceRegister { + @Since("1.6.0") override def shortName(): String = "libsvm" private def checkPath(parameters: Map[String, String]): String = { diff --git a/mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java b/mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java index 5ccbafb640cf..11fa4eec0ccf 100644 --- a/mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java @@ -23,13 +23,13 @@ import com.google.common.base.Charsets; import com.google.common.io.Files; -import org.apache.spark.mllib.linalg.DenseVector; import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.DenseVector; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.Row; @@ -45,6 +45,7 @@ public class JavaLibSVMRelationSuite { private transient SQLContext jsql; private transient DataFrame dataset; + private File tmpDir; private File path; @Before @@ -52,8 +53,8 @@ public void setUp() throws IOException { jsc = new JavaSparkContext("local", "JavaLibSVMRelationSuite"); jsql = new SQLContext(jsc); - File tmpDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource"); - path = File.createTempFile("datasource", "libsvm-relation", tmpDir); + tmpDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource"); + path = new File(tmpDir.getPath(), "part-00000"); String s = "1 1:1.0 3:2.0 5:3.0\n0\n0 2:4.0 4:5.0 6:6.0"; Files.write(s, path, Charsets.US_ASCII); @@ -63,7 +64,7 @@ public void setUp() throws IOException { public void tearDown() { jsc.stop(); jsc = null; - path.delete(); + Utils.deleteRecursively(tmpDir); } @Test @@ -72,7 +73,7 @@ public void verifyLibSVMDF() { Assert.assertEquals("label", dataset.columns()[0]); Assert.assertEquals("features", dataset.columns()[1]); Row r = dataset.first(); - Assert.assertEquals(Double.valueOf(1.0), Double.valueOf(r.getDouble(0))); + Assert.assertEquals(1.0, r.getDouble(0), 1e-15); DenseVector v = r.getAs(1); Assert.assertEquals(Vectors.dense(1.0, 0.0, 2.0, 0.0, 3.0, 0.0), v); } diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala index dc9980d526e9..0860220cc8e4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala @@ -21,8 +21,8 @@ import java.io.File import com.google.common.base.Charsets import com.google.common.io.Files + import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.source.libsvm._ import org.apache.spark.mllib.linalg.{SparseVector, Vectors, DenseVector} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.util.Utils @@ -67,19 +67,10 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { } test("select long vector with specifying the number of features") { - val lines = - """ - |1 1:1 10:2 20:3 30:4 40:5 50:6 60:7 70:8 80:9 90:10 100:1 - |0 1:1 10:10 20:9 30:8 40:7 50:6 60:5 70:4 80:3 90:2 100:1 - """.stripMargin - val tempDir = Utils.createTempDir() - val file = new File(tempDir.getPath, "part-00001") - Files.write(lines, file, Charsets.US_ASCII) val df = sqlContext.read.option("numFeatures", "100").format("libsvm") - .load(tempDir.toURI.toString) + .load(path) val row1 = df.first() val v = row1.getAs[SparseVector](1) - assert(v == Vectors.sparse(100, Seq((0, 1.0), (9, 2.0), (19, 3.0), (29, 4.0), (39, 5.0), - (49, 6.0), (59, 7.0), (69, 8.0), (79, 9.0), (89, 10.0), (99, 1.0)))) + assert(v == Vectors.sparse(100, Seq((0, 1.0), (2, 2.0), (4, 3.0)))) } } From 986999d9d878ff2e52e506a10ebc0abe715f6871 Mon Sep 17 00:00:00 2001 From: lewuathe Date: Thu, 10 Sep 2015 00:20:25 +0900 Subject: [PATCH 13/13] Change unit test phrase --- .../scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala index 0860220cc8e4..8ed134128c8d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala @@ -66,7 +66,7 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { assert(v == Vectors.dense(1.0, 0.0, 2.0, 0.0, 3.0, 0.0)) } - test("select long vector with specifying the number of features") { + test("select a vector with specifying the longer dimension") { val df = sqlContext.read.option("numFeatures", "100").format("libsvm") .load(path) val row1 = df.first()