diff --git a/sdk/eventhub/azure-eventhubs/azure/eventhub/_consumer_producer_mixin.py b/sdk/eventhub/azure-eventhubs/azure/eventhub/_consumer_producer_mixin.py index ff1a921d9220..341639213569 100644 --- a/sdk/eventhub/azure-eventhubs/azure/eventhub/_consumer_producer_mixin.py +++ b/sdk/eventhub/azure-eventhubs/azure/eventhub/_consumer_producer_mixin.py @@ -64,6 +64,8 @@ def _open(self, timeout_time=None): """ # pylint: disable=protected-access if not self.running: + if self._handler: + self._handler.close() if self.redirected: alt_creds = { "username": self.client._auth_config.get("iot_username"), diff --git a/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/_consumer_producer_mixin_async.py b/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/_consumer_producer_mixin_async.py index 68587637f1c3..23b8bd6a8fa6 100644 --- a/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/_consumer_producer_mixin_async.py +++ b/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/_consumer_producer_mixin_async.py @@ -65,6 +65,8 @@ async def _open(self, timeout_time=None): """ # pylint: disable=protected-access if not self.running: + if self._handler: + await self._handler.close_async() if self.redirected: alt_creds = { "username": self.client._auth_config.get("iot_username"), diff --git a/sdk/eventhub/azure-eventhubs/azure/eventhub/producer.py b/sdk/eventhub/azure-eventhubs/azure/eventhub/producer.py index c4a30d81b189..570bd8609964 100644 --- a/sdk/eventhub/azure-eventhubs/azure/eventhub/producer.py +++ b/sdk/eventhub/azure-eventhubs/azure/eventhub/producer.py @@ -223,7 +223,7 @@ def send(self, event_data, partition_key=None, timeout=None): wrapper_event_data = event_data else: if partition_key: - event_data = self._set_partition_key(event_data, partition_key) + event_data = _set_partition_key(event_data, partition_key) wrapper_event_data = EventDataBatch._from_batch(event_data, partition_key) # pylint: disable=protected-access wrapper_event_data.message.on_send_complete = self._on_outcome self.unsent_events = [wrapper_event_data.message] diff --git a/sdk/eventhub/azure-eventhubs/tests/asynctests/test_negative_async.py b/sdk/eventhub/azure-eventhubs/tests/asynctests/test_negative_async.py index 0ab4fe53f006..4406da855f59 100644 --- a/sdk/eventhub/azure-eventhubs/tests/asynctests/test_negative_async.py +++ b/sdk/eventhub/azure-eventhubs/tests/asynctests/test_negative_async.py @@ -30,6 +30,7 @@ async def test_send_with_invalid_hostname_async(invalid_hostname, connstr_receiv sender = client.create_producer() with pytest.raises(AuthenticationError): await sender.send(EventData("test data")) + await sender.close() @pytest.mark.liveTest @@ -39,6 +40,7 @@ async def test_receive_with_invalid_hostname_async(invalid_hostname): receiver = client.create_consumer(consumer_group="$default", partition_id="0", event_position=EventPosition("-1")) with pytest.raises(AuthenticationError): await receiver.receive(timeout=3) + await receiver.close() @pytest.mark.liveTest @@ -49,6 +51,7 @@ async def test_send_with_invalid_key_async(invalid_key, connstr_receivers): sender = client.create_producer() with pytest.raises(AuthenticationError): await sender.send(EventData("test data")) + await sender.close() @pytest.mark.liveTest @@ -58,6 +61,7 @@ async def test_receive_with_invalid_key_async(invalid_key): receiver = client.create_consumer(consumer_group="$default", partition_id="0", event_position=EventPosition("-1")) with pytest.raises(AuthenticationError): await receiver.receive(timeout=3) + await receiver.close() @pytest.mark.liveTest @@ -68,6 +72,7 @@ async def test_send_with_invalid_policy_async(invalid_policy, connstr_receivers) sender = client.create_producer() with pytest.raises(AuthenticationError): await sender.send(EventData("test data")) + await sender.close() @pytest.mark.liveTest @@ -77,6 +82,7 @@ async def test_receive_with_invalid_policy_async(invalid_policy): receiver = client.create_consumer(consumer_group="$default", partition_id="0", event_position=EventPosition("-1")) with pytest.raises(AuthenticationError): await receiver.receive(timeout=3) + await receiver.close() @pytest.mark.liveTest diff --git a/sdk/eventhub/azure-eventhubs/tests/test_negative.py b/sdk/eventhub/azure-eventhubs/tests/test_negative.py index 3682461f9db2..a1fee7605818 100644 --- a/sdk/eventhub/azure-eventhubs/tests/test_negative.py +++ b/sdk/eventhub/azure-eventhubs/tests/test_negative.py @@ -27,6 +27,7 @@ def test_send_with_invalid_hostname(invalid_hostname, connstr_receivers): sender = client.create_producer() with pytest.raises(AuthenticationError): sender.send(EventData("test data")) + sender.close() @pytest.mark.liveTest @@ -47,6 +48,7 @@ def test_send_with_invalid_key(invalid_key, connstr_receivers): sender.send(EventData("test data")) sender.close() + @pytest.mark.liveTest def test_receive_with_invalid_key_sync(invalid_key): client = EventHubClient.from_connection_string(invalid_key, network_tracing=False) @@ -96,13 +98,13 @@ def test_non_existing_entity_sender(connection_str): sender = client.create_producer(partition_id="1") with pytest.raises(AuthenticationError): sender.send(EventData("test data")) + sender.close() @pytest.mark.liveTest def test_non_existing_entity_receiver(connection_str): client = EventHubClient.from_connection_string(connection_str, event_hub_path="nemo", network_tracing=False) receiver = client.create_consumer(consumer_group="$default", partition_id="0", event_position=EventPosition("-1")) - with pytest.raises(AuthenticationError): receiver.receive(timeout=5) receiver.close() diff --git a/sdk/eventhub/azure-eventhubs/tests/test_receive.py b/sdk/eventhub/azure-eventhubs/tests/test_receive.py index 35c5e39c992b..d241a8e6e585 100644 --- a/sdk/eventhub/azure-eventhubs/tests/test_receive.py +++ b/sdk/eventhub/azure-eventhubs/tests/test_receive.py @@ -148,10 +148,10 @@ def test_receive_with_custom_datetime_sync(connstr_senders): receiver = client.create_consumer(consumer_group="$default", partition_id="0", event_position=EventPosition(offset)) with receiver: all_received = [] - received = receiver.receive(timeout=1) + received = receiver.receive(timeout=5) while received: all_received.extend(received) - received = receiver.receive(timeout=1) + received = receiver.receive(timeout=5) assert len(all_received) == 5 for received_event in all_received: