Skip to content

Commit 7a573b9

Browse files
gene-dbcloud-fan
authored andcommitted
[SPARK-48019] Fix incorrect behavior in ColumnVector/ColumnarArray with dictionary and nulls
This fixes how `ColumnVector` handles copying arrays when the vector has a dictionary and null values. The possible issues with the previous implementation: - An `ArrayIndexOutOfBoundsException` may be thrown when the `ColumnVector` has nulls and dictionaries. This is because the dictionary id for `null` entries might be invalid and should not be used for `null` entries. - Copying a `ColumnarArray` (which contains a `ColumnVector`) is incorrect, if it contains `null` entries. This is because copying a primitive array does not take into account the `null` entries, so all the null entries get lost. These changes are needed to avoid `ArrayIndexOutOfBoundsException` and to produce correct results when copying `ColumnarArray`. The only user facing changes are to fix existing errors and incorrect results. Added new unit tests. No. Closes #46254 from gene-db/dictionary-nulls. Authored-by: Gene Pang <[email protected]> Signed-off-by: Wenchen Fan <[email protected]> (cherry picked from commit 76ce6b0) Signed-off-by: Wenchen Fan <[email protected]>
1 parent fdc0cee commit 7a573b9

File tree

4 files changed

+215
-12
lines changed

4 files changed

+215
-12
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@ public int numElements() {
5151
public ArrayData copy() {
5252
DataType dt = data.dataType();
5353

54+
if (data.hasNull()) {
55+
// UnsafeArrayData cannot be used if there are any nulls.
56+
return new GenericArrayData(toObjectArray(dt)).copy();
57+
}
58+
5459
if (dt instanceof BooleanType) {
5560
return UnsafeArrayData.fromPrimitiveArray(toBooleanArray());
5661
} else if (dt instanceof ByteType) {

sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,9 @@ public byte[] getBytes(int rowId, int count) {
215215
Platform.copyMemory(null, data + rowId, array, Platform.BYTE_ARRAY_OFFSET, count);
216216
} else {
217217
for (int i = 0; i < count; i++) {
218-
array[i] = getByte(rowId + i);
218+
if (!isNullAt(rowId + i)) {
219+
array[i] = (byte) dictionary.decodeToInt(dictionaryIds.getDictId(rowId + i));
220+
}
219221
}
220222
}
221223
return array;
@@ -276,7 +278,9 @@ public short[] getShorts(int rowId, int count) {
276278
Platform.copyMemory(null, data + rowId * 2L, array, Platform.SHORT_ARRAY_OFFSET, count * 2L);
277279
} else {
278280
for (int i = 0; i < count; i++) {
279-
array[i] = getShort(rowId + i);
281+
if (!isNullAt(rowId + i)) {
282+
array[i] = (short) dictionary.decodeToInt(dictionaryIds.getDictId(rowId + i));
283+
}
280284
}
281285
}
282286
return array;
@@ -342,7 +346,9 @@ public int[] getInts(int rowId, int count) {
342346
Platform.copyMemory(null, data + rowId * 4L, array, Platform.INT_ARRAY_OFFSET, count * 4L);
343347
} else {
344348
for (int i = 0; i < count; i++) {
345-
array[i] = getInt(rowId + i);
349+
if (!isNullAt(rowId + i)) {
350+
array[i] = dictionary.decodeToInt(dictionaryIds.getDictId(rowId + i));
351+
}
346352
}
347353
}
348354
return array;
@@ -420,7 +426,9 @@ public long[] getLongs(int rowId, int count) {
420426
Platform.copyMemory(null, data + rowId * 8L, array, Platform.LONG_ARRAY_OFFSET, count * 8L);
421427
} else {
422428
for (int i = 0; i < count; i++) {
423-
array[i] = getLong(rowId + i);
429+
if (!isNullAt(rowId + i)) {
430+
array[i] = dictionary.decodeToLong(dictionaryIds.getDictId(rowId + i));
431+
}
424432
}
425433
}
426434
return array;
@@ -484,7 +492,9 @@ public float[] getFloats(int rowId, int count) {
484492
Platform.copyMemory(null, data + rowId * 4L, array, Platform.FLOAT_ARRAY_OFFSET, count * 4L);
485493
} else {
486494
for (int i = 0; i < count; i++) {
487-
array[i] = getFloat(rowId + i);
495+
if (!isNullAt(rowId + i)) {
496+
array[i] = dictionary.decodeToFloat(dictionaryIds.getDictId(rowId + i));
497+
}
488498
}
489499
}
490500
return array;
@@ -550,7 +560,9 @@ public double[] getDoubles(int rowId, int count) {
550560
count * 8L);
551561
} else {
552562
for (int i = 0; i < count; i++) {
553-
array[i] = getDouble(rowId + i);
563+
if (!isNullAt(rowId + i)) {
564+
array[i] = dictionary.decodeToDouble(dictionaryIds.getDictId(rowId + i));
565+
}
554566
}
555567
}
556568
return array;

sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,9 @@ public byte[] getBytes(int rowId, int count) {
213213
System.arraycopy(byteData, rowId, array, 0, count);
214214
} else {
215215
for (int i = 0; i < count; i++) {
216-
array[i] = getByte(rowId + i);
216+
if (!isNullAt(rowId + i)) {
217+
array[i] = (byte) dictionary.decodeToInt(dictionaryIds.getDictId(rowId + i));
218+
}
217219
}
218220
}
219221
return array;
@@ -273,7 +275,9 @@ public short[] getShorts(int rowId, int count) {
273275
System.arraycopy(shortData, rowId, array, 0, count);
274276
} else {
275277
for (int i = 0; i < count; i++) {
276-
array[i] = getShort(rowId + i);
278+
if (!isNullAt(rowId + i)) {
279+
array[i] = (short) dictionary.decodeToInt(dictionaryIds.getDictId(rowId + i));
280+
}
277281
}
278282
}
279283
return array;
@@ -334,7 +338,9 @@ public int[] getInts(int rowId, int count) {
334338
System.arraycopy(intData, rowId, array, 0, count);
335339
} else {
336340
for (int i = 0; i < count; i++) {
337-
array[i] = getInt(rowId + i);
341+
if (!isNullAt(rowId + i)) {
342+
array[i] = dictionary.decodeToInt(dictionaryIds.getDictId(rowId + i));
343+
}
338344
}
339345
}
340346
return array;
@@ -406,7 +412,9 @@ public long[] getLongs(int rowId, int count) {
406412
System.arraycopy(longData, rowId, array, 0, count);
407413
} else {
408414
for (int i = 0; i < count; i++) {
409-
array[i] = getLong(rowId + i);
415+
if (!isNullAt(rowId + i)) {
416+
array[i] = dictionary.decodeToLong(dictionaryIds.getDictId(rowId + i));
417+
}
410418
}
411419
}
412420
return array;
@@ -463,7 +471,9 @@ public float[] getFloats(int rowId, int count) {
463471
System.arraycopy(floatData, rowId, array, 0, count);
464472
} else {
465473
for (int i = 0; i < count; i++) {
466-
array[i] = getFloat(rowId + i);
474+
if (!isNullAt(rowId + i)) {
475+
array[i] = dictionary.decodeToFloat(dictionaryIds.getDictId(rowId + i));
476+
}
467477
}
468478
}
469479
return array;
@@ -522,7 +532,9 @@ public double[] getDoubles(int rowId, int count) {
522532
System.arraycopy(doubleData, rowId, array, 0, count);
523533
} else {
524534
for (int i = 0; i < count; i++) {
525-
array[i] = getDouble(rowId + i);
535+
if (!isNullAt(rowId + i)) {
536+
array[i] = dictionary.decodeToDouble(dictionaryIds.getDictId(rowId + i));
537+
}
526538
}
527539
}
528540
return array;

sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,180 @@ class ColumnVectorSuite extends SparkFunSuite {
473473
assert(testVector.getDoubles(0, 3)(2) == 1342.17729d)
474474
}
475475

476+
def check(expected: Seq[Any], testVector: WritableColumnVector): Unit = {
477+
expected.zipWithIndex.foreach {
478+
case (v: Integer, idx) =>
479+
assert(testVector.getInt(idx) == v)
480+
assert(testVector.getInts(0, testVector.capacity)(idx) == v)
481+
case (v: Short, idx) =>
482+
assert(testVector.getShort(idx) == v)
483+
assert(testVector.getShorts(0, testVector.capacity)(idx) == v)
484+
case (v: Byte, idx) =>
485+
assert(testVector.getByte(idx) == v)
486+
assert(testVector.getBytes(0, testVector.capacity)(idx) == v)
487+
case (v: Long, idx) =>
488+
assert(testVector.getLong(idx) == v)
489+
assert(testVector.getLongs(0, testVector.capacity)(idx) == v)
490+
case (v: Float, idx) =>
491+
assert(testVector.getFloat(idx) == v)
492+
assert(testVector.getFloats(0, testVector.capacity)(idx) == v)
493+
case (v: Double, idx) =>
494+
assert(testVector.getDouble(idx) == v)
495+
assert(testVector.getDoubles(0, testVector.capacity)(idx) == v)
496+
case (null, idx) => testVector.isNullAt(idx)
497+
case (_, idx) => assert(false, s"Unexpected value at $idx")
498+
}
499+
500+
// Verify ColumnarArray.copy() works as expected
501+
val arr = new ColumnarArray(testVector, 0, testVector.capacity)
502+
assert(arr.toSeq(testVector.dataType) == expected)
503+
assert(arr.copy().toSeq(testVector.dataType) == expected)
504+
}
505+
506+
testVectors("getInts with dictionary and nulls", 3, IntegerType) { testVector =>
507+
// Validate without dictionary
508+
val expected = Seq(1, null, 3)
509+
expected.foreach {
510+
case i: Integer => testVector.appendInt(i)
511+
case _ => testVector.appendNull()
512+
}
513+
check(expected, testVector)
514+
515+
// Validate with dictionary
516+
val expectedDictionary = Seq(7, null, 9)
517+
val dictArray = (Seq(-1, -1) ++ expectedDictionary.map {
518+
case i: Integer => i.toInt
519+
case _ => -1
520+
}).toArray
521+
val dict = new ColumnDictionary(dictArray)
522+
testVector.setDictionary(dict)
523+
testVector.reserveDictionaryIds(3)
524+
testVector.getDictionaryIds.putInt(0, 2)
525+
testVector.getDictionaryIds.putInt(1, -1) // This is a null, so the entry should be ignored
526+
testVector.getDictionaryIds.putInt(2, 4)
527+
check(expectedDictionary, testVector)
528+
}
529+
530+
testVectors("getShorts with dictionary and nulls", 3, ShortType) { testVector =>
531+
// Validate without dictionary
532+
val expected = Seq(1.toShort, null, 3.toShort)
533+
expected.foreach {
534+
case i: Short => testVector.appendShort(i)
535+
case _ => testVector.appendNull()
536+
}
537+
check(expected, testVector)
538+
539+
// Validate with dictionary
540+
val expectedDictionary = Seq(7.toShort, null, 9.toShort)
541+
val dictArray = (Seq(-1, -1) ++ expectedDictionary.map {
542+
case i: Short => i.toInt
543+
case _ => -1
544+
}).toArray
545+
val dict = new ColumnDictionary(dictArray)
546+
testVector.setDictionary(dict)
547+
testVector.reserveDictionaryIds(3)
548+
testVector.getDictionaryIds.putInt(0, 2)
549+
testVector.getDictionaryIds.putInt(1, -1) // This is a null, so the entry should be ignored
550+
testVector.getDictionaryIds.putInt(2, 4)
551+
check(expectedDictionary, testVector)
552+
}
553+
554+
testVectors("getBytes with dictionary and nulls", 3, ByteType) { testVector =>
555+
// Validate without dictionary
556+
val expected = Seq(1.toByte, null, 3.toByte)
557+
expected.foreach {
558+
case i: Byte => testVector.appendByte(i)
559+
case _ => testVector.appendNull()
560+
}
561+
check(expected, testVector)
562+
563+
// Validate with dictionary
564+
val expectedDictionary = Seq(7.toByte, null, 9.toByte)
565+
val dictArray = (Seq(-1, -1) ++ expectedDictionary.map {
566+
case i: Byte => i.toInt
567+
case _ => -1
568+
}).toArray
569+
val dict = new ColumnDictionary(dictArray)
570+
testVector.setDictionary(dict)
571+
testVector.reserveDictionaryIds(3)
572+
testVector.getDictionaryIds.putInt(0, 2)
573+
testVector.getDictionaryIds.putInt(1, -1) // This is a null, so the entry should be ignored
574+
testVector.getDictionaryIds.putInt(2, 4)
575+
check(expectedDictionary, testVector)
576+
}
577+
578+
testVectors("getLongs with dictionary and nulls", 3, LongType) { testVector =>
579+
// Validate without dictionary
580+
val expected = Seq(2147483L, null, 2147485L)
581+
expected.foreach {
582+
case i: Long => testVector.appendLong(i)
583+
case _ => testVector.appendNull()
584+
}
585+
check(expected, testVector)
586+
587+
// Validate with dictionary
588+
val expectedDictionary = Seq(2147483648L, null, 2147483650L)
589+
val dictArray = (Seq(-1L, -1L) ++ expectedDictionary.map {
590+
case i: Long => i
591+
case _ => -1L
592+
}).toArray
593+
val dict = new ColumnDictionary(dictArray)
594+
testVector.setDictionary(dict)
595+
testVector.reserveDictionaryIds(3)
596+
testVector.getDictionaryIds.putInt(0, 2)
597+
testVector.getDictionaryIds.putInt(1, -1) // This is a null, so the entry should be ignored
598+
testVector.getDictionaryIds.putInt(2, 4)
599+
check(expectedDictionary, testVector)
600+
}
601+
602+
testVectors("getFloats with dictionary and nulls", 3, FloatType) { testVector =>
603+
// Validate without dictionary
604+
val expected = Seq(1.1f, null, 3.3f)
605+
expected.foreach {
606+
case i: Float => testVector.appendFloat(i)
607+
case _ => testVector.appendNull()
608+
}
609+
check(expected, testVector)
610+
611+
// Validate with dictionary
612+
val expectedDictionary = Seq(0.1f, null, 0.3f)
613+
val dictArray = (Seq(-1f, -1f) ++ expectedDictionary.map {
614+
case i: Float => i
615+
case _ => -1f
616+
}).toArray
617+
val dict = new ColumnDictionary(dictArray)
618+
testVector.setDictionary(dict)
619+
testVector.reserveDictionaryIds(3)
620+
testVector.getDictionaryIds.putInt(0, 2)
621+
testVector.getDictionaryIds.putInt(1, -1) // This is a null, so the entry should be ignored
622+
testVector.getDictionaryIds.putInt(2, 4)
623+
check(expectedDictionary, testVector)
624+
}
625+
626+
testVectors("getDoubles with dictionary and nulls", 3, DoubleType) { testVector =>
627+
// Validate without dictionary
628+
val expected = Seq(1.1d, null, 3.3d)
629+
expected.foreach {
630+
case i: Double => testVector.appendDouble(i)
631+
case _ => testVector.appendNull()
632+
}
633+
check(expected, testVector)
634+
635+
// Validate with dictionary
636+
val expectedDictionary = Seq(1342.17727d, null, 1342.17729d)
637+
val dictArray = (Seq(-1d, -1d) ++ expectedDictionary.map {
638+
case i: Double => i
639+
case _ => -1d
640+
}).toArray
641+
val dict = new ColumnDictionary(dictArray)
642+
testVector.setDictionary(dict)
643+
testVector.reserveDictionaryIds(3)
644+
testVector.getDictionaryIds.putInt(0, 2)
645+
testVector.getDictionaryIds.putInt(1, -1) // This is a null, so the entry should be ignored
646+
testVector.getDictionaryIds.putInt(2, 4)
647+
check(expectedDictionary, testVector)
648+
}
649+
476650
test("[SPARK-22092] off-heap column vector reallocation corrupts array data") {
477651
withVector(new OffHeapColumnVector(8, arrayType)) { testVector =>
478652
val data = testVector.arrayData()

0 commit comments

Comments
 (0)