Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import scala.collection.mutable
import scala.reflect.ClassTag

import org.apache.spark.QueryContext
import org.apache.spark.SparkException.internalError
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedAttribute, UnresolvedSeed}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
Expand All @@ -41,7 +42,6 @@ import org.apache.spark.sql.types._
import org.apache.spark.sql.util.SQLOpenHashSet
import org.apache.spark.unsafe.UTF8StringBuilder
import org.apache.spark.unsafe.array.ByteArrayMethods
import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
import org.apache.spark.unsafe.types.{ByteArray, CalendarInterval, UTF8String}

/**
Expand Down Expand Up @@ -3122,6 +3122,34 @@ case class Sequence(
}

object Sequence {
private def prettyName: String = "sequence"

def sequenceLength(start: Long, stop: Long, step: Long): Int = {
try {
val delta = Math.subtractExact(stop, start)
if (delta == Long.MinValue && step == -1L) {
// We must special-case division of Long.MinValue by -1 to catch potential unchecked
// overflow in next operation. Division does not have a builtin overflow check. We
// previously special-case div-by-zero.
throw new ArithmeticException("Long overflow (Long.MinValue / -1)")
}
val len = if (stop == start) 1L else Math.addExact(1L, (delta / step))
if (len > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
throw QueryExecutionErrors.createArrayWithElementsExceedLimitError(prettyName, len)
}
len.toInt
} catch {
// We handle overflows in the previous try block by raising an appropriate exception.
case _: ArithmeticException =>
val safeLen =
BigInt(1) + (BigInt(stop) - BigInt(start)) / BigInt(step)
if (safeLen > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe just use an assert? Assertion error is also treated as internal errors.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally like the current exception better since it's more descriptive of the actual problem -- trying to create too large an array (with the user's intended size) and what the limit is. If strong opinion, I can change to an assertion.

throw QueryExecutionErrors.createArrayWithElementsExceedLimitError(prettyName, safeLen)
}
throw internalError("Unreachable code reached.")
case e: Exception => throw e
}
}

private type LessThanOrEqualFn = (Any, Any) => Boolean

Expand Down Expand Up @@ -3493,13 +3521,7 @@ object Sequence {
|| (estimatedStep == num.zero && start == stop),
s"Illegal sequence boundaries: $start to $stop by $step")

val len = if (start == stop) 1L else 1L + (stop.toLong - start.toLong) / estimatedStep.toLong

require(
len <= MAX_ROUNDED_ARRAY_LENGTH,
s"Too long sequence: $len. Should be <= $MAX_ROUNDED_ARRAY_LENGTH")

len.toInt
sequenceLength(start.toLong, stop.toLong, estimatedStep.toLong)
}

private def genSequenceLengthCode(
Expand All @@ -3509,20 +3531,15 @@ object Sequence {
step: String,
estimatedStep: String,
len: String): String = {
val longLen = ctx.freshName("longLen")
val calcFn = classOf[Sequence].getName + ".sequenceLength"
s"""
|if (!(($estimatedStep > 0 && $start <= $stop) ||
| ($estimatedStep < 0 && $start >= $stop) ||
| ($estimatedStep == 0 && $start == $stop))) {
| throw new IllegalArgumentException(
| "Illegal sequence boundaries: " + $start + " to " + $stop + " by " + $step);
|}
|long $longLen = $stop == $start ? 1L : 1L + ((long) $stop - $start) / $estimatedStep;
|if ($longLen > $MAX_ROUNDED_ARRAY_LENGTH) {
| throw new IllegalArgumentException(
| "Too long sequence: " + $longLen + ". Should be <= $MAX_ROUNDED_ARRAY_LENGTH");
|}
|int $len = (int) $longLen;
|int $len = $calcFn((long) $start, (long) $stop, (long) $estimatedStep);
""".stripMargin
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.util.IntervalUtils._
import org.apache.spark.sql.errors.DataTypeErrorsBase
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
import org.apache.spark.unsafe.array.ByteArrayMethods
import org.apache.spark.unsafe.types.UTF8String

class CollectionExpressionsSuite
Expand Down Expand Up @@ -795,10 +795,6 @@ class CollectionExpressionsSuite

// test sequence boundaries checking

checkExceptionInExpression[IllegalArgumentException](
new Sequence(Literal(Int.MinValue), Literal(Int.MaxValue), Literal(1)),
EmptyRow, s"Too long sequence: 4294967296. Should be <= $MAX_ROUNDED_ARRAY_LENGTH")

checkExceptionInExpression[IllegalArgumentException](
new Sequence(Literal(1), Literal(2), Literal(0)), EmptyRow, "boundaries: 1 to 2 by 0")
checkExceptionInExpression[IllegalArgumentException](
Expand All @@ -808,6 +804,56 @@ class CollectionExpressionsSuite
checkExceptionInExpression[IllegalArgumentException](
new Sequence(Literal(1), Literal(2), Literal(-1)), EmptyRow, "boundaries: 1 to 2 by -1")

// SPARK-43393: test Sequence overflow checking
checkErrorInExpression[SparkRuntimeException](
new Sequence(Literal(Int.MinValue), Literal(Int.MaxValue), Literal(1)),
errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER",
parameters = Map(
"numberOfElements" -> (BigInt(Int.MaxValue) - BigInt { Int.MinValue } + 1).toString,
"functionName" -> toSQLId("sequence"),
"maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString(),
"parameter" -> toSQLId("count")))
checkErrorInExpression[SparkRuntimeException](
new Sequence(Literal(0L), Literal(Long.MaxValue), Literal(1L)),
errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER",
parameters = Map(
"numberOfElements" -> (BigInt(Long.MaxValue) + 1).toString,
"functionName" -> toSQLId("sequence"),
"maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString(),
"parameter" -> toSQLId("count")))
checkErrorInExpression[SparkRuntimeException](
new Sequence(Literal(0L), Literal(Long.MinValue), Literal(-1L)),
errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER",
parameters = Map(
"numberOfElements" -> ((0 - BigInt(Long.MinValue)) + 1).toString(),
"functionName" -> toSQLId("sequence"),
"maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString(),
"parameter" -> toSQLId("count")))
checkErrorInExpression[SparkRuntimeException](
new Sequence(Literal(Long.MinValue), Literal(Long.MaxValue), Literal(1L)),
errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER",
parameters = Map(
"numberOfElements" -> (BigInt(Long.MaxValue) - BigInt { Long.MinValue } + 1).toString,
"functionName" -> toSQLId("sequence"),
"maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString(),
"parameter" -> toSQLId("count")))
checkErrorInExpression[SparkRuntimeException](
new Sequence(Literal(Long.MaxValue), Literal(Long.MinValue), Literal(-1L)),
errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER",
parameters = Map(
"numberOfElements" -> (BigInt(Long.MaxValue) - BigInt { Long.MinValue } + 1).toString,
"functionName" -> toSQLId("sequence"),
"maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString(),
"parameter" -> toSQLId("count")))
checkErrorInExpression[SparkRuntimeException](
new Sequence(Literal(Long.MaxValue), Literal(-1L), Literal(-1L)),
errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER",
parameters = Map(
"numberOfElements" -> (BigInt(Long.MaxValue) - BigInt { -1L } + 1).toString,
"functionName" -> toSQLId("sequence"),
"maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString(),
"parameter" -> toSQLId("count")))

// test sequence with one element (zero step or equal start and stop)

checkEvaluation(new Sequence(Literal(1), Literal(1), Literal(-1)), Seq(1))
Expand Down