diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala index 4e2224b058a0..baaccedd2d53 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala @@ -225,6 +225,7 @@ class AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType) { case (UNION, _) => val allTypes = avroType.getTypes.asScala val nonNullTypes = allTypes.filter(_.getType != NULL) + val nonNullAvroType = Schema.createUnion(nonNullTypes.asJava) if (nonNullTypes.nonEmpty) { if (nonNullTypes.length == 1) { newWriter(nonNullTypes.head, catalystType, path) @@ -253,7 +254,7 @@ class AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType) { (updater, ordinal, value) => { val row = new SpecificInternalRow(st) val fieldUpdater = new RowUpdater(row) - val i = GenericData.get().resolveUnion(avroType, value) + val i = GenericData.get().resolveUnion(nonNullAvroType, value) fieldWriters(i)(fieldUpdater, i, value) updater.set(ordinal, row) } diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index 81a5cb7cd31b..b3f5248bae4e 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -247,6 +247,32 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } } + test("SPARK-27858 Union type: More than one non-null type") { + withTempDir { dir => + val complexNullUnionType = Schema.createUnion( + List(Schema.create(Type.INT), Schema.create(Type.NULL), Schema.create(Type.STRING)).asJava) + val fields = Seq( + new Field("field1", complexNullUnionType, "doc", null.asInstanceOf[AnyVal])).asJava + val schema = Schema.createRecord("name", "docs", "namespace", false) + schema.setFields(fields) + val datumWriter = new GenericDatumWriter[GenericRecord](schema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + dataFileWriter.create(schema, new File(s"$dir.avro")) + val avroRec = new GenericData.Record(schema) + avroRec.put("field1", 42) + dataFileWriter.append(avroRec) + val avroRec2 = new GenericData.Record(schema) + avroRec2.put("field1", "Alice") + dataFileWriter.append(avroRec2) + dataFileWriter.flush() + dataFileWriter.close() + + val df = spark.read.format("avro").load(s"$dir.avro") + assert(df.schema === StructType.fromDDL("field1 struct")) + assert(df.collect().toSet == Set(Row(Row(42, null)), Row(Row(null, "Alice")))) + } + } + test("Complex Union Type") { withTempPath { dir => val fixedSchema = Schema.createFixed("fixed_name", "doc", "namespace", 4)