Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Use a new flag isEmptyOrNulls to handle the scenario with all nulls a…
…s well and also to use it for identifying overflow scenarios
  • Loading branch information
skambha committed Apr 27, 2020
commit cc2fec066f2ec1bfc7563ef7b37c935608656b11
Original file line number Diff line number Diff line change
Expand Up @@ -62,91 +62,113 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast

private lazy val sum = AttributeReference("sum", sumDataType)()

private lazy val overflow = AttributeReference("overflow", BooleanType, false)()
private lazy val isEmptyOrNulls = AttributeReference("isEmptyOrNulls", BooleanType, false)()

private lazy val zero = Literal.default(sumDataType)

override lazy val aggBufferAttributes = sum :: overflow :: Nil
override lazy val aggBufferAttributes = sum :: isEmptyOrNulls :: Nil

override lazy val initialValues: Seq[Expression] = Seq(
/* sum = */ Literal.create(null, sumDataType),
/* overflow = */ Literal.create(false, BooleanType)
/* sum = */ zero,
/* isEmptyOrNulls = */ Literal.create(true, BooleanType)
)

/**
* For decimal types, update will do the following:
* We have a overflow flag and when it is true, it indicates overflow has happened
* 1. Start initial state with overflow = false, sum = null
* 2. Set sum to null if the value overflows else sum contains the intermediate sum
* 3. If overflow flag is true, keep sum as null
* 4. If overflow happened, then set overflow flag to true
* For decimal types and when child is nullable:
* isEmptyOrNulls flag is a boolean to represent if there are no rows or if all rows that
* have been seen are null. This will be used to identify if the end result of sum in
* evaluateExpression should be null or not.
*
* Update of the isEmptyOrNulls flag:
* If this flag is false, then keep it as is.
* If this flag is true, then check if the incoming value is null and if it is null, keep it
* as true else update it to false.
* Once this flag is switched to false, it will remain false.
*
* The update of the sum is as follows:
* If sum is null, then we have a case of overflow, so keep sum as is.
* If sum is not null, and the incoming value is not null, then perform the addition along
* with the overflow checking. Note, that if overflow occurs, then sum will be null here.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it really necessary? We can let it overflow, and it will become null when we write it out to shuffle files.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it is necessary. Otherwise basically for when the HashAggregate has keys, it can throw an exception when building the hash map if overflow happened.

As an example, one of the examples we have already seen and discussed won't work(will throw an exception) if we remove this. df.select(expr(s"cast('$decimalStr' as decimal (38, 18)) as d"), lit(1).as("key")).groupBy("key").agg(sum($"d")).show
Please see this comment that explained the reasoning: #27627 (comment)

* If the new incoming value is null, we will keep the sum in buffer as is and skip this
* incoming null
*/
override lazy val updateExpressions: Seq[Expression] = {
if (child.nullable) {
resultType match {
case d: DecimalType =>
Seq(
If(overflow, sum, coalesce(
CheckOverflow(coalesce(sum, zero) + child.cast(sumDataType), d, true), sum)),
overflow ||
coalesce(HasOverflow(coalesce(sum, zero) + child.cast(sumDataType), d), false)
/* sum */
If(IsNull(sum), sum,
If(IsNotNull(child.cast(sumDataType)),
CheckOverflow(sum + child.cast(sumDataType), d, true), sum)),
/* isEmptyOrNulls */
If(isEmptyOrNulls, IsNull(child.cast(sumDataType)), isEmptyOrNulls)
)
case _ =>
Seq(
coalesce(sum + child.cast(sumDataType), sum),
If(isEmptyOrNulls, IsNull(child.cast(sumDataType)), isEmptyOrNulls)
)
case _ => Seq(coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum), false)
}
} else {
resultType match {
case d: DecimalType =>
Seq(
If(overflow, sum,
CheckOverflow(coalesce(sum, zero) + child.cast(sumDataType), d, true)),
overflow ||
coalesce(HasOverflow(coalesce(sum, zero) + child.cast(sumDataType), d), false)
/* sum */
If(IsNull(sum), sum, CheckOverflow(sum + child.cast(sumDataType), d, true)),
/* isEmptyOrNulls */
false
)
case _ => Seq(coalesce(sum, zero) + child.cast(sumDataType), false)
case _ => Seq(sum + child.cast(sumDataType), false)
}
}
}

/**
* For decimal type:
* update of the sum is as follows:
* Check if either portion of the left.sum or right.sum has overflowed
* If it has, then the sum value will remain null.
* If it did not have overflow, then add the sum.left and sum.right and check for overflow.
*
* Decimal handling:
* If any of the left or right portion of the agg buffers has the overflow flag to true,
* then sum is set to null else sum is added for both sum.left and sum.right
* and if the value overflows it is set to null.
* If we have already seen overflow , then set overflow to true, else check if the addition
* overflowed and update the overflow buffer.
* isEmptyOrNulls: Set to false if either one of the left or right is set to false. This
* means we have seen atleast a row that was not null.
* If the value from bufferLeft and bufferRight are both true, then this will be true.
*/
override lazy val mergeExpressions: Seq[Expression] = {
resultType match {
case d: DecimalType =>
Seq(
If(coalesce(overflow.left, false) || coalesce(overflow.right, false),
Literal.create(null, d),
coalesce(CheckOverflow(coalesce(sum.left, zero) + sum.right, d, true), sum.left)),
If(coalesce(overflow.left, false) || coalesce(overflow.right, false),
true, HasOverflow(coalesce(sum.left, zero) + sum.right, d))
)
/* sum = */
If(And(IsNull(sum.left), EqualTo(isEmptyOrNulls.left, false)) ||
And(IsNull(sum.right), EqualTo(isEmptyOrNulls.right, false)),
Literal.create(null, resultType),
CheckOverflow(sum.left + sum.right, d, true)),
/* isEmptyOrNulls = */
And(isEmptyOrNulls.left, isEmptyOrNulls.right)
)
case _ =>
Seq(
coalesce(coalesce(sum.left, zero) + sum.right, sum.left),
false
coalesce(sum.left + sum.right, sum.left),
And(isEmptyOrNulls.left, isEmptyOrNulls.right)
)
}
}

/**
* Decimal handling:
* If overflow buffer is true, and if ansiEnabled is true then throw exception, else return null
* If overflow did not happen, then return the sum value
* If the isEmptyOrNulls is true, then it means either there are no rows, or all the rows were
* null, so the result will be null.
* If the isEmptyOrNulls is false, then if sum is null that means an overflow has happened.
* So now, if ansi is enabled, then throw exception, if not then return null.
* If sum is not null, then return the sum.
*/
override lazy val evaluateExpression: Expression = resultType match {
case d: DecimalType =>
If(EqualTo(overflow, true),
If(!SQLConf.get.ansiEnabled,
Literal.create(null, sumDataType),
OverflowException(resultType, "Arithmetic Operation overflow")),
sum)
case _ => sum
If(EqualTo(isEmptyOrNulls, true),
Literal.create(null, sumDataType),
If(And(SQLConf.get.ansiEnabled, IsNull(sum)),
OverflowException(resultType, "Arithmetic Operation overflow"), sum))
case _ => If(EqualTo(isEmptyOrNulls, true), Literal.create(null, resultType), sum)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -146,35 +146,6 @@ case class CheckOverflow(
override def sql: String = child.sql
}

case class HasOverflow(
child: Expression,
inputType: DecimalType) extends UnaryExpression {

override def dataType: DataType = BooleanType

override def nullable: Boolean = true

override def nullSafeEval(input: Any): Any =
!input.asInstanceOf[Decimal].changePrecision(
inputType.precision,
inputType.scale,
Decimal.ROUND_HALF_UP)

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, eval => {
s"""
|${ev.value} = !$eval.changePrecision(
| ${inputType.precision}, ${inputType.scale}, Decimal.ROUND_HALF_UP());
|${ev.isNull} = false;
""".stripMargin
})
}

override def toString: String = s"HasOverflow($child, $inputType)"

override def sql: String = child.sql
}

case class OverflowException(dtype: DataType, msg: String) extends LeafExpression {

override def dataType: DataType = dtype
Expand Down
8 changes: 4 additions & 4 deletions sql/core/src/test/resources/sql-tests/results/explain.sql.out
Original file line number Diff line number Diff line change
Expand Up @@ -918,15 +918,15 @@ Input [2]: [key#x, val#x]
Input [2]: [key#x, val#x]
Keys: []
Functions [3]: [partial_count(val#x), partial_sum(cast(key#x as bigint)), partial_count(key#x) FILTER (WHERE (val#x > 1))]
Aggregate Attributes [4]: [count#xL, sum#xL, overflow#x, count#xL]
Results [4]: [count#xL, sum#xL, overflow#x, count#xL]
Aggregate Attributes [4]: [count#xL, sum#xL, isEmptyOrNulls#x, count#xL]
Results [4]: [count#xL, sum#xL, isEmptyOrNulls#x, count#xL]

(4) Exchange
Input [4]: [count#xL, sum#xL, overflow#x, count#xL]
Input [4]: [count#xL, sum#xL, isEmptyOrNulls#x, count#xL]
Arguments: SinglePartition, true, [id=#x]

(5) HashAggregate [codegen id : 2]
Input [4]: [count#xL, sum#xL, overflow#x, count#xL]
Input [4]: [count#xL, sum#xL, isEmptyOrNulls#x, count#xL]
Keys: []
Functions [3]: [count(val#x), sum(cast(key#x as bigint)), count(key#x)]
Aggregate Attributes [3]: [count(val#x)#xL, sum(cast(key#x as bigint))#xL, count(key#x)#xL]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,6 @@ class DataFrameSuite extends QueryTest
withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled)) {
val df = spark.range(1, 4, 1).select(expr(s"cast(null as decimal(38,18)) as d"))
checkAnswer(df.agg(sum($"d")), Row(null))
df.agg(sum($"d")).show
}
}
}
Expand Down Expand Up @@ -267,6 +266,12 @@ class DataFrameSuite extends QueryTest
val d5 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as d"),
lit(1).as("key")).groupBy("key").agg(sum($"d").alias("sumd")).select($"sumd")
checkAnsi(d5, ansiEnabled)

val nullsDf = spark.range(1, 4, 1).select(expr(s"cast(null as decimal(38,18)) as d"))

val largeDecimals = Seq(BigDecimal("1"* 20 + ".123"), BigDecimal("9"* 20 + ".123")).
toDF("d")
checkAnsi(nullsDf.union(largeDecimals).agg(sum($"d")), ansiEnabled)
}
}
}
Expand Down