Skip to content

Commit ec914dc

Browse files
committed
Merge remote-tracking branch 'upstream/master' into refactorDDLSuite
2 parents d19d725 + 030acdd commit ec914dc

File tree

30 files changed

+1516
-430
lines changed

30 files changed

+1516
-430
lines changed

R/pkg/R/DataFrame.R

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2642,6 +2642,7 @@ generateAliasesForIntersectedCols <- function (x, intersectedColNames, suffix) {
26422642
#'
26432643
#' Return a new SparkDataFrame containing the union of rows in this SparkDataFrame
26442644
#' and another SparkDataFrame. This is equivalent to \code{UNION ALL} in SQL.
2645+
#' Input SparkDataFrames can have different schemas (names and data types).
26452646
#'
26462647
#' Note: This does not remove duplicate rows across the two SparkDataFrames.
26472648
#'
@@ -2685,7 +2686,8 @@ setMethod("unionAll",
26852686

26862687
#' Union two or more SparkDataFrames
26872688
#'
2688-
#' Union two or more SparkDataFrames. This is equivalent to \code{UNION ALL} in SQL.
2689+
#' Union two or more SparkDataFrames by row. As in R's \code{rbind}, this method
2690+
#' requires that the input SparkDataFrames have the same column names.
26892691
#'
26902692
#' Note: This does not remove duplicate rows across the two SparkDataFrames.
26912693
#'
@@ -2709,6 +2711,10 @@ setMethod("unionAll",
27092711
setMethod("rbind",
27102712
signature(... = "SparkDataFrame"),
27112713
function(x, ..., deparse.level = 1) {
2714+
nm <- lapply(list(x, ...), names)
2715+
if (length(unique(nm)) != 1) {
2716+
stop("Names of input data frames are different.")
2717+
}
27122718
if (nargs() == 3) {
27132719
union(x, ...)
27142720
} else {

R/pkg/inst/tests/testthat/test_sparkSQL.R

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1850,6 +1850,13 @@ test_that("union(), rbind(), except(), and intersect() on a DataFrame", {
18501850
expect_equal(count(unioned2), 12)
18511851
expect_equal(first(unioned2)$name, "Michael")
18521852

1853+
df3 <- df2
1854+
names(df3)[1] <- "newName"
1855+
expect_error(rbind(df, df3),
1856+
"Names of input data frames are different.")
1857+
expect_error(rbind(df, df2, df3),
1858+
"Names of input data frames are different.")
1859+
18531860
excepted <- arrange(except(df, df2), desc(df$age))
18541861
expect_is(unioned, "SparkDataFrame")
18551862
expect_equal(count(excepted), 2)

core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.scheduler
1919

2020
import java.io.{DataInputStream, DataOutputStream}
2121
import java.nio.ByteBuffer
22+
import java.nio.charset.StandardCharsets
2223
import java.util.Properties
2324

2425
import scala.collection.JavaConverters._
@@ -86,7 +87,10 @@ private[spark] object TaskDescription {
8687
dataOut.writeInt(taskDescription.properties.size())
8788
taskDescription.properties.asScala.foreach { case (key, value) =>
8889
dataOut.writeUTF(key)
89-
dataOut.writeUTF(value)
90+
// SPARK-19796 -- writeUTF doesn't work for long strings, which can happen for property values
91+
val bytes = value.getBytes(StandardCharsets.UTF_8)
92+
dataOut.writeInt(bytes.length)
93+
dataOut.write(bytes)
9094
}
9195

9296
// Write the task. The task is already serialized, so write it directly to the byte buffer.
@@ -124,7 +128,11 @@ private[spark] object TaskDescription {
124128
val properties = new Properties()
125129
val numProperties = dataIn.readInt()
126130
for (i <- 0 until numProperties) {
127-
properties.setProperty(dataIn.readUTF(), dataIn.readUTF())
131+
val key = dataIn.readUTF()
132+
val valueLength = dataIn.readInt()
133+
val valueBytes = new Array[Byte](valueLength)
134+
dataIn.readFully(valueBytes)
135+
properties.setProperty(key, new String(valueBytes, StandardCharsets.UTF_8))
128136
}
129137

130138
// Create a sub-buffer for the serialized task into its own buffer (to be deserialized later).

core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.scheduler
1919

20+
import java.io.{ByteArrayOutputStream, DataOutputStream, UTFDataFormatException}
2021
import java.nio.ByteBuffer
2122
import java.util.Properties
2223

@@ -36,6 +37,21 @@ class TaskDescriptionSuite extends SparkFunSuite {
3637
val originalProperties = new Properties()
3738
originalProperties.put("property1", "18")
3839
originalProperties.put("property2", "test value")
40+
// SPARK-19796 -- large property values (like a large job description for a long sql query)
41+
// can cause problems for DataOutputStream, make sure we handle correctly
42+
val sb = new StringBuilder()
43+
(0 to 10000).foreach(_ => sb.append("1234567890"))
44+
val largeString = sb.toString()
45+
originalProperties.put("property3", largeString)
46+
// make sure we've got a good test case
47+
intercept[UTFDataFormatException] {
48+
val out = new DataOutputStream(new ByteArrayOutputStream())
49+
try {
50+
out.writeUTF(largeString)
51+
} finally {
52+
out.close()
53+
}
54+
}
3955

4056
// Create a dummy byte buffer for the task.
4157
val taskBuffer = ByteBuffer.wrap(Array[Byte](1, 2, 3, 4))
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.kafka010
19+
20+
import java.{util => ju}
21+
22+
import org.apache.spark.internal.Logging
23+
import org.apache.spark.sql.{DataFrame, SQLContext}
24+
import org.apache.spark.sql.execution.streaming.Sink
25+
26+
private[kafka010] class KafkaSink(
27+
sqlContext: SQLContext,
28+
executorKafkaParams: ju.Map[String, Object],
29+
topic: Option[String]) extends Sink with Logging {
30+
@volatile private var latestBatchId = -1L
31+
32+
override def toString(): String = "KafkaSink"
33+
34+
override def addBatch(batchId: Long, data: DataFrame): Unit = {
35+
if (batchId <= latestBatchId) {
36+
logInfo(s"Skipping already committed batch $batchId")
37+
} else {
38+
KafkaWriter.write(sqlContext.sparkSession,
39+
data.queryExecution, executorKafkaParams, topic)
40+
latestBatchId = batchId
41+
}
42+
}
43+
}

external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala

Lines changed: 78 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,21 +23,27 @@ import java.util.UUID
2323
import scala.collection.JavaConverters._
2424

2525
import org.apache.kafka.clients.consumer.ConsumerConfig
26-
import org.apache.kafka.common.serialization.ByteArrayDeserializer
26+
import org.apache.kafka.clients.producer.ProducerConfig
27+
import org.apache.kafka.common.serialization.{ByteArrayDeserializer, ByteArraySerializer}
2728

2829
import org.apache.spark.internal.Logging
29-
import org.apache.spark.sql.SQLContext
30-
import org.apache.spark.sql.execution.streaming.Source
30+
import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext}
31+
import org.apache.spark.sql.execution.streaming.{Sink, Source}
3132
import org.apache.spark.sql.sources._
33+
import org.apache.spark.sql.streaming.OutputMode
3234
import org.apache.spark.sql.types.StructType
3335

3436
/**
3537
* The provider class for the [[KafkaSource]]. This provider is designed such that it throws
3638
* IllegalArgumentException when the Kafka Dataset is created, so that it can catch
3739
* missing options even before the query is started.
3840
*/
39-
private[kafka010] class KafkaSourceProvider extends DataSourceRegister with StreamSourceProvider
40-
with RelationProvider with Logging {
41+
private[kafka010] class KafkaSourceProvider extends DataSourceRegister
42+
with StreamSourceProvider
43+
with StreamSinkProvider
44+
with RelationProvider
45+
with CreatableRelationProvider
46+
with Logging {
4147
import KafkaSourceProvider._
4248

4349
override def shortName(): String = "kafka"
@@ -152,6 +158,72 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister with Stre
152158
endingRelationOffsets)
153159
}
154160

161+
override def createSink(
162+
sqlContext: SQLContext,
163+
parameters: Map[String, String],
164+
partitionColumns: Seq[String],
165+
outputMode: OutputMode): Sink = {
166+
val defaultTopic = parameters.get(TOPIC_OPTION_KEY).map(_.trim)
167+
val specifiedKafkaParams = kafkaParamsForProducer(parameters)
168+
new KafkaSink(sqlContext,
169+
new ju.HashMap[String, Object](specifiedKafkaParams.asJava), defaultTopic)
170+
}
171+
172+
override def createRelation(
173+
outerSQLContext: SQLContext,
174+
mode: SaveMode,
175+
parameters: Map[String, String],
176+
data: DataFrame): BaseRelation = {
177+
mode match {
178+
case SaveMode.Overwrite | SaveMode.Ignore =>
179+
throw new AnalysisException(s"Save mode $mode not allowed for Kafka. " +
180+
s"Allowed save modes are ${SaveMode.Append} and " +
181+
s"${SaveMode.ErrorIfExists} (default).")
182+
case _ => // good
183+
}
184+
val topic = parameters.get(TOPIC_OPTION_KEY).map(_.trim)
185+
val specifiedKafkaParams = kafkaParamsForProducer(parameters)
186+
KafkaWriter.write(outerSQLContext.sparkSession, data.queryExecution,
187+
new ju.HashMap[String, Object](specifiedKafkaParams.asJava), topic)
188+
189+
/* This method is suppose to return a relation that reads the data that was written.
190+
* We cannot support this for Kafka. Therefore, in order to make things consistent,
191+
* we return an empty base relation.
192+
*/
193+
new BaseRelation {
194+
override def sqlContext: SQLContext = unsupportedException
195+
override def schema: StructType = unsupportedException
196+
override def needConversion: Boolean = unsupportedException
197+
override def sizeInBytes: Long = unsupportedException
198+
override def unhandledFilters(filters: Array[Filter]): Array[Filter] = unsupportedException
199+
private def unsupportedException =
200+
throw new UnsupportedOperationException("BaseRelation from Kafka write " +
201+
"operation is not usable.")
202+
}
203+
}
204+
205+
private def kafkaParamsForProducer(parameters: Map[String, String]): Map[String, String] = {
206+
val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase, v) }
207+
if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}")) {
208+
throw new IllegalArgumentException(
209+
s"Kafka option '${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}' is not supported as keys "
210+
+ "are serialized with ByteArraySerializer.")
211+
}
212+
213+
if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}"))
214+
{
215+
throw new IllegalArgumentException(
216+
s"Kafka option '${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}' is not supported as "
217+
+ "value are serialized with ByteArraySerializer.")
218+
}
219+
parameters
220+
.keySet
221+
.filter(_.toLowerCase.startsWith("kafka."))
222+
.map { k => k.drop(6).toString -> parameters(k) }
223+
.toMap + (ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName,
224+
ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName)
225+
}
226+
155227
private def kafkaParamsForDriver(specifiedKafkaParams: Map[String, String]) =
156228
ConfigUpdater("source", specifiedKafkaParams)
157229
.set(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, deserClassName)
@@ -381,6 +453,7 @@ private[kafka010] object KafkaSourceProvider {
381453
private val STARTING_OFFSETS_OPTION_KEY = "startingoffsets"
382454
private val ENDING_OFFSETS_OPTION_KEY = "endingoffsets"
383455
private val FAIL_ON_DATA_LOSS_OPTION_KEY = "failondataloss"
456+
val TOPIC_OPTION_KEY = "topic"
384457

385458
private val deserClassName = classOf[ByteArrayDeserializer].getName
386459
}
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.kafka010
19+
20+
import java.{util => ju}
21+
22+
import org.apache.kafka.clients.producer.{KafkaProducer, _}
23+
import org.apache.kafka.common.serialization.ByteArraySerializer
24+
25+
import org.apache.spark.sql.catalyst.InternalRow
26+
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, UnsafeProjection}
27+
import org.apache.spark.sql.types.{BinaryType, StringType}
28+
29+
/**
30+
* A simple trait for writing out data in a single Spark task, without any concerns about how
31+
* to commit or abort tasks. Exceptions thrown by the implementation of this class will
32+
* automatically trigger task aborts.
33+
*/
34+
private[kafka010] class KafkaWriteTask(
35+
producerConfiguration: ju.Map[String, Object],
36+
inputSchema: Seq[Attribute],
37+
topic: Option[String]) {
38+
// used to synchronize with Kafka callbacks
39+
@volatile private var failedWrite: Exception = null
40+
private val projection = createProjection
41+
private var producer: KafkaProducer[Array[Byte], Array[Byte]] = _
42+
43+
/**
44+
* Writes key value data out to topics.
45+
*/
46+
def execute(iterator: Iterator[InternalRow]): Unit = {
47+
producer = new KafkaProducer[Array[Byte], Array[Byte]](producerConfiguration)
48+
while (iterator.hasNext && failedWrite == null) {
49+
val currentRow = iterator.next()
50+
val projectedRow = projection(currentRow)
51+
val topic = projectedRow.getUTF8String(0)
52+
val key = projectedRow.getBinary(1)
53+
val value = projectedRow.getBinary(2)
54+
if (topic == null) {
55+
throw new NullPointerException(s"null topic present in the data. Use the " +
56+
s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a default topic.")
57+
}
58+
val record = new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, key, value)
59+
val callback = new Callback() {
60+
override def onCompletion(recordMetadata: RecordMetadata, e: Exception): Unit = {
61+
if (failedWrite == null && e != null) {
62+
failedWrite = e
63+
}
64+
}
65+
}
66+
producer.send(record, callback)
67+
}
68+
}
69+
70+
def close(): Unit = {
71+
if (producer != null) {
72+
checkForErrors
73+
producer.close()
74+
checkForErrors
75+
producer = null
76+
}
77+
}
78+
79+
private def createProjection: UnsafeProjection = {
80+
val topicExpression = topic.map(Literal(_)).orElse {
81+
inputSchema.find(_.name == KafkaWriter.TOPIC_ATTRIBUTE_NAME)
82+
}.getOrElse {
83+
throw new IllegalStateException(s"topic option required when no " +
84+
s"'${KafkaWriter.TOPIC_ATTRIBUTE_NAME}' attribute is present")
85+
}
86+
topicExpression.dataType match {
87+
case StringType => // good
88+
case t =>
89+
throw new IllegalStateException(s"${KafkaWriter.TOPIC_ATTRIBUTE_NAME} " +
90+
s"attribute unsupported type $t. ${KafkaWriter.TOPIC_ATTRIBUTE_NAME} " +
91+
s"must be a ${StringType}")
92+
}
93+
val keyExpression = inputSchema.find(_.name == KafkaWriter.KEY_ATTRIBUTE_NAME)
94+
.getOrElse(Literal(null, BinaryType))
95+
keyExpression.dataType match {
96+
case StringType | BinaryType => // good
97+
case t =>
98+
throw new IllegalStateException(s"${KafkaWriter.KEY_ATTRIBUTE_NAME} " +
99+
s"attribute unsupported type $t")
100+
}
101+
val valueExpression = inputSchema
102+
.find(_.name == KafkaWriter.VALUE_ATTRIBUTE_NAME).getOrElse(
103+
throw new IllegalStateException(s"Required attribute " +
104+
s"'${KafkaWriter.VALUE_ATTRIBUTE_NAME}' not found")
105+
)
106+
valueExpression.dataType match {
107+
case StringType | BinaryType => // good
108+
case t =>
109+
throw new IllegalStateException(s"${KafkaWriter.VALUE_ATTRIBUTE_NAME} " +
110+
s"attribute unsupported type $t")
111+
}
112+
UnsafeProjection.create(
113+
Seq(topicExpression, Cast(keyExpression, BinaryType),
114+
Cast(valueExpression, BinaryType)), inputSchema)
115+
}
116+
117+
private def checkForErrors: Unit = {
118+
if (failedWrite != null) {
119+
throw failedWrite
120+
}
121+
}
122+
}
123+

0 commit comments

Comments
 (0)