Skip to content

Commit 5e100b2

Browse files
committed
Super-messy WIP on external sort
1 parent 595923a commit 5e100b2

File tree

8 files changed

+663
-77
lines changed

8 files changed

+663
-77
lines changed

core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java

Lines changed: 19 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717

1818
package org.apache.spark.shuffle.unsafe;
1919

20+
import org.apache.spark.*;
21+
import org.apache.spark.unsafe.sort.UnsafeExternalSortSpillMerger;
22+
import org.apache.spark.unsafe.sort.UnsafeExternalSorter;
2023
import scala.Option;
2124
import scala.Product2;
2225
import scala.reflect.ClassTag;
@@ -30,10 +33,6 @@
3033

3134
import com.esotericsoftware.kryo.io.ByteBufferOutputStream;
3235

33-
import org.apache.spark.Partitioner;
34-
import org.apache.spark.ShuffleDependency;
35-
import org.apache.spark.SparkEnv;
36-
import org.apache.spark.TaskContext;
3736
import org.apache.spark.executor.ShuffleWriteMetrics;
3837
import org.apache.spark.scheduler.MapStatus;
3938
import org.apache.spark.scheduler.MapStatus$;
@@ -54,7 +53,6 @@
5453
// IntelliJ gets confused and claims that this class should be abstract, but this actually compiles
5554
public class UnsafeShuffleWriter<K, V> implements ShuffleWriter<K, V> {
5655

57-
private static final int PAGE_SIZE = 1024 * 1024; // TODO: tune this
5856
private static final int SER_BUFFER_SIZE = 1024 * 1024; // TODO: tune this
5957
private static final ClassTag<Object> OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object();
6058

@@ -70,9 +68,6 @@ public class UnsafeShuffleWriter<K, V> implements ShuffleWriter<K, V> {
7068
private final int fileBufferSize;
7169
private MapStatus mapStatus = null;
7270

73-
private MemoryBlock currentPage = null;
74-
private long currentPagePosition = -1;
75-
7671
/**
7772
* Are we in the process of stopping? Because map tasks can call stop() with success = true
7873
* and then call stop() with success = false if they get an exception, we want to make sure
@@ -109,39 +104,20 @@ public void write(scala.collection.Iterator<Product2<K, V>> records) {
109104
}
110105
}
111106

112-
private void ensureSpaceInDataPage(long requiredSpace) throws Exception {
113-
final long spaceInCurrentPage;
114-
if (currentPage != null) {
115-
spaceInCurrentPage = PAGE_SIZE - (currentPagePosition - currentPage.getBaseOffset());
116-
} else {
117-
spaceInCurrentPage = 0;
118-
}
119-
if (requiredSpace > PAGE_SIZE) {
120-
// TODO: throw a more specific exception?
121-
throw new Exception("Required space " + requiredSpace + " is greater than page size (" +
122-
PAGE_SIZE + ")");
123-
} else if (requiredSpace > spaceInCurrentPage) {
124-
currentPage = memoryManager.allocatePage(PAGE_SIZE);
125-
currentPagePosition = currentPage.getBaseOffset();
126-
allocatedPages.add(currentPage);
127-
}
128-
}
129-
130107
private void freeMemory() {
131-
final Iterator<MemoryBlock> iter = allocatedPages.iterator();
132-
while (iter.hasNext()) {
133-
memoryManager.freePage(iter.next());
134-
iter.remove();
135-
}
108+
// TODO: free sorter memory
136109
}
137110

138-
private Iterator<RecordPointerAndKeyPrefix> sortRecords(
139-
scala.collection.Iterator<? extends Product2<K, V>> records) throws Exception {
140-
final UnsafeSorter sorter = new UnsafeSorter(
111+
private Iterator<UnsafeExternalSortSpillMerger.RecordAddressAndKeyPrefix> sortRecords(
112+
scala.collection.Iterator<? extends Product2<K, V>> records) throws Exception {
113+
final UnsafeExternalSorter sorter = new UnsafeExternalSorter(
141114
memoryManager,
115+
SparkEnv$.MODULE$.get().shuffleMemoryManager(),
116+
SparkEnv$.MODULE$.get().blockManager(),
142117
RECORD_COMPARATOR,
143118
PREFIX_COMPARATOR,
144-
4096 // Initial size (TODO: tune this!)
119+
4096, // Initial size (TODO: tune this!)
120+
SparkEnv$.MODULE$.get().conf()
145121
);
146122

147123
final byte[] serArray = new byte[SER_BUFFER_SIZE];
@@ -161,30 +137,16 @@ private Iterator<RecordPointerAndKeyPrefix> sortRecords(
161137

162138
final int serializedRecordSize = serByteBuffer.position();
163139
assert (serializedRecordSize > 0);
164-
// Need 4 bytes to store the record length.
165-
ensureSpaceInDataPage(serializedRecordSize + 4);
166-
167-
final long recordAddress =
168-
memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition);
169-
final Object baseObject = currentPage.getBaseObject();
170-
PlatformDependent.UNSAFE.putInt(baseObject, currentPagePosition, serializedRecordSize);
171-
currentPagePosition += 4;
172-
PlatformDependent.copyMemory(
173-
serArray,
174-
PlatformDependent.BYTE_ARRAY_OFFSET,
175-
baseObject,
176-
currentPagePosition,
177-
serializedRecordSize);
178-
currentPagePosition += serializedRecordSize;
179140

180-
sorter.insertRecord(recordAddress, partitionId);
141+
sorter.insertRecord(
142+
serArray, PlatformDependent.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId);
181143
}
182144

183145
return sorter.getSortedIterator();
184146
}
185147

186148
private long[] writeSortedRecordsToFile(
187-
Iterator<RecordPointerAndKeyPrefix> sortedRecords) throws IOException {
149+
Iterator<UnsafeExternalSortSpillMerger.RecordAddressAndKeyPrefix> sortedRecords) throws IOException {
188150
final File outputFile = shuffleBlockManager.getDataFile(shuffleId, mapId);
189151
final ShuffleBlockId blockId =
190152
new ShuffleBlockId(shuffleId, mapId, IndexShuffleBlockManager.NOOP_REDUCE_ID());
@@ -195,7 +157,7 @@ private long[] writeSortedRecordsToFile(
195157

196158
final byte[] arr = new byte[SER_BUFFER_SIZE];
197159
while (sortedRecords.hasNext()) {
198-
final RecordPointerAndKeyPrefix recordPointer = sortedRecords.next();
160+
final UnsafeExternalSortSpillMerger.RecordAddressAndKeyPrefix recordPointer = sortedRecords.next();
199161
final int partition = (int) recordPointer.keyPrefix;
200162
assert (partition >= currentPartition);
201163
if (partition != currentPartition) {
@@ -209,17 +171,14 @@ private long[] writeSortedRecordsToFile(
209171
blockManager.getDiskWriter(blockId, outputFile, serializer, fileBufferSize, writeMetrics);
210172
}
211173

212-
final Object baseObject = memoryManager.getPage(recordPointer.recordPointer);
213-
final long baseOffset = memoryManager.getOffsetInPage(recordPointer.recordPointer);
214-
final int recordLength = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset);
215174
PlatformDependent.copyMemory(
216-
baseObject,
217-
baseOffset + 4,
175+
recordPointer.baseObject,
176+
recordPointer.baseOffset + 4,
218177
arr,
219178
PlatformDependent.BYTE_ARRAY_OFFSET,
220-
recordLength);
179+
recordPointer.recordLength);
221180
assert (writer != null); // To suppress an IntelliJ warning
222-
writer.write(arr, 0, recordLength);
181+
writer.write(arr, 0, recordPointer.recordLength);
223182
// TODO: add a test that detects whether we leave this call out:
224183
writer.recordWritten();
225184
}
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.unsafe.sort;
19+
20+
import java.util.Comparator;
21+
import java.util.Iterator;
22+
import java.util.PriorityQueue;
23+
24+
import static org.apache.spark.unsafe.sort.UnsafeSorter.*;
25+
26+
public final class UnsafeExternalSortSpillMerger {
27+
28+
private final PriorityQueue<MergeableIterator> priorityQueue;
29+
30+
public static abstract class MergeableIterator {
31+
public abstract boolean hasNext();
32+
33+
public abstract void advanceRecord();
34+
35+
public abstract long getPrefix();
36+
37+
public abstract Object getBaseObject();
38+
39+
public abstract long getBaseOffset();
40+
}
41+
42+
public static final class RecordAddressAndKeyPrefix {
43+
public Object baseObject;
44+
public long baseOffset;
45+
public int recordLength;
46+
public long keyPrefix;
47+
}
48+
49+
public UnsafeExternalSortSpillMerger(
50+
final RecordComparator recordComparator,
51+
final UnsafeSorter.PrefixComparator prefixComparator) {
52+
final Comparator<MergeableIterator> comparator = new Comparator<MergeableIterator>() {
53+
54+
@Override
55+
public int compare(MergeableIterator left, MergeableIterator right) {
56+
final int prefixComparisonResult =
57+
prefixComparator.compare(left.getPrefix(), right.getPrefix());
58+
if (prefixComparisonResult == 0) {
59+
return recordComparator.compare(
60+
left.getBaseObject(), left.getBaseOffset(),
61+
right.getBaseObject(), right.getBaseOffset());
62+
} else {
63+
return prefixComparisonResult;
64+
}
65+
}
66+
};
67+
priorityQueue = new PriorityQueue<MergeableIterator>(10, comparator);
68+
}
69+
70+
public void addSpill(MergeableIterator spillReader) {
71+
priorityQueue.add(spillReader);
72+
}
73+
74+
public Iterator<RecordAddressAndKeyPrefix> getSortedIterator() {
75+
return new Iterator<RecordAddressAndKeyPrefix>() {
76+
77+
private MergeableIterator spillReader;
78+
private final RecordAddressAndKeyPrefix record = new RecordAddressAndKeyPrefix();
79+
80+
@Override
81+
public boolean hasNext() {
82+
return spillReader.hasNext() || !priorityQueue.isEmpty();
83+
}
84+
85+
@Override
86+
public RecordAddressAndKeyPrefix next() {
87+
if (spillReader != null) {
88+
if (spillReader.hasNext()) {
89+
priorityQueue.add(spillReader);
90+
}
91+
}
92+
spillReader = priorityQueue.poll();
93+
record.baseObject = spillReader.getBaseObject();
94+
record.baseOffset = spillReader.getBaseOffset();
95+
record.keyPrefix = spillReader.getPrefix();
96+
return record;
97+
}
98+
99+
@Override
100+
public void remove() {
101+
throw new UnsupportedOperationException();
102+
}
103+
};
104+
}
105+
106+
}

0 commit comments

Comments
 (0)