Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions pyhap/accessory_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,24 @@ def async_subscribe_client_topic(self, client, topic, subscribe=True):
if not subscribed_clients:
del self.topics[topic]

def connection_lost(self, client):
"""Called when a connection is lost to a client.

This method must be run in the event loop.

:param client: A client (address, port) tuple that should be unsubscribed.
:type client: tuple <str, int>
"""
client_topics = []
for topic, subscribed_clients in self.topics.items():
if client in subscribed_clients:
# Make a copy to avoid changing
# self.topics during iteration
client_topics.append(topic)

for topic in client_topics:
self.async_subscribe_client_topic(client, topic, subscribe=False)

def publish(self, data, sender_client_addr=None):
"""Publishes an event to the client.

Expand Down
1 change: 1 addition & 0 deletions pyhap/hap_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def connection_lost(self, exc: Exception) -> None:
self.accessory_driver.accessory.display_name,
exc,
)
self.accessory_driver.connection_lost(self.peername)
self.close()

def connection_made(self, transport: asyncio.Transport) -> None:
Expand Down
2 changes: 2 additions & 0 deletions tests/test_accessory_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,8 @@ def test_async_subscribe_client_topic(driver):
assert driver.topics == {topic: {addr_info}}
driver.async_subscribe_client_topic(addr_info, topic, False)
assert driver.topics == {}
driver.async_subscribe_client_topic(addr_info, "invalid", False)
assert driver.topics == {}


def test_mdns_service_info(driver):
Expand Down
22 changes: 22 additions & 0 deletions tests/test_hap_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,23 +36,45 @@ def test_connection_management(driver):
"""Verify closing the connection removes it from the pool."""
loop = MagicMock()
addr_info = ("1.2.3.4", 5)
addr_info2 = ("1.2.3.5", 6)

transport = MagicMock(get_extra_info=Mock(return_value=addr_info))
connections = {}
driver.add_accessory(Accessory(driver, "TestAcc"))
driver.async_subscribe_client_topic(addr_info, "1.1", True)
driver.async_subscribe_client_topic(addr_info, "2.2", True)
driver.async_subscribe_client_topic(addr_info2, "1.1", True)

assert "1.1" in driver.topics
assert "2.2" in driver.topics

assert addr_info in driver.topics["1.1"]
assert addr_info in driver.topics["2.2"]
assert addr_info2 in driver.topics["1.1"]

hap_proto = hap_protocol.HAPServerProtocol(loop, connections, driver)
hap_proto.connection_made(transport)
assert len(connections) == 1
assert connections[addr_info] == hap_proto
hap_proto.connection_lost(None)
assert len(connections) == 0
assert "1.1" in driver.topics
assert "2.2" not in driver.topics
assert addr_info not in driver.topics["1.1"]
assert addr_info2 in driver.topics["1.1"]

hap_proto.connection_made(transport)
assert len(connections) == 1
assert connections[addr_info] == hap_proto
hap_proto.close()
assert len(connections) == 0

hap_proto.connection_made(transport)
assert len(connections) == 1
assert connections[addr_info] == hap_proto
hap_proto.connection_lost(None)
assert len(connections) == 0


def test_pair_setup(driver):
"""Verify an non-encrypt request."""
Expand Down