-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-14850][ML] convert primitive array from/to unsafe array directly in VectorUDT/MatrixUDT #12640
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-14850][ML] convert primitive array from/to unsafe array directly in VectorUDT/MatrixUDT #12640
Changes from 1 commit
55d3178
f4d2cbb
a7b7694
f6964f9
c6c3584
537e363
ae6f365
b10845c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -81,7 +81,7 @@ private void assertIndexIsValid(int ordinal) { | |
| } | ||
|
|
||
| public Object[] array() { | ||
| throw new UnsupportedOperationException("Only supported on GenericArrayData."); | ||
| throw new UnsupportedOperationException("Not supported on UnsafeArrayData."); | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -336,4 +336,62 @@ public UnsafeArrayData copy() { | |
| arrayCopy.pointTo(arrayDataCopy, Platform.BYTE_ARRAY_OFFSET, sizeInBytes); | ||
| return arrayCopy; | ||
| } | ||
|
|
||
| public int[] toPrimitiveIntArray() { | ||
| int[] result = new int[numElements]; | ||
| Platform.copyMemory(baseObject, baseOffset + 4 + 4 * numElements, | ||
|
||
| result, Platform.INT_ARRAY_OFFSET, 4 * numElements); | ||
| return result; | ||
| } | ||
|
|
||
| public double[] toPrimitiveDoubleArray() { | ||
| double[] result = new double[numElements]; | ||
| Platform.copyMemory(baseObject, baseOffset + 4 + 4 * numElements, | ||
| result, Platform.DOUBLE_ARRAY_OFFSET, 8 * numElements); | ||
| return result; | ||
| } | ||
|
|
||
| public static UnsafeArrayData fromPrimitiveArray(int[] arr) { | ||
| int offsetRegionSize = 4 * arr.length; | ||
|
||
| int valueRegionSize = 4 * arr.length; | ||
| int totalSize = 4 + offsetRegionSize + valueRegionSize; | ||
| byte[] data = new byte[totalSize]; | ||
|
||
|
|
||
| Platform.putInt(data, Platform.BYTE_ARRAY_OFFSET, arr.length); | ||
|
|
||
| int elementOffsetStart = 4 + offsetRegionSize; | ||
| for (int i = 0; i < arr.length; i++) { | ||
| Platform.putInt(data, Platform.BYTE_ARRAY_OFFSET + 4 + i * 4, elementOffsetStart + i * 4); | ||
|
||
| } | ||
|
|
||
| Platform.copyMemory(arr, Platform.INT_ARRAY_OFFSET, data, | ||
| Platform.BYTE_ARRAY_OFFSET + elementOffsetStart, valueRegionSize); | ||
|
|
||
| UnsafeArrayData result = new UnsafeArrayData(); | ||
| result.pointTo(data, Platform.BYTE_ARRAY_OFFSET, totalSize); | ||
| return result; | ||
| } | ||
|
|
||
| public static UnsafeArrayData fromPrimitiveArray(double[] arr) { | ||
| int offsetRegionSize = 4 * arr.length; | ||
| int valueRegionSize = 8 * arr.length; | ||
| int totalSize = 4 + offsetRegionSize + valueRegionSize; | ||
| byte[] data = new byte[totalSize]; | ||
|
|
||
| Platform.putInt(data, Platform.BYTE_ARRAY_OFFSET, arr.length); | ||
|
|
||
| int elementOffsetStart = 4 + offsetRegionSize; | ||
| for (int i = 0; i < arr.length; i++) { | ||
| Platform.putInt(data, Platform.BYTE_ARRAY_OFFSET + 4 + i * 4, elementOffsetStart + i * 8); | ||
| } | ||
|
|
||
| Platform.copyMemory(arr, Platform.DOUBLE_ARRAY_OFFSET, data, | ||
| Platform.BYTE_ARRAY_OFFSET + elementOffsetStart, valueRegionSize); | ||
|
|
||
| UnsafeArrayData result = new UnsafeArrayData(); | ||
| result.pointTo(data, Platform.BYTE_ARRAY_OFFSET, totalSize); | ||
| return result; | ||
| } | ||
|
|
||
| // TODO: add more specialized methods. | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,58 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.sql.catalyst.util | ||
|
|
||
| import org.apache.spark.SparkFunSuite | ||
| import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData | ||
|
|
||
| class UnsafeArraySuite extends SparkFunSuite { | ||
|
|
||
| test("from primitive int array") { | ||
| val array = Array(1, 10, 100) | ||
| val unsafe = UnsafeArrayData.fromPrimitiveArray(array) | ||
| assert(unsafe.numElements == 3) | ||
| assert(unsafe.getSizeInBytes == 4 + 4 * 3 + 4 * 3) | ||
| assert(unsafe.getInt(0) == 1) | ||
| assert(unsafe.getInt(1) == 10) | ||
| assert(unsafe.getInt(2) == 100) | ||
| } | ||
|
|
||
| test("from primitive double array") { | ||
| val array = Array(1.1, 2.2, 3.3) | ||
| val unsafe = UnsafeArrayData.fromPrimitiveArray(array) | ||
| assert(unsafe.numElements == 3) | ||
| assert(unsafe.getSizeInBytes == 4 + 4 * 3 + 8 * 3) | ||
| assert(unsafe.getDouble(0) == 1.1) | ||
| assert(unsafe.getDouble(1) == 2.2) | ||
| assert(unsafe.getDouble(2) == 3.3) | ||
| } | ||
|
|
||
| test("to primitive int array") { | ||
| val array = Array(1, 10, 100) | ||
| val unsafe = UnsafeArrayData.fromPrimitiveArray(array) | ||
| val array2 = unsafe.toPrimitiveIntArray | ||
| assert(array.toSeq == array2.toSeq) | ||
| } | ||
|
|
||
| test("to primitive double array") { | ||
| val array = Array(1.1, 2.2, 3.3) | ||
| val unsafe = UnsafeArrayData.fromPrimitiveArray(array) | ||
| val array2 = unsafe.toPrimitiveDoubleArray | ||
| assert(array.toSeq == array2.toSeq) | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't override the
toIntArray, but create this special method instead. This operation is dangerous, if some elements are null, we won't return 0, but may crash instead. The reason is we don't write null values, if an element is null, we simply mark it as null in the offset region and skip it. For example, the data size of unsafe int array may be less than4 * numElementsand the memory copy may crash.Ideally I think we need to improve unsafe array format to handle primitive array better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be hard to tell the difference between
toPrimitiveIntArrayandtoIntArrayby name and signature because both returns primitive arrays. How abouttoIntArrayUnchecked? Please add JavaDoc to explain the difference.