Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions mockafka/aiokafka/aiokafka_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,24 @@ async def getmany(

return dict(result)

def __aiter__(self):
if self._is_closed:
raise ConsumerStoppedError()
return self

async def __anext__(self) -> ConsumerRecord[bytes, bytes]:
while True:
try:
result = await self.getone()
if result is None:
# Follow the lead of `getone`, though note that we should
# address this as part of any fix to
# https://github.com/alm0ra/mockafka-py/issues/117
raise StopAsyncIteration
return result
except ConsumerStoppedError:
raise StopAsyncIteration from None

async def __aenter__(self) -> Self:
await self.start()
return self
Expand Down
65 changes: 55 additions & 10 deletions tests/test_aiokafka/test_aiokafka_consumer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import sys
import itertools
from unittest import IsolatedAsyncioTestCase

Expand All @@ -17,6 +18,13 @@
)
from mockafka.kafka_store import KafkaStore

if sys.version_info < (3, 10):
def aiter(async_iterable): # noqa: A001
return async_iterable.__aiter__()

async def anext(async_iterable): # noqa: A001
return await async_iterable.__anext__()


@pytest.mark.asyncio
class TestAIOKAFKAFakeConsumer(IsolatedAsyncioTestCase):
Expand All @@ -40,7 +48,7 @@ def topic(self):
def create_topic(self):
self.kafka.create_partition(topic=self.test_topic, partitions=16)

async def produce_message(self):
async def produce_two_messages(self):
await self.producer.send(
topic=self.test_topic, partition=0, key=b"test", value=b"test"
)
Expand All @@ -51,6 +59,40 @@ async def produce_message(self):
async def test_consume(self):
await self.test_poll_with_commit()

async def test_async_iterator(self):
self.create_topic()
await self.produce_two_messages()
self.consumer.subscribe(topics=[self.test_topic])
await self.consumer.start()

iterator = aiter(self.consumer)
message = await anext(iterator)
self.assertEqual(message.value, b"test")

message = await anext(iterator)
self.assertEqual(message.value, b"test1")

# Technically at this point aiokafka's consumer would block
# indefinitely, however since that's not useful in tests we instead stop
# iterating.
with pytest.raises(StopAsyncIteration):
await anext(iterator)

async def test_async_iterator_closed_early(self):
self.create_topic()
await self.produce_two_messages()
self.consumer.subscribe(topics=[self.test_topic])
await self.consumer.start()

iterator = aiter(self.consumer)
message = await anext(iterator)
self.assertEqual(message.value, b"test")

await self.consumer.stop()

with pytest.raises(StopAsyncIteration):
await anext(iterator)

async def test_start(self):
# check consumer store is empty
await self.consumer.start()
Expand All @@ -69,7 +111,7 @@ async def test_start(self):

async def test_poll_without_commit(self):
self.create_topic()
await self.produce_message()
await self.produce_two_messages()
self.consumer.subscribe(topics=[self.test_topic])
await self.consumer.start()

Expand All @@ -83,7 +125,7 @@ async def test_poll_without_commit(self):

async def test_partition_specific_poll_without_commit(self):
self.create_topic()
await self.produce_message()
await self.produce_two_messages()
self.consumer.subscribe(topics=[self.test_topic])
await self.consumer.start()

Expand All @@ -99,7 +141,7 @@ async def test_partition_specific_poll_without_commit(self):

async def test_poll_with_commit(self):
self.create_topic()
await self.produce_message()
await self.produce_two_messages()
self.consumer.subscribe(topics=[self.test_topic])
await self.consumer.start()

Expand All @@ -116,7 +158,7 @@ async def test_poll_with_commit(self):

async def test_getmany_without_commit(self):
self.create_topic()
await self.produce_message()
await self.produce_two_messages()
await self.producer.send(
topic=self.test_topic, partition=2, key=b"test2", value=b"test2"
)
Expand Down Expand Up @@ -145,7 +187,7 @@ async def test_getmany_without_commit(self):

async def test_getmany_with_limit_without_commit(self):
self.create_topic()
await self.produce_message()
await self.produce_two_messages()
await self.producer.send(
topic=self.test_topic, partition=0, key=b"test2", value=b"test2"
)
Expand Down Expand Up @@ -182,7 +224,7 @@ async def test_getmany_with_limit_without_commit(self):

async def test_getmany_specific_poll_without_commit(self):
self.create_topic()
await self.produce_message()
await self.produce_two_messages()
await self.producer.send(
topic=self.test_topic, partition=1, key=b"test2", value=b"test2"
)
Expand Down Expand Up @@ -210,7 +252,7 @@ async def test_getmany_specific_poll_without_commit(self):

async def test_getmany_with_commit(self):
self.create_topic()
await self.produce_message()
await self.produce_two_messages()
await self.producer.send(
topic=self.test_topic, partition=2, key=b"test2", value=b"test2"
)
Expand Down Expand Up @@ -287,7 +329,7 @@ async def test_lifecycle(self):

self.assertEqual(self.consumer.subscribed_topic, topics)

await self.produce_message()
await self.produce_two_messages()

messages = {
tp: self.summarise(msgs)
Expand Down Expand Up @@ -336,7 +378,7 @@ async def test_context_manager(self):

async with self.consumer as consumer:
self.assertEqual(self.consumer, consumer)
await self.produce_message()
await self.produce_two_messages()

messages = {
tp: self.summarise(msgs)
Expand Down Expand Up @@ -373,3 +415,6 @@ async def test_consumer_is_stopped(self):
self.consumer.subscribe(topics=topics)
with self.assertRaises(ConsumerStoppedError):
await self.consumer.getone()

with self.assertRaises(ConsumerStoppedError):
aiter(self.consumer)