Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Rename Window to WindowExpr so we can define Window to mean a window …
…definition to be reused
  • Loading branch information
timsaucer committed Sep 17, 2024
commit 64d341549d2c10eb9ebd9fb3031ac7a03f8dd104
50 changes: 34 additions & 16 deletions python/datafusion/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@
Union = expr_internal.Union
Unnest = expr_internal.Unnest
UnnestExpr = expr_internal.UnnestExpr
Window = expr_internal.Window
WindowExpr = expr_internal.WindowExpr

__all__ = [
"Expr",
Expand Down Expand Up @@ -154,6 +154,7 @@
"Partitioning",
"Repartition",
"Window",
"WindowExpr",
"WindowFrame",
"WindowFrameBound",
]
Expand Down Expand Up @@ -542,32 +543,25 @@ def window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder:
"""
return ExprFuncBuilder(self.expr.window_frame(window_frame.window_frame))

def over(
self,
partition_by: Optional[list[Expr]] = None,
window_frame: Optional[WindowFrame] = None,
order_by: Optional[list[SortExpr | Expr]] = None,
null_treatment: Optional[NullTreatment] = None,
) -> Expr:
def over(self, window: Window) -> Expr:
"""Turn an aggregate function into a window function.

This function turns any aggregate function into a window function. With the
exception of ``partition_by``, how each of the parameters is used is determined
by the underlying aggregate function.

Args:
partition_by: Expressions to partition the window frame on
window_frame: Specify the window frame parameters
order_by: Set ordering within the window frame
null_treatment: Set how to handle null values
window: Window definition
"""
partition_by_raw = expr_list_to_raw_expr_list(partition_by)
order_by_raw = sort_list_to_raw_sort_list(order_by)
partition_by_raw = expr_list_to_raw_expr_list(window._partition_by)
order_by_raw = sort_list_to_raw_sort_list(window._order_by)
window_frame_raw = (
window_frame.window_frame if window_frame is not None else None
window._window_frame.window_frame
if window._window_frame is not None
else None
)
null_treatment_raw = (
null_treatment.value if null_treatment is not None else None
window._null_treatment.value if window._null_treatment is not None else None
)

return Expr(
Expand Down Expand Up @@ -621,6 +615,30 @@ def build(self) -> Expr:
return Expr(self.builder.build())


class Window:
"""Define reusable window parameters."""

def __init__(
self,
partition_by: Optional[list[Expr]] = None,
window_frame: Optional[WindowFrame] = None,
order_by: Optional[list[SortExpr | Expr]] = None,
null_treatment: Optional[NullTreatment] = None,
) -> None:
"""Construct a window definition.

Args:
partition_by: Partitions for window operation
window_frame: Define the start and end bounds of the window frame
order_by: Set ordering
null_treatment: Indicate how nulls are to be treated
"""
self._partition_by = partition_by
self._window_frame = window_frame
self._order_by = order_by
self._null_treatment = null_treatment


class WindowFrame:
"""Defines a window frame for performing window operations."""

Expand Down
15 changes: 9 additions & 6 deletions python/datafusion/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
literal,
udf,
)
from datafusion.expr import Window


@pytest.fixture
Expand Down Expand Up @@ -389,27 +390,29 @@ def test_distinct():
(
"first_value",
f.first_value(column("a")).over(
partition_by=[column("c")], order_by=[column("b")]
Window(partition_by=[column("c")], order_by=[column("b")])
),
[1, 1, 1, 1, 5, 5, 5],
),
(
"last_value",
f.last_value(column("a")).over(
partition_by=[column("c")],
order_by=[column("b")],
window_frame=WindowFrame("rows", None, None),
Window(
partition_by=[column("c")],
order_by=[column("b")],
window_frame=WindowFrame("rows", None, None),
)
),
[3, 3, 3, 3, 6, 6, 6],
),
(
"3rd_value",
f.nth_value(column("b"), 3).over(order_by=[column("a")]),
f.nth_value(column("b"), 3).over(Window(order_by=[column("a")])),
[None, None, 7, 7, 7, 7, 7],
),
(
"avg",
f.round(f.avg(column("b")).over(order_by=[column("a")]), literal(3)),
f.round(f.avg(column("b")).over(Window(order_by=[column("a")])), literal(3)),
[7.0, 7.0, 7.0, 7.333, 7.75, 7.75, 8.0],
),
]
Expand Down
2 changes: 1 addition & 1 deletion src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -791,7 +791,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<drop_table::PyDropTable>()?;
m.add_class::<repartition::PyPartitioning>()?;
m.add_class::<repartition::PyRepartition>()?;
m.add_class::<window::PyWindow>()?;
m.add_class::<window::PyWindowExpr>()?;
m.add_class::<window::PyWindowFrame>()?;
m.add_class::<window::PyWindowFrameBound>()?;
Ok(())
Expand Down
20 changes: 10 additions & 10 deletions src/expr/window.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ use super::py_expr_list;

use crate::errors::py_datafusion_err;

#[pyclass(name = "Window", module = "datafusion.expr", subclass)]
#[pyclass(name = "WindowExpr", module = "datafusion.expr", subclass)]
#[derive(Clone)]
pub struct PyWindow {
pub struct PyWindowExpr {
window: Window,
}

Expand Down Expand Up @@ -62,15 +62,15 @@ pub struct PyWindowFrameBound {
frame_bound: WindowFrameBound,
}

impl From<PyWindow> for Window {
fn from(window: PyWindow) -> Window {
impl From<PyWindowExpr> for Window {
fn from(window: PyWindowExpr) -> Window {
window.window
}
}

impl From<Window> for PyWindow {
fn from(window: Window) -> PyWindow {
PyWindow { window }
impl From<Window> for PyWindowExpr {
fn from(window: Window) -> PyWindowExpr {
PyWindowExpr { window }
}
}

Expand All @@ -80,7 +80,7 @@ impl From<WindowFrameBound> for PyWindowFrameBound {
}
}

impl Display for PyWindow {
impl Display for PyWindowExpr {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(
f,
Expand All @@ -103,7 +103,7 @@ impl Display for PyWindowFrame {
}

#[pymethods]
impl PyWindow {
impl PyWindowExpr {
/// Returns the schema of the Window
pub fn schema(&self) -> PyResult<PyDFSchema> {
Ok(self.window.schema.as_ref().clone().into())
Expand Down Expand Up @@ -283,7 +283,7 @@ impl PyWindowFrameBound {
}
}

impl LogicalNode for PyWindow {
impl LogicalNode for PyWindowExpr {
fn inputs(&self) -> Vec<PyLogicalPlan> {
vec![self.window.input.as_ref().clone().into()]
}
Expand Down
2 changes: 0 additions & 2 deletions src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
// specific language governing permissions and limitations
// under the License.

use std::ptr::null;

use datafusion::functions_aggregate::all_default_aggregate_functions;
use datafusion::logical_expr::window_function;
use datafusion::logical_expr::ExprFunctionExt;
Expand Down
4 changes: 2 additions & 2 deletions src/sql/logical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ use crate::expr::subquery::PySubquery;
use crate::expr::subquery_alias::PySubqueryAlias;
use crate::expr::table_scan::PyTableScan;
use crate::expr::unnest::PyUnnest;
use crate::expr::window::PyWindow;
use crate::expr::window::PyWindowExpr;
use datafusion::logical_expr::LogicalPlan;
use pyo3::prelude::*;

Expand Down Expand Up @@ -80,7 +80,7 @@ impl PyLogicalPlan {
LogicalPlan::Subquery(plan) => PySubquery::from(plan.clone()).to_variant(py),
LogicalPlan::SubqueryAlias(plan) => PySubqueryAlias::from(plan.clone()).to_variant(py),
LogicalPlan::Unnest(plan) => PyUnnest::from(plan.clone()).to_variant(py),
LogicalPlan::Window(plan) => PyWindow::from(plan.clone()).to_variant(py),
LogicalPlan::Window(plan) => PyWindowExpr::from(plan.clone()).to_variant(py),
LogicalPlan::Repartition(_)
| LogicalPlan::Union(_)
| LogicalPlan::Statement(_)
Expand Down