From 55d31788836fe94ba38eec11e99881c0c8562a15 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 23 Apr 2016 23:10:47 +0800 Subject: [PATCH 1/7] specialize array data --- .../apache/spark/mllib/linalg/Matrices.scala | 10 ++-- .../apache/spark/mllib/linalg/Vectors.scala | 8 +-- .../sql/catalyst/util/GenericArrayData.scala | 58 +++++++++++++++++++ 3 files changed, 67 insertions(+), 9 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 90fa4fbbc604..3d91c8322387 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -28,7 +28,7 @@ import org.apache.spark.annotation.Since import org.apache.spark.ml.{linalg => newlinalg} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.catalyst.util.{DoubleArrayData, IntArrayData} import org.apache.spark.sql.types._ /** @@ -194,9 +194,9 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] { row.setByte(0, 0) row.setInt(1, sm.numRows) row.setInt(2, sm.numCols) - row.update(3, new GenericArrayData(sm.colPtrs.map(_.asInstanceOf[Any]))) - row.update(4, new GenericArrayData(sm.rowIndices.map(_.asInstanceOf[Any]))) - row.update(5, new GenericArrayData(sm.values.map(_.asInstanceOf[Any]))) + row.update(3, new IntArrayData(sm.colPtrs)) + row.update(4, new IntArrayData(sm.rowIndices)) + row.update(5, new DoubleArrayData(sm.values)) row.setBoolean(6, sm.isTransposed) case dm: DenseMatrix => @@ -205,7 +205,7 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] { row.setInt(2, dm.numCols) row.setNullAt(3) row.setNullAt(4) - row.update(5, new GenericArrayData(dm.values.map(_.asInstanceOf[Any]))) + row.update(5, new DoubleArrayData(dm.values)) row.setBoolean(6, dm.isTransposed) } row diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 6e3da6b701cb..398146b5f040 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -34,7 +34,7 @@ import org.apache.spark.ml.{linalg => newlinalg} import org.apache.spark.mllib.util.NumericParser import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.catalyst.util.{DoubleArrayData, IntArrayData} import org.apache.spark.sql.types._ /** @@ -216,15 +216,15 @@ class VectorUDT extends UserDefinedType[Vector] { val row = new GenericMutableRow(4) row.setByte(0, 0) row.setInt(1, size) - row.update(2, new GenericArrayData(indices.map(_.asInstanceOf[Any]))) - row.update(3, new GenericArrayData(values.map(_.asInstanceOf[Any]))) + row.update(2, new IntArrayData(indices)) + row.update(3, new DoubleArrayData(values)) row case DenseVector(values) => val row = new GenericMutableRow(4) row.setByte(0, 1) row.setNullAt(1) row.setNullAt(2) - row.update(3, new GenericArrayData(values.map(_.asInstanceOf[Any]))) + row.update(3, new DoubleArrayData(values)) row } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala index 2b8cdc1e23ab..eaa596d6f9a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala @@ -137,3 +137,61 @@ class GenericArrayData(val array: Array[Any]) extends ArrayData { result } } + +abstract class SpecializedArrayData extends ArrayData { + // Primitive arrays can't haven null elements. + override def isNullAt(ordinal: Int): Boolean = false + + private def fail() = { + throw new UnsupportedOperationException( + "Specialized array data should implement its corresponding get method") + } + + override def get(ordinal: Int, elementType: DataType): AnyRef = fail() + override def getBoolean(ordinal: Int): Boolean = fail() + override def getByte(ordinal: Int): Byte = fail() + override def getShort(ordinal: Int): Short = fail() + override def getInt(ordinal: Int): Int = fail() + override def getLong(ordinal: Int): Long = fail() + override def getFloat(ordinal: Int): Float = fail() + override def getDouble(ordinal: Int): Double = fail() + override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = fail() + override def getUTF8String(ordinal: Int): UTF8String = fail() + override def getBinary(ordinal: Int): Array[Byte] = fail() + override def getInterval(ordinal: Int): CalendarInterval = fail() + override def getStruct(ordinal: Int, numFields: Int): InternalRow = fail() + override def getArray(ordinal: Int): ArrayData = fail() + override def getMap(ordinal: Int): MapData = fail() +} + +class IntArrayData(val values: Array[Int]) extends SpecializedArrayData { + + override def array(): Array[Any] = values.map(_.asInstanceOf[Any]) + + override def numElements(): Int = values.length + + override def get(ordinal: Int, elementType: DataType): AnyRef = + values(ordinal).asInstanceOf[AnyRef] + + override def getInt(ordinal: Int): Int = values(ordinal) + + override def toIntArray(): Array[Int] = values + + override def copy(): IntArrayData = new IntArrayData(values.clone()) +} + +class DoubleArrayData(val values: Array[Double]) extends SpecializedArrayData { + + override def array(): Array[Any] = values.map(_.asInstanceOf[Any]) + + override def numElements(): Int = values.length + + override def get(ordinal: Int, elementType: DataType): AnyRef = + values(ordinal).asInstanceOf[AnyRef] + + override def getDouble(ordinal: Int): Double = values(ordinal) + + override def toDoubleArray(): Array[Double] = values + + override def copy(): DoubleArrayData = new DoubleArrayData(values.clone()) +} From f4d2cbbefabdd7e42317835cc168ea92c26e040c Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 24 Apr 2016 16:03:48 +0800 Subject: [PATCH 2/7] update --- .../apache/spark/mllib/linalg/Matrices.scala | 26 +++++--- .../apache/spark/mllib/linalg/Vectors.scala | 24 +++++--- .../catalyst/expressions/UnsafeArrayData.java | 60 ++++++++++++++++++- .../sql/catalyst/util/GenericArrayData.scala | 58 ------------------ .../sql/catalyst/util/UnsafeArraySuite.scala | 58 ++++++++++++++++++ 5 files changed, 150 insertions(+), 76 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 3d91c8322387..1b0dcc623b43 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -27,8 +27,7 @@ import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.spark.annotation.Since import org.apache.spark.ml.{linalg => newlinalg} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.catalyst.util.{DoubleArrayData, IntArrayData} +import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, UnsafeArrayData} import org.apache.spark.sql.types._ /** @@ -194,9 +193,9 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] { row.setByte(0, 0) row.setInt(1, sm.numRows) row.setInt(2, sm.numCols) - row.update(3, new IntArrayData(sm.colPtrs)) - row.update(4, new IntArrayData(sm.rowIndices)) - row.update(5, new DoubleArrayData(sm.values)) + row.update(3, UnsafeArrayData.fromPrimitiveArray(sm.colPtrs)) + row.update(4, UnsafeArrayData.fromPrimitiveArray(sm.rowIndices)) + row.update(5, UnsafeArrayData.fromPrimitiveArray(sm.values)) row.setBoolean(6, sm.isTransposed) case dm: DenseMatrix => @@ -205,7 +204,7 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] { row.setInt(2, dm.numCols) row.setNullAt(3) row.setNullAt(4) - row.update(5, new DoubleArrayData(dm.values)) + row.update(5, UnsafeArrayData.fromPrimitiveArray(dm.values)) row.setBoolean(6, dm.isTransposed) } row @@ -219,12 +218,21 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] { val tpe = row.getByte(0) val numRows = row.getInt(1) val numCols = row.getInt(2) - val values = row.getArray(5).toDoubleArray() + val values = row.getArray(5) match { + case u: UnsafeArrayData => u.toPrimitiveDoubleArray + case a => a.toDoubleArray() + } val isTransposed = row.getBoolean(6) tpe match { case 0 => - val colPtrs = row.getArray(3).toIntArray() - val rowIndices = row.getArray(4).toIntArray() + val colPtrs = row.getArray(3) match { + case u: UnsafeArrayData => u.toPrimitiveIntArray + case a => a.toIntArray() + } + val rowIndices = row.getArray(4) match { + case u: UnsafeArrayData => u.toPrimitiveIntArray + case a => a.toIntArray() + } new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values, isTransposed) case 1 => new DenseMatrix(numRows, numCols, values, isTransposed) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 398146b5f040..0779b4b6ea2d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -33,8 +33,7 @@ import org.apache.spark.annotation.{AlphaComponent, Since} import org.apache.spark.ml.{linalg => newlinalg} import org.apache.spark.mllib.util.NumericParser import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.catalyst.util.{DoubleArrayData, IntArrayData} +import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, UnsafeArrayData} import org.apache.spark.sql.types._ /** @@ -216,15 +215,15 @@ class VectorUDT extends UserDefinedType[Vector] { val row = new GenericMutableRow(4) row.setByte(0, 0) row.setInt(1, size) - row.update(2, new IntArrayData(indices)) - row.update(3, new DoubleArrayData(values)) + row.update(2, UnsafeArrayData.fromPrimitiveArray(indices)) + row.update(3, UnsafeArrayData.fromPrimitiveArray(values)) row case DenseVector(values) => val row = new GenericMutableRow(4) row.setByte(0, 1) row.setNullAt(1) row.setNullAt(2) - row.update(3, new DoubleArrayData(values)) + row.update(3, UnsafeArrayData.fromPrimitiveArray(values)) row } } @@ -238,11 +237,20 @@ class VectorUDT extends UserDefinedType[Vector] { tpe match { case 0 => val size = row.getInt(1) - val indices = row.getArray(2).toIntArray() - val values = row.getArray(3).toDoubleArray() + val indices = row.getArray(2) match { + case u: UnsafeArrayData => u.toPrimitiveIntArray + case a => a.toIntArray() + } + val values = row.getArray(3) match { + case u: UnsafeArrayData => u.toPrimitiveDoubleArray + case a => a.toDoubleArray() + } new SparseVector(size, indices, values) case 1 => - val values = row.getArray(3).toDoubleArray() + val values = row.getArray(3) match { + case u: UnsafeArrayData => u.toPrimitiveDoubleArray + case a => a.toDoubleArray() + } new DenseVector(values) } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index 648625b2cc5d..82cba84e8f72 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -81,7 +81,7 @@ private void assertIndexIsValid(int ordinal) { } public Object[] array() { - throw new UnsupportedOperationException("Only supported on GenericArrayData."); + throw new UnsupportedOperationException("Not supported on UnsafeArrayData."); } /** @@ -336,4 +336,62 @@ public UnsafeArrayData copy() { arrayCopy.pointTo(arrayDataCopy, Platform.BYTE_ARRAY_OFFSET, sizeInBytes); return arrayCopy; } + + public int[] toPrimitiveIntArray() { + int[] result = new int[numElements]; + Platform.copyMemory(baseObject, baseOffset + 4 + 4 * numElements, + result, Platform.INT_ARRAY_OFFSET, 4 * numElements); + return result; + } + + public double[] toPrimitiveDoubleArray() { + double[] result = new double[numElements]; + Platform.copyMemory(baseObject, baseOffset + 4 + 4 * numElements, + result, Platform.DOUBLE_ARRAY_OFFSET, 8 * numElements); + return result; + } + + public static UnsafeArrayData fromPrimitiveArray(int[] arr) { + int offsetRegionSize = 4 * arr.length; + int valueRegionSize = 4 * arr.length; + int totalSize = 4 + offsetRegionSize + valueRegionSize; + byte[] data = new byte[totalSize]; + + Platform.putInt(data, Platform.BYTE_ARRAY_OFFSET, arr.length); + + int elementOffsetStart = 4 + offsetRegionSize; + for (int i = 0; i < arr.length; i++) { + Platform.putInt(data, Platform.BYTE_ARRAY_OFFSET + 4 + i * 4, elementOffsetStart + i * 4); + } + + Platform.copyMemory(arr, Platform.INT_ARRAY_OFFSET, data, + Platform.BYTE_ARRAY_OFFSET + elementOffsetStart, valueRegionSize); + + UnsafeArrayData result = new UnsafeArrayData(); + result.pointTo(data, Platform.BYTE_ARRAY_OFFSET, totalSize); + return result; + } + + public static UnsafeArrayData fromPrimitiveArray(double[] arr) { + int offsetRegionSize = 4 * arr.length; + int valueRegionSize = 8 * arr.length; + int totalSize = 4 + offsetRegionSize + valueRegionSize; + byte[] data = new byte[totalSize]; + + Platform.putInt(data, Platform.BYTE_ARRAY_OFFSET, arr.length); + + int elementOffsetStart = 4 + offsetRegionSize; + for (int i = 0; i < arr.length; i++) { + Platform.putInt(data, Platform.BYTE_ARRAY_OFFSET + 4 + i * 4, elementOffsetStart + i * 8); + } + + Platform.copyMemory(arr, Platform.DOUBLE_ARRAY_OFFSET, data, + Platform.BYTE_ARRAY_OFFSET + elementOffsetStart, valueRegionSize); + + UnsafeArrayData result = new UnsafeArrayData(); + result.pointTo(data, Platform.BYTE_ARRAY_OFFSET, totalSize); + return result; + } + + // TODO: add more specialized methods. } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala index eaa596d6f9a2..2b8cdc1e23ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala @@ -137,61 +137,3 @@ class GenericArrayData(val array: Array[Any]) extends ArrayData { result } } - -abstract class SpecializedArrayData extends ArrayData { - // Primitive arrays can't haven null elements. - override def isNullAt(ordinal: Int): Boolean = false - - private def fail() = { - throw new UnsupportedOperationException( - "Specialized array data should implement its corresponding get method") - } - - override def get(ordinal: Int, elementType: DataType): AnyRef = fail() - override def getBoolean(ordinal: Int): Boolean = fail() - override def getByte(ordinal: Int): Byte = fail() - override def getShort(ordinal: Int): Short = fail() - override def getInt(ordinal: Int): Int = fail() - override def getLong(ordinal: Int): Long = fail() - override def getFloat(ordinal: Int): Float = fail() - override def getDouble(ordinal: Int): Double = fail() - override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = fail() - override def getUTF8String(ordinal: Int): UTF8String = fail() - override def getBinary(ordinal: Int): Array[Byte] = fail() - override def getInterval(ordinal: Int): CalendarInterval = fail() - override def getStruct(ordinal: Int, numFields: Int): InternalRow = fail() - override def getArray(ordinal: Int): ArrayData = fail() - override def getMap(ordinal: Int): MapData = fail() -} - -class IntArrayData(val values: Array[Int]) extends SpecializedArrayData { - - override def array(): Array[Any] = values.map(_.asInstanceOf[Any]) - - override def numElements(): Int = values.length - - override def get(ordinal: Int, elementType: DataType): AnyRef = - values(ordinal).asInstanceOf[AnyRef] - - override def getInt(ordinal: Int): Int = values(ordinal) - - override def toIntArray(): Array[Int] = values - - override def copy(): IntArrayData = new IntArrayData(values.clone()) -} - -class DoubleArrayData(val values: Array[Double]) extends SpecializedArrayData { - - override def array(): Array[Any] = values.map(_.asInstanceOf[Any]) - - override def numElements(): Int = values.length - - override def get(ordinal: Int, elementType: DataType): AnyRef = - values(ordinal).asInstanceOf[AnyRef] - - override def getDouble(ordinal: Int): Double = values(ordinal) - - override def toDoubleArray(): Array[Double] = values - - override def copy(): DoubleArrayData = new DoubleArrayData(values.clone()) -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala new file mode 100644 index 000000000000..243a932f081f --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData + +class UnsafeArraySuite extends SparkFunSuite { + + test("from primitive int array") { + val array = Array(1, 10, 100) + val unsafe = UnsafeArrayData.fromPrimitiveArray(array) + assert(unsafe.numElements == 3) + assert(unsafe.getSizeInBytes == 4 + 4 * 3 + 4 * 3) + assert(unsafe.getInt(0) == 1) + assert(unsafe.getInt(1) == 10) + assert(unsafe.getInt(2) == 100) + } + + test("from primitive double array") { + val array = Array(1.1, 2.2, 3.3) + val unsafe = UnsafeArrayData.fromPrimitiveArray(array) + assert(unsafe.numElements == 3) + assert(unsafe.getSizeInBytes == 4 + 4 * 3 + 8 * 3) + assert(unsafe.getDouble(0) == 1.1) + assert(unsafe.getDouble(1) == 2.2) + assert(unsafe.getDouble(2) == 3.3) + } + + test("to primitive int array") { + val array = Array(1, 10, 100) + val unsafe = UnsafeArrayData.fromPrimitiveArray(array) + val array2 = unsafe.toPrimitiveIntArray + assert(array.toSeq == array2.toSeq) + } + + test("to primitive double array") { + val array = Array(1.1, 2.2, 3.3) + val unsafe = UnsafeArrayData.fromPrimitiveArray(array) + val array2 = unsafe.toPrimitiveDoubleArray + assert(array.toSeq == array2.toSeq) + } +} From a7b769459619eb8f5968fcde0a56caf8c7a0f499 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 28 Apr 2016 16:44:09 +0800 Subject: [PATCH 3/7] address comments --- .../catalyst/expressions/UnsafeArrayData.java | 28 +++++++++++-------- .../catalyst/expressions/UnsafeMapData.java | 2 +- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index 82cba84e8f72..160c4a35812b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -47,7 +47,7 @@ * Instances of `UnsafeArrayData` act as pointers to row data stored in this format. */ // todo: there is a lof of duplicated code between UnsafeRow and UnsafeArrayData. -public class UnsafeArrayData extends ArrayData { +public final class UnsafeArrayData extends ArrayData { private Object baseObject; private long baseOffset; @@ -339,15 +339,15 @@ public UnsafeArrayData copy() { public int[] toPrimitiveIntArray() { int[] result = new int[numElements]; - Platform.copyMemory(baseObject, baseOffset + 4 + 4 * numElements, - result, Platform.INT_ARRAY_OFFSET, 4 * numElements); + Platform.copyMemory(baseObject, baseOffset + 4 + 4L * numElements, + result, Platform.INT_ARRAY_OFFSET, 4L * numElements); return result; } public double[] toPrimitiveDoubleArray() { double[] result = new double[numElements]; - Platform.copyMemory(baseObject, baseOffset + 4 + 4 * numElements, - result, Platform.DOUBLE_ARRAY_OFFSET, 8 * numElements); + Platform.copyMemory(baseObject, baseOffset + 4 + 4L * numElements, + result, Platform.DOUBLE_ARRAY_OFFSET, 8L * numElements); return result; } @@ -359,13 +359,16 @@ public static UnsafeArrayData fromPrimitiveArray(int[] arr) { Platform.putInt(data, Platform.BYTE_ARRAY_OFFSET, arr.length); - int elementOffsetStart = 4 + offsetRegionSize; + int offsetPosition = Platform.BYTE_ARRAY_OFFSET + 4; + int valueOffset = 4 + offsetRegionSize; for (int i = 0; i < arr.length; i++) { - Platform.putInt(data, Platform.BYTE_ARRAY_OFFSET + 4 + i * 4, elementOffsetStart + i * 4); + Platform.putInt(data, offsetPosition, valueOffset); + offsetPosition += 4; + valueOffset += 4; } Platform.copyMemory(arr, Platform.INT_ARRAY_OFFSET, data, - Platform.BYTE_ARRAY_OFFSET + elementOffsetStart, valueRegionSize); + Platform.BYTE_ARRAY_OFFSET + 4 + offsetRegionSize, valueRegionSize); UnsafeArrayData result = new UnsafeArrayData(); result.pointTo(data, Platform.BYTE_ARRAY_OFFSET, totalSize); @@ -380,13 +383,16 @@ public static UnsafeArrayData fromPrimitiveArray(double[] arr) { Platform.putInt(data, Platform.BYTE_ARRAY_OFFSET, arr.length); - int elementOffsetStart = 4 + offsetRegionSize; + int offsetPosition = Platform.BYTE_ARRAY_OFFSET + 4; + int valueOffset = 4 + offsetRegionSize; for (int i = 0; i < arr.length; i++) { - Platform.putInt(data, Platform.BYTE_ARRAY_OFFSET + 4 + i * 4, elementOffsetStart + i * 8); + Platform.putInt(data, offsetPosition, valueOffset); + offsetPosition += 4; + valueOffset += 8; } Platform.copyMemory(arr, Platform.DOUBLE_ARRAY_OFFSET, data, - Platform.BYTE_ARRAY_OFFSET + elementOffsetStart, valueRegionSize); + Platform.BYTE_ARRAY_OFFSET + 4 + offsetRegionSize, valueRegionSize); UnsafeArrayData result = new UnsafeArrayData(); result.pointTo(data, Platform.BYTE_ARRAY_OFFSET, totalSize); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java index 651eb1ff0c56..0700148becab 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java @@ -30,7 +30,7 @@ * [unsafe key array numBytes] [unsafe key array] [unsafe value array] */ // TODO: Use a more efficient format which doesn't depend on unsafe array. -public class UnsafeMapData extends MapData { +public final class UnsafeMapData extends MapData { private Object baseObject; private long baseOffset; From f6964f9e8d0db473354941f31cf182628d0a3299 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 28 Apr 2016 17:17:37 +0800 Subject: [PATCH 4/7] add benchmark --- .../linalg/UDTSerializationBenchmark.scala | 70 +++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala new file mode 100644 index 000000000000..be7110ad6bbf --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.linalg + +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.util.Benchmark + +/** + * Serialization benchmark for VectorUDT. + */ +object UDTSerializationBenchmark { + + def main(args: Array[String]): Unit = { + val iters = 1e2.toInt + val numRows = 1e3.toInt + + val encoder = ExpressionEncoder[Vector].defaultBinding + + val vectors = (1 to numRows).map { i => + Vectors.dense(Array.fill(1e5.toInt)(1.0 * i)) + }.toArray + val rows = vectors.map(encoder.toRow) + + val benchmark = new Benchmark("VectorUDT de/serialization", numRows, iters) + + benchmark.addCase("serialize") { _ => + var sum = 0 + var i = 0 + while (i < numRows) { + sum += encoder.toRow(vectors(i)).numFields + i += 1 + } + } + + benchmark.addCase("deserialize") { _ => + var sum = 0 + var i = 0 + while (i < numRows) { + sum += encoder.fromRow(rows(i)).numActives + i += 1 + } + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11.4 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + + VectorUDT de/serialization: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + serialize 380 / 392 0.0 379730.0 1.0X + deserialize 138 / 142 0.0 137816.6 2.8X + */ + benchmark.run() + } +} From c6c3584c6cc7c500ca31cb21a00391c06caf3aa6 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 28 Apr 2016 17:49:43 +0800 Subject: [PATCH 5/7] rename --- .../org/apache/spark/mllib/linalg/Matrices.scala | 6 +++--- .../org/apache/spark/mllib/linalg/Vectors.scala | 6 +++--- .../sql/catalyst/expressions/UnsafeArrayData.java | 14 ++++++++++++-- .../spark/sql/catalyst/util/UnsafeArraySuite.scala | 8 ++++---- 4 files changed, 22 insertions(+), 12 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 1b0dcc623b43..955794d17d36 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -219,18 +219,18 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] { val numRows = row.getInt(1) val numCols = row.getInt(2) val values = row.getArray(5) match { - case u: UnsafeArrayData => u.toPrimitiveDoubleArray + case u: UnsafeArrayData => u.toDoubleArrayUnchecked case a => a.toDoubleArray() } val isTransposed = row.getBoolean(6) tpe match { case 0 => val colPtrs = row.getArray(3) match { - case u: UnsafeArrayData => u.toPrimitiveIntArray + case u: UnsafeArrayData => u.toIntArrayUnchecked case a => a.toIntArray() } val rowIndices = row.getArray(4) match { - case u: UnsafeArrayData => u.toPrimitiveIntArray + case u: UnsafeArrayData => u.toIntArrayUnchecked case a => a.toIntArray() } new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values, isTransposed) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 0779b4b6ea2d..add027d2bfca 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -238,17 +238,17 @@ class VectorUDT extends UserDefinedType[Vector] { case 0 => val size = row.getInt(1) val indices = row.getArray(2) match { - case u: UnsafeArrayData => u.toPrimitiveIntArray + case u: UnsafeArrayData => u.toIntArrayUnchecked case a => a.toIntArray() } val values = row.getArray(3) match { - case u: UnsafeArrayData => u.toPrimitiveDoubleArray + case u: UnsafeArrayData => u.toDoubleArrayUnchecked case a => a.toDoubleArray() } new SparseVector(size, indices, values) case 1 => val values = row.getArray(3) match { - case u: UnsafeArrayData => u.toPrimitiveDoubleArray + case u: UnsafeArrayData => u.toDoubleArrayUnchecked case a => a.toDoubleArray() } new DenseVector(values) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index 160c4a35812b..f2f8c3783a0b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -337,14 +337,24 @@ public UnsafeArrayData copy() { return arrayCopy; } - public int[] toPrimitiveIntArray() { + /** + * A faster version of `toIntArray`, which use memory copy instead of iterating all elements. + * Note that, this method is dangerous if this array contains null elements. We don't write + * null elements into the data region and memory copy will crash as the data size doesn't match. + */ + public int[] toIntArrayUnchecked() { int[] result = new int[numElements]; Platform.copyMemory(baseObject, baseOffset + 4 + 4L * numElements, result, Platform.INT_ARRAY_OFFSET, 4L * numElements); return result; } - public double[] toPrimitiveDoubleArray() { + /** + * A faster version of `toDoubleArray`, which use memory copy instead of iterating all elements. + * Note that, this method is dangerous if this array contains null elements. We don't write + * null elements into the data region and memory copy will crash as the data size doesn't match. + */ + public double[] toDoubleArrayUnchecked() { double[] result = new double[numElements]; Platform.copyMemory(baseObject, baseOffset + 4 + 4L * numElements, result, Platform.DOUBLE_ARRAY_OFFSET, 8L * numElements); diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala index 243a932f081f..1632596e19a3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala @@ -42,17 +42,17 @@ class UnsafeArraySuite extends SparkFunSuite { assert(unsafe.getDouble(2) == 3.3) } - test("to primitive int array") { + test("to int array unchecked") { val array = Array(1, 10, 100) val unsafe = UnsafeArrayData.fromPrimitiveArray(array) - val array2 = unsafe.toPrimitiveIntArray + val array2 = unsafe.toIntArrayUnchecked assert(array.toSeq == array2.toSeq) } - test("to primitive double array") { + test("to double array unchecked") { val array = Array(1.1, 2.2, 3.3) val unsafe = UnsafeArrayData.fromPrimitiveArray(array) - val array2 = unsafe.toPrimitiveDoubleArray + val array2 = unsafe.toDoubleArrayUnchecked assert(array.toSeq == array2.toSeq) } } From 537e3636da53523a14458006808dfb0303635470 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 29 Apr 2016 14:56:51 +0800 Subject: [PATCH 6/7] add size check --- .../catalyst/expressions/UnsafeArrayData.java | 30 ++++++++++++------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index f2f8c3783a0b..9333d3e38421 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -343,7 +343,7 @@ public UnsafeArrayData copy() { * null elements into the data region and memory copy will crash as the data size doesn't match. */ public int[] toIntArrayUnchecked() { - int[] result = new int[numElements]; + final int[] result = new int[numElements]; Platform.copyMemory(baseObject, baseOffset + 4 + 4L * numElements, result, Platform.INT_ARRAY_OFFSET, 4L * numElements); return result; @@ -355,17 +355,22 @@ public int[] toIntArrayUnchecked() { * null elements into the data region and memory copy will crash as the data size doesn't match. */ public double[] toDoubleArrayUnchecked() { - double[] result = new double[numElements]; + final double[] result = new double[numElements]; Platform.copyMemory(baseObject, baseOffset + 4 + 4L * numElements, result, Platform.DOUBLE_ARRAY_OFFSET, 8L * numElements); return result; } public static UnsafeArrayData fromPrimitiveArray(int[] arr) { - int offsetRegionSize = 4 * arr.length; - int valueRegionSize = 4 * arr.length; - int totalSize = 4 + offsetRegionSize + valueRegionSize; - byte[] data = new byte[totalSize]; + if (arr.length > (Integer.MAX_VALUE - 4) / 8) { + throw new UnsupportedOperationException("Cannot convert this array to unsafe format as " + + "it's too big."); + } + + final int offsetRegionSize = 4 * arr.length; + final int valueRegionSize = 4 * arr.length; + final int totalSize = 4 + offsetRegionSize + valueRegionSize; + final byte[] data = new byte[totalSize]; Platform.putInt(data, Platform.BYTE_ARRAY_OFFSET, arr.length); @@ -386,10 +391,15 @@ public static UnsafeArrayData fromPrimitiveArray(int[] arr) { } public static UnsafeArrayData fromPrimitiveArray(double[] arr) { - int offsetRegionSize = 4 * arr.length; - int valueRegionSize = 8 * arr.length; - int totalSize = 4 + offsetRegionSize + valueRegionSize; - byte[] data = new byte[totalSize]; + if (arr.length > (Integer.MAX_VALUE - 4) / 12) { + throw new UnsupportedOperationException("Cannot convert this array to unsafe format as " + + "it's too big."); + } + + final int offsetRegionSize = 4 * arr.length; + final int valueRegionSize = 8 * arr.length; + final int totalSize = 4 + offsetRegionSize + valueRegionSize; + final byte[] data = new byte[totalSize]; Platform.putInt(data, Platform.BYTE_ARRAY_OFFSET, arr.length); From b10845cb57a574c3337d3658d1c0cb42a22910a8 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 30 Apr 2016 09:19:05 +0800 Subject: [PATCH 7/7] remove unchecked toArray --- .../apache/spark/mllib/linalg/Matrices.scala | 15 +++--------- .../apache/spark/mllib/linalg/Vectors.scala | 15 +++--------- .../catalyst/expressions/UnsafeArrayData.java | 24 ------------------- .../sql/catalyst/util/UnsafeArraySuite.scala | 14 ----------- 4 files changed, 6 insertions(+), 62 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 955794d17d36..076cca6016ec 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -218,21 +218,12 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] { val tpe = row.getByte(0) val numRows = row.getInt(1) val numCols = row.getInt(2) - val values = row.getArray(5) match { - case u: UnsafeArrayData => u.toDoubleArrayUnchecked - case a => a.toDoubleArray() - } + val values = row.getArray(5).toDoubleArray() val isTransposed = row.getBoolean(6) tpe match { case 0 => - val colPtrs = row.getArray(3) match { - case u: UnsafeArrayData => u.toIntArrayUnchecked - case a => a.toIntArray() - } - val rowIndices = row.getArray(4) match { - case u: UnsafeArrayData => u.toIntArrayUnchecked - case a => a.toIntArray() - } + val colPtrs = row.getArray(3).toIntArray() + val rowIndices = row.getArray(4).toIntArray() new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values, isTransposed) case 1 => new DenseMatrix(numRows, numCols, values, isTransposed) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index add027d2bfca..132e54a8c3de 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -237,20 +237,11 @@ class VectorUDT extends UserDefinedType[Vector] { tpe match { case 0 => val size = row.getInt(1) - val indices = row.getArray(2) match { - case u: UnsafeArrayData => u.toIntArrayUnchecked - case a => a.toIntArray() - } - val values = row.getArray(3) match { - case u: UnsafeArrayData => u.toDoubleArrayUnchecked - case a => a.toDoubleArray() - } + val indices = row.getArray(2).toIntArray() + val values = row.getArray(3).toDoubleArray() new SparseVector(size, indices, values) case 1 => - val values = row.getArray(3) match { - case u: UnsafeArrayData => u.toDoubleArrayUnchecked - case a => a.toDoubleArray() - } + val values = row.getArray(3).toDoubleArray() new DenseVector(values) } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index 9333d3e38421..02a863b2bb49 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -337,30 +337,6 @@ public UnsafeArrayData copy() { return arrayCopy; } - /** - * A faster version of `toIntArray`, which use memory copy instead of iterating all elements. - * Note that, this method is dangerous if this array contains null elements. We don't write - * null elements into the data region and memory copy will crash as the data size doesn't match. - */ - public int[] toIntArrayUnchecked() { - final int[] result = new int[numElements]; - Platform.copyMemory(baseObject, baseOffset + 4 + 4L * numElements, - result, Platform.INT_ARRAY_OFFSET, 4L * numElements); - return result; - } - - /** - * A faster version of `toDoubleArray`, which use memory copy instead of iterating all elements. - * Note that, this method is dangerous if this array contains null elements. We don't write - * null elements into the data region and memory copy will crash as the data size doesn't match. - */ - public double[] toDoubleArrayUnchecked() { - final double[] result = new double[numElements]; - Platform.copyMemory(baseObject, baseOffset + 4 + 4L * numElements, - result, Platform.DOUBLE_ARRAY_OFFSET, 8L * numElements); - return result; - } - public static UnsafeArrayData fromPrimitiveArray(int[] arr) { if (arr.length > (Integer.MAX_VALUE - 4) / 8) { throw new UnsupportedOperationException("Cannot convert this array to unsafe format as " + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala index 1632596e19a3..1685276ff120 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala @@ -41,18 +41,4 @@ class UnsafeArraySuite extends SparkFunSuite { assert(unsafe.getDouble(1) == 2.2) assert(unsafe.getDouble(2) == 3.3) } - - test("to int array unchecked") { - val array = Array(1, 10, 100) - val unsafe = UnsafeArrayData.fromPrimitiveArray(array) - val array2 = unsafe.toIntArrayUnchecked - assert(array.toSeq == array2.toSeq) - } - - test("to double array unchecked") { - val array = Array(1.1, 2.2, 3.3) - val unsafe = UnsafeArrayData.fromPrimitiveArray(array) - val array2 = unsafe.toDoubleArrayUnchecked - assert(array.toSeq == array2.toSeq) - } }