Skip to content

Commit 830f4c0

Browse files
committed
Merge pull request databricks#1 from cfregly/master
initial checkin
2 parents 5d6cfa9 + b26989c commit 830f4c0

File tree

6 files changed

+452
-0
lines changed

6 files changed

+452
-0
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
target/
2+
project/target

README.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
RedshiftInputFormat
2+
===
3+
4+
Hadoop input format for Redshift tables unloaded with the ESCAPE option.
5+
6+
Usage in Spark Core:
7+
```scala
8+
import com.databricks.examples.redshift.input.RedshiftInputFormat
9+
10+
val records = sc.newAPIHadoopFile(
11+
path,
12+
classOf[RedshiftInputFormat],
13+
classOf[java.lang.Long],
14+
classOf[Array[String]])
15+
```
16+
17+
Usage in Spark SQL:
18+
```scala
19+
import com.databricks.examples.redshift.input.RedshiftInputFormat._
20+
21+
// Call redshiftFile() that returns a SchemaRDD with all string columns.
22+
val records: SchemaRDD = sqlContext.redshiftFile(path, Seq("name", "age"))
23+
```

build.sbt

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
net.virtualvoid.sbt.graph.Plugin.graphSettings
2+
3+
organization := "com.databricks.examples.redshift"
4+
5+
name := "redshift-input-format"
6+
7+
version := "0.1"
8+
9+
scalaVersion := "2.10.4"
10+
11+
libraryDependencies += "org.apache.hadoop" % "hadoop-client" % "1.0.4"
12+
13+
libraryDependencies += "org.apache.spark" %% "spark-sql" % "1.1.0"
14+
15+
libraryDependencies += "com.google.guava" % "guava" % "14.0.1" % Test
16+
17+
libraryDependencies += "org.scalatest" %% "scalatest" % "2.1.5" % Test

project/plugins.sbt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
addSbtPlugin("com.github.mpeltonen" % "sbt-idea" % "1.6.0")
2+
3+
addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.7.4")
Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,285 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package com.databricks.examples.redshift.input
19+
20+
import java.lang.{Long => JavaLong}
21+
import java.io.{BufferedInputStream, IOException}
22+
import java.nio.charset.Charset
23+
24+
import scala.collection.mutable.ArrayBuffer
25+
26+
import org.apache.hadoop.conf.Configuration
27+
import org.apache.hadoop.fs.{Path, FileSystem}
28+
import org.apache.hadoop.io.compress.CompressionCodecFactory
29+
import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, FileSplit}
30+
import org.apache.hadoop.mapreduce.{InputSplit, RecordReader, TaskAttemptContext}
31+
32+
import org.apache.spark.SparkContext._
33+
import org.apache.spark.sql.{SQLContext, SchemaRDD, Row}
34+
import org.apache.spark.sql.catalyst.types._
35+
36+
/**
37+
* Input format for text records saved with in-record delimiter and newline characters escaped.
38+
*
39+
* For example, a record containing two fields: `"a\n"` and `"|b\\"` saved with delimiter `|`
40+
* should be the following:
41+
* {{{
42+
* a\\\n|\\|b\\\\\n
43+
* }}},
44+
* where the in-record `|`, `\r`, `\n`, and `\\` characters are escaped by `\\`.
45+
* Users can configure the delimiter via [[RedshiftInputFormat$#KEY_DELIMITER]].
46+
* Its default value [[RedshiftInputFormat$#DEFAULT_DELIMITER]] is set to match Redshift's UNLOAD
47+
* with the ESCAPE option:
48+
* {{{
49+
* UNLOAD ('select_statement')
50+
* TO 's3://object_path_prefix'
51+
* ESCAPE
52+
* }}}
53+
*
54+
* @see org.apache.spark.SparkContext#newAPIHadoopFile
55+
*/
56+
class RedshiftInputFormat extends FileInputFormat[JavaLong, Array[String]] {
57+
58+
override def createRecordReader(
59+
split: InputSplit,
60+
context: TaskAttemptContext): RecordReader[JavaLong, Array[String]] = {
61+
new RedshiftRecordReader
62+
}
63+
}
64+
65+
object RedshiftInputFormat {
66+
67+
/** configuration key for delimiter */
68+
val KEY_DELIMITER = "redshift.delimiter"
69+
/** default delimiter */
70+
val DEFAULT_DELIMITER = '|'
71+
72+
/** Gets the delimiter char from conf or the default. */
73+
private[input] def getDelimiterOrDefault(conf: Configuration): Char = {
74+
val c = conf.get(KEY_DELIMITER, DEFAULT_DELIMITER.toString)
75+
if (c.length != 1) {
76+
throw new IllegalArgumentException(s"Expect delimiter be a single character but got '$c'.")
77+
} else {
78+
c.charAt(0)
79+
}
80+
}
81+
82+
/**
83+
* Wrapper of SQLContext that provide `redshiftFile` method.
84+
*/
85+
class SQLContextWithRedshiftFile(sqlContext: SQLContext) {
86+
87+
/**
88+
* Read a file unloaded from Redshift into a SchemaRDD.
89+
* @param path input path
90+
* @return a SchemaRDD
91+
*/
92+
def redshiftFile(path: String, columns: Seq[String]): SchemaRDD = {
93+
val sc = sqlContext.sparkContext
94+
val rdd = sc.newAPIHadoopFile(path, classOf[RedshiftInputFormat],
95+
classOf[java.lang.Long], classOf[Array[String]], sc.hadoopConfiguration)
96+
val schema = StructType(columns.map(c => StructField(c, StringType, false)))
97+
sqlContext.applySchema(rdd.values.map(x => Row(x: _*)), schema)
98+
}
99+
}
100+
101+
implicit def fromSQLContext(sqlContext: SQLContext): SQLContextWithRedshiftFile =
102+
new SQLContextWithRedshiftFile(sqlContext)
103+
}
104+
105+
private[input] class RedshiftRecordReader extends RecordReader[JavaLong, Array[String]] {
106+
107+
private var reader: BufferedInputStream = _
108+
109+
private var key: JavaLong = _
110+
private var value: Array[String] = _
111+
112+
private var start: Long = _
113+
private var end: Long = _
114+
private var cur: Long = _
115+
116+
private var eof: Boolean = false
117+
118+
private var delimiter: Byte = _
119+
@inline private[this] final val escapeChar: Byte = '\\'
120+
@inline private[this] final val lineFeed: Byte = '\n'
121+
@inline private[this] final val carriageReturn: Byte = '\r'
122+
123+
@inline private[this] final val defaultBufferSize = 1024 * 1024
124+
125+
private[this] val chars = ArrayBuffer.empty[Byte]
126+
127+
override def initialize(inputSplit: InputSplit, context: TaskAttemptContext): Unit = {
128+
val split = inputSplit.asInstanceOf[FileSplit]
129+
val file = split.getPath
130+
val conf = context.getConfiguration
131+
delimiter = RedshiftInputFormat.getDelimiterOrDefault(conf).asInstanceOf[Byte]
132+
require(delimiter != escapeChar,
133+
s"The delimiter and the escape char cannot be the same but found $delimiter.")
134+
require(delimiter != lineFeed, "The delimiter cannot be the lineFeed character.")
135+
require(delimiter != carriageReturn, "The delimiter cannot be the carriage return.")
136+
val compressionCodecs = new CompressionCodecFactory(conf)
137+
val codec = compressionCodecs.getCodec(file)
138+
if (codec != null) {
139+
throw new IOException(s"Do not support compressed files but found $file.")
140+
}
141+
val fs = file.getFileSystem(conf)
142+
val size = fs.getFileStatus(file).getLen
143+
start = findNext(fs, file, size, split.getStart)
144+
end = findNext(fs, file, size, split.getStart + split.getLength)
145+
cur = start
146+
val in = fs.open(file)
147+
if (cur > 0L) {
148+
in.seek(cur - 1L)
149+
in.read()
150+
}
151+
reader = new BufferedInputStream(in, defaultBufferSize)
152+
}
153+
154+
override def getProgress: Float = {
155+
if (start >= end) {
156+
1.0f
157+
} else {
158+
math.min((cur - start).toFloat / (end - start), 1.0f)
159+
}
160+
}
161+
162+
override def nextKeyValue(): Boolean = {
163+
if (cur < end && !eof) {
164+
key = cur
165+
value = nextValue()
166+
true
167+
} else {
168+
key = null
169+
value = null
170+
false
171+
}
172+
}
173+
174+
override def getCurrentValue: Array[String] = value
175+
176+
override def getCurrentKey: JavaLong = key
177+
178+
override def close(): Unit = {
179+
if (reader != null) {
180+
reader.close()
181+
}
182+
}
183+
184+
/**
185+
* Finds the start of the next record.
186+
* Because we don't know whether the first char is escaped or not, we need to first find a
187+
* position that is not escaped.
188+
*
189+
* @param fs file system
190+
* @param file file path
191+
* @param size file size
192+
* @param offset start offset
193+
* @return the start position of the next record
194+
*/
195+
private def findNext(fs: FileSystem, file: Path, size: Long, offset: Long): Long = {
196+
if (offset == 0L)
197+
return 0L
198+
else if (offset >= size)
199+
return size
200+
val in = fs.open(file)
201+
var pos = offset
202+
in.seek(pos)
203+
val bis = new BufferedInputStream(in, defaultBufferSize)
204+
// Find the first unescaped char.
205+
var escaped = true
206+
var thisEof = false
207+
while (escaped && !thisEof) {
208+
val v = bis.read()
209+
if (v < 0) {
210+
thisEof = true
211+
} else {
212+
pos += 1
213+
if (v != escapeChar) {
214+
escaped = false
215+
}
216+
}
217+
}
218+
// Find the next unescaped line feed.
219+
var endOfRecord = false
220+
while ((escaped || !endOfRecord) && !thisEof) {
221+
val v = bis.read()
222+
if (v < 0) {
223+
thisEof = true
224+
} else {
225+
pos += 1
226+
if (v == escapeChar) {
227+
escaped = true
228+
} else {
229+
if (!escaped) {
230+
endOfRecord = v == lineFeed
231+
} else {
232+
escaped = false
233+
}
234+
}
235+
}
236+
}
237+
in.close()
238+
pos
239+
}
240+
241+
private def nextValue(): Array[String] = {
242+
val fields = ArrayBuffer.empty[String]
243+
var escaped = false
244+
var endOfRecord = false
245+
while (!endOfRecord && !eof) {
246+
var endOfField = false
247+
chars.clear()
248+
while (!endOfField && !endOfRecord && !eof) {
249+
val v = reader.read()
250+
if (v < 0) {
251+
eof = true
252+
} else {
253+
cur += 1L
254+
val c = v.asInstanceOf[Byte]
255+
if (escaped) {
256+
if (c != escapeChar && c != delimiter && c != lineFeed && c != carriageReturn) {
257+
throw new IllegalStateException(
258+
s"Found `$c` (ASCII $v) after $escapeChar.")
259+
}
260+
chars.append(c)
261+
escaped = false
262+
} else {
263+
if (c == escapeChar) {
264+
escaped = true
265+
} else if (c == delimiter) {
266+
endOfField = true
267+
} else if (c == lineFeed) {
268+
endOfRecord = true
269+
} else {
270+
// also copy carriage return
271+
chars.append(c)
272+
}
273+
}
274+
}
275+
}
276+
// TODO: charset?
277+
fields.append(new String(chars.toArray, Charset.forName("UTF-8")))
278+
}
279+
if (escaped) {
280+
throw new IllegalStateException(s"Found hanging escape char.")
281+
}
282+
fields.toArray
283+
}
284+
}
285+

0 commit comments

Comments
 (0)