@@ -19,26 +19,29 @@ package org.apache.spark.ml.source.libsvm
1919
2020import com .google .common .base .Objects
2121import org .apache .spark .Logging
22- import org .apache .spark .mllib .linalg .Vector
22+ import org .apache .spark .annotation .Since
23+ import org .apache .spark .mllib .linalg .{VectorUDT , Vector }
2324import org .apache .spark .mllib .regression .LabeledPoint
2425import org .apache .spark .mllib .util .MLUtils
2526import org .apache .spark .rdd .RDD
2627import org .apache .spark .sql .types ._
2728import org .apache .spark .sql .{Row , SQLContext }
2829import org .apache .spark .sql .sources .{DataSourceRegister , PrunedScan , BaseRelation , RelationProvider }
2930
30-
31- class LibSVMRelation (val path : String , val numFeatures : Int , val featuresType : String )
31+ /**
32+ * LibSVMRelation provides the DataFrame constructed from LibSVM format data.
33+ * @param path
34+ * @param numFeatures
35+ * @param vectorType
36+ * @param sqlContext
37+ */
38+ private [ml] class LibSVMRelation (val path : String , val numFeatures : Int , val vectorType : String )
3239 (@ transient val sqlContext : SQLContext )
3340 extends BaseRelation with PrunedScan with Logging {
3441
35- private final val vectorType : DataType
36- = classOf [Vector ].getAnnotation(classOf [SQLUserDefinedType ]).udt().newInstance()
37-
38-
3942 override def schema : StructType = StructType (
4043 StructField (" label" , DoubleType , nullable = false ) ::
41- StructField (" features" , vectorType , nullable = false ) :: Nil
44+ StructField (" features" , new VectorUDT () , nullable = false ) :: Nil
4245 )
4346
4447 override def buildScan (requiredColumns : Array [String ]): RDD [Row ] = {
@@ -47,8 +50,8 @@ class LibSVMRelation(val path: String, val numFeatures: Int, val featuresType: S
4750
4851 val rowBuilders = requiredColumns.map {
4952 case " label" => (pt : LabeledPoint ) => Seq (pt.label)
50- case " features" if featuresType == " sparse" => (pt : LabeledPoint ) => Seq (pt.features.toSparse)
51- case " features" if featuresType == " dense" => (pt : LabeledPoint ) => Seq (pt.features.toDense)
53+ case " features" if vectorType == " sparse" => (pt : LabeledPoint ) => Seq (pt.features.toSparse)
54+ case " features" if vectorType == " dense" => (pt : LabeledPoint ) => Seq (pt.features.toDense)
5255 }
5356
5457 baseRdd.map(pt => {
@@ -69,16 +72,6 @@ class LibSVMRelation(val path: String, val numFeatures: Int, val featuresType: S
6972
7073class DefaultSource extends RelationProvider with DataSourceRegister {
7174
72- /**
73- * The string that represents the format that this data source provider uses. This is
74- * overridden by children to provide a nice alias for the data source. For example:
75- *
76- * {{{
77- * override def format(): String = "parquet"
78- * }}}
79- *
80- * @since 1.5.0
81- */
8275 override def shortName (): String = " libsvm"
8376
8477 private def checkPath (parameters : Map [String , String ]): String = {
@@ -90,8 +83,8 @@ class DefaultSource extends RelationProvider with DataSourceRegister {
9083 * Note: the parameters' keywords are case insensitive and this insensitivity is enforced
9184 * by the Map that is passed to the function.
9285 */
93- override def createRelation (sqlContext : SQLContext , parameters : Map [String , String ]):
94- BaseRelation = {
86+ override def createRelation (sqlContext : SQLContext , parameters : Map [String , String ])
87+ : BaseRelation = {
9588 val path = checkPath(parameters)
9689 val numFeatures = parameters.getOrElse(" numFeatures" , " -1" ).toInt
9790 /**
0 commit comments