Skip to content

Commit e77c82f

Browse files
committed
Add support for AggregateFunction
1 parent 04c9f69 commit e77c82f

File tree

7 files changed

+226
-3
lines changed

7 files changed

+226
-3
lines changed

clickhouse_sqlalchemy/drivers/base.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from .compilers.sqlcompiler import ClickHouseSQLCompiler
1414
from .compilers.typecompiler import ClickHouseTypeCompiler
1515
from .reflection import ClickHouseInspector
16-
from .util import get_inner_spec
16+
from .util import get_inner_spec, parse_arguments
1717
from .. import types
1818

1919
# Column specifications
@@ -54,6 +54,8 @@
5454
'_lowcardinality': types.LowCardinality,
5555
'_tuple': types.Tuple,
5656
'_map': types.Map,
57+
'_aggregatefunction': types.AggregateFunction,
58+
'_simpleaggregatefunction': types.SimpleAggregateFunction,
5759
}
5860

5961

@@ -230,6 +232,32 @@ def _get_column_type(self, name, spec):
230232
coltype = self.ischema_names['_lowcardinality']
231233
return coltype(self._get_column_type(name, inner))
232234

235+
elif spec.startswith('AggregateFunction'):
236+
params = spec[18:-1]
237+
238+
arguments = parse_arguments(params)
239+
agg_func, inner = arguments[0], arguments[1:]
240+
241+
inner_types = [
242+
self._get_column_type(name, param)
243+
for param in inner
244+
]
245+
coltype = self.ischema_names['_aggregatefunction']
246+
return coltype(agg_func, *inner_types)
247+
248+
elif spec.startswith('SimpleAggregateFunction'):
249+
params = spec[24:-1]
250+
251+
arguments = parse_arguments(params)
252+
agg_func, inner = arguments[0], arguments[1:]
253+
254+
inner_types = [
255+
self._get_column_type(name, param)
256+
for param in inner
257+
]
258+
coltype = self.ischema_names['_simpleaggregatefunction']
259+
return coltype(agg_func, *inner_types)
260+
233261
elif spec.startswith('Tuple'):
234262
inner = spec[6:-1]
235263
coltype = self.ischema_names['_tuple']

clickhouse_sqlalchemy/drivers/compilers/typecompiler.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,3 +131,29 @@ def visit_map(self, type_, **kw):
131131
self.process(key_type, **kw),
132132
self.process(value_type, **kw)
133133
)
134+
135+
def visit_aggregatefunction(self, type_, **kw):
136+
types = [type_api.to_instance(val) for val in type_.nested_types]
137+
type_strings = [self.process(val, **kw) for val in types]
138+
139+
if isinstance(type_.agg_func, str):
140+
agg_str = type_.agg_func
141+
else:
142+
agg_str = str(type_.agg_func.compile(dialect=self.dialect))
143+
144+
return "AggregateFunction(%s, %s)" % (
145+
agg_str, ", ".join(type_strings)
146+
)
147+
148+
def visit_simpleaggregatefunction(self, type_, **kw):
149+
types = [type_api.to_instance(val) for val in type_.nested_types]
150+
type_strings = [self.process(val, **kw) for val in types]
151+
152+
if isinstance(type_.agg_func, str):
153+
agg_str = type_.agg_func
154+
else:
155+
agg_str = str(type_.agg_func.compile(dialect=self.dialect))
156+
157+
return "SimpleAggregateFunction(%s, %s)" % (
158+
agg_str, ", ".join(type_strings)
159+
)

clickhouse_sqlalchemy/types/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
'Nested',
3434
'Tuple',
3535
'Map',
36+
'AggregateFunction',
37+
'SimpleAggregateFunction',
3638
]
3739

3840
from .common import String
@@ -66,6 +68,8 @@
6668
from .common import Decimal
6769
from .common import Tuple
6870
from .common import Map
71+
from .common import AggregateFunction
72+
from .common import SimpleAggregateFunction
6973
from .ip import IPv4
7074
from .ip import IPv6
7175
from .nested import Nested

clickhouse_sqlalchemy/types/common.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1-
from sqlalchemy.sql.type_api import to_instance
1+
from typing import Type, Union
2+
23
from sqlalchemy import types
4+
from sqlalchemy.sql.functions import Function
5+
from sqlalchemy.sql.type_api import to_instance
36

47

58
class ClickHouseTypeEngine(types.TypeEngine):
@@ -197,3 +200,49 @@ def __init__(self, key_type, value_type):
197200
self.key_type = key_type
198201
self.value_type = value_type
199202
super(Map, self).__init__()
203+
204+
205+
class AggregateFunction(ClickHouseTypeEngine):
206+
__visit_name__ = 'aggregatefunction'
207+
208+
def __init__(
209+
self,
210+
agg_func: Union[Function, str],
211+
*nested_types: Union[Type[ClickHouseTypeEngine], ClickHouseTypeEngine],
212+
):
213+
self.agg_func = agg_func
214+
self.nested_types = [to_instance(val) for val in nested_types]
215+
super(AggregateFunction, self).__init__()
216+
217+
def __repr__(self) -> str:
218+
type_strs = [f'{val.__module__}.{val!r}' for val in self.nested_types]
219+
220+
if isinstance(self.agg_func, str):
221+
agg_str = self.agg_func
222+
else:
223+
agg_str = f'sa.func.{self.agg_func}'
224+
225+
return f"AggregateFunction({agg_str}, {', '.join(type_strs)})"
226+
227+
228+
class SimpleAggregateFunction(ClickHouseTypeEngine):
229+
__visit_name__ = 'simpleaggregatefunction'
230+
231+
def __init__(
232+
self,
233+
agg_func: Union[Function, str],
234+
*nested_types: Union[Type[ClickHouseTypeEngine], ClickHouseTypeEngine],
235+
):
236+
self.agg_func = agg_func
237+
self.nested_types = [to_instance(val) for val in nested_types]
238+
super(SimpleAggregateFunction, self).__init__()
239+
240+
def __repr__(self) -> str:
241+
type_strs = [f'{val.__module__}.{val!r}' for val in self.nested_types]
242+
243+
if isinstance(self.agg_func, str):
244+
agg_str = self.agg_func
245+
else:
246+
agg_str = f'sa.func.{self.agg_func}'
247+
248+
return f"SimpleAggregateFunction({agg_str}, {', '.join(type_strs)})"

docs/features.rst

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,26 @@ You can specify cluster for materialized view in inner table definition.
589589
{'clickhouse_cluster': 'my_cluster'}
590590
)
591591
592+
Materialized views can also store the aggregated data in a table using the
593+
``AggregatingMergeTree`` engine. The aggregate columns are defined using
594+
``AggregateFunction`` or ``SimpleAggregateFunction``.
595+
596+
.. code-block:: python
597+
598+
599+
# Define storage for Materialized View
600+
class GroupedStatistics(Base):
601+
date = Column(types.Date, primary_key=True)
602+
metric1 = Column(SimpleAggregateFunction(sa.func.sum(), types.Int32), nullable=False)
603+
604+
__table_args__ = (
605+
engines.AggregatingMergeTree(
606+
partition_by=func.toYYYYMM(date),
607+
order_by=(date, )
608+
),
609+
)
610+
611+
592612
Basic DDL support
593613
-----------------
594614

tests/test_ddl.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,39 @@ def test_create_table_map(self):
311311
'ENGINE = Memory'
312312
)
313313

314+
def test_create_aggregate_function(self):
315+
table = Table(
316+
't1', self.metadata(),
317+
Column('total', types.AggregateFunction(func.sum(), types.UInt32)),
318+
engines.Memory()
319+
)
320+
321+
self.assertEqual(
322+
self.compile(CreateTable(table)),
323+
'CREATE TABLE t1 ('
324+
'total AggregateFunction(sum(), UInt32)) '
325+
'ENGINE = Memory'
326+
)
327+
328+
@require_server_version(22, 8, 21)
329+
def test_create_simple_aggregate_function(self):
330+
table = Table(
331+
't1', self.metadata(),
332+
Column(
333+
'total', types.SimpleAggregateFunction(
334+
func.sum(), types.UInt32
335+
)
336+
),
337+
engines.Memory()
338+
)
339+
340+
self.assertEqual(
341+
self.compile(CreateTable(table)),
342+
'CREATE TABLE t1 ('
343+
'total SimpleAggregateFunction(sum(), UInt32)) '
344+
'ENGINE = Memory'
345+
)
346+
314347
def test_table_create_on_cluster(self):
315348
create_sql = (
316349
'CREATE TABLE t1 ON CLUSTER test_cluster '

tests/test_reflection.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import enum
2-
from sqlalchemy import Column, inspect, types as sa_types
2+
from sqlalchemy import Column, func, inspect, types as sa_types
33

44
from clickhouse_sqlalchemy import types, engines, Table
5+
56
from tests.testcase import BaseTestCase
67
from tests.util import require_server_version, with_native_and_http_sessions
78

@@ -166,3 +167,65 @@ def test_datetime(self):
166167

167168
self.assertIsInstance(coltype, types.DateTime)
168169
self.assertIsNone(coltype.timezone)
170+
171+
def test_aggregate_function(self):
172+
coltype = self._type_round_trip(
173+
types.AggregateFunction(func.sum(), types.UInt16)
174+
)[0]['type']
175+
176+
self.assertIsInstance(coltype, types.AggregateFunction)
177+
self.assertEqual(coltype.agg_func, 'sum')
178+
self.assertEqual(len(coltype.nested_types), 1)
179+
self.assertIsInstance(coltype.nested_types[0], types.UInt16)
180+
181+
coltype = self._type_round_trip(
182+
types.AggregateFunction('quantiles(0.5, 0.9)', types.UInt32)
183+
)[0]['type']
184+
self.assertIsInstance(coltype, types.AggregateFunction)
185+
self.assertEqual(coltype.agg_func, 'quantiles(0.5, 0.9)')
186+
self.assertEqual(len(coltype.nested_types), 1)
187+
self.assertIsInstance(coltype.nested_types[0], types.UInt32)
188+
189+
coltype = self._type_round_trip(
190+
types.AggregateFunction(
191+
func.argMin(), types.Float32, types.Float32
192+
)
193+
)[0]['type']
194+
self.assertIsInstance(coltype, types.AggregateFunction)
195+
self.assertEqual(coltype.agg_func, 'argMin')
196+
self.assertEqual(len(coltype.nested_types), 2)
197+
self.assertIsInstance(coltype.nested_types[0], types.Float32)
198+
self.assertIsInstance(coltype.nested_types[1], types.Float32)
199+
200+
coltype = self._type_round_trip(
201+
types.AggregateFunction(
202+
'sum', types.Decimal(18, 2)
203+
)
204+
)[0]['type']
205+
self.assertIsInstance(coltype, types.AggregateFunction)
206+
self.assertEqual(coltype.agg_func, 'sum')
207+
self.assertEqual(len(coltype.nested_types), 1)
208+
self.assertIsInstance(coltype.nested_types[0], types.Decimal)
209+
self.assertEqual(coltype.nested_types[0].precision, 18)
210+
self.assertEqual(coltype.nested_types[0].scale, 2)
211+
212+
@require_server_version(22, 8, 21)
213+
def test_simple_aggregate_function(self):
214+
coltype = self._type_round_trip(
215+
types.SimpleAggregateFunction(func.sum(), types.UInt64)
216+
)[0]['type']
217+
218+
self.assertIsInstance(coltype, types.SimpleAggregateFunction)
219+
self.assertEqual(coltype.agg_func, 'sum')
220+
self.assertEqual(len(coltype.nested_types), 1)
221+
self.assertIsInstance(coltype.nested_types[0], types.UInt64)
222+
223+
coltype = self._type_round_trip(
224+
types.SimpleAggregateFunction(
225+
'sum', types.Float64
226+
)
227+
)[0]['type']
228+
self.assertIsInstance(coltype, types.SimpleAggregateFunction)
229+
self.assertEqual(coltype.agg_func, 'sum')
230+
self.assertEqual(len(coltype.nested_types), 1)
231+
self.assertIsInstance(coltype.nested_types[0], types.Float64)

0 commit comments

Comments
 (0)