diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java index d92293b918703..721e6a60befe2 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java @@ -49,31 +49,43 @@ public int numElements() { return length; } + /** + * Sets all the appropriate null bits in the input UnsafeArrayData. + * + * @param arrayData The UnsafeArrayData to set the null bits for + * @return The UnsafeArrayData with the null bits set + */ + private UnsafeArrayData setNullBits(UnsafeArrayData arrayData) { + if (data.hasNull()) { + for (int i = 0; i < length; i++) { + if (data.isNullAt(i)) { + arrayData.setNullAt(i); + } + } + } + return arrayData; + } + @Override public ArrayData copy() { DataType dt = data.dataType(); - if (data.hasNull()) { - // UnsafeArrayData cannot be used if there are any nulls. - return new GenericArrayData(toObjectArray(dt)).copy(); - } - if (dt instanceof BooleanType) { - return UnsafeArrayData.fromPrimitiveArray(toBooleanArray()); + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toBooleanArray())); } else if (dt instanceof ByteType) { - return UnsafeArrayData.fromPrimitiveArray(toByteArray()); + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toByteArray())); } else if (dt instanceof ShortType) { - return UnsafeArrayData.fromPrimitiveArray(toShortArray()); + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toShortArray())); } else if (dt instanceof IntegerType || dt instanceof DateType || dt instanceof YearMonthIntervalType) { - return UnsafeArrayData.fromPrimitiveArray(toIntArray()); + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toIntArray())); } else if (dt instanceof LongType || dt instanceof TimestampType || dt instanceof DayTimeIntervalType) { - return UnsafeArrayData.fromPrimitiveArray(toLongArray()); + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toLongArray())); } else if (dt instanceof FloatType) { - return UnsafeArrayData.fromPrimitiveArray(toFloatArray()); + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toFloatArray())); } else if (dt instanceof DoubleType) { - return UnsafeArrayData.fromPrimitiveArray(toDoubleArray()); + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toDoubleArray())); } else { return new GenericArrayData(toObjectArray(dt)).copy(); // ensure the elements are copied. }