Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Add to turn any aggregate function into a window function
  • Loading branch information
timsaucer committed Sep 17, 2024
commit d50aacf1dc725e5ae360f7acbb99f657dacf0f5d
37 changes: 37 additions & 0 deletions python/datafusion/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,43 @@ def window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder:
"""
return ExprFuncBuilder(self.expr.window_frame(window_frame.window_frame))

def over(
self,
partition_by: Optional[list[Expr]] = None,
window_frame: Optional[WindowFrame] = None,
order_by: Optional[list[SortExpr | Expr]] = None,
null_treatment: Optional[NullTreatment] = None,
) -> Expr:
"""Turn an aggregate function into a window function.

This function turns any aggregate function into a window function. With the
exception of ``partition_by``, how each of the parameters is used is determined
by the underlying aggregate function.

Args:
partition_by: Expressions to partition the window frame on
window_frame: Specify the window frame parameters
order_by: Set ordering within the window frame
null_treatment: Set how to handle null values
"""
partition_by_raw = expr_list_to_raw_expr_list(partition_by)
order_by_raw = sort_list_to_raw_sort_list(order_by)
window_frame_raw = (
window_frame.window_frame if window_frame is not None else None
)
null_treatment_raw = (
null_treatment.value if null_treatment is not None else None
)

return Expr(
self.expr.over(
partition_by=partition_by_raw,
order_by=order_by_raw,
window_frame=window_frame_raw,
null_treatment=null_treatment_raw,
)
)


class ExprFuncBuilder:
def __init__(self, builder: expr_internal.ExprFuncBuilder):
Expand Down
34 changes: 13 additions & 21 deletions python/datafusion/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,38 +386,30 @@ def test_distinct():
),
[-1, -1, None, 7, -1, -1, None],
),
# TODO update all aggregate functions as windows once upstream merges https://github.com/apache/datafusion-python/issues/833
pytest.param(
(
"first_value",
f.window(
"first_value",
[column("a")],
order_by=[f.order_by(column("b"))],
partition_by=[column("c")],
f.first_value(column("a")).over(
partition_by=[column("c")], order_by=[column("b")]
),
[1, 1, 1, 1, 5, 5, 5],
),
pytest.param(
(
"last_value",
f.window("last_value", [column("a")])
.window_frame(WindowFrame("rows", 0, None))
.order_by(column("b"))
.partition_by(column("c"))
.build(),
f.last_value(column("a")).over(
partition_by=[column("c")],
order_by=[column("b")],
window_frame=WindowFrame("rows", None, None),
),
[3, 3, 3, 3, 6, 6, 6],
),
pytest.param(
(
"3rd_value",
f.window(
"nth_value",
[column("b"), literal(3)],
order_by=[f.order_by(column("a"))],
),
f.nth_value(column("b"), 3).over(order_by=[column("a")]),
[None, None, 7, 7, 7, 7, 7],
),
pytest.param(
(
"avg",
f.round(f.window("avg", [column("b")], order_by=[column("a")]), literal(3)),
f.round(f.avg(column("b")).over(order_by=[column("a")]), literal(3)),
[7.0, 7.0, 7.0, 7.333, 7.75, 7.75, 8.0],
),
]
Expand Down
44 changes: 43 additions & 1 deletion src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
// under the License.

use datafusion::logical_expr::utils::exprlist_to_fields;
use datafusion::logical_expr::{ExprFuncBuilder, ExprFunctionExt, LogicalPlan};
use datafusion::logical_expr::{
ExprFuncBuilder, ExprFunctionExt, LogicalPlan, WindowFunctionDefinition,
};
use pyo3::{basic::CompareOp, prelude::*};
use std::convert::{From, Into};
use std::sync::Arc;
Expand All @@ -39,6 +41,7 @@ use crate::expr::aggregate_expr::PyAggregateFunction;
use crate::expr::binary_expr::PyBinaryExpr;
use crate::expr::column::PyColumn;
use crate::expr::literal::PyLiteral;
use crate::functions::add_builder_fns_to_window;
use crate::sql::logical::PyLogicalPlan;

use self::alias::PyAlias;
Expand Down Expand Up @@ -558,6 +561,45 @@ impl PyExpr {
pub fn window_frame(&self, window_frame: PyWindowFrame) -> PyExprFuncBuilder {
self.expr.clone().window_frame(window_frame.into()).into()
}

#[pyo3(signature = (partition_by=None, window_frame=None, order_by=None, null_treatment=None))]
pub fn over(
&self,
partition_by: Option<Vec<PyExpr>>,
window_frame: Option<PyWindowFrame>,
order_by: Option<Vec<PySortExpr>>,
null_treatment: Option<NullTreatment>,
) -> PyResult<PyExpr> {
match &self.expr {
Expr::AggregateFunction(agg_fn) => {
let window_fn = Expr::WindowFunction(WindowFunction::new(
WindowFunctionDefinition::AggregateUDF(agg_fn.func.clone()),
agg_fn.args.clone(),
));

add_builder_fns_to_window(
window_fn,
partition_by,
window_frame,
order_by,
null_treatment,
)
}
Expr::WindowFunction(_) => add_builder_fns_to_window(
self.expr.clone(),
partition_by,
window_frame,
order_by,
null_treatment,
),
_ => Err(
DataFusionError::ExecutionError(datafusion::error::DataFusionError::Plan(
"Using `over` requires an aggregate function.".to_string(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"Using `over` requires an aggregate function.".to_string(),
format!("Using {} with `over` is not allowed. Must use an aggregate or window function.", self.expr.variant_name())

))
.into(),
),
}
}
}

#[pyclass(name = "ExprFuncBuilder", module = "datafusion.expr", subclass)]
Expand Down
31 changes: 19 additions & 12 deletions src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
// specific language governing permissions and limitations
// under the License.

use std::ptr::null;

use datafusion::functions_aggregate::all_default_aggregate_functions;
use datafusion::logical_expr::window_function;
use datafusion::logical_expr::ExprFunctionExt;
Expand Down Expand Up @@ -711,14 +713,15 @@ pub fn string_agg(
add_builder_fns_to_aggregate(agg_fn, distinct, filter, order_by, null_treatment)
}

fn add_builder_fns_to_window(
pub(crate) fn add_builder_fns_to_window(
window_fn: Expr,
partition_by: Option<Vec<PyExpr>>,
window_frame: Option<PyWindowFrame>,
order_by: Option<Vec<PySortExpr>>,
null_treatment: Option<NullTreatment>,
) -> PyResult<PyExpr> {
// Since ExprFuncBuilder::new() is private, set an empty partition and then
// override later if appropriate.
let mut builder = window_fn.partition_by(vec![]);
let null_treatment = null_treatment.map(|n| n.into());
let mut builder = window_fn.null_treatment(null_treatment);

if let Some(partition_cols) = partition_by {
builder = builder.partition_by(
Expand All @@ -734,6 +737,10 @@ fn add_builder_fns_to_window(
builder = builder.order_by(order_by_cols);
}

if let Some(window_frame) = window_frame {
builder = builder.window_frame(window_frame.into());
}

builder.build().map(|e| e.into()).map_err(|err| err.into())
}

Expand All @@ -748,7 +755,7 @@ pub fn lead(
) -> PyResult<PyExpr> {
let window_fn = window_function::lead(arg.expr, Some(shift_offset), default_value);

add_builder_fns_to_window(window_fn, partition_by, order_by)
add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
}

#[pyfunction]
Expand All @@ -762,7 +769,7 @@ pub fn lag(
) -> PyResult<PyExpr> {
let window_fn = window_function::lag(arg.expr, Some(shift_offset), default_value);

add_builder_fns_to_window(window_fn, partition_by, order_by)
add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
}

#[pyfunction]
Expand All @@ -773,7 +780,7 @@ pub fn row_number(
) -> PyResult<PyExpr> {
let window_fn = datafusion::functions_window::expr_fn::row_number();

add_builder_fns_to_window(window_fn, partition_by, order_by)
add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
}

#[pyfunction]
Expand All @@ -784,7 +791,7 @@ pub fn rank(
) -> PyResult<PyExpr> {
let window_fn = window_function::rank();

add_builder_fns_to_window(window_fn, partition_by, order_by)
add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
}

#[pyfunction]
Expand All @@ -795,7 +802,7 @@ pub fn dense_rank(
) -> PyResult<PyExpr> {
let window_fn = window_function::dense_rank();

add_builder_fns_to_window(window_fn, partition_by, order_by)
add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
}

#[pyfunction]
Expand All @@ -806,7 +813,7 @@ pub fn percent_rank(
) -> PyResult<PyExpr> {
let window_fn = window_function::percent_rank();

add_builder_fns_to_window(window_fn, partition_by, order_by)
add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
}

#[pyfunction]
Expand All @@ -817,7 +824,7 @@ pub fn cume_dist(
) -> PyResult<PyExpr> {
let window_fn = window_function::cume_dist();

add_builder_fns_to_window(window_fn, partition_by, order_by)
add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
}

#[pyfunction]
Expand All @@ -829,7 +836,7 @@ pub fn ntile(
) -> PyResult<PyExpr> {
let window_fn = window_function::ntile(arg.into());

add_builder_fns_to_window(window_fn, partition_by, order_by)
add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
}

pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
Expand Down