diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index ca3bca61d80a2..30d8bc9cce358 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -2212,8 +2212,8 @@ mod tests { ) -> Result { // define schema for data source (csv file) let schema = Arc::new(Schema::new(vec![ - Field::new("c1", DataType::UInt32, false), - Field::new("c2", DataType::UInt64, false), + Field::new("c1", DataType::Int32, false), + Field::new("c2", DataType::Int64, false), Field::new("c3", DataType::Boolean, false), ])); diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index 566baefab7f42..b9041033b99a6 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -610,8 +610,8 @@ fn populate_csv_partitions( ) -> Result { // define schema for data source (csv file) let schema = Arc::new(Schema::new(vec![ - Field::new("c1", DataType::UInt32, false), - Field::new("c2", DataType::UInt64, false), + Field::new("c1", DataType::Int32, false), + Field::new("c2", DataType::Int64, false), Field::new("c3", DataType::Boolean, false), ])); diff --git a/datafusion/physical-expr/src/aggregate/sum.rs b/datafusion/physical-expr/src/aggregate/sum.rs index c369e7af00813..dffda879612a6 100644 --- a/datafusion/physical-expr/src/aggregate/sum.rs +++ b/datafusion/physical-expr/src/aggregate/sum.rs @@ -36,10 +36,11 @@ use datafusion_expr::Accumulator; use crate::aggregate::row_accumulator::RowAccumulator; use crate::expressions::format_state_name; -use arrow::array::Array; use arrow::array::DecimalArray; +use arrow::array::{Array, Float16Array}; use arrow::compute::cast; use datafusion_row::accessor::RowAccessor; +use std::any::type_name; /// SUM aggregate expression #[derive(Debug)] @@ -158,11 +159,9 @@ fn sum_decimal_batch( scale: &usize, ) -> Result { let array = values.as_any().downcast_ref::().unwrap(); - if array.null_count() == array.len() { return Ok(ScalarValue::Decimal128(None, *precision, *scale)); } - let mut result = 0_i128; for i in 0..array.len() { if array.is_valid(i) { @@ -198,18 +197,6 @@ pub(crate) fn sum_batch(values: &ArrayRef, sum_type: &DataType) -> Result {{ - ScalarValue::$SCALAR(match ($OLD_VALUE, $DELTA) { - (None, None) => None, - (Some(a), None) => Some(a.clone()), - (None, Some(b)) => Some(b.clone() as $TYPE), - (Some(a), Some(b)) => Some(a + (*b as $TYPE)), - }) - }}; -} - macro_rules! sum_row { ($INDEX:ident, $ACC:ident, $DELTA:expr, $TYPE:ident) => {{ paste::item! { @@ -262,98 +249,83 @@ fn sum_decimal_with_diff_scale( } } +macro_rules! downcast_arg { + ($ARG:expr, $NAME:expr, $ARRAY_TYPE:ident) => {{ + $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| { + DataFusionError::Internal(format!( + "could not cast {} to {}", + $NAME, + type_name::<$ARRAY_TYPE>() + )) + })? + }}; +} + +macro_rules! union_arrays { + ($LHS: expr, $RHS: expr, $DTYPE: expr, $ARR_DTYPE: ident, $NAME: expr) => {{ + let lhs_casted = &cast(&$LHS.to_array(), $DTYPE)?; + let rhs_casted = &cast(&$RHS.to_array(), $DTYPE)?; + let lhs_prim_array = downcast_arg!(lhs_casted, $NAME, $ARR_DTYPE); + let rhs_prim_array = downcast_arg!(rhs_casted, $NAME, $ARR_DTYPE); + + let chained = lhs_prim_array + .iter() + .chain(rhs_prim_array.iter()) + .collect::<$ARR_DTYPE>(); + + Arc::new(chained) + }}; +} + pub(crate) fn sum(lhs: &ScalarValue, rhs: &ScalarValue) -> Result { - Ok(match (lhs, rhs) { - (ScalarValue::Decimal128(v1, p1, s1), ScalarValue::Decimal128(v2, p2, s2)) => { + let result = match (lhs.get_datatype(), rhs.get_datatype()) { + (DataType::Decimal(p1, s1), DataType::Decimal(p2, s2)) => { let max_precision = p1.max(p2); - if s1.eq(s2) { - // s1 = s2 - sum_decimal(v1, v2, max_precision, s1) - } else if s1.gt(s2) { - // s1 > s2 - sum_decimal_with_diff_scale(v1, v2, max_precision, s1, s2) - } else { - // s1 < s2 - sum_decimal_with_diff_scale(v2, v1, max_precision, s2, s1) + + match (lhs, rhs) { + ( + ScalarValue::Decimal128(v1, _, _), + ScalarValue::Decimal128(v2, _, _), + ) => { + Ok(if s1.eq(&s2) { + // s1 = s2 + sum_decimal(v1, v2, &max_precision, &s1) + } else if s1.gt(&s2) { + // s1 > s2 + sum_decimal_with_diff_scale(v1, v2, &max_precision, &s1, &s2) + } else { + // s1 < s2 + sum_decimal_with_diff_scale(v2, v1, &max_precision, &s2, &s1) + }) + } + _ => Err(DataFusionError::Internal( + "Internal state error on sum decimals ".to_string(), + )), } } - // float64 coerces everything to f64 - (ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => { - typed_sum!(lhs, rhs, Float64, f64) - } - (ScalarValue::Float64(lhs), ScalarValue::Float32(rhs)) => { - typed_sum!(lhs, rhs, Float64, f64) - } - (ScalarValue::Float64(lhs), ScalarValue::Int64(rhs)) => { - typed_sum!(lhs, rhs, Float64, f64) - } - (ScalarValue::Float64(lhs), ScalarValue::Int32(rhs)) => { - typed_sum!(lhs, rhs, Float64, f64) - } - (ScalarValue::Float64(lhs), ScalarValue::Int16(rhs)) => { - typed_sum!(lhs, rhs, Float64, f64) - } - (ScalarValue::Float64(lhs), ScalarValue::Int8(rhs)) => { - typed_sum!(lhs, rhs, Float64, f64) + (DataType::Float64, _) | (_, DataType::Float64) => { + let data: ArrayRef = + union_arrays!(lhs, rhs, &DataType::Float64, Float64Array, "f64"); + sum_batch(&data, &arrow::datatypes::DataType::Float64) } - (ScalarValue::Float64(lhs), ScalarValue::UInt64(rhs)) => { - typed_sum!(lhs, rhs, Float64, f64) - } - (ScalarValue::Float64(lhs), ScalarValue::UInt32(rhs)) => { - typed_sum!(lhs, rhs, Float64, f64) - } - (ScalarValue::Float64(lhs), ScalarValue::UInt16(rhs)) => { - typed_sum!(lhs, rhs, Float64, f64) - } - (ScalarValue::Float64(lhs), ScalarValue::UInt8(rhs)) => { - typed_sum!(lhs, rhs, Float64, f64) - } - // float32 has no cast - (ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => { - typed_sum!(lhs, rhs, Float32, f32) - } - // u64 coerces u* to u64 - (ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => { - typed_sum!(lhs, rhs, UInt64, u64) + (DataType::Float32, _) | (_, DataType::Float32) => { + let data: ArrayRef = + union_arrays!(lhs, rhs, &DataType::Float32, Float32Array, "f32"); + sum_batch(&data, &arrow::datatypes::DataType::Float32) } - (ScalarValue::UInt64(lhs), ScalarValue::UInt32(rhs)) => { - typed_sum!(lhs, rhs, UInt64, u64) + (DataType::Float16, _) | (_, DataType::Float16) => { + let data: ArrayRef = + union_arrays!(lhs, rhs, &DataType::Float16, Float16Array, "f16"); + sum_batch(&data, &arrow::datatypes::DataType::Float16) } - (ScalarValue::UInt64(lhs), ScalarValue::UInt16(rhs)) => { - typed_sum!(lhs, rhs, UInt64, u64) + _ => { + let data: ArrayRef = + union_arrays!(lhs, rhs, &DataType::Int64, Int64Array, "i64"); + sum_batch(&data, &arrow::datatypes::DataType::Int64) } - (ScalarValue::UInt64(lhs), ScalarValue::UInt8(rhs)) => { - typed_sum!(lhs, rhs, UInt64, u64) - } - // i64 coerces i* to i64 - (ScalarValue::Int64(lhs), ScalarValue::Int64(rhs)) => { - typed_sum!(lhs, rhs, Int64, i64) - } - (ScalarValue::Int64(lhs), ScalarValue::Int32(rhs)) => { - typed_sum!(lhs, rhs, Int64, i64) - } - (ScalarValue::Int64(lhs), ScalarValue::Int16(rhs)) => { - typed_sum!(lhs, rhs, Int64, i64) - } - (ScalarValue::Int64(lhs), ScalarValue::Int8(rhs)) => { - typed_sum!(lhs, rhs, Int64, i64) - } - (ScalarValue::Int64(lhs), ScalarValue::UInt32(rhs)) => { - typed_sum!(lhs, rhs, Int64, i64) - } - (ScalarValue::Int64(lhs), ScalarValue::UInt16(rhs)) => { - typed_sum!(lhs, rhs, Int64, i64) - } - (ScalarValue::Int64(lhs), ScalarValue::UInt8(rhs)) => { - typed_sum!(lhs, rhs, Int64, i64) - } - e => { - return Err(DataFusionError::Internal(format!( - "Sum is not expected to receive a scalar {:?}", - e - ))); - } - }) + }?; + + Ok(result) } pub(crate) fn add_to_row( @@ -440,7 +412,12 @@ impl Accumulator for SumAccumulator { } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + let values = &values[0]; + self.sum = sum(&self.sum, &sum_batch(values, &self.sum.get_datatype())?)?; Ok(()) } @@ -668,19 +645,6 @@ mod tests { ) } - #[test] - fn sum_u32() -> Result<()> { - let a: ArrayRef = - Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32])); - generic_test_op!( - a, - DataType::UInt32, - Sum, - ScalarValue::from(15u64), - DataType::UInt64 - ) - } - #[test] fn sum_f32() -> Result<()> { let a: ArrayRef =