@@ -45,7 +45,7 @@ import org.apache.spark.sql.util.ArrowUtils
4545import org .apache .spark .sql .vectorized .{ArrowColumnVector , ColumnarBatch , ColumnarBatchRow , ColumnVector }
4646import org .apache .spark .tags .ExtendedSQLTest
4747import org .apache .spark .unsafe .Platform
48- import org .apache .spark .unsafe .types .{CalendarInterval , UTF8String }
48+ import org .apache .spark .unsafe .types .{CalendarInterval , UTF8String , VariantVal }
4949import org .apache .spark .util .ArrayImplicits ._
5050
5151@ ExtendedSQLTest
@@ -1650,6 +1650,7 @@ class ColumnarBatchSuite extends SparkFunSuite {
16501650 StructField (" int_to_int" , MapType (IntegerType , IntegerType )) ::
16511651 StructField (" binary" , BinaryType ) ::
16521652 StructField (" ts_ntz" , TimestampNTZType ) ::
1653+ StructField (" variant" , VariantType ) ::
16531654 Nil )
16541655 var mapBuilder = new ArrayBasedMapBuilder (IntegerType , IntegerType )
16551656 mapBuilder.put(1 , 10 )
@@ -1664,6 +1665,9 @@ class ColumnarBatchSuite extends SparkFunSuite {
16641665 val tsNTZ2 =
16651666 DateTimeUtils .localDateTimeToMicros(LocalDateTime .parse(tsString2.replace(" " , " T" )))
16661667
1668+ val variantVal1 = new VariantVal (Array [Byte ](1 , 2 , 3 ), Array [Byte ](4 , 5 ))
1669+ val variantVal2 = new VariantVal (Array [Byte ](6 ), Array [Byte ](7 , 8 ))
1670+
16671671 val row1 = new GenericInternalRow (Array [Any ](
16681672 UTF8String .fromString(" a string" ),
16691673 true ,
@@ -1681,7 +1685,8 @@ class ColumnarBatchSuite extends SparkFunSuite {
16811685 new GenericInternalRow (Array [Any ](5 .asInstanceOf [Any ], 10 )),
16821686 mapBuilder.build(),
16831687 " Spark SQL" .getBytes(),
1684- tsNTZ1
1688+ tsNTZ1,
1689+ variantVal1
16851690 ))
16861691
16871692 mapBuilder = new ArrayBasedMapBuilder (IntegerType , IntegerType )
@@ -1704,7 +1709,8 @@ class ColumnarBatchSuite extends SparkFunSuite {
17041709 new GenericInternalRow (Array [Any ](20 .asInstanceOf [Any ], null )),
17051710 mapBuilder.build(),
17061711 " Parquet" .getBytes(),
1707- tsNTZ2
1712+ tsNTZ2,
1713+ variantVal2
17081714 ))
17091715
17101716 val row3 = new GenericInternalRow (Array [Any ](
@@ -1724,6 +1730,7 @@ class ColumnarBatchSuite extends SparkFunSuite {
17241730 null ,
17251731 null ,
17261732 null ,
1733+ null ,
17271734 null
17281735 ))
17291736
@@ -1852,6 +1859,13 @@ class ColumnarBatchSuite extends SparkFunSuite {
18521859 assert(columns(16 ).getLong(0 ) == tsNTZ1)
18531860 assert(columns(16 ).getLong(1 ) == tsNTZ2)
18541861 assert(columns(16 ).isNullAt(2 ))
1862+
1863+ assert(columns(17 ).dataType() == VariantType )
1864+ assert(columns(17 ).getVariant(0 ).debugString() == variantVal1.debugString())
1865+ assert(columns(17 ).getVariant(1 ).debugString() == variantVal2.debugString())
1866+ assert(columns(17 ).isNullAt(2 ))
1867+ assert(columns(17 ).getChild(0 ).isNullAt(2 ))
1868+ assert(columns(17 ).getChild(1 ).isNullAt(2 ))
18551869 } finally {
18561870 batch.close()
18571871 }
0 commit comments