Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 53 additions & 61 deletions external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,32 +31,24 @@ import org.apache.avro.generic.{GenericData, GenericDatumWriter, GenericRecord}
import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed}
import org.apache.commons.io.FileUtils

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql._
import org.apache.spark.sql.avro.SchemaConverters.IncompatibleSchemaException
import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
import org.apache.spark.sql.types._

class AvroSuite extends SparkFunSuite {
class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
val episodesFile = "src/test/resources/episodes.avro"
val testFile = "src/test/resources/test.avro"

private var spark: SparkSession = _

override protected def beforeAll(): Unit = {
super.beforeAll()
spark = SparkSession.builder()
.master("local[2]")
.appName("AvroSuite")
.config("spark.sql.files.maxPartitionBytes", 1024)
.getOrCreate()
}

override protected def afterAll(): Unit = {
try {
spark.sparkContext.stop()
} finally {
super.afterAll()
}
spark.conf.set("spark.sql.files.maxPartitionBytes", 1024)
}

def checkReloadMatchesSaved(originalFile: String, newFile: String): Unit = {
val originalEntries = spark.read.avro(testFile).collect()
val newEntries = spark.read.avro(newFile)
checkAnswer(newEntries, originalEntries)
}

test("reading from multiple paths") {
Expand All @@ -68,7 +60,7 @@ class AvroSuite extends SparkFunSuite {
val df = spark.read.avro(episodesFile)
val fields = List("title", "air_date", "doctor")
for (field <- fields) {
TestUtils.withTempDir { dir =>
withTempPath { dir =>
val outputDir = s"$dir/${UUID.randomUUID}"
df.write.partitionBy(field).avro(outputDir)
val input = spark.read.avro(outputDir)
Expand All @@ -82,28 +74,29 @@ class AvroSuite extends SparkFunSuite {

test("request no fields") {
val df = spark.read.avro(episodesFile)
df.registerTempTable("avro_table")
df.createOrReplaceTempView("avro_table")
assert(spark.sql("select count(*) from avro_table").collect().head === Row(8))
}

test("convert formats") {
TestUtils.withTempDir { dir =>
withTempPath { dir =>
val df = spark.read.avro(episodesFile)
df.write.parquet(dir.getCanonicalPath)
assert(spark.read.parquet(dir.getCanonicalPath).count() === df.count)
}
}

test("rearrange internal schema") {
TestUtils.withTempDir { dir =>
withTempPath { dir =>
val df = spark.read.avro(episodesFile)
df.select("doctor", "title").write.avro(dir.getCanonicalPath)
}
}

test("test NULL avro type") {
TestUtils.withTempDir { dir =>
val fields = Seq(new Field("null", Schema.create(Type.NULL), "doc", null)).asJava
withTempPath { dir =>
val fields =
Seq(new Field("null", Schema.create(Type.NULL), "doc", null.asInstanceOf[Any])).asJava
val schema = Schema.createRecord("name", "docs", "namespace", false)
schema.setFields(fields)
val datumWriter = new GenericDatumWriter[GenericRecord](schema)
Expand All @@ -122,11 +115,11 @@ class AvroSuite extends SparkFunSuite {
}

test("union(int, long) is read as long") {
TestUtils.withTempDir { dir =>
withTempPath { dir =>
val avroSchema: Schema = {
val union =
Schema.createUnion(List(Schema.create(Type.INT), Schema.create(Type.LONG)).asJava)
val fields = Seq(new Field("field1", union, "doc", null)).asJava
val fields = Seq(new Field("field1", union, "doc", null.asInstanceOf[Any])).asJava
val schema = Schema.createRecord("name", "docs", "namespace", false)
schema.setFields(fields)
schema
Expand All @@ -150,11 +143,11 @@ class AvroSuite extends SparkFunSuite {
}

test("union(float, double) is read as double") {
TestUtils.withTempDir { dir =>
withTempPath { dir =>
val avroSchema: Schema = {
val union =
Schema.createUnion(List(Schema.create(Type.FLOAT), Schema.create(Type.DOUBLE)).asJava)
val fields = Seq(new Field("field1", union, "doc", null)).asJava
val fields = Seq(new Field("field1", union, "doc", null.asInstanceOf[Any])).asJava
val schema = Schema.createRecord("name", "docs", "namespace", false)
schema.setFields(fields)
schema
Expand All @@ -178,15 +171,15 @@ class AvroSuite extends SparkFunSuite {
}

test("union(float, double, null) is read as nullable double") {
TestUtils.withTempDir { dir =>
withTempPath { dir =>
val avroSchema: Schema = {
val union = Schema.createUnion(
List(Schema.create(Type.FLOAT),
Schema.create(Type.DOUBLE),
Schema.create(Type.NULL)
).asJava
)
val fields = Seq(new Field("field1", union, "doc", null)).asJava
val fields = Seq(new Field("field1", union, "doc", null.asInstanceOf[Any])).asJava
val schema = Schema.createRecord("name", "docs", "namespace", false)
schema.setFields(fields)
schema
Expand All @@ -210,9 +203,9 @@ class AvroSuite extends SparkFunSuite {
}

test("Union of a single type") {
TestUtils.withTempDir { dir =>
withTempPath { dir =>
val UnionOfOne = Schema.createUnion(List(Schema.create(Type.INT)).asJava)
val fields = Seq(new Field("field1", UnionOfOne, "doc", null)).asJava
val fields = Seq(new Field("field1", UnionOfOne, "doc", null.asInstanceOf[Any])).asJava
val schema = Schema.createRecord("name", "docs", "namespace", false)
schema.setFields(fields)

Expand All @@ -233,16 +226,16 @@ class AvroSuite extends SparkFunSuite {
}

test("Complex Union Type") {
TestUtils.withTempDir { dir =>
withTempPath { dir =>
val fixedSchema = Schema.createFixed("fixed_name", "doc", "namespace", 4)
val enumSchema = Schema.createEnum("enum_name", "doc", "namespace", List("e1", "e2").asJava)
val complexUnionType = Schema.createUnion(
List(Schema.create(Type.INT), Schema.create(Type.STRING), fixedSchema, enumSchema).asJava)
val fields = Seq(
new Field("field1", complexUnionType, "doc", null),
new Field("field2", complexUnionType, "doc", null),
new Field("field3", complexUnionType, "doc", null),
new Field("field4", complexUnionType, "doc", null)
new Field("field1", complexUnionType, "doc", null.asInstanceOf[Any]),
new Field("field2", complexUnionType, "doc", null.asInstanceOf[Any]),
new Field("field3", complexUnionType, "doc", null.asInstanceOf[Any]),
new Field("field4", complexUnionType, "doc", null.asInstanceOf[Any])
).asJava
val schema = Schema.createRecord("name", "docs", "namespace", false)
schema.setFields(fields)
Expand Down Expand Up @@ -271,7 +264,7 @@ class AvroSuite extends SparkFunSuite {
}

test("Lots of nulls") {
TestUtils.withTempDir { dir =>
withTempPath { dir =>
val schema = StructType(Seq(
StructField("binary", BinaryType, true),
StructField("timestamp", TimestampType, true),
Expand All @@ -290,7 +283,7 @@ class AvroSuite extends SparkFunSuite {
}

test("Struct field type") {
TestUtils.withTempDir { dir =>
withTempPath { dir =>
val schema = StructType(Seq(
StructField("float", FloatType, true),
StructField("short", ShortType, true),
Expand All @@ -309,7 +302,7 @@ class AvroSuite extends SparkFunSuite {
}

test("Date field type") {
TestUtils.withTempDir { dir =>
withTempPath { dir =>
val schema = StructType(Seq(
StructField("float", FloatType, true),
StructField("date", DateType, true)
Expand All @@ -329,7 +322,7 @@ class AvroSuite extends SparkFunSuite {
}

test("Array data types") {
TestUtils.withTempDir { dir =>
withTempPath { dir =>
val testSchema = StructType(Seq(
StructField("byte_array", ArrayType(ByteType), true),
StructField("short_array", ArrayType(ShortType), true),
Expand Down Expand Up @@ -363,13 +356,12 @@ class AvroSuite extends SparkFunSuite {
}

test("write with compression") {
TestUtils.withTempDir { dir =>
withTempPath { dir =>
val AVRO_COMPRESSION_CODEC = "spark.sql.avro.compression.codec"
val AVRO_DEFLATE_LEVEL = "spark.sql.avro.deflate.level"
val uncompressDir = s"$dir/uncompress"
val deflateDir = s"$dir/deflate"
val snappyDir = s"$dir/snappy"
val fakeDir = s"$dir/fake"

val df = spark.read.avro(testFile)
spark.conf.set(AVRO_COMPRESSION_CODEC, "uncompressed")
Expand Down Expand Up @@ -439,7 +431,7 @@ class AvroSuite extends SparkFunSuite {
test("sql test") {
spark.sql(
s"""
|CREATE TEMPORARY TABLE avroTable
|CREATE TEMPORARY VIEW avroTable
|USING avro
|OPTIONS (path "$episodesFile")
""".stripMargin.replaceAll("\n", " "))
Expand All @@ -450,24 +442,24 @@ class AvroSuite extends SparkFunSuite {
test("conversion to avro and back") {
// Note that test.avro includes a variety of types, some of which are nullable. We expect to
// get the same values back.
TestUtils.withTempDir { dir =>
withTempPath { dir =>
val avroDir = s"$dir/avro"
spark.read.avro(testFile).write.avro(avroDir)
TestUtils.checkReloadMatchesSaved(spark, testFile, avroDir)
checkReloadMatchesSaved(testFile, avroDir)
}
}

test("conversion to avro and back with namespace") {
// Note that test.avro includes a variety of types, some of which are nullable. We expect to
// get the same values back.
TestUtils.withTempDir { tempDir =>
withTempPath { tempDir =>
val name = "AvroTest"
val namespace = "com.databricks.spark.avro"
val parameters = Map("recordName" -> name, "recordNamespace" -> namespace)

val avroDir = tempDir + "/namedAvro"
spark.read.avro(testFile).write.options(parameters).avro(avroDir)
TestUtils.checkReloadMatchesSaved(spark, testFile, avroDir)
checkReloadMatchesSaved(testFile, avroDir)

// Look at raw file and make sure has namespace info
val rawSaved = spark.sparkContext.textFile(avroDir)
Expand All @@ -478,7 +470,7 @@ class AvroSuite extends SparkFunSuite {
}

test("converting some specific sparkSQL types to avro") {
TestUtils.withTempDir { tempDir =>
withTempPath { tempDir =>
val testSchema = StructType(Seq(
StructField("Name", StringType, false),
StructField("Length", IntegerType, true),
Expand Down Expand Up @@ -520,7 +512,7 @@ class AvroSuite extends SparkFunSuite {
}

test("correctly read long as date/timestamp type") {
TestUtils.withTempDir { tempDir =>
withTempPath { tempDir =>
val sparkSession = spark
import sparkSession.implicits._

Expand Down Expand Up @@ -549,7 +541,7 @@ class AvroSuite extends SparkFunSuite {
}

test("does not coerce null date/timestamp value to 0 epoch.") {
TestUtils.withTempDir { tempDir =>
withTempPath { tempDir =>
val sparkSession = spark
import sparkSession.implicits._

Expand Down Expand Up @@ -610,7 +602,7 @@ class AvroSuite extends SparkFunSuite {

// Directory given has no avro files
intercept[AnalysisException] {
TestUtils.withTempDir(dir => spark.read.avro(dir.getCanonicalPath))
withTempPath(dir => spark.read.avro(dir.getCanonicalPath))
}

intercept[AnalysisException] {
Expand All @@ -624,7 +616,7 @@ class AvroSuite extends SparkFunSuite {
}

intercept[FileNotFoundException] {
TestUtils.withTempDir { dir =>
withTempPath { dir =>
FileUtils.touch(new File(dir, "test"))
spark.read.avro(dir.toString)
}
Expand All @@ -633,19 +625,19 @@ class AvroSuite extends SparkFunSuite {
}

test("SQL test insert overwrite") {
TestUtils.withTempDir { tempDir =>
withTempPath { tempDir =>
val tempEmptyDir = s"$tempDir/sqlOverwrite"
// Create a temp directory for table that will be overwritten
new File(tempEmptyDir).mkdirs()
spark.sql(
s"""
|CREATE TEMPORARY TABLE episodes
|CREATE TEMPORARY VIEW episodes
|USING avro
|OPTIONS (path "$episodesFile")
""".stripMargin.replaceAll("\n", " "))
spark.sql(
s"""
|CREATE TEMPORARY TABLE episodesEmpty
|CREATE TEMPORARY VIEW episodesEmpty
|(name string, air_date string, doctor int)
|USING avro
|OPTIONS (path "$tempEmptyDir")
Expand All @@ -665,7 +657,7 @@ class AvroSuite extends SparkFunSuite {

test("test save and load") {
// Test if load works as expected
TestUtils.withTempDir { tempDir =>
withTempPath { tempDir =>
val df = spark.read.avro(episodesFile)
assert(df.count == 8)

Expand All @@ -679,7 +671,7 @@ class AvroSuite extends SparkFunSuite {

test("test load with non-Avro file") {
// Test if load works as expected
TestUtils.withTempDir { tempDir =>
withTempPath { tempDir =>
val df = spark.read.avro(episodesFile)
assert(df.count == 8)

Expand Down Expand Up @@ -737,7 +729,7 @@ class AvroSuite extends SparkFunSuite {
}

test("read avro file partitioned") {
TestUtils.withTempDir { dir =>
withTempPath { dir =>
val sparkSession = spark
import sparkSession.implicits._
val df = (0 to 1024 * 3).toDS.map(i => s"record${i}").toDF("records")
Expand All @@ -756,7 +748,7 @@ class AvroSuite extends SparkFunSuite {
case class NestedTop(id: Int, data: NestedMiddle)

test("saving avro that has nested records with the same name") {
TestUtils.withTempDir { tempDir =>
withTempPath { tempDir =>
// Save avro file on output folder path
val writeDf = spark.createDataFrame(List(NestedTop(1, NestedMiddle(2, NestedBottom(3, "1")))))
val outputFolder = s"$tempDir/duplicate_names/"
Expand All @@ -773,7 +765,7 @@ class AvroSuite extends SparkFunSuite {
case class NestedTopArray(id: Int, data: NestedMiddleArray)

test("saving avro that has nested records with the same name inside an array") {
TestUtils.withTempDir { tempDir =>
withTempPath { tempDir =>
// Save avro file on output folder path
val writeDf = spark.createDataFrame(
List(NestedTopArray(1, NestedMiddleArray(2, Array(
Expand All @@ -794,7 +786,7 @@ class AvroSuite extends SparkFunSuite {
case class NestedTopMap(id: Int, data: NestedMiddleMap)

test("saving avro that has nested records with the same name inside a map") {
TestUtils.withTempDir { tempDir =>
withTempPath { tempDir =>
// Save avro file on output folder path
val writeDf = spark.createDataFrame(
List(NestedTopMap(1, NestedMiddleMap(2, Map(
Expand Down
Loading