Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

import org.apache.avro.Schema
import org.apache.avro.generic.GenericDatumReader
import org.apache.avro.io.{BinaryDecoder, DecoderFactory}

import org.apache.spark.sql.avro.{AvroDeserializer, SchemaConverters}
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.types.{AbstractDataType, BinaryType, DataType}

case class AvroDataToCatalyst(child: Expression, jsonFormatSchema: String)
extends UnaryExpression with ExpectsInputTypes {

override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType)

override lazy val dataType: DataType = SchemaConverters.toSqlType(avroSchema).dataType

override def nullable: Boolean = true

@transient private lazy val avroSchema = new Schema.Parser().parse(jsonFormatSchema)

@transient private lazy val reader = new GenericDatumReader[Any](avroSchema)

@transient private lazy val deserializer = new AvroDeserializer(avroSchema, dataType)

@transient private var decoder: BinaryDecoder = _

@transient private var result: Any = _

override def nullSafeEval(input: Any): Any = {
val binary = input.asInstanceOf[Array[Byte]]
decoder = DecoderFactory.get().binaryDecoder(binary, 0, binary.length, decoder)
result = reader.read(result, decoder)
deserializer.deserialize(result)
}

override def simpleString: String = {
s"from_avro(${child.sql}, ${dataType.simpleString})"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we use catalogString for datatype?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC simpleString will be used in the plan string and should not be too long.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but this is being used in sql though. Do we prefer truncated string form in sql too?

Copy link
Contributor

@cloud-fan cloud-fan Jul 20, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not used in sql, we can override sql here and use the untruncated version.

}

override def sql: String = {
s"from_avro(${child.sql}, ${dataType.catalogString})"
}

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val expr = ctx.addReferenceObj("this", this)
defineCodeGen(ctx, ev, input =>
s"(${CodeGenerator.boxedType(dataType)})$expr.nullSafeEval($input)")
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

import java.io.ByteArrayOutputStream

import org.apache.avro.generic.GenericDatumWriter
import org.apache.avro.io.{BinaryEncoder, EncoderFactory}

import org.apache.spark.sql.avro.{AvroSerializer, SchemaConverters}
import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.types.{BinaryType, DataType}

case class CatalystDataToAvro(child: Expression) extends UnaryExpression {

override def dataType: DataType = BinaryType

@transient private lazy val avroType =
SchemaConverters.toAvroType(child.dataType, child.nullable)

@transient private lazy val serializer =
new AvroSerializer(child.dataType, avroType, child.nullable)

@transient private lazy val writer =
new GenericDatumWriter[Any](avroType)

@transient private var encoder: BinaryEncoder = _

@transient private lazy val out = new ByteArrayOutputStream

override def nullSafeEval(input: Any): Any = {
out.reset()
encoder = EncoderFactory.get().directBinaryEncoder(out, encoder)
val avroData = serializer.serialize(input)
writer.write(avroData, encoder)
encoder.flush()
out.toByteArray
}

override def simpleString: String = {
s"to_avro(${child.sql}, ${child.dataType.simpleString})"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto for catalogString

}

override def sql: String = {
s"to_avro(${child.sql}, ${child.dataType.catalogString})"
}

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val expr = ctx.addReferenceObj("this", this)
defineCodeGen(ctx, ev, input =>
s"(byte[]) $expr.nullSafeEval($input)")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@

package org.apache.spark.sql

import org.apache.avro.Schema

import org.apache.spark.annotation.Experimental

package object avro {
/**
* Adds a method, `avro`, to DataFrameWriter that allows you to write avro files using
Expand All @@ -36,4 +40,31 @@ package object avro {
@scala.annotation.varargs
def avro(sources: String*): DataFrame = reader.format("avro").load(sources: _*)
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why these two functions are not a part of org.apache.spark.sql.functions?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because avro data source is an external package like kafka data source. It's not available in org.apache.spark.sql.functions

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, having a function depending on external package in org.apache.spark.sql.functions would be weird.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks guys for your explanation!

/**
* Converts a binary column of avro format into its corresponding catalyst value. The specified
* schema must match the read data, otherwise the behavior is undefined: it may fail or return
* arbitrary result.
*
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we add @since?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

* @param data the binary column.
* @param jsonFormatSchema the avro schema in JSON string format.
*
* @since 2.4.0
*/
@Experimental
def from_avro(data: Column, jsonFormatSchema: String): Column = {
new Column(AvroDataToCatalyst(data.expr, jsonFormatSchema))
}

/**
* Converts a column into binary of avro format.
*
* @param data the data column.
*
* @since 2.4.0
*/
@Experimental
def to_avro(data: Column): Column = {
new Column(CatalystDataToAvro(data.expr))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.avro

import org.apache.avro.Schema

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.{AvroDataToCatalyst, CatalystDataToAvro, RandomDataGenerator}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper, GenericInternalRow, Literal}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

class AvroCatalystDataConversionSuite extends SparkFunSuite with ExpressionEvalHelper {

private def roundTripTest(data: Literal): Unit = {
val avroType = SchemaConverters.toAvroType(data.dataType, data.nullable)
checkResult(data, avroType.toString, data.eval())
}

private def checkResult(data: Literal, schema: String, expected: Any): Unit = {
checkEvaluation(
AvroDataToCatalyst(CatalystDataToAvro(data), schema),
prepareExpectedResult(expected))
}

private def assertFail(data: Literal, schema: String): Unit = {
intercept[java.io.EOFException] {
AvroDataToCatalyst(CatalystDataToAvro(data), schema).eval()
}
}

private val testingTypes = Seq(
BooleanType,
ByteType,
ShortType,
IntegerType,
LongType,
FloatType,
DoubleType,
DecimalType(8, 0), // 32 bits decimal without fraction
DecimalType(8, 4), // 32 bits decimal
DecimalType(16, 0), // 64 bits decimal without fraction
DecimalType(16, 11), // 64 bits decimal
DecimalType(38, 0),
DecimalType(38, 38),
StringType,
BinaryType)

protected def prepareExpectedResult(expected: Any): Any = expected match {
// Spark decimal is converted to avro string=
case d: Decimal => UTF8String.fromString(d.toString)
// Spark byte and short both map to avro int
case b: Byte => b.toInt
case s: Short => s.toInt
case row: GenericInternalRow => InternalRow.fromSeq(row.values.map(prepareExpectedResult))
case array: GenericArrayData => new GenericArrayData(array.array.map(prepareExpectedResult))
case map: MapData =>
val keys = new GenericArrayData(
map.keyArray().asInstanceOf[GenericArrayData].array.map(prepareExpectedResult))
val values = new GenericArrayData(
map.valueArray().asInstanceOf[GenericArrayData].array.map(prepareExpectedResult))
new ArrayBasedMapData(keys, values)
case other => other
}

testingTypes.foreach { dt =>
val seed = scala.util.Random.nextLong()
test(s"single $dt with seed $seed") {
val rand = new scala.util.Random(seed)
val data = RandomDataGenerator.forType(dt, rand = rand).get.apply()
val converter = CatalystTypeConverters.createToCatalystConverter(dt)
val input = Literal.create(converter(data), dt)
roundTripTest(input)
}
}

for (_ <- 1 to 5) {
val seed = scala.util.Random.nextLong()
val rand = new scala.util.Random(seed)
val schema = RandomDataGenerator.randomSchema(rand, 5, testingTypes)
test(s"flat schema ${schema.catalogString} with seed $seed") {
val data = RandomDataGenerator.randomRow(rand, schema)
val converter = CatalystTypeConverters.createToCatalystConverter(schema)
val input = Literal.create(converter(data), schema)
roundTripTest(input)
}
}

for (_ <- 1 to 5) {
val seed = scala.util.Random.nextLong()
val rand = new scala.util.Random(seed)
val schema = RandomDataGenerator.randomNestedSchema(rand, 10, testingTypes)
test(s"nested schema ${schema.catalogString} with seed $seed") {
val data = RandomDataGenerator.randomRow(rand, schema)
val converter = CatalystTypeConverters.createToCatalystConverter(schema)
val input = Literal.create(converter(data), schema)
roundTripTest(input)
}
}

test("read int as string") {
val data = Literal(1)
val avroTypeJson =
s"""
|{
| "type": "string",
| "name": "my_string"
|}
""".stripMargin

// When read int as string, avro reader is not able to parse the binary and fail.
assertFail(data, avroTypeJson)
}

test("read string as int") {
val data = Literal("abc")
val avroTypeJson =
s"""
|{
| "type": "int",
| "name": "my_int"
|}
""".stripMargin

// When read string data as int, avro reader is not able to find the type mismatch and read
// the string length as int value.
checkResult(data, avroTypeJson, 3)
}

test("read float as double") {
val data = Literal(1.23f)
val avroTypeJson =
s"""
|{
| "type": "double",
| "name": "my_double"
|}
""".stripMargin

// When read float data as double, avro reader fails(trying to read 8 bytes while the data have
// only 4 bytes).
assertFail(data, avroTypeJson)
}

test("read double as float") {
val data = Literal(1.23)
val avroTypeJson =
s"""
|{
| "type": "float",
| "name": "my_float"
|}
""".stripMargin

// avro reader reads the first 4 bytes of a double as a float, the result is totally undefined.
checkResult(data, avroTypeJson, 5.848603E35f)
}
}
Loading