Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
*/
private final int numElementsForSpillThreshold;

private boolean resourceCleand = false;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cleand -> cleaned, but is it simpler to call this 'closed'?


/**
* Memory pages that hold the records being sorted. The pages in this list are freed when
* spilling, although in principle we could recycle these pages across spills (on the other hand,
Expand Down Expand Up @@ -324,13 +326,18 @@ public void cleanupResources() {
synchronized (this) {
deleteSpillFiles();
freeMemory();
this.resourceCleand = true;
if (inMemSorter != null) {
inMemSorter.free();
inMemSorter = null;
}
}
}

public boolean isResourceCleand() {
return resourceCleand;
}

/**
* Checks whether there is enough space to insert an additional record in to the sort pointer
* array and grows the array if additional space is required. If the required space cannot be
Expand Down Expand Up @@ -464,7 +471,7 @@ public UnsafeSorterIterator getSortedIterator() throws IOException {
assert(recordComparatorSupplier != null);
if (spillWriters.isEmpty()) {
assert(inMemSorter != null);
readingIterator = new SpillableIterator(inMemSorter.getSortedIterator());
readingIterator = new SpillableIterator(inMemSorter.getSortedIterator(), this);
return readingIterator;
} else {
final UnsafeSorterSpillMerger spillMerger = new UnsafeSorterSpillMerger(
Expand All @@ -473,10 +480,10 @@ public UnsafeSorterIterator getSortedIterator() throws IOException {
spillMerger.addSpillIfNotEmpty(spillWriter.getReader(serializerManager));
}
if (inMemSorter != null) {
readingIterator = new SpillableIterator(inMemSorter.getSortedIterator());
readingIterator = new SpillableIterator(inMemSorter.getSortedIterator(), this);
spillMerger.addSpillIfNotEmpty(readingIterator);
}
return spillMerger.getSortedIterator();
return spillMerger.getSortedIterator(this);
}
}

Expand All @@ -503,12 +510,14 @@ class SpillableIterator extends UnsafeSorterIterator {
private UnsafeSorterIterator upstream;
private UnsafeSorterIterator nextUpstream = null;
private MemoryBlock lastPage = null;
private UnsafeExternalSorter sorter;
private boolean loaded = false;
private int numRecords = 0;

SpillableIterator(UnsafeSorterIterator inMemIterator) {
SpillableIterator(UnsafeSorterIterator inMemIterator, UnsafeExternalSorter sorter) {
this.upstream = inMemIterator;
this.numRecords = inMemIterator.getNumRecords();
this.sorter = sorter;
}

@Override
Expand Down Expand Up @@ -566,7 +575,7 @@ public long spill() throws IOException {

@Override
public boolean hasNext() {
return numRecords > 0;
return !sorter.isResourceCleand() && numRecords > 0;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ public void addSpillIfNotEmpty(UnsafeSorterIterator spillReader) throws IOExcept
}
}

public UnsafeSorterIterator getSortedIterator() throws IOException {
public UnsafeSorterIterator getSortedIterator(UnsafeExternalSorter sorter) throws IOException {
return new UnsafeSorterIterator() {

private UnsafeSorterIterator spillReader;
Expand All @@ -72,7 +72,8 @@ public int getNumRecords() {

@Override
public boolean hasNext() {
return !priorityQueue.isEmpty() || (spillReader != null && spillReader.hasNext());
return !sorter.isResourceCleand()
&& (!priorityQueue.isEmpty() || (spillReader != null && spillReader.hasNext()));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution;

import java.io.Closeable;
import java.io.IOException;

import scala.collection.AbstractIterator;

import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterIterator;

public abstract class UnsafeExternalRowIterator extends AbstractIterator<UnsafeRow>
implements Closeable {

private final UnsafeSorterIterator sortedIterator;
private UnsafeRow row;

UnsafeExternalRowIterator(StructType schema, UnsafeSorterIterator iterator) {
row = new UnsafeRow(schema.length());
sortedIterator = iterator;
}

@Override
public boolean hasNext() {
return sortedIterator.hasNext();
}

@Override
public UnsafeRow next() {
try {
sortedIterator.loadNext();
row.pointTo(
sortedIterator.getBaseObject(),
sortedIterator.getBaseOffset(),
sortedIterator.getRecordLength());
if (!hasNext()) {
UnsafeRow copy = row.copy(); // so that we don't have dangling pointers to freed page
row = null; // so that we don't keep references to the base object
close();
return copy;
} else {
return row;
}
} catch (IOException e) {
close();
// Scala iterators don't declare any checked exceptions, so we need to use this hack
// to re-throw the exception:
Platform.throwException(e);
}
throw new RuntimeException("Exception should have been re-thrown in next()");
}

/**
* Implementation should clean up resources used by this iterator, to prevent memory leaks
*/
@Override
public abstract void close();
}
Original file line number Diff line number Diff line change
Expand Up @@ -95,4 +95,10 @@ public void incPeakExecutionMemory(long size) {
* After it's called, if currentRow is still null, it means no more rows left.
*/
protected abstract void processNext() throws IOException;

/**
* This enables the generate class to implement a method in order to properly release the
* resources if the iterator is not fully consumed. See SPARK-21492 for more details.
*/
public void close() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import java.io.IOException;
import java.util.function.Supplier;

import scala.collection.AbstractIterator;
import scala.collection.Iterator;
import scala.math.Ordering;

Expand All @@ -32,7 +31,6 @@
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.util.collection.unsafe.sort.PrefixComparator;
import org.apache.spark.util.collection.unsafe.sort.RecordComparator;
import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter;
Expand Down Expand Up @@ -169,39 +167,13 @@ public Iterator<UnsafeRow> sort() throws IOException {
// here in order to prevent memory leaks.
cleanupResources();
}
return new AbstractIterator<UnsafeRow>() {

private final int numFields = schema.length();
private UnsafeRow row = new UnsafeRow(numFields);

@Override
public boolean hasNext() {
return sortedIterator.hasNext();
}
return new UnsafeExternalRowIterator(schema, sortedIterator) {

@Override
public UnsafeRow next() {
try {
sortedIterator.loadNext();
row.pointTo(
sortedIterator.getBaseObject(),
sortedIterator.getBaseOffset(),
sortedIterator.getRecordLength());
if (!hasNext()) {
UnsafeRow copy = row.copy(); // so that we don't have dangling pointers to freed page
row = null; // so that we don't keep references to the base object
cleanupResources();
return copy;
} else {
return row;
}
} catch (IOException e) {
cleanupResources();
// Scala iterators don't declare any checked exceptions, so we need to use this hack
// to re-throw the exception:
Platform.throwException(e);
}
throw new RuntimeException("Exception should have been re-thrown in next()");
public void close() {
// Caller should clean up resources if the iterator is not consumed in it's entirety,
// for example, in a SortMergeJoin.
cleanupResources();
}
};
} catch (IOException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,16 @@ case class SortExec(
| ${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
| }
""".stripMargin.trim)
// Override the close method in BufferedRowIterator to release resources if the sortedIterator
// is not fully consumed
ctx.addNewFunction("close",
s"""
| public void close() {
| if ($sortedIterator != null) {
| ((org.apache.spark.sql.execution.UnsafeExternalRowIterator)$sortedIterator).close();
| }
| }
""".stripMargin, true)

val outputRow = ctx.freshName("outputRow")
val peakMemory = metricTerm(ctx, "peakMemory")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,14 @@ object WholeStageCodegenExec {
}
}

/**
* A trait that extends Scala Iterator[InternalRow] which enables exposing the underlying
* BufferedRowIterator
*/
trait ScalaIteratorWithBufferedIterator extends Iterator[InternalRow] {
def getBufferedRowIterator: BufferedRowIterator
}

/**
* WholeStageCodegen compiles a subtree of plans that support codegen together into single Java
* function.
Expand Down Expand Up @@ -721,13 +729,14 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int)
val (clazz, _) = CodeGenerator.compile(cleanedSource)
val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
buffer.init(index, Array(iter))
new Iterator[InternalRow] {
new ScalaIteratorWithBufferedIterator {
override def hasNext: Boolean = {
val v = buffer.hasNext
if (!v) durationMs += buffer.durationMs()
v
}
override def next: InternalRow = buffer.next()
override def getBufferedRowIterator: BufferedRowIterator = buffer
}
}
} else {
Expand All @@ -740,13 +749,14 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int)
val (clazz, _) = CodeGenerator.compile(cleanedSource)
val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
buffer.init(index, Array(leftIter, rightIter))
new Iterator[InternalRow] {
new ScalaIteratorWithBufferedIterator {
override def hasNext: Boolean = {
val v = buffer.hasNext
if (!v) durationMs += buffer.durationMs()
v
}
override def next: InternalRow = buffer.next()
override def getBufferedRowIterator: BufferedRowIterator = buffer
}
}
}
Expand Down
Loading