diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index 8c0f1e125750..97a685a25a80 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -611,6 +611,9 @@ public final int appendByteArray(byte[] value, int offset, int length) { public final int appendArray(int length) { reserve(elementsAppended + 1); + for (WritableColumnVector childColumn : childColumns) { + childColumn.reserve(childColumn.elementsAppended + length); + } putArray(elementsAppended, arrayData().elementsAppended, length); return elementsAppended++; } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala index 0920f823d845..a45276573589 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala @@ -264,12 +264,12 @@ private object RowToColumnConverter { case DoubleType => DoubleConverter case StringType => StringConverter case CalendarIntervalType => CalendarConverter - case at: ArrayType => new ArrayConverter(getConverterForType(at.elementType, nullable)) + case at: ArrayType => ArrayConverter(getConverterForType(at.elementType, at.containsNull)) case st: StructType => new StructConverter(st.fields.map( (f) => getConverterForType(f.dataType, f.nullable))) case dt: DecimalType => new DecimalConverter(dt) - case mt: MapType => new MapConverter(getConverterForType(mt.keyType, nullable), - getConverterForType(mt.valueType, nullable)) + case mt: MapType => MapConverter(getConverterForType(mt.keyType, nullable = false), + getConverterForType(mt.valueType, mt.valueContainsNull)) case unknown => throw new UnsupportedOperationException( s"Type $unknown not supported") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowToColumnConverterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowToColumnConverterSuite.scala new file mode 100644 index 000000000000..1afe742b988e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowToColumnConverterSuite.scala @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} +import org.apache.spark.sql.execution.vectorized.{OnHeapColumnVector, WritableColumnVector} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +class RowToColumnConverterSuite extends SparkFunSuite { + def convertRows(rows: Seq[InternalRow], schema: StructType): Seq[WritableColumnVector] = { + val converter = new RowToColumnConverter(schema) + val vectors = + schema.map(f => new OnHeapColumnVector(5, f.dataType)).toArray[WritableColumnVector] + for (row <- rows) { + converter.convert(row, vectors) + } + vectors + } + + test("integer column") { + val schema = StructType(Seq(StructField("i", IntegerType))) + val rows = (0 until 100).map(i => InternalRow(i)) + val vectors = convertRows(rows, schema) + rows.zipWithIndex.map { case (row, i) => + assert(vectors.head.getInt(i) === row.getInt(0)) + } + } + + test("array column") { + val schema = StructType(Seq(StructField("a", ArrayType(IntegerType)))) + val rows = (0 until 100).map { i => + InternalRow(new GenericArrayData(0 until i)) + } + val vectors = convertRows(rows, schema) + rows.zipWithIndex.map { case (row, i) => + assert(vectors.head.getArray(i).array().array === row.getArray(0).array) + } + } + + test("non-nullable array column with null elements") { + val arrayType = ArrayType(IntegerType, containsNull = true) + val schema = StructType(Seq(StructField("a", arrayType, nullable = false))) + val rows = (0 until 100).map { i => + InternalRow(new GenericArrayData((0 until i).map { j => + if (j % 3 == 0) { + null + } else { + j + } + })) + } + val vectors = convertRows(rows, schema) + rows.zipWithIndex.map { case (row, i) => + assert(vectors.head.getArray(i).array().array === row.getArray(0).array) + } + } + + test("nested array column") { + val arrayType = ArrayType(ArrayType(IntegerType)) + val schema = StructType(Seq(StructField("a", arrayType))) + val rows = (0 until 100).map { i => + InternalRow(new GenericArrayData((0 until i).map(j => new GenericArrayData(0 until j)))) + } + val vectors = convertRows(rows, schema) + rows.zipWithIndex.map { case (row, i) => + val result = vectors.head.getArray(i).array().array + .map(_.asInstanceOf[ArrayData].array) + val expected = row.getArray(0).array + .map(_.asInstanceOf[ArrayData].array) + assert(result === expected) + } + } + + test("map column") { + val mapType = MapType(IntegerType, StringType) + val schema = StructType(Seq(StructField("m", mapType))) + val rows = (0 until 100).map { i => + InternalRow(new ArrayBasedMapData( + new GenericArrayData(0 until i), + new GenericArrayData((0 until i).map(j => UTF8String.fromString(s"str$j"))))) + } + val vectors = convertRows(rows, schema) + rows.zipWithIndex.map { case (row, i) => + val result = vectors.head.getMap(i) + val expected = row.getMap(0) + assert(result.keyArray().array().array === expected.keyArray().array) + assert(result.valueArray().array().array === expected.valueArray().array) + } + } + + test("non-nullable map column with null values") { + val mapType = MapType(IntegerType, StringType, valueContainsNull = true) + val schema = StructType(Seq(StructField("m", mapType, nullable = false))) + val rows = (0 until 100).map { i => + InternalRow(new ArrayBasedMapData( + new GenericArrayData(0 until i), + new GenericArrayData((0 until i).map { j => + if (j % 3 == 0) { + null + } else { + UTF8String.fromString(s"str$j") + } + }))) + } + val vectors = convertRows(rows, schema) + rows.zipWithIndex.map { case (row, i) => + val result = vectors.head.getMap(i) + val expected = row.getMap(0) + assert(result.keyArray().array().array === expected.keyArray().array) + assert(result.valueArray().array().array === expected.valueArray().array) + } + } + + test("multiple columns") { + val schema = StructType( + Seq(StructField("s", ShortType), StructField("i", IntegerType), StructField("l", LongType))) + val rows = (0 until 100).map { i => + InternalRow((3 * i).toShort, 3 * i + 1, (3 * i + 2).toLong) + } + val vectors = convertRows(rows, schema) + rows.zipWithIndex.map { case (row, i) => + assert(vectors(0).getShort(i) === row.getShort(0)) + assert(vectors(1).getInt(i) === row.getInt(1)) + assert(vectors(2).getLong(i) === row.getLong(2)) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala index 247efd5554a8..43f48abb9734 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala @@ -243,6 +243,93 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { assert(testVector.getArray(3).toIntArray() === Array(3, 4, 5)) } + testVectors("SPARK-35898: array append", 1, arrayType) { testVector => + // Populate it with arrays [0], [1, 2], [], [3, 4, 5] + val data = testVector.arrayData() + testVector.appendArray(1) + data.appendInt(0) + testVector.appendArray(2) + data.appendInt(1) + data.appendInt(2) + testVector.appendArray(0) + testVector.appendArray(3) + data.appendInt(3) + data.appendInt(4) + data.appendInt(5) + + assert(testVector.getArray(0).toIntArray === Array(0)) + assert(testVector.getArray(1).toIntArray === Array(1, 2)) + assert(testVector.getArray(2).toIntArray === Array.empty[Int]) + assert(testVector.getArray(3).toIntArray === Array(3, 4, 5)) + } + + val mapType: MapType = MapType(IntegerType, StringType) + testVectors("SPARK-35898: map", 5, mapType) { testVector => + val keys = testVector.getChild(0) + val values = testVector.getChild(1) + var i = 0 + while (i < 6) { + keys.appendInt(i) + val utf8 = s"str$i".getBytes("utf8") + values.appendByteArray(utf8, 0, utf8.length) + i += 1 + } + + testVector.putArray(0, 0, 1) + testVector.putArray(1, 1, 2) + testVector.putArray(2, 3, 0) + testVector.putArray(3, 3, 3) + + assert(testVector.getMap(0).keyArray().toIntArray === Array(0)) + assert(testVector.getMap(0).valueArray().toArray[UTF8String](StringType) === + Array(UTF8String.fromString(s"str0"))) + assert(testVector.getMap(1).keyArray().toIntArray === Array(1, 2)) + assert(testVector.getMap(1).valueArray().toArray[UTF8String](StringType) === + (1 to 2).map(i => UTF8String.fromString(s"str$i")).toArray) + assert(testVector.getMap(2).keyArray().toIntArray === Array.empty[Int]) + assert(testVector.getMap(2).valueArray().toArray[UTF8String](StringType) === + Array.empty[UTF8String]) + assert(testVector.getMap(3).keyArray().toIntArray === Array(3, 4, 5)) + assert(testVector.getMap(3).valueArray().toArray[UTF8String](StringType) === + (3 to 5).map(i => UTF8String.fromString(s"str$i")).toArray) + } + + testVectors("SPARK-35898: map append", 1, mapType) { testVector => + val keys = testVector.getChild(0) + val values = testVector.getChild(1) + def appendPair(i: Int): Unit = { + keys.appendInt(i) + val utf8 = s"str$i".getBytes("utf8") + values.appendByteArray(utf8, 0, utf8.length) + } + + // Populate it with the maps [0 -> str0], [1 -> str1, 2 -> str2], [], + // [3 -> str3, 4 -> str4, 5 -> str5] + testVector.appendArray(1) + appendPair(0) + testVector.appendArray(2) + appendPair(1) + appendPair(2) + testVector.appendArray(0) + testVector.appendArray(3) + appendPair(3) + appendPair(4) + appendPair(5) + + assert(testVector.getMap(0).keyArray().toIntArray === Array(0)) + assert(testVector.getMap(0).valueArray().toArray[UTF8String](StringType) === + Array(UTF8String.fromString(s"str0"))) + assert(testVector.getMap(1).keyArray().toIntArray === Array(1, 2)) + assert(testVector.getMap(1).valueArray().toArray[UTF8String](StringType) === + (1 to 2).map(i => UTF8String.fromString(s"str$i")).toArray) + assert(testVector.getMap(2).keyArray().toIntArray === Array.empty[Int]) + assert(testVector.getMap(2).valueArray().toArray[UTF8String](StringType) === + Array.empty[UTF8String]) + assert(testVector.getMap(3).keyArray().toIntArray === Array(3, 4, 5)) + assert(testVector.getMap(3).valueArray().toArray[UTF8String](StringType) === + (3 to 5).map(i => UTF8String.fromString(s"str$i")).toArray) + } + val structType: StructType = new StructType().add("int", IntegerType).add("double", DoubleType) testVectors("struct", 10, structType) { testVector => val c1 = testVector.getChild(0)