Skip to content

Commit 83c0c56

Browse files
committed
Merge remote-tracking branch 'origin/master' into unsafe-by-default
2 parents f4cc859 + c581593 commit 83c0c56

File tree

54 files changed

+1255
-330
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+1255
-330
lines changed

core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ private PrefixComparators() {}
2929

3030
public static final StringPrefixComparator STRING = new StringPrefixComparator();
3131
public static final IntegralPrefixComparator INTEGRAL = new IntegralPrefixComparator();
32+
public static final FloatPrefixComparator FLOAT = new FloatPrefixComparator();
3233
public static final DoublePrefixComparator DOUBLE = new DoublePrefixComparator();
3334

3435
public static final class StringPrefixComparator extends PrefixComparator {
@@ -54,6 +55,21 @@ public int compare(long a, long b) {
5455
public final long NULL_PREFIX = Long.MIN_VALUE;
5556
}
5657

58+
public static final class FloatPrefixComparator extends PrefixComparator {
59+
@Override
60+
public int compare(long aPrefix, long bPrefix) {
61+
float a = Float.intBitsToFloat((int) aPrefix);
62+
float b = Float.intBitsToFloat((int) bPrefix);
63+
return Utils.nanSafeCompareFloats(a, b);
64+
}
65+
66+
public long computePrefix(float value) {
67+
return Float.floatToIntBits(value) & 0xffffffffL;
68+
}
69+
70+
public final long NULL_PREFIX = computePrefix(Float.NEGATIVE_INFINITY);
71+
}
72+
5773
public static final class DoublePrefixComparator extends PrefixComparator {
5874
@Override
5975
public int compare(long aPrefix, long bPrefix) {

sql/core/src/main/scala/org/apache/spark/sql/execution/SqlNewHadoopRDD.scala renamed to core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@
1515
* limitations under the License.
1616
*/
1717

18-
package org.apache.spark.sql.execution
18+
package org.apache.spark.rdd
1919

2020
import java.text.SimpleDateFormat
2121
import java.util.Date
2222

23-
import org.apache.spark.{Partition => SparkPartition, _}
23+
import scala.reflect.ClassTag
24+
2425
import org.apache.hadoop.conf.{Configurable, Configuration}
2526
import org.apache.hadoop.io.Writable
2627
import org.apache.hadoop.mapreduce._
@@ -30,12 +31,12 @@ import org.apache.spark.broadcast.Broadcast
3031
import org.apache.spark.deploy.SparkHadoopUtil
3132
import org.apache.spark.executor.DataReadMethod
3233
import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil
34+
import org.apache.spark.unsafe.types.UTF8String
35+
import org.apache.spark.{Partition => SparkPartition, _}
3336
import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD
34-
import org.apache.spark.rdd.{HadoopRDD, RDD}
3537
import org.apache.spark.storage.StorageLevel
3638
import org.apache.spark.util.{SerializableConfiguration, Utils}
3739

38-
import scala.reflect.ClassTag
3940

4041
private[spark] class SqlNewHadoopPartition(
4142
rddId: Int,
@@ -62,7 +63,7 @@ private[spark] class SqlNewHadoopPartition(
6263
* changes based on [[org.apache.spark.rdd.HadoopRDD]]. In future, this functionality will be
6364
* folded into core.
6465
*/
65-
private[sql] class SqlNewHadoopRDD[K, V](
66+
private[spark] class SqlNewHadoopRDD[K, V](
6667
@transient sc : SparkContext,
6768
broadcastedConf: Broadcast[SerializableConfiguration],
6869
@transient initDriverSideJobFuncOpt: Option[Job => Unit],
@@ -128,6 +129,12 @@ private[sql] class SqlNewHadoopRDD[K, V](
128129
val inputMetrics = context.taskMetrics
129130
.getInputMetricsForReadMethod(DataReadMethod.Hadoop)
130131

132+
// Sets the thread local variable for the file's name
133+
split.serializableHadoopSplit.value match {
134+
case fs: FileSplit => SqlNewHadoopRDD.setInputFileName(fs.getPath.toString)
135+
case _ => SqlNewHadoopRDD.unsetInputFileName()
136+
}
137+
131138
// Find a function that will return the FileSystem bytes read by this thread. Do this before
132139
// creating RecordReader, because RecordReader's constructor might read some bytes
133140
val bytesReadCallback = inputMetrics.bytesReadCallback.orElse {
@@ -188,6 +195,8 @@ private[sql] class SqlNewHadoopRDD[K, V](
188195
reader.close()
189196
reader = null
190197

198+
SqlNewHadoopRDD.unsetInputFileName()
199+
191200
if (bytesReadCallback.isDefined) {
192201
inputMetrics.updateBytesRead()
193202
} else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] ||
@@ -250,6 +259,21 @@ private[sql] class SqlNewHadoopRDD[K, V](
250259
}
251260

252261
private[spark] object SqlNewHadoopRDD {
262+
263+
/**
264+
* The thread variable for the name of the current file being read. This is used by
265+
* the InputFileName function in Spark SQL.
266+
*/
267+
private[this] val inputFileName: ThreadLocal[UTF8String] = new ThreadLocal[UTF8String] {
268+
override protected def initialValue(): UTF8String = UTF8String.fromString("")
269+
}
270+
271+
def getInputFileName(): UTF8String = inputFileName.get()
272+
273+
private[spark] def setInputFileName(file: String) = inputFileName.set(UTF8String.fromString(file))
274+
275+
private[spark] def unsetInputFileName(): Unit = inputFileName.remove()
276+
253277
/**
254278
* Analogous to [[org.apache.spark.rdd.MapPartitionsRDD]], but passes in an InputSplit to
255279
* the given function rather than the index of the partition.

core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,18 @@ class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks {
5555
forAll { (s1: String, s2: String) => testPrefixComparison(s1, s2) }
5656
}
5757

58+
test("float prefix comparator handles NaN properly") {
59+
val nan1: Float = java.lang.Float.intBitsToFloat(0x7f800001)
60+
val nan2: Float = java.lang.Float.intBitsToFloat(0x7fffffff)
61+
assert(nan1.isNaN)
62+
assert(nan2.isNaN)
63+
val nan1Prefix = PrefixComparators.FLOAT.computePrefix(nan1)
64+
val nan2Prefix = PrefixComparators.FLOAT.computePrefix(nan2)
65+
assert(nan1Prefix === nan2Prefix)
66+
val floatMaxPrefix = PrefixComparators.FLOAT.computePrefix(Float.MaxValue)
67+
assert(PrefixComparators.FLOAT.compare(nan1Prefix, floatMaxPrefix) === 1)
68+
}
69+
5870
test("double prefix comparator handles NaNs properly") {
5971
val nan1: Double = java.lang.Double.longBitsToDouble(0x7ff0000000000001L)
6072
val nan2: Double = java.lang.Double.longBitsToDouble(0x7fffffffffffffffL)

dev/run-tests.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,13 @@ def identify_changed_files_from_git_commits(patch_sha, target_branch=None, targe
8585
return [f for f in raw_output.split('\n') if f]
8686

8787

88+
def setup_test_environ(environ):
89+
print("[info] Setup the following environment variables for tests: ")
90+
for (k, v) in environ.items():
91+
print("%s=%s" % (k, v))
92+
os.environ[k] = v
93+
94+
8895
def determine_modules_to_test(changed_modules):
8996
"""
9097
Given a set of modules that have changed, compute the transitive closure of those modules'
@@ -455,6 +462,15 @@ def main():
455462
print("[info] Found the following changed modules:",
456463
", ".join(x.name for x in changed_modules))
457464

465+
# setup environment variables
466+
# note - the 'root' module doesn't collect environment variables for all modules. Because the
467+
# environment variables should not be set if a module is not changed, even if running the 'root'
468+
# module. So here we should use changed_modules rather than test_modules.
469+
test_environ = {}
470+
for m in changed_modules:
471+
test_environ.update(m.environ)
472+
setup_test_environ(test_environ)
473+
458474
test_modules = determine_modules_to_test(changed_modules)
459475

460476
# license checks

dev/sparktestsupport/modules.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class Module(object):
2929
changed.
3030
"""
3131

32-
def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=(),
32+
def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=(), environ={},
3333
sbt_test_goals=(), python_test_goals=(), blacklisted_python_implementations=(),
3434
should_run_r_tests=False):
3535
"""
@@ -43,6 +43,8 @@ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=
4343
filename strings.
4444
:param build_profile_flags: A set of profile flags that should be passed to Maven or SBT in
4545
order to build and test this module (e.g. '-PprofileName').
46+
:param environ: A dict of environment variables that should be set when files in this
47+
module are changed.
4648
:param sbt_test_goals: A set of SBT test goals for testing this module.
4749
:param python_test_goals: A set of Python test goals for testing this module.
4850
:param blacklisted_python_implementations: A set of Python implementations that are not
@@ -55,6 +57,7 @@ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=
5557
self.source_file_prefixes = source_file_regexes
5658
self.sbt_test_goals = sbt_test_goals
5759
self.build_profile_flags = build_profile_flags
60+
self.environ = environ
5861
self.python_test_goals = python_test_goals
5962
self.blacklisted_python_implementations = blacklisted_python_implementations
6063
self.should_run_r_tests = should_run_r_tests
@@ -126,15 +129,22 @@ def contains_file(self, filename):
126129
)
127130

128131

132+
# Don't set the dependencies because changes in other modules should not trigger Kinesis tests.
133+
# Kinesis tests depends on external Amazon kinesis service. We should run these tests only when
134+
# files in streaming_kinesis_asl are changed, so that if Kinesis experiences an outage, we don't
135+
# fail other PRs.
129136
streaming_kinesis_asl = Module(
130137
name="kinesis-asl",
131-
dependencies=[streaming],
138+
dependencies=[],
132139
source_file_regexes=[
133140
"extras/kinesis-asl/",
134141
],
135142
build_profile_flags=[
136143
"-Pkinesis-asl",
137144
],
145+
environ={
146+
"ENABLE_KINESIS_TESTS": "1"
147+
},
138148
sbt_test_goals=[
139149
"kinesis-asl/test",
140150
]

mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,8 @@ final class EMLDAOptimizer extends LDAOptimizer {
142142
this.k = k
143143
this.vocabSize = docs.take(1).head._2.size
144144
this.checkpointInterval = lda.getCheckpointInterval
145-
this.graphCheckpointer = new
146-
PeriodicGraphCheckpointer[TopicCounts, TokenCount](graph, checkpointInterval)
145+
this.graphCheckpointer = new PeriodicGraphCheckpointer[TopicCounts, TokenCount](
146+
checkpointInterval, graph.vertices.sparkContext)
147147
this.globalTopicTotals = computeGlobalTopicTotals()
148148
this
149149
}
@@ -188,7 +188,7 @@ final class EMLDAOptimizer extends LDAOptimizer {
188188
// Update the vertex descriptors with the new counts.
189189
val newGraph = GraphImpl.fromExistingRDDs(docTopicDistributions, graph.edges)
190190
graph = newGraph
191-
graphCheckpointer.updateGraph(newGraph)
191+
graphCheckpointer.update(newGraph)
192192
globalTopicTotals = computeGlobalTopicTotals()
193193
this
194194
}
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.impl
19+
20+
import scala.collection.mutable
21+
22+
import org.apache.hadoop.fs.{Path, FileSystem}
23+
24+
import org.apache.spark.{SparkContext, Logging}
25+
import org.apache.spark.storage.StorageLevel
26+
27+
28+
/**
29+
* This abstraction helps with persisting and checkpointing RDDs and types derived from RDDs
30+
* (such as Graphs and DataFrames). In documentation, we use the phrase "Dataset" to refer to
31+
* the distributed data type (RDD, Graph, etc.).
32+
*
33+
* Specifically, this abstraction automatically handles persisting and (optionally) checkpointing,
34+
* as well as unpersisting and removing checkpoint files.
35+
*
36+
* Users should call update() when a new Dataset has been created,
37+
* before the Dataset has been materialized. After updating [[PeriodicCheckpointer]], users are
38+
* responsible for materializing the Dataset to ensure that persisting and checkpointing actually
39+
* occur.
40+
*
41+
* When update() is called, this does the following:
42+
* - Persist new Dataset (if not yet persisted), and put in queue of persisted Datasets.
43+
* - Unpersist Datasets from queue until there are at most 3 persisted Datasets.
44+
* - If using checkpointing and the checkpoint interval has been reached,
45+
* - Checkpoint the new Dataset, and put in a queue of checkpointed Datasets.
46+
* - Remove older checkpoints.
47+
*
48+
* WARNINGS:
49+
* - This class should NOT be copied (since copies may conflict on which Datasets should be
50+
* checkpointed).
51+
* - This class removes checkpoint files once later Datasets have been checkpointed.
52+
* However, references to the older Datasets will still return isCheckpointed = true.
53+
*
54+
* @param checkpointInterval Datasets will be checkpointed at this interval
55+
* @param sc SparkContext for the Datasets given to this checkpointer
56+
* @tparam T Dataset type, such as RDD[Double]
57+
*/
58+
private[mllib] abstract class PeriodicCheckpointer[T](
59+
val checkpointInterval: Int,
60+
val sc: SparkContext) extends Logging {
61+
62+
/** FIFO queue of past checkpointed Datasets */
63+
private val checkpointQueue = mutable.Queue[T]()
64+
65+
/** FIFO queue of past persisted Datasets */
66+
private val persistedQueue = mutable.Queue[T]()
67+
68+
/** Number of times [[update()]] has been called */
69+
private var updateCount = 0
70+
71+
/**
72+
* Update with a new Dataset. Handle persistence and checkpointing as needed.
73+
* Since this handles persistence and checkpointing, this should be called before the Dataset
74+
* has been materialized.
75+
*
76+
* @param newData New Dataset created from previous Datasets in the lineage.
77+
*/
78+
def update(newData: T): Unit = {
79+
persist(newData)
80+
persistedQueue.enqueue(newData)
81+
// We try to maintain 2 Datasets in persistedQueue to support the semantics of this class:
82+
// Users should call [[update()]] when a new Dataset has been created,
83+
// before the Dataset has been materialized.
84+
while (persistedQueue.size > 3) {
85+
val dataToUnpersist = persistedQueue.dequeue()
86+
unpersist(dataToUnpersist)
87+
}
88+
updateCount += 1
89+
90+
// Handle checkpointing (after persisting)
91+
if ((updateCount % checkpointInterval) == 0 && sc.getCheckpointDir.nonEmpty) {
92+
// Add new checkpoint before removing old checkpoints.
93+
checkpoint(newData)
94+
checkpointQueue.enqueue(newData)
95+
// Remove checkpoints before the latest one.
96+
var canDelete = true
97+
while (checkpointQueue.size > 1 && canDelete) {
98+
// Delete the oldest checkpoint only if the next checkpoint exists.
99+
if (isCheckpointed(checkpointQueue.head)) {
100+
removeCheckpointFile()
101+
} else {
102+
canDelete = false
103+
}
104+
}
105+
}
106+
}
107+
108+
/** Checkpoint the Dataset */
109+
protected def checkpoint(data: T): Unit
110+
111+
/** Return true iff the Dataset is checkpointed */
112+
protected def isCheckpointed(data: T): Boolean
113+
114+
/**
115+
* Persist the Dataset.
116+
* Note: This should handle checking the current [[StorageLevel]] of the Dataset.
117+
*/
118+
protected def persist(data: T): Unit
119+
120+
/** Unpersist the Dataset */
121+
protected def unpersist(data: T): Unit
122+
123+
/** Get list of checkpoint files for this given Dataset */
124+
protected def getCheckpointFiles(data: T): Iterable[String]
125+
126+
/**
127+
* Call this at the end to delete any remaining checkpoint files.
128+
*/
129+
def deleteAllCheckpoints(): Unit = {
130+
while (checkpointQueue.nonEmpty) {
131+
removeCheckpointFile()
132+
}
133+
}
134+
135+
/**
136+
* Dequeue the oldest checkpointed Dataset, and remove its checkpoint files.
137+
* This prints a warning but does not fail if the files cannot be removed.
138+
*/
139+
private def removeCheckpointFile(): Unit = {
140+
val old = checkpointQueue.dequeue()
141+
// Since the old checkpoint is not deleted by Spark, we manually delete it.
142+
val fs = FileSystem.get(sc.hadoopConfiguration)
143+
getCheckpointFiles(old).foreach { checkpointFile =>
144+
try {
145+
fs.delete(new Path(checkpointFile), true)
146+
} catch {
147+
case e: Exception =>
148+
logWarning("PeriodicCheckpointer could not remove old checkpoint file: " +
149+
checkpointFile)
150+
}
151+
}
152+
}
153+
154+
}

0 commit comments

Comments
 (0)