Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ class SparkContext(
* (a-hdfs-path/part-nnnnn, its content)
* }}}
*
* @note Small files are perferred, large file is also allowable, but may cause bad performance.
* @note Small files are preferred, as each file will be loaded fully in memory.
*/
def wholeTextFiles(path: String): RDD[(String, String)] = {
newAPIHadoopFile(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
* (a-hdfs-path/part-nnnnn, its content)
* }}}
*
* @note Small files are perferred, large file is also allowable, but may cause bad performance.
* @note Small files are preferred, as each file will be loaded fully in memory.
*/
def wholeTextFiles(path: String): JavaPairRDD[String, String] =
new JavaPairRDD(sc.wholeTextFiles(path))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.api.python

import java.io._
import java.net._
import java.nio.charset.Charset
import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections}

import scala.collection.JavaConversions._
Expand Down Expand Up @@ -206,6 +207,7 @@ private object SpecialLengths {
}

private[spark] object PythonRDD {
val UTF8 = Charset.forName("UTF-8")

def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
JavaRDD[Array[Byte]] = {
Expand Down Expand Up @@ -266,7 +268,7 @@ private[spark] object PythonRDD {
}

def writeUTF(str: String, dataOut: DataOutputStream) {
val bytes = str.getBytes("UTF-8")
val bytes = str.getBytes(UTF8)
dataOut.writeInt(bytes.length)
dataOut.write(bytes)
}
Expand All @@ -286,7 +288,7 @@ private[spark] object PythonRDD {

private
class BytesToString extends org.apache.spark.api.java.function.Function[Array[Byte], String] {
override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8")
override def call(arr: Array[Byte]) : String = new String(arr, PythonRDD.UTF8)
}

/**
Expand Down
44 changes: 42 additions & 2 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
from pyspark.conf import SparkConf
from pyspark.files import SparkFiles
from pyspark.java_gateway import launch_gateway
from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer
from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \
PairDeserializer
from pyspark.storagelevel import StorageLevel
from pyspark import rdd
from pyspark.rdd import RDD
Expand Down Expand Up @@ -257,6 +258,45 @@ def textFile(self, name, minSplits=None):
return RDD(self._jsc.textFile(name, minSplits), self,
UTF8Deserializer())

def wholeTextFiles(self, path):
"""
Read a directory of text files from HDFS, a local file system
(available on all nodes), or any Hadoop-supported file system
URI. Each file is read as a single record and returned in a
key-value pair, where the key is the path of each file, the
value is the content of each file.

For example, if you have the following files::

hdfs://a-hdfs-path/part-00000
hdfs://a-hdfs-path/part-00001
...
hdfs://a-hdfs-path/part-nnnnn

Do C{rdd = sparkContext.wholeTextFiles("hdfs://a-hdfs-path")},
then C{rdd} contains::

(a-hdfs-path/part-00000, its content)
(a-hdfs-path/part-00001, its content)
...
(a-hdfs-path/part-nnnnn, its content)

NOTE: Small files are preferred, as each file will be loaded
fully in memory.

>>> dirPath = os.path.join(tempdir, "files")
>>> os.mkdir(dirPath)
>>> with open(os.path.join(dirPath, "1.txt"), "w") as file1:
... file1.write("1")
>>> with open(os.path.join(dirPath, "2.txt"), "w") as file2:
... file2.write("2")
>>> textFiles = sc.wholeTextFiles(dirPath)
>>> sorted(textFiles.collect())
[(u'.../1.txt', u'1'), (u'.../2.txt', u'2')]
"""
return RDD(self._jsc.wholeTextFiles(path), self,
PairDeserializer(UTF8Deserializer(), UTF8Deserializer()))

def _checkpointFile(self, name, input_deserializer):
jrdd = self._jsc.checkpointFile(name)
return RDD(jrdd, self, input_deserializer)
Expand Down Expand Up @@ -425,7 +465,7 @@ def _test():
globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
globs['tempdir'] = tempfile.mkdtemp()
atexit.register(lambda: shutil.rmtree(globs['tempdir']))
(failure_count, test_count) = doctest.testmod(globs=globs)
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
globs['sc'].stop()
if failure_count:
exit(-1)
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ class MarshalSerializer(FramedSerializer):

class UTF8Deserializer(Serializer):
"""
Deserializes streams written by getBytes.
Deserializes streams written by String.getBytes.
"""

def loads(self, stream):
Expand Down