1717
1818package org .apache .spark .shuffle .unsafe ;
1919
20+ import org .apache .spark .*;
21+ import org .apache .spark .unsafe .sort .UnsafeExternalSortSpillMerger ;
22+ import org .apache .spark .unsafe .sort .UnsafeExternalSorter ;
2023import scala .Option ;
2124import scala .Product2 ;
2225import scala .reflect .ClassTag ;
3033
3134import com .esotericsoftware .kryo .io .ByteBufferOutputStream ;
3235
33- import org .apache .spark .Partitioner ;
34- import org .apache .spark .ShuffleDependency ;
35- import org .apache .spark .SparkEnv ;
36- import org .apache .spark .TaskContext ;
3736import org .apache .spark .executor .ShuffleWriteMetrics ;
3837import org .apache .spark .scheduler .MapStatus ;
3938import org .apache .spark .scheduler .MapStatus$ ;
5453// IntelliJ gets confused and claims that this class should be abstract, but this actually compiles
5554public class UnsafeShuffleWriter <K , V > implements ShuffleWriter <K , V > {
5655
57- private static final int PAGE_SIZE = 1024 * 1024 ; // TODO: tune this
5856 private static final int SER_BUFFER_SIZE = 1024 * 1024 ; // TODO: tune this
5957 private static final ClassTag <Object > OBJECT_CLASS_TAG = ClassTag$ .MODULE$ .Object ();
6058
@@ -70,9 +68,6 @@ public class UnsafeShuffleWriter<K, V> implements ShuffleWriter<K, V> {
7068 private final int fileBufferSize ;
7169 private MapStatus mapStatus = null ;
7270
73- private MemoryBlock currentPage = null ;
74- private long currentPagePosition = -1 ;
75-
7671 /**
7772 * Are we in the process of stopping? Because map tasks can call stop() with success = true
7873 * and then call stop() with success = false if they get an exception, we want to make sure
@@ -109,39 +104,20 @@ public void write(scala.collection.Iterator<Product2<K, V>> records) {
109104 }
110105 }
111106
112- private void ensureSpaceInDataPage (long requiredSpace ) throws Exception {
113- final long spaceInCurrentPage ;
114- if (currentPage != null ) {
115- spaceInCurrentPage = PAGE_SIZE - (currentPagePosition - currentPage .getBaseOffset ());
116- } else {
117- spaceInCurrentPage = 0 ;
118- }
119- if (requiredSpace > PAGE_SIZE ) {
120- // TODO: throw a more specific exception?
121- throw new Exception ("Required space " + requiredSpace + " is greater than page size (" +
122- PAGE_SIZE + ")" );
123- } else if (requiredSpace > spaceInCurrentPage ) {
124- currentPage = memoryManager .allocatePage (PAGE_SIZE );
125- currentPagePosition = currentPage .getBaseOffset ();
126- allocatedPages .add (currentPage );
127- }
128- }
129-
130107 private void freeMemory () {
131- final Iterator <MemoryBlock > iter = allocatedPages .iterator ();
132- while (iter .hasNext ()) {
133- memoryManager .freePage (iter .next ());
134- iter .remove ();
135- }
108+ // TODO: free sorter memory
136109 }
137110
138- private Iterator <RecordPointerAndKeyPrefix > sortRecords (
139- scala .collection .Iterator <? extends Product2 <K , V >> records ) throws Exception {
140- final UnsafeSorter sorter = new UnsafeSorter (
111+ private Iterator <UnsafeExternalSortSpillMerger . RecordAddressAndKeyPrefix > sortRecords (
112+ scala .collection .Iterator <? extends Product2 <K , V >> records ) throws Exception {
113+ final UnsafeExternalSorter sorter = new UnsafeExternalSorter (
141114 memoryManager ,
115+ SparkEnv$ .MODULE$ .get ().shuffleMemoryManager (),
116+ SparkEnv$ .MODULE$ .get ().blockManager (),
142117 RECORD_COMPARATOR ,
143118 PREFIX_COMPARATOR ,
144- 4096 // Initial size (TODO: tune this!)
119+ 4096 , // Initial size (TODO: tune this!)
120+ SparkEnv$ .MODULE$ .get ().conf ()
145121 );
146122
147123 final byte [] serArray = new byte [SER_BUFFER_SIZE ];
@@ -161,30 +137,16 @@ private Iterator<RecordPointerAndKeyPrefix> sortRecords(
161137
162138 final int serializedRecordSize = serByteBuffer .position ();
163139 assert (serializedRecordSize > 0 );
164- // Need 4 bytes to store the record length.
165- ensureSpaceInDataPage (serializedRecordSize + 4 );
166-
167- final long recordAddress =
168- memoryManager .encodePageNumberAndOffset (currentPage , currentPagePosition );
169- final Object baseObject = currentPage .getBaseObject ();
170- PlatformDependent .UNSAFE .putInt (baseObject , currentPagePosition , serializedRecordSize );
171- currentPagePosition += 4 ;
172- PlatformDependent .copyMemory (
173- serArray ,
174- PlatformDependent .BYTE_ARRAY_OFFSET ,
175- baseObject ,
176- currentPagePosition ,
177- serializedRecordSize );
178- currentPagePosition += serializedRecordSize ;
179140
180- sorter .insertRecord (recordAddress , partitionId );
141+ sorter .insertRecord (
142+ serArray , PlatformDependent .BYTE_ARRAY_OFFSET , serializedRecordSize , partitionId );
181143 }
182144
183145 return sorter .getSortedIterator ();
184146 }
185147
186148 private long [] writeSortedRecordsToFile (
187- Iterator <RecordPointerAndKeyPrefix > sortedRecords ) throws IOException {
149+ Iterator <UnsafeExternalSortSpillMerger . RecordAddressAndKeyPrefix > sortedRecords ) throws IOException {
188150 final File outputFile = shuffleBlockManager .getDataFile (shuffleId , mapId );
189151 final ShuffleBlockId blockId =
190152 new ShuffleBlockId (shuffleId , mapId , IndexShuffleBlockManager .NOOP_REDUCE_ID ());
@@ -195,7 +157,7 @@ private long[] writeSortedRecordsToFile(
195157
196158 final byte [] arr = new byte [SER_BUFFER_SIZE ];
197159 while (sortedRecords .hasNext ()) {
198- final RecordPointerAndKeyPrefix recordPointer = sortedRecords .next ();
160+ final UnsafeExternalSortSpillMerger . RecordAddressAndKeyPrefix recordPointer = sortedRecords .next ();
199161 final int partition = (int ) recordPointer .keyPrefix ;
200162 assert (partition >= currentPartition );
201163 if (partition != currentPartition ) {
@@ -209,17 +171,14 @@ private long[] writeSortedRecordsToFile(
209171 blockManager .getDiskWriter (blockId , outputFile , serializer , fileBufferSize , writeMetrics );
210172 }
211173
212- final Object baseObject = memoryManager .getPage (recordPointer .recordPointer );
213- final long baseOffset = memoryManager .getOffsetInPage (recordPointer .recordPointer );
214- final int recordLength = (int ) PlatformDependent .UNSAFE .getLong (baseObject , baseOffset );
215174 PlatformDependent .copyMemory (
216- baseObject ,
217- baseOffset + 4 ,
175+ recordPointer . baseObject ,
176+ recordPointer . baseOffset + 4 ,
218177 arr ,
219178 PlatformDependent .BYTE_ARRAY_OFFSET ,
220- recordLength );
179+ recordPointer . recordLength );
221180 assert (writer != null ); // To suppress an IntelliJ warning
222- writer .write (arr , 0 , recordLength );
181+ writer .write (arr , 0 , recordPointer . recordLength );
223182 // TODO: add a test that detects whether we leave this call out:
224183 writer .recordWritten ();
225184 }
0 commit comments