Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
remove element type parameter from getArray
  • Loading branch information
cloud-fan committed Jul 30, 2015
commit 8cb8842e3c9fbdd4321ec14e5f53d578a30812bf
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,9 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
row.setByte(0, 0)
row.setInt(1, sm.numRows)
row.setInt(2, sm.numCols)
row.update(3, sm.colPtrs.toSeq)
row.update(4, sm.rowIndices.toSeq)
row.update(5, sm.values.toSeq)
row.update(3, new GenericArrayData(sm.colPtrs.map(_.asInstanceOf[Any])))
row.update(4, new GenericArrayData(sm.rowIndices.map(_.asInstanceOf[Any])))
row.update(5, new GenericArrayData(sm.values.map(_.asInstanceOf[Any])))
row.setBoolean(6, sm.isTransposed)

case dm: DenseMatrix =>
Expand All @@ -165,7 +165,7 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
row.setInt(2, dm.numCols)
row.setNullAt(3)
row.setNullAt(4)
row.update(5, dm.values.toSeq)
row.update(5, new GenericArrayData(dm.values.map(_.asInstanceOf[Any])))
row.setBoolean(6, dm.isTransposed)
}
row
Expand All @@ -179,14 +179,12 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
val tpe = row.getByte(0)
val numRows = row.getInt(1)
val numCols = row.getInt(2)
val values = row.getAs[Seq[Double]](5, ArrayType(DoubleType, containsNull = false)).toArray
val values = row.getArray(5).toArray.map(_.asInstanceOf[Double])
val isTransposed = row.getBoolean(6)
tpe match {
case 0 =>
val colPtrs =
row.getAs[Seq[Int]](3, ArrayType(IntegerType, containsNull = false)).toArray
val rowIndices =
row.getAs[Seq[Int]](4, ArrayType(IntegerType, containsNull = false)).toArray
val colPtrs = row.getArray(3).toArray.map(_.asInstanceOf[Int])
val rowIndices = row.getArray(4).toArray.map(_.asInstanceOf[Int])
new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values, isTransposed)
case 1 =>
new DenseMatrix(numRows, numCols, values, isTransposed)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,11 +209,11 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
tpe match {
case 0 =>
val size = row.getInt(1)
val indices = row.getArray(2, IntegerType).toArray().map(_.asInstanceOf[Int])
val values = row.getArray(3, DoubleType).toArray().map(_.asInstanceOf[Double])
val indices = row.getArray(2).toArray().map(_.asInstanceOf[Int])
val values = row.getArray(3).toArray().map(_.asInstanceOf[Double])
new SparseVector(size, indices, values)
case 1 =>
val values = row.getArray(3, DoubleType).toArray().map(_.asInstanceOf[Double])
val values = row.getArray(3).toArray().map(_.asInstanceOf[Double])
new DenseVector(values)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,5 @@ public interface SpecializedGetters {

InternalRow getStruct(int ordinal, int numFields);

ArrayData getArray(int ordinal, DataType elementType);
ArrayData getArray(int ordinal);
}
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ object CatalystTypeConverters {
}

override def toScalaImpl(row: InternalRow, column: Int): Seq[Any] =
toScala(row.getArray(column, elementType))
toScala(row.getArray(column))
}

private case class MapConverter(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ abstract class InternalRow extends Serializable with SpecializedGetters {
override def getStruct(ordinal: Int, numFields: Int): InternalRow =
getAs[InternalRow](ordinal, null)

override def getArray(ordinal: Int, elementType: DataType): ArrayData = getAs(ordinal, null)
override def getArray(ordinal: Int): ArrayData = getAs(ordinal, null)

override def toString: String = s"[${this.mkString(",")}]"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,7 @@ class CodeGenContext {
case BinaryType => s"$getter.getBinary($ordinal)"
case CalendarIntervalType => s"$getter.getInterval($ordinal)"
case t: StructType => s"$getter.getStruct($ordinal, ${t.size})"
case a: ArrayType =>
val typeString = '"' + a.elementType.json.replace("\"", "\\\"") + '"'
s"$getter.getArray($ordinal, org.apache.spark.sql.types.DataType.fromJson($typeString))"
case a: ArrayType => s"$getter.getArray($ordinal)"
case _ => s"($jt)$getter.get($ordinal)" // todo: remove generic getter.
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class GenericArrayData(array: Array[Any]) extends ArrayData {

override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal)

override def getArray(ordinal: Int, elementType: DataType): ArrayData = getAs(ordinal)
override def getArray(ordinal: Int): ArrayData = getAs(ordinal)

override def numElements(): Int = array.length
}