diff --git a/streaming/src/main/scala/org/apache/spark/streaming/storage/HdfsUtils.scala b/streaming/src/main/scala/org/apache/spark/streaming/storage/HdfsUtils.scala index de53c59c6826..079b2fef904a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/storage/HdfsUtils.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/storage/HdfsUtils.scala @@ -21,11 +21,10 @@ import org.apache.hadoop.fs.{FSDataInputStream, FSDataOutputStream, Path} private[streaming] object HdfsUtils { - def getOutputStream(path: String): FSDataOutputStream = { + def getOutputStream(path: String, conf: Configuration): FSDataOutputStream = { // HDFS is not thread-safe when getFileSystem is called, so synchronize on that val dfsPath = new Path(path) - val conf = new Configuration() val dfs = this.synchronized { dfsPath.getFileSystem(conf) @@ -45,10 +44,10 @@ private[streaming] object HdfsUtils { stream } - def getInputStream(path: String): FSDataInputStream = { + def getInputStream(path: String, conf: Configuration): FSDataInputStream = { val dfsPath = new Path(path) val dfs = this.synchronized { - dfsPath.getFileSystem(new Configuration()) + dfsPath.getFileSystem(conf) } val instream = dfs.open(dfsPath) instream diff --git a/streaming/src/main/scala/org/apache/spark/streaming/storage/WriteAheadLogRandomReader.scala b/streaming/src/main/scala/org/apache/spark/streaming/storage/WriteAheadLogRandomReader.scala index aee5d192102e..3df024834f7a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/storage/WriteAheadLogRandomReader.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/storage/WriteAheadLogRandomReader.scala @@ -17,13 +17,17 @@ package org.apache.spark.streaming.storage import java.io.Closeable +import java.nio.ByteBuffer -private[streaming] class WriteAheadLogRandomReader(path: String) extends Closeable { +import org.apache.hadoop.conf.Configuration - private val instream = HdfsUtils.getInputStream(path) +private[streaming] class WriteAheadLogRandomReader(path: String, conf: Configuration) + extends Closeable { + + private val instream = HdfsUtils.getInputStream(path, conf) private var closed = false - def read(segment: FileSegment): Array[Byte] = synchronized { + def read(segment: FileSegment): ByteBuffer = synchronized { assertOpen() instream.seek(segment.offset) val nextLength = instream.readInt() @@ -31,7 +35,7 @@ private[streaming] class WriteAheadLogRandomReader(path: String) extends Closeab "Expected message length to be " + segment.length + ", " + "but was " + nextLength) val buffer = new Array[Byte](nextLength) instream.readFully(buffer) - buffer + ByteBuffer.wrap(buffer) } override def close(): Unit = synchronized { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/storage/WriteAheadLogReader.scala b/streaming/src/main/scala/org/apache/spark/streaming/storage/WriteAheadLogReader.scala index 75791c247018..724549e216e9 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/storage/WriteAheadLogReader.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/storage/WriteAheadLogReader.scala @@ -16,36 +16,37 @@ */ package org.apache.spark.streaming.storage -import java.io.Closeable +import java.io.{EOFException, Closeable} +import java.nio.ByteBuffer -private[streaming] class WriteAheadLogReader(path: String) - extends Iterator[Array[Byte]] with Closeable { +import org.apache.hadoop.conf.Configuration - private val instream = HdfsUtils.getInputStream(path) +private[streaming] class WriteAheadLogReader(path: String, conf: Configuration) + extends Iterator[ByteBuffer] with Closeable { + + private val instream = HdfsUtils.getInputStream(path, conf) private var closed = false - private var nextItem: Option[Array[Byte]] = None + private var nextItem: Option[ByteBuffer] = None override def hasNext: Boolean = synchronized { assertOpen() if (nextItem.isDefined) { // handle the case where hasNext is called without calling next true } else { - val available = instream.available() - if (available < 4) { // Length of next block (which is an Int = 4 bytes) of data is unavailable! - false - } - val length = instream.readInt() - if (instream.available() < length) { - false + try { + val length = instream.readInt() + val buffer = new Array[Byte](length) + instream.readFully(buffer) + nextItem = Some(ByteBuffer.wrap(buffer)) + true + } catch { + case e: EOFException => false + case e: Exception => throw e } - val buffer = new Array[Byte](length) - instream.readFully(buffer) - nextItem = Some(buffer) - true } } - override def next(): Array[Byte] = synchronized { + override def next(): ByteBuffer = synchronized { // TODO: Possible error case where there are not enough bytes in the stream // TODO: How to handle that? val data = nextItem.getOrElse { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/storage/WriteAheadLogWriter.scala b/streaming/src/main/scala/org/apache/spark/streaming/storage/WriteAheadLogWriter.scala index f151c17ff66d..8a2db8305a7e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/storage/WriteAheadLogWriter.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/storage/WriteAheadLogWriter.scala @@ -17,21 +17,38 @@ package org.apache.spark.streaming.storage import java.io.Closeable +import java.lang.reflect.Method +import java.nio.ByteBuffer -private[streaming] class WriteAheadLogWriter(path: String) extends Closeable { - private val stream = HdfsUtils.getOutputStream(path) +import scala.util.Try + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FSDataOutputStream + +private[streaming] class WriteAheadLogWriter(path: String, conf: Configuration) extends Closeable { + private val stream = HdfsUtils.getOutputStream(path, conf) private var nextOffset = stream.getPos private var closed = false + private val hflushMethod = getHflushOrSync() // Data is always written as: // - Length - Long // - Data - of length = Length - def write(data: Array[Byte]): FileSegment = synchronized { + def write(data: ByteBuffer): FileSegment = synchronized { assertOpen() - val segment = new FileSegment(path, nextOffset, data.length) - stream.writeInt(data.length) - stream.write(data) - stream.hflush() + data.rewind() // Rewind to ensure all data in the buffer is retrieved + val lengthToWrite = data.remaining() + val segment = new FileSegment(path, nextOffset, lengthToWrite) + stream.writeInt(lengthToWrite) + if (data.hasArray) { + stream.write(data.array()) + } else { + // If the buffer is not backed by an array we need to write the data byte by byte + while (data.hasRemaining) { + stream.write(data.get()) + } + } + hflushOrSync() nextOffset = stream.getPos segment } @@ -41,6 +58,19 @@ private[streaming] class WriteAheadLogWriter(path: String) extends Closeable { stream.close() } + private def hflushOrSync() { + hflushMethod.foreach(_.invoke(stream)) + } + + private def getHflushOrSync(): Option[Method] = { + Try { + Some(classOf[FSDataOutputStream].getMethod("hflush")) + }.recover { + case e: NoSuchMethodException => + Some(classOf[FSDataOutputStream].getMethod("sync")) + }.getOrElse(None) + } + private def assertOpen() { HdfsUtils.checkState(!closed, "Stream is closed. Create a new Writer to write to file.") } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/storage/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/storage/WriteAheadLogSuite.scala new file mode 100644 index 000000000000..ed21bdbb399f --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/storage/WriteAheadLogSuite.scala @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.storage + +import java.io.{RandomAccessFile, File} +import java.nio.ByteBuffer +import java.util.Random + +import scala.collection.mutable.ArrayBuffer + +import com.google.common.io.Files +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.streaming.TestSuiteBase + +class WriteAheadLogSuite extends TestSuiteBase { + + val hadoopConf = new Configuration() + val random = new Random() + + test("Test successful writes") { + val dir = Files.createTempDir() + val file = new File(dir, "TestWriter") + try { + val dataToWrite = for (i <- 1 to 50) yield generateRandomData() + val writer = new WriteAheadLogWriter("file:///" + file.toString, hadoopConf) + val segments = dataToWrite.map(writer.write) + writer.close() + val writtenData = readData(segments, file) + assert(writtenData.toArray === dataToWrite.toArray) + } finally { + file.delete() + dir.delete() + } + } + + test("Test successful reads using random reader") { + val file = File.createTempFile("TestRandomReads", "") + file.deleteOnExit() + val writtenData = writeData(50, file) + val reader = new WriteAheadLogRandomReader("file:///" + file.toString, hadoopConf) + writtenData.foreach { + x => + val length = x._1.remaining() + assert(x._1 === reader.read(new FileSegment(file.toString, x._2, length))) + } + reader.close() + } + + test("Test reading data using random reader written with writer") { + val dir = Files.createTempDir() + val file = new File(dir, "TestRandomReads") + try { + val dataToWrite = for (i <- 1 to 50) yield generateRandomData() + val segments = writeUsingWriter(file, dataToWrite) + val iter = dataToWrite.iterator + val reader = new WriteAheadLogRandomReader("file:///" + file.toString, hadoopConf) + val writtenData = segments.map { x => + reader.read(x) + } + assert(dataToWrite.toArray === writtenData.toArray) + } finally { + file.delete() + dir.delete() + } + } + + test("Test successful reads using sequential reader") { + val file = File.createTempFile("TestSequentialReads", "") + file.deleteOnExit() + val writtenData = writeData(50, file) + val reader = new WriteAheadLogReader("file:///" + file.toString, hadoopConf) + val iter = writtenData.iterator + iter.foreach { x => + assert(reader.hasNext === true) + assert(reader.next() === x._1) + } + reader.close() + } + + + test("Test reading data using sequential reader written with writer") { + val dir = Files.createTempDir() + val file = new File(dir, "TestWriter") + try { + val dataToWrite = for (i <- 1 to 50) yield generateRandomData() + val segments = writeUsingWriter(file, dataToWrite) + val iter = dataToWrite.iterator + val reader = new WriteAheadLogReader("file:///" + file.toString, hadoopConf) + reader.foreach { x => + assert(x === iter.next()) + } + } finally { + file.delete() + dir.delete() + } + } + + /** + * Writes data to the file and returns the an array of the bytes written. + * @param count + * @return + */ + // We don't want to be using the WAL writer to test the reader - it will be painful to figure + // out where the bug is. Instead generate the file by hand and see if the WAL reader can + // handle it. + def writeData(count: Int, file: File): ArrayBuffer[(ByteBuffer, Long)] = { + val writtenData = new ArrayBuffer[(ByteBuffer, Long)]() + val writer = new RandomAccessFile(file, "rw") + var i = 0 + while (i < count) { + val data = generateRandomData() + writtenData += ((data, writer.getFilePointer)) + data.rewind() + writer.writeInt(data.remaining()) + writer.write(data.array()) + i += 1 + } + writer.close() + writtenData + } + + def readData(segments: Seq[FileSegment], file: File): Seq[ByteBuffer] = { + val reader = new RandomAccessFile(file, "r") + segments.map { x => + reader.seek(x.offset) + val data = new Array[Byte](x.length) + reader.readInt() + reader.readFully(data) + ByteBuffer.wrap(data) + } + } + + def generateRandomData(): ByteBuffer = { + val data = new Array[Byte](random.nextInt(50)) + random.nextBytes(data) + ByteBuffer.wrap(data) + } + + def writeUsingWriter(file: File, input: Seq[ByteBuffer]): Seq[FileSegment] = { + val writer = new WriteAheadLogWriter(file.toString, hadoopConf) + val segments = input.map(writer.write) + writer.close() + segments + } +}