Skip to content

Commit 0336579

Browse files
committed
Refactor Flume unit tests and also add tests for Python API
1 parent 9f33873 commit 0336579

File tree

6 files changed

+546
-218
lines changed

6 files changed

+546
-218
lines changed
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.streaming.flume
19+
20+
import java.net.{InetSocketAddress, ServerSocket}
21+
import java.nio.ByteBuffer
22+
import java.util.{List => JList}
23+
24+
import scala.collection.JavaConversions._
25+
26+
import com.google.common.base.Charsets.UTF_8
27+
import org.apache.avro.ipc.NettyTransceiver
28+
import org.apache.avro.ipc.specific.SpecificRequestor
29+
import org.apache.commons.lang3.RandomUtils
30+
import org.apache.flume.source.avro
31+
import org.apache.flume.source.avro.{AvroSourceProtocol, AvroFlumeEvent}
32+
import org.jboss.netty.channel.ChannelPipeline
33+
import org.jboss.netty.channel.socket.SocketChannel
34+
import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory
35+
import org.jboss.netty.handler.codec.compression.{ZlibDecoder, ZlibEncoder}
36+
37+
import org.apache.spark.util.Utils
38+
import org.apache.spark.SparkConf
39+
40+
/**
41+
* Share codes for Scala and Python unit tests
42+
*/
43+
private[flume] class FlumeTestUtils {
44+
45+
private var transceiver: NettyTransceiver = null
46+
47+
private val testPort: Int = findFreePort()
48+
49+
def getTestPort(): Int = testPort
50+
51+
/** Find a free port */
52+
private def findFreePort(): Int = {
53+
val candidatePort = RandomUtils.nextInt(1024, 65536)
54+
Utils.startServiceOnPort(candidatePort, (trialPort: Int) => {
55+
val socket = new ServerSocket(trialPort)
56+
socket.close()
57+
(null, trialPort)
58+
}, new SparkConf())._2
59+
}
60+
61+
/** Send data to the flume receiver */
62+
def writeInput(input: JList[String], enableCompression: Boolean): Unit = {
63+
val testAddress = new InetSocketAddress("localhost", testPort)
64+
65+
val inputEvents = input.map { item =>
66+
val event = new AvroFlumeEvent
67+
event.setBody(ByteBuffer.wrap(item.getBytes(UTF_8)))
68+
event.setHeaders(Map[CharSequence, CharSequence]("test" -> "header"))
69+
event
70+
}
71+
72+
// if last attempted transceiver had succeeded, close it
73+
close()
74+
75+
// Create transceiver
76+
transceiver = {
77+
if (enableCompression) {
78+
new NettyTransceiver(testAddress, new CompressionChannelFactory(6))
79+
} else {
80+
new NettyTransceiver(testAddress)
81+
}
82+
}
83+
84+
// Create Avro client with the transceiver
85+
val client = SpecificRequestor.getClient(classOf[AvroSourceProtocol], transceiver)
86+
if (client == null) {
87+
throw new AssertionError("Cannot create client")
88+
}
89+
90+
// Send data
91+
val status = client.appendBatch(inputEvents.toList)
92+
if (status != avro.Status.OK) {
93+
throw new AssertionError("Sent events unsuccessfully")
94+
}
95+
}
96+
97+
def close(): Unit = {
98+
if (transceiver != null) {
99+
transceiver.close()
100+
transceiver = null
101+
}
102+
}
103+
104+
/** Class to create socket channel with compression */
105+
private class CompressionChannelFactory(compressionLevel: Int)
106+
extends NioClientSocketChannelFactory {
107+
108+
override def newChannel(pipeline: ChannelPipeline): SocketChannel = {
109+
val encoder = new ZlibEncoder(compressionLevel)
110+
pipeline.addFirst("deflater", encoder)
111+
pipeline.addFirst("inflater", new ZlibDecoder())
112+
super.newChannel(pipeline)
113+
}
114+
}
115+
116+
}
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.streaming.flume
19+
20+
import java.util.concurrent._
21+
import java.util.{List => JList, Map => JMap}
22+
23+
import scala.collection.JavaConversions._
24+
import scala.collection.mutable.ArrayBuffer
25+
26+
import com.google.common.base.Charsets.UTF_8
27+
import org.apache.flume.event.EventBuilder
28+
import org.apache.flume.Context
29+
import org.apache.flume.channel.MemoryChannel
30+
import org.apache.flume.conf.Configurables
31+
32+
import org.apache.spark.streaming.flume.sink.{SparkSinkConfig, SparkSink}
33+
34+
/**
35+
* Share codes for Scala and Python unit tests
36+
*/
37+
private[flume] class PollingFlumeTestUtils {
38+
39+
private val batchCount = 5
40+
private val eventsPerBatch = 100
41+
private val totalEventsPerChannel = batchCount * eventsPerBatch
42+
private val channelCapacity = 5000
43+
44+
def getEventsPerBatch: Int = eventsPerBatch
45+
46+
def getTotalEvents: Int = totalEventsPerChannel * channels.size
47+
48+
private val channels = new ArrayBuffer[MemoryChannel]
49+
private val sinks = new ArrayBuffer[SparkSink]
50+
51+
/**
52+
* Start a sink and return the port of this sink
53+
*/
54+
def startSingleSink(): Int = {
55+
channels.clear()
56+
sinks.clear()
57+
58+
// Start the channel and sink.
59+
val context = new Context()
60+
context.put("capacity", channelCapacity.toString)
61+
context.put("transactionCapacity", "1000")
62+
context.put("keep-alive", "0")
63+
val channel = new MemoryChannel()
64+
Configurables.configure(channel, context)
65+
66+
val sink = new SparkSink()
67+
context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost")
68+
context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0))
69+
Configurables.configure(sink, context)
70+
sink.setChannel(channel)
71+
sink.start()
72+
73+
channels += (channel)
74+
sinks += sink
75+
76+
sink.getPort()
77+
}
78+
79+
/**
80+
* Start 2 sinks and return the ports
81+
*/
82+
def startMultipleSinks(): JList[Int] = {
83+
channels.clear()
84+
sinks.clear()
85+
86+
// Start the channel and sink.
87+
val context = new Context()
88+
context.put("capacity", channelCapacity.toString)
89+
context.put("transactionCapacity", "1000")
90+
context.put("keep-alive", "0")
91+
val channel = new MemoryChannel()
92+
Configurables.configure(channel, context)
93+
94+
val channel2 = new MemoryChannel()
95+
Configurables.configure(channel2, context)
96+
97+
val sink = new SparkSink()
98+
context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost")
99+
context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0))
100+
Configurables.configure(sink, context)
101+
sink.setChannel(channel)
102+
sink.start()
103+
104+
val sink2 = new SparkSink()
105+
context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost")
106+
context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0))
107+
Configurables.configure(sink2, context)
108+
sink2.setChannel(channel2)
109+
sink2.start()
110+
111+
sinks += sink
112+
sinks += sink2
113+
channels += channel
114+
channels += channel2
115+
116+
sinks.map(_.getPort())
117+
}
118+
119+
/**
120+
* Send data and wait until all data has been received
121+
*/
122+
def sendDatAndEnsureAllDataHasBeenReceived(): Unit = {
123+
val executor = Executors.newCachedThreadPool()
124+
val executorCompletion = new ExecutorCompletionService[Void](executor)
125+
126+
val latch = new CountDownLatch(batchCount * channels.size)
127+
sinks.foreach(_.countdownWhenBatchReceived(latch))
128+
129+
channels.foreach(channel => {
130+
executorCompletion.submit(new TxnSubmitter(channel))
131+
})
132+
133+
for (i <- 0 until channels.size) {
134+
executorCompletion.take()
135+
}
136+
137+
latch.await(15, TimeUnit.SECONDS) // Ensure all data has been received.
138+
}
139+
140+
/**
141+
* A Python-friendly method to assert the output
142+
*/
143+
def assertOutput(
144+
outputHeaders: JList[JMap[String, String]], outputBodies: JList[String]): Unit = {
145+
require(outputHeaders.size == outputBodies.size)
146+
val eventSize = outputHeaders.size
147+
if (eventSize != totalEventsPerChannel * channels.size) {
148+
throw new AssertionError(
149+
s"Expected ${totalEventsPerChannel * channels.size} events, but was $eventSize")
150+
}
151+
var counter = 0
152+
for (k <- 0 until channels.size; i <- 0 until totalEventsPerChannel) {
153+
val eventBodyToVerify = s"${channels(k).getName}-$i"
154+
val eventHeaderToVerify: JMap[String, String] = Map[String, String](s"test-$i" -> "header")
155+
var found = false
156+
var j = 0
157+
while (j < eventSize && !found) {
158+
if (eventBodyToVerify == outputBodies.get(j) &&
159+
eventHeaderToVerify == outputHeaders.get(j)) {
160+
found = true
161+
counter += 1
162+
}
163+
j += 1
164+
}
165+
}
166+
if (counter != totalEventsPerChannel * channels.size) {
167+
throw new AssertionError(
168+
s"111 Expected ${totalEventsPerChannel * channels.size} events, but was $counter")
169+
}
170+
}
171+
172+
def assertChannelsAreEmpty(): Unit = {
173+
channels.foreach(assertChannelIsEmpty)
174+
}
175+
176+
private def assertChannelIsEmpty(channel: MemoryChannel): Unit = {
177+
val queueRemaining = channel.getClass.getDeclaredField("queueRemaining")
178+
queueRemaining.setAccessible(true)
179+
val m = queueRemaining.get(channel).getClass.getDeclaredMethod("availablePermits")
180+
if (m.invoke(queueRemaining.get(channel)).asInstanceOf[Int] != 5000) {
181+
throw new AssertionError(s"Channel ${channel.getName} is not empty")
182+
}
183+
}
184+
185+
def close(): Unit = {
186+
sinks.foreach(_.stop())
187+
sinks.clear()
188+
channels.foreach(_.stop())
189+
channels.clear()
190+
}
191+
192+
private class TxnSubmitter(channel: MemoryChannel) extends Callable[Void] {
193+
override def call(): Void = {
194+
var t = 0
195+
for (i <- 0 until batchCount) {
196+
val tx = channel.getTransaction
197+
tx.begin()
198+
for (j <- 0 until eventsPerBatch) {
199+
channel.put(EventBuilder.withBody(s"${channel.getName}-$t".getBytes(UTF_8),
200+
Map[String, String](s"test-$t" -> "header")))
201+
t += 1
202+
}
203+
tx.commit()
204+
tx.close()
205+
Thread.sleep(500) // Allow some time for the events to reach
206+
}
207+
null
208+
}
209+
}
210+
211+
}

0 commit comments

Comments
 (0)