diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index bd0155df0271b..d06ee3946001e 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -44,7 +44,7 @@ use datafusion_expr::{ function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, Signature, Volatility, }; -use datafusion_expr::{Expr, ReversedUDAF}; +use datafusion_expr::{Expr, ReversedUDAF, TypeSignature}; use datafusion_physical_expr_common::aggregate::groups_accumulator::accumulate::accumulate_indices; use datafusion_physical_expr_common::{ aggregate::count_distinct::{ @@ -95,7 +95,11 @@ impl Default for Count { impl Count { pub fn new() -> Self { Self { - signature: Signature::variadic_any(Volatility::Immutable), + signature: Signature::one_of( + // TypeSignature::Any(0) is required to handle `Count()` with no args + vec![TypeSignature::VariadicAny, TypeSignature::Any(0)], + Volatility::Immutable, + ), } } } diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index 959ffdaaa2129..fa8aeb86ed31e 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -35,7 +35,7 @@ pub struct CountWildcardRule {} impl CountWildcardRule { pub fn new() -> Self { - CountWildcardRule {} + Self {} } } @@ -59,14 +59,14 @@ fn is_count_star_aggregate(aggregate_function: &AggregateFunction) -> bool { func_def: AggregateFunctionDefinition::UDF(udf), args, .. - } if udf.name() == "count" && args.len() == 1 && is_wildcard(&args[0])) + } if udf.name() == "count" && (args.len() == 1 && is_wildcard(&args[0]) || args.is_empty())) } fn is_count_star_window_aggregate(window_function: &WindowFunction) -> bool { let args = &window_function.args; matches!(window_function.fun, WindowFunctionDefinition::AggregateUDF(ref udaf) - if udaf.name() == "count" && args.len() == 1 && is_wildcard(&args[0])) + if udaf.name() == "count" && (args.len() == 1 && is_wildcard(&args[0]) || args.is_empty())) } fn analyze_internal(plan: LogicalPlan) -> Result> { diff --git a/datafusion/sqllogictest/test_files/count_star_rule.slt b/datafusion/sqllogictest/test_files/count_star_rule.slt new file mode 100644 index 0000000000000..99d358ad17f02 --- /dev/null +++ b/datafusion/sqllogictest/test_files/count_star_rule.slt @@ -0,0 +1,105 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +statement ok +CREATE TABLE t1 (a INTEGER, b INTEGER, c INTEGER); + +statement ok +INSERT INTO t1 VALUES +(1, 2, 3), +(1, 5, 6), +(2, 3, 5); + +statement ok +CREATE TABLE t2 (a INTEGER, b INTEGER, c INTEGER); + +query TT +EXPLAIN SELECT COUNT() FROM (SELECT 1 AS a, 2 AS b) AS t; +---- +logical_plan +01)Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count()]] +02)--SubqueryAlias: t +03)----EmptyRelation +physical_plan +01)ProjectionExec: expr=[1 as count()] +02)--PlaceholderRowExec + +query TT +EXPLAIN SELECT t1.a, COUNT() FROM t1 GROUP BY t1.a; +---- +logical_plan +01)Aggregate: groupBy=[[t1.a]], aggr=[[count(Int64(1)) AS count()]] +02)--TableScan: t1 projection=[a] +physical_plan +01)AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[count()] +02)--CoalesceBatchesExec: target_batch_size=8192 +03)----RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4 +04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +05)--------AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[count()] +06)----------MemoryExec: partitions=1, partition_sizes=[1] + +query TT +EXPLAIN SELECT t1.a, COUNT() AS cnt FROM t1 GROUP BY t1.a HAVING COUNT() > 0; +---- +logical_plan +01)Projection: t1.a, count() AS cnt +02)--Filter: count() > Int64(0) +03)----Aggregate: groupBy=[[t1.a]], aggr=[[count(Int64(1)) AS count()]] +04)------TableScan: t1 projection=[a] +physical_plan +01)ProjectionExec: expr=[a@0 as a, count()@1 as cnt] +02)--CoalesceBatchesExec: target_batch_size=8192 +03)----FilterExec: count()@1 > 0 +04)------AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[count()] +05)--------CoalesceBatchesExec: target_batch_size=8192 +06)----------RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4 +07)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +08)--------------AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[count()] +09)----------------MemoryExec: partitions=1, partition_sizes=[1] + +query II +SELECT t1.a, COUNT() AS cnt FROM t1 GROUP BY t1.a HAVING COUNT() > 1; +---- +1 2 + +query TT +EXPLAIN SELECT a, COUNT() OVER (PARTITION BY a) AS count_a FROM t1; +---- +logical_plan +01)Projection: t1.a, count() PARTITION BY [t1.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS count_a +02)--WindowAggr: windowExpr=[[count(Int64(1)) PARTITION BY [t1.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS count() PARTITION BY [t1.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] +03)----TableScan: t1 projection=[a] +physical_plan +01)ProjectionExec: expr=[a@0 as a, count() PARTITION BY [t1.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@1 as count_a] +02)--WindowAggExec: wdw=[count() PARTITION BY [t1.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "count() PARTITION BY [t1.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] +03)----SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] +04)------CoalesceBatchesExec: target_batch_size=8192 +05)--------RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=1 +06)----------MemoryExec: partitions=1, partition_sizes=[1] + +query II +SELECT a, COUNT() OVER (PARTITION BY a) AS count_a FROM t1 ORDER BY a; +---- +1 2 +1 2 +2 1 + +statement ok +DROP TABLE t1; + +statement ok +DROP TABLE t2; diff --git a/datafusion/sqllogictest/test_files/errors.slt b/datafusion/sqllogictest/test_files/errors.slt index fa25f00974a9b..be7fdac71b57d 100644 --- a/datafusion/sqllogictest/test_files/errors.slt +++ b/datafusion/sqllogictest/test_files/errors.slt @@ -103,10 +103,6 @@ SELECT power(1, 2, 3); # Wrong window/aggregate function signature # -# AggregateFunction with wrong number of arguments -query error -select count(); - # AggregateFunction with wrong number of arguments query error select avg(c1, c12) from aggregate_test_100;