@@ -27,11 +27,9 @@ import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV}
2727
2828import org .apache .spark .SparkException
2929import org .apache .spark .mllib .util .NumericParser
30- import org .apache .spark .sql .catalyst .UDTRegistry
3130import org .apache .spark .sql .catalyst .annotation .SQLUserDefinedType
3231import org .apache .spark .sql .catalyst .expressions .GenericMutableRow
3332import org .apache .spark .sql .catalyst .types ._
34- import org .apache .spark .sql .Row
3533
3634/**
3735 * Represents a numeric vector, whose index type is Int and value type is Double.
@@ -86,12 +84,6 @@ sealed trait Vector extends Serializable {
8684 */
8785object Vectors {
8886
89- // Note: Explicit registration is only needed for Vector and SparseVector;
90- // the annotation works for DenseVector.
91- UDTRegistry .registerType(scala.reflect.runtime.universe.typeOf[Vector ], new VectorUDT ())
92- UDTRegistry .registerType(scala.reflect.runtime.universe.typeOf[SparseVector ],
93- new SparseVectorUDT ())
94-
9587 /**
9688 * Creates a dense vector from its values.
9789 */
@@ -202,7 +194,6 @@ object Vectors {
202194/**
203195 * A dense vector represented by a value array.
204196 */
205- @ SQLUserDefinedType (udt = classOf [DenseVectorUDT ])
206197class DenseVector (val values : Array [Double ]) extends Vector {
207198
208199 override def size : Int = values.length
@@ -254,105 +245,3 @@ class SparseVector(
254245
255246 private [mllib] override def toBreeze : BV [Double ] = new BSV [Double ](indices, values, size)
256247}
257-
258- /**
259- * User-defined type for [[Vector ]] which allows easy interaction with SQL
260- * via [[org.apache.spark.sql.SchemaRDD ]].
261- */
262- private [spark] class VectorUDT extends UserDefinedType [Vector ] {
263-
264- /**
265- * vectorType: 0 = dense, 1 = sparse.
266- * dense, sparse: One element holds the vector, and the other is null.
267- */
268- override def sqlType : StructType = StructType (Seq (
269- StructField (" vectorType" , ByteType , nullable = false ),
270- StructField (" dense" , new DenseVectorUDT , nullable = true ),
271- StructField (" sparse" , new SparseVectorUDT , nullable = true )))
272-
273- override def serialize (obj : Any ): Row = {
274- val row = new GenericMutableRow (3 )
275- obj match {
276- case v : DenseVector =>
277- row.setByte(0 , 0 )
278- row.update(1 , new DenseVectorUDT ().serialize(obj))
279- row.setNullAt(2 )
280- case v : SparseVector =>
281- row.setByte(0 , 1 )
282- row.setNullAt(1 )
283- row.update(2 , new SparseVectorUDT ().serialize(obj))
284- }
285- row
286- }
287-
288- override def deserialize (datum : Any ): Vector = {
289- datum match {
290- case row : Row =>
291- require(row.length == 3 ,
292- s " VectorUDT.deserialize given row with length ${row.length} but requires length == 3 " )
293- val vectorType = row.getByte(0 )
294- vectorType match {
295- case 0 =>
296- new DenseVectorUDT ().deserialize(row.getAs[Row ](1 ))
297- case 1 =>
298- new SparseVectorUDT ().deserialize(row.getAs[Row ](2 ))
299- }
300- }
301- }
302- }
303-
304- /**
305- * User-defined type for [[DenseVector ]] which allows easy interaction with SQL
306- * via [[org.apache.spark.sql.SchemaRDD ]].
307- */
308- private [spark] class DenseVectorUDT extends UserDefinedType [DenseVector ] {
309-
310- override def sqlType : DataType = ArrayType (DoubleType , containsNull = false )
311-
312- override def serialize (obj : Any ): Seq [Double ] = {
313- obj match {
314- case v : DenseVector =>
315- v.values.toSeq
316- }
317- }
318-
319- override def deserialize (datum : Any ): DenseVector = {
320- datum match {
321- case values : Seq [_] =>
322- new DenseVector (values.asInstanceOf [Seq [Double ]].toArray)
323- }
324- }
325- }
326-
327- /**
328- * User-defined type for [[SparseVector ]] which allows easy interaction with SQL
329- * via [[org.apache.spark.sql.SchemaRDD ]].
330- */
331- private [spark] class SparseVectorUDT extends UserDefinedType [SparseVector ] {
332-
333- override def sqlType : StructType = StructType (Seq (
334- StructField (" size" , IntegerType , nullable = false ),
335- StructField (" indices" , ArrayType (IntegerType , containsNull = false ), nullable = false ),
336- StructField (" values" , ArrayType (DoubleType , containsNull = false ), nullable = false )))
337-
338- override def serialize (obj : Any ): Row = obj match {
339- case v : SparseVector =>
340- val row : GenericMutableRow = new GenericMutableRow (3 )
341- row.setInt(0 , v.size)
342- row.update(1 , v.indices.toSeq)
343- row.update(2 , v.values.toSeq)
344- row
345- }
346-
347- override def deserialize (datum : Any ): SparseVector = {
348- datum match {
349- case row : Row =>
350- require(row.length == 3 ,
351- s " SparseVectorUDT.deserialize given row with length ${row.length} but expect 3. " )
352- val vSize = row.getInt(0 )
353- val indices = row.getAs[Seq [Int ]](1 ).toArray
354- val values = row.getAs[Seq [Double ]](2 ).toArray
355- new SparseVector (vSize, indices, values)
356- }
357- }
358- }
0 commit comments